07. Deep Q-Network (DQN)

07. Deep Q-Network (DQN)

๋‚œ์ด๋„: โญโญโญ (์ค‘๊ธ‰)

ํ•™์Šต ๋ชฉํ‘œ

  • DQN์˜ ํ•ต์‹ฌ ์•„์ด๋””์–ด์™€ ๊ตฌ์กฐ ์ดํ•ด
  • Experience Replay์˜ ์›๋ฆฌ์™€ ๊ตฌํ˜„
  • Target Network์˜ ํ•„์š”์„ฑ๊ณผ ๋™์ž‘ ๋ฐฉ์‹
  • Double DQN, Dueling DQN ๋“ฑ ๊ฐœ์„  ๊ธฐ๋ฒ•
  • PyTorch๋กœ DQN ๊ตฌํ˜„

1. Q-Learning์˜ ํ•œ๊ณ„์™€ DQN

1.1 ํ…Œ์ด๋ธ” ๊ธฐ๋ฐ˜ Q-Learning์˜ ํ•œ๊ณ„

๋ฌธ์ œ์ :
1. ์ƒํƒœ ๊ณต๊ฐ„์ด ํฌ๋ฉด ํ…Œ์ด๋ธ” ์ €์žฅ ๋ถˆ๊ฐ€ (Atari: 256^(84*84*4) ์ƒํƒœ)
2. ์—ฐ์† ์ƒํƒœ ๊ณต๊ฐ„ ์ฒ˜๋ฆฌ ๋ถˆ๊ฐ€
3. ๋น„์Šทํ•œ ์ƒํƒœ ๊ฐ„ ์ผ๋ฐ˜ํ™” ๋ถˆ๊ฐ€

1.2 ํ•จ์ˆ˜ ๊ทผ์‚ฌ (Function Approximation)

# ํ…Œ์ด๋ธ” ๋Œ€์‹  ์‹ ๊ฒฝ๋ง์œผ๋กœ Q ํ•จ์ˆ˜ ๊ทผ์‚ฌ
# Q(s, a) โ‰ˆ Q(s, a; ฮธ)

import torch
import torch.nn as nn

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, state):
        return self.network(state)  # ๋ชจ๋“  ํ–‰๋™์˜ Q๊ฐ’ ์ถœ๋ ฅ

2. DQN์˜ ํ•ต์‹ฌ ๊ธฐ๋ฒ•

2.1 Experience Replay

๊ฒฝํ—˜์„ ๋ฒ„ํผ์— ์ €์žฅํ•˜๊ณ  ๋ฌด์ž‘์œ„๋กœ ์ƒ˜ํ”Œ๋งํ•˜์—ฌ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

์žฅ์ : - ๋ฐ์ดํ„ฐ ํšจ์œจ์„ฑ ํ–ฅ์ƒ (๊ฒฝํ—˜ ์žฌ์‚ฌ์šฉ) - ์—ฐ์† ์ƒ˜ํ”Œ์˜ ์ƒ๊ด€๊ด€๊ณ„ ์ œ๊ฑฐ - ํ•™์Šต ์•ˆ์ •ํ™”

from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        return (
            torch.FloatTensor(states),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(next_states),
            torch.FloatTensor(dones)
        )

    def __len__(self):
        return len(self.buffer)

2.2 Target Network

๋ณ„๋„์˜ ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต ์•ˆ์ •ํ™”ํ•ฉ๋‹ˆ๋‹ค.

๋ฌธ์ œ: Q(s,a;ฮธ) ์—…๋ฐ์ดํŠธ ์‹œ ํƒ€๊ฒŸ y = r + ฮณ max Q(s',a';ฮธ)๋„ ๋ณ€ํ•จ ํ•ด๊ฒฐ: ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ ฮธโป๋ฅผ ๊ณ ์ •ํ•˜๊ณ  ์ฃผ๊ธฐ์ ์œผ๋กœ ์—…๋ฐ์ดํŠธ

class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=1e-4):
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)

        # ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ ์ดˆ๊ธฐํ™” (๋™์ผํ•œ ๊ฐ€์ค‘์น˜)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = 0.99

    def update_target_network(self):
        """ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ ํ•˜๋“œ ์—…๋ฐ์ดํŠธ"""
        self.target_network.load_state_dict(self.q_network.state_dict())

    def soft_update_target(self, tau=0.005):
        """ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ ์†Œํ”„ํŠธ ์—…๋ฐ์ดํŠธ"""
        for target_param, param in zip(
            self.target_network.parameters(),
            self.q_network.parameters()
        ):
            target_param.data.copy_(
                tau * param.data + (1 - tau) * target_param.data
            )

3. DQN ์ „์ฒด ๊ตฌํ˜„

3.1 ์—์ด์ „ํŠธ ํด๋ž˜์Šค

import numpy as np

