1"""
2DQNμΌλ‘ CartPole νμ΅
3"""
4import torch
5import torch.nn as nn
6import torch.optim as optim
7import numpy as np
8import gymnasium as gym
9from collections import deque
10import random
11
12
13class QNetwork(nn.Module):
14 def __init__(self, state_dim, action_dim, hidden_dim=128):
15 super().__init__()
16 self.network = nn.Sequential(
17 nn.Linear(state_dim, hidden_dim),
18 nn.ReLU(),
19 nn.Linear(hidden_dim, hidden_dim),
20 nn.ReLU(),
21 nn.Linear(hidden_dim, action_dim)
22 )
23
24 def forward(self, x):
25 return self.network(x)
26
27
28class ReplayBuffer:
29 def __init__(self, capacity=100000):
30 self.buffer = deque(maxlen=capacity)
31
32 def push(self, state, action, reward, next_state, done):
33 self.buffer.append((state, action, reward, next_state, done))
34
35 def sample(self, batch_size):
36 batch = random.sample(self.buffer, batch_size)
37 states, actions, rewards, next_states, dones = zip(*batch)
38 return (
39 torch.FloatTensor(np.array(states)),
40 torch.LongTensor(actions),
41 torch.FloatTensor(rewards),
42 torch.FloatTensor(np.array(next_states)),
43 torch.FloatTensor(dones)
44 )
45
46 def __len__(self):
47 return len(self.buffer)
48
49
50class DQNAgent:
51 def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99,
52 epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995,
53 batch_size=64, target_update_freq=100):
54 self.q_network = QNetwork(state_dim, action_dim)
55 self.target_network = QNetwork(state_dim, action_dim)
56 self.target_network.load_state_dict(self.q_network.state_dict())
57
58 self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
59 self.buffer = ReplayBuffer()
60
61 self.gamma = gamma
62 self.epsilon = epsilon
63 self.epsilon_min = epsilon_min
64 self.epsilon_decay = epsilon_decay
65 self.batch_size = batch_size
66 self.target_update_freq = target_update_freq
67 self.action_dim = action_dim
68 self.learn_step = 0
69
70 def choose_action(self, state):
71 if np.random.random() < self.epsilon:
72 return np.random.randint(self.action_dim)
73 with torch.no_grad():
74 q_values = self.q_network(torch.FloatTensor(state).unsqueeze(0))
75 return q_values.argmax().item()
76
77 def learn(self):
78 if len(self.buffer) < self.batch_size:
79 return None
80
81 states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
82
83 current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()
84
85 with torch.no_grad():
86 next_q = self.target_network(next_states).max(1)[0]
87 target_q = rewards + self.gamma * next_q * (1 - dones)
88
89 loss = nn.MSELoss()(current_q, target_q)
90
91 self.optimizer.zero_grad()
92 loss.backward()
93 self.optimizer.step()
94
95 self.learn_step += 1
96 if self.learn_step % self.target_update_freq == 0:
97 self.target_network.load_state_dict(self.q_network.state_dict())
98
99 self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
100 return loss.item()
101
102
103def train():
104 env = gym.make('CartPole-v1')
105 state_dim = env.observation_space.shape[0]
106 action_dim = env.action_space.n
107
108 agent = DQNAgent(state_dim, action_dim)
109 scores = []
110
111 for episode in range(300):
112 state, _ = env.reset()
113 score = 0
114
115 for _ in range(500):
116 action = agent.choose_action(state)
117 next_state, reward, done, truncated, _ = env.step(action)
118
119 agent.buffer.push(state, action, reward, next_state, done or truncated)
120 agent.learn()
121
122 state = next_state
123 score += reward
124
125 if done or truncated:
126 break
127
128 scores.append(score)
129
130 if (episode + 1) % 10 == 0:
131 print(f"Episode {episode + 1}, Score: {np.mean(scores[-10:]):.2f}, "
132 f"Epsilon: {agent.epsilon:.3f}")
133
134 if np.mean(scores[-100:]) >= 475:
135 print(f"Solved in {episode + 1} episodes!")
136 break
137
138 env.close()
139 return agent, scores
140
141
142if __name__ == "__main__":
143 agent, scores = train()