15. LSTM / GRU
15. LSTM / GRU¶
Previous: LSTM and GRU | Next: Attention & Transformer
Overview¶
LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit) are Recurrent Neural Network (RNN) variants that solve the vanishing gradient problem. They effectively learn long-term dependencies through gating mechanisms.
Mathematical Background¶
1. Vanilla RNN Problem¶
Vanilla RNN:
h_t = tanh(W_h Β· h_{t-1} + W_x Β· x_t + b)
Problem: Backpropagation Through Time (BPTT)
βL/βh_0 = βL/βh_T Β· βh_T/βh_{T-1} Β· ... Β· βh_1/βh_0
= βL/βh_T Β· Ξ _{t=1}^{T} βh_t/βh_{t-1}
βh_t/βh_{t-1} = diag(1 - tanhΒ²(Β·)) Β· W_h
Result:
- |eigenvalue(W_h)| < 1 β Vanishing gradient
- |eigenvalue(W_h)| > 1 β Exploding gradient
β Cannot learn initial information in long sequences
2. LSTM Equations¶
Input: x_t (current input), h_{t-1} (previous hidden), c_{t-1} (previous cell)
Output: h_t (current hidden), c_t (current cell)
1. Forget Gate (what to discard?)
f_t = Ο(W_f Β· [h_{t-1}, x_t] + b_f)
2. Input Gate (what to store?)
i_t = Ο(W_i Β· [h_{t-1}, x_t] + b_i)
3. Candidate Cell (new information)
cΜ_t = tanh(W_c Β· [h_{t-1}, x_t] + b_c)
4. Cell State Update
c_t = f_t β c_{t-1} + i_t β cΜ_t
β previous info β new info
5. Output Gate (what to output?)
o_t = Ο(W_o Β· [h_{t-1}, x_t] + b_o)
6. Hidden State
h_t = o_t β tanh(c_t)
Ο: sigmoid (0~1)
β: element-wise multiplication
3. GRU Equations¶
GRU: Simplified version of LSTM (no cell state)
1. Reset Gate (how much to ignore previous info?)
r_t = Ο(W_r Β· [h_{t-1}, x_t] + b_r)
2. Update Gate (ratio of previous vs new info)
z_t = Ο(W_z Β· [h_{t-1}, x_t] + b_z)
3. Candidate Hidden
hΜ_t = tanh(W_h Β· [r_t β h_{t-1}, x_t] + b_h)
4. Hidden State
h_t = (1 - z_t) β h_{t-1} + z_t β hΜ_t
β keep previous β new info
LSTM vs GRU:
- GRU: 2 gates (reset, update)
- LSTM: 3 gates (forget, input, output) + cell state
- GRU has 25% fewer parameters
- Performance similar depending on task
4. Why Gradient is Preserved¶
LSTM Cell State Update:
c_t = f_t β c_{t-1} + i_t β cΜ_t
Gradient:
βc_t/βc_{t-1} = f_t (forget gate)
If f_t β 1, gradient propagates almost unchanged!
This acts as a "highway":
- Cell state can flow without transformation
- Gradient maintained even in long sequences
- Model learns f_t to decide what information to retain
Architecture¶
LSTM Structure Diagram¶
βββββββββββββββββββββββββββββββββββ
β Cell State c_t β
c_{t-1} ββββββββββββββΊβ βββββββββββββ+βββββββββββββΊ c_t
β β forget β input β
β f_t i_t β cΜ_t β
β β
β ββββββββββββββββββββββββ β
β β Ο Ο tanh Ο β β
β β f i cΜ o β β
β ββββββββββββββββββββββββ β
β β β
β [h_{t-1}, x_t] β
h_{t-1} ββββββββββββββΊβ βββββΊ h_t
β β βββ tanh(c_t) β
β o_t β
βββββββββββββββββββββββββββββββββββ
β
x_t
GRU Structure Diagram¶
βββββββββββββββββββββββββββββββββββ
β β
h_{t-1} ββββββββββββββΊβ βββββββββββββ+βββββββββββββββββΊβ h_t
β (1-z) z β hΜ β
β β β
β ββββββββββββββββββββββββ β
β β Ο Ο tanh β β
β β r z hΜ β β
β ββββββββββββββββββββββββ β
β β β
β [h_{t-1}, x_t] β
β [rβh_{t-1}, x_t] β
βββββββββββββββββββββββββββββββββββ
β
x_t
Parameter Count¶
LSTM:
4 gates Γ (input_size Γ hidden_size + hidden_size Γ hidden_size + hidden_size)
= 4 Γ (input_size + hidden_size + 1) Γ hidden_size
Example: input=128, hidden=256
= 4 Γ (128 + 256 + 1) Γ 256 = 394,240
GRU:
3 gates
= 3 Γ (input_size + hidden_size + 1) Γ hidden_size
Example: input=128, hidden=256
= 3 Γ (128 + 256 + 1) Γ 256 = 295,680 (25% less)
File Structure¶
06_LSTM_GRU/
βββ README.md # This file
βββ numpy/
β βββ lstm_numpy.py # NumPy LSTM (forward + backward)
β βββ gru_numpy.py # NumPy GRU
βββ pytorch_lowlevel/
β βββ lstm_gru_lowlevel.py # Using F.linear, not nn.LSTM
βββ paper/
β βββ lstm_paper.py # Original 1997 paper implementation
β βββ gru_paper.py # 2014 paper implementation
βββ exercises/
βββ 01_gradient_flow.md # BPTT gradient analysis
βββ 02_sequence_tasks.md # Sequence classification/generation
Core Concepts¶
1. Role of Gates¶
Forget Gate (f):
- Close to 1: retain previous information
- Close to 0: discard previous information
- Example: reset previous context when new sentence starts
Input Gate (i):
- Determines importance of new information
- Works with Candidate (cΜ)
Output Gate (o):
- What from cell state to expose as hidden
- Example: remember internally but don't output
2. Peephole Connection (Optional)¶
Basic LSTM: gates only reference [h_{t-1}, x_t]
Peephole: gates also reference c_{t-1}
f_t = Ο(W_f Β· [h_{t-1}, x_t] + W_{cf} Β· c_{t-1} + b_f)
i_t = Ο(W_i Β· [h_{t-1}, x_t] + W_{ci} Β· c_{t-1} + b_i)
o_t = Ο(W_o Β· [h_{t-1}, x_t] + W_{co} Β· c_t + b_o)
Effect: directly use cell state information in gate decisions
3. Bidirectional LSTM¶
Process sequence in both directions:
Forward: β h_1 β h_2 β h_3 β h_4 β
Backward: β h_4 β h_3 β h_2 β h_1 β
Output: [forward_h_t; backward_h_t] (concatenate)
Advantages:
- Use future context too
- Effective for NER, POS tagging
- Standard in NLP before Transformer
4. Stacked LSTM¶
Stack multiple LSTM layers:
x_t β LSTM_1 β h_t^1 β LSTM_2 β h_t^2 β ... β output
Each layer:
- Takes previous layer's hidden as input
- Learns more abstract representations
Caution: harder to train as it gets deeper
- Dropout essential (especially between layers)
- Residual connections help
Implementation Levels¶
Level 1: NumPy From-Scratch (numpy/)¶
- Direct implementation of all gate operations
- Manual BPTT gradient computation
- Derive cell state gradient
Level 2: PyTorch Low-Level (pytorch_lowlevel/)¶
- Use F.linear, torch.sigmoid, torch.tanh
- Don't use nn.LSTM
- Manual parameter management
- Implement Bidirectional, Stacked
Level 3: Paper Implementation (paper/)¶
- Hochreiter & Schmidhuber (1997) LSTM
- Cho et al. (2014) GRU
- Peephole connections
Learning Checklist¶
- [ ] Vanilla RNN vanishing gradient problem
- [ ] Memorize 4 LSTM equations
- [ ] Memorize 3 GRU equations
- [ ] Why cell state preserves gradient
- [ ] Explain role of each gate
- [ ] LSTM vs GRU pros/cons
- [ ] BPTT implementation
- [ ] Bidirectional, Stacked structures
References¶
- Hochreiter & Schmidhuber (1997). "Long Short-Term Memory"
- Cho et al. (2014). "Learning Phrase Representations using RNN Encoder-Decoder"
- colah's blog: Understanding LSTM
- d2l.ai: LSTM
- ../02_MLP/README.md