lstm_gru_lowlevel.py

Download
python 571 lines 16.3 KB
  1"""
  2PyTorch Low-Level LSTM/GRU ๊ตฌํ˜„
  3
  4nn.LSTM, nn.GRU ๋Œ€์‹  F.linear, torch.sigmoid, torch.tanh ์‚ฌ์šฉ
  5ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ˆ˜๋™์œผ๋กœ ๊ด€๋ฆฌ
  6"""
  7
  8import torch
  9import torch.nn.functional as F
 10import math
 11from typing import Tuple, List, Optional
 12
 13
 14class LSTMCellLowLevel:
 15    """
 16    ๋‹จ์ผ LSTM Cell (Low-Level PyTorch)
 17
 18    nn.LSTMCell ๋ฏธ์‚ฌ์šฉ
 19    """
 20
 21    def __init__(self, input_size: int, hidden_size: int, device: torch.device = None):
 22        self.input_size = input_size
 23        self.hidden_size = hidden_size
 24        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 25
 26        # Xavier ์ดˆ๊ธฐํ™”
 27        concat_size = input_size + hidden_size
 28        std = math.sqrt(2.0 / (concat_size + hidden_size))
 29
 30        # 4๊ฐœ ๊ฒŒ์ดํŠธ๋ฅผ ํ•˜๋‚˜๋กœ: [forget, input, candidate, output]
 31        self.W_ih = torch.randn(
 32            4 * hidden_size, input_size,
 33            requires_grad=True, device=self.device
 34        ) * std
 35        self.W_hh = torch.randn(
 36            4 * hidden_size, hidden_size,
 37            requires_grad=True, device=self.device
 38        ) * std
 39        self.bias = torch.zeros(
 40            4 * hidden_size,
 41            requires_grad=True, device=self.device
 42        )
 43
 44    def forward(
 45        self,
 46        x: torch.Tensor,
 47        hx: Tuple[torch.Tensor, torch.Tensor]
 48    ) -> Tuple[torch.Tensor, torch.Tensor]:
 49        """
 50        Forward pass
 51
 52        Args:
 53            x: (batch_size, input_size)
 54            hx: (h_prev, c_prev) ๊ฐ๊ฐ (batch_size, hidden_size)
 55
 56        Returns:
 57            h_t, c_t: ๊ฐ๊ฐ (batch_size, hidden_size)
 58        """
 59        h_prev, c_prev = hx
 60        H = self.hidden_size
 61
 62        # ๊ฒŒ์ดํŠธ ๊ณ„์‚ฐ
 63        gates = (x @ self.W_ih.t() + h_prev @ self.W_hh.t() + self.bias)
 64
 65        # ๋ถ„๋ฆฌ
 66        f = torch.sigmoid(gates[:, 0:H])           # Forget
 67        i = torch.sigmoid(gates[:, H:2*H])         # Input
 68        g = torch.tanh(gates[:, 2*H:3*H])          # Candidate
 69        o = torch.sigmoid(gates[:, 3*H:4*H])       # Output
 70
 71        # Cell & Hidden
 72        c_t = f * c_prev + i * g
 73        h_t = o * torch.tanh(c_t)
 74
 75        return h_t, c_t
 76
 77    def parameters(self) -> List[torch.Tensor]:
 78        return [self.W_ih, self.W_hh, self.bias]
 79
 80
 81class GRUCellLowLevel:
 82    """
 83    ๋‹จ์ผ GRU Cell (Low-Level PyTorch)
 84
 85    nn.GRUCell ๋ฏธ์‚ฌ์šฉ
 86    """
 87
 88    def __init__(self, input_size: int, hidden_size: int, device: torch.device = None):
 89        self.input_size = input_size
 90        self.hidden_size = hidden_size
 91        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 92
 93        concat_size = input_size + hidden_size
 94        std = math.sqrt(2.0 / (concat_size + hidden_size))
 95
 96        # 3๊ฐœ ๊ฒŒ์ดํŠธ: [reset, update, candidate]
 97        self.W_ih = torch.randn(
 98            3 * hidden_size, input_size,
 99            requires_grad=True, device=self.device
100        ) * std
101        self.W_hh = torch.randn(
102            3 * hidden_size, hidden_size,
103            requires_grad=True, device=self.device
104        ) * std
105        self.bias = torch.zeros(
106            3 * hidden_size,
107            requires_grad=True, device=self.device
108        )
109
110    def forward(
111        self,
112        x: torch.Tensor,
113        h_prev: torch.Tensor
114    ) -> torch.Tensor:
115        """
116        Forward pass
117
118        Args:
119            x: (batch_size, input_size)
120            h_prev: (batch_size, hidden_size)
121
122        Returns:
123            h_t: (batch_size, hidden_size)
124        """
125        H = self.hidden_size
126
127        # Reset, Update ๊ฒŒ์ดํŠธ
128        ih = x @ self.W_ih.t()
129        hh = h_prev @ self.W_hh.t()
130
131        r = torch.sigmoid(ih[:, 0:H] + hh[:, 0:H] + self.bias[0:H])
132        z = torch.sigmoid(ih[:, H:2*H] + hh[:, H:2*H] + self.bias[H:2*H])
133
134        # Candidate (reset ์ ์šฉ)
135        n = torch.tanh(ih[:, 2*H:3*H] + r * hh[:, 2*H:3*H] + self.bias[2*H:3*H])
136
137        # Hidden
138        h_t = (1 - z) * h_prev + z * n
139
140        return h_t
141
142    def parameters(self) -> List[torch.Tensor]:
143        return [self.W_ih, self.W_hh, self.bias]
144
145
146class LSTMLowLevel:
147    """
148    ๋‹ค์ธต LSTM (Low-Level PyTorch)
149    """
150
151    def __init__(
152        self,
153        input_size: int,
154        hidden_size: int,
155        num_layers: int = 1,
156        bidirectional: bool = False,
157        dropout: float = 0.0,
158        device: torch.device = None
159    ):
160        self.input_size = input_size
161        self.hidden_size = hidden_size
162        self.num_layers = num_layers
163        self.bidirectional = bidirectional
164        self.dropout = dropout
165        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166
167        self.num_directions = 2 if bidirectional else 1
168
169        # ๋ ˆ์ด์–ด๋ณ„ Cell ์ƒ์„ฑ
170        self.cells = []
171        for layer in range(num_layers):
172            for direction in range(self.num_directions):
173                in_size = input_size if layer == 0 else hidden_size * self.num_directions
174                cell = LSTMCellLowLevel(in_size, hidden_size, self.device)
175                self.cells.append(cell)
176
177    def forward(
178        self,
179        x: torch.Tensor,
180        hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
181    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
182        """
183        Forward pass
184
185        Args:
186            x: (seq_len, batch_size, input_size)
187            hx: (h_0, c_0) ๊ฐ๊ฐ (num_layers * num_directions, batch, hidden)
188
189        Returns:
190            output: (seq_len, batch, hidden * num_directions)
191            (h_n, c_n): ๋งˆ์ง€๋ง‰ ์ƒํƒœ
192        """
193        seq_len, batch_size, _ = x.shape
194
195        # ์ดˆ๊ธฐ ์ƒํƒœ
196        if hx is None:
197            h_0 = torch.zeros(
198                self.num_layers * self.num_directions, batch_size, self.hidden_size,
199                device=self.device
200            )
201            c_0 = torch.zeros_like(h_0)
202        else:
203            h_0, c_0 = hx
204
205        h_states = list(h_0)
206        c_states = list(c_0)
207
208        output = x
209        new_h_states = []
210        new_c_states = []
211
212        for layer in range(self.num_layers):
213            # Forward direction
214            cell_idx = layer * self.num_directions
215            cell = self.cells[cell_idx]
216
217            h, c = h_states[cell_idx], c_states[cell_idx]
218            forward_outputs = []
219
220            for t in range(seq_len):
221                h, c = cell.forward(output[t], (h, c))
222                forward_outputs.append(h)
223
224            new_h_states.append(h)
225            new_c_states.append(c)
226
227            if self.bidirectional:
228                # Backward direction
229                cell = self.cells[cell_idx + 1]
230                h, c = h_states[cell_idx + 1], c_states[cell_idx + 1]
231                backward_outputs = []
232
233                for t in reversed(range(seq_len)):
234                    h, c = cell.forward(output[t], (h, c))
235                    backward_outputs.insert(0, h)
236
237                new_h_states.append(h)
238                new_c_states.append(c)
239
240                # Forward + Backward concat
241                output = torch.stack([
242                    torch.cat([f, b], dim=-1)
243                    for f, b in zip(forward_outputs, backward_outputs)
244                ])
245            else:
246                output = torch.stack(forward_outputs)
247
248            # Dropout (๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด ์ œ์™ธ)
249            if self.dropout > 0 and layer < self.num_layers - 1:
250                output = F.dropout(output, p=self.dropout, training=True)
251
252        h_n = torch.stack(new_h_states)
253        c_n = torch.stack(new_c_states)
254
255        return output, (h_n, c_n)
256
257    def parameters(self) -> List[torch.Tensor]:
258        params = []
259        for cell in self.cells:
260            params.extend(cell.parameters())
261        return params
262
263    def zero_grad(self):
264        for param in self.parameters():
265            if param.grad is not None:
266                param.grad.zero_()
267
268
269class GRULowLevel:
270    """
271    ๋‹ค์ธต GRU (Low-Level PyTorch)
272    """
273
274    def __init__(
275        self,
276        input_size: int,
277        hidden_size: int,
278        num_layers: int = 1,
279        bidirectional: bool = False,
280        dropout: float = 0.0,
281        device: torch.device = None
282    ):
283        self.input_size = input_size
284        self.hidden_size = hidden_size
285        self.num_layers = num_layers
286        self.bidirectional = bidirectional
287        self.dropout = dropout
288        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
289
290        self.num_directions = 2 if bidirectional else 1
291
292        self.cells = []
293        for layer in range(num_layers):
294            for direction in range(self.num_directions):
295                in_size = input_size if layer == 0 else hidden_size * self.num_directions
296                cell = GRUCellLowLevel(in_size, hidden_size, self.device)
297                self.cells.append(cell)
298
299    def forward(
300        self,
301        x: torch.Tensor,
302        h_0: Optional[torch.Tensor] = None
303    ) -> Tuple[torch.Tensor, torch.Tensor]:
304        """
305        Forward pass
306
307        Args:
308            x: (seq_len, batch_size, input_size)
309            h_0: (num_layers * num_directions, batch, hidden)
310
311        Returns:
312            output: (seq_len, batch, hidden * num_directions)
313            h_n: ๋งˆ์ง€๋ง‰ hidden
314        """
315        seq_len, batch_size, _ = x.shape
316
317        if h_0 is None:
318            h_0 = torch.zeros(
319                self.num_layers * self.num_directions, batch_size, self.hidden_size,
320                device=self.device
321            )
322
323        h_states = list(h_0)
324        output = x
325        new_h_states = []
326
327        for layer in range(self.num_layers):
328            cell_idx = layer * self.num_directions
329            cell = self.cells[cell_idx]
330
331            h = h_states[cell_idx]
332            forward_outputs = []
333
334            for t in range(seq_len):
335                h = cell.forward(output[t], h)
336                forward_outputs.append(h)
337
338            new_h_states.append(h)
339
340            if self.bidirectional:
341                cell = self.cells[cell_idx + 1]
342                h = h_states[cell_idx + 1]
343                backward_outputs = []
344
345                for t in reversed(range(seq_len)):
346                    h = cell.forward(output[t], h)
347                    backward_outputs.insert(0, h)
348
349                new_h_states.append(h)
350
351                output = torch.stack([
352                    torch.cat([f, b], dim=-1)
353                    for f, b in zip(forward_outputs, backward_outputs)
354                ])
355            else:
356                output = torch.stack(forward_outputs)
357
358            if self.dropout > 0 and layer < self.num_layers - 1:
359                output = F.dropout(output, p=self.dropout, training=True)
360
361        h_n = torch.stack(new_h_states)
362
363        return output, h_n
364
365    def parameters(self) -> List[torch.Tensor]:
366        params = []
367        for cell in self.cells:
368            params.extend(cell.parameters())
369        return params
370
371    def zero_grad(self):
372        for param in self.parameters():
373            if param.grad is not None:
374                param.grad.zero_()
375
376
377class SequenceClassifier:
378    """
379    LSTM/GRU ๊ธฐ๋ฐ˜ ์‹œํ€€์Šค ๋ถ„๋ฅ˜๊ธฐ
380    """
381
382    def __init__(
383        self,
384        vocab_size: int,
385        embed_size: int,
386        hidden_size: int,
387        num_classes: int,
388        num_layers: int = 1,
389        bidirectional: bool = False,
390        rnn_type: str = 'lstm',
391        device: torch.device = None
392    ):
393        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
394
395        # Embedding
396        self.embedding = torch.randn(
397            vocab_size, embed_size,
398            requires_grad=True, device=self.device
399        ) * 0.1
400
401        # RNN
402        if rnn_type == 'lstm':
403            self.rnn = LSTMLowLevel(
404                embed_size, hidden_size, num_layers,
405                bidirectional, dropout=0.3, device=self.device
406            )
407        else:
408            self.rnn = GRULowLevel(
409                embed_size, hidden_size, num_layers,
410                bidirectional, dropout=0.3, device=self.device
411            )
412
413        # Classifier
414        fc_in = hidden_size * (2 if bidirectional else 1)
415        std = math.sqrt(2.0 / (fc_in + num_classes))
416        self.fc_weight = torch.randn(
417            num_classes, fc_in,
418            requires_grad=True, device=self.device
419        ) * std
420        self.fc_bias = torch.zeros(num_classes, requires_grad=True, device=self.device)
421
422    def forward(self, x: torch.Tensor) -> torch.Tensor:
423        """
424        Args:
425            x: (batch_size, seq_len) ํ† ํฐ ์ธ๋ฑ์Šค
426
427        Returns:
428            logits: (batch_size, num_classes)
429        """
430        # Embedding
431        embedded = F.embedding(x, self.embedding)  # (batch, seq, embed)
432        embedded = embedded.transpose(0, 1)  # (seq, batch, embed)
433
434        # RNN
435        if isinstance(self.rnn, LSTMLowLevel):
436            output, (h_n, c_n) = self.rnn.forward(embedded)
437        else:
438            output, h_n = self.rnn.forward(embedded)
439
440        # ๋งˆ์ง€๋ง‰ hidden (bidirectional์ด๋ฉด concat)
441        if self.rnn.bidirectional:
442            # Forward์˜ ๋งˆ์ง€๋ง‰ + Backward์˜ ์ฒซ ๋ฒˆ์งธ
443            last_hidden = torch.cat([h_n[-2], h_n[-1]], dim=-1)
444        else:
445            last_hidden = h_n[-1]
446
447        # Classifier
448        logits = last_hidden @ self.fc_weight.t() + self.fc_bias
449
450        return logits
451
452    def parameters(self) -> List[torch.Tensor]:
453        params = [self.embedding]
454        params.extend(self.rnn.parameters())
455        params.extend([self.fc_weight, self.fc_bias])
456        return params
457
458    def zero_grad(self):
459        for param in self.parameters():
460            if param.grad is not None:
461                param.grad.zero_()
462
463
464def train_imdb_sentiment():
465    """IMDB ๊ฐ์„ฑ ๋ถ„์„ (๊ฐ„์†Œํ™” ๋ฒ„์ „)"""
466    print("=== LSTM/GRU Sentiment Analysis ===\n")
467
468    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
469    print(f"Device: {device}")
470
471    # ๋”๋ฏธ ๋ฐ์ดํ„ฐ (์‹ค์ œ๋กœ๋Š” torchtext ์‚ฌ์šฉ)
472    vocab_size = 10000
473    seq_len = 100
474    batch_size = 32
475    num_samples = 1000
476
477    # ๊ฐ€์ƒ์˜ ํ•™์Šต ๋ฐ์ดํ„ฐ
478    X_train = torch.randint(0, vocab_size, (num_samples, seq_len), device=device)
479    y_train = torch.randint(0, 2, (num_samples,), device=device)
480
481    # ๋ชจ๋ธ
482    model = SequenceClassifier(
483        vocab_size=vocab_size,
484        embed_size=128,
485        hidden_size=256,
486        num_classes=2,
487        num_layers=2,
488        bidirectional=True,
489        rnn_type='lstm',
490        device=device
491    )
492
493    param_count = sum(p.numel() for p in model.parameters())
494    print(f"Parameters: {param_count:,}")
495
496    # ํ•™์Šต
497    lr = 0.001
498    epochs = 5
499
500    for epoch in range(epochs):
501        total_loss = 0
502        total_correct = 0
503
504        for i in range(0, num_samples, batch_size):
505            batch_x = X_train[i:i+batch_size]
506            batch_y = y_train[i:i+batch_size]
507
508            # Forward
509            logits = model.forward(batch_x)
510            loss = F.cross_entropy(logits, batch_y)
511
512            # Backward
513            model.zero_grad()
514            loss.backward()
515
516            # Gradient clipping
517            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
518
519            # SGD update
520            with torch.no_grad():
521                for param in model.parameters():
522                    if param.grad is not None:
523                        param -= lr * param.grad
524
525            total_loss += loss.item() * len(batch_y)
526            total_correct += (logits.argmax(dim=1) == batch_y).sum().item()
527
528        avg_loss = total_loss / num_samples
529        accuracy = total_correct / num_samples
530        print(f"Epoch {epoch+1}/{epochs}: Loss={avg_loss:.4f}, Acc={accuracy:.4f}")
531
532
533def main():
534    """ํ…Œ์ŠคํŠธ"""
535    print("=== LSTM/GRU Low-Level Test ===\n")
536
537    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
538
539    # LSTM ํ…Œ์ŠคํŠธ
540    print("Testing LSTM...")
541    lstm = LSTMLowLevel(
542        input_size=10, hidden_size=20,
543        num_layers=2, bidirectional=True, device=device
544    )
545
546    x = torch.randn(5, 3, 10, device=device)  # (seq, batch, input)
547    output, (h_n, c_n) = lstm.forward(x)
548
549    print(f"  Input: {x.shape}")
550    print(f"  Output: {output.shape}")  # (5, 3, 40) bidirectional
551    print(f"  h_n: {h_n.shape}")  # (4, 3, 20) 2 layers * 2 directions
552
553    # GRU ํ…Œ์ŠคํŠธ
554    print("\nTesting GRU...")
555    gru = GRULowLevel(
556        input_size=10, hidden_size=20,
557        num_layers=2, bidirectional=False, device=device
558    )
559
560    output, h_n = gru.forward(x)
561    print(f"  Output: {output.shape}")  # (5, 3, 20)
562    print(f"  h_n: {h_n.shape}")  # (2, 3, 20)
563
564    # ๊ฐ์„ฑ ๋ถ„์„ ํ•™์Šต
565    print()
566    train_imdb_sentiment()
567
568
569if __name__ == "__main__":
570    main()