1"""
212. ๋ชจ๋ธ ์ ์ฅ ๋ฐ ๋ฐฐํฌ
3
4PyTorch ๋ชจ๋ธ ์ ์ฅ, TorchScript, ONNX ๋ณํ์ ๊ตฌํํฉ๋๋ค.
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10import os
11import tempfile
12
13print("=" * 60)
14print("PyTorch ๋ชจ๋ธ ์ ์ฅ ๋ฐ ๋ฐฐํฌ")
15print("=" * 60)
16
17
18# ============================================
19# 1. ์ํ ๋ชจ๋ธ
20# ============================================
21print("\n[1] ์ํ ๋ชจ๋ธ")
22print("-" * 40)
23
24class SimpleClassifier(nn.Module):
25 def __init__(self, input_size=784, hidden_size=256, num_classes=10):
26 super().__init__()
27 self.config = {
28 'input_size': input_size,
29 'hidden_size': hidden_size,
30 'num_classes': num_classes
31 }
32 self.fc1 = nn.Linear(input_size, hidden_size)
33 self.bn1 = nn.BatchNorm1d(hidden_size)
34 self.fc2 = nn.Linear(hidden_size, num_classes)
35
36 def forward(self, x):
37 x = x.view(x.size(0), -1)
38 x = F.relu(self.bn1(self.fc1(x)))
39 x = self.fc2(x)
40 return x
41
42model = SimpleClassifier()
43print(f"๋ชจ๋ธ ๊ตฌ์กฐ:\n{model}")
44print(f"ํ๋ผ๋ฏธํฐ ์: {sum(p.numel() for p in model.parameters()):,}")
45
46
47# ============================================
48# 2. state_dict ์ ์ฅ
49# ============================================
50print("\n[2] state_dict ์ ์ฅ")
51print("-" * 40)
52
53# ์์ ๋๋ ํ ๋ฆฌ ์ฌ์ฉ
54save_dir = tempfile.mkdtemp()
55
56# ์ ์ฅ
57weights_path = os.path.join(save_dir, 'model_weights.pth')
58torch.save(model.state_dict(), weights_path)
59print(f"์ ์ฅ: {weights_path}")
60print(f"ํ์ผ ํฌ๊ธฐ: {os.path.getsize(weights_path) / 1024:.2f} KB")
61
62# ๋ก๋
63loaded_model = SimpleClassifier()
64loaded_model.load_state_dict(torch.load(weights_path, weights_only=True))
65loaded_model.eval()
66
67# ๊ฒ์ฆ
68x = torch.randn(2, 1, 28, 28)
69model.eval()
70with torch.no_grad():
71 original_out = model(x)
72 loaded_out = loaded_model(x)
73 diff = (original_out - loaded_out).abs().max().item()
74 print(f"์ถ๋ ฅ ์ฐจ์ด: {diff:.10f}")
75
76
77# ============================================
78# 3. ์ฒดํฌํฌ์ธํธ ์ ์ฅ
79# ============================================
80print("\n[3] ์ฒดํฌํฌ์ธํธ ์ ์ฅ")
81print("-" * 40)
82
83optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
84
85# ๊ฐ์ง ํ์ต ์ํ
86epoch = 10
87loss = 0.123
88best_acc = 0.95
89
90# ์ฒดํฌํฌ์ธํธ ์ ์ฅ
91checkpoint = {
92 'epoch': epoch,
93 'model_state_dict': model.state_dict(),
94 'optimizer_state_dict': optimizer.state_dict(),
95 'loss': loss,
96 'best_acc': best_acc,
97 'model_config': model.config
98}
99
100checkpoint_path = os.path.join(save_dir, 'checkpoint.pth')
101torch.save(checkpoint, checkpoint_path)
102print(f"์ฒดํฌํฌ์ธํธ ์ ์ฅ: {checkpoint_path}")
103
104# ์ฒดํฌํฌ์ธํธ ๋ก๋
105loaded_checkpoint = torch.load(checkpoint_path, weights_only=False)
106print(f"๋ก๋๋ epoch: {loaded_checkpoint['epoch']}")
107print(f"๋ก๋๋ best_acc: {loaded_checkpoint['best_acc']}")
108print(f"๋ชจ๋ธ ์ค์ : {loaded_checkpoint['model_config']}")
109
110
111# ============================================
112# 4. TorchScript - Tracing
113# ============================================
114print("\n[4] TorchScript - Tracing")
115print("-" * 40)
116
117model.eval()
118example_input = torch.randn(1, 1, 28, 28)
119
120# Trace
121traced_model = torch.jit.trace(model, example_input)
122
123# ์ ์ฅ
124traced_path = os.path.join(save_dir, 'model_traced.pt')
125traced_model.save(traced_path)
126print(f"TorchScript ์ ์ฅ: {traced_path}")
127print(f"ํ์ผ ํฌ๊ธฐ: {os.path.getsize(traced_path) / 1024:.2f} KB")
128
129# ๋ก๋ ๋ฐ ๊ฒ์ฆ
130loaded_traced = torch.jit.load(traced_path)
131with torch.no_grad():
132 traced_out = loaded_traced(example_input)
133 original_out = model(example_input)
134 diff = (traced_out - original_out).abs().max().item()
135 print(f"์ถ๋ ฅ ์ฐจ์ด: {diff:.10f}")
136
137
138# ============================================
139# 5. TorchScript - Scripting
140# ============================================
141print("\n[5] TorchScript - Scripting")
142print("-" * 40)
143
144class ConditionalModel(nn.Module):
145 """์กฐ๊ฑด๋ฌธ์ด ์๋ ๋ชจ๋ธ"""
146 def __init__(self):
147 super().__init__()
148 self.fc = nn.Linear(10, 5)
149
150 def forward(self, x, use_relu: bool = True):
151 x = self.fc(x)
152 if use_relu:
153 x = F.relu(x)
154 return x
155
156cond_model = ConditionalModel()
157scripted_model = torch.jit.script(cond_model)
158
159scripted_path = os.path.join(save_dir, 'model_scripted.pt')
160scripted_model.save(scripted_path)
161print(f"Scripted ๋ชจ๋ธ ์ ์ฅ: {scripted_path}")
162
163# ์กฐ๊ฑด๋ถ ์คํ ํ
์คํธ
164x = torch.randn(2, 10)
165out_relu = scripted_model(x, True)
166out_no_relu = scripted_model(x, False)
167print(f"ReLU ์ ์ฉ: min={out_relu.min():.4f}")
168print(f"ReLU ๋ฏธ์ ์ฉ: min={out_no_relu.min():.4f}")
169
170
171# ============================================
172# 6. ONNX ๋ณํ
173# ============================================
174print("\n[6] ONNX ๋ณํ")
175print("-" * 40)
176
177try:
178 import onnx
179
180 model.eval()
181 dummy_input = torch.randn(1, 1, 28, 28)
182
183 onnx_path = os.path.join(save_dir, 'model.onnx')
184
185 torch.onnx.export(
186 model,
187 dummy_input,
188 onnx_path,
189 input_names=['input'],
190 output_names=['output'],
191 dynamic_axes={
192 'input': {0: 'batch_size'},
193 'output': {0: 'batch_size'}
194 },
195 opset_version=11
196 )
197
198 print(f"ONNX ์ ์ฅ: {onnx_path}")
199 print(f"ํ์ผ ํฌ๊ธฐ: {os.path.getsize(onnx_path) / 1024:.2f} KB")
200
201 # ๊ฒ์ฆ
202 onnx_model = onnx.load(onnx_path)
203 onnx.checker.check_model(onnx_model)
204 print("ONNX ๋ชจ๋ธ ๊ฒ์ฆ ํต๊ณผ")
205
206except ImportError:
207 print("onnx ๋ฏธ์ค์น - ์คํต")
208
209
210# ============================================
211# 7. ONNX Runtime ์ถ๋ก
212# ============================================
213print("\n[7] ONNX Runtime ์ถ๋ก ")
214print("-" * 40)
215
216try:
217 import onnxruntime as ort
218 import numpy as np
219
220 session = ort.InferenceSession(onnx_path)
221
222 input_name = session.get_inputs()[0].name
223 output_name = session.get_outputs()[0].name
224
225 # ์ถ๋ก
226 input_data = np.random.randn(2, 1, 28, 28).astype(np.float32)
227 result = session.run([output_name], {input_name: input_data})
228
229 print(f"ONNX Runtime ์ถ๋ ฅ: {result[0].shape}")
230
231 # PyTorch ๊ฒฐ๊ณผ์ ๋น๊ต
232 model.eval()
233 with torch.no_grad():
234 torch_out = model(torch.from_numpy(input_data))
235 diff = np.abs(result[0] - torch_out.numpy()).max()
236 print(f"PyTorch vs ONNX ์ฐจ์ด: {diff:.6f}")
237
238except ImportError:
239 print("onnxruntime ๋ฏธ์ค์น - ์คํต")
240
241
242# ============================================
243# 8. ์์ํ
244# ============================================
245print("\n[8] ์์ํ (Quantization)")
246print("-" * 40)
247
248# ๋์ ์์ํ
249quantized_model = torch.quantization.quantize_dynamic(
250 model, {nn.Linear}, dtype=torch.qint8
251)
252
253# ํฌ๊ธฐ ๋น๊ต
254original_size = sum(p.numel() * p.element_size() for p in model.parameters())
255quantized_size = sum(
256 p.numel() * p.element_size() for p in quantized_model.parameters()
257 if p.dtype != torch.qint8
258)
259
260print(f"์๋ณธ ๋ชจ๋ธ ํฌ๊ธฐ: {original_size / 1024:.2f} KB")
261print(f"์์ํ ๋ชจ๋ธ (์ผ๋ถ ์ธต): ์ฝ {original_size / 1024 * 0.25:.2f} KB (์ถ์ )")
262
263# ์ถ๋ก ๋น๊ต
264x = torch.randn(100, 1, 28, 28)
265
266model.eval()
267quantized_model.eval()
268
269import time
270
271# ์๋ณธ ๋ชจ๋ธ
272start = time.time()
273for _ in range(10):
274 with torch.no_grad():
275 _ = model(x)
276original_time = time.time() - start
277
278# ์์ํ ๋ชจ๋ธ
279start = time.time()
280for _ in range(10):
281 with torch.no_grad():
282 _ = quantized_model(x)
283quantized_time = time.time() - start
284
285print(f"์๋ณธ ์ถ๋ก ์๊ฐ: {original_time*1000:.2f} ms")
286print(f"์์ํ ์ถ๋ก ์๊ฐ: {quantized_time*1000:.2f} ms")
287
288
289# ============================================
290# 9. ์ถ๋ก ์ต์ ํ
291# ============================================
292print("\n[9] ์ถ๋ก ์ต์ ํ")
293print("-" * 40)
294
295model.eval()
296x = torch.randn(100, 1, 28, 28)
297
298# no_grad
299start = time.time()
300for _ in range(100):
301 with torch.no_grad():
302 _ = model(x)
303no_grad_time = time.time() - start
304
305# inference_mode (๋ ๋น ๋ฆ)
306start = time.time()
307for _ in range(100):
308 with torch.inference_mode():
309 _ = model(x)
310inference_time = time.time() - start
311
312print(f"no_grad ์๊ฐ: {no_grad_time*1000:.2f} ms")
313print(f"inference_mode ์๊ฐ: {inference_time*1000:.2f} ms")
314print(f"๊ฐ์ : {(no_grad_time - inference_time) / no_grad_time * 100:.1f}%")
315
316
317# ============================================
318# 10. ๋ชจ๋ฐ์ผ ์ต์ ํ
319# ============================================
320print("\n[10] ๋ชจ๋ฐ์ผ ์ต์ ํ")
321print("-" * 40)
322
323try:
324 # ๋ชจ๋ฐ์ผ์ฉ ์ต์ ํ
325 traced_model = torch.jit.trace(model.eval(), example_input)
326 optimized_model = torch.utils.mobile_optimizer.optimize_for_mobile(traced_model)
327
328 mobile_path = os.path.join(save_dir, 'model_mobile.ptl')
329 optimized_model._save_for_lite_interpreter(mobile_path)
330
331 print(f"๋ชจ๋ฐ์ผ ๋ชจ๋ธ ์ ์ฅ: {mobile_path}")
332 print(f"ํ์ผ ํฌ๊ธฐ: {os.path.getsize(mobile_path) / 1024:.2f} KB")
333except Exception as e:
334 print(f"๋ชจ๋ฐ์ผ ์ต์ ํ ์คํต: {e}")
335
336
337# ============================================
338# 11. ์ ์ฅ๋ ํ์ผ ๋ชฉ๋ก
339# ============================================
340print("\n[11] ์ ์ฅ๋ ํ์ผ ๋ชฉ๋ก")
341print("-" * 40)
342
343print(f"์ ์ฅ ๋๋ ํ ๋ฆฌ: {save_dir}")
344for f in os.listdir(save_dir):
345 path = os.path.join(save_dir, f)
346 size = os.path.getsize(path) / 1024
347 print(f" {f}: {size:.2f} KB")
348
349
350# ============================================
351# ์ ๋ฆฌ
352# ============================================
353print("\n" + "=" * 60)
354print("๋ชจ๋ธ ์ ์ฅ ๋ฐ ๋ฐฐํฌ ์ ๋ฆฌ")
355print("=" * 60)
356
357summary = """
358์ ์ฅ ๋ฐฉ๋ฒ:
359
3601. state_dict (๊ถ์ฅ)
361 torch.save(model.state_dict(), 'model.pth')
362 model.load_state_dict(torch.load('model.pth'))
363
3642. ์ฒดํฌํฌ์ธํธ
365 checkpoint = {'model': model.state_dict(), 'optimizer': ...}
366 torch.save(checkpoint, 'checkpoint.pth')
367
3683. TorchScript
369 traced = torch.jit.trace(model, example_input)
370 traced.save('model.pt')
371
3724. ONNX
373 torch.onnx.export(model, input, 'model.onnx')
374
375์ถ๋ก ์ต์ ํ:
376 - model.eval()
377 - torch.inference_mode()
378 - ์์ํ (quantize_dynamic)
379
380๋ฐฐํฌ ์ต์
:
381 - FastAPI/Flask: ์น API
382 - ONNX Runtime: ๋ฒ์ฉ ์ถ๋ก
383 - TorchScript: C++ ๋ฐฐํฌ
384 - PyTorch Mobile: ๋ชจ๋ฐ์ผ ์ฑ
385"""
386print(summary)
387print("=" * 60)
388
389# ์์ ํ์ผ ์ ๋ฆฌ ์๋ด
390print(f"\n์์ ํ์ผ ์์น: {save_dir}")
391print("(์๋ ์ญ์ ๋์ง ์์ - ํ์์ ์๋ ์ญ์ )")