03_backprop.py

Download
python 309 lines 8.7 KB
  1"""
  203. ์—ญ์ „ํŒŒ (Backpropagation) - PyTorch ๋ฒ„์ „
  3
  4PyTorch์˜ autograd๊ฐ€ ์—ญ์ „ํŒŒ๋ฅผ ์ž๋™์œผ๋กœ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
  5NumPy ๋ฒ„์ „(examples/numpy/03_backprop_scratch.py)๊ณผ ๋น„๊ตํ•ด ๋ณด์„ธ์š”.
  6
  7ํ•ต์‹ฌ: loss.backward() ํ•œ ์ค„์ด ๋ชจ๋“  ๊ธฐ์šธ๊ธฐ๋ฅผ ์ž๋™ ๊ณ„์‚ฐ!
  8"""
  9
 10import torch
 11import torch.nn as nn
 12import torch.nn.functional as F
 13import matplotlib.pyplot as plt
 14
 15print("=" * 60)
 16print("PyTorch ์—ญ์ „ํŒŒ (Backpropagation)")
 17print("=" * 60)
 18
 19
 20# ============================================
 21# 1. ์ž๋™ ๋ฏธ๋ถ„ ๋ณต์Šต
 22# ============================================
 23print("\n[1] ์ž๋™ ๋ฏธ๋ถ„ ๋ณต์Šต")
 24print("-" * 40)
 25
 26# requires_grad=True๋กœ ๊ธฐ์šธ๊ธฐ ์ถ”์ 
 27x = torch.tensor(2.0, requires_grad=True)
 28w = torch.tensor(3.0, requires_grad=True)
 29b = torch.tensor(1.0, requires_grad=True)
 30
 31# ์ˆœ์ „ํŒŒ
 32y = w * x + b
 33print(f"y = w*x + b = {w.item()}*{x.item()} + {b.item()} = {y.item()}")
 34
 35# ์—ญ์ „ํŒŒ
 36y.backward()
 37
 38print(f"dy/dw = x = {w.grad.item()}")
 39print(f"dy/dx = w = {x.grad.item()}")
 40print(f"dy/db = 1 = {b.grad.item()}")
 41
 42
 43# ============================================
 44# 2. ๋‹จ์ผ ๋‰ด๋Ÿฐ ์—ญ์ „ํŒŒ
 45# ============================================
 46print("\n[2] ๋‹จ์ผ ๋‰ด๋Ÿฐ ์—ญ์ „ํŒŒ")
 47print("-" * 40)
 48
 49# ์ž…๋ ฅ๊ณผ ๋ชฉํ‘œ
 50x = torch.tensor([2.0], requires_grad=True)
 51target = torch.tensor([1.0])
 52
 53# ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ
 54w = torch.tensor([0.5], requires_grad=True)
 55b = torch.tensor([0.1], requires_grad=True)
 56
 57# ์ˆœ์ „ํŒŒ
 58z = w * x + b
 59a = torch.sigmoid(z)
 60loss = (a - target) ** 2
 61
 62print(f"์ž…๋ ฅ: x={x.item()}, target={target.item()}")
 63print(f"๊ฐ€์ค‘์น˜: w={w.item()}, b={b.item()}")
 64print(f"์˜ˆ์ธก: a={a.item():.4f}")
 65print(f"์†์‹ค: {loss.item():.4f}")
 66
 67# ์—ญ์ „ํŒŒ (์ž๋™!)
 68loss.backward()
 69
 70print(f"\n์ž๋™ ๊ณ„์‚ฐ๋œ ๊ธฐ์šธ๊ธฐ:")
 71print(f"  dL/dw = {w.grad.item():.4f}")
 72print(f"  dL/db = {b.grad.item():.4f}")
 73
 74
 75# ============================================
 76# 3. 2์ธต MLP ์—ญ์ „ํŒŒ
 77# ============================================
 78print("\n[3] 2์ธต MLP ์—ญ์ „ํŒŒ")
 79print("-" * 40)
 80
 81class SimpleMLP(nn.Module):
 82    def __init__(self, input_dim, hidden_dim, output_dim):
 83        super().__init__()
 84        self.fc1 = nn.Linear(input_dim, hidden_dim)
 85        self.fc2 = nn.Linear(hidden_dim, output_dim)
 86
 87    def forward(self, x):
 88        x = F.relu(self.fc1(x))
 89        x = torch.sigmoid(self.fc2(x))
 90        return x
 91
 92# ๋ชจ๋ธ ์ƒ์„ฑ
 93torch.manual_seed(42)
 94model = SimpleMLP(2, 8, 1)
 95print(model)
 96
 97# ํŒŒ๋ผ๋ฏธํ„ฐ ํ™•์ธ
 98total_params = sum(p.numel() for p in model.parameters())
 99print(f"\n์ด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {total_params}")
100
101for name, param in model.named_parameters():
102    print(f"  {name}: shape={param.shape}")
103
104
105# ============================================
106# 4. XOR ๋ฌธ์ œ๋กœ ์—ญ์ „ํŒŒ ํ™•์ธ
107# ============================================
108print("\n[4] XOR ๋ฌธ์ œ ํ•™์Šต")
109print("-" * 40)
110
111# ๋ฐ์ดํ„ฐ
112X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
113y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
114
115# ๋ชจ๋ธ, ์†์‹ค ํ•จ์ˆ˜, ์˜ตํ‹ฐ๋งˆ์ด์ €
116torch.manual_seed(42)
117mlp = SimpleMLP(2, 8, 1)
118criterion = nn.MSELoss()
119optimizer = torch.optim.SGD(mlp.parameters(), lr=1.0)
120
121# ํ•™์Šต
122losses = []
123for epoch in range(2000):
124    # ์ˆœ์ „ํŒŒ
125    y_pred = mlp(X)
126    loss = criterion(y_pred, y)
127    losses.append(loss.item())
128
129    # ์—ญ์ „ํŒŒ (ํ•ต์‹ฌ 3์ค„!)
130    optimizer.zero_grad()  # ๊ธฐ์šธ๊ธฐ ์ดˆ๊ธฐํ™”
131    loss.backward()        # ์—ญ์ „ํŒŒ (์ž๋™ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ)
132    optimizer.step()       # ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ
133
134    if (epoch + 1) % 400 == 0:
135        print(f"Epoch {epoch+1}: Loss = {loss.item():.6f}")
136
137# ๊ฒฐ๊ณผ ํ™•์ธ
138print("\nํ•™์Šต ๊ฒฐ๊ณผ:")
139mlp.eval()
140with torch.no_grad():
141    y_final = mlp(X)
142    for i in range(4):
143        print(f"  {X[i].tolist()} โ†’ {y_final[i, 0]:.4f} (์ •๋‹ต: {y[i, 0]})")
144
145# ์†์‹ค ๊ทธ๋ž˜ํ”„
146plt.figure(figsize=(10, 5))
147plt.plot(losses)
148plt.xlabel('Epoch')
149plt.ylabel('Loss')
150plt.title('XOR Training Loss (PyTorch Backprop)')
151plt.yscale('log')
152plt.grid(True, alpha=0.3)
153plt.savefig('pytorch_xor_loss.png', dpi=100)
154plt.close()
155print("\n์†์‹ค ๊ทธ๋ž˜ํ”„ ์ €์žฅ: pytorch_xor_loss.png")
156
157
158# ============================================
159# 5. ๊ธฐ์šธ๊ธฐ ํ๋ฆ„ ์‹œ๊ฐํ™”
160# ============================================
161print("\n[5] ๊ธฐ์šธ๊ธฐ ํ๋ฆ„ ํ™•์ธ")
162print("-" * 40)
163
164# ์ƒˆ ๋ชจ๋ธ๋กœ ๊ธฐ์šธ๊ธฐ ํ™•์ธ
165torch.manual_seed(0)
166test_model = SimpleMLP(2, 4, 1)
167
168# ์ˆœ์ „ํŒŒ
169x_test = torch.tensor([[1.0, 0.0]])
170y_test = torch.tensor([[1.0]])
171
172y_pred = test_model(x_test)
173loss = criterion(y_pred, y_test)
174
175# ์—ญ์ „ํŒŒ ์ „ ๊ธฐ์šธ๊ธฐ ํ™•์ธ
176print("์—ญ์ „ํŒŒ ์ „:")
177for name, param in test_model.named_parameters():
178    print(f"  {name}.grad: {param.grad}")
179
180# ์—ญ์ „ํŒŒ
181loss.backward()
182
183# ์—ญ์ „ํŒŒ ํ›„ ๊ธฐ์šธ๊ธฐ ํ™•์ธ
184print("\n์—ญ์ „ํŒŒ ํ›„:")
185for name, param in test_model.named_parameters():
186    grad_norm = param.grad.norm().item()
187    print(f"  {name}.grad norm: {grad_norm:.6f}")
188
189
190# ============================================
191# 6. ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„ ํ™•์ธ
192# ============================================
193print("\n[6] ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„")
194print("-" * 40)
195
196# ๊ฐ„๋‹จํ•œ ๊ณ„์‚ฐ
197a = torch.tensor(2.0, requires_grad=True)
198b = torch.tensor(3.0, requires_grad=True)
199
200c = a + b
201d = a * b
202e = c * d
203
204print(f"a = {a.item()}, b = {b.item()}")
205print(f"c = a + b = {c.item()}")
206print(f"d = a * b = {d.item()}")
207print(f"e = c * d = {e.item()}")
208
209# ์—ญ์ „ํŒŒ
210e.backward()
211
212print(f"\nde/da = {a.grad.item()}")  # d(c*d)/da = d + c*b = 6 + 5*3 = 21
213print(f"de/db = {b.grad.item()}")  # d(c*d)/db = d + c*a = 6 + 5*2 = 16
214
215# ์ˆ˜๋™ ๊ฒ€์ฆ
216print("\n์ˆ˜๋™ ๊ฒ€์ฆ:")
217print("e = (a+b) * (a*b)")
218print("de/da = (a*b) + (a+b)*b = d + c*b")
219print(f"     = {d.item()} + {c.item()}*{b.item()} = {d.item() + c.item()*b.item()}")
220
221
222# ============================================
223# 7. retain_graph์™€ ๊ธฐ์šธ๊ธฐ ๋ˆ„์ 
224# ============================================
225print("\n[7] ๊ธฐ์šธ๊ธฐ ๋ˆ„์ ")
226print("-" * 40)
227
228x = torch.tensor(2.0, requires_grad=True)
229y = x ** 2
230
231# ์ฒซ ๋ฒˆ์งธ backward
232y.backward(retain_graph=True)
233print(f"์ฒซ ๋ฒˆ์งธ backward: dy/dx = {x.grad.item()}")
234
235# ๋‘ ๋ฒˆ์งธ backward (๊ธฐ์šธ๊ธฐ ๋ˆ„์ !)
236y.backward(retain_graph=True)
237print(f"๋‘ ๋ฒˆ์งธ backward: dy/dx = {x.grad.item()} (๋ˆ„์ ๋จ!)")
238
239# ๊ธฐ์šธ๊ธฐ ์ดˆ๊ธฐํ™” ํ›„ ๋‹ค์‹œ
240x.grad.zero_()
241y.backward()
242print(f"zero_grad() ํ›„: dy/dx = {x.grad.item()}")
243
244
245# ============================================
246# 8. NumPy vs PyTorch ๋น„๊ต
247# ============================================
248print("\n" + "=" * 60)
249print("NumPy vs PyTorch ์—ญ์ „ํŒŒ ๋น„๊ต")
250print("=" * 60)
251
252comparison = """
253| ๋‹จ๊ณ„        | NumPy (์ˆ˜๋™)                    | PyTorch (์ž๋™)              |
254|-------------|--------------------------------|----------------------------|
255| ์ˆœ์ „ํŒŒ      | z1 = X @ W1 + b1               | y = model(X)              |
256|             | a1 = relu(z1)                  |                            |
257|             | z2 = a1 @ W2 + b2              |                            |
258|             | a2 = sigmoid(z2)               |                            |
259| ์†์‹ค        | loss = mean((a2 - y)**2)       | loss = criterion(y, target)|
260| ์—ญ์ „ํŒŒ      | dL_da2 = 2*(a2-y)/m            | loss.backward()           |
261|             | dL_dz2 = dL_da2 * ฯƒ'(z2)       | (์ž๋™!)                    |
262|             | dW2 = a1.T @ dL_dz2            |                            |
263|             | dL_da1 = dL_dz2 @ W2.T         |                            |
264|             | dL_dz1 = dL_da1 * relu'(z1)    |                            |
265|             | dW1 = X.T @ dL_dz1             |                            |
266| ์—…๋ฐ์ดํŠธ    | W1 -= lr * dW1                 | optimizer.step()          |
267|             | W2 -= lr * dW2                 |                            |
268
269NumPy ๊ตฌํ˜„์˜ ๊ฐ€์น˜:
2701. ์ฒด์ธ ๋ฃฐ์˜ ๋™์ž‘ ์›๋ฆฌ ์ง์ ‘ ์ฒดํ—˜
2712. ํ–‰๋ ฌ ์ „์น˜(T)๊ฐ€ ์™œ ํ•„์š”ํ•œ์ง€ ์ดํ•ด
2723. ํ™œ์„ฑํ™” ํ•จ์ˆ˜ ๋ฏธ๋ถ„์˜ ์—ญํ•  ํŒŒ์•…
2734. ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ์˜ ์ˆ˜ํ•™์  ์˜๋ฏธ ์ดํ•ด
274
275PyTorch์˜ ์žฅ์ :
2761. ์ฝ”๋“œ ๊ฐ„๊ฒฐ์„ฑ (3์ค„๋กœ ์—ญ์ „ํŒŒ ์™„๋ฃŒ)
2772. ๊ณ„์‚ฐ ์˜ค๋ฅ˜ ์—†์Œ (์ž๋™ ๋ฏธ๋ถ„)
2783. ๋ณต์žกํ•œ ๋ชจ๋ธ๋„ ๋™์ผํ•œ ๋ฐฉ์‹
2794. GPU ๊ฐ€์† ์ž๋™ ์ง€์›
280"""
281print(comparison)
282
283
284# ============================================
285# ์ •๋ฆฌ
286# ============================================
287print("=" * 60)
288print("์—ญ์ „ํŒŒ ํ•ต์‹ฌ ์ •๋ฆฌ")
289print("=" * 60)
290
291summary = """
292PyTorch ์—ญ์ „ํŒŒ 3์ค„:
293    optimizer.zero_grad()  # ๊ธฐ์šธ๊ธฐ ์ดˆ๊ธฐํ™” (ํ•„์ˆ˜!)
294    loss.backward()        # ์—ญ์ „ํŒŒ (๋ชจ๋“  ๊ธฐ์šธ๊ธฐ ์ž๋™ ๊ณ„์‚ฐ)
295    optimizer.step()       # W = W - lr * grad
296
297์ฃผ์˜์‚ฌํ•ญ:
2981. zero_grad() ์—†์œผ๋ฉด ๊ธฐ์šธ๊ธฐ๊ฐ€ ๋ˆ„์ ๋จ
2992. backward()๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ๊ทธ๋ž˜ํ”„ ์‚ญ์ œ (retain_graph=True๋กœ ์œ ์ง€)
3003. torch.no_grad()๋กœ ์ถ”๋ก  ์‹œ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”
301
302NumPy๋กœ ๊ตฌํ˜„ํ•ด๋ณด๋ฉด:
303- ์ฒด์ธ ๋ฃฐ์ด ์‹ค์ œ๋กœ ์–ด๋–ป๊ฒŒ ์ ์šฉ๋˜๋Š”์ง€ ์ดํ•ด
304- backward()๊ฐ€ ๋‚ด๋ถ€์ ์œผ๋กœ ํ•˜๋Š” ์ผ์„ ์•Œ๊ฒŒ ๋จ
305- ๋” ๊นŠ์€ ๋””๋ฒ„๊น… ๋Šฅ๋ ฅ ํš๋“
306"""
307print(summary)
308print("=" * 60)