lstm_numpy.py

Download
python 457 lines 13.4 KB
  1"""
  2NumPy LSTM From-Scratch ๊ตฌํ˜„
  3
  4๋ชจ๋“  ๊ฒŒ์ดํŠธ ์—ฐ์‚ฐ๊ณผ BPTT๋ฅผ ์ง์ ‘ ๊ตฌํ˜„
  5"""
  6
  7import numpy as np
  8from typing import Tuple, Dict, List
  9
 10
 11def sigmoid(x: np.ndarray) -> np.ndarray:
 12    """Sigmoid ํ™œ์„ฑํ™” (์ˆ˜์น˜ ์•ˆ์ •์„ฑ ๊ณ ๋ ค)"""
 13    return np.where(x >= 0,
 14                    1 / (1 + np.exp(-x)),
 15                    np.exp(x) / (1 + np.exp(x)))
 16
 17
 18def sigmoid_derivative(s: np.ndarray) -> np.ndarray:
 19    """Sigmoid์˜ derivative: s * (1 - s)"""
 20    return s * (1 - s)
 21
 22
 23def tanh_derivative(t: np.ndarray) -> np.ndarray:
 24    """Tanh์˜ derivative: 1 - t^2"""
 25    return 1 - t ** 2
 26
 27
 28class LSTMCellNumPy:
 29    """
 30    ๋‹จ์ผ LSTM Cell (NumPy ๊ตฌํ˜„)
 31
 32    ์ˆ˜์‹:
 33        f_t = ฯƒ(W_f ยท [h_{t-1}, x_t] + b_f)
 34        i_t = ฯƒ(W_i ยท [h_{t-1}, x_t] + b_i)
 35        cฬƒ_t = tanh(W_c ยท [h_{t-1}, x_t] + b_c)
 36        c_t = f_t โŠ™ c_{t-1} + i_t โŠ™ cฬƒ_t
 37        o_t = ฯƒ(W_o ยท [h_{t-1}, x_t] + b_o)
 38        h_t = o_t โŠ™ tanh(c_t)
 39    """
 40
 41    def __init__(self, input_size: int, hidden_size: int):
 42        self.input_size = input_size
 43        self.hidden_size = hidden_size
 44
 45        # Xavier ์ดˆ๊ธฐํ™”
 46        concat_size = input_size + hidden_size
 47        scale = np.sqrt(2.0 / (concat_size + hidden_size))
 48
 49        # 4๊ฐœ ๊ฒŒ์ดํŠธ๋ฅผ ํ•˜๋‚˜์˜ ๊ฐ€์ค‘์น˜๋กœ ๊ด€๋ฆฌ (ํšจ์œจ์„ฑ)
 50        # ์ˆœ์„œ: forget, input, candidate, output
 51        self.W = np.random.randn(4 * hidden_size, concat_size) * scale
 52        self.b = np.zeros(4 * hidden_size)
 53
 54        # Gradient ์ €์žฅ
 55        self.dW = np.zeros_like(self.W)
 56        self.db = np.zeros_like(self.b)
 57
 58        # Forward pass ์บ์‹œ
 59        self.cache = {}
 60
 61    def forward(
 62        self,
 63        x: np.ndarray,
 64        h_prev: np.ndarray,
 65        c_prev: np.ndarray
 66    ) -> Tuple[np.ndarray, np.ndarray]:
 67        """
 68        Forward pass
 69
 70        Args:
 71            x: (batch_size, input_size) ํ˜„์žฌ ์ž…๋ ฅ
 72            h_prev: (batch_size, hidden_size) ์ด์ „ hidden
 73            c_prev: (batch_size, hidden_size) ์ด์ „ cell
 74
 75        Returns:
 76            h_t: (batch_size, hidden_size) ํ˜„์žฌ hidden
 77            c_t: (batch_size, hidden_size) ํ˜„์žฌ cell
 78        """
 79        batch_size = x.shape[0]
 80        H = self.hidden_size
 81
 82        # Concatenate [h_prev, x]
 83        concat = np.concatenate([h_prev, x], axis=1)  # (batch, hidden+input)
 84
 85        # ๋ชจ๋“  ๊ฒŒ์ดํŠธ ํ•œ๋ฒˆ์— ๊ณ„์‚ฐ
 86        gates = concat @ self.W.T + self.b  # (batch, 4*hidden)
 87
 88        # ๋ถ„๋ฆฌ
 89        f_gate = sigmoid(gates[:, 0:H])           # Forget gate
 90        i_gate = sigmoid(gates[:, H:2*H])         # Input gate
 91        c_tilde = np.tanh(gates[:, 2*H:3*H])      # Candidate
 92        o_gate = sigmoid(gates[:, 3*H:4*H])       # Output gate
 93
 94        # Cell state ์—…๋ฐ์ดํŠธ
 95        c_t = f_gate * c_prev + i_gate * c_tilde
 96
 97        # Hidden state
 98        h_t = o_gate * np.tanh(c_t)
 99
