09. LSTM๊ณผ GRU

09. LSTM๊ณผ GRU

ํ•™์Šต ๋ชฉํ‘œ

  • LSTM๊ณผ GRU์˜ ๊ตฌ์กฐ ์ดํ•ด
  • ๊ฒŒ์ดํŠธ ๋ฉ”์ปค๋‹ˆ์ฆ˜
  • ์žฅ๊ธฐ ์˜์กด์„ฑ ํ•™์Šต
  • PyTorch ๊ตฌํ˜„

1. LSTM (Long Short-Term Memory)

๋ฌธ์ œ: RNN์˜ ๊ธฐ์šธ๊ธฐ ์†Œ์‹ค

h100 โ† W ร— W ร— ... ร— W ร— h1
            โ†‘
    ๊ธฐ์šธ๊ธฐ๊ฐ€ 0์— ์ˆ˜๋ ด

ํ•ด๊ฒฐ: ์…€ ์ƒํƒœ (Cell State)

LSTM = ์…€ ์ƒํƒœ (์žฅ๊ธฐ ๊ธฐ์–ต) + ์€๋‹‰ ์ƒํƒœ (๋‹จ๊ธฐ ๊ธฐ์–ต)

LSTM ๊ตฌ์กฐ

       โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
       โ”‚            ์…€ ์ƒํƒœ (C)                 โ”‚
       โ”‚     ร—โ”€โ”€โ”€โ”€โ”€(+)โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–บ   โ”‚
       โ”‚     โ†‘      โ†‘                          โ”‚
       โ”‚    forget  input                      โ”‚
       โ”‚    gate    gate                       โ”‚
       โ”‚     โ†‘      โ†‘                          โ”‚
h(t-1)โ”€โ”ดโ”€โ”€โ–บ[ฯƒ]   [ฯƒ][tanh]    [ฯƒ]โ”€โ”€โ–บร—โ”€โ”€โ”€โ”€โ”€โ”€โ–บh(t)
           f(t)   i(t) g(t)   o(t)     โ†‘
                              output gate

๊ฒŒ์ดํŠธ ์ˆ˜์‹

# Forget Gate: ์ด์ „ ๊ธฐ์–ต ์ค‘ ์–ผ๋งˆ๋‚˜ ์žŠ์„์ง€
f(t) = ฯƒ(W_f ร— [h(t-1), x(t)] + b_f)

# Input Gate: ์ƒˆ ์ •๋ณด ์ค‘ ์–ผ๋งˆ๋‚˜ ์ €์žฅํ• ์ง€
i(t) = ฯƒ(W_i ร— [h(t-1), x(t)] + b_i)

# Cell Candidate: ์ƒˆ๋กœ์šด ํ›„๋ณด ์ •๋ณด
g(t) = tanh(W_g ร— [h(t-1), x(t)] + b_g)

# Cell State Update
C(t) = f(t) ร— C(t-1) + i(t) ร— g(t)

# Output Gate: ์…€ ์ƒํƒœ ์ค‘ ์–ผ๋งˆ๋‚˜ ์ถœ๋ ฅํ• ์ง€
o(t) = ฯƒ(W_o ร— [h(t-1), x(t)] + b_o)

# Hidden State
h(t) = o(t) ร— tanh(C(t))

2. GRU (Gated Recurrent Unit)

LSTM์˜ ๋‹จ์ˆœํ™” ๋ฒ„์ „

GRU = Reset Gate + Update Gate
(์…€ ์ƒํƒœ์™€ ์€๋‹‰ ์ƒํƒœ ํ†ตํ•ฉ)

GRU ๊ตฌ์กฐ

       Update Gate (z)
       โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
       โ”‚                            โ”‚
h(t-1)โ”€โ”ดโ”€โ”€โ–บ[ฯƒ]โ”€โ”€โ”€z(t)โ”€โ”€โ”€โ”€โ”€โ”€ร—โ”€โ”€(+)โ”€โ”€โ–บh(t)
              โ”‚           โ†‘    โ†‘
              โ”‚      โ”Œโ”€โ”€โ”€โ”€โ”˜    โ”‚
              โ”‚      โ”‚   ร—โ”€โ”€โ”€โ”€โ”€โ”˜
              โ”‚      โ”‚   โ†‘
              โ”œโ”€โ”€โ–บ[ฯƒ]   [tanh]
              โ”‚   r(t)    โ”‚
              โ”‚    โ”‚      โ”‚
              โ””โ”€โ”€โ”€โ”€ร—โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                Reset Gate (r)

๊ฒŒ์ดํŠธ ์ˆ˜์‹

# Update Gate: ์ด์ „ ์ƒํƒœ vs ์ƒˆ ์ƒํƒœ ๋น„์œจ
z(t) = ฯƒ(W_z ร— [h(t-1), x(t)] + b_z)

# Reset Gate: ์ด์ „ ์ƒํƒœ๋ฅผ ์–ผ๋งˆ๋‚˜ ์žŠ์„์ง€
r(t) = ฯƒ(W_r ร— [h(t-1), x(t)] + b_r)

# Candidate Hidden
hฬƒ(t) = tanh(W ร— [r(t) ร— h(t-1), x(t)] + b)

# Hidden State Update
h(t) = (1 - z(t)) ร— h(t-1) + z(t) ร— hฬƒ(t)

3. PyTorch LSTM/GRU

LSTM

lstm = nn.LSTM(
    input_size=10,
    hidden_size=20,
    num_layers=2,
    batch_first=True,
    dropout=0.1,
    bidirectional=False
)

