06. LSTM / GRU

06. LSTM / GRU

κ°œμš”

LSTM(Long Short-Term Memory)κ³Ό GRU(Gated Recurrent Unit)λŠ” vanishing gradient 문제λ₯Ό ν•΄κ²°ν•œ μˆœν™˜ 신경망(RNN) λ³€ν˜•μž…λ‹ˆλ‹€. 게이트 λ©”μ»€λ‹ˆμ¦˜μ„ 톡해 μž₯κΈ° μ˜μ‘΄μ„±(long-term dependency)을 효과적으둜 ν•™μŠ΅ν•©λ‹ˆλ‹€.


μˆ˜ν•™μ  λ°°κ²½

1. Vanilla RNN의 문제

Vanilla RNN:
  h_t = tanh(W_h Β· h_{t-1} + W_x Β· x_t + b)

문제: Backpropagation Through Time (BPTT)

βˆ‚L/βˆ‚h_0 = βˆ‚L/βˆ‚h_T Β· βˆ‚h_T/βˆ‚h_{T-1} Β· ... Β· βˆ‚h_1/βˆ‚h_0
        = βˆ‚L/βˆ‚h_T Β· Ξ _{t=1}^{T} βˆ‚h_t/βˆ‚h_{t-1}

βˆ‚h_t/βˆ‚h_{t-1} = diag(1 - tanhΒ²(Β·)) Β· W_h

κ²°κ³Ό:
- |eigenvalue(W_h)| < 1 β†’ Vanishing gradient
- |eigenvalue(W_h)| > 1 β†’ Exploding gradient

β†’ κΈ΄ μ‹œν€€μŠ€μ—μ„œ 초기 정보 ν•™μŠ΅ λΆˆκ°€

2. LSTM μˆ˜μ‹

μž…λ ₯: x_t (ν˜„μž¬ μž…λ ₯), h_{t-1} (이전 hidden), c_{t-1} (이전 cell)
좜λ ₯: h_t (ν˜„μž¬ hidden), c_t (ν˜„μž¬ cell)

1. Forget Gate (무엇을 λ²„λ¦΄κΉŒ?)
   f_t = Οƒ(W_f Β· [h_{t-1}, x_t] + b_f)

2. Input Gate (무엇을 μ €μž₯ν• κΉŒ?)
   i_t = Οƒ(W_i Β· [h_{t-1}, x_t] + b_i)

3. Candidate Cell (μƒˆλ‘œμš΄ 정보)
   c̃_t = tanh(W_c · [h_{t-1}, x_t] + b_c)

4. Cell State μ—…λ°μ΄νŠΈ
   c_t = f_t βŠ™ c_{t-1} + i_t βŠ™ cΜƒ_t
        ↑ 이전 정보   ↑ μƒˆ 정보

5. Output Gate (무엇을 좜λ ₯ν• κΉŒ?)
   o_t = Οƒ(W_o Β· [h_{t-1}, x_t] + b_o)

6. Hidden State
   h_t = o_t βŠ™ tanh(c_t)

Οƒ: sigmoid (0~1)
βŠ™: element-wise κ³±

3. GRU μˆ˜μ‹

GRU: LSTM의 κ°„μ†Œν™” 버전 (cell state μ—†μŒ)

1. Reset Gate (이전 정보 μ–Όλ§ˆλ‚˜ λ¬΄μ‹œ?)
   r_t = Οƒ(W_r Β· [h_{t-1}, x_t] + b_r)

2. Update Gate (이전 vs μƒˆ 정보 λΉ„μœ¨)
   z_t = Οƒ(W_z Β· [h_{t-1}, x_t] + b_z)

3. Candidate Hidden
   hΜƒ_t = tanh(W_h Β· [r_t βŠ™ h_{t-1}, x_t] + b_h)

4. Hidden State
   h_t = (1 - z_t) βŠ™ h_{t-1} + z_t βŠ™ hΜƒ_t
        ↑ 이전 정보 μœ μ§€      ↑ μƒˆ 정보

LSTM vs GRU:
- GRU: 2개 게이트 (reset, update)
- LSTM: 3개 게이트 (forget, input, output) + cell state
- GRUκ°€ νŒŒλΌλ―Έν„° 25% 적음
- μ„±λŠ₯은 task에 따라 λΉ„μŠ·

4. μ™œ Gradientκ°€ λ³΄μ‘΄λ˜λŠ”κ°€?

