05_td_learning.py

Download
python 331 lines 10.5 KB
  1"""
  2TD Learning (Temporal Difference Learning) κ΅¬ν˜„
  3TD(0), SARSA, Q-Learning, Expected SARSA 포함
  4"""
  5import numpy as np
  6import gymnasium as gym
  7from collections import defaultdict
  8import matplotlib.pyplot as plt
  9
 10
 11class TD0Prediction:
 12    """TD(0) μ •μ±… 평가 μ•Œκ³ λ¦¬μ¦˜"""
 13
 14    def __init__(self, alpha=0.1, gamma=0.99):
 15        self.V = defaultdict(float)
 16        self.alpha = alpha
 17        self.gamma = gamma
 18
 19    def update(self, state, reward, next_state, done):
 20        """TD(0) κ°€μΉ˜ ν•¨μˆ˜ μ—…λ°μ΄νŠΈ"""
 21        if done:
 22            td_target = reward
 23        else:
 24            td_target = reward + self.gamma * self.V[next_state]
 25
 26        td_error = td_target - self.V[state]
 27        self.V[state] += self.alpha * td_error
 28        return td_error
 29
 30    def get_value(self, state):
 31        return self.V[state]
 32
 33
 34class SARSA:
 35    """SARSA (On-policy TD Control)"""
 36
 37    def __init__(self, n_actions, alpha=0.5, gamma=0.99, epsilon=0.1):
 38        self.Q = defaultdict(lambda: np.zeros(n_actions))
 39        self.alpha = alpha
 40        self.gamma = gamma
 41        self.epsilon = epsilon
 42        self.n_actions = n_actions
 43
 44    def choose_action(self, state):
 45        """Ξ΅-greedy μ •μ±…"""
 46        if np.random.random() < self.epsilon:
 47            return np.random.randint(self.n_actions)
 48        return np.argmax(self.Q[state])
 49
 50    def update(self, state, action, reward, next_state, next_action, done):
 51        """SARSA μ—…λ°μ΄νŠΈ: Q(s,a) ← Q(s,a) + Ξ±[r + Ξ³Q(s',a') - Q(s,a)]"""
 52        if done:
 53            td_target = reward
 54        else:
 55            td_target = reward + self.gamma * self.Q[next_state][next_action]
 56
 57        td_error = td_target - self.Q[state][action]
 58        self.Q[state][action] += self.alpha * td_error
 59        return td_error
 60
 61
 62class QLearning:
 63    """Q-Learning (Off-policy TD Control)"""
 64
 65    def __init__(self, n_actions, alpha=0.5, gamma=0.99, epsilon=0.1):
 66        self.Q = defaultdict(lambda: np.zeros(n_actions))
 67        self.alpha = alpha
 68        self.gamma = gamma
 69        self.epsilon = epsilon
 70        self.n_actions = n_actions
 71
 72    def choose_action(self, state):
 73        """Ξ΅-greedy μ •μ±…"""
 74        if np.random.random() < self.epsilon:
 75            return np.random.randint(self.n_actions)
 76        return np.argmax(self.Q[state])
 77
 78    def update(self, state, action, reward, next_state, done):
 79        """Q-Learning μ—…λ°μ΄νŠΈ: Q(s,a) ← Q(s,a) + Ξ±[r + Ξ³ max Q(s',a') - Q(s,a)]"""
 80        if done:
 81            td_target = reward
 82        else:
 83            td_target = reward + self.gamma * np.max(self.Q[next_state])
 84
 85        td_error = td_target - self.Q[state][action]
 86        self.Q[state][action] += self.alpha * td_error
 87        return td_error
 88
 89
 90class ExpectedSARSA:
 91    """Expected SARSA"""
 92
 93    def __init__(self, n_actions, alpha=0.5, gamma=0.99, epsilon=0.1):
 94        self.Q = defaultdict(lambda: np.zeros(n_actions))
 95        self.alpha = alpha
 96        self.gamma = gamma
 97        self.epsilon = epsilon
 98        self.n_actions = n_actions
 99
