12_practical_project.py

Download
python 579 lines 19.5 KB
  1"""
  2μ‹€μ „ RL ν”„λ‘œμ νŠΈ: μ™„μ „ν•œ ꡬ쑰의 νŠΈλ ˆμ΄λ”© ν™˜κ²½ κ΅¬ν˜„
  3ν”„λ‘œμ νŠΈ ꡬ쑰, ν•™μŠ΅ νŒŒμ΄ν”„λΌμΈ, λͺ¨λΈ μ €μž₯/λ‘œλ“œ, 평가 및 μ‹œκ°ν™” 포함
  4"""
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8import numpy as np
  9import matplotlib.pyplot as plt
 10from collections import deque
 11import json
 12import os
 13
 14
 15# =============================================================================
 16# 1. μ‚¬μš©μž μ •μ˜ ν™˜κ²½: κ°„λ‹¨ν•œ νŠΈλ ˆμ΄λ”© ν™˜κ²½
 17# =============================================================================
 18
 19class SimpleTradingEnv:
 20    """
 21    κ°„λ‹¨ν•œ 주식 νŠΈλ ˆμ΄λ”© ν™˜κ²½
 22    - κ΄€μΈ‘: ν˜„μž¬ 가격, 이동평균, 보유 주식 수
 23    - 행동: 0(맀도), 1(보유), 2(맀수)
 24    - 보상: 포트폴리였 κ°€μΉ˜ λ³€ν™”
 25    """
 26
 27    def __init__(self, initial_balance=10000, stock_dim=1, max_steps=100):
 28        self.initial_balance = initial_balance
 29        self.stock_dim = stock_dim
 30        self.max_steps = max_steps
 31
 32        # κ΄€μΈ‘ 곡간: [가격, 5일 평균, 20일 평균, 보유 주식 수, ν˜„κΈˆ λΉ„μœ¨]
 33        self.obs_dim = 5
 34        # 행동 곡간: 맀도(0), 보유(1), 맀수(2)
 35        self.action_dim = 3
 36
 37        self.reset()
 38
 39    def _generate_price_series(self):
 40        """가격 μ‹œκ³„μ—΄ 생성 (랜덀 μ›Œν¬ + νŠΈλ Œλ“œ)"""
 41        trend = np.random.choice([-1, 0, 1])  # ν•˜λ½, 횑보, μƒμŠΉ
 42        prices = [100.0]
 43
 44        for _ in range(self.max_steps):
 45            # 랜덀 μ›Œν¬ + νŠΈλ Œλ“œ
 46            change = np.random.randn() * 2 + trend * 0.5
 47            new_price = max(50.0, prices[-1] + change)  # μ΅œμ†Œ 가격 μ œν•œ
 48            prices.append(new_price)
 49
 50        return np.array(prices)
 51
 52    def reset(self):
 53        """ν™˜κ²½ μ΄ˆκΈ°ν™”"""
 54        self.prices = self._generate_price_series()
 55        self.current_step = 0
 56
 57        self.balance = self.initial_balance
 58        self.shares_held = 0
 59        self.total_shares_bought = 0
 60        self.total_shares_sold = 0
 61
 62        return self._get_observation()
 63
 64    def _get_observation(self):
 65        """ν˜„μž¬ κ΄€μΈ‘ λ°˜ν™˜"""
 66        # 가격 정보
 67        current_price = self.prices[self.current_step]
 68
 69        # 이동평균 계산
 70        start_5 = max(0, self.current_step - 5)
 71        start_20 = max(0, self.current_step - 20)
 72        ma5 = np.mean(self.prices[start_5:self.current_step + 1])
 73        ma20 = np.mean(self.prices[start_20:self.current_step + 1])
 74
 75        # 포트폴리였 정보
 76        total_value = self.balance + self.shares_held * current_price
 77        cash_ratio = self.balance / total_value if total_value > 0 else 0
 78
 79        obs = np.array([
 80            current_price / 100.0,  # μ •κ·œν™”
 81            ma5 / 100.0,
 82            ma20 / 100.0,
 83            self.shares_held / 100.0,
 84            cash_ratio
 85        ], dtype=np.float32)
 86
 87        return obs
 88
 89    def step(self, action):
 90        """ν™˜κ²½ μŠ€ν…"""
 91        current_price = self.prices[self.current_step]
 92        prev_value = self.balance + self.shares_held * current_price
 93
 94        # 행동 μ‹€ν–‰
 95        if action == 0:  # 맀도
 96            if self.shares_held > 0:
 97                self.balance += self.shares_held * current_price * 0.99  # 수수료 1%
 98                self.total_shares_sold += self.shares_held
 99                self.shares_held = 0
