08_reinforce.py

Download
python 115 lines 3.1 KB
  1"""
  2REINFORCE (Policy Gradient) 구현
  3"""
  4import torch
  5import torch.nn as nn
  6import torch.nn.functional as F
  7import numpy as np
  8import gymnasium as gym
  9
 10
 11class PolicyNetwork(nn.Module):
 12    def __init__(self, state_dim, action_dim, hidden_dim=128):
 13        super().__init__()
 14        self.network = nn.Sequential(
 15            nn.Linear(state_dim, hidden_dim),
 16            nn.ReLU(),
 17            nn.Linear(hidden_dim, hidden_dim),
 18            nn.ReLU(),
 19            nn.Linear(hidden_dim, action_dim)
 20        )
 21
 22    def forward(self, state):
 23        return F.softmax(self.network(state), dim=-1)
 24
 25    def get_action(self, state):
 26        probs = self.forward(state)
 27        dist = torch.distributions.Categorical(probs)
 28        action = dist.sample()
 29        return action.item(), dist.log_prob(action)
 30
 31
 32class REINFORCE:
 33    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99):
 34        self.policy = PolicyNetwork(state_dim, action_dim)
 35        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
 36        self.gamma = gamma
 37        self.log_probs = []
 38        self.rewards = []
 39
 40    def choose_action(self, state):
 41        state_tensor = torch.FloatTensor(state).unsqueeze(0)
 42        action, log_prob = self.policy.get_action(state_tensor)
 43        self.log_probs.append(log_prob)
 44        return action
 45
 46    def store_reward(self, reward):
 47        self.rewards.append(reward)
 48
 49    def compute_returns(self):
 50        returns = []
 51        G = 0
 52        for r in reversed(self.rewards):
 53            G = r + self.gamma * G
 54            returns.insert(0, G)
 55        returns = torch.tensor(returns)
 56        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
 57        return returns
 58
 59    def update(self):
 60        returns = self.compute_returns()
 61
 62        policy_loss = []
 63        for log_prob, G in zip(self.log_probs, returns):
 64            policy_loss.append(-log_prob * G)
 65
 66        loss = torch.stack(policy_loss).sum()
 67
 68        self.optimizer.zero_grad()
 69        loss.backward()
 70        self.optimizer.step()
 71
 72        self.log_probs = []
 73        self.rewards = []
 74        return loss.item()
 75
 76
 77def train():
 78    env = gym.make('CartPole-v1')
 79    state_dim = env.observation_space.shape[0]
 80    action_dim = env.action_space.n
 81
 82    agent = REINFORCE(state_dim, action_dim, lr=1e-3)
 83    scores = []
 84
 85    for episode in range(1000):
 86        state, _ = env.reset()
 87        total_reward = 0
 88        done = False
 89
 90        while not done:
 91            action = agent.choose_action(state)
 92            next_state, reward, terminated, truncated, _ = env.step(action)
 93            done = terminated or truncated
 94
 95            agent.store_reward(reward)
 96            state = next_state
 97            total_reward += reward
 98
 99        agent.update()
100        scores.append(total_reward)
101
102        if (episode + 1) % 100 == 0:
103            print(f"Episode {episode + 1}, Avg: {np.mean(scores[-100:]):.2f}")
104
105        if np.mean(scores[-100:]) >= 475:
106            print(f"Solved in {episode + 1} episodes!")
107            break
108
109    env.close()
110    return agent, scores
111
112
113if __name__ == "__main__":
114    agent, scores = train()