06. LSTM / GRU
06. LSTM / GRU¶
κ°μ¶
LSTM(Long Short-Term Memory)κ³Ό GRU(Gated Recurrent Unit)λ vanishing gradient λ¬Έμ λ₯Ό ν΄κ²°ν μν μ κ²½λ§(RNN) λ³νμ λλ€. κ²μ΄νΈ λ©μ»€λμ¦μ ν΅ν΄ μ₯κΈ° μμ‘΄μ±(long-term dependency)μ ν¨κ³Όμ μΌλ‘ νμ΅ν©λλ€.
μνμ λ°°κ²½¶
1. Vanilla RNNμ λ¬Έμ ¶
Vanilla RNN:
h_t = tanh(W_h Β· h_{t-1} + W_x Β· x_t + b)
λ¬Έμ : Backpropagation Through Time (BPTT)
βL/βh_0 = βL/βh_T Β· βh_T/βh_{T-1} Β· ... Β· βh_1/βh_0
= βL/βh_T Β· Ξ _{t=1}^{T} βh_t/βh_{t-1}
βh_t/βh_{t-1} = diag(1 - tanhΒ²(Β·)) Β· W_h
κ²°κ³Ό:
- |eigenvalue(W_h)| < 1 β Vanishing gradient
- |eigenvalue(W_h)| > 1 β Exploding gradient
β κΈ΄ μνμ€μμ μ΄κΈ° μ 보 νμ΅ λΆκ°
2. LSTM μμ¶
μ
λ ₯: x_t (νμ¬ μ
λ ₯), h_{t-1} (μ΄μ hidden), c_{t-1} (μ΄μ cell)
μΆλ ₯: h_t (νμ¬ hidden), c_t (νμ¬ cell)
1. Forget Gate (무μμ λ²λ¦΄κΉ?)
f_t = Ο(W_f Β· [h_{t-1}, x_t] + b_f)
2. Input Gate (무μμ μ μ₯ν κΉ?)
i_t = Ο(W_i Β· [h_{t-1}, x_t] + b_i)
3. Candidate Cell (μλ‘μ΄ μ 보)
cΜ_t = tanh(W_c Β· [h_{t-1}, x_t] + b_c)
4. Cell State μ
λ°μ΄νΈ
c_t = f_t β c_{t-1} + i_t β cΜ_t
β μ΄μ μ 보 β μ μ 보
5. Output Gate (무μμ μΆλ ₯ν κΉ?)
o_t = Ο(W_o Β· [h_{t-1}, x_t] + b_o)
6. Hidden State
h_t = o_t β tanh(c_t)
Ο: sigmoid (0~1)
β: element-wise κ³±
3. GRU μμ¶
GRU: LSTMμ κ°μν λ²μ (cell state μμ)
1. Reset Gate (μ΄μ μ 보 μΌλ§λ 무μ?)
r_t = Ο(W_r Β· [h_{t-1}, x_t] + b_r)
2. Update Gate (μ΄μ vs μ μ 보 λΉμ¨)
z_t = Ο(W_z Β· [h_{t-1}, x_t] + b_z)
3. Candidate Hidden
hΜ_t = tanh(W_h Β· [r_t β h_{t-1}, x_t] + b_h)
4. Hidden State
h_t = (1 - z_t) β h_{t-1} + z_t β hΜ_t
β μ΄μ μ 보 μ μ§ β μ μ 보
LSTM vs GRU:
- GRU: 2κ° κ²μ΄νΈ (reset, update)
- LSTM: 3κ° κ²μ΄νΈ (forget, input, output) + cell state
- GRUκ° νλΌλ―Έν° 25% μ μ
- μ±λ₯μ taskμ λ°λΌ λΉμ·
4. μ Gradientκ° λ³΄μ‘΄λλκ°?¶
LSTM Cell State μ
λ°μ΄νΈ:
c_t = f_t β c_{t-1} + i_t β cΜ_t
Gradient:
βc_t/βc_{t-1} = f_t (forget gate)
f_t β 1μ΄λ©΄ gradientκ° κ±°μ κ·Έλλ‘ μ ν!
μ΄κ²μ΄ "highway" μν :
- Cell stateκ° λ³ν μμ΄ νλ₯Ό μ μμ
- κΈ΄ μνμ€μμλ gradient μ μ§
- λͺ¨λΈμ΄ f_tλ₯Ό νμ΅ν΄ μ΄λ€ μ 보λ₯Ό μ μ§ν μ§ κ²°μ
μν€ν μ²¶
LSTM ꡬ쑰 λ€μ΄μ΄κ·Έλ¨¶
βββββββββββββββββββββββββββββββββββ
β Cell State c_t β
c_{t-1} ββββββββββββββΊβ βββββββββββββ+βββββββββββββΊ c_t
β β forget β input β
β f_t i_t β cΜ_t β
β β
β ββββββββββββββββββββββββ β
β β Ο Ο tanh Ο β β
β β f i cΜ o β β
β ββββββββββββββββββββββββ β
β β β
β [h_{t-1}, x_t] β
h_{t-1} ββββββββββββββΊβ βββββΊ h_t
β β βββ tanh(c_t) β
β o_t β
βββββββββββββββββββββββββββββββββββ
β
x_t
GRU ꡬ쑰 λ€μ΄μ΄κ·Έλ¨¶
βββββββββββββββββββββββββββββββββββ
β β
h_{t-1} ββββββββββββββΊβ βββββββββββββ+βββββββββββββββββΊβ h_t
β (1-z) z β hΜ β
β β β
β ββββββββββββββββββββββββ β
β β Ο Ο tanh β β
β β r z hΜ β β
β ββββββββββββββββββββββββ β
β β β
β [h_{t-1}, x_t] β
β [rβh_{t-1}, x_t] β
βββββββββββββββββββββββββββββββββββ
β
x_t
νλΌλ―Έν° μ¶
LSTM:
4κ° κ²μ΄νΈ Γ (input_size Γ hidden_size + hidden_size Γ hidden_size + hidden_size)
= 4 Γ (input_size + hidden_size + 1) Γ hidden_size
μ: input=128, hidden=256
= 4 Γ (128 + 256 + 1) Γ 256 = 394,240
GRU:
3κ° κ²μ΄νΈ
= 3 Γ (input_size + hidden_size + 1) Γ hidden_size
μ: input=128, hidden=256
= 3 Γ (128 + 256 + 1) Γ 256 = 295,680 (25% μ μ)
νμΌ κ΅¬μ‘°¶
06_LSTM_GRU/
βββ README.md # μ΄ νμΌ
βββ numpy/
β βββ lstm_numpy.py # NumPy LSTM (forward + backward)
β βββ gru_numpy.py # NumPy GRU
βββ pytorch_lowlevel/
β βββ lstm_gru_lowlevel.py # F.linear μ¬μ©, nn.LSTM λ―Έμ¬μ©
βββ paper/
β βββ lstm_paper.py # μλ³Έ 1997 λ
Όλ¬Έ ꡬν
β βββ gru_paper.py # 2014 λ
Όλ¬Έ ꡬν
βββ exercises/
βββ 01_gradient_flow.md # BPTT gradient λΆμ
βββ 02_sequence_tasks.md # μνμ€ λΆλ₯/μμ±
ν΅μ¬ κ°λ ¶
1. κ²μ΄νΈμ μν ¶
Forget Gate (f):
- 1μ κ°κΉμ: μ΄μ μ 보 μ μ§
- 0μ κ°κΉμ: μ΄μ μ 보 μμ
- μ: μ λ¬Έμ₯ μμ μ μ΄μ λ¬Έλ§₯ 리μ
Input Gate (i):
- μ μ 보μ μ€μλ κ²°μ
- Candidate (cΜ)μ ν¨κ» μλ
Output Gate (o):
- Cell state μ€ λ¬΄μμ hiddenμΌλ‘ λ
ΈμΆ
- μ: λ΄λΆμ μΌλ‘λ κΈ°μ΅νμ§λ§ μΆλ ₯νμ§ μμ
2. Peephole Connection (μ νμ )¶
κΈ°λ³Έ LSTM: κ²μ΄νΈκ° [h_{t-1}, x_t]λ§ μ°Έμ‘°
Peephole: κ²μ΄νΈκ° c_{t-1}λ μ°Έμ‘°
f_t = Ο(W_f Β· [h_{t-1}, x_t] + W_{cf} Β· c_{t-1} + b_f)
i_t = Ο(W_i Β· [h_{t-1}, x_t] + W_{ci} Β· c_{t-1} + b_i)
o_t = Ο(W_o Β· [h_{t-1}, x_t] + W_{co} Β· c_t + b_o)
ν¨κ³Ό: cell state μ 보λ₯Ό κ²μ΄νΈ κ²°μ μ μ§μ νμ©
3. Bidirectional LSTM¶
μνμ€λ₯Ό μλ°©ν₯μΌλ‘ μ²λ¦¬:
Forward: β h_1 β h_2 β h_3 β h_4 β
Backward: β h_4 β h_3 β h_2 β h_1 β
μΆλ ₯: [forward_h_t; backward_h_t] (concatenate)
μ₯μ :
- λ―Έλ 컨ν
μ€νΈλ νμ©
- NER, POS taggingμ ν¨κ³Όμ
- Transformer λ±μ₯ μ NLP νμ€
4. Stacked LSTM¶
μ¬λ¬ LSTM λ μ΄μ΄ μκΈ°:
x_t β LSTM_1 β h_t^1 β LSTM_2 β h_t^2 β ... β output
κ° λ μ΄μ΄:
- μ΄μ λ μ΄μ΄μ hiddenμ μ
λ ₯μΌλ‘
- λ μΆμμ μΈ νν νμ΅
μ£Όμ: κΉμ΄μ§μλ‘ νμ΅ μ΄λ €μ
- Dropout νμ (νΉν λ μ΄μ΄ κ°)
- Residual connection λμ
ꡬν λ 벨¶
Level 1: NumPy From-Scratch (numpy/)¶
- λͺ¨λ κ²μ΄νΈ μ°μ° μ§μ ꡬν
- BPTT gradient μλ κ³μ°
- Cell state gradient μ λ
Level 2: PyTorch Low-Level (pytorch_lowlevel/)¶
- F.linear, torch.sigmoid, torch.tanh μ¬μ©
- nn.LSTM λ―Έμ¬μ©
- νλΌλ―Έν° μλ κ΄λ¦¬
- Bidirectional, Stacked ꡬν
Level 3: Paper Implementation (paper/)¶
- Hochreiter & Schmidhuber (1997) LSTM
- Cho et al. (2014) GRU
- Peephole connections
νμ΅ μ²΄ν¬λ¦¬μ€νΈ¶
- [ ] Vanilla RNNμ vanishing gradient λ¬Έμ
- [ ] LSTM 4κ° μμ μκΈ°
- [ ] GRU 3κ° μμ μκΈ°
- [ ] Cell stateκ° gradientλ₯Ό 보쑴νλ μ΄μ
- [ ] κ° κ²μ΄νΈμ μν μ€λͺ
- [ ] LSTM vs GRU μ₯λ¨μ
- [ ] BPTT ꡬν
- [ ] Bidirectional, Stacked ꡬ쑰
μ°Έκ³ μλ£¶
- Hochreiter & Schmidhuber (1997). "Long Short-Term Memory"
- Cho et al. (2014). "Learning Phrase Representations using RNN Encoder-Decoder"
- colah's blog: Understanding LSTM
- d2l.ai: LSTM
- ../02_MLP/README.md