class DQNAgent:
    def __init__(
        self,
        state_dim,
        action_dim,
        lr=1e-4,
        gamma=0.99,
        epsilon=1.0,
        epsilon_min=0.01,
        epsilon_decay=0.995,
        buffer_size=100000,
        batch_size=64,
        target_update_freq=1000
    ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq
        self.learn_step = 0

        # ๋„คํŠธ์›Œํฌ
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
        self.buffer = ReplayBuffer(buffer_size)

    def choose_action(self, state, training=True):
        if training and np.random.random() < self.epsilon:
            return np.random.randint(self.action_dim)

        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            q_values = self.q_network(state_tensor)
            return q_values.argmax().item()

    def store_transition(self, state, action, reward, next_state, done):
        self.buffer.push(state, action, reward, next_state, done)

    def learn(self):
        if len(self.buffer) < self.batch_size:
            return None

        # ๋ฐฐ์น˜ ์ƒ˜ํ”Œ๋ง
        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

        # ํ˜„์žฌ Q๊ฐ’
        current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()

        # ํƒ€๊ฒŸ Q๊ฐ’ (ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ ์‚ฌ์šฉ)
        with torch.no_grad():
            next_q = self.target_network(next_states).max(1)[0]
            target_q = rewards + self.gamma * next_q * (1 - dones)

        # ์†์‹ค ๊ณ„์‚ฐ ๋ฐ ์—…๋ฐ์ดํŠธ
        loss = nn.MSELoss()(current_q, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        # ๊ทธ๋ž˜๋””์–ธํŠธ ํด๋ฆฌํ•‘ (์•ˆ์ •์„ฑ)
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
        self.optimizer.step()

        # ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ ์—…๋ฐ์ดํŠธ
        self.learn_step += 1
        if self.learn_step % self.target_update_freq == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

        # Epsilon ๊ฐ์†Œ
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

        return loss.item()

3.2 ํ•™์Šต ๋ฃจํ”„

import gymnasium as gym

def train_dqn(env_name='CartPole-v1', n_episodes=500):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    agent = DQNAgent(state_dim, action_dim)

    rewards_history = []

    for episode in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False

        while not done:
            action = agent.choose_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            agent.store_transition(state, action, reward, next_state, done)
            loss = agent.learn()

            state = next_state
            total_reward += reward

        rewards_history.append(total_reward)

        if (episode + 1) % 10 == 0:
            avg_reward = np.mean(rewards_history[-10:])
            print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, "
                  f"Epsilon: {agent.epsilon:.3f}")

    return agent, rewards_history

4. DQN ๊ฐœ์„  ๊ธฐ๋ฒ•

4.1 Double DQN

์ผ๋ฐ˜ DQN์˜ Q๊ฐ’ ๊ณผ๋Œ€์ถ”์ • ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•ฉ๋‹ˆ๋‹ค.

# ์ผ๋ฐ˜ DQN: y = r + ฮณ max_a' Q(s', a'; ฮธโป)
# Double DQN: y = r + ฮณ Q(s', argmax_a' Q(s', a'; ฮธ); ฮธโป)

def compute_double_dqn_target(self, rewards, next_states, dones):
    with torch.no_grad():
        # Q ๋„คํŠธ์›Œํฌ๋กœ ํ–‰๋™ ์„ ํƒ
        next_actions = self.q_network(next_states).argmax(1, keepdim=True)

        # ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ๋กœ Q๊ฐ’ ํ‰๊ฐ€
        next_q = self.target_network(next_states).gather(1, next_actions).squeeze()

        target_q = rewards + self.gamma * next_q * (1 - dones)

    return target_q

4.2 Dueling DQN

Q ํ•จ์ˆ˜๋ฅผ V(์ƒํƒœ ๊ฐ€์น˜)์™€ A(์–ด๋“œ๋ฐดํ‹ฐ์ง€)๋กœ ๋ถ„ํ•ดํ•ฉ๋‹ˆ๋‹ค.

Q(s, a) = V(s) + A(s, a) - mean(A(s, ยท))
class DuelingQNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()

        # ๊ณต์œ  ํŠน์ง• ์ถ”์ถœ
        self.feature = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )

        # ๊ฐ€์น˜ ์ŠคํŠธ๋ฆผ (V)
        self.value_stream = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # ์–ด๋“œ๋ฐดํ‹ฐ์ง€ ์ŠคํŠธ๋ฆผ (A)
        self.advantage_stream = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, state):
        features = self.feature(state)
        value = self.value_stream(features)
        advantage = self.advantage_stream(features)

        # Q = V + A - mean(A)
        q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
        return q_values

4.3 Prioritized Experience Replay (PER)