# ์ˆœ์ „ํŒŒ
# output: ๋ชจ๋“  ์‹œ๊ฐ„์˜ ์€๋‹‰ ์ƒํƒœ
# (h_n, c_n): ๋งˆ์ง€๋ง‰ (์€๋‹‰, ์…€) ์ƒํƒœ
output, (h_n, c_n) = lstm(x)

GRU

gru = nn.GRU(
    input_size=10,
    hidden_size=20,
    num_layers=2,
    batch_first=True
)

# ์ˆœ์ „ํŒŒ (์…€ ์ƒํƒœ ์—†์Œ)
output, h_n = gru(x)

4. LSTM ๋ถ„๋ฅ˜๊ธฐ

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        # ์–‘๋ฐฉํ–ฅ์ด๋ฏ€๋กœ hidden_dim * 2
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x):
        # x: (batch, seq) - ํ† ํฐ ์ธ๋ฑ์Šค
        embedded = self.embedding(x)

        # LSTM
        output, (h_n, c_n) = self.lstm(embedded)

        # ์–‘๋ฐฉํ–ฅ ๋งˆ์ง€๋ง‰ ์€๋‹‰ ์ƒํƒœ ๊ฒฐํ•ฉ
        # h_n: (num_layers*2, batch, hidden)
        forward_last = h_n[-2]  # ์ •๋ฐฉํ–ฅ ๋งˆ์ง€๋ง‰ ์ธต
        backward_last = h_n[-1]  # ์—ญ๋ฐฉํ–ฅ ๋งˆ์ง€๋ง‰ ์ธต
        combined = torch.cat([forward_last, backward_last], dim=1)

        return self.fc(combined)

5. ์‹œํ€€์Šค ์ƒ์„ฑ (์–ธ์–ด ๋ชจ๋ธ)

class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        embedded = self.embedding(x)
        output, hidden = self.lstm(embedded, hidden)
        logits = self.fc(output)
        return logits, hidden

    def generate(self, start_token, max_len, temperature=1.0):
        self.eval()
        tokens = [start_token]
        hidden = None

        with torch.no_grad():
            for _ in range(max_len):
                x = torch.tensor([[tokens[-1]]])
                logits, hidden = self(x, hidden)

                # Temperature sampling
                probs = F.softmax(logits[0, -1] / temperature, dim=0)
                next_token = torch.multinomial(probs, 1).item()
                tokens.append(next_token)

        return tokens

6. LSTM vs GRU ๋น„๊ต

ํ•ญ๋ชฉ LSTM GRU
๊ฒŒ์ดํŠธ ์ˆ˜ 3๊ฐœ (f, i, o) 2๊ฐœ (r, z)
์ƒํƒœ ์…€ + ์€๋‹‰ ์€๋‹‰๋งŒ
ํŒŒ๋ผ๋ฏธํ„ฐ ๋” ๋งŽ์Œ ๋” ์ ์Œ
ํ•™์Šต ์†๋„ ๋А๋ฆผ ๋น ๋ฆ„
์„ฑ๋Šฅ ๋ณต์žกํ•œ ํŒจํ„ด ๋น„์Šทํ•˜๊ฑฐ๋‚˜ ์•ฝ๊ฐ„ ๋‚ฎ์Œ

์„ ํƒ ๊ฐ€์ด๋“œ

  • LSTM: ๊ธด ์‹œํ€€์Šค, ๋ณต์žกํ•œ ์˜์กด์„ฑ
  • GRU: ๋น ๋ฅธ ํ•™์Šต, ์ œํ•œ๋œ ์ž์›

7. ์‹ค์ „ ํŒ

์ดˆ๊ธฐํ™”

# ์€๋‹‰ ์ƒํƒœ ์ดˆ๊ธฐํ™”
def init_hidden(batch_size, hidden_size, num_layers, bidirectional):
    num_directions = 2 if bidirectional else 1
    h = torch.zeros(num_layers * num_directions, batch_size, hidden_size)
    c = torch.zeros(num_layers * num_directions, batch_size, hidden_size)
    return (h.to(device), c.to(device))

Dropout ํŒจํ„ด

class LSTMWithDropout(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, dropout=0.5):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        output, (h_n, _) = self.lstm(x)
        # ๋งˆ์ง€๋ง‰ ์€๋‹‰ ์ƒํƒœ์— dropout
        dropped = self.dropout(h_n[-1])
        return self.fc(dropped)

์ •๋ฆฌ

ํ•ต์‹ฌ ๊ฐœ๋…

  1. LSTM: ์…€ ์ƒํƒœ๋กœ ์žฅ๊ธฐ ๊ธฐ์–ต ์œ ์ง€, 3๊ฐœ ๊ฒŒ์ดํŠธ
  2. GRU: LSTM ๋‹จ์ˆœํ™”, 2๊ฐœ ๊ฒŒ์ดํŠธ
  3. ๊ฒŒ์ดํŠธ: ์ •๋ณด ํ๋ฆ„ ์ œ์–ด (์‹œ๊ทธ๋ชจ์ด๋“œ ร— ๊ฐ’)

ํ•ต์‹ฌ ์ฝ”๋“œ

# LSTM
lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
output, (h_n, c_n) = lstm(x)

# GRU
gru = nn.GRU(input_size, hidden_size, batch_first=True)
output, h_n = gru(x)

๋‹ค์Œ ๋‹จ๊ณ„

16_Attention_Transformer.md์—์„œ Seq2Seq์™€ Attention์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

to navigate between lessons