07. Deep Q-Network (DQN)
07. Deep Q-Network (DQN)¶
๋์ด๋: โญโญโญ (์ค๊ธ)
ํ์ต ๋ชฉํ¶
- DQN์ ํต์ฌ ์์ด๋์ด์ ๊ตฌ์กฐ ์ดํด
- Experience Replay์ ์๋ฆฌ์ ๊ตฌํ
- Target Network์ ํ์์ฑ๊ณผ ๋์ ๋ฐฉ์
- Double DQN, Dueling DQN ๋ฑ ๊ฐ์ ๊ธฐ๋ฒ
- PyTorch๋ก DQN ๊ตฌํ
1. Q-Learning์ ํ๊ณ์ DQN¶
1.1 ํ ์ด๋ธ ๊ธฐ๋ฐ Q-Learning์ ํ๊ณ¶
๋ฌธ์ ์ :
1. ์ํ ๊ณต๊ฐ์ด ํฌ๋ฉด ํ
์ด๋ธ ์ ์ฅ ๋ถ๊ฐ (Atari: 256^(84*84*4) ์ํ)
2. ์ฐ์ ์ํ ๊ณต๊ฐ ์ฒ๋ฆฌ ๋ถ๊ฐ
3. ๋น์ทํ ์ํ ๊ฐ ์ผ๋ฐํ ๋ถ๊ฐ
1.2 ํจ์ ๊ทผ์ฌ (Function Approximation)¶
# ํ
์ด๋ธ ๋์ ์ ๊ฒฝ๋ง์ผ๋ก Q ํจ์ ๊ทผ์ฌ
# Q(s, a) โ Q(s, a; ฮธ)
import torch
import torch.nn as nn
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, state):
return self.network(state) # ๋ชจ๋ ํ๋์ Q๊ฐ ์ถ๋ ฅ
2. DQN์ ํต์ฌ ๊ธฐ๋ฒ¶
2.1 Experience Replay¶
๊ฒฝํ์ ๋ฒํผ์ ์ ์ฅํ๊ณ ๋ฌด์์๋ก ์ํ๋งํ์ฌ ํ์ตํฉ๋๋ค.
์ฅ์ : - ๋ฐ์ดํฐ ํจ์จ์ฑ ํฅ์ (๊ฒฝํ ์ฌ์ฌ์ฉ) - ์ฐ์ ์ํ์ ์๊ด๊ด๊ณ ์ ๊ฑฐ - ํ์ต ์์ ํ
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=100000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.FloatTensor(states),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(next_states),
torch.FloatTensor(dones)
)
def __len__(self):
return len(self.buffer)
2.2 Target Network¶
๋ณ๋์ ํ๊ฒ ๋คํธ์ํฌ๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต ์์ ํํฉ๋๋ค.
๋ฌธ์ : Q(s,a;ฮธ) ์ ๋ฐ์ดํธ ์ ํ๊ฒ y = r + ฮณ max Q(s',a';ฮธ)๋ ๋ณํจ ํด๊ฒฐ: ํ๊ฒ ๋คํธ์ํฌ ฮธโป๋ฅผ ๊ณ ์ ํ๊ณ ์ฃผ๊ธฐ์ ์ผ๋ก ์ ๋ฐ์ดํธ
class DQNAgent:
def __init__(self, state_dim, action_dim, lr=1e-4):
self.q_network = QNetwork(state_dim, action_dim)
self.target_network = QNetwork(state_dim, action_dim)
# ํ๊ฒ ๋คํธ์ํฌ ์ด๊ธฐํ (๋์ผํ ๊ฐ์ค์น)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
self.gamma = 0.99
def update_target_network(self):
"""ํ๊ฒ ๋คํธ์ํฌ ํ๋ ์
๋ฐ์ดํธ"""
self.target_network.load_state_dict(self.q_network.state_dict())
def soft_update_target(self, tau=0.005):
"""ํ๊ฒ ๋คํธ์ํฌ ์ํํธ ์
๋ฐ์ดํธ"""
for target_param, param in zip(
self.target_network.parameters(),
self.q_network.parameters()
):
target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data
)
3. DQN ์ ์ฒด ๊ตฌํ¶
3.1 ์์ด์ ํธ ํด๋์ค¶
import numpy as np
class DQNAgent:
def __init__(
self,
state_dim,
action_dim,
lr=1e-4,
gamma=0.99,
epsilon=1.0,
epsilon_min=0.01,
epsilon_decay=0.995,
buffer_size=100000,
batch_size=64,
target_update_freq=1000
):
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update_freq = target_update_freq
self.learn_step = 0
# ๋คํธ์ํฌ
self.q_network = QNetwork(state_dim, action_dim)
self.target_network = QNetwork(state_dim, action_dim)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
self.buffer = ReplayBuffer(buffer_size)
def choose_action(self, state, training=True):
if training and np.random.random() < self.epsilon:
return np.random.randint(self.action_dim)
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = self.q_network(state_tensor)
return q_values.argmax().item()
def store_transition(self, state, action, reward, next_state, done):
self.buffer.push(state, action, reward, next_state, done)
def learn(self):
if len(self.buffer) < self.batch_size:
return None
# ๋ฐฐ์น ์ํ๋ง
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
# ํ์ฌ Q๊ฐ
current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()
# ํ๊ฒ Q๊ฐ (ํ๊ฒ ๋คํธ์ํฌ ์ฌ์ฉ)
with torch.no_grad():
next_q = self.target_network(next_states).max(1)[0]
target_q = rewards + self.gamma * next_q * (1 - dones)
# ์์ค ๊ณ์ฐ ๋ฐ ์
๋ฐ์ดํธ
loss = nn.MSELoss()(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
# ๊ทธ๋๋์ธํธ ํด๋ฆฌํ (์์ ์ฑ)
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
self.optimizer.step()
# ํ๊ฒ ๋คํธ์ํฌ ์
๋ฐ์ดํธ
self.learn_step += 1
if self.learn_step % self.target_update_freq == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
# Epsilon ๊ฐ์
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
3.2 ํ์ต ๋ฃจํ¶
import gymnasium as gym
def train_dqn(env_name='CartPole-v1', n_episodes=500):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
rewards_history = []
for episode in range(n_episodes):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
action = agent.choose_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.store_transition(state, action, reward, next_state, done)
loss = agent.learn()
state = next_state
total_reward += reward
rewards_history.append(total_reward)
if (episode + 1) % 10 == 0:
avg_reward = np.mean(rewards_history[-10:])
print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, "
f"Epsilon: {agent.epsilon:.3f}")
return agent, rewards_history
4. DQN ๊ฐ์ ๊ธฐ๋ฒ¶
4.1 Double DQN¶
์ผ๋ฐ DQN์ Q๊ฐ ๊ณผ๋์ถ์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํฉ๋๋ค.
# ์ผ๋ฐ DQN: y = r + ฮณ max_a' Q(s', a'; ฮธโป)
# Double DQN: y = r + ฮณ Q(s', argmax_a' Q(s', a'; ฮธ); ฮธโป)
def compute_double_dqn_target(self, rewards, next_states, dones):
with torch.no_grad():
# Q ๋คํธ์ํฌ๋ก ํ๋ ์ ํ
next_actions = self.q_network(next_states).argmax(1, keepdim=True)
# ํ๊ฒ ๋คํธ์ํฌ๋ก Q๊ฐ ํ๊ฐ
next_q = self.target_network(next_states).gather(1, next_actions).squeeze()
target_q = rewards + self.gamma * next_q * (1 - dones)
return target_q
4.2 Dueling DQN¶
Q ํจ์๋ฅผ V(์ํ ๊ฐ์น)์ A(์ด๋๋ฐดํฐ์ง)๋ก ๋ถํดํฉ๋๋ค.
Q(s, a) = V(s) + A(s, a) - mean(A(s, ยท))
class DuelingQNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
# ๊ณต์ ํน์ง ์ถ์ถ
self.feature = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU()
)
# ๊ฐ์น ์คํธ๋ฆผ (V)
self.value_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# ์ด๋๋ฐดํฐ์ง ์คํธ๋ฆผ (A)
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, state):
features = self.feature(state)
value = self.value_stream(features)
advantage = self.advantage_stream(features)
# Q = V + A - mean(A)
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
return q_values
4.3 Prioritized Experience Replay (PER)¶
TD ์ค๋ฅ๊ฐ ํฐ ๊ฒฝํ์ ๋ ์์ฃผ ์ํ๋งํฉ๋๋ค.
class PrioritizedReplayBuffer:
def __init__(self, capacity, alpha=0.6, beta=0.4):
self.capacity = capacity
self.alpha = alpha # ์ฐ์ ์์ ์ง์
self.beta = beta # ์ค์๋ ์ํ๋ง ์ง์
self.buffer = []
self.priorities = np.zeros(capacity)
self.position = 0
def push(self, *experience, priority=1.0):
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
else:
self.buffer[self.position] = experience
self.priorities[self.position] = priority ** self.alpha
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
total_priority = self.priorities[:len(self.buffer)].sum()
probs = self.priorities[:len(self.buffer)] / total_priority
indices = np.random.choice(len(self.buffer), batch_size, p=probs)
# ์ค์๋ ์ํ๋ง ๊ฐ์ค์น
weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
weights /= weights.max()
batch = [self.buffer[i] for i in indices]
return batch, indices, torch.FloatTensor(weights)
def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors):
self.priorities[idx] = (abs(td_error) + 1e-6) ** self.alpha
5. CNN ๊ธฐ๋ฐ DQN (Atari)¶
5.1 ์ด๋ฏธ์ง ์ ๋ ฅ ๋คํธ์ํฌ¶
class AtariDQN(nn.Module):
def __init__(self, n_actions):
super().__init__()
# ์
๋ ฅ: 84x84x4 (4 ํ๋ ์ ์คํ)
self.conv = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
self.fc = nn.Sequential(
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def forward(self, x):
# x shape: (batch, 4, 84, 84)
x = x / 255.0 # ์ ๊ทํ
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
5.2 ํ๋ ์ ์ ์ฒ๋ฆฌ¶
import cv2
class AtariPreprocessor:
def __init__(self, frame_stack=4):
self.frame_stack = frame_stack
self.frames = deque(maxlen=frame_stack)
def preprocess_frame(self, frame):
"""84x84 ๊ทธ๋ ์ด์ค์ผ์ผ๋ก ๋ณํ"""
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
return resized
def reset(self, frame):
processed = self.preprocess_frame(frame)
for _ in range(self.frame_stack):
self.frames.append(processed)
return np.array(self.frames)
def step(self, frame):
processed = self.preprocess_frame(frame)
self.frames.append(processed)
return np.array(self.frames)
6. ์ค์ต: CartPole-v1¶
def main():
# ํ๊ฒฝ ์ค์
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
# DQN ์์ด์ ํธ
agent = DQNAgent(
state_dim=state_dim,
action_dim=action_dim,
lr=1e-3,
gamma=0.99,
epsilon=1.0,
epsilon_min=0.01,
epsilon_decay=0.995,
batch_size=64,
target_update_freq=100
)
# ํ์ต
n_episodes = 300
scores = []
for episode in range(n_episodes):
state, _ = env.reset()
score = 0
for t in range(500):
action = agent.choose_action(state)
next_state, reward, done, truncated, _ = env.step(action)
agent.store_transition(state, action, reward, next_state, done or truncated)
agent.learn()
state = next_state
score += reward
if done or truncated:
break
scores.append(score)
if (episode + 1) % 10 == 0:
print(f"Episode {episode + 1}, Score: {np.mean(scores[-10:]):.2f}")
# ํด๊ฒฐ ์กฐ๊ฑด
if np.mean(scores[-100:]) >= 475:
print(f"Solved in {episode + 1} episodes!")
break
env.close()
return agent, scores
if __name__ == "__main__":
agent, scores = main()
์์ฝ¶
| ๊ธฐ๋ฒ | ๋ชฉ์ | ํต์ฌ ์์ด๋์ด |
|---|---|---|
| Experience Replay | ๋ฐ์ดํฐ ํจ์จ์ฑ, ์๊ด๊ด๊ณ ์ ๊ฑฐ | ๋ฒํผ์์ ๋ฌด์์ ์ํ๋ง |
| Target Network | ํ์ต ์์ ํ | ํ๊ฒ ๊ณ ์ , ์ฃผ๊ธฐ์ ์ ๋ฐ์ดํธ |
| Double DQN | ๊ณผ๋์ถ์ ๋ฐฉ์ง | ํ๋ ์ ํ/ํ๊ฐ ๋ถ๋ฆฌ |
| Dueling DQN | ํจ์จ์ ํ์ต | V์ A ๋ถ๋ฆฌ |
| PER | ํจ์จ์ ์ํ๋ง | TD ์ค๋ฅ ๊ธฐ๋ฐ ์ฐ์ ์์ |
๋ค์ ๋จ๊ณ¶
- 08_Policy_Gradient.md - ์ ์ฑ ๊ธฐ๋ฐ ๋ฐฉ๋ฒ๋ก