07_dqn_cartpole.py

Download
python 144 lines 4.2 KB
  1"""
  2DQN으둜 CartPole ν•™μŠ΅
  3"""
  4import torch
  5import torch.nn as nn
  6import torch.optim as optim
  7import numpy as np
  8import gymnasium as gym
  9from collections import deque
 10import random
 11
 12
 13class QNetwork(nn.Module):
 14    def __init__(self, state_dim, action_dim, hidden_dim=128):
 15        super().__init__()
 16        self.network = nn.Sequential(
 17            nn.Linear(state_dim, hidden_dim),
 18            nn.ReLU(),
 19            nn.Linear(hidden_dim, hidden_dim),
 20            nn.ReLU(),
 21            nn.Linear(hidden_dim, action_dim)
 22        )
 23
 24    def forward(self, x):
 25        return self.network(x)
 26
 27
 28class ReplayBuffer:
 29    def __init__(self, capacity=100000):
 30        self.buffer = deque(maxlen=capacity)
 31
 32    def push(self, state, action, reward, next_state, done):
 33        self.buffer.append((state, action, reward, next_state, done))
 34
 35    def sample(self, batch_size):
 36        batch = random.sample(self.buffer, batch_size)
 37        states, actions, rewards, next_states, dones = zip(*batch)
 38        return (
 39            torch.FloatTensor(np.array(states)),
 40            torch.LongTensor(actions),
 41            torch.FloatTensor(rewards),
 42            torch.FloatTensor(np.array(next_states)),
 43            torch.FloatTensor(dones)
 44        )
 45
 46    def __len__(self):
 47        return len(self.buffer)
 48
 49
 50class DQNAgent:
 51    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99,
 52                 epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995,
 53                 batch_size=64, target_update_freq=100):
 54        self.q_network = QNetwork(state_dim, action_dim)
 55        self.target_network = QNetwork(state_dim, action_dim)
 56        self.target_network.load_state_dict(self.q_network.state_dict())
 57
 58        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
 59        self.buffer = ReplayBuffer()
 60
 61        self.gamma = gamma
 62        self.epsilon = epsilon
 63        self.epsilon_min = epsilon_min
 64        self.epsilon_decay = epsilon_decay
 65        self.batch_size = batch_size
 66        self.target_update_freq = target_update_freq
 67        self.action_dim = action_dim
 68        self.learn_step = 0
 69
 70    def choose_action(self, state):
 71        if np.random.random() < self.epsilon:
 72            return np.random.randint(self.action_dim)
 73        with torch.no_grad():
 74            q_values = self.q_network(torch.FloatTensor(state).unsqueeze(0))
 75            return q_values.argmax().item()
 76
 77    def learn(self):
 78        if len(self.buffer) < self.batch_size:
 79            return None
 80
 81        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
 82
 83        current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()
 84
 85        with torch.no_grad():
 86            next_q = self.target_network(next_states).max(1)[0]
 87            target_q = rewards + self.gamma * next_q * (1 - dones)
 88
 89        loss = nn.MSELoss()(current_q, target_q)
 90
 91        self.optimizer.zero_grad()
 92        loss.backward()
 93        self.optimizer.step()
 94
 95        self.learn_step += 1
 96        if self.learn_step % self.target_update_freq == 0:
 97            self.target_network.load_state_dict(self.q_network.state_dict())
 98
 99        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
100        return loss.item()
101
102
103def train():
104    env = gym.make('CartPole-v1')
105    state_dim = env.observation_space.shape[0]
106    action_dim = env.action_space.n
107
108    agent = DQNAgent(state_dim, action_dim)
109    scores = []
110
111    for episode in range(300):
112        state, _ = env.reset()
113        score = 0
114
115        for _ in range(500):
116            action = agent.choose_action(state)
117            next_state, reward, done, truncated, _ = env.step(action)
118
119            agent.buffer.push(state, action, reward, next_state, done or truncated)
120            agent.learn()
121
122            state = next_state
123            score += reward
124
125            if done or truncated:
126                break
127
128        scores.append(score)
129
130        if (episode + 1) % 10 == 0:
131            print(f"Episode {episode + 1}, Score: {np.mean(scores[-10:]):.2f}, "
132                  f"Epsilon: {agent.epsilon:.3f}")
133
134        if np.mean(scores[-100:]) >= 475:
135            print(f"Solved in {episode + 1} episodes!")
136            break
137
138    env.close()
139    return agent, scores
140
141
142if __name__ == "__main__":
143    agent, scores = train()