TD ์˜ค๋ฅ˜๊ฐ€ ํฐ ๊ฒฝํ—˜์„ ๋” ์ž์ฃผ ์ƒ˜ํ”Œ๋งํ•ฉ๋‹ˆ๋‹ค.

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4):
        self.capacity = capacity
        self.alpha = alpha  # ์šฐ์„ ์ˆœ์œ„ ์ง€์ˆ˜
        self.beta = beta    # ์ค‘์š”๋„ ์ƒ˜ํ”Œ๋ง ์ง€์ˆ˜
        self.buffer = []
        self.priorities = np.zeros(capacity)
        self.position = 0

    def push(self, *experience, priority=1.0):
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
        else:
            self.buffer[self.position] = experience

        self.priorities[self.position] = priority ** self.alpha
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        total_priority = self.priorities[:len(self.buffer)].sum()
        probs = self.priorities[:len(self.buffer)] / total_priority

        indices = np.random.choice(len(self.buffer), batch_size, p=probs)

        # ์ค‘์š”๋„ ์ƒ˜ํ”Œ๋ง ๊ฐ€์ค‘์น˜
        weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
        weights /= weights.max()

        batch = [self.buffer[i] for i in indices]
        return batch, indices, torch.FloatTensor(weights)

    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors):
            self.priorities[idx] = (abs(td_error) + 1e-6) ** self.alpha

5. CNN ๊ธฐ๋ฐ˜ DQN (Atari)

5.1 ์ด๋ฏธ์ง€ ์ž…๋ ฅ ๋„คํŠธ์›Œํฌ

class AtariDQN(nn.Module):
    def __init__(self, n_actions):
        super().__init__()

        # ์ž…๋ ฅ: 84x84x4 (4 ํ”„๋ ˆ์ž„ ์Šคํƒ)
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def forward(self, x):
        # x shape: (batch, 4, 84, 84)
        x = x / 255.0  # ์ •๊ทœํ™”
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

5.2 ํ”„๋ ˆ์ž„ ์ „์ฒ˜๋ฆฌ

import cv2

class AtariPreprocessor:
    def __init__(self, frame_stack=4):
        self.frame_stack = frame_stack
        self.frames = deque(maxlen=frame_stack)

    def preprocess_frame(self, frame):
        """84x84 ๊ทธ๋ ˆ์ด์Šค์ผ€์ผ๋กœ ๋ณ€ํ™˜"""
        gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
        return resized

    def reset(self, frame):
        processed = self.preprocess_frame(frame)
        for _ in range(self.frame_stack):
            self.frames.append(processed)
        return np.array(self.frames)

    def step(self, frame):
        processed = self.preprocess_frame(frame)
        self.frames.append(processed)
        return np.array(self.frames)

6. ์‹ค์Šต: CartPole-v1

def main():
    # ํ™˜๊ฒฝ ์„ค์ •
    env = gym.make('CartPole-v1')
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    # DQN ์—์ด์ „ํŠธ
    agent = DQNAgent(
        state_dim=state_dim,
        action_dim=action_dim,
        lr=1e-3,
        gamma=0.99,
        epsilon=1.0,
        epsilon_min=0.01,
        epsilon_decay=0.995,
        batch_size=64,
        target_update_freq=100
    )

    # ํ•™์Šต
    n_episodes = 300
    scores = []

    for episode in range(n_episodes):
        state, _ = env.reset()
        score = 0

        for t in range(500):
            action = agent.choose_action(state)
            next_state, reward, done, truncated, _ = env.step(action)

            agent.store_transition(state, action, reward, next_state, done or truncated)
            agent.learn()

            state = next_state
            score += reward

            if done or truncated:
                break

        scores.append(score)

        if (episode + 1) % 10 == 0:
            print(f"Episode {episode + 1}, Score: {np.mean(scores[-10:]):.2f}")

        # ํ•ด๊ฒฐ ์กฐ๊ฑด
        if np.mean(scores[-100:]) >= 475:
            print(f"Solved in {episode + 1} episodes!")
            break

    env.close()
    return agent, scores

if __name__ == "__main__":
    agent, scores = main()

์š”์•ฝ

๊ธฐ๋ฒ• ๋ชฉ์  ํ•ต์‹ฌ ์•„์ด๋””์–ด
Experience Replay ๋ฐ์ดํ„ฐ ํšจ์œจ์„ฑ, ์ƒ๊ด€๊ด€๊ณ„ ์ œ๊ฑฐ ๋ฒ„ํผ์—์„œ ๋ฌด์ž‘์œ„ ์ƒ˜ํ”Œ๋ง
Target Network ํ•™์Šต ์•ˆ์ •ํ™” ํƒ€๊ฒŸ ๊ณ ์ •, ์ฃผ๊ธฐ์  ์—…๋ฐ์ดํŠธ
Double DQN ๊ณผ๋Œ€์ถ”์ • ๋ฐฉ์ง€ ํ–‰๋™ ์„ ํƒ/ํ‰๊ฐ€ ๋ถ„๋ฆฌ
Dueling DQN ํšจ์œจ์  ํ•™์Šต V์™€ A ๋ถ„๋ฆฌ
PER ํšจ์œจ์  ์ƒ˜ํ”Œ๋ง TD ์˜ค๋ฅ˜ ๊ธฐ๋ฐ˜ ์šฐ์„ ์ˆœ์œ„

๋‹ค์Œ ๋‹จ๊ณ„

to navigate between lessons