15. LSTM / GRU

15. LSTM / GRU

Previous: LSTM and GRU | Next: Attention & Transformer


Overview

LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit) are Recurrent Neural Network (RNN) variants that solve the vanishing gradient problem. They effectively learn long-term dependencies through gating mechanisms.


Mathematical Background

1. Vanilla RNN Problem

Vanilla RNN:
  h_t = tanh(W_h Β· h_{t-1} + W_x Β· x_t + b)

Problem: Backpropagation Through Time (BPTT)

βˆ‚L/βˆ‚h_0 = βˆ‚L/βˆ‚h_T Β· βˆ‚h_T/βˆ‚h_{T-1} Β· ... Β· βˆ‚h_1/βˆ‚h_0
        = βˆ‚L/βˆ‚h_T Β· Ξ _{t=1}^{T} βˆ‚h_t/βˆ‚h_{t-1}

βˆ‚h_t/βˆ‚h_{t-1} = diag(1 - tanhΒ²(Β·)) Β· W_h

Result:
- |eigenvalue(W_h)| < 1 β†’ Vanishing gradient
- |eigenvalue(W_h)| > 1 β†’ Exploding gradient

β†’ Cannot learn initial information in long sequences

2. LSTM Equations

Input: x_t (current input), h_{t-1} (previous hidden), c_{t-1} (previous cell)
Output: h_t (current hidden), c_t (current cell)

1. Forget Gate (what to discard?)
   f_t = Οƒ(W_f Β· [h_{t-1}, x_t] + b_f)

2. Input Gate (what to store?)
   i_t = Οƒ(W_i Β· [h_{t-1}, x_t] + b_i)

3. Candidate Cell (new information)
   c̃_t = tanh(W_c · [h_{t-1}, x_t] + b_c)

4. Cell State Update
   c_t = f_t βŠ™ c_{t-1} + i_t βŠ™ cΜƒ_t
        ↑ previous info   ↑ new info

5. Output Gate (what to output?)
   o_t = Οƒ(W_o Β· [h_{t-1}, x_t] + b_o)

6. Hidden State
   h_t = o_t βŠ™ tanh(c_t)

Οƒ: sigmoid (0~1)
βŠ™: element-wise multiplication

3. GRU Equations

GRU: Simplified version of LSTM (no cell state)

1. Reset Gate (how much to ignore previous info?)
   r_t = Οƒ(W_r Β· [h_{t-1}, x_t] + b_r)

2. Update Gate (ratio of previous vs new info)
   z_t = Οƒ(W_z Β· [h_{t-1}, x_t] + b_z)

3. Candidate Hidden
   hΜƒ_t = tanh(W_h Β· [r_t βŠ™ h_{t-1}, x_t] + b_h)

4. Hidden State
   h_t = (1 - z_t) βŠ™ h_{t-1} + z_t βŠ™ hΜƒ_t
        ↑ keep previous      ↑ new info

LSTM vs GRU:
- GRU: 2 gates (reset, update)
- LSTM: 3 gates (forget, input, output) + cell state
- GRU has 25% fewer parameters
- Performance similar depending on task

4. Why Gradient is Preserved

LSTM Cell State Update:
  c_t = f_t βŠ™ c_{t-1} + i_t βŠ™ cΜƒ_t

Gradient:
  βˆ‚c_t/βˆ‚c_{t-1} = f_t  (forget gate)

If f_t β‰ˆ 1, gradient propagates almost unchanged!

This acts as a "highway":
- Cell state can flow without transformation
- Gradient maintained even in long sequences
- Model learns f_t to decide what information to retain

Architecture

LSTM Structure Diagram

                      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                      β”‚           Cell State c_t         β”‚
c_{t-1} ─────────────►│   βŠ™β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€+────────────►  c_t
                      β”‚   ↑ forget     ↑ input           β”‚
                      β”‚   f_t        i_t βŠ™ cΜƒ_t          β”‚
                      β”‚                                  β”‚
                      β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”‚
                      β”‚   β”‚  Οƒ   Οƒ   tanh   Οƒ   β”‚       β”‚
                      │   │  f   i    c̃    o   │       │
                      β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β”‚
                      β”‚         ↑                        β”‚
                      β”‚    [h_{t-1}, x_t]                β”‚
h_{t-1} ─────────────►│                                  β”œβ”€β”€β”€β–Ί h_t
                      β”‚                 βŠ™ ◄── tanh(c_t)  β”‚
                      β”‚                 o_t              β”‚
                      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                    ↑
                                   x_t

GRU Structure Diagram

                      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                      β”‚                                  β”‚
h_{t-1} ─────────────►│ βŠ™β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€+────────────────►│ h_t
                      β”‚ (1-z)        z βŠ™ hΜƒ             β”‚
                      β”‚              ↑                   β”‚
                      β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”‚
                      β”‚   β”‚    Οƒ   Οƒ   tanh     β”‚       β”‚
                      │   │    r   z    h̃      │       │
                      β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β”‚
                      β”‚         ↑                        β”‚
                      β”‚    [h_{t-1}, x_t]                β”‚
                      β”‚    [rβŠ™h_{t-1}, x_t]              β”‚
                      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                    ↑
                                   x_t

Parameter Count

LSTM:
  4 gates Γ— (input_size Γ— hidden_size + hidden_size Γ— hidden_size + hidden_size)
  = 4 Γ— (input_size + hidden_size + 1) Γ— hidden_size

Example: input=128, hidden=256
  = 4 Γ— (128 + 256 + 1) Γ— 256 = 394,240

GRU:
  3 gates
  = 3 Γ— (input_size + hidden_size + 1) Γ— hidden_size

Example: input=128, hidden=256
  = 3 Γ— (128 + 256 + 1) Γ— 256 = 295,680  (25% less)

File Structure

06_LSTM_GRU/
β”œβ”€β”€ README.md                      # This file
β”œβ”€β”€ numpy/
β”‚   β”œβ”€β”€ lstm_numpy.py             # NumPy LSTM (forward + backward)
β”‚   └── gru_numpy.py              # NumPy GRU
β”œβ”€β”€ pytorch_lowlevel/
β”‚   └── lstm_gru_lowlevel.py      # Using F.linear, not nn.LSTM
β”œβ”€β”€ paper/
β”‚   β”œβ”€β”€ lstm_paper.py             # Original 1997 paper implementation
β”‚   └── gru_paper.py              # 2014 paper implementation
└── exercises/
    β”œβ”€β”€ 01_gradient_flow.md       # BPTT gradient analysis
    └── 02_sequence_tasks.md      # Sequence classification/generation

Core Concepts

1. Role of Gates

Forget Gate (f):
- Close to 1: retain previous information
- Close to 0: discard previous information
- Example: reset previous context when new sentence starts

Input Gate (i):
- Determines importance of new information
- Works with Candidate (c̃)

Output Gate (o):
- What from cell state to expose as hidden
- Example: remember internally but don't output

2. Peephole Connection (Optional)

Basic LSTM: gates only reference [h_{t-1}, x_t]
Peephole: gates also reference c_{t-1}

f_t = Οƒ(W_f Β· [h_{t-1}, x_t] + W_{cf} Β· c_{t-1} + b_f)
i_t = Οƒ(W_i Β· [h_{t-1}, x_t] + W_{ci} Β· c_{t-1} + b_i)
o_t = Οƒ(W_o Β· [h_{t-1}, x_t] + W_{co} Β· c_t + b_o)

Effect: directly use cell state information in gate decisions

3. Bidirectional LSTM

Process sequence in both directions:

Forward:  β†’ h_1 β†’ h_2 β†’ h_3 β†’ h_4 β†’
Backward: ← h_4 ← h_3 ← h_2 ← h_1 ←

Output: [forward_h_t; backward_h_t] (concatenate)

Advantages:
- Use future context too
- Effective for NER, POS tagging
- Standard in NLP before Transformer

4. Stacked LSTM

Stack multiple LSTM layers:

x_t β†’ LSTM_1 β†’ h_t^1 β†’ LSTM_2 β†’ h_t^2 β†’ ... β†’ output

Each layer:
- Takes previous layer's hidden as input
- Learns more abstract representations

Caution: harder to train as it gets deeper
- Dropout essential (especially between layers)
- Residual connections help

Implementation Levels

Level 1: NumPy From-Scratch (numpy/)

  • Direct implementation of all gate operations
  • Manual BPTT gradient computation
  • Derive cell state gradient

Level 2: PyTorch Low-Level (pytorch_lowlevel/)

  • Use F.linear, torch.sigmoid, torch.tanh
  • Don't use nn.LSTM
  • Manual parameter management
  • Implement Bidirectional, Stacked

Level 3: Paper Implementation (paper/)

  • Hochreiter & Schmidhuber (1997) LSTM
  • Cho et al. (2014) GRU
  • Peephole connections

Learning Checklist

  • [ ] Vanilla RNN vanishing gradient problem
  • [ ] Memorize 4 LSTM equations
  • [ ] Memorize 3 GRU equations
  • [ ] Why cell state preserves gradient
  • [ ] Explain role of each gate
  • [ ] LSTM vs GRU pros/cons
  • [ ] BPTT implementation
  • [ ] Bidirectional, Stacked structures

References

to navigate between lessons