100
101        elif action == 2:  # 맀수
102            shares_to_buy = self.balance // current_price
103            if shares_to_buy > 0:
104                cost = shares_to_buy * current_price * 1.01  # 수수료 1%
105                if cost <= self.balance:
106                    self.shares_held += shares_to_buy
107                    self.balance -= cost
108                    self.total_shares_bought += shares_to_buy
109
110        # λ‹€μŒ μŠ€ν…μœΌλ‘œ
111        self.current_step += 1
112
113        # 보상 계산: 포트폴리였 κ°€μΉ˜ λ³€ν™”
114        next_price = self.prices[self.current_step]
115        current_value = self.balance + self.shares_held * next_price
116        reward = (current_value - prev_value) / prev_value
117
118        # μ’…λ£Œ 쑰건
119        done = self.current_step >= self.max_steps - 1
120
121        # μ΅œμ’… 보상 κ°€μ‚°
122        if done:
123            # μ΅œμ’… 수읡λ₯ μ— λ”°λ₯Έ λ³΄λ„ˆμŠ€/νŽ˜λ„ν‹°
124            total_return = (current_value - self.initial_balance) / self.initial_balance
125            reward += total_return * 10
126
127        return self._get_observation(), reward, done, {}
128
129
130# =============================================================================
131# 2. λ„€νŠΈμ›Œν¬ μ•„ν‚€ν…μ²˜
132# =============================================================================
133
134class TradingPolicyNetwork(nn.Module):
135    """νŠΈλ ˆμ΄λ”© μ •μ±… λ„€νŠΈμ›Œν¬"""
136
137    def __init__(self, obs_dim, action_dim, hidden_dim=128):
138        super().__init__()
139
140        self.feature_extractor = nn.Sequential(
141            nn.Linear(obs_dim, hidden_dim),
142            nn.ReLU(),
143            nn.Linear(hidden_dim, hidden_dim),
144            nn.ReLU()
145        )
146
147        self.actor = nn.Linear(hidden_dim, action_dim)
148        self.critic = nn.Linear(hidden_dim, 1)
149
150        # κ°€μ€‘μΉ˜ μ΄ˆκΈ°ν™”
151        self._init_weights()
152
153    def _init_weights(self):
154        for m in self.modules():
155            if isinstance(m, nn.Linear):
156                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
157                nn.init.constant_(m.bias, 0)
158
159    def forward(self, obs):
160        features = self.feature_extractor(obs)
161        return F.softmax(self.actor(features), dim=-1), self.critic(features)
162
163    def get_action_and_value(self, obs, action=None):
164        probs, value = self.forward(obs)
165        dist = torch.distributions.Categorical(probs)
166
167        if action is None:
168            action = dist.sample()
169
170        return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)
171
172
173# =============================================================================
174# 3. PPO μ—μ΄μ „νŠΈ (ν”„λ‘œμ νŠΈμš©)
175# =============================================================================
176
177class PPOAgent:
178    """ν”„λ‘œμ νŠΈμš© PPO μ—μ΄μ „νŠΈ"""
179
180    def __init__(self, config):
181        self.config = config
182
183        # λ„€νŠΈμ›Œν¬ 생성
184        self.network = TradingPolicyNetwork(
185            obs_dim=config['obs_dim'],
186            action_dim=config['action_dim'],
187            hidden_dim=config.get('hidden_dim', 128)
188        )
189
190        self.optimizer = torch.optim.Adam(
191            self.network.parameters(),
192            lr=config.get('lr', 3e-4)
193        )
194
195        # ν•˜μ΄νΌνŒŒλΌλ―Έν„°
196        self.gamma = config.get('gamma', 0.99)
197        self.gae_lambda = config.get('gae_lambda', 0.95)
198        self.clip_epsilon = config.get('clip_epsilon', 0.2)
199        self.value_coef = config.get('value_coef', 0.5)
200        self.entropy_coef = config.get('entropy_coef', 0.01)
201        self.max_grad_norm = config.get('max_grad_norm', 0.5)
202        self.n_epochs = config.get('n_epochs', 10)
203        self.batch_size = config.get('batch_size', 64)
204
205    def collect_rollout(self, env, n_steps):
206        """κ²½ν—˜ μˆ˜μ§‘"""
207        rollout = {
208            'obs': [], 'actions': [], 'rewards': [], 'dones': [],
209            'values': [], 'log_probs': []
210        }
211
212        obs = env.reset()
213
214        for _ in range(n_steps):
215            obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
216
217            with torch.no_grad():
218                action, log_prob, _, value = self.network.get_action_and_value(obs_tensor)
219
220            next_obs, reward, done, _ = env.step(action.item())
221
222            rollout['obs'].append(obs)
223            rollout['actions'].append(action.item())
224            rollout['rewards'].append(reward)
225            rollout['dones'].append(done)
226            rollout['values'].append(value.item())
227            rollout['log_probs'].append(log_prob.item())
228
229            obs = next_obs if not done else env.reset()
230
231        # λ§ˆμ§€λ§‰ κ°€μΉ˜ μΆ”μ •
232        with torch.no_grad():
233            _, _, _, last_value = self.network.get_action_and_value(
234                torch.FloatTensor(obs).unsqueeze(0)
235            )
236            rollout['last_value'] = last_value.item()
237
238        # NumPy λ°°μ—΄λ‘œ λ³€ν™˜
239        for key in ['obs', 'actions', 'rewards', 'dones', 'values', 'log_probs']:
240            rollout[key] = np.array(rollout[key])
241
242        return rollout
243
244    def compute_gae(self, rollout):
245        """GAE 계산"""
246        rewards = rollout['rewards']
247        values = rollout['values']
248        dones = rollout['dones']
249        last_value = rollout['last_value']
250
251        advantages = np.zeros_like(rewards)
252        last_gae = 0
253
254        for t in reversed(range(len(rewards))):
255            next_val = last_value if t == len(rewards) - 1 else values[t + 1]
256            delta = rewards[t] + self.gamma * next_val * (1 - dones[t]) - values[t]
257            advantages[t] = last_gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_gae
258
259        returns = advantages + values
260        return advantages, returns
261
262    def update(self, rollout):
263        """PPO μ—…λ°μ΄νŠΈ"""
264        advantages, returns = self.compute_gae(rollout)
265        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
266
267        # ν…μ„œ λ³€ν™˜
268        obs = torch.FloatTensor(rollout['obs'])
269        actions = torch.LongTensor(rollout['actions'])
270        old_log_probs = torch.FloatTensor(rollout['log_probs'])
271        advantages_tensor = torch.FloatTensor(advantages)
272        returns_tensor = torch.FloatTensor(returns)
273
274        # μ—¬λŸ¬ 에폭 ν•™μŠ΅
275        total_loss = 0
276        n_updates = 0
277
278        for _ in range(self.n_epochs):
279            indices = np.random.permutation(len(obs))
280
281            for start in range(0, len(obs), self.batch_size):
282                idx = indices[start:start + self.batch_size]
283
284                _, new_log_probs, entropy, values = self.network.get_action_and_value(
285                    obs[idx], actions[idx]
286                )
287
288                # PPO loss
289                ratio = torch.exp(new_log_probs - old_log_probs[idx])
290                surr1 = ratio * advantages_tensor[idx]
291                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages_tensor[idx]
292
293                actor_loss = -torch.min(surr1, surr2).mean()
294                critic_loss = F.mse_loss(values, returns_tensor[idx])
295                entropy_loss = -entropy.mean()
296
297                loss = actor_loss + self.value_coef * critic_loss + self.entropy_coef * entropy_loss
298
299                self.optimizer.zero_grad()
300                loss.backward()
301                nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
302                self.optimizer.step()
303
304                total_loss += loss.item()
305                n_updates += 1
306
307        return total_loss / n_updates if n_updates > 0 else 0
308
309    def save(self, filepath):
310        """λͺ¨λΈ μ €μž₯"""
311        torch.save({
312            'network_state_dict': self.network.state_dict(),
313            'optimizer_state_dict': self.optimizer.state_dict(),
314            'config': self.config
315        }, filepath)
316
317    def load(self, filepath):
318        """λͺ¨λΈ λ‘œλ“œ"""
319        checkpoint = torch.load(filepath)
320        self.network.load_state_dict(checkpoint['network_state_dict'])
321        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
322
323
324# =============================================================================
325# 4. ν•™μŠ΅ νŒŒμ΄ν”„λΌμΈ
326# =============================================================================
327
328class TrainingLogger:
329    """ν•™μŠ΅ 둜거"""
330
331    def __init__(self, log_dir='logs'):
332        self.log_dir = log_dir
333        os.makedirs(log_dir, exist_ok=True)
334
335        self.metrics = {
336            'episodes': [],
337            'rewards': [],
338            'losses': [],
339            'returns': []
340        }
341
342    def log(self, episode, reward, loss, portfolio_return):
343        """λ©”νŠΈλ¦­ 기둝"""
344        self.metrics['episodes'].append(episode)
345        self.metrics['rewards'].append(reward)
346        self.metrics['losses'].append(loss)
347        self.metrics['returns'].append(portfolio_return)
348
349    def save(self):
350        """둜그 μ €μž₯"""
351        filepath = os.path.join(self.log_dir, 'training_log.json')
352        with open(filepath, 'w') as f:
353            json.dump(self.metrics, f, indent=2)
354
355    def plot(self):
356        """ν•™μŠ΅ μ§„ν–‰ μ‹œκ°ν™”"""
357        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
358
359        # μ—ν”Όμ†Œλ“œ 보상
360        axes[0, 0].plot(self.metrics['episodes'], self.metrics['rewards'], alpha=0.3)
361        if len(self.metrics['rewards']) > 10:
362            window = min(50, len(self.metrics['rewards']) // 10)
363            smoothed = np.convolve(self.metrics['rewards'], np.ones(window)/window, mode='valid')
364            axes[0, 0].plot(range(window-1, len(self.metrics['rewards'])), smoothed, linewidth=2)
365        axes[0, 0].set_title('μ—ν”Όμ†Œλ“œ 보상')
366        axes[0, 0].set_xlabel('Episode')
367        axes[0, 0].set_ylabel('Reward')
368        axes[0, 0].grid(True, alpha=0.3)
369
370        # 손싀
371        axes[0, 1].plot(self.metrics['episodes'], self.metrics['losses'])
372        axes[0, 1].set_title('ν•™μŠ΅ 손싀')
373        axes[0, 1].set_xlabel('Episode')
374        axes[0, 1].set_ylabel('Loss')
375        axes[0, 1].grid(True, alpha=0.3)
376
377        # 수읡λ₯ 
378        axes[1, 0].plot(self.metrics['episodes'], self.metrics['returns'], alpha=0.3)
379        if len(self.metrics['returns']) > 10:
380            window = min(50, len(self.metrics['returns']) // 10)
381            smoothed = np.convolve(self.metrics['returns'], np.ones(window)/window, mode='valid')
382            axes[1, 0].plot(range(window-1, len(self.metrics['returns'])), smoothed, linewidth=2)
383        axes[1, 0].axhline(y=0, color='r', linestyle='--', alpha=0.3)
384        axes[1, 0].set_title('포트폴리였 수읡λ₯ ')
385        axes[1, 0].set_xlabel('Episode')
386        axes[1, 0].set_ylabel('Return (%)')
387        axes[1, 0].grid(True, alpha=0.3)
388
389        # 수읡λ₯  뢄포
390        axes[1, 1].hist(self.metrics['returns'], bins=30, alpha=0.7, edgecolor='black')
391        axes[1, 1].axvline(x=0, color='r', linestyle='--', linewidth=2)
392        axes[1, 1].set_title('수읡λ₯  뢄포')
393        axes[1, 1].set_xlabel('Return (%)')
394        axes[1, 1].set_ylabel('Frequency')
395        axes[1, 1].grid(True, alpha=0.3)
396
397        plt.tight_layout()
398        plt.savefig(os.path.join(self.log_dir, 'training_progress.png'), dpi=100, bbox_inches='tight')
399        print(f"ν•™μŠ΅ κ·Έλž˜ν”„ μ €μž₯: {self.log_dir}/training_progress.png")
400
401
402def train_agent(config):
403    """μ—μ΄μ „νŠΈ ν•™μŠ΅"""
404    # ν™˜κ²½ 생성
405    env = SimpleTradingEnv(
406        initial_balance=config['initial_balance'],
407        max_steps=config['max_steps']
408    )
409
410    # μ—μ΄μ „νŠΈ 생성
411    agent_config = {
412        'obs_dim': env.obs_dim,
413        'action_dim': env.action_dim,
414        'hidden_dim': config['hidden_dim'],
415        'lr': config['lr'],
416        'gamma': config['gamma'],
417        'gae_lambda': config['gae_lambda'],
418        'clip_epsilon': config['clip_epsilon'],
419        'n_epochs': config['n_epochs'],
420        'batch_size': config['batch_size']
421    }
422    agent = PPOAgent(agent_config)
423
424    # 둜거
425    logger = TrainingLogger(log_dir=config['log_dir'])
426
427    # ν•™μŠ΅ 루프
428    n_episodes = config['n_episodes']
429    n_steps = config['n_steps']
430
431    print("ν•™μŠ΅ μ‹œμž‘...\n")
432
433    for episode in range(n_episodes):
434        # 둀아웃 μˆ˜μ§‘
435        rollout = agent.collect_rollout(env, n_steps)
436
437        # μ—ν”Όμ†Œλ“œ 톡계
438        episode_reward = rollout['rewards'].sum()
439        final_obs = rollout['obs'][-1]
440        portfolio_return = (final_obs[3] * 100 + final_obs[4] * config['initial_balance'] - config['initial_balance']) / config['initial_balance'] * 100
441
442        # μ—…λ°μ΄νŠΈ
443        loss = agent.update(rollout)
444
445        # λ‘œκΉ…
446        logger.log(episode, episode_reward, loss, portfolio_return)
447
448        if (episode + 1) % config['log_interval'] == 0:
449            avg_reward = np.mean(logger.metrics['rewards'][-config['log_interval']:])
450            avg_return = np.mean(logger.metrics['returns'][-config['log_interval']:])
451            print(f"Episode {episode + 1}/{n_episodes} | "
452                  f"Avg Reward: {avg_reward:.2f} | "
453                  f"Avg Return: {avg_return:.2f}% | "
454                  f"Loss: {loss:.4f}")
455
456        # 체크포인트 μ €μž₯
457        if (episode + 1) % config['save_interval'] == 0:
458            save_path = os.path.join(config['checkpoint_dir'], f'agent_ep{episode + 1}.pt')
459            agent.save(save_path)
460            print(f"  체크포인트 μ €μž₯: {save_path}")
461
462    # μ΅œμ’… λͺ¨λΈ μ €μž₯
463    final_path = os.path.join(config['checkpoint_dir'], 'agent_final.pt')
464    agent.save(final_path)
465
466    # 둜그 μ €μž₯ 및 μ‹œκ°ν™”
467    logger.save()
468    logger.plot()
469
470    print("\nν•™μŠ΅ μ™„λ£Œ!")
471    return agent, logger
472
473
474# =============================================================================
475# 5. 평가
476# =============================================================================
477
478def evaluate_agent(agent, n_episodes=10, render=False):
479    """ν•™μŠ΅λœ μ—μ΄μ „νŠΈ 평가"""
480    env = SimpleTradingEnv()
481    episode_returns = []
482
483    print("\n=== μ—μ΄μ „νŠΈ 평가 ===\n")
484
485    for episode in range(n_episodes):
486        obs = env.reset()
487        total_reward = 0
488        done = False
489
490        while not done:
491            obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
492
493            with torch.no_grad():
494                action, _, _, _ = agent.network.get_action_and_value(obs_tensor)
495
496            obs, reward, done, _ = env.step(action.item())
497            total_reward += reward
498
499        # μ΅œμ’… 수읡λ₯ 
500        final_value = env.balance + env.shares_held * env.prices[env.current_step]
501        portfolio_return = (final_value - env.initial_balance) / env.initial_balance * 100
502
503        episode_returns.append(portfolio_return)
504        print(f"Episode {episode + 1}: Return = {portfolio_return:.2f}%")
505
506    mean_return = np.mean(episode_returns)
507    std_return = np.std(episode_returns)
508
509    print(f"\n평균 수읡λ₯ : {mean_return:.2f}% Β± {std_return:.2f}%")
510
511    return episode_returns
512
513
514# =============================================================================
515# 6. 메인
516# =============================================================================
517
518if __name__ == "__main__":
519    # ν”„λ‘œμ νŠΈ μ„€μ •
520    config = {
521        # ν™˜κ²½
522        'initial_balance': 10000,
523        'max_steps': 100,
524
525        # λ„€νŠΈμ›Œν¬
526        'hidden_dim': 128,
527
528        # ν•™μŠ΅
529        'lr': 3e-4,
530        'gamma': 0.99,
531        'gae_lambda': 0.95,
532        'clip_epsilon': 0.2,
533        'n_epochs': 10,
534        'batch_size': 64,
535
536        # ν•™μŠ΅ νŒŒλΌλ―Έν„°
537        'n_episodes': 500,
538        'n_steps': 100,
539        'log_interval': 50,
540        'save_interval': 100,
541
542        # 디렉토리
543        'log_dir': 'logs',
544        'checkpoint_dir': 'checkpoints'
545    }
546
547    # 디렉토리 생성
548    os.makedirs(config['log_dir'], exist_ok=True)
549    os.makedirs(config['checkpoint_dir'], exist_ok=True)
550
551    print("=" * 60)
552    print("μ‹€μ „ RL ν”„λ‘œμ νŠΈ: νŠΈλ ˆμ΄λ”© μ—μ΄μ „νŠΈ")
553    print("=" * 60)
554
555    # μ„€μ • 좜λ ₯
556    print("\nμ„€μ •:")
557    for key, value in config.items():
558        print(f"  {key}: {value}")
559
560    # ν•™μŠ΅
561    agent, logger = train_agent(config)
562
563    # 평가
564    returns = evaluate_agent(agent, n_episodes=20)
565
566    print("\nν”„λ‘œμ νŠΈ μ™„λ£Œ!")
567    print("\nμƒμ„±λœ 파일:")
568    print(f"  - {config['log_dir']}/training_log.json")
569    print(f"  - {config['log_dir']}/training_progress.png")
570    print(f"  - {config['checkpoint_dir']}/agent_final.pt")
571
572    print("\nμ£Όμš” ν•™μŠ΅ λ‚΄μš©:")
573    print("  1. μ‚¬μš©μž μ •μ˜ ν™˜κ²½ κ΅¬ν˜„ (Gymnasium μŠ€νƒ€μΌ)")
574    print("  2. λͺ¨λ“ˆν™”λœ PPO μ—μ΄μ „νŠΈ")
575    print("  3. ν•™μŠ΅ νŒŒμ΄ν”„λΌμΈ (μˆ˜μ§‘-μ—…λ°μ΄νŠΈ-λ‘œκΉ…)")
576    print("  4. λͺ¨λΈ μ €μž₯/λ‘œλ“œ")
577    print("  5. 평가 및 μ‹œκ°ν™”")
578    print("  6. ν”„λ‘œμ νŠΈ ꡬ쑰 λͺ¨λ²” 사둀")