06_q_learning.py

Download
python 103 lines 2.9 KB
  1"""
  2Q-Learning๊ณผ SARSA ๊ตฌํ˜„
  3"""
  4import numpy as np
  5import gymnasium as gym
  6
  7
  8class QLearning:
  9    """Q-Learning ์—์ด์ „ํŠธ"""
 10
 11    def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99, epsilon=0.1):
 12        self.q_table = np.zeros((n_states, n_actions))
 13        self.alpha = alpha
 14        self.gamma = gamma
 15        self.epsilon = epsilon
 16        self.n_actions = n_actions
 17
 18    def choose_action(self, state):
 19        if np.random.random() < self.epsilon:
 20            return np.random.randint(self.n_actions)
 21        return np.argmax(self.q_table[state])
 22
 23    def update(self, state, action, reward, next_state, done):
 24        if done:
 25            target = reward
 26        else:
 27            target = reward + self.gamma * np.max(self.q_table[next_state])
 28
 29        td_error = target - self.q_table[state, action]
 30        self.q_table[state, action] += self.alpha * td_error
 31        return td_error
 32
 33
 34class SARSA:
 35    """SARSA ์—์ด์ „ํŠธ"""
 36
 37    def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99, epsilon=0.1):
 38        self.q_table = np.zeros((n_states, n_actions))
 39        self.alpha = alpha
 40        self.gamma = gamma
 41        self.epsilon = epsilon
 42        self.n_actions = n_actions
 43
 44    def choose_action(self, state):
 45        if np.random.random() < self.epsilon:
 46            return np.random.randint(self.n_actions)
 47        return np.argmax(self.q_table[state])
 48
 49    def update(self, state, action, reward, next_state, next_action, done):
 50        if done:
 51            target = reward
 52        else:
 53            target = reward + self.gamma * self.q_table[next_state, next_action]
 54
 55        td_error = target - self.q_table[state, action]
 56        self.q_table[state, action] += self.alpha * td_error
 57        return td_error
 58
 59
 60def train_qlearning():
 61    """Q-Learning์œผ๋กœ FrozenLake ํ•™์Šต"""
 62    env = gym.make('FrozenLake-v1', is_slippery=True)
 63    agent = QLearning(
 64        n_states=env.observation_space.n,
 65        n_actions=env.action_space.n,
 66        alpha=0.1,
 67        gamma=0.99,
 68        epsilon=1.0
 69    )
 70
 71    n_episodes = 10000
 72    rewards = []
 73
 74    for episode in range(n_episodes):
 75        state, _ = env.reset()
 76        total_reward = 0
 77        done = False
 78
 79        while not done:
 80            action = agent.choose_action(state)
 81            next_state, reward, terminated, truncated, _ = env.step(action)
 82            done = terminated or truncated
 83
 84            agent.update(state, action, reward, next_state, done)
 85            state = next_state
 86            total_reward += reward
 87
 88        rewards.append(total_reward)
 89        agent.epsilon = max(0.01, agent.epsilon * 0.9995)
 90
 91        if (episode + 1) % 1000 == 0:
 92            avg = np.mean(rewards[-100:])
 93            print(f"Episode {episode + 1}, Avg Reward: {avg:.3f}")
 94
 95    env.close()
 96    return agent
 97
 98
 99if __name__ == "__main__":
100    agent = train_qlearning()
101    print("\nํ•™์Šต ์™„๋ฃŒ!")
102    print(f"์ตœ์ข… Q ํ…Œ์ด๋ธ” shape: {agent.q_table.shape}")