100    def choose_action(self, state):
101        """Ξ΅-greedy μ •μ±…"""
102        if np.random.random() < self.epsilon:
103            return np.random.randint(self.n_actions)
104        return np.argmax(self.Q[state])
105
106    def update(self, state, action, reward, next_state, done):
107        """Expected SARSA μ—…λ°μ΄νŠΈ: λ‹€μŒ μƒνƒœμ—μ„œ μ •μ±…μ˜ κΈ°λŒ“κ°’ μ‚¬μš©"""
108        if done:
109            td_target = reward
110        else:
111            # Ξ΅-greedy μ •μ±… ν•˜μ—μ„œ κΈ°λŒ“κ°’ 계산
112            best_action = np.argmax(self.Q[next_state])
113            expected_q = 0.0
114            for a in range(self.n_actions):
115                if a == best_action:
116                    prob = 1 - self.epsilon + self.epsilon / self.n_actions
117                else:
118                    prob = self.epsilon / self.n_actions
119                expected_q += prob * self.Q[next_state][a]
120
121            td_target = reward + self.gamma * expected_q
122
123        td_error = td_target - self.Q[state][action]
124        self.Q[state][action] += self.alpha * td_error
125        return td_error
126
127
128def train_sarsa(env_name='CliffWalking-v0', n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
129    """SARSA ν•™μŠ΅"""
130    env = gym.make(env_name)
131    agent = SARSA(env.action_space.n, alpha=alpha, gamma=gamma, epsilon=epsilon)
132
133    episode_rewards = []
134
135    for episode in range(n_episodes):
136        state, _ = env.reset()
137        action = agent.choose_action(state)
138        total_reward = 0
139        done = False
140
141        while not done:
142            next_state, reward, terminated, truncated, _ = env.step(action)
143            done = terminated or truncated
144
145            next_action = agent.choose_action(next_state)
146            agent.update(state, action, reward, next_state, next_action, done)
147
148            state = next_state
149            action = next_action
150            total_reward += reward
151
152        episode_rewards.append(total_reward)
153
154        if (episode + 1) % 100 == 0:
155            avg = np.mean(episode_rewards[-100:])
156            print(f"SARSA - Episode {episode + 1}: avg_reward = {avg:.1f}")
157
158    env.close()
159    return agent, episode_rewards
160
161
162def train_qlearning(env_name='CliffWalking-v0', n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
163    """Q-Learning ν•™μŠ΅"""
164    env = gym.make(env_name)
165    agent = QLearning(env.action_space.n, alpha=alpha, gamma=gamma, epsilon=epsilon)
166
167    episode_rewards = []
168
169    for episode in range(n_episodes):
170        state, _ = env.reset()
171        total_reward = 0
172        done = False
173
174        while not done:
175            action = agent.choose_action(state)
176            next_state, reward, terminated, truncated, _ = env.step(action)
177            done = terminated or truncated
178
179            agent.update(state, action, reward, next_state, done)
180
181            state = next_state
182            total_reward += reward
183
184        episode_rewards.append(total_reward)
185
186        if (episode + 1) % 100 == 0:
187            avg = np.mean(episode_rewards[-100:])
188            print(f"Q-Learning - Episode {episode + 1}: avg_reward = {avg:.1f}")
189
190    env.close()
191    return agent, episode_rewards
192
193
194def compare_td_methods():
195    """TD 방법듀 비ꡐ: SARSA vs Q-Learning vs Expected SARSA"""
196    print("=== CliffWalking ν™˜κ²½μ—μ„œ TD 방법 비ꡐ ===\n")
197
198    # SARSA (μ•ˆμ „ν•œ 경둜 μ„ ν˜Έ)
199    print("SARSA ν•™μŠ΅ 쀑...")
200    _, sarsa_rewards = train_sarsa(n_episodes=500)
201
202    # Q-Learning (졜적 경둜 ν•™μŠ΅, 더 μœ„ν—˜)
203    print("\nQ-Learning ν•™μŠ΅ 쀑...")
204    _, qlearning_rewards = train_qlearning(n_episodes=500)
205
206    # Expected SARSA
207    print("\nExpected SARSA ν•™μŠ΅ 쀑...")
208    env = gym.make('CliffWalking-v0')
209    expected_sarsa = ExpectedSARSA(env.action_space.n, alpha=0.5, gamma=1.0, epsilon=0.1)
210    expected_rewards = []
211
212    for episode in range(500):
213        state, _ = env.reset()
214        total_reward = 0
215        done = False
216
217        while not done:
218            action = expected_sarsa.choose_action(state)
219            next_state, reward, terminated, truncated, _ = env.step(action)
220            done = terminated or truncated
221
222            expected_sarsa.update(state, action, reward, next_state, done)
223
224            state = next_state
225            total_reward += reward
226
227        expected_rewards.append(total_reward)
228
229        if (episode + 1) % 100 == 0:
230            avg = np.mean(expected_rewards[-100:])
231            print(f"Expected SARSA - Episode {episode + 1}: avg_reward = {avg:.1f}")
232
233    env.close()
234
235    # ν•™μŠ΅ 곑선 μ‹œκ°ν™”
236    plot_comparison(sarsa_rewards, qlearning_rewards, expected_rewards)
237
238    return sarsa_rewards, qlearning_rewards, expected_rewards
239
240
241def plot_comparison(sarsa_rewards, qlearning_rewards, expected_rewards):
242    """ν•™μŠ΅ 곑선 비ꡐ μ‹œκ°ν™”"""
243    window = 10
244
245    def smooth(data, window):
246        return np.convolve(data, np.ones(window)/window, mode='valid')
247
248    plt.figure(figsize=(12, 5))
249
250    # 원본 데이터
251    plt.subplot(1, 2, 1)
252    plt.plot(sarsa_rewards, alpha=0.3, label='SARSA (raw)')
253    plt.plot(qlearning_rewards, alpha=0.3, label='Q-Learning (raw)')
254    plt.plot(expected_rewards, alpha=0.3, label='Expected SARSA (raw)')
255    plt.xlabel('Episode')
256    plt.ylabel('Episode Reward')
257    plt.title('TD Methods Comparison - Raw Data')
258    plt.legend()
259    plt.grid(True, alpha=0.3)
260
261    # ν‰ν™œν™”λœ 데이터
262    plt.subplot(1, 2, 2)
263    plt.plot(smooth(sarsa_rewards, window), label='SARSA (smoothed)', linewidth=2)
264    plt.plot(smooth(qlearning_rewards, window), label='Q-Learning (smoothed)', linewidth=2)
265    plt.plot(smooth(expected_rewards, window), label='Expected SARSA (smoothed)', linewidth=2)
266    plt.xlabel('Episode')
267    plt.ylabel('Episode Reward (smoothed)')
268    plt.title(f'TD Methods Comparison - Smoothed (window={window})')
269    plt.legend()
270    plt.grid(True, alpha=0.3)
271
272    plt.tight_layout()
273    plt.savefig('td_methods_comparison.png', dpi=150)
274    print("\nν•™μŠ΅ 곑선이 'td_methods_comparison.png'둜 μ €μž₯λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
275
276
277def visualize_policy(agent, env_name='CliffWalking-v0'):
278    """ν•™μŠ΅λœ μ •μ±… μ‹œκ°ν™” (CliffWalking μ „μš©)"""
279    if env_name != 'CliffWalking-v0':
280        print("μ •μ±… μ‹œκ°ν™”λŠ” CliffWalking ν™˜κ²½λ§Œ μ§€μ›ν•©λ‹ˆλ‹€.")
281        return
282
283    print("\n=== ν•™μŠ΅λœ μ •μ±… (4x12 κ·Έλ¦¬λ“œ) ===")
284    arrows = {0: '^', 1: '>', 2: 'v', 3: '<'}
285
286    for row in range(4):
287        line = ""
288        for col in range(12):
289            state = row * 12 + col
290            if state == 36:  # μ‹œμž‘μ 
291                line += " S "
292            elif state == 47:  # λͺ©ν‘œ
293                line += " G "
294            elif 37 <= state <= 46:  # 절벽
295                line += " C "
296            else:
297                action = np.argmax(agent.Q[state])
298                line += f" {arrows[action]} "
299        print(line)
300
301    print("\n(S: μ‹œμž‘, G: λͺ©ν‘œ, C: 절벽, ^>v<: 행동 λ°©ν–₯)")
302
303
304if __name__ == "__main__":
305    # TD 방법 비ꡐ
306    sarsa_rewards, qlearning_rewards, expected_rewards = compare_td_methods()
307
308    # SARSA μ •μ±… μ‹œκ°ν™”
309    print("\n" + "="*50)
310    env = gym.make('CliffWalking-v0')
311    sarsa_agent = SARSA(env.action_space.n)
312
313    # λ‹€μ‹œ ν•™μŠ΅ (μ‹œκ°ν™”μš©)
314    for episode in range(500):
315        state, _ = env.reset()
316        action = sarsa_agent.choose_action(state)
317        done = False
318
319        while not done:
320            next_state, reward, terminated, truncated, _ = env.step(action)
321            done = terminated or truncated
322            next_action = sarsa_agent.choose_action(next_state)
323            sarsa_agent.update(state, action, reward, next_state, next_action, done)
324            state = next_state
325            action = next_action
326
327    print("\nSARSA ν•™μŠ΅ μ™„λ£Œ - μ•ˆμ „ν•œ 경둜 μ„ ν˜Έ")
328    visualize_policy(sarsa_agent)
329
330    env.close()