1"""
2NumPy LSTM From-Scratch ๊ตฌํ
3
4๋ชจ๋ ๊ฒ์ดํธ ์ฐ์ฐ๊ณผ BPTT๋ฅผ ์ง์ ๊ตฌํ
5"""
6
7import numpy as np
8from typing import Tuple, Dict, List
9
10
11def sigmoid(x: np.ndarray) -> np.ndarray:
12 """Sigmoid ํ์ฑํ (์์น ์์ ์ฑ ๊ณ ๋ ค)"""
13 return np.where(x >= 0,
14 1 / (1 + np.exp(-x)),
15 np.exp(x) / (1 + np.exp(x)))
16
17
18def sigmoid_derivative(s: np.ndarray) -> np.ndarray:
19 """Sigmoid์ derivative: s * (1 - s)"""
20 return s * (1 - s)
21
22
23def tanh_derivative(t: np.ndarray) -> np.ndarray:
24 """Tanh์ derivative: 1 - t^2"""
25 return 1 - t ** 2
26
27
28class LSTMCellNumPy:
29 """
30 ๋จ์ผ LSTM Cell (NumPy ๊ตฌํ)
31
32 ์์:
33 f_t = ฯ(W_f ยท [h_{t-1}, x_t] + b_f)
34 i_t = ฯ(W_i ยท [h_{t-1}, x_t] + b_i)
35 cฬ_t = tanh(W_c ยท [h_{t-1}, x_t] + b_c)
36 c_t = f_t โ c_{t-1} + i_t โ cฬ_t
37 o_t = ฯ(W_o ยท [h_{t-1}, x_t] + b_o)
38 h_t = o_t โ tanh(c_t)
39 """
40
41 def __init__(self, input_size: int, hidden_size: int):
42 self.input_size = input_size
43 self.hidden_size = hidden_size
44
45 # Xavier ์ด๊ธฐํ
46 concat_size = input_size + hidden_size
47 scale = np.sqrt(2.0 / (concat_size + hidden_size))
48
49 # 4๊ฐ ๊ฒ์ดํธ๋ฅผ ํ๋์ ๊ฐ์ค์น๋ก ๊ด๋ฆฌ (ํจ์จ์ฑ)
50 # ์์: forget, input, candidate, output
51 self.W = np.random.randn(4 * hidden_size, concat_size) * scale
52 self.b = np.zeros(4 * hidden_size)
53
54 # Gradient ์ ์ฅ
55 self.dW = np.zeros_like(self.W)
56 self.db = np.zeros_like(self.b)
57
58 # Forward pass ์บ์
59 self.cache = {}
60
61 def forward(
62 self,
63 x: np.ndarray,
64 h_prev: np.ndarray,
65 c_prev: np.ndarray
66 ) -> Tuple[np.ndarray, np.ndarray]:
67 """
68 Forward pass
69
70 Args:
71 x: (batch_size, input_size) ํ์ฌ ์
๋ ฅ
72 h_prev: (batch_size, hidden_size) ์ด์ hidden
73 c_prev: (batch_size, hidden_size) ์ด์ cell
74
75 Returns:
76 h_t: (batch_size, hidden_size) ํ์ฌ hidden
77 c_t: (batch_size, hidden_size) ํ์ฌ cell
78 """
79 batch_size = x.shape[0]
80 H = self.hidden_size
81
82 # Concatenate [h_prev, x]
83 concat = np.concatenate([h_prev, x], axis=1) # (batch, hidden+input)
84
85 # ๋ชจ๋ ๊ฒ์ดํธ ํ๋ฒ์ ๊ณ์ฐ
86 gates = concat @ self.W.T + self.b # (batch, 4*hidden)
87
88 # ๋ถ๋ฆฌ
89 f_gate = sigmoid(gates[:, 0:H]) # Forget gate
90 i_gate = sigmoid(gates[:, H:2*H]) # Input gate
91 c_tilde = np.tanh(gates[:, 2*H:3*H]) # Candidate
92 o_gate = sigmoid(gates[:, 3*H:4*H]) # Output gate
93
94 # Cell state ์
๋ฐ์ดํธ
95 c_t = f_gate * c_prev + i_gate * c_tilde
96
97 # Hidden state
98 h_t = o_gate * np.tanh(c_t)
99
100 # Backward๋ฅผ ์ํ ์บ์
101 self.cache = {
102 'x': x,
103 'h_prev': h_prev,
104 'c_prev': c_prev,
105 'concat': concat,
106 'f_gate': f_gate,
107 'i_gate': i_gate,
108 'c_tilde': c_tilde,
109 'o_gate': o_gate,
110 'c_t': c_t,
111 'h_t': h_t,
112 'tanh_c_t': np.tanh(c_t),
113 }
114
115 return h_t, c_t
116
117 def backward(
118 self,
119 dh_next: np.ndarray,
120 dc_next: np.ndarray
121 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
122 """
123 Backward pass (BPTT ํ ์คํ
)
124
125 Args:
126 dh_next: (batch_size, hidden_size) ๋ค์ ์์ ์์ ์จ h gradient
127 dc_next: (batch_size, hidden_size) ๋ค์ ์์ ์์ ์จ c gradient
128
129 Returns:
130 dx: (batch_size, input_size) ์
๋ ฅ์ ๋ํ gradient
131 dh_prev: (batch_size, hidden_size) ์ด์ hidden gradient
132 dc_prev: (batch_size, hidden_size) ์ด์ cell gradient
133 """
134 cache = self.cache
135 H = self.hidden_size
136
137 # Cell state gradient (๋ ๊ฒฝ๋ก์์ ์ด)
138 # 1. dh_next โ o_gate โ tanh(c_t) โ c_t
139 # 2. dc_next (๋ค์ ์์ ์์ ์ง์ )
140 do = dh_next * cache['tanh_c_t']
141 dc = dh_next * cache['o_gate'] * tanh_derivative(cache['tanh_c_t'])
142 dc = dc + dc_next # ๋ ๊ฒฝ๋ก ํฉ์นจ
143
144 # ๊ฐ ๊ฒ์ดํธ gradient
145 df = dc * cache['c_prev']
146 di = dc * cache['c_tilde']
147 dc_tilde = dc * cache['i_gate']
148
149 # ์ด์ cell state gradient (ํต์ฌ: forget gate๋ฅผ ํตํด ์ง์ ์ ํ)
150 dc_prev = dc * cache['f_gate']
151
152 # ํ์ฑํ ํจ์ derivative
153 df_gate = df * sigmoid_derivative(cache['f_gate'])
154 di_gate = di * sigmoid_derivative(cache['i_gate'])
155 dc_tilde_gate = dc_tilde * tanh_derivative(cache['c_tilde'])
156 do_gate = do * sigmoid_derivative(cache['o_gate'])
157
158 # ๋ชจ๋ ๊ฒ์ดํธ gradient ํฉ์น๊ธฐ
159 dgates = np.concatenate([df_gate, di_gate, dc_tilde_gate, do_gate], axis=1)
160
161 # ๊ฐ์ค์น gradient
162 self.dW += dgates.T @ cache['concat']
163 self.db += dgates.sum(axis=0)
164
165 # Concat gradient โ h_prev, x gradient
166 dconcat = dgates @ self.W
167 dh_prev = dconcat[:, :H]
168 dx = dconcat[:, H:]
169
170 return dx, dh_prev, dc_prev
171
172 def zero_grad(self):
173 """Gradient ์ด๊ธฐํ"""
174 self.dW.fill(0)
175 self.db.fill(0)
176
177
178class LSTMNumPy:
179 """
180 ์ ์ฒด LSTM (์ฌ๋ฌ ์์ )
181 """
182
183 def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1):
184 self.input_size = input_size
185 self.hidden_size = hidden_size
186 self.num_layers = num_layers
187
188 # ๋ ์ด์ด๋ณ LSTM Cell
189 self.cells = []
190 for i in range(num_layers):
191 in_size = input_size if i == 0 else hidden_size
192 self.cells.append(LSTMCellNumPy(in_size, hidden_size))
193
194 def forward(
195 self,
196 x: np.ndarray,
197 h_0: np.ndarray = None,
198 c_0: np.ndarray = None
199 ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
200 """
201 Forward pass (์ ์ฒด ์ํ์ค)
202
203 Args:
204 x: (seq_len, batch_size, input_size)
205 h_0: (num_layers, batch_size, hidden_size) ์ด๊ธฐ hidden
206 c_0: (num_layers, batch_size, hidden_size) ์ด๊ธฐ cell
207
208 Returns:
209 output: (seq_len, batch_size, hidden_size) ๋ชจ๋ ์์ ์ hidden
210 (h_n, c_n): ๋ง์ง๋ง hidden/cell
211 """
212 seq_len, batch_size, _ = x.shape
213
214 # ์ด๊ธฐ ์ํ
215 if h_0 is None:
216 h_0 = np.zeros((self.num_layers, batch_size, self.hidden_size))
217 if c_0 is None:
218 c_0 = np.zeros((self.num_layers, batch_size, self.hidden_size))
219
220 # ์ถ๋ ฅ ์ ์ฅ
221 outputs = []
222 h_states = [h_0[i] for i in range(self.num_layers)]
223 c_states = [c_0[i] for i in range(self.num_layers)]
224
225 # ์๊ฐ์ ๋ฐ๋ฅธ ์บ์ (backward์ฉ)
226 self.time_cache = []
227
228 for t in range(seq_len):
229 layer_input = x[t]
230
231 for layer_idx, cell in enumerate(self.cells):
232 h_states[layer_idx], c_states[layer_idx] = cell.forward(
233 layer_input, h_states[layer_idx], c_states[layer_idx]
234 )
235 layer_input = h_states[layer_idx]
236
237 outputs.append(h_states[-1])
238 self.time_cache.append([cell.cache.copy() for cell in self.cells])
239
240 output = np.stack(outputs, axis=0)
241 h_n = np.stack(h_states, axis=0)
242 c_n = np.stack(c_states, axis=0)
243
244 return output, (h_n, c_n)
245
246 def backward(self, doutput: np.ndarray) -> np.ndarray:
247 """
248 Backward pass (BPTT)
249
250 Args:
251 doutput: (seq_len, batch_size, hidden_size) ์ถ๋ ฅ์ ๋ํ gradient
252
253 Returns:
254 dx: (seq_len, batch_size, input_size) ์
๋ ฅ์ ๋ํ gradient
255 """
256 seq_len, batch_size, _ = doutput.shape
257
258 # Gradient ์ด๊ธฐํ
259 for cell in self.cells:
260 cell.zero_grad()
261
262 dx = np.zeros((seq_len, batch_size, self.input_size))
263
264 # ๋ ์ด์ด๋ณ gradient ์ ํ
265 dh_next = [np.zeros((batch_size, self.hidden_size))
266 for _ in range(self.num_layers)]
267 dc_next = [np.zeros((batch_size, self.hidden_size))
268 for _ in range(self.num_layers)]
269
270 # ์๊ฐ ์ญ์
271 for t in reversed(range(seq_len)):
272 # ๋ง์ง๋ง ๋ ์ด์ด์ ์ถ๋ ฅ gradient ๋ํจ
273 dh_next[-1] += doutput[t]
274
275 # ๋ ์ด์ด ์ญ์ (๊น์ ๋ ์ด์ด โ ์์ ๋ ์ด์ด)
276 for layer_idx in reversed(range(self.num_layers)):
277 cell = self.cells[layer_idx]
278 cell.cache = self.time_cache[t][layer_idx]
279
280 dx_layer, dh_prev, dc_prev = cell.backward(
281 dh_next[layer_idx], dc_next[layer_idx]
282 )
283
284 # ๋ค์ ์์ ์ผ๋ก ์ ํ
285 dh_next[layer_idx] = dh_prev
286 dc_next[layer_idx] = dc_prev
287
288 # ์ด์ ๋ ์ด์ด๋ก ์ ํ
289 if layer_idx > 0:
290 dh_next[layer_idx - 1] += dx_layer
291 else:
292 dx[t] = dx_layer
293
294 return dx
295
296 def parameters(self) -> List[np.ndarray]:
297 """๋ชจ๋ ํ๋ผ๋ฏธํฐ ๋ฐํ"""
298 params = []
299 for cell in self.cells:
300 params.extend([cell.W, cell.b])
301 return params
302
303 def gradients(self) -> List[np.ndarray]:
304 """๋ชจ๋ gradient ๋ฐํ"""
305 grads = []
306 for cell in self.cells:
307 grads.extend([cell.dW, cell.db])
308 return grads
309
310
311def sgd_update(params: List[np.ndarray], grads: List[np.ndarray], lr: float):
312 """SGD ์
๋ฐ์ดํธ"""
313 for param, grad in zip(params, grads):
314 param -= lr * grad
315
316
317def clip_gradients(grads: List[np.ndarray], max_norm: float = 5.0):
318 """Gradient clipping (exploding gradient ๋ฐฉ์ง)"""
319 total_norm = np.sqrt(sum(np.sum(g ** 2) for g in grads))
320 if total_norm > max_norm:
321 scale = max_norm / (total_norm + 1e-6)
322 for g in grads:
323 g *= scale
324
325
326# ๊ฐ๋จํ ํ
์คํธ
327def test_lstm():
328 print("=== LSTM NumPy Test ===\n")
329
330 # ํ์ดํผํ๋ผ๋ฏธํฐ
331 batch_size = 2
332 seq_len = 5
333 input_size = 10
334 hidden_size = 20
335 num_layers = 2
336
337 # ๋ชจ๋ธ
338 lstm = LSTMNumPy(input_size, hidden_size, num_layers)
339
340 # ๋๋ฏธ ์
๋ ฅ
341 x = np.random.randn(seq_len, batch_size, input_size)
342
343 # Forward
344 output, (h_n, c_n) = lstm.forward(x)
345
346 print(f"Input shape: {x.shape}")
347 print(f"Output shape: {output.shape}")
348 print(f"h_n shape: {h_n.shape}")
349 print(f"c_n shape: {c_n.shape}")
350
351 # Backward (๋ง์ง๋ง ์ถ๋ ฅ์ ๋ํ loss ๊ฐ์ )
352 loss = np.sum(output[-1] ** 2) # ๋๋ฏธ loss
353 doutput = np.zeros_like(output)
354 doutput[-1] = 2 * output[-1]
355
356 dx = lstm.backward(doutput)
357
358 print(f"\ndx shape: {dx.shape}")
359 print(f"Gradient norms:")
360 for i, (param, grad) in enumerate(zip(lstm.parameters(), lstm.gradients())):
361 print(f" Layer {i//2}, {'W' if i%2==0 else 'b'}: "
362 f"param norm={np.linalg.norm(param):.4f}, "
363 f"grad norm={np.linalg.norm(grad):.4f}")
364
365
366def train_sequence_classification():
367 """๊ฐ๋จํ ์ํ์ค ๋ถ๋ฅ ์์ """
368 print("\n=== Sequence Classification ===\n")
369
370 np.random.seed(42)
371
372 # ๋ฐ์ดํฐ: ์ํ์ค์ ํ๊ท ์ด ์์๋ฉด 1, ์์๋ฉด 0
373 def generate_data(n_samples, seq_len, input_size):
374 X = np.random.randn(n_samples, seq_len, input_size)
375 y = (X.mean(axis=(1, 2)) > 0).astype(int)
376 return X, y
377
378 X_train, y_train = generate_data(100, 10, 5)
379 X_test, y_test = generate_data(20, 10, 5)
380
381 # ๋ชจ๋ธ
382 lstm = LSTMNumPy(input_size=5, hidden_size=16, num_layers=1)
383
384 # ์ถ๋ ฅ ๋ ์ด์ด
385 W_out = np.random.randn(2, 16) * 0.1
386 b_out = np.zeros(2)
387
388 lr = 0.01
389 epochs = 50
390
391 for epoch in range(epochs):
392 total_loss = 0
393 correct = 0
394
395 for i in range(len(X_train)):
396 x = X_train[i:i+1].transpose(1, 0, 2) # (seq, 1, input)
397 target = y_train[i]
398
399 # Forward
400 output, _ = lstm.forward(x)
401 last_hidden = output[-1] # (1, hidden)
402
403 # ๋ถ๋ฅ
404 logits = last_hidden @ W_out.T + b_out
405 probs = np.exp(logits - logits.max()) / np.exp(logits - logits.max()).sum()
406
407 # Loss (cross entropy)
408 loss = -np.log(probs[0, target] + 1e-7)
409 total_loss += loss
410
411 # Accuracy
412 pred = logits.argmax()
413 correct += (pred == target)
414
415 # Backward
416 dlogits = probs.copy()
417 dlogits[0, target] -= 1
418
419 dW_out = dlogits.T @ last_hidden
420 db_out = dlogits.sum(axis=0)
421 dlast_hidden = dlogits @ W_out
422
423 doutput = np.zeros_like(output)
424 doutput[-1] = dlast_hidden
425
426 lstm.backward(doutput)
427
428 # Gradient clipping
429 clip_gradients(lstm.gradients())
430
431 # Update
432 sgd_update(lstm.parameters(), lstm.gradients(), lr)
433 W_out -= lr * dW_out
434 b_out -= lr * db_out
435
436 if (epoch + 1) % 10 == 0:
437 acc = correct / len(X_train)
438 print(f"Epoch {epoch+1}: Loss={total_loss/len(X_train):.4f}, Acc={acc:.2f}")
439
440 # ํ
์คํธ
441 test_correct = 0
442 for i in range(len(X_test)):
443 x = X_test[i:i+1].transpose(1, 0, 2)
444 target = y_test[i]
445
446 output, _ = lstm.forward(x)
447 logits = output[-1] @ W_out.T + b_out
448 pred = logits.argmax()
449 test_correct += (pred == target)
450
451 print(f"\nTest Accuracy: {test_correct/len(X_test):.2f}")
452
453
454if __name__ == "__main__":
455 test_lstm()
456 train_sequence_classification()