1"""
2REINFORCE (Policy Gradient) 구현
3"""
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7import numpy as np
8import gymnasium as gym
9
10
11class PolicyNetwork(nn.Module):
12 def __init__(self, state_dim, action_dim, hidden_dim=128):
13 super().__init__()
14 self.network = nn.Sequential(
15 nn.Linear(state_dim, hidden_dim),
16 nn.ReLU(),
17 nn.Linear(hidden_dim, hidden_dim),
18 nn.ReLU(),
19 nn.Linear(hidden_dim, action_dim)
20 )
21
22 def forward(self, state):
23 return F.softmax(self.network(state), dim=-1)
24
25 def get_action(self, state):
26 probs = self.forward(state)
27 dist = torch.distributions.Categorical(probs)
28 action = dist.sample()
29 return action.item(), dist.log_prob(action)
30
31
32class REINFORCE:
33 def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99):
34 self.policy = PolicyNetwork(state_dim, action_dim)
35 self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
36 self.gamma = gamma
37 self.log_probs = []
38 self.rewards = []
39
40 def choose_action(self, state):
41 state_tensor = torch.FloatTensor(state).unsqueeze(0)
42 action, log_prob = self.policy.get_action(state_tensor)
43 self.log_probs.append(log_prob)
44 return action
45
46 def store_reward(self, reward):
47 self.rewards.append(reward)
48
49 def compute_returns(self):
50 returns = []
51 G = 0
52 for r in reversed(self.rewards):
53 G = r + self.gamma * G
54 returns.insert(0, G)
55 returns = torch.tensor(returns)
56 returns = (returns - returns.mean()) / (returns.std() + 1e-8)
57 return returns
58
59 def update(self):
60 returns = self.compute_returns()
61
62 policy_loss = []
63 for log_prob, G in zip(self.log_probs, returns):
64 policy_loss.append(-log_prob * G)
65
66 loss = torch.stack(policy_loss).sum()
67
68 self.optimizer.zero_grad()
69 loss.backward()
70 self.optimizer.step()
71
72 self.log_probs = []
73 self.rewards = []
74 return loss.item()
75
76
77def train():
78 env = gym.make('CartPole-v1')
79 state_dim = env.observation_space.shape[0]
80 action_dim = env.action_space.n
81
82 agent = REINFORCE(state_dim, action_dim, lr=1e-3)
83 scores = []
84
85 for episode in range(1000):
86 state, _ = env.reset()
87 total_reward = 0
88 done = False
89
90 while not done:
91 action = agent.choose_action(state)
92 next_state, reward, terminated, truncated, _ = env.step(action)
93 done = terminated or truncated
94
95 agent.store_reward(reward)
96 state = next_state
97 total_reward += reward
98
99 agent.update()
100 scores.append(total_reward)
101
102 if (episode + 1) % 100 == 0:
103 print(f"Episode {episode + 1}, Avg: {np.mean(scores[-100:]):.2f}")
104
105 if np.mean(scores[-100:]) >= 475:
106 print(f"Solved in {episode + 1} episodes!")
107 break
108
109 env.close()
110 return agent, scores
111
112
113if __name__ == "__main__":
114 agent, scores = train()