41. 모델 저장 및 배포

이전: 실전 텍스트 분류 프로젝트 | 다음: 강화학습 소개


41. 모델 저장 및 배포

학습 목표

  • PyTorch 모델 저장 방법
  • ONNX 변환
  • TorchScript 사용
  • 추론 최적화

1. PyTorch 모델 저장

state_dict 저장 (권장)

# 저장
torch.save(model.state_dict(), 'model_weights.pth')

# 로드
model = MyModel()  # 같은 구조 필요
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

전체 모델 저장

# 저장
torch.save(model, 'model_full.pth')

# 로드
model = torch.load('model_full.pth')
model.eval()

체크포인트 저장

# 저장
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'best_acc': best_acc
}
torch.save(checkpoint, 'checkpoint.pth')

# 로드
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

2. TorchScript

개념

Python 의존성 없이 모델 실행
- C++에서 로드 가능
- 모바일 배포
- 서버 최적화

Tracing

# 예시 입력으로 추적
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)

# 저장
traced_model.save('model_traced.pt')

# 로드
loaded_model = torch.jit.load('model_traced.pt')
output = loaded_model(example_input)

Scripting

# 제어 흐름 있는 모델
class MyModel(nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x * 2
        return x

scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')

비교

방법 장점 단점
Trace 간단, 대부분 동작 동적 제어 흐름 불가
Script 동적 제어 흐름 지원 일부 Python 기능 제한

3. ONNX 변환

변환

import torch.onnx

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    opset_version=11
)

ONNX Runtime 추론

import onnxruntime as ort
import numpy as np

# 세션 생성
session = ort.InferenceSession("model.onnx")

# 추론
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
result = session.run([output_name], {input_name: input_data})

검증

import onnx

# 모델 로드 및 검증
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX 모델 검증 통과")

4. 추론 최적화

eval 모드

model.eval()  # Dropout, BatchNorm 비활성화

no_grad

with torch.no_grad():
    output = model(input)

추론 모드 (PyTorch 2.0+)

with torch.inference_mode():
    output = model(input)

양자화 (Quantization)

# 동적 양자화 (간단)
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

# 정적 양자화 (더 최적화)
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_prepared = torch.quantization.prepare(model)
# 캘리브레이션 데이터로 실행
model_quantized = torch.quantization.convert(model_prepared)

5. 배포 옵션

Flask API

from flask import Flask, request, jsonify
import torch

app = Flask(__name__)
model = torch.load('model.pth')
model.eval()

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json['data']
    tensor = torch.tensor(data).float()

    with torch.no_grad():
        output = model(tensor)
        pred = output.argmax(dim=1).tolist()

    return jsonify({'prediction': pred})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

FastAPI (권장)

from fastapi import FastAPI
from pydantic import BaseModel
import torch

app = FastAPI()
model = torch.jit.load('model_traced.pt')
model.eval()

class InputData(BaseModel):
    data: list

@app.post("/predict")
async def predict(input_data: InputData):
    tensor = torch.tensor(input_data.data).float()

    with torch.inference_mode():
        output = model(tensor)
        pred = output.argmax(dim=1).tolist()

    return {"prediction": pred}

Docker

FROM python:3.10-slim

WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt

COPY model_traced.pt .
COPY app.py .

EXPOSE 8000
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]

6. 모바일 배포

PyTorch Mobile

# 모바일용 최적화
traced_model = torch.jit.trace(model, example_input)
optimized_model = torch.utils.mobile_optimizer.optimize_for_mobile(traced_model)
optimized_model._save_for_lite_interpreter("model_mobile.ptl")

Android/iOS

// Android (Kotlin)
val module = LiteModuleLoader.load(assetFilePath(this, "model_mobile.ptl"))
val inputTensor = Tensor.fromBlob(inputArray, longArrayOf(1, 3, 224, 224))
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()

7. 클라우드 배포

AWS SageMaker

from sagemaker.pytorch import PyTorchModel

model = PyTorchModel(
    model_data='s3://bucket/model.tar.gz',
    role=role,
    framework_version='2.0',
    py_version='py310',
    entry_point='inference.py'
)

predictor = model.deploy(
    instance_type='ml.m5.large',
    initial_instance_count=1
)

Hugging Face Hub

from huggingface_hub import HfApi

api = HfApi()
api.upload_file(
    path_or_fileobj="model.pt",
    path_in_repo="pytorch_model.bin",
    repo_id="username/model-name",
    repo_type="model"
)

8. 베스트 프랙티스

저장 전 체크리스트

# 1. eval 모드
model.eval()

# 2. GPU → CPU (범용성)
model.cpu()

# 3. 검증
with torch.no_grad():
    test_output = model(test_input.cpu())
    assert test_output.shape == expected_shape

버전 관리

save_dict = {
    'model_state_dict': model.state_dict(),
    'model_config': {
        'input_size': 784,
        'hidden_size': 256,
        'num_classes': 10
    },
    'pytorch_version': torch.__version__,
    'training_date': datetime.now().isoformat()
}
torch.save(save_dict, 'model_v1.0.pth')

정리

저장 방법 선택

용도 방법
학습 재개 체크포인트 (state_dict + optimizer)
Python 배포 state_dict
C++ 배포 TorchScript
범용 배포 ONNX
모바일 PyTorch Mobile

핵심 코드

# 저장
torch.save(model.state_dict(), 'model.pth')

# TorchScript
traced = torch.jit.trace(model.eval(), example_input)
traced.save('model.pt')

# ONNX
torch.onnx.export(model, example_input, 'model.onnx')

다음 단계

39_Practical_Image_Classification.md에서 실전 프로젝트를 진행합니다.

to navigate between lessons