41. Model Saving and Deployment

Previous: Practical Text Classification Project | Next: Reinforcement Learning Introduction


41. Model Saving and Deployment

Learning Objectives

  • PyTorch model saving methods
  • ONNX conversion
  • Using TorchScript
  • Inference optimization

1. PyTorch Model Saving

# μ €μž₯
torch.save(model.state_dict(), 'model_weights.pth')

# λ‘œλ“œ
model = MyModel()  # 같은 ꡬ쑰 ν•„μš”
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

Saving Full Model

# μ €μž₯
torch.save(model, 'model_full.pth')

# λ‘œλ“œ
model = torch.load('model_full.pth')
model.eval()

Saving Checkpoint

# μ €μž₯
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'best_acc': best_acc
}
torch.save(checkpoint, 'checkpoint.pth')

# λ‘œλ“œ
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

2. TorchScript

Concept

Run model without Python dependency
- Loadable from C++
- Mobile deployment
- Server optimization

Tracing

# μ˜ˆμ‹œ μž…λ ₯으둜 좔적
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)

# μ €μž₯
traced_model.save('model_traced.pt')

# λ‘œλ“œ
loaded_model = torch.jit.load('model_traced.pt')
output = loaded_model(example_input)

Scripting

# μ œμ–΄ 흐름 μžˆλŠ” λͺ¨λΈ
class MyModel(nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x * 2
        return x

scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')

Comparison

Method Advantages Disadvantages
Trace Simple, works for most cases Cannot handle dynamic control flow
Script Supports dynamic control flow Some Python features restricted

3. ONNX Conversion

Conversion

import torch.onnx

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    opset_version=11
)

ONNX Runtime Inference

import onnxruntime as ort
import numpy as np

# μ„Έμ…˜ 생성
session = ort.InferenceSession("model.onnx")

# μΆ”λ‘ 
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
result = session.run([output_name], {input_name: input_data})

Validation

import onnx

# λͺ¨λΈ λ‘œλ“œ 및 검증
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX λͺ¨λΈ 검증 톡과")

4. Inference Optimization

eval Mode

model.eval()  # Dropout, BatchNorm λΉ„ν™œμ„±ν™”

no_grad

with torch.no_grad():
    output = model(input)

Inference Mode (PyTorch 2.0+)

with torch.inference_mode():
    output = model(input)

Quantization

# 동적 μ–‘μžν™” (간단)
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

# 정적 μ–‘μžν™” (더 μ΅œμ ν™”)
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_prepared = torch.quantization.prepare(model)
# μΊ˜λ¦¬λΈŒλ ˆμ΄μ…˜ λ°μ΄ν„°λ‘œ μ‹€ν–‰
model_quantized = torch.quantization.convert(model_prepared)

5. Deployment Options

Flask API

from flask import Flask, request, jsonify
import torch

app = Flask(__name__)
model = torch.load('model.pth')
model.eval()

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json['data']
    tensor = torch.tensor(data).float()

    with torch.no_grad():
        output = model(tensor)
        pred = output.argmax(dim=1).tolist()

    return jsonify({'prediction': pred})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)
from fastapi import FastAPI
from pydantic import BaseModel
import torch

app = FastAPI()
model = torch.jit.load('model_traced.pt')
model.eval()

class InputData(BaseModel):
    data: list

@app.post("/predict")
async def predict(input_data: InputData):
    tensor = torch.tensor(input_data.data).float()

    with torch.inference_mode():
        output = model(tensor)
        pred = output.argmax(dim=1).tolist()

    return {"prediction": pred}

Docker

FROM python:3.10-slim

WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt

COPY model_traced.pt .
COPY app.py .

EXPOSE 8000
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]

6. Mobile Deployment

PyTorch Mobile

# λͺ¨λ°”μΌμš© μ΅œμ ν™”
traced_model = torch.jit.trace(model, example_input)
optimized_model = torch.utils.mobile_optimizer.optimize_for_mobile(traced_model)
optimized_model._save_for_lite_interpreter("model_mobile.ptl")

Android/iOS

// Android (Kotlin)
val module = LiteModuleLoader.load(assetFilePath(this, "model_mobile.ptl"))
val inputTensor = Tensor.fromBlob(inputArray, longArrayOf(1, 3, 224, 224))
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()

7. Cloud Deployment

AWS SageMaker

from sagemaker.pytorch import PyTorchModel

model = PyTorchModel(
    model_data='s3://bucket/model.tar.gz',
    role=role,
    framework_version='2.0',
    py_version='py310',
    entry_point='inference.py'
)

predictor = model.deploy(
    instance_type='ml.m5.large',
    initial_instance_count=1
)

Hugging Face Hub

from huggingface_hub import HfApi

api = HfApi()
api.upload_file(
    path_or_fileobj="model.pt",
    path_in_repo="pytorch_model.bin",
    repo_id="username/model-name",
    repo_type="model"
)

8. Best Practices

Pre-Save Checklist

# 1. eval λͺ¨λ“œ
model.eval()

# 2. GPU β†’ CPU (λ²”μš©μ„±)
model.cpu()

# 3. 검증
with torch.no_grad():
    test_output = model(test_input.cpu())
    assert test_output.shape == expected_shape

Version Control

save_dict = {
    'model_state_dict': model.state_dict(),
    'model_config': {
        'input_size': 784,
        'hidden_size': 256,
        'num_classes': 10
    },
    'pytorch_version': torch.__version__,
    'training_date': datetime.now().isoformat()
}
torch.save(save_dict, 'model_v1.0.pth')

Summary

Choosing Save Method

Use Case Method
Resume training Checkpoint (state_dict + optimizer)
Python deployment state_dict
C++ deployment TorchScript
Universal deployment ONNX
Mobile PyTorch Mobile

Core Code

# μ €μž₯
torch.save(model.state_dict(), 'model.pth')

# TorchScript
traced = torch.jit.trace(model.eval(), example_input)
traced.save('model.pt')

# ONNX
torch.onnx.export(model, example_input, 'model.onnx')

Next Steps

Proceed with practical projects in 39_Practical_Image_Classification.md.

to navigate between lessons