1"""
209. LSTM๊ณผ GRU
3
4LSTM๊ณผ GRU์ ๊ตฌํ๊ณผ ํ์ฉ์ ํ์ตํฉ๋๋ค.
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10import numpy as np
11import matplotlib.pyplot as plt
12
13print("=" * 60)
14print("PyTorch LSTM/GRU")
15print("=" * 60)
16
17device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
19
20# ============================================
21# 1. LSTM ๊ธฐ๋ณธ
22# ============================================
23print("\n[1] LSTM ๊ธฐ๋ณธ")
24print("-" * 40)
25
26lstm = nn.LSTM(
27 input_size=10,
28 hidden_size=20,
29 num_layers=2,
30 batch_first=True,
31 dropout=0.1
32)
33
34# ์
๋ ฅ
35x = torch.randn(4, 8, 10) # (batch, seq, features)
36
37# ์์ ํ
38output, (h_n, c_n) = lstm(x)
39
40print(f"์
๋ ฅ: {x.shape}")
41print(f"output: {output.shape}") # (4, 8, 20)
42print(f"h_n (์๋): {h_n.shape}") # (2, 4, 20)
43print(f"c_n (์
): {c_n.shape}") # (2, 4, 20)
44
45# ์ด๊ธฐ ์ํ ์ง์
46h0 = torch.zeros(2, 4, 20)
47c0 = torch.zeros(2, 4, 20)
48output, (h_n, c_n) = lstm(x, (h0, c0))
49print(f"\n์ด๊ธฐ ์ํ ์ง์ : h0={h0.shape}, c0={c0.shape}")
50
51
52# ============================================
53# 2. GRU ๊ธฐ๋ณธ
54# ============================================
55print("\n[2] GRU ๊ธฐ๋ณธ")
56print("-" * 40)
57
58gru = nn.GRU(
59 input_size=10,
60 hidden_size=20,
61 num_layers=2,
62 batch_first=True
63)
64
65output, h_n = gru(x)
66
67print(f"GRU output: {output.shape}")
68print(f"GRU h_n: {h_n.shape}") # ์
์ํ ์์
69
70
71# ============================================
72# 3. ์๋ฐฉํฅ LSTM
73# ============================================
74print("\n[3] ์๋ฐฉํฅ LSTM")
75print("-" * 40)
76
77lstm_bi = nn.LSTM(
78 input_size=10,
79 hidden_size=20,
80 num_layers=2,
81 batch_first=True,
82 bidirectional=True
83)
84
85output_bi, (h_n_bi, c_n_bi) = lstm_bi(x)
86
87print(f"์๋ฐฉํฅ LSTM:")
88print(f" output: {output_bi.shape}") # (4, 8, 40)
89print(f" h_n: {h_n_bi.shape}") # (4, 4, 20)
90
91# ์ ๋ฐฉํฅ/์ญ๋ฐฉํฅ ๋ถ๋ฆฌ
92forward_out = output_bi[:, :, :20]
93backward_out = output_bi[:, :, 20:]
94print(f" ์ ๋ฐฉํฅ: {forward_out.shape}")
95print(f" ์ญ๋ฐฉํฅ: {backward_out.shape}")
96
97
98# ============================================
99# 4. LSTM ๋ถ๋ฅ๊ธฐ
100# ============================================
101print("\n[4] LSTM ๋ถ๋ฅ๊ธฐ")
102print("-" * 40)
103
104class LSTMClassifier(nn.Module):
105 def __init__(self, input_size, hidden_size, num_classes,
106 num_layers=2, bidirectional=True, dropout=0.3):
107 super().__init__()
108 self.hidden_size = hidden_size
109 self.num_layers = num_layers
110 self.bidirectional = bidirectional
111 self.num_directions = 2 if bidirectional else 1
112
113 self.lstm = nn.LSTM(
114 input_size, hidden_size,
115 num_layers=num_layers,
116 batch_first=True,
117 bidirectional=bidirectional,
118 dropout=dropout if num_layers > 1 else 0
119 )
120
121 self.dropout = nn.Dropout(dropout)
122 self.fc = nn.Linear(hidden_size * self.num_directions, num_classes)
123
124 def forward(self, x):
125 # x: (batch, seq, features)
126 output, (h_n, c_n) = self.lstm(x)
127
128 # ๋ง์ง๋ง ์ธต์ ์๋ ์ํ ๊ฒฐํฉ
129 if self.bidirectional:
130 # ์ ๋ฐฉํฅ ๋ง์ง๋ง + ์ญ๋ฐฉํฅ ๋ง์ง๋ง
131 forward_last = h_n[-2]
132 backward_last = h_n[-1]
133 combined = torch.cat([forward_last, backward_last], dim=1)
134 else:
135 combined = h_n[-1]
136
137 dropped = self.dropout(combined)
138 return self.fc(dropped)
139
140model = LSTMClassifier(input_size=10, hidden_size=32, num_classes=5)
141out = model(x)
142print(f"๋ถ๋ฅ๊ธฐ ์ถ๋ ฅ: {out.shape}")
143
144
145# ============================================
146# 5. ์๊ณ์ด ์์ธก ๋น๊ต (RNN vs LSTM vs GRU)
147# ============================================
148print("\n[5] RNN vs LSTM vs GRU ๋น๊ต")
149print("-" * 40)
150
151# ๋ ๋ณต์กํ ์๊ณ์ด ๋ฐ์ดํฐ ์์ฑ
152def generate_complex_series(seq_len=100, n_samples=1000):
153 X, y = [], []
154 for _ in range(n_samples):
155 t = np.linspace(0, 10*np.pi, seq_len + 1)
156 # ๋ณตํฉ ํจํด: sin + ๋
ธ์ด์ฆ + ์ถ์ธ
157 signal = np.sin(t) + 0.5*np.sin(3*t) + 0.1*t + np.random.randn(seq_len+1)*0.1
158 X.append(signal[:-1].reshape(-1, 1))
159 y.append(signal[-1])
160 return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)
161
162X, y = generate_complex_series(seq_len=100, n_samples=2000)
163X_train, y_train = torch.from_numpy(X[:1600]), torch.from_numpy(y[:1600])
164X_test, y_test = torch.from_numpy(X[1600:]), torch.from_numpy(y[1600:])
165
166class TimeSeriesModel(nn.Module):
167 def __init__(self, model_type='lstm', hidden_size=64):
168 super().__init__()
169 if model_type == 'rnn':
170 self.rnn = nn.RNN(1, hidden_size, batch_first=True)
171 elif model_type == 'lstm':
172 self.rnn = nn.LSTM(1, hidden_size, batch_first=True)
173 elif model_type == 'gru':
174 self.rnn = nn.GRU(1, hidden_size, batch_first=True)
175
176 self.model_type = model_type
177 self.fc = nn.Linear(hidden_size, 1)
178
179 def forward(self, x):
180 if self.model_type == 'lstm':
181 _, (h_n, _) = self.rnn(x)
182 else:
183 _, h_n = self.rnn(x)
184 return self.fc(h_n[-1]).squeeze(-1)
185
186def train_model(model_type, epochs=30):
187 model = TimeSeriesModel(model_type).to(device)
188 optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
189 criterion = nn.MSELoss()
190
191 train_loader = torch.utils.data.DataLoader(
192 torch.utils.data.TensorDataset(X_train, y_train),
193 batch_size=64, shuffle=True
194 )
195
196 losses = []
197 for epoch in range(epochs):
198 model.train()
199 epoch_loss = 0
200 for X_batch, y_batch in train_loader:
201 X_batch = X_batch.to(device)
202 y_batch = y_batch.to(device)
203
204 pred = model(X_batch)
205 loss = criterion(pred, y_batch)
206
207 optimizer.zero_grad()
208 loss.backward()
209 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
210 optimizer.step()
211
212 epoch_loss += loss.item()
213 losses.append(epoch_loss / len(train_loader))
214
215 # ํ
์คํธ
216 model.eval()
217 with torch.no_grad():
218 test_pred = model(X_test.to(device))
219 test_loss = criterion(test_pred, y_test.to(device)).item()
220
221 return losses, test_loss
222
223# ๋น๊ต ์คํ
224print("๋ชจ๋ธ ํ์ต ์ค...")
225results = {}
226for model_type in ['rnn', 'lstm', 'gru']:
227 losses, test_loss = train_model(model_type)
228 results[model_type] = {'losses': losses, 'test_loss': test_loss}
229 print(f" {model_type.upper()}: Test MSE = {test_loss:.6f}")
230
231# ์๊ฐํ
232plt.figure(figsize=(10, 5))
233for name, data in results.items():
234 plt.plot(data['losses'], label=f"{name.upper()} (test={data['test_loss']:.4f})")
235plt.xlabel('Epoch')
236plt.ylabel('Loss')
237plt.title('RNN vs LSTM vs GRU')
238plt.legend()
239plt.grid(True, alpha=0.3)
240plt.savefig('rnn_lstm_gru_comparison.png', dpi=100)
241plt.close()
242print("๊ทธ๋ํ ์ ์ฅ: rnn_lstm_gru_comparison.png")
243
244
245# ============================================
246# 6. ํ
์คํธ ๋ถ๋ฅ ์์
247# ============================================
248print("\n[6] ํ
์คํธ ๋ถ๋ฅ ์์ ")
249print("-" * 40)
250
251class TextClassifier(nn.Module):
252 def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
253 super().__init__()
254 self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
255 self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True,
256 bidirectional=True, num_layers=2, dropout=0.3)
257 self.fc = nn.Sequential(
258 nn.Dropout(0.5),
259 nn.Linear(hidden_dim * 2, num_classes)
260 )
261
262 def forward(self, x):
263 # x: (batch, seq) - ํ ํฐ ์ธ๋ฑ์ค
264 embedded = self.embedding(x)
265 _, (h_n, _) = self.lstm(embedded)
266 combined = torch.cat([h_n[-2], h_n[-1]], dim=1)
267 return self.fc(combined)
268
269model = TextClassifier(vocab_size=10000, embed_dim=128,
270 hidden_dim=256, num_classes=5)
271print(f"TextClassifier ํ๋ผ๋ฏธํฐ: {sum(p.numel() for p in model.parameters()):,}")
272
273# ๋๋ฏธ ์
๋ ฅ
274x = torch.randint(0, 10000, (8, 50)) # 8 ๋ฌธ์ฅ, 50 ํ ํฐ
275out = model(x)
276print(f"์
๋ ฅ: {x.shape} โ ์ถ๋ ฅ: {out.shape}")
277
278
279# ============================================
280# 7. ์ธ์ด ๋ชจ๋ธ (ํ
์คํธ ์์ฑ)
281# ============================================
282print("\n[7] ์ธ์ด ๋ชจ๋ธ")
283print("-" * 40)
284
285class CharLSTM(nn.Module):
286 """๋ฌธ์ ์์ค ์ธ์ด ๋ชจ๋ธ"""
287 def __init__(self, vocab_size, embed_dim, hidden_dim):
288 super().__init__()
289 self.embedding = nn.Embedding(vocab_size, embed_dim)
290 self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
291 self.fc = nn.Linear(hidden_dim, vocab_size)
292
293 def forward(self, x, hidden=None):
294 embedded = self.embedding(x)
295 output, hidden = self.lstm(embedded, hidden)
296 logits = self.fc(output)
297 return logits, hidden
298
299 def generate(self, start_tokens, max_len=50, temperature=1.0):
300 self.eval()
301 tokens = list(start_tokens)
302 hidden = None
303
304 with torch.no_grad():
305 for _ in range(max_len):
306 x = torch.tensor([[tokens[-1]]])
307 logits, hidden = self(x, hidden)
308
309 # Temperature sampling
310 probs = F.softmax(logits[0, -1] / temperature, dim=0)
311 next_token = torch.multinomial(probs, 1).item()
312 tokens.append(next_token)
313
314 return tokens
315
316char_lm = CharLSTM(vocab_size=128, embed_dim=64, hidden_dim=256)
317print(f"CharLSTM ํ๋ผ๋ฏธํฐ: {sum(p.numel() for p in char_lm.parameters()):,}")
318
319# ์์ฑ ํ
์คํธ
320generated = char_lm.generate([65, 66, 67], max_len=20) # ABC...
321print(f"์์ฑ๋ ํ ํฐ: {generated[:10]}...")
322
323
324# ============================================
325# 8. LSTM ๋ด๋ถ ์๊ฐํ
326# ============================================
327print("\n[8] LSTM ๊ฒ์ดํธ ๋ถ์")
328print("-" * 40)
329
330class LSTMWithGates(nn.Module):
331 """๊ฒ์ดํธ ๊ฐ์ ๋ฐํํ๋ LSTM"""
332 def __init__(self, input_size, hidden_size):
333 super().__init__()
334 self.hidden_size = hidden_size
335 self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
336
337 def forward(self, x):
338 batch_size, seq_len, _ = x.shape
339 h = torch.zeros(batch_size, self.hidden_size)
340 c = torch.zeros(batch_size, self.hidden_size)
341
342 outputs = []
343 gates = {'input': [], 'forget': [], 'output': []}
344
345 for t in range(seq_len):
346 h, c = self.lstm_cell(x[:, t], (h, c))
347 outputs.append(h)
348
349 return torch.stack(outputs, dim=1)
350
351# ํ
์คํธ
352lstm_gates = LSTMWithGates(10, 20)
353x = torch.randn(1, 30, 10)
354out = lstm_gates(x)
355print(f"๊ฒ์ดํธ ๋ถ์์ฉ LSTM ์ถ๋ ฅ: {out.shape}")
356
357
358# ============================================
359# ์ ๋ฆฌ
360# ============================================
361print("\n" + "=" * 60)
362print("LSTM/GRU ์ ๋ฆฌ")
363print("=" * 60)
364
365summary = """
366LSTM:
367 output, (h_n, c_n) = lstm(x)
368 - ์
์ํ(c)๋ก ์ฅ๊ธฐ ๊ธฐ์ต ์ ์ง
369 - Forget, Input, Output ๊ฒ์ดํธ
370
371GRU:
372 output, h_n = gru(x)
373 - ์
์ํ ์์, ๋ ๋จ์
374 - Reset, Update ๊ฒ์ดํธ
375
376๋ถ๋ฅ ํจํด:
377 # ์๋ฐฉํฅ LSTM
378 forward_last = h_n[-2]
379 backward_last = h_n[-1]
380 combined = torch.cat([forward_last, backward_last], dim=1)
381 output = fc(combined)
382
383ํ
์คํธ ๋ถ๋ฅ:
384 embedded = embedding(x) # ํ ํฐ โ ๋ฒกํฐ
385 _, (h_n, _) = lstm(embedded)
386 output = fc(h_n[-1])
387
388์ ํ ๊ธฐ์ค:
389 - ๊ธด ์ํ์ค, ๋ณต์กํ ์์กด์ฑ โ LSTM
390 - ๋น ๋ฅธ ํ์ต, ์ ํ๋ ์์ โ GRU
391 - ๋จ์ํ ํจํด โ RNN๋ ๊ฐ๋ฅ
392"""
393print(summary)
394print("=" * 60)