1"""
2PyTorch Low-Level LSTM/GRU ๊ตฌํ
3
4nn.LSTM, nn.GRU ๋์ F.linear, torch.sigmoid, torch.tanh ์ฌ์ฉ
5ํ๋ผ๋ฏธํฐ๋ฅผ ์๋์ผ๋ก ๊ด๋ฆฌ
6"""
7
8import torch
9import torch.nn.functional as F
10import math
11from typing import Tuple, List, Optional
12
13
14class LSTMCellLowLevel:
15 """
16 ๋จ์ผ LSTM Cell (Low-Level PyTorch)
17
18 nn.LSTMCell ๋ฏธ์ฌ์ฉ
19 """
20
21 def __init__(self, input_size: int, hidden_size: int, device: torch.device = None):
22 self.input_size = input_size
23 self.hidden_size = hidden_size
24 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
26 # Xavier ์ด๊ธฐํ
27 concat_size = input_size + hidden_size
28 std = math.sqrt(2.0 / (concat_size + hidden_size))
29
30 # 4๊ฐ ๊ฒ์ดํธ๋ฅผ ํ๋๋ก: [forget, input, candidate, output]
31 self.W_ih = torch.randn(
32 4 * hidden_size, input_size,
33 requires_grad=True, device=self.device
34 ) * std
35 self.W_hh = torch.randn(
36 4 * hidden_size, hidden_size,
37 requires_grad=True, device=self.device
38 ) * std
39 self.bias = torch.zeros(
40 4 * hidden_size,
41 requires_grad=True, device=self.device
42 )
43
44 def forward(
45 self,
46 x: torch.Tensor,
47 hx: Tuple[torch.Tensor, torch.Tensor]
48 ) -> Tuple[torch.Tensor, torch.Tensor]:
49 """
50 Forward pass
51
52 Args:
53 x: (batch_size, input_size)
54 hx: (h_prev, c_prev) ๊ฐ๊ฐ (batch_size, hidden_size)
55
56 Returns:
57 h_t, c_t: ๊ฐ๊ฐ (batch_size, hidden_size)
58 """
59 h_prev, c_prev = hx
60 H = self.hidden_size
61
62 # ๊ฒ์ดํธ ๊ณ์ฐ
63 gates = (x @ self.W_ih.t() + h_prev @ self.W_hh.t() + self.bias)
64
65 # ๋ถ๋ฆฌ
66 f = torch.sigmoid(gates[:, 0:H]) # Forget
67 i = torch.sigmoid(gates[:, H:2*H]) # Input
68 g = torch.tanh(gates[:, 2*H:3*H]) # Candidate
69 o = torch.sigmoid(gates[:, 3*H:4*H]) # Output
70
71 # Cell & Hidden
72 c_t = f * c_prev + i * g
73 h_t = o * torch.tanh(c_t)
74
75 return h_t, c_t
76
77 def parameters(self) -> List[torch.Tensor]:
78 return [self.W_ih, self.W_hh, self.bias]
79
80
81class GRUCellLowLevel:
82 """
83 ๋จ์ผ GRU Cell (Low-Level PyTorch)
84
85 nn.GRUCell ๋ฏธ์ฌ์ฉ
86 """
87
88 def __init__(self, input_size: int, hidden_size: int, device: torch.device = None):
89 self.input_size = input_size
90 self.hidden_size = hidden_size
91 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
93 concat_size = input_size + hidden_size
94 std = math.sqrt(2.0 / (concat_size + hidden_size))
95
96 # 3๊ฐ ๊ฒ์ดํธ: [reset, update, candidate]
97 self.W_ih = torch.randn(
98 3 * hidden_size, input_size,
99 requires_grad=True, device=self.device
100 ) * std
101 self.W_hh = torch.randn(
102 3 * hidden_size, hidden_size,
103 requires_grad=True, device=self.device
104 ) * std
105 self.bias = torch.zeros(
106 3 * hidden_size,
107 requires_grad=True, device=self.device
108 )
109
110 def forward(
111 self,
112 x: torch.Tensor,
113 h_prev: torch.Tensor
114 ) -> torch.Tensor:
115 """
116 Forward pass
117
118 Args:
119 x: (batch_size, input_size)
120 h_prev: (batch_size, hidden_size)
121
122 Returns:
123 h_t: (batch_size, hidden_size)
124 """
125 H = self.hidden_size
126
127 # Reset, Update ๊ฒ์ดํธ
128 ih = x @ self.W_ih.t()
129 hh = h_prev @ self.W_hh.t()
130
131 r = torch.sigmoid(ih[:, 0:H] + hh[:, 0:H] + self.bias[0:H])
132 z = torch.sigmoid(ih[:, H:2*H] + hh[:, H:2*H] + self.bias[H:2*H])
133
134 # Candidate (reset ์ ์ฉ)
135 n = torch.tanh(ih[:, 2*H:3*H] + r * hh[:, 2*H:3*H] + self.bias[2*H:3*H])
136
137 # Hidden
138 h_t = (1 - z) * h_prev + z * n
139
140 return h_t
141
142 def parameters(self) -> List[torch.Tensor]:
143 return [self.W_ih, self.W_hh, self.bias]
144
145
146class LSTMLowLevel:
147 """
148 ๋ค์ธต LSTM (Low-Level PyTorch)
149 """
150
151 def __init__(
152 self,
153 input_size: int,
154 hidden_size: int,
155 num_layers: int = 1,
156 bidirectional: bool = False,
157 dropout: float = 0.0,
158 device: torch.device = None
159 ):
160 self.input_size = input_size
161 self.hidden_size = hidden_size
162 self.num_layers = num_layers
163 self.bidirectional = bidirectional
164 self.dropout = dropout
165 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166
167 self.num_directions = 2 if bidirectional else 1
168
169 # ๋ ์ด์ด๋ณ Cell ์์ฑ
170 self.cells = []
171 for layer in range(num_layers):
172 for direction in range(self.num_directions):
173 in_size = input_size if layer == 0 else hidden_size * self.num_directions
174 cell = LSTMCellLowLevel(in_size, hidden_size, self.device)
175 self.cells.append(cell)
176
177 def forward(
178 self,
179 x: torch.Tensor,
180 hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
181 ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
182 """
183 Forward pass
184
185 Args:
186 x: (seq_len, batch_size, input_size)
187 hx: (h_0, c_0) ๊ฐ๊ฐ (num_layers * num_directions, batch, hidden)
188
189 Returns:
190 output: (seq_len, batch, hidden * num_directions)
191 (h_n, c_n): ๋ง์ง๋ง ์ํ
192 """
193 seq_len, batch_size, _ = x.shape
194
195 # ์ด๊ธฐ ์ํ
196 if hx is None:
197 h_0 = torch.zeros(
198 self.num_layers * self.num_directions, batch_size, self.hidden_size,
199 device=self.device
200 )
201 c_0 = torch.zeros_like(h_0)
202 else:
203 h_0, c_0 = hx
204
205 h_states = list(h_0)
206 c_states = list(c_0)
207
208 output = x
209 new_h_states = []
210 new_c_states = []
211
212 for layer in range(self.num_layers):
213 # Forward direction
214 cell_idx = layer * self.num_directions
215 cell = self.cells[cell_idx]
216
217 h, c = h_states[cell_idx], c_states[cell_idx]
218 forward_outputs = []
219
220 for t in range(seq_len):
221 h, c = cell.forward(output[t], (h, c))
222 forward_outputs.append(h)
223
224 new_h_states.append(h)
225 new_c_states.append(c)
226
227 if self.bidirectional:
228 # Backward direction
229 cell = self.cells[cell_idx + 1]
230 h, c = h_states[cell_idx + 1], c_states[cell_idx + 1]
231 backward_outputs = []
232
233 for t in reversed(range(seq_len)):
234 h, c = cell.forward(output[t], (h, c))
235 backward_outputs.insert(0, h)
236
237 new_h_states.append(h)
238 new_c_states.append(c)
239
240 # Forward + Backward concat
241 output = torch.stack([
242 torch.cat([f, b], dim=-1)
243 for f, b in zip(forward_outputs, backward_outputs)
244 ])
245 else:
246 output = torch.stack(forward_outputs)
247
248 # Dropout (๋ง์ง๋ง ๋ ์ด์ด ์ ์ธ)
249 if self.dropout > 0 and layer < self.num_layers - 1:
250 output = F.dropout(output, p=self.dropout, training=True)
251
252 h_n = torch.stack(new_h_states)
253 c_n = torch.stack(new_c_states)
254
255 return output, (h_n, c_n)
256
257 def parameters(self) -> List[torch.Tensor]:
258 params = []
259 for cell in self.cells:
260 params.extend(cell.parameters())
261 return params
262
263 def zero_grad(self):
264 for param in self.parameters():
265 if param.grad is not None:
266 param.grad.zero_()
267
268
269class GRULowLevel:
270 """
271 ๋ค์ธต GRU (Low-Level PyTorch)
272 """
273
274 def __init__(
275 self,
276 input_size: int,
277 hidden_size: int,
278 num_layers: int = 1,
279 bidirectional: bool = False,
280 dropout: float = 0.0,
281 device: torch.device = None
282 ):
283 self.input_size = input_size
284 self.hidden_size = hidden_size
285 self.num_layers = num_layers
286 self.bidirectional = bidirectional
287 self.dropout = dropout
288 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
289
290 self.num_directions = 2 if bidirectional else 1
291
292 self.cells = []
293 for layer in range(num_layers):
294 for direction in range(self.num_directions):
295 in_size = input_size if layer == 0 else hidden_size * self.num_directions
296 cell = GRUCellLowLevel(in_size, hidden_size, self.device)
297 self.cells.append(cell)
298
299 def forward(
300 self,
301 x: torch.Tensor,
302 h_0: Optional[torch.Tensor] = None
303 ) -> Tuple[torch.Tensor, torch.Tensor]:
304 """
305 Forward pass
306
307 Args:
308 x: (seq_len, batch_size, input_size)
309 h_0: (num_layers * num_directions, batch, hidden)
310
311 Returns:
312 output: (seq_len, batch, hidden * num_directions)
313 h_n: ๋ง์ง๋ง hidden
314 """
315 seq_len, batch_size, _ = x.shape
316
317 if h_0 is None:
318 h_0 = torch.zeros(
319 self.num_layers * self.num_directions, batch_size, self.hidden_size,
320 device=self.device
321 )
322
323 h_states = list(h_0)
324 output = x
325 new_h_states = []
326
327 for layer in range(self.num_layers):
328 cell_idx = layer * self.num_directions
329 cell = self.cells[cell_idx]
330
331 h = h_states[cell_idx]
332 forward_outputs = []
333
334 for t in range(seq_len):
335 h = cell.forward(output[t], h)
336 forward_outputs.append(h)
337
338 new_h_states.append(h)
339
340 if self.bidirectional:
341 cell = self.cells[cell_idx + 1]
342 h = h_states[cell_idx + 1]
343 backward_outputs = []
344
345 for t in reversed(range(seq_len)):
346 h = cell.forward(output[t], h)
347 backward_outputs.insert(0, h)
348
349 new_h_states.append(h)
350
351 output = torch.stack([
352 torch.cat([f, b], dim=-1)
353 for f, b in zip(forward_outputs, backward_outputs)
354 ])
355 else:
356 output = torch.stack(forward_outputs)
357
358 if self.dropout > 0 and layer < self.num_layers - 1:
359 output = F.dropout(output, p=self.dropout, training=True)
360
361 h_n = torch.stack(new_h_states)
362
363 return output, h_n
364
365 def parameters(self) -> List[torch.Tensor]:
366 params = []
367 for cell in self.cells:
368 params.extend(cell.parameters())
369 return params
370
371 def zero_grad(self):
372 for param in self.parameters():
373 if param.grad is not None:
374 param.grad.zero_()
375
376
377class SequenceClassifier:
378 """
379 LSTM/GRU ๊ธฐ๋ฐ ์ํ์ค ๋ถ๋ฅ๊ธฐ
380 """
381
382 def __init__(
383 self,
384 vocab_size: int,
385 embed_size: int,
386 hidden_size: int,
387 num_classes: int,
388 num_layers: int = 1,
389 bidirectional: bool = False,
390 rnn_type: str = 'lstm',
391 device: torch.device = None
392 ):
393 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
394
395 # Embedding
396 self.embedding = torch.randn(
397 vocab_size, embed_size,
398 requires_grad=True, device=self.device
399 ) * 0.1
400
401 # RNN
402 if rnn_type == 'lstm':
403 self.rnn = LSTMLowLevel(
404 embed_size, hidden_size, num_layers,
405 bidirectional, dropout=0.3, device=self.device
406 )
407 else:
408 self.rnn = GRULowLevel(
409 embed_size, hidden_size, num_layers,
410 bidirectional, dropout=0.3, device=self.device
411 )
412
413 # Classifier
414 fc_in = hidden_size * (2 if bidirectional else 1)
415 std = math.sqrt(2.0 / (fc_in + num_classes))
416 self.fc_weight = torch.randn(
417 num_classes, fc_in,
418 requires_grad=True, device=self.device
419 ) * std
420 self.fc_bias = torch.zeros(num_classes, requires_grad=True, device=self.device)
421
422 def forward(self, x: torch.Tensor) -> torch.Tensor:
423 """
424 Args:
425 x: (batch_size, seq_len) ํ ํฐ ์ธ๋ฑ์ค
426
427 Returns:
428 logits: (batch_size, num_classes)
429 """
430 # Embedding
431 embedded = F.embedding(x, self.embedding) # (batch, seq, embed)
432 embedded = embedded.transpose(0, 1) # (seq, batch, embed)
433
434 # RNN
435 if isinstance(self.rnn, LSTMLowLevel):
436 output, (h_n, c_n) = self.rnn.forward(embedded)
437 else:
438 output, h_n = self.rnn.forward(embedded)
439
440 # ๋ง์ง๋ง hidden (bidirectional์ด๋ฉด concat)
441 if self.rnn.bidirectional:
442 # Forward์ ๋ง์ง๋ง + Backward์ ์ฒซ ๋ฒ์งธ
443 last_hidden = torch.cat([h_n[-2], h_n[-1]], dim=-1)
444 else:
445 last_hidden = h_n[-1]
446
447 # Classifier
448 logits = last_hidden @ self.fc_weight.t() + self.fc_bias
449
450 return logits
451
452 def parameters(self) -> List[torch.Tensor]:
453 params = [self.embedding]
454 params.extend(self.rnn.parameters())
455 params.extend([self.fc_weight, self.fc_bias])
456 return params
457
458 def zero_grad(self):
459 for param in self.parameters():
460 if param.grad is not None:
461 param.grad.zero_()
462
463
464def train_imdb_sentiment():
465 """IMDB ๊ฐ์ฑ ๋ถ์ (๊ฐ์ํ ๋ฒ์ )"""
466 print("=== LSTM/GRU Sentiment Analysis ===\n")
467
468 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
469 print(f"Device: {device}")
470
471 # ๋๋ฏธ ๋ฐ์ดํฐ (์ค์ ๋ก๋ torchtext ์ฌ์ฉ)
472 vocab_size = 10000
473 seq_len = 100
474 batch_size = 32
475 num_samples = 1000
476
477 # ๊ฐ์์ ํ์ต ๋ฐ์ดํฐ
478 X_train = torch.randint(0, vocab_size, (num_samples, seq_len), device=device)
479 y_train = torch.randint(0, 2, (num_samples,), device=device)
480
481 # ๋ชจ๋ธ
482 model = SequenceClassifier(
483 vocab_size=vocab_size,
484 embed_size=128,
485 hidden_size=256,
486 num_classes=2,
487 num_layers=2,
488 bidirectional=True,
489 rnn_type='lstm',
490 device=device
491 )
492
493 param_count = sum(p.numel() for p in model.parameters())
494 print(f"Parameters: {param_count:,}")
495
496 # ํ์ต
497 lr = 0.001
498 epochs = 5
499
500 for epoch in range(epochs):
501 total_loss = 0
502 total_correct = 0
503
504 for i in range(0, num_samples, batch_size):
505 batch_x = X_train[i:i+batch_size]
506 batch_y = y_train[i:i+batch_size]
507
508 # Forward
509 logits = model.forward(batch_x)
510 loss = F.cross_entropy(logits, batch_y)
511
512 # Backward
513 model.zero_grad()
514 loss.backward()
515
516 # Gradient clipping
517 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
518
519 # SGD update
520 with torch.no_grad():
521 for param in model.parameters():
522 if param.grad is not None:
523 param -= lr * param.grad
524
525 total_loss += loss.item() * len(batch_y)
526 total_correct += (logits.argmax(dim=1) == batch_y).sum().item()
527
528 avg_loss = total_loss / num_samples
529 accuracy = total_correct / num_samples
530 print(f"Epoch {epoch+1}/{epochs}: Loss={avg_loss:.4f}, Acc={accuracy:.4f}")
531
532
533def main():
534 """ํ
์คํธ"""
535 print("=== LSTM/GRU Low-Level Test ===\n")
536
537 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
538
539 # LSTM ํ
์คํธ
540 print("Testing LSTM...")
541 lstm = LSTMLowLevel(
542 input_size=10, hidden_size=20,
543 num_layers=2, bidirectional=True, device=device
544 )
545
546 x = torch.randn(5, 3, 10, device=device) # (seq, batch, input)
547 output, (h_n, c_n) = lstm.forward(x)
548
549 print(f" Input: {x.shape}")
550 print(f" Output: {output.shape}") # (5, 3, 40) bidirectional
551 print(f" h_n: {h_n.shape}") # (4, 3, 20) 2 layers * 2 directions
552
553 # GRU ํ
์คํธ
554 print("\nTesting GRU...")
555 gru = GRULowLevel(
556 input_size=10, hidden_size=20,
557 num_layers=2, bidirectional=False, device=device
558 )
559
560 output, h_n = gru.forward(x)
561 print(f" Output: {output.shape}") # (5, 3, 20)
562 print(f" h_n: {h_n.shape}") # (2, 3, 20)
563
564 # ๊ฐ์ฑ ๋ถ์ ํ์ต
565 print()
566 train_imdb_sentiment()
567
568
569if __name__ == "__main__":
570 main()