09_lstm_gru.py

Download
python 395 lines 11.3 KB
  1"""
  209. LSTM๊ณผ GRU
  3
  4LSTM๊ณผ GRU์˜ ๊ตฌํ˜„๊ณผ ํ™œ์šฉ์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10import numpy as np
 11import matplotlib.pyplot as plt
 12
 13print("=" * 60)
 14print("PyTorch LSTM/GRU")
 15print("=" * 60)
 16
 17device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 18
 19
 20# ============================================
 21# 1. LSTM ๊ธฐ๋ณธ
 22# ============================================
 23print("\n[1] LSTM ๊ธฐ๋ณธ")
 24print("-" * 40)
 25
 26lstm = nn.LSTM(
 27    input_size=10,
 28    hidden_size=20,
 29    num_layers=2,
 30    batch_first=True,
 31    dropout=0.1
 32)
 33
 34# ์ž…๋ ฅ
 35x = torch.randn(4, 8, 10)  # (batch, seq, features)
 36
 37# ์ˆœ์ „ํŒŒ
 38output, (h_n, c_n) = lstm(x)
 39
 40print(f"์ž…๋ ฅ: {x.shape}")
 41print(f"output: {output.shape}")  # (4, 8, 20)
 42print(f"h_n (์€๋‹‰): {h_n.shape}")  # (2, 4, 20)
 43print(f"c_n (์…€): {c_n.shape}")    # (2, 4, 20)
 44
 45# ์ดˆ๊ธฐ ์ƒํƒœ ์ง€์ •
 46h0 = torch.zeros(2, 4, 20)
 47c0 = torch.zeros(2, 4, 20)
 48output, (h_n, c_n) = lstm(x, (h0, c0))
 49print(f"\n์ดˆ๊ธฐ ์ƒํƒœ ์ง€์ •: h0={h0.shape}, c0={c0.shape}")
 50
 51
 52# ============================================
 53# 2. GRU ๊ธฐ๋ณธ
 54# ============================================
 55print("\n[2] GRU ๊ธฐ๋ณธ")
 56print("-" * 40)
 57
 58gru = nn.GRU(
 59    input_size=10,
 60    hidden_size=20,
 61    num_layers=2,
 62    batch_first=True
 63)
 64
 65output, h_n = gru(x)
 66
 67print(f"GRU output: {output.shape}")
 68print(f"GRU h_n: {h_n.shape}")  # ์…€ ์ƒํƒœ ์—†์Œ
 69
 70
 71# ============================================
 72# 3. ์–‘๋ฐฉํ–ฅ LSTM
 73# ============================================
 74print("\n[3] ์–‘๋ฐฉํ–ฅ LSTM")
 75print("-" * 40)
 76
 77lstm_bi = nn.LSTM(
 78    input_size=10,
 79    hidden_size=20,
 80    num_layers=2,
 81    batch_first=True,
 82    bidirectional=True
 83)
 84
 85output_bi, (h_n_bi, c_n_bi) = lstm_bi(x)
 86
 87print(f"์–‘๋ฐฉํ–ฅ LSTM:")
 88print(f"  output: {output_bi.shape}")  # (4, 8, 40)
 89print(f"  h_n: {h_n_bi.shape}")        # (4, 4, 20)
 90
 91# ์ •๋ฐฉํ–ฅ/์—ญ๋ฐฉํ–ฅ ๋ถ„๋ฆฌ
 92forward_out = output_bi[:, :, :20]
 93backward_out = output_bi[:, :, 20:]
 94print(f"  ์ •๋ฐฉํ–ฅ: {forward_out.shape}")
 95print(f"  ์—ญ๋ฐฉํ–ฅ: {backward_out.shape}")
 96
 97
 98# ============================================
 99# 4. LSTM ๋ถ„๋ฅ˜๊ธฐ
100# ============================================
101print("\n[4] LSTM ๋ถ„๋ฅ˜๊ธฐ")
102print("-" * 40)
103
104class LSTMClassifier(nn.Module):
105    def __init__(self, input_size, hidden_size, num_classes,
106                 num_layers=2, bidirectional=True, dropout=0.3):
107        super().__init__()
108        self.hidden_size = hidden_size
109        self.num_layers = num_layers
110        self.bidirectional = bidirectional
111        self.num_directions = 2 if bidirectional else 1
112
113        self.lstm = nn.LSTM(
114            input_size, hidden_size,
115            num_layers=num_layers,
116            batch_first=True,
117            bidirectional=bidirectional,
118            dropout=dropout if num_layers > 1 else 0
119        )
120
121        self.dropout = nn.Dropout(dropout)
122        self.fc = nn.Linear(hidden_size * self.num_directions, num_classes)
123
124    def forward(self, x):
125        # x: (batch, seq, features)
126        output, (h_n, c_n) = self.lstm(x)
127
128        # ๋งˆ์ง€๋ง‰ ์ธต์˜ ์€๋‹‰ ์ƒํƒœ ๊ฒฐํ•ฉ
129        if self.bidirectional:
130            # ์ •๋ฐฉํ–ฅ ๋งˆ์ง€๋ง‰ + ์—ญ๋ฐฉํ–ฅ ๋งˆ์ง€๋ง‰
131            forward_last = h_n[-2]
132            backward_last = h_n[-1]
133            combined = torch.cat([forward_last, backward_last], dim=1)
134        else:
135            combined = h_n[-1]
136
137        dropped = self.dropout(combined)
138        return self.fc(dropped)
139
140model = LSTMClassifier(input_size=10, hidden_size=32, num_classes=5)
141out = model(x)
142print(f"๋ถ„๋ฅ˜๊ธฐ ์ถœ๋ ฅ: {out.shape}")
143
144
145# ============================================
146# 5. ์‹œ๊ณ„์—ด ์˜ˆ์ธก ๋น„๊ต (RNN vs LSTM vs GRU)
147# ============================================
148print("\n[5] RNN vs LSTM vs GRU ๋น„๊ต")
149print("-" * 40)
150
151# ๋” ๋ณต์žกํ•œ ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
152def generate_complex_series(seq_len=100, n_samples=1000):
153    X, y = [], []
154    for _ in range(n_samples):
155        t = np.linspace(0, 10*np.pi, seq_len + 1)
156        # ๋ณตํ•ฉ ํŒจํ„ด: sin + ๋…ธ์ด์ฆˆ + ์ถ”์„ธ
157        signal = np.sin(t) + 0.5*np.sin(3*t) + 0.1*t + np.random.randn(seq_len+1)*0.1
158        X.append(signal[:-1].reshape(-1, 1))
159        y.append(signal[-1])
160    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)
161
162X, y = generate_complex_series(seq_len=100, n_samples=2000)
163X_train, y_train = torch.from_numpy(X[:1600]), torch.from_numpy(y[:1600])
164X_test, y_test = torch.from_numpy(X[1600:]), torch.from_numpy(y[1600:])
165
166class TimeSeriesModel(nn.Module):
167    def __init__(self, model_type='lstm', hidden_size=64):
168        super().__init__()
169        if model_type == 'rnn':
170            self.rnn = nn.RNN(1, hidden_size, batch_first=True)
171        elif model_type == 'lstm':
172            self.rnn = nn.LSTM(1, hidden_size, batch_first=True)
173        elif model_type == 'gru':
174            self.rnn = nn.GRU(1, hidden_size, batch_first=True)
175
176        self.model_type = model_type
177        self.fc = nn.Linear(hidden_size, 1)
178
179    def forward(self, x):
180        if self.model_type == 'lstm':
181            _, (h_n, _) = self.rnn(x)
182        else:
183            _, h_n = self.rnn(x)
184        return self.fc(h_n[-1]).squeeze(-1)
185
186def train_model(model_type, epochs=30):
187    model = TimeSeriesModel(model_type).to(device)
188    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
189    criterion = nn.MSELoss()
190
191    train_loader = torch.utils.data.DataLoader(
192        torch.utils.data.TensorDataset(X_train, y_train),
193        batch_size=64, shuffle=True
194    )
195
196    losses = []
197    for epoch in range(epochs):
198        model.train()
199        epoch_loss = 0
200        for X_batch, y_batch in train_loader:
201            X_batch = X_batch.to(device)
202            y_batch = y_batch.to(device)
203
204            pred = model(X_batch)
205            loss = criterion(pred, y_batch)
206
207            optimizer.zero_grad()
208            loss.backward()
209            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
210            optimizer.step()
211
212            epoch_loss += loss.item()
213        losses.append(epoch_loss / len(train_loader))
214
215    # ํ…Œ์ŠคํŠธ
216    model.eval()
217    with torch.no_grad():
218        test_pred = model(X_test.to(device))
219        test_loss = criterion(test_pred, y_test.to(device)).item()
220
221    return losses, test_loss
222
223# ๋น„๊ต ์‹คํ–‰
224print("๋ชจ๋ธ ํ•™์Šต ์ค‘...")
225results = {}
226for model_type in ['rnn', 'lstm', 'gru']:
227    losses, test_loss = train_model(model_type)
228    results[model_type] = {'losses': losses, 'test_loss': test_loss}
229    print(f"  {model_type.upper()}: Test MSE = {test_loss:.6f}")
230
231# ์‹œ๊ฐํ™”
232plt.figure(figsize=(10, 5))
233for name, data in results.items():
234    plt.plot(data['losses'], label=f"{name.upper()} (test={data['test_loss']:.4f})")
235plt.xlabel('Epoch')
236plt.ylabel('Loss')
237plt.title('RNN vs LSTM vs GRU')
238plt.legend()
239plt.grid(True, alpha=0.3)
240plt.savefig('rnn_lstm_gru_comparison.png', dpi=100)
241plt.close()
242print("๊ทธ๋ž˜ํ”„ ์ €์žฅ: rnn_lstm_gru_comparison.png")
243
244
245# ============================================
246# 6. ํ…์ŠคํŠธ ๋ถ„๋ฅ˜ ์˜ˆ์ œ
247# ============================================
248print("\n[6] ํ…์ŠคํŠธ ๋ถ„๋ฅ˜ ์˜ˆ์ œ")
249print("-" * 40)
250
251class TextClassifier(nn.Module):
252    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
253        super().__init__()
254        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
255        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True,
256                           bidirectional=True, num_layers=2, dropout=0.3)
257        self.fc = nn.Sequential(
258            nn.Dropout(0.5),
259            nn.Linear(hidden_dim * 2, num_classes)
260        )
261
262    def forward(self, x):
263        # x: (batch, seq) - ํ† ํฐ ์ธ๋ฑ์Šค
264        embedded = self.embedding(x)
265        _, (h_n, _) = self.lstm(embedded)
266        combined = torch.cat([h_n[-2], h_n[-1]], dim=1)
267        return self.fc(combined)
268
269model = TextClassifier(vocab_size=10000, embed_dim=128,
270                       hidden_dim=256, num_classes=5)
271print(f"TextClassifier ํŒŒ๋ผ๋ฏธํ„ฐ: {sum(p.numel() for p in model.parameters()):,}")
272
273# ๋”๋ฏธ ์ž…๋ ฅ
274x = torch.randint(0, 10000, (8, 50))  # 8 ๋ฌธ์žฅ, 50 ํ† ํฐ
275out = model(x)
276print(f"์ž…๋ ฅ: {x.shape} โ†’ ์ถœ๋ ฅ: {out.shape}")
277
278
279# ============================================
280# 7. ์–ธ์–ด ๋ชจ๋ธ (ํ…์ŠคํŠธ ์ƒ์„ฑ)
281# ============================================
282print("\n[7] ์–ธ์–ด ๋ชจ๋ธ")
283print("-" * 40)
284
285class CharLSTM(nn.Module):
286    """๋ฌธ์ž ์ˆ˜์ค€ ์–ธ์–ด ๋ชจ๋ธ"""
287    def __init__(self, vocab_size, embed_dim, hidden_dim):
288        super().__init__()
289        self.embedding = nn.Embedding(vocab_size, embed_dim)
290        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
291        self.fc = nn.Linear(hidden_dim, vocab_size)
292
293    def forward(self, x, hidden=None):
294        embedded = self.embedding(x)
295        output, hidden = self.lstm(embedded, hidden)
296        logits = self.fc(output)
297        return logits, hidden
298
299    def generate(self, start_tokens, max_len=50, temperature=1.0):
300        self.eval()
301        tokens = list(start_tokens)
302        hidden = None
303
304        with torch.no_grad():
305            for _ in range(max_len):
306                x = torch.tensor([[tokens[-1]]])
307                logits, hidden = self(x, hidden)
308
309                # Temperature sampling
310                probs = F.softmax(logits[0, -1] / temperature, dim=0)
311                next_token = torch.multinomial(probs, 1).item()
312                tokens.append(next_token)
313
314        return tokens
315
316char_lm = CharLSTM(vocab_size=128, embed_dim=64, hidden_dim=256)
317print(f"CharLSTM ํŒŒ๋ผ๋ฏธํ„ฐ: {sum(p.numel() for p in char_lm.parameters()):,}")
318
319# ์ƒ์„ฑ ํ…Œ์ŠคํŠธ
320generated = char_lm.generate([65, 66, 67], max_len=20)  # ABC...
321print(f"์ƒ์„ฑ๋œ ํ† ํฐ: {generated[:10]}...")
322
323
324# ============================================
325# 8. LSTM ๋‚ด๋ถ€ ์‹œ๊ฐํ™”
326# ============================================
327print("\n[8] LSTM ๊ฒŒ์ดํŠธ ๋ถ„์„")
328print("-" * 40)
329
330class LSTMWithGates(nn.Module):
331    """๊ฒŒ์ดํŠธ ๊ฐ’์„ ๋ฐ˜ํ™˜ํ•˜๋Š” LSTM"""
332    def __init__(self, input_size, hidden_size):
333        super().__init__()
334        self.hidden_size = hidden_size
335        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
336
337    def forward(self, x):
338        batch_size, seq_len, _ = x.shape
339        h = torch.zeros(batch_size, self.hidden_size)
340        c = torch.zeros(batch_size, self.hidden_size)
341
342        outputs = []
343        gates = {'input': [], 'forget': [], 'output': []}
344
345        for t in range(seq_len):
346            h, c = self.lstm_cell(x[:, t], (h, c))
347            outputs.append(h)
348
349        return torch.stack(outputs, dim=1)
350
351# ํ…Œ์ŠคํŠธ
352lstm_gates = LSTMWithGates(10, 20)
353x = torch.randn(1, 30, 10)
354out = lstm_gates(x)
355print(f"๊ฒŒ์ดํŠธ ๋ถ„์„์šฉ LSTM ์ถœ๋ ฅ: {out.shape}")
356
357
358# ============================================
359# ์ •๋ฆฌ
360# ============================================
361print("\n" + "=" * 60)
362print("LSTM/GRU ์ •๋ฆฌ")
363print("=" * 60)
364
365summary = """
366LSTM:
367    output, (h_n, c_n) = lstm(x)
368    - ์…€ ์ƒํƒœ(c)๋กœ ์žฅ๊ธฐ ๊ธฐ์–ต ์œ ์ง€
369    - Forget, Input, Output ๊ฒŒ์ดํŠธ
370
371GRU:
372    output, h_n = gru(x)
373    - ์…€ ์ƒํƒœ ์—†์Œ, ๋” ๋‹จ์ˆœ
374    - Reset, Update ๊ฒŒ์ดํŠธ
375
376๋ถ„๋ฅ˜ ํŒจํ„ด:
377    # ์–‘๋ฐฉํ–ฅ LSTM
378    forward_last = h_n[-2]
379    backward_last = h_n[-1]
380    combined = torch.cat([forward_last, backward_last], dim=1)
381    output = fc(combined)
382
383ํ…์ŠคํŠธ ๋ถ„๋ฅ˜:
384    embedded = embedding(x)  # ํ† ํฐ โ†’ ๋ฒกํ„ฐ
385    _, (h_n, _) = lstm(embedded)
386    output = fc(h_n[-1])
387
388์„ ํƒ ๊ธฐ์ค€:
389    - ๊ธด ์‹œํ€€์Šค, ๋ณต์žกํ•œ ์˜์กด์„ฑ โ†’ LSTM
390    - ๋น ๋ฅธ ํ•™์Šต, ์ œํ•œ๋œ ์ž์› โ†’ GRU
391    - ๋‹จ์ˆœํ•œ ํŒจํ„ด โ†’ RNN๋„ ๊ฐ€๋Šฅ
392"""
393print(summary)
394print("=" * 60)