LSTM Cell State μ—…λ°μ΄νŠΈ:
  c_t = f_t βŠ™ c_{t-1} + i_t βŠ™ cΜƒ_t

Gradient:
  βˆ‚c_t/βˆ‚c_{t-1} = f_t  (forget gate)

f_t β‰ˆ 1이면 gradientκ°€ 거의 κ·ΈλŒ€λ‘œ μ „νŒŒ!

이것이 "highway" μ—­ν• :
- Cell stateκ°€ λ³€ν˜• 없이 흐λ₯Ό 수 있음
- κΈ΄ μ‹œν€€μŠ€μ—μ„œλ„ gradient μœ μ§€
- λͺ¨λΈμ΄ f_tλ₯Ό ν•™μŠ΅ν•΄ μ–΄λ–€ 정보λ₯Ό μœ μ§€ν• μ§€ κ²°μ •

μ•„ν‚€ν…μ²˜

LSTM ꡬ쑰 λ‹€μ΄μ–΄κ·Έλž¨

                      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                      β”‚           Cell State c_t         β”‚
c_{t-1} ─────────────►│   βŠ™β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€+────────────►  c_t
                      β”‚   ↑ forget     ↑ input           β”‚
                      β”‚   f_t        i_t βŠ™ cΜƒ_t          β”‚
                      β”‚                                  β”‚
                      β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”‚
                      β”‚   β”‚  Οƒ   Οƒ   tanh   Οƒ   β”‚       β”‚
                      │   │  f   i    c̃    o   │       │
                      β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β”‚
                      β”‚         ↑                        β”‚
                      β”‚    [h_{t-1}, x_t]                β”‚
h_{t-1} ─────────────►│                                  β”œβ”€β”€β”€β–Ί h_t
                      β”‚                 βŠ™ ◄── tanh(c_t)  β”‚
                      β”‚                 o_t              β”‚
                      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                    ↑
                                   x_t

GRU ꡬ쑰 λ‹€μ΄μ–΄κ·Έλž¨

                      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                      β”‚                                  β”‚
h_{t-1} ─────────────►│ βŠ™β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€+────────────────►│ h_t
                      β”‚ (1-z)        z βŠ™ hΜƒ             β”‚
                      β”‚              ↑                   β”‚
                      β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”‚
                      β”‚   β”‚    Οƒ   Οƒ   tanh     β”‚       β”‚
                      │   │    r   z    h̃      │       │
                      β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β”‚
                      β”‚         ↑                        β”‚
                      β”‚    [h_{t-1}, x_t]                β”‚
                      β”‚    [rβŠ™h_{t-1}, x_t]              β”‚
                      β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                    ↑
                                   x_t

νŒŒλΌλ―Έν„° 수

LSTM:
  4개 게이트 Γ— (input_size Γ— hidden_size + hidden_size Γ— hidden_size + hidden_size)
  = 4 Γ— (input_size + hidden_size + 1) Γ— hidden_size

예: input=128, hidden=256
  = 4 Γ— (128 + 256 + 1) Γ— 256 = 394,240

GRU:
  3개 게이트
  = 3 Γ— (input_size + hidden_size + 1) Γ— hidden_size

예: input=128, hidden=256
  = 3 Γ— (128 + 256 + 1) Γ— 256 = 295,680  (25% 적음)

파일 ꡬ쑰

06_LSTM_GRU/
β”œβ”€β”€ README.md                      # 이 파일
β”œβ”€β”€ numpy/
β”‚   β”œβ”€β”€ lstm_numpy.py             # NumPy LSTM (forward + backward)
β”‚   └── gru_numpy.py              # NumPy GRU
β”œβ”€β”€ pytorch_lowlevel/
β”‚   └── lstm_gru_lowlevel.py      # F.linear μ‚¬μš©, nn.LSTM λ―Έμ‚¬μš©
β”œβ”€β”€ paper/
β”‚   β”œβ”€β”€ lstm_paper.py             # 원본 1997 λ…Όλ¬Έ κ΅¬ν˜„
β”‚   └── gru_paper.py              # 2014 λ…Όλ¬Έ κ΅¬ν˜„
└── exercises/
    β”œβ”€β”€ 01_gradient_flow.md       # BPTT gradient 뢄석
    └── 02_sequence_tasks.md      # μ‹œν€€μŠ€ λΆ„λ₯˜/생성

핡심 κ°œλ…

1. 게이트의 μ—­ν• 

