1"""
2Q-Learning๊ณผ SARSA ๊ตฌํ
3"""
4import numpy as np
5import gymnasium as gym
6
7
8class QLearning:
9 """Q-Learning ์์ด์ ํธ"""
10
11 def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99, epsilon=0.1):
12 self.q_table = np.zeros((n_states, n_actions))
13 self.alpha = alpha
14 self.gamma = gamma
15 self.epsilon = epsilon
16 self.n_actions = n_actions
17
18 def choose_action(self, state):
19 if np.random.random() < self.epsilon:
20 return np.random.randint(self.n_actions)
21 return np.argmax(self.q_table[state])
22
23 def update(self, state, action, reward, next_state, done):
24 if done:
25 target = reward
26 else:
27 target = reward + self.gamma * np.max(self.q_table[next_state])
28
29 td_error = target - self.q_table[state, action]
30 self.q_table[state, action] += self.alpha * td_error
31 return td_error
32
33
34class SARSA:
35 """SARSA ์์ด์ ํธ"""
36
37 def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99, epsilon=0.1):
38 self.q_table = np.zeros((n_states, n_actions))
39 self.alpha = alpha
40 self.gamma = gamma
41 self.epsilon = epsilon
42 self.n_actions = n_actions
43
44 def choose_action(self, state):
45 if np.random.random() < self.epsilon:
46 return np.random.randint(self.n_actions)
47 return np.argmax(self.q_table[state])
48
49 def update(self, state, action, reward, next_state, next_action, done):
50 if done:
51 target = reward
52 else:
53 target = reward + self.gamma * self.q_table[next_state, next_action]
54
55 td_error = target - self.q_table[state, action]
56 self.q_table[state, action] += self.alpha * td_error
57 return td_error
58
59
60def train_qlearning():
61 """Q-Learning์ผ๋ก FrozenLake ํ์ต"""
62 env = gym.make('FrozenLake-v1', is_slippery=True)
63 agent = QLearning(
64 n_states=env.observation_space.n,
65 n_actions=env.action_space.n,
66 alpha=0.1,
67 gamma=0.99,
68 epsilon=1.0
69 )
70
71 n_episodes = 10000
72 rewards = []
73
74 for episode in range(n_episodes):
75 state, _ = env.reset()
76 total_reward = 0
77 done = False
78
79 while not done:
80 action = agent.choose_action(state)
81 next_state, reward, terminated, truncated, _ = env.step(action)
82 done = terminated or truncated
83
84 agent.update(state, action, reward, next_state, done)
85 state = next_state
86 total_reward += reward
87
88 rewards.append(total_reward)
89 agent.epsilon = max(0.01, agent.epsilon * 0.9995)
90
91 if (episode + 1) % 1000 == 0:
92 avg = np.mean(rewards[-100:])
93 print(f"Episode {episode + 1}, Avg Reward: {avg:.3f}")
94
95 env.close()
96 return agent
97
98
99if __name__ == "__main__":
100 agent = train_qlearning()
101 print("\nํ์ต ์๋ฃ!")
102 print(f"์ต์ข
Q ํ
์ด๋ธ shape: {agent.q_table.shape}")