100        # Backward๋ฅผ ์œ„ํ•œ ์บ์‹œ
101        self.cache = {
102            'x': x,
103            'h_prev': h_prev,
104            'c_prev': c_prev,
105            'concat': concat,
106            'f_gate': f_gate,
107            'i_gate': i_gate,
108            'c_tilde': c_tilde,
109            'o_gate': o_gate,
110            'c_t': c_t,
111            'h_t': h_t,
112            'tanh_c_t': np.tanh(c_t),
113        }
114
115        return h_t, c_t
116
117    def backward(
118        self,
119        dh_next: np.ndarray,
120        dc_next: np.ndarray
121    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
122        """
123        Backward pass (BPTT ํ•œ ์Šคํ…)
124
125        Args:
126            dh_next: (batch_size, hidden_size) ๋‹ค์Œ ์‹œ์ ์—์„œ ์˜จ h gradient
127            dc_next: (batch_size, hidden_size) ๋‹ค์Œ ์‹œ์ ์—์„œ ์˜จ c gradient
128
129        Returns:
130            dx: (batch_size, input_size) ์ž…๋ ฅ์— ๋Œ€ํ•œ gradient
131            dh_prev: (batch_size, hidden_size) ์ด์ „ hidden gradient
132            dc_prev: (batch_size, hidden_size) ์ด์ „ cell gradient
133        """
134        cache = self.cache
135        H = self.hidden_size
136
137        # Cell state gradient (๋‘ ๊ฒฝ๋กœ์—์„œ ์˜ด)
138        # 1. dh_next โ†’ o_gate โ†’ tanh(c_t) โ†’ c_t
139        # 2. dc_next (๋‹ค์Œ ์‹œ์ ์—์„œ ์ง์ ‘)
140        do = dh_next * cache['tanh_c_t']
141        dc = dh_next * cache['o_gate'] * tanh_derivative(cache['tanh_c_t'])
142        dc = dc + dc_next  # ๋‘ ๊ฒฝ๋กœ ํ•ฉ์นจ
143
144        # ๊ฐ ๊ฒŒ์ดํŠธ gradient
145        df = dc * cache['c_prev']
146        di = dc * cache['c_tilde']
147        dc_tilde = dc * cache['i_gate']
148
149        # ์ด์ „ cell state gradient (ํ•ต์‹ฌ: forget gate๋ฅผ ํ†ตํ•ด ์ง์ ‘ ์ „ํŒŒ)
150        dc_prev = dc * cache['f_gate']
151
152        # ํ™œ์„ฑํ™” ํ•จ์ˆ˜ derivative
153        df_gate = df * sigmoid_derivative(cache['f_gate'])
154        di_gate = di * sigmoid_derivative(cache['i_gate'])
155        dc_tilde_gate = dc_tilde * tanh_derivative(cache['c_tilde'])
156        do_gate = do * sigmoid_derivative(cache['o_gate'])
157
158        # ๋ชจ๋“  ๊ฒŒ์ดํŠธ gradient ํ•ฉ์น˜๊ธฐ
159        dgates = np.concatenate([df_gate, di_gate, dc_tilde_gate, do_gate], axis=1)
160
161        # ๊ฐ€์ค‘์น˜ gradient
162        self.dW += dgates.T @ cache['concat']
163        self.db += dgates.sum(axis=0)
164
165        # Concat gradient โ†’ h_prev, x gradient
166        dconcat = dgates @ self.W
167        dh_prev = dconcat[:, :H]
168        dx = dconcat[:, H:]
169
170        return dx, dh_prev, dc_prev
171
172    def zero_grad(self):
173        """Gradient ์ดˆ๊ธฐํ™”"""
174        self.dW.fill(0)
175        self.db.fill(0)
176
177
178class LSTMNumPy:
179    """
180    ์ „์ฒด LSTM (์—ฌ๋Ÿฌ ์‹œ์ )
181    """
182
183    def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1):
184        self.input_size = input_size
185        self.hidden_size = hidden_size
186        self.num_layers = num_layers
187
188        # ๋ ˆ์ด์–ด๋ณ„ LSTM Cell
189        self.cells = []
190        for i in range(num_layers):
191            in_size = input_size if i == 0 else hidden_size
192            self.cells.append(LSTMCellNumPy(in_size, hidden_size))
193
194    def forward(
195        self,
196        x: np.ndarray,
197        h_0: np.ndarray = None,
198        c_0: np.ndarray = None
199    ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
200        """
201        Forward pass (์ „์ฒด ์‹œํ€€์Šค)
202
203        Args:
204            x: (seq_len, batch_size, input_size)
205            h_0: (num_layers, batch_size, hidden_size) ์ดˆ๊ธฐ hidden
206            c_0: (num_layers, batch_size, hidden_size) ์ดˆ๊ธฐ cell
207
208        Returns:
209            output: (seq_len, batch_size, hidden_size) ๋ชจ๋“  ์‹œ์ ์˜ hidden
210            (h_n, c_n): ๋งˆ์ง€๋ง‰ hidden/cell
211        """
212        seq_len, batch_size, _ = x.shape
213
214        # ์ดˆ๊ธฐ ์ƒํƒœ
215        if h_0 is None:
216            h_0 = np.zeros((self.num_layers, batch_size, self.hidden_size))
217        if c_0 is None:
218            c_0 = np.zeros((self.num_layers, batch_size, self.hidden_size))
219
220        # ์ถœ๋ ฅ ์ €์žฅ
221        outputs = []
222        h_states = [h_0[i] for i in range(self.num_layers)]
223        c_states = [c_0[i] for i in range(self.num_layers)]
224
225        # ์‹œ๊ฐ„์— ๋”ฐ๋ฅธ ์บ์‹œ (backward์šฉ)
226        self.time_cache = []
227
228        for t in range(seq_len):
229            layer_input = x[t]
230
231            for layer_idx, cell in enumerate(self.cells):
232                h_states[layer_idx], c_states[layer_idx] = cell.forward(
233                    layer_input, h_states[layer_idx], c_states[layer_idx]
234                )
235                layer_input = h_states[layer_idx]
236
237            outputs.append(h_states[-1])
238            self.time_cache.append([cell.cache.copy() for cell in self.cells])
239
240        output = np.stack(outputs, axis=0)
241        h_n = np.stack(h_states, axis=0)
242        c_n = np.stack(c_states, axis=0)
243
244        return output, (h_n, c_n)
245
246    def backward(self, doutput: np.ndarray) -> np.ndarray:
247        """
248        Backward pass (BPTT)
249
250        Args:
251            doutput: (seq_len, batch_size, hidden_size) ์ถœ๋ ฅ์— ๋Œ€ํ•œ gradient
252
253        Returns:
254            dx: (seq_len, batch_size, input_size) ์ž…๋ ฅ์— ๋Œ€ํ•œ gradient
255        """
256        seq_len, batch_size, _ = doutput.shape
257
258        # Gradient ์ดˆ๊ธฐํ™”
259        for cell in self.cells:
260            cell.zero_grad()
261
262        dx = np.zeros((seq_len, batch_size, self.input_size))
263
264        # ๋ ˆ์ด์–ด๋ณ„ gradient ์ „ํŒŒ
265        dh_next = [np.zeros((batch_size, self.hidden_size))
266                   for _ in range(self.num_layers)]
267        dc_next = [np.zeros((batch_size, self.hidden_size))
268                   for _ in range(self.num_layers)]
269
270        # ์‹œ๊ฐ„ ์—ญ์ˆœ
271        for t in reversed(range(seq_len)):
272            # ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด์— ์ถœ๋ ฅ gradient ๋”ํ•จ
273            dh_next[-1] += doutput[t]
274
275            # ๋ ˆ์ด์–ด ์—ญ์ˆœ (๊นŠ์€ ๋ ˆ์ด์–ด โ†’ ์–•์€ ๋ ˆ์ด์–ด)
276            for layer_idx in reversed(range(self.num_layers)):
277                cell = self.cells[layer_idx]
278                cell.cache = self.time_cache[t][layer_idx]
279
280                dx_layer, dh_prev, dc_prev = cell.backward(
281                    dh_next[layer_idx], dc_next[layer_idx]
282                )
283
284                # ๋‹ค์Œ ์‹œ์ ์œผ๋กœ ์ „ํŒŒ
285                dh_next[layer_idx] = dh_prev
286                dc_next[layer_idx] = dc_prev
287
288                # ์ด์ „ ๋ ˆ์ด์–ด๋กœ ์ „ํŒŒ
289                if layer_idx > 0:
290                    dh_next[layer_idx - 1] += dx_layer
291                else:
292                    dx[t] = dx_layer
293
294        return dx
295
296    def parameters(self) -> List[np.ndarray]:
297        """๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ ๋ฐ˜ํ™˜"""
298        params = []
299        for cell in self.cells:
300            params.extend([cell.W, cell.b])
301        return params
302
303    def gradients(self) -> List[np.ndarray]:
304        """๋ชจ๋“  gradient ๋ฐ˜ํ™˜"""
305        grads = []
306        for cell in self.cells:
307            grads.extend([cell.dW, cell.db])
308        return grads
309
310
311def sgd_update(params: List[np.ndarray], grads: List[np.ndarray], lr: float):
312    """SGD ์—…๋ฐ์ดํŠธ"""
313    for param, grad in zip(params, grads):
314        param -= lr * grad
315
316
317def clip_gradients(grads: List[np.ndarray], max_norm: float = 5.0):
318    """Gradient clipping (exploding gradient ๋ฐฉ์ง€)"""
319    total_norm = np.sqrt(sum(np.sum(g ** 2) for g in grads))
320    if total_norm > max_norm:
321        scale = max_norm / (total_norm + 1e-6)
322        for g in grads:
323            g *= scale
324
325
326# ๊ฐ„๋‹จํ•œ ํ…Œ์ŠคํŠธ
327def test_lstm():
328    print("=== LSTM NumPy Test ===\n")
329
330    # ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
331    batch_size = 2
332    seq_len = 5
333    input_size = 10
334    hidden_size = 20
335    num_layers = 2
336
337    # ๋ชจ๋ธ
338    lstm = LSTMNumPy(input_size, hidden_size, num_layers)
339
340    # ๋”๋ฏธ ์ž…๋ ฅ
341    x = np.random.randn(seq_len, batch_size, input_size)
342
343    # Forward
344    output, (h_n, c_n) = lstm.forward(x)
345
346    print(f"Input shape: {x.shape}")
347    print(f"Output shape: {output.shape}")
348    print(f"h_n shape: {h_n.shape}")
349    print(f"c_n shape: {c_n.shape}")
350
351    # Backward (๋งˆ์ง€๋ง‰ ์ถœ๋ ฅ์— ๋Œ€ํ•œ loss ๊ฐ€์ •)
352    loss = np.sum(output[-1] ** 2)  # ๋”๋ฏธ loss
353    doutput = np.zeros_like(output)
354    doutput[-1] = 2 * output[-1]
355
356    dx = lstm.backward(doutput)
357
358    print(f"\ndx shape: {dx.shape}")
359    print(f"Gradient norms:")
360    for i, (param, grad) in enumerate(zip(lstm.parameters(), lstm.gradients())):
361        print(f"  Layer {i//2}, {'W' if i%2==0 else 'b'}: "
362              f"param norm={np.linalg.norm(param):.4f}, "
363              f"grad norm={np.linalg.norm(grad):.4f}")
364
365
366def train_sequence_classification():
367    """๊ฐ„๋‹จํ•œ ์‹œํ€€์Šค ๋ถ„๋ฅ˜ ์˜ˆ์ œ"""
368    print("\n=== Sequence Classification ===\n")
369
370    np.random.seed(42)
371
372    # ๋ฐ์ดํ„ฐ: ์‹œํ€€์Šค์˜ ํ‰๊ท ์ด ์–‘์ˆ˜๋ฉด 1, ์Œ์ˆ˜๋ฉด 0
373    def generate_data(n_samples, seq_len, input_size):
374        X = np.random.randn(n_samples, seq_len, input_size)
375        y = (X.mean(axis=(1, 2)) > 0).astype(int)
376        return X, y
377
378    X_train, y_train = generate_data(100, 10, 5)
379    X_test, y_test = generate_data(20, 10, 5)
380
381    # ๋ชจ๋ธ
382    lstm = LSTMNumPy(input_size=5, hidden_size=16, num_layers=1)
383
384    # ์ถœ๋ ฅ ๋ ˆ์ด์–ด
385    W_out = np.random.randn(2, 16) * 0.1
386    b_out = np.zeros(2)
387
388    lr = 0.01
389    epochs = 50
390
391    for epoch in range(epochs):
392        total_loss = 0
393        correct = 0
394
395        for i in range(len(X_train)):
396            x = X_train[i:i+1].transpose(1, 0, 2)  # (seq, 1, input)
397            target = y_train[i]
398
399            # Forward
400            output, _ = lstm.forward(x)
401            last_hidden = output[-1]  # (1, hidden)
402
403            # ๋ถ„๋ฅ˜
404            logits = last_hidden @ W_out.T + b_out
405            probs = np.exp(logits - logits.max()) / np.exp(logits - logits.max()).sum()
406
407            # Loss (cross entropy)
408            loss = -np.log(probs[0, target] + 1e-7)
409            total_loss += loss
410
411            # Accuracy
412            pred = logits.argmax()
413            correct += (pred == target)
414
415            # Backward
416            dlogits = probs.copy()
417            dlogits[0, target] -= 1
418
419            dW_out = dlogits.T @ last_hidden
420            db_out = dlogits.sum(axis=0)
421            dlast_hidden = dlogits @ W_out
422
423            doutput = np.zeros_like(output)
424            doutput[-1] = dlast_hidden
425
426            lstm.backward(doutput)
427
428            # Gradient clipping
429            clip_gradients(lstm.gradients())
430
431            # Update
432            sgd_update(lstm.parameters(), lstm.gradients(), lr)
433            W_out -= lr * dW_out
434            b_out -= lr * db_out
435
436        if (epoch + 1) % 10 == 0:
437            acc = correct / len(X_train)
438            print(f"Epoch {epoch+1}: Loss={total_loss/len(X_train):.4f}, Acc={acc:.2f}")
439
440    # ํ…Œ์ŠคํŠธ
441    test_correct = 0
442    for i in range(len(X_test)):
443        x = X_test[i:i+1].transpose(1, 0, 2)
444        target = y_test[i]
445
446        output, _ = lstm.forward(x)
447        logits = output[-1] @ W_out.T + b_out
448        pred = logits.argmax()
449        test_correct += (pred == target)
450
451    print(f"\nTest Accuracy: {test_correct/len(X_test):.2f}")
452
453
454if __name__ == "__main__":
455    test_lstm()
456    train_sequence_classification()