Forget Gate (f):
- 1에 κ°€κΉŒμ›€: 이전 정보 μœ μ§€
- 0에 κ°€κΉŒμ›€: 이전 정보 μ‚­μ œ
- 예: μƒˆ λ¬Έμž₯ μ‹œμž‘ μ‹œ 이전 λ¬Έλ§₯ 리셋

Input Gate (i):
- μƒˆ μ •λ³΄μ˜ μ€‘μš”λ„ κ²°μ •
- Candidate (cΜƒ)와 ν•¨κ»˜ μž‘λ™

Output Gate (o):
- Cell state 쀑 무엇을 hidden으둜 λ…ΈμΆœ
- 예: λ‚΄λΆ€μ μœΌλ‘œλŠ” κΈ°μ–΅ν•˜μ§€λ§Œ 좜λ ₯ν•˜μ§€ μ•ŠμŒ

2. Peephole Connection (선택적)

κΈ°λ³Έ LSTM: κ²Œμ΄νŠΈκ°€ [h_{t-1}, x_t]만 μ°Έμ‘°
Peephole: κ²Œμ΄νŠΈκ°€ c_{t-1}도 μ°Έμ‘°

f_t = Οƒ(W_f Β· [h_{t-1}, x_t] + W_{cf} Β· c_{t-1} + b_f)
i_t = Οƒ(W_i Β· [h_{t-1}, x_t] + W_{ci} Β· c_{t-1} + b_i)
o_t = Οƒ(W_o Β· [h_{t-1}, x_t] + W_{co} Β· c_t + b_o)

효과: cell state 정보λ₯Ό 게이트 결정에 직접 ν™œμš©

3. Bidirectional LSTM

μ‹œν€€μŠ€λ₯Ό μ–‘λ°©ν–₯으둜 처리:

Forward:  β†’ h_1 β†’ h_2 β†’ h_3 β†’ h_4 β†’
Backward: ← h_4 ← h_3 ← h_2 ← h_1 ←

좜λ ₯: [forward_h_t; backward_h_t] (concatenate)

μž₯점:
- 미래 μ»¨ν…μŠ€νŠΈλ„ ν™œμš©
- NER, POS tagging에 효과적
- Transformer λ“±μž₯ μ „ NLP ν‘œμ€€

4. Stacked LSTM

μ—¬λŸ¬ LSTM λ ˆμ΄μ–΄ μŒ“κΈ°:

x_t β†’ LSTM_1 β†’ h_t^1 β†’ LSTM_2 β†’ h_t^2 β†’ ... β†’ output

각 λ ˆμ΄μ–΄:
- 이전 λ ˆμ΄μ–΄μ˜ hidden을 μž…λ ₯으둜
- 더 좔상적인 ν‘œν˜„ ν•™μŠ΅

주의: κΉŠμ–΄μ§ˆμˆ˜λ‘ ν•™μŠ΅ 어렀움
- Dropout ν•„μˆ˜ (특히 λ ˆμ΄μ–΄ κ°„)
- Residual connection 도움

κ΅¬ν˜„ 레벨

Level 1: NumPy From-Scratch (numpy/)

  • λͺ¨λ“  게이트 μ—°μ‚° 직접 κ΅¬ν˜„
  • BPTT gradient μˆ˜λ™ 계산
  • Cell state gradient μœ λ„

Level 2: PyTorch Low-Level (pytorch_lowlevel/)

  • F.linear, torch.sigmoid, torch.tanh μ‚¬μš©
  • nn.LSTM λ―Έμ‚¬μš©
  • νŒŒλΌλ―Έν„° μˆ˜λ™ 관리
  • Bidirectional, Stacked κ΅¬ν˜„

Level 3: Paper Implementation (paper/)

  • Hochreiter & Schmidhuber (1997) LSTM
  • Cho et al. (2014) GRU
  • Peephole connections

ν•™μŠ΅ 체크리슀트

  • [ ] Vanilla RNN의 vanishing gradient 문제
  • [ ] LSTM 4개 μˆ˜μ‹ μ•”κΈ°
  • [ ] GRU 3개 μˆ˜μ‹ μ•”κΈ°
  • [ ] Cell stateκ°€ gradientλ₯Ό λ³΄μ‘΄ν•˜λŠ” 이유
  • [ ] 각 게이트의 μ—­ν•  μ„€λͺ…
  • [ ] LSTM vs GRU μž₯단점
  • [ ] BPTT κ΅¬ν˜„
  • [ ] Bidirectional, Stacked ꡬ쑰

참고 자료

to navigate between lessons