1"""
202. ์ ๊ฒฝ๋ง ๊ธฐ์ด - PyTorch ๋ฒ์
3
4nn.Module์ ์ฌ์ฉํ MLP ๊ตฌํ๊ณผ XOR ๋ฌธ์ ํด๊ฒฐ.
5NumPy ๋ฒ์ (examples/numpy/02_neural_network_scratch.py)๊ณผ ๋น๊ตํด ๋ณด์ธ์.
6"""
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11import numpy as np
12import matplotlib.pyplot as plt
13
14print("=" * 60)
15print("PyTorch ์ ๊ฒฝ๋ง ๊ธฐ์ด")
16print("=" * 60)
17
18
19# ============================================
20# 1. ํ์ฑํ ํจ์
21# ============================================
22print("\n[1] ํ์ฑํ ํจ์")
23print("-" * 40)
24
25x = torch.linspace(-5, 5, 100)
26
27# ํ์ฑํ ํจ์ ์ ์ฉ
28sigmoid_out = torch.sigmoid(x)
29tanh_out = torch.tanh(x)
30relu_out = F.relu(x)
31leaky_relu_out = F.leaky_relu(x, 0.1)
32
33# ์๊ฐํ
34fig, axes = plt.subplots(2, 2, figsize=(12, 8))
35
36axes[0, 0].plot(x.numpy(), sigmoid_out.numpy())
37axes[0, 0].set_title('Sigmoid')
38axes[0, 0].grid(True, alpha=0.3)
39axes[0, 0].axhline(y=0, color='k', linewidth=0.5)
40axes[0, 0].axvline(x=0, color='k', linewidth=0.5)
41
42axes[0, 1].plot(x.numpy(), tanh_out.numpy())
43axes[0, 1].set_title('Tanh')
44axes[0, 1].grid(True, alpha=0.3)
45axes[0, 1].axhline(y=0, color='k', linewidth=0.5)
46axes[0, 1].axvline(x=0, color='k', linewidth=0.5)
47
48axes[1, 0].plot(x.numpy(), relu_out.numpy())
49axes[1, 0].set_title('ReLU')
50axes[1, 0].grid(True, alpha=0.3)
51axes[1, 0].axhline(y=0, color='k', linewidth=0.5)
52axes[1, 0].axvline(x=0, color='k', linewidth=0.5)
53
54axes[1, 1].plot(x.numpy(), leaky_relu_out.numpy())
55axes[1, 1].set_title('Leaky ReLU (ฮฑ=0.1)')
56axes[1, 1].grid(True, alpha=0.3)
57axes[1, 1].axhline(y=0, color='k', linewidth=0.5)
58axes[1, 1].axvline(x=0, color='k', linewidth=0.5)
59
60plt.tight_layout()
61plt.savefig('activation_functions.png', dpi=100)
62plt.close()
63print("ํ์ฑํ ํจ์ ๊ทธ๋ํ ์ ์ฅ: activation_functions.png")
64
65
66# ============================================
67# 2. nn.Module๋ก MLP ์ ์
68# ============================================
69print("\n[2] nn.Module MLP")
70print("-" * 40)
71
72class MLP(nn.Module):
73 def __init__(self, input_dim, hidden_dim, output_dim):
74 super().__init__()
75 self.fc1 = nn.Linear(input_dim, hidden_dim)
76 self.fc2 = nn.Linear(hidden_dim, output_dim)
77
78 def forward(self, x):
79 x = F.relu(self.fc1(x))
80 x = self.fc2(x)
81 return x
82
83model = MLP(input_dim=10, hidden_dim=32, output_dim=3)
84print(model)
85
86# ํ๋ผ๋ฏธํฐ ํ์ธ
87total_params = sum(p.numel() for p in model.parameters())
88print(f"Total parameters: {total_params}")
89
90for name, param in model.named_parameters():
91 print(f" {name}: {param.shape}")
92
93
94# ============================================
95# 3. nn.Sequential๋ก ๊ฐ๋จํ ์ ์
96# ============================================
97print("\n[3] nn.Sequential")
98print("-" * 40)
99
100model_seq = nn.Sequential(
101 nn.Linear(10, 32),
102 nn.ReLU(),
103 nn.Linear(32, 16),
104 nn.ReLU(),
105 nn.Linear(16, 3)
106)
107print(model_seq)
108
109
110# ============================================
111# 4. XOR ๋ฌธ์ ํด๊ฒฐ
112# ============================================
113print("\n[4] XOR ๋ฌธ์ ํด๊ฒฐ")
114print("-" * 40)
115
116# ๋ฐ์ดํฐ
117X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
118y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
119
120print("XOR ๋ฐ์ดํฐ:")
121print(" (0,0) โ 0")
122print(" (0,1) โ 1")
123print(" (1,0) โ 1")
124print(" (1,1) โ 0")
125
126# ๋ชจ๋ธ ์ ์
127class XORNet(nn.Module):
128 def __init__(self):
129 super().__init__()
130 self.fc1 = nn.Linear(2, 8)
131 self.fc2 = nn.Linear(8, 1)
132
133 def forward(self, x):
134 x = torch.relu(self.fc1(x))
135 x = torch.sigmoid(self.fc2(x))
136 return x
137
138xor_model = XORNet()
139
140# ์์ค ํจ์์ ์ตํฐ๋ง์ด์
141criterion = nn.BCELoss()
142optimizer = torch.optim.Adam(xor_model.parameters(), lr=0.1)
143
144# ํ์ต
145losses = []
146for epoch in range(1000):
147 # ์์ ํ
148 pred = xor_model(X)
149 loss = criterion(pred, y)
150
151 # ์ญ์ ํ
152 optimizer.zero_grad()
153 loss.backward()
154 optimizer.step()
155
156 losses.append(loss.item())
157
158 if (epoch + 1) % 200 == 0:
159 print(f"Epoch {epoch+1}: Loss = {loss.item():.6f}")
160
161# ๊ฒฐ๊ณผ ํ์ธ
162print("\nํ์ต ๊ฒฐ๊ณผ:")
163xor_model.eval()
164with torch.no_grad():
165 predictions = xor_model(X)
166 for i in range(4):
167 print(f" {X[i].numpy()} โ {predictions[i].item():.4f} (์ ๋ต: {y[i].item()})")
168
169# ์์ค ๊ทธ๋ํ
170plt.figure(figsize=(10, 5))
171plt.plot(losses)
172plt.xlabel('Epoch')
173plt.ylabel('Loss')
174plt.title('XOR Training Loss')
175plt.grid(True, alpha=0.3)
176plt.savefig('xor_loss.png', dpi=100)
177plt.close()
178print("์์ค ๊ทธ๋ํ ์ ์ฅ: xor_loss.png")
179
180
181# ============================================
182# 5. ๊ฐ์ค์น ์ด๊ธฐํ
183# ============================================
184print("\n[5] ๊ฐ์ค์น ์ด๊ธฐํ")
185print("-" * 40)
186
187def init_weights(m):
188 if isinstance(m, nn.Linear):
189 nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
190 nn.init.zeros_(m.bias)
191 print(f" Initialized: {m}")
192
193model_init = nn.Sequential(
194 nn.Linear(10, 32),
195 nn.ReLU(),
196 nn.Linear(32, 10)
197)
198
199print("๊ฐ์ค์น ์ด๊ธฐํ ์ :")
200print(f" fc1 weight mean: {model_init[0].weight.mean().item():.6f}")
201
202print("\n์ด๊ธฐํ ์ ์ฉ:")
203model_init.apply(init_weights)
204
205print("\n๊ฐ์ค์น ์ด๊ธฐํ ํ:")
206print(f" fc1 weight mean: {model_init[0].weight.mean().item():.6f}")
207
208
209# ============================================
210# 6. ์์ ํ ๋จ๊ณ๋ณ ํ์ธ
211# ============================================
212print("\n[6] ์์ ํ ๋จ๊ณ๋ณ ํ์ธ")
213print("-" * 40)
214
215class VerboseMLP(nn.Module):
216 def __init__(self):
217 super().__init__()
218 self.fc1 = nn.Linear(3, 4)
219 self.fc2 = nn.Linear(4, 2)
220
221 def forward(self, x):
222 print(f" ์
๋ ฅ: {x.shape}")
223
224 z1 = self.fc1(x)
225 print(f" fc1 ํ: {z1.shape}")
226
227 a1 = F.relu(z1)
228 print(f" ReLU ํ: {a1.shape}")
229
230 z2 = self.fc2(a1)
231 print(f" fc2 ํ (์ถ๋ ฅ): {z2.shape}")
232
233 return z2
234
235verbose_model = VerboseMLP()
236sample_input = torch.randn(2, 3) # ๋ฐฐ์น ํฌ๊ธฐ 2, ์
๋ ฅ ์ฐจ์ 3
237print("์์ ํ ๊ณผ์ :")
238output = verbose_model(sample_input)
239
240
241# ============================================
242# 7. ๋ชจ๋ธ ์ ์ฅ ๋ฐ ๋ก๋
243# ============================================
244print("\n[7] ๋ชจ๋ธ ์ ์ฅ/๋ก๋")
245print("-" * 40)
246
247# ์ ์ฅ
248torch.save(xor_model.state_dict(), 'xor_model.pth')
249print("๋ชจ๋ธ ์ ์ฅ: xor_model.pth")
250
251# ์ ๋ชจ๋ธ์ ๋ก๋
252new_model = XORNet()
253new_model.load_state_dict(torch.load('xor_model.pth', weights_only=True))
254new_model.eval()
255print("๋ชจ๋ธ ๋ก๋ ์๋ฃ")
256
257# ๊ฒ์ฆ
258with torch.no_grad():
259 new_pred = new_model(X)
260 print("๋ก๋๋ ๋ชจ๋ธ ์์ธก:")
261 for i in range(4):
262 print(f" {X[i].numpy()} โ {new_pred[i].item():.4f}")
263
264
265print("\n" + "=" * 60)
266print("PyTorch ์ ๊ฒฝ๋ง ๊ธฐ์ด ์๋ฃ!")
267print("NumPy ๋ฒ์ ๊ณผ ๋น๊ต: examples/numpy/02_neural_network_scratch.py")
268print("=" * 60)