12_model_save_deploy.py

Download
python 392 lines 10.1 KB
  1"""
  212. ๋ชจ๋ธ ์ €์žฅ ๋ฐ ๋ฐฐํฌ
  3
  4PyTorch ๋ชจ๋ธ ์ €์žฅ, TorchScript, ONNX ๋ณ€ํ™˜์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10import os
 11import tempfile
 12
 13print("=" * 60)
 14print("PyTorch ๋ชจ๋ธ ์ €์žฅ ๋ฐ ๋ฐฐํฌ")
 15print("=" * 60)
 16
 17
 18# ============================================
 19# 1. ์ƒ˜ํ”Œ ๋ชจ๋ธ
 20# ============================================
 21print("\n[1] ์ƒ˜ํ”Œ ๋ชจ๋ธ")
 22print("-" * 40)
 23
 24class SimpleClassifier(nn.Module):
 25    def __init__(self, input_size=784, hidden_size=256, num_classes=10):
 26        super().__init__()
 27        self.config = {
 28            'input_size': input_size,
 29            'hidden_size': hidden_size,
 30            'num_classes': num_classes
 31        }
 32        self.fc1 = nn.Linear(input_size, hidden_size)
 33        self.bn1 = nn.BatchNorm1d(hidden_size)
 34        self.fc2 = nn.Linear(hidden_size, num_classes)
 35
 36    def forward(self, x):
 37        x = x.view(x.size(0), -1)
 38        x = F.relu(self.bn1(self.fc1(x)))
 39        x = self.fc2(x)
 40        return x
 41
 42model = SimpleClassifier()
 43print(f"๋ชจ๋ธ ๊ตฌ์กฐ:\n{model}")
 44print(f"ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {sum(p.numel() for p in model.parameters()):,}")
 45
 46
 47# ============================================
 48# 2. state_dict ์ €์žฅ
 49# ============================================
 50print("\n[2] state_dict ์ €์žฅ")
 51print("-" * 40)
 52
 53# ์ž„์‹œ ๋””๋ ‰ํ† ๋ฆฌ ์‚ฌ์šฉ
 54save_dir = tempfile.mkdtemp()
 55
 56# ์ €์žฅ
 57weights_path = os.path.join(save_dir, 'model_weights.pth')
 58torch.save(model.state_dict(), weights_path)
 59print(f"์ €์žฅ: {weights_path}")
 60print(f"ํŒŒ์ผ ํฌ๊ธฐ: {os.path.getsize(weights_path) / 1024:.2f} KB")
 61
 62# ๋กœ๋“œ
 63loaded_model = SimpleClassifier()
 64loaded_model.load_state_dict(torch.load(weights_path, weights_only=True))
 65loaded_model.eval()
 66
 67# ๊ฒ€์ฆ
 68x = torch.randn(2, 1, 28, 28)
 69model.eval()
 70with torch.no_grad():
 71    original_out = model(x)
 72    loaded_out = loaded_model(x)
 73    diff = (original_out - loaded_out).abs().max().item()
 74    print(f"์ถœ๋ ฅ ์ฐจ์ด: {diff:.10f}")
 75
 76
 77# ============================================
 78# 3. ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
 79# ============================================
 80print("\n[3] ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ")
 81print("-" * 40)
 82
 83optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
 84
 85# ๊ฐ€์งœ ํ•™์Šต ์ƒํƒœ
 86epoch = 10
 87loss = 0.123
 88best_acc = 0.95
 89
 90# ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
 91checkpoint = {
 92    'epoch': epoch,
 93    'model_state_dict': model.state_dict(),
 94    'optimizer_state_dict': optimizer.state_dict(),
 95    'loss': loss,
 96    'best_acc': best_acc,
 97    'model_config': model.config
 98}
 99
100checkpoint_path = os.path.join(save_dir, 'checkpoint.pth')
101torch.save(checkpoint, checkpoint_path)
102print(f"์ฒดํฌํฌ์ธํŠธ ์ €์žฅ: {checkpoint_path}")
103
104# ์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ
105loaded_checkpoint = torch.load(checkpoint_path, weights_only=False)
106print(f"๋กœ๋“œ๋œ epoch: {loaded_checkpoint['epoch']}")
107print(f"๋กœ๋“œ๋œ best_acc: {loaded_checkpoint['best_acc']}")
108print(f"๋ชจ๋ธ ์„ค์ •: {loaded_checkpoint['model_config']}")
109
110
111# ============================================
112# 4. TorchScript - Tracing
113# ============================================
114print("\n[4] TorchScript - Tracing")
115print("-" * 40)
116
117model.eval()
118example_input = torch.randn(1, 1, 28, 28)
119
120# Trace
121traced_model = torch.jit.trace(model, example_input)
122
123# ์ €์žฅ
124traced_path = os.path.join(save_dir, 'model_traced.pt')
125traced_model.save(traced_path)
126print(f"TorchScript ์ €์žฅ: {traced_path}")
127print(f"ํŒŒ์ผ ํฌ๊ธฐ: {os.path.getsize(traced_path) / 1024:.2f} KB")
128
129# ๋กœ๋“œ ๋ฐ ๊ฒ€์ฆ
130loaded_traced = torch.jit.load(traced_path)
131with torch.no_grad():
132    traced_out = loaded_traced(example_input)
133    original_out = model(example_input)
134    diff = (traced_out - original_out).abs().max().item()
135    print(f"์ถœ๋ ฅ ์ฐจ์ด: {diff:.10f}")
136
137
138# ============================================
139# 5. TorchScript - Scripting
140# ============================================
141print("\n[5] TorchScript - Scripting")
142print("-" * 40)
143
144class ConditionalModel(nn.Module):
145    """์กฐ๊ฑด๋ฌธ์ด ์žˆ๋Š” ๋ชจ๋ธ"""
146    def __init__(self):
147        super().__init__()
148        self.fc = nn.Linear(10, 5)
149
150    def forward(self, x, use_relu: bool = True):
151        x = self.fc(x)
152        if use_relu:
153            x = F.relu(x)
154        return x
155
156cond_model = ConditionalModel()
157scripted_model = torch.jit.script(cond_model)
158
159scripted_path = os.path.join(save_dir, 'model_scripted.pt')
160scripted_model.save(scripted_path)
161print(f"Scripted ๋ชจ๋ธ ์ €์žฅ: {scripted_path}")
162
163# ์กฐ๊ฑด๋ถ€ ์‹คํ–‰ ํ…Œ์ŠคํŠธ
164x = torch.randn(2, 10)
165out_relu = scripted_model(x, True)
166out_no_relu = scripted_model(x, False)
167print(f"ReLU ์ ์šฉ: min={out_relu.min():.4f}")
168print(f"ReLU ๋ฏธ์ ์šฉ: min={out_no_relu.min():.4f}")
169
170
171# ============================================
172# 6. ONNX ๋ณ€ํ™˜
173# ============================================
174print("\n[6] ONNX ๋ณ€ํ™˜")
175print("-" * 40)
176
177try:
178    import onnx
179
180    model.eval()
181    dummy_input = torch.randn(1, 1, 28, 28)
182
183    onnx_path = os.path.join(save_dir, 'model.onnx')
184
185    torch.onnx.export(
186        model,
187        dummy_input,
188        onnx_path,
189        input_names=['input'],
190        output_names=['output'],
191        dynamic_axes={
192            'input': {0: 'batch_size'},
193            'output': {0: 'batch_size'}
194        },
195        opset_version=11
196    )
197
198    print(f"ONNX ์ €์žฅ: {onnx_path}")
199    print(f"ํŒŒ์ผ ํฌ๊ธฐ: {os.path.getsize(onnx_path) / 1024:.2f} KB")
200
201    # ๊ฒ€์ฆ
202    onnx_model = onnx.load(onnx_path)
203    onnx.checker.check_model(onnx_model)
204    print("ONNX ๋ชจ๋ธ ๊ฒ€์ฆ ํ†ต๊ณผ")
205
206except ImportError:
207    print("onnx ๋ฏธ์„ค์น˜ - ์Šคํ‚ต")
208
209
210# ============================================
211# 7. ONNX Runtime ์ถ”๋ก 
212# ============================================
213print("\n[7] ONNX Runtime ์ถ”๋ก ")
214print("-" * 40)
215
216try:
217    import onnxruntime as ort
218    import numpy as np
219
220    session = ort.InferenceSession(onnx_path)
221
222    input_name = session.get_inputs()[0].name
223    output_name = session.get_outputs()[0].name
224
225    # ์ถ”๋ก 
226    input_data = np.random.randn(2, 1, 28, 28).astype(np.float32)
227    result = session.run([output_name], {input_name: input_data})
228
229    print(f"ONNX Runtime ์ถœ๋ ฅ: {result[0].shape}")
230
231    # PyTorch ๊ฒฐ๊ณผ์™€ ๋น„๊ต
232    model.eval()
233    with torch.no_grad():
234        torch_out = model(torch.from_numpy(input_data))
235        diff = np.abs(result[0] - torch_out.numpy()).max()
236        print(f"PyTorch vs ONNX ์ฐจ์ด: {diff:.6f}")
237
238except ImportError:
239    print("onnxruntime ๋ฏธ์„ค์น˜ - ์Šคํ‚ต")
240
241
242# ============================================
243# 8. ์–‘์žํ™”
244# ============================================
245print("\n[8] ์–‘์žํ™” (Quantization)")
246print("-" * 40)
247
248# ๋™์  ์–‘์žํ™”
249quantized_model = torch.quantization.quantize_dynamic(
250    model, {nn.Linear}, dtype=torch.qint8
251)
252
253# ํฌ๊ธฐ ๋น„๊ต
254original_size = sum(p.numel() * p.element_size() for p in model.parameters())
255quantized_size = sum(
256    p.numel() * p.element_size() for p in quantized_model.parameters()
257    if p.dtype != torch.qint8
258)
259
260print(f"์›๋ณธ ๋ชจ๋ธ ํฌ๊ธฐ: {original_size / 1024:.2f} KB")
261print(f"์–‘์žํ™” ๋ชจ๋ธ (์ผ๋ถ€ ์ธต): ์•ฝ {original_size / 1024 * 0.25:.2f} KB (์ถ”์ •)")
262
263# ์ถ”๋ก  ๋น„๊ต
264x = torch.randn(100, 1, 28, 28)
265
266model.eval()
267quantized_model.eval()
268
269import time
270
271# ์›๋ณธ ๋ชจ๋ธ
272start = time.time()
273for _ in range(10):
274    with torch.no_grad():
275        _ = model(x)
276original_time = time.time() - start
277
278# ์–‘์žํ™” ๋ชจ๋ธ
279start = time.time()
280for _ in range(10):
281    with torch.no_grad():
282        _ = quantized_model(x)
283quantized_time = time.time() - start
284
285print(f"์›๋ณธ ์ถ”๋ก  ์‹œ๊ฐ„: {original_time*1000:.2f} ms")
286print(f"์–‘์žํ™” ์ถ”๋ก  ์‹œ๊ฐ„: {quantized_time*1000:.2f} ms")
287
288
289# ============================================
290# 9. ์ถ”๋ก  ์ตœ์ ํ™”
291# ============================================
292print("\n[9] ์ถ”๋ก  ์ตœ์ ํ™”")
293print("-" * 40)
294
295model.eval()
296x = torch.randn(100, 1, 28, 28)
297
298# no_grad
299start = time.time()
300for _ in range(100):
301    with torch.no_grad():
302        _ = model(x)
303no_grad_time = time.time() - start
304
305# inference_mode (๋” ๋น ๋ฆ„)
306start = time.time()
307for _ in range(100):
308    with torch.inference_mode():
309        _ = model(x)
310inference_time = time.time() - start
311
312print(f"no_grad ์‹œ๊ฐ„: {no_grad_time*1000:.2f} ms")
313print(f"inference_mode ์‹œ๊ฐ„: {inference_time*1000:.2f} ms")
314print(f"๊ฐœ์„ : {(no_grad_time - inference_time) / no_grad_time * 100:.1f}%")
315
316
317# ============================================
318# 10. ๋ชจ๋ฐ”์ผ ์ตœ์ ํ™”
319# ============================================
320print("\n[10] ๋ชจ๋ฐ”์ผ ์ตœ์ ํ™”")
321print("-" * 40)
322
323try:
324    # ๋ชจ๋ฐ”์ผ์šฉ ์ตœ์ ํ™”
325    traced_model = torch.jit.trace(model.eval(), example_input)
326    optimized_model = torch.utils.mobile_optimizer.optimize_for_mobile(traced_model)
327
328    mobile_path = os.path.join(save_dir, 'model_mobile.ptl')
329    optimized_model._save_for_lite_interpreter(mobile_path)
330
331    print(f"๋ชจ๋ฐ”์ผ ๋ชจ๋ธ ์ €์žฅ: {mobile_path}")
332    print(f"ํŒŒ์ผ ํฌ๊ธฐ: {os.path.getsize(mobile_path) / 1024:.2f} KB")
333except Exception as e:
334    print(f"๋ชจ๋ฐ”์ผ ์ตœ์ ํ™” ์Šคํ‚ต: {e}")
335
336
337# ============================================
338# 11. ์ €์žฅ๋œ ํŒŒ์ผ ๋ชฉ๋ก
339# ============================================
340print("\n[11] ์ €์žฅ๋œ ํŒŒ์ผ ๋ชฉ๋ก")
341print("-" * 40)
342
343print(f"์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ: {save_dir}")
344for f in os.listdir(save_dir):
345    path = os.path.join(save_dir, f)
346    size = os.path.getsize(path) / 1024
347    print(f"  {f}: {size:.2f} KB")
348
349
350# ============================================
351# ์ •๋ฆฌ
352# ============================================
353print("\n" + "=" * 60)
354print("๋ชจ๋ธ ์ €์žฅ ๋ฐ ๋ฐฐํฌ ์ •๋ฆฌ")
355print("=" * 60)
356
357summary = """
358์ €์žฅ ๋ฐฉ๋ฒ•:
359
3601. state_dict (๊ถŒ์žฅ)
361   torch.save(model.state_dict(), 'model.pth')
362   model.load_state_dict(torch.load('model.pth'))
363
3642. ์ฒดํฌํฌ์ธํŠธ
365   checkpoint = {'model': model.state_dict(), 'optimizer': ...}
366   torch.save(checkpoint, 'checkpoint.pth')
367
3683. TorchScript
369   traced = torch.jit.trace(model, example_input)
370   traced.save('model.pt')
371
3724. ONNX
373   torch.onnx.export(model, input, 'model.onnx')
374
375์ถ”๋ก  ์ตœ์ ํ™”:
376   - model.eval()
377   - torch.inference_mode()
378   - ์–‘์žํ™” (quantize_dynamic)
379
380๋ฐฐํฌ ์˜ต์…˜:
381   - FastAPI/Flask: ์›น API
382   - ONNX Runtime: ๋ฒ”์šฉ ์ถ”๋ก 
383   - TorchScript: C++ ๋ฐฐํฌ
384   - PyTorch Mobile: ๋ชจ๋ฐ”์ผ ์•ฑ
385"""
386print(summary)
387print("=" * 60)
388
389# ์ž„์‹œ ํŒŒ์ผ ์ •๋ฆฌ ์•ˆ๋‚ด
390print(f"\n์ž„์‹œ ํŒŒ์ผ ์œ„์น˜: {save_dir}")
391print("(์ž๋™ ์‚ญ์ œ๋˜์ง€ ์•Š์Œ - ํ•„์š”์‹œ ์ˆ˜๋™ ์‚ญ์ œ)")