14. LSTM and GRU
14. LSTM and GRU¶
Previous: RNN Basics | Next: LSTM & GRU Implementation
Learning Objectives¶
- Understand LSTM and GRU structures
- Learn gate mechanisms
- Learn long-term dependencies
- Implement with PyTorch
1. LSTM (Long Short-Term Memory)¶
Problem: RNN's Vanishing Gradient¶
h100 β W Γ W Γ ... Γ W Γ h1
β
Gradient converges to 0
Solution: Cell State¶
LSTM = Cell State (long-term memory) + Hidden State (short-term memory)
LSTM Structure¶
ββββββββββββββββββββββββββββββββββββββββ
β Cell State (C) β
β Γβββββ(+)ββββββββββββββββββββββΊ β
β β β β
β forget input β
β gate gate β
β β β β
h(t-1)ββ΄βββΊ[Ο] [Ο][tanh] [Ο]βββΊΓβββββββΊh(t)
f(t) i(t) g(t) o(t) β
output gate
Gate Formulas¶
# Forget Gate: How much to forget from previous memory
f(t) = Ο(W_f Γ [h(t-1), x(t)] + b_f)
# Input Gate: How much new information to store
i(t) = Ο(W_i Γ [h(t-1), x(t)] + b_i)
# Cell Candidate: New candidate information
g(t) = tanh(W_g Γ [h(t-1), x(t)] + b_g)
# Cell State Update
C(t) = f(t) Γ C(t-1) + i(t) Γ g(t)
# Output Gate: How much of cell state to output
o(t) = Ο(W_o Γ [h(t-1), x(t)] + b_o)
# Hidden State
h(t) = o(t) Γ tanh(C(t))
2. GRU (Gated Recurrent Unit)¶
Simplified Version of LSTM¶
GRU = Reset Gate + Update Gate
(Merges cell state and hidden state)
GRU Structure¶
Update Gate (z)
ββββββββββββββββββββββββββββββ
β β
h(t-1)ββ΄βββΊ[Ο]βββz(t)ββββββΓββ(+)βββΊh(t)
β β β
β ββββββ β
β β Γββββββ
β β β
ββββΊ[Ο] [tanh]
β r(t) β
β β β
βββββΓβββββββ
Reset Gate (r)
Gate Formulas¶
# Update Gate: Ratio of previous state vs new state
z(t) = Ο(W_z Γ [h(t-1), x(t)] + b_z)
# Reset Gate: How much to forget previous state
r(t) = Ο(W_r Γ [h(t-1), x(t)] + b_r)
# Candidate Hidden
hΜ(t) = tanh(W Γ [r(t) Γ h(t-1), x(t)] + b)
# Hidden State Update
h(t) = (1 - z(t)) Γ h(t-1) + z(t) Γ hΜ(t)
3. PyTorch LSTM/GRU¶
LSTM¶
lstm = nn.LSTM(
input_size=10,
hidden_size=20,
num_layers=2,
batch_first=True,
dropout=0.1,
bidirectional=False
)
# Forward pass
# output: Hidden states at all times
# (h_n, c_n): Last (hidden, cell) states
output, (h_n, c_n) = lstm(x)
GRU¶
gru = nn.GRU(
input_size=10,
hidden_size=20,
num_layers=2,
batch_first=True
)
# Forward pass (no cell state)
output, h_n = gru(x)
4. LSTM Classifier¶
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim, hidden_dim,
num_layers=2,
batch_first=True,
dropout=0.3,
bidirectional=True
)
# Bidirectional so hidden_dim * 2
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
# x: (batch, seq) - token indices
embedded = self.embedding(x)
# LSTM
output, (h_n, c_n) = self.lstm(embedded)
# Combine bidirectional last hidden states
# h_n: (num_layers*2, batch, hidden)
forward_last = h_n[-2] # Forward last layer
backward_last = h_n[-1] # Backward last layer
combined = torch.cat([forward_last, backward_last], dim=1)
return self.fc(combined)
5. Sequence Generation (Language Model)¶
class LSTMLanguageModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, hidden=None):
embedded = self.embedding(x)
output, hidden = self.lstm(embedded, hidden)
logits = self.fc(output)
return logits, hidden
def generate(self, start_token, max_len, temperature=1.0):
self.eval()
tokens = [start_token]
hidden = None
with torch.no_grad():
for _ in range(max_len):
x = torch.tensor([[tokens[-1]]])
logits, hidden = self(x, hidden)
# Temperature sampling
probs = F.softmax(logits[0, -1] / temperature, dim=0)
next_token = torch.multinomial(probs, 1).item()
tokens.append(next_token)
return tokens
6. LSTM vs GRU Comparison¶
| Item | LSTM | GRU |
|---|---|---|
| Number of Gates | 3 (f, i, o) | 2 (r, z) |
| States | Cell + Hidden | Hidden only |
| Parameters | More | Fewer |
| Training Speed | Slower | Faster |
| Performance | Complex patterns | Similar or slightly lower |
Selection Guide¶
- LSTM: Long sequences, complex dependencies
- GRU: Fast training, limited resources
7. Practical Tips¶
Initialization¶
# Initialize hidden state
def init_hidden(batch_size, hidden_size, num_layers, bidirectional):
num_directions = 2 if bidirectional else 1
h = torch.zeros(num_layers * num_directions, batch_size, hidden_size)
c = torch.zeros(num_layers * num_directions, batch_size, hidden_size)
return (h.to(device), c.to(device))
Dropout Pattern¶
class LSTMWithDropout(nn.Module):
def __init__(self, input_size, hidden_size, num_classes, dropout=0.5):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
output, (h_n, _) = self.lstm(x)
# Apply dropout to last hidden state
dropped = self.dropout(h_n[-1])
return self.fc(dropped)
Summary¶
Core Concepts¶
- LSTM: Maintain long-term memory with cell state, 3 gates
- GRU: Simplified LSTM, 2 gates
- Gates: Control information flow (sigmoid Γ value)
Key Code¶
# LSTM
lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
output, (h_n, c_n) = lstm(x)
# GRU
gru = nn.GRU(input_size, hidden_size, batch_first=True)
output, h_n = gru(x)
Next Steps¶
In 16_Attention_Transformer.md, we'll learn Seq2Seq and Attention.