41. Model Saving and Deployment
Previous: Practical Text Classification Project | Next: Reinforcement Learning Introduction
41. Model Saving and Deployment¶
Learning Objectives¶
- PyTorch model saving methods
- ONNX conversion
- Using TorchScript
- Inference optimization
1. PyTorch Model Saving¶
Saving state_dict (Recommended)¶
# μ μ₯
torch.save(model.state_dict(), 'model_weights.pth')
# λ‘λ
model = MyModel() # κ°μ ꡬ쑰 νμ
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
Saving Full Model¶
# μ μ₯
torch.save(model, 'model_full.pth')
# λ‘λ
model = torch.load('model_full.pth')
model.eval()
Saving Checkpoint¶
# μ μ₯
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'best_acc': best_acc
}
torch.save(checkpoint, 'checkpoint.pth')
# λ‘λ
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
2. TorchScript¶
Concept¶
Run model without Python dependency
- Loadable from C++
- Mobile deployment
- Server optimization
Tracing¶
# μμ μ
λ ₯μΌλ‘ μΆμ
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
# μ μ₯
traced_model.save('model_traced.pt')
# λ‘λ
loaded_model = torch.jit.load('model_traced.pt')
output = loaded_model(example_input)
Scripting¶
# μ μ΄ νλ¦ μλ λͺ¨λΈ
class MyModel(nn.Module):
def forward(self, x):
if x.sum() > 0:
return x * 2
return x
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')
Comparison¶
| Method | Advantages | Disadvantages |
|---|---|---|
| Trace | Simple, works for most cases | Cannot handle dynamic control flow |
| Script | Supports dynamic control flow | Some Python features restricted |
3. ONNX Conversion¶
Conversion¶
import torch.onnx
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
},
opset_version=11
)
ONNX Runtime Inference¶
import onnxruntime as ort
import numpy as np
# μΈμ
μμ±
session = ort.InferenceSession("model.onnx")
# μΆλ‘
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
result = session.run([output_name], {input_name: input_data})
Validation¶
import onnx
# λͺ¨λΈ λ‘λ λ° κ²μ¦
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX λͺ¨λΈ κ²μ¦ ν΅κ³Ό")
4. Inference Optimization¶
eval Mode¶
model.eval() # Dropout, BatchNorm λΉνμ±ν
no_grad¶
with torch.no_grad():
output = model(input)
Inference Mode (PyTorch 2.0+)¶
with torch.inference_mode():
output = model(input)
Quantization¶
# λμ μμν (κ°λ¨)
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8
)
# μ μ μμν (λ μ΅μ ν)
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_prepared = torch.quantization.prepare(model)
# μΊλ¦¬λΈλ μ΄μ
λ°μ΄ν°λ‘ μ€ν
model_quantized = torch.quantization.convert(model_prepared)
5. Deployment Options¶
Flask API¶
from flask import Flask, request, jsonify
import torch
app = Flask(__name__)
model = torch.load('model.pth')
model.eval()
@app.route('/predict', methods=['POST'])
def predict():
data = request.json['data']
tensor = torch.tensor(data).float()
with torch.no_grad():
output = model(tensor)
pred = output.argmax(dim=1).tolist()
return jsonify({'prediction': pred})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
FastAPI (Recommended)¶
from fastapi import FastAPI
from pydantic import BaseModel
import torch
app = FastAPI()
model = torch.jit.load('model_traced.pt')
model.eval()
class InputData(BaseModel):
data: list
@app.post("/predict")
async def predict(input_data: InputData):
tensor = torch.tensor(input_data.data).float()
with torch.inference_mode():
output = model(tensor)
pred = output.argmax(dim=1).tolist()
return {"prediction": pred}
Docker¶
FROM python:3.10-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY model_traced.pt .
COPY app.py .
EXPOSE 8000
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
6. Mobile Deployment¶
PyTorch Mobile¶
# λͺ¨λ°μΌμ© μ΅μ ν
traced_model = torch.jit.trace(model, example_input)
optimized_model = torch.utils.mobile_optimizer.optimize_for_mobile(traced_model)
optimized_model._save_for_lite_interpreter("model_mobile.ptl")
Android/iOS¶
// Android (Kotlin)
val module = LiteModuleLoader.load(assetFilePath(this, "model_mobile.ptl"))
val inputTensor = Tensor.fromBlob(inputArray, longArrayOf(1, 3, 224, 224))
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
7. Cloud Deployment¶
AWS SageMaker¶
from sagemaker.pytorch import PyTorchModel
model = PyTorchModel(
model_data='s3://bucket/model.tar.gz',
role=role,
framework_version='2.0',
py_version='py310',
entry_point='inference.py'
)
predictor = model.deploy(
instance_type='ml.m5.large',
initial_instance_count=1
)
Hugging Face Hub¶
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj="model.pt",
path_in_repo="pytorch_model.bin",
repo_id="username/model-name",
repo_type="model"
)
8. Best Practices¶
Pre-Save Checklist¶
# 1. eval λͺ¨λ
model.eval()
# 2. GPU β CPU (λ²μ©μ±)
model.cpu()
# 3. κ²μ¦
with torch.no_grad():
test_output = model(test_input.cpu())
assert test_output.shape == expected_shape
Version Control¶
save_dict = {
'model_state_dict': model.state_dict(),
'model_config': {
'input_size': 784,
'hidden_size': 256,
'num_classes': 10
},
'pytorch_version': torch.__version__,
'training_date': datetime.now().isoformat()
}
torch.save(save_dict, 'model_v1.0.pth')
Summary¶
Choosing Save Method¶
| Use Case | Method |
|---|---|
| Resume training | Checkpoint (state_dict + optimizer) |
| Python deployment | state_dict |
| C++ deployment | TorchScript |
| Universal deployment | ONNX |
| Mobile | PyTorch Mobile |
Core Code¶
# μ μ₯
torch.save(model.state_dict(), 'model.pth')
# TorchScript
traced = torch.jit.trace(model.eval(), example_input)
traced.save('model.pt')
# ONNX
torch.onnx.export(model, example_input, 'model.onnx')
Next Steps¶
Proceed with practical projects in 39_Practical_Image_Classification.md.