41. 모델 저장 및 배포
이전: 실전 텍스트 분류 프로젝트 | 다음: 강화학습 소개
41. 모델 저장 및 배포¶
학습 목표¶
- PyTorch 모델 저장 방법
- ONNX 변환
- TorchScript 사용
- 추론 최적화
1. PyTorch 모델 저장¶
state_dict 저장 (권장)¶
# 저장
torch.save(model.state_dict(), 'model_weights.pth')
# 로드
model = MyModel() # 같은 구조 필요
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
전체 모델 저장¶
# 저장
torch.save(model, 'model_full.pth')
# 로드
model = torch.load('model_full.pth')
model.eval()
체크포인트 저장¶
# 저장
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'best_acc': best_acc
}
torch.save(checkpoint, 'checkpoint.pth')
# 로드
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
2. TorchScript¶
개념¶
Python 의존성 없이 모델 실행
- C++에서 로드 가능
- 모바일 배포
- 서버 최적화
Tracing¶
# 예시 입력으로 추적
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
# 저장
traced_model.save('model_traced.pt')
# 로드
loaded_model = torch.jit.load('model_traced.pt')
output = loaded_model(example_input)
Scripting¶
# 제어 흐름 있는 모델
class MyModel(nn.Module):
def forward(self, x):
if x.sum() > 0:
return x * 2
return x
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')
비교¶
| 방법 | 장점 | 단점 |
|---|---|---|
| Trace | 간단, 대부분 동작 | 동적 제어 흐름 불가 |
| Script | 동적 제어 흐름 지원 | 일부 Python 기능 제한 |
3. ONNX 변환¶
변환¶
import torch.onnx
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
},
opset_version=11
)
ONNX Runtime 추론¶
import onnxruntime as ort
import numpy as np
# 세션 생성
session = ort.InferenceSession("model.onnx")
# 추론
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
result = session.run([output_name], {input_name: input_data})
검증¶
import onnx
# 모델 로드 및 검증
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX 모델 검증 통과")
4. 추론 최적화¶
eval 모드¶
model.eval() # Dropout, BatchNorm 비활성화
no_grad¶
with torch.no_grad():
output = model(input)
추론 모드 (PyTorch 2.0+)¶
with torch.inference_mode():
output = model(input)
양자화 (Quantization)¶
# 동적 양자화 (간단)
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=torch.qint8
)
# 정적 양자화 (더 최적화)
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_prepared = torch.quantization.prepare(model)
# 캘리브레이션 데이터로 실행
model_quantized = torch.quantization.convert(model_prepared)
5. 배포 옵션¶
Flask API¶
from flask import Flask, request, jsonify
import torch
app = Flask(__name__)
model = torch.load('model.pth')
model.eval()
@app.route('/predict', methods=['POST'])
def predict():
data = request.json['data']
tensor = torch.tensor(data).float()
with torch.no_grad():
output = model(tensor)
pred = output.argmax(dim=1).tolist()
return jsonify({'prediction': pred})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
FastAPI (권장)¶
from fastapi import FastAPI
from pydantic import BaseModel
import torch
app = FastAPI()
model = torch.jit.load('model_traced.pt')
model.eval()
class InputData(BaseModel):
data: list
@app.post("/predict")
async def predict(input_data: InputData):
tensor = torch.tensor(input_data.data).float()
with torch.inference_mode():
output = model(tensor)
pred = output.argmax(dim=1).tolist()
return {"prediction": pred}
Docker¶
FROM python:3.10-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY model_traced.pt .
COPY app.py .
EXPOSE 8000
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
6. 모바일 배포¶
PyTorch Mobile¶
# 모바일용 최적화
traced_model = torch.jit.trace(model, example_input)
optimized_model = torch.utils.mobile_optimizer.optimize_for_mobile(traced_model)
optimized_model._save_for_lite_interpreter("model_mobile.ptl")
Android/iOS¶
// Android (Kotlin)
val module = LiteModuleLoader.load(assetFilePath(this, "model_mobile.ptl"))
val inputTensor = Tensor.fromBlob(inputArray, longArrayOf(1, 3, 224, 224))
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
7. 클라우드 배포¶
AWS SageMaker¶
from sagemaker.pytorch import PyTorchModel
model = PyTorchModel(
model_data='s3://bucket/model.tar.gz',
role=role,
framework_version='2.0',
py_version='py310',
entry_point='inference.py'
)
predictor = model.deploy(
instance_type='ml.m5.large',
initial_instance_count=1
)
Hugging Face Hub¶
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj="model.pt",
path_in_repo="pytorch_model.bin",
repo_id="username/model-name",
repo_type="model"
)
8. 베스트 프랙티스¶
저장 전 체크리스트¶
# 1. eval 모드
model.eval()
# 2. GPU → CPU (범용성)
model.cpu()
# 3. 검증
with torch.no_grad():
test_output = model(test_input.cpu())
assert test_output.shape == expected_shape
버전 관리¶
save_dict = {
'model_state_dict': model.state_dict(),
'model_config': {
'input_size': 784,
'hidden_size': 256,
'num_classes': 10
},
'pytorch_version': torch.__version__,
'training_date': datetime.now().isoformat()
}
torch.save(save_dict, 'model_v1.0.pth')
정리¶
저장 방법 선택¶
| 용도 | 방법 |
|---|---|
| 학습 재개 | 체크포인트 (state_dict + optimizer) |
| Python 배포 | state_dict |
| C++ 배포 | TorchScript |
| 범용 배포 | ONNX |
| 모바일 | PyTorch Mobile |
핵심 코드¶
# 저장
torch.save(model.state_dict(), 'model.pth')
# TorchScript
traced = torch.jit.trace(model.eval(), example_input)
traced.save('model.pt')
# ONNX
torch.onnx.export(model, example_input, 'model.onnx')
다음 단계¶
39_Practical_Image_Classification.md에서 실전 프로젝트를 진행합니다.