1"""
203. ์ญ์ ํ (Backpropagation) - PyTorch ๋ฒ์
3
4PyTorch์ autograd๊ฐ ์ญ์ ํ๋ฅผ ์๋์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค.
5NumPy ๋ฒ์ (examples/numpy/03_backprop_scratch.py)๊ณผ ๋น๊ตํด ๋ณด์ธ์.
6
7ํต์ฌ: loss.backward() ํ ์ค์ด ๋ชจ๋ ๊ธฐ์ธ๊ธฐ๋ฅผ ์๋ ๊ณ์ฐ!
8"""
9
10import torch
11import torch.nn as nn
12import torch.nn.functional as F
13import matplotlib.pyplot as plt
14
15print("=" * 60)
16print("PyTorch ์ญ์ ํ (Backpropagation)")
17print("=" * 60)
18
19
20# ============================================
21# 1. ์๋ ๋ฏธ๋ถ ๋ณต์ต
22# ============================================
23print("\n[1] ์๋ ๋ฏธ๋ถ ๋ณต์ต")
24print("-" * 40)
25
26# requires_grad=True๋ก ๊ธฐ์ธ๊ธฐ ์ถ์
27x = torch.tensor(2.0, requires_grad=True)
28w = torch.tensor(3.0, requires_grad=True)
29b = torch.tensor(1.0, requires_grad=True)
30
31# ์์ ํ
32y = w * x + b
33print(f"y = w*x + b = {w.item()}*{x.item()} + {b.item()} = {y.item()}")
34
35# ์ญ์ ํ
36y.backward()
37
38print(f"dy/dw = x = {w.grad.item()}")
39print(f"dy/dx = w = {x.grad.item()}")
40print(f"dy/db = 1 = {b.grad.item()}")
41
42
43# ============================================
44# 2. ๋จ์ผ ๋ด๋ฐ ์ญ์ ํ
45# ============================================
46print("\n[2] ๋จ์ผ ๋ด๋ฐ ์ญ์ ํ")
47print("-" * 40)
48
49# ์
๋ ฅ๊ณผ ๋ชฉํ
50x = torch.tensor([2.0], requires_grad=True)
51target = torch.tensor([1.0])
52
53# ๊ฐ์ค์น์ ํธํฅ
54w = torch.tensor([0.5], requires_grad=True)
55b = torch.tensor([0.1], requires_grad=True)
56
57# ์์ ํ
58z = w * x + b
59a = torch.sigmoid(z)
60loss = (a - target) ** 2
61
62print(f"์
๋ ฅ: x={x.item()}, target={target.item()}")
63print(f"๊ฐ์ค์น: w={w.item()}, b={b.item()}")
64print(f"์์ธก: a={a.item():.4f}")
65print(f"์์ค: {loss.item():.4f}")
66
67# ์ญ์ ํ (์๋!)
68loss.backward()
69
70print(f"\n์๋ ๊ณ์ฐ๋ ๊ธฐ์ธ๊ธฐ:")
71print(f" dL/dw = {w.grad.item():.4f}")
72print(f" dL/db = {b.grad.item():.4f}")
73
74
75# ============================================
76# 3. 2์ธต MLP ์ญ์ ํ
77# ============================================
78print("\n[3] 2์ธต MLP ์ญ์ ํ")
79print("-" * 40)
80
81class SimpleMLP(nn.Module):
82 def __init__(self, input_dim, hidden_dim, output_dim):
83 super().__init__()
84 self.fc1 = nn.Linear(input_dim, hidden_dim)
85 self.fc2 = nn.Linear(hidden_dim, output_dim)
86
87 def forward(self, x):
88 x = F.relu(self.fc1(x))
89 x = torch.sigmoid(self.fc2(x))
90 return x
91
92# ๋ชจ๋ธ ์์ฑ
93torch.manual_seed(42)
94model = SimpleMLP(2, 8, 1)
95print(model)
96
97# ํ๋ผ๋ฏธํฐ ํ์ธ
98total_params = sum(p.numel() for p in model.parameters())
99print(f"\n์ด ํ๋ผ๋ฏธํฐ ์: {total_params}")
100
101for name, param in model.named_parameters():
102 print(f" {name}: shape={param.shape}")
103
104
105# ============================================
106# 4. XOR ๋ฌธ์ ๋ก ์ญ์ ํ ํ์ธ
107# ============================================
108print("\n[4] XOR ๋ฌธ์ ํ์ต")
109print("-" * 40)
110
111# ๋ฐ์ดํฐ
112X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
113y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
114
115# ๋ชจ๋ธ, ์์ค ํจ์, ์ตํฐ๋ง์ด์
116torch.manual_seed(42)
117mlp = SimpleMLP(2, 8, 1)
118criterion = nn.MSELoss()
119optimizer = torch.optim.SGD(mlp.parameters(), lr=1.0)
120
121# ํ์ต
122losses = []
123for epoch in range(2000):
124 # ์์ ํ
125 y_pred = mlp(X)
126 loss = criterion(y_pred, y)
127 losses.append(loss.item())
128
129 # ์ญ์ ํ (ํต์ฌ 3์ค!)
130 optimizer.zero_grad() # ๊ธฐ์ธ๊ธฐ ์ด๊ธฐํ
131 loss.backward() # ์ญ์ ํ (์๋ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ)
132 optimizer.step() # ๊ฐ์ค์น ์
๋ฐ์ดํธ
133
134 if (epoch + 1) % 400 == 0:
135 print(f"Epoch {epoch+1}: Loss = {loss.item():.6f}")
136
137# ๊ฒฐ๊ณผ ํ์ธ
138print("\nํ์ต ๊ฒฐ๊ณผ:")
139mlp.eval()
140with torch.no_grad():
141 y_final = mlp(X)
142 for i in range(4):
143 print(f" {X[i].tolist()} โ {y_final[i, 0]:.4f} (์ ๋ต: {y[i, 0]})")
144
145# ์์ค ๊ทธ๋ํ
146plt.figure(figsize=(10, 5))
147plt.plot(losses)
148plt.xlabel('Epoch')
149plt.ylabel('Loss')
150plt.title('XOR Training Loss (PyTorch Backprop)')
151plt.yscale('log')
152plt.grid(True, alpha=0.3)
153plt.savefig('pytorch_xor_loss.png', dpi=100)
154plt.close()
155print("\n์์ค ๊ทธ๋ํ ์ ์ฅ: pytorch_xor_loss.png")
156
157
158# ============================================
159# 5. ๊ธฐ์ธ๊ธฐ ํ๋ฆ ์๊ฐํ
160# ============================================
161print("\n[5] ๊ธฐ์ธ๊ธฐ ํ๋ฆ ํ์ธ")
162print("-" * 40)
163
164# ์ ๋ชจ๋ธ๋ก ๊ธฐ์ธ๊ธฐ ํ์ธ
165torch.manual_seed(0)
166test_model = SimpleMLP(2, 4, 1)
167
168# ์์ ํ
169x_test = torch.tensor([[1.0, 0.0]])
170y_test = torch.tensor([[1.0]])
171
172y_pred = test_model(x_test)
173loss = criterion(y_pred, y_test)
174
175# ์ญ์ ํ ์ ๊ธฐ์ธ๊ธฐ ํ์ธ
176print("์ญ์ ํ ์ :")
177for name, param in test_model.named_parameters():
178 print(f" {name}.grad: {param.grad}")
179
180# ์ญ์ ํ
181loss.backward()
182
183# ์ญ์ ํ ํ ๊ธฐ์ธ๊ธฐ ํ์ธ
184print("\n์ญ์ ํ ํ:")
185for name, param in test_model.named_parameters():
186 grad_norm = param.grad.norm().item()
187 print(f" {name}.grad norm: {grad_norm:.6f}")
188
189
190# ============================================
191# 6. ๊ณ์ฐ ๊ทธ๋ํ ํ์ธ
192# ============================================
193print("\n[6] ๊ณ์ฐ ๊ทธ๋ํ")
194print("-" * 40)
195
196# ๊ฐ๋จํ ๊ณ์ฐ
197a = torch.tensor(2.0, requires_grad=True)
198b = torch.tensor(3.0, requires_grad=True)
199
200c = a + b
201d = a * b
202e = c * d
203
204print(f"a = {a.item()}, b = {b.item()}")
205print(f"c = a + b = {c.item()}")
206print(f"d = a * b = {d.item()}")
207print(f"e = c * d = {e.item()}")
208
209# ์ญ์ ํ
210e.backward()
211
212print(f"\nde/da = {a.grad.item()}") # d(c*d)/da = d + c*b = 6 + 5*3 = 21
213print(f"de/db = {b.grad.item()}") # d(c*d)/db = d + c*a = 6 + 5*2 = 16
214
215# ์๋ ๊ฒ์ฆ
216print("\n์๋ ๊ฒ์ฆ:")
217print("e = (a+b) * (a*b)")
218print("de/da = (a*b) + (a+b)*b = d + c*b")
219print(f" = {d.item()} + {c.item()}*{b.item()} = {d.item() + c.item()*b.item()}")
220
221
222# ============================================
223# 7. retain_graph์ ๊ธฐ์ธ๊ธฐ ๋์
224# ============================================
225print("\n[7] ๊ธฐ์ธ๊ธฐ ๋์ ")
226print("-" * 40)
227
228x = torch.tensor(2.0, requires_grad=True)
229y = x ** 2
230
231# ์ฒซ ๋ฒ์งธ backward
232y.backward(retain_graph=True)
233print(f"์ฒซ ๋ฒ์งธ backward: dy/dx = {x.grad.item()}")
234
235# ๋ ๋ฒ์งธ backward (๊ธฐ์ธ๊ธฐ ๋์ !)
236y.backward(retain_graph=True)
237print(f"๋ ๋ฒ์งธ backward: dy/dx = {x.grad.item()} (๋์ ๋จ!)")
238
239# ๊ธฐ์ธ๊ธฐ ์ด๊ธฐํ ํ ๋ค์
240x.grad.zero_()
241y.backward()
242print(f"zero_grad() ํ: dy/dx = {x.grad.item()}")
243
244
245# ============================================
246# 8. NumPy vs PyTorch ๋น๊ต
247# ============================================
248print("\n" + "=" * 60)
249print("NumPy vs PyTorch ์ญ์ ํ ๋น๊ต")
250print("=" * 60)
251
252comparison = """
253| ๋จ๊ณ | NumPy (์๋) | PyTorch (์๋) |
254|-------------|--------------------------------|----------------------------|
255| ์์ ํ | z1 = X @ W1 + b1 | y = model(X) |
256| | a1 = relu(z1) | |
257| | z2 = a1 @ W2 + b2 | |
258| | a2 = sigmoid(z2) | |
259| ์์ค | loss = mean((a2 - y)**2) | loss = criterion(y, target)|
260| ์ญ์ ํ | dL_da2 = 2*(a2-y)/m | loss.backward() |
261| | dL_dz2 = dL_da2 * ฯ'(z2) | (์๋!) |
262| | dW2 = a1.T @ dL_dz2 | |
263| | dL_da1 = dL_dz2 @ W2.T | |
264| | dL_dz1 = dL_da1 * relu'(z1) | |
265| | dW1 = X.T @ dL_dz1 | |
266| ์
๋ฐ์ดํธ | W1 -= lr * dW1 | optimizer.step() |
267| | W2 -= lr * dW2 | |
268
269NumPy ๊ตฌํ์ ๊ฐ์น:
2701. ์ฒด์ธ ๋ฃฐ์ ๋์ ์๋ฆฌ ์ง์ ์ฒดํ
2712. ํ๋ ฌ ์ ์น(T)๊ฐ ์ ํ์ํ์ง ์ดํด
2723. ํ์ฑํ ํจ์ ๋ฏธ๋ถ์ ์ญํ ํ์
2734. ๋ฐฐ์น ์ฒ๋ฆฌ์ ์ํ์ ์๋ฏธ ์ดํด
274
275PyTorch์ ์ฅ์ :
2761. ์ฝ๋ ๊ฐ๊ฒฐ์ฑ (3์ค๋ก ์ญ์ ํ ์๋ฃ)
2772. ๊ณ์ฐ ์ค๋ฅ ์์ (์๋ ๋ฏธ๋ถ)
2783. ๋ณต์กํ ๋ชจ๋ธ๋ ๋์ผํ ๋ฐฉ์
2794. GPU ๊ฐ์ ์๋ ์ง์
280"""
281print(comparison)
282
283
284# ============================================
285# ์ ๋ฆฌ
286# ============================================
287print("=" * 60)
288print("์ญ์ ํ ํต์ฌ ์ ๋ฆฌ")
289print("=" * 60)
290
291summary = """
292PyTorch ์ญ์ ํ 3์ค:
293 optimizer.zero_grad() # ๊ธฐ์ธ๊ธฐ ์ด๊ธฐํ (ํ์!)
294 loss.backward() # ์ญ์ ํ (๋ชจ๋ ๊ธฐ์ธ๊ธฐ ์๋ ๊ณ์ฐ)
295 optimizer.step() # W = W - lr * grad
296
297์ฃผ์์ฌํญ:
2981. zero_grad() ์์ผ๋ฉด ๊ธฐ์ธ๊ธฐ๊ฐ ๋์ ๋จ
2992. backward()๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๊ทธ๋ํ ์ญ์ (retain_graph=True๋ก ์ ์ง)
3003. torch.no_grad()๋ก ์ถ๋ก ์ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ ๋นํ์ฑํ
301
302NumPy๋ก ๊ตฌํํด๋ณด๋ฉด:
303- ์ฒด์ธ ๋ฃฐ์ด ์ค์ ๋ก ์ด๋ป๊ฒ ์ ์ฉ๋๋์ง ์ดํด
304- backward()๊ฐ ๋ด๋ถ์ ์ผ๋ก ํ๋ ์ผ์ ์๊ฒ ๋จ
305- ๋ ๊น์ ๋๋ฒ๊น
๋ฅ๋ ฅ ํ๋
306"""
307print(summary)
308print("=" * 60)