TorchServe & Triton Inference Server

TorchServe & Triton Inference Server

1. TorchServe κ°œμš”

TorchServeλŠ” PyTorch λͺ¨λΈμ„ ν”„λ‘œλ•μ…˜ ν™˜κ²½μ—μ„œ μ„œλΉ™ν•˜κΈ° μœ„ν•œ 곡식 λ„κ΅¬μž…λ‹ˆλ‹€.

1.1 TorchServe μ•„ν‚€ν…μ²˜

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                      TorchServe μ•„ν‚€ν…μ²˜                             β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                     β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                                   β”‚
β”‚   β”‚   Client    β”‚                                                   β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜                                                   β”‚
β”‚          β”‚                                                          β”‚
β”‚          β–Ό                                                          β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”‚
β”‚   β”‚                    Frontend                              β”‚      β”‚
β”‚   β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                     β”‚      β”‚
β”‚   β”‚  β”‚ REST API     β”‚  β”‚ gRPC API     β”‚                     β”‚      β”‚
β”‚   β”‚  β”‚ :8080        β”‚  β”‚ :7070        β”‚                     β”‚      β”‚
β”‚   β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                     β”‚      β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β”‚
β”‚          β”‚                                                          β”‚
β”‚          β–Ό                                                          β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”‚
β”‚   β”‚                    Backend                               β”‚      β”‚
β”‚   β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”           β”‚      β”‚
β”‚   β”‚  β”‚              Model Store                  β”‚           β”‚      β”‚
β”‚   β”‚  β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”‚           β”‚      β”‚
β”‚   β”‚  β”‚  β”‚ Model A β”‚ β”‚ Model B β”‚ β”‚ Model C β”‚    β”‚           β”‚      β”‚
β”‚   β”‚  β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β”‚           β”‚      β”‚
β”‚   β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜           β”‚      β”‚
β”‚   β”‚                                                          β”‚      β”‚
β”‚   β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                     β”‚      β”‚
β”‚   β”‚  β”‚ Worker 1     β”‚  β”‚ Worker 2     β”‚  ...                β”‚      β”‚
β”‚   β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                     β”‚      β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β”‚
β”‚                                                                     β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1.2 μ„€μΉ˜

# TorchServe μ„€μΉ˜
pip install torchserve torch-model-archiver torch-workflow-archiver

# 버전 확인
torchserve --version

2. ν•Έλ“€λŸ¬ μž‘μ„±

2.1 κΈ°λ³Έ ν•Έλ“€λŸ¬

"""
custom_handler.py - TorchServe μ»€μŠ€ν…€ ν•Έλ“€λŸ¬
"""

import torch
import torch.nn.functional as F
from ts.torch_handler.base_handler import BaseHandler
import json
import logging

logger = logging.getLogger(__name__)

class CustomHandler(BaseHandler):
    """μ»€μŠ€ν…€ TorchServe ν•Έλ“€λŸ¬"""

    def __init__(self):
        super().__init__()
        self.initialized = False

    def initialize(self, context):
        """λͺ¨λΈ μ΄ˆκΈ°ν™”"""
        self.manifest = context.manifest
        properties = context.system_properties
        model_dir = properties.get("model_dir")

        # λ””λ°”μ΄μŠ€ μ„€μ •
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )

        # λͺ¨λΈ λ‘œλ“œ
        serialized_file = self.manifest["model"]["serializedFile"]
        model_pt_path = f"{model_dir}/{serialized_file}"

        self.model = torch.jit.load(model_pt_path, map_location=self.device)
        self.model.eval()

        # μΆ”κ°€ μ„€μ • λ‘œλ“œ (μžˆλŠ” 경우)
        self.class_names = self._load_class_names(model_dir)

        self.initialized = True
        logger.info("Model initialized successfully")

    def _load_class_names(self, model_dir):
        """클래슀 이름 λ‘œλ“œ"""
        try:
            with open(f"{model_dir}/index_to_name.json") as f:
                return json.load(f)
        except FileNotFoundError:
            return None

    def preprocess(self, data):
        """μž…λ ₯ μ „μ²˜λ¦¬"""
        inputs = []

        for row in data:
            # JSON μž…λ ₯ 처리
            if isinstance(row, dict):
                features = row.get("data") or row.get("body")
            else:
                features = row.get("body")

            if isinstance(features, (bytes, bytearray)):
                features = json.loads(features.decode("utf-8"))

            tensor = torch.tensor(features, dtype=torch.float32)
            inputs.append(tensor)

        # 배치둜 묢기
        return torch.stack(inputs).to(self.device)

    def inference(self, data):
        """μΆ”λ‘  μˆ˜ν–‰"""
        with torch.no_grad():
            outputs = self.model(data)
            probabilities = F.softmax(outputs, dim=1)
        return probabilities

    def postprocess(self, data):
        """좜λ ₯ ν›„μ²˜λ¦¬"""
        results = []

        for prob in data:
            prob_list = prob.cpu().numpy().tolist()
            prediction = int(torch.argmax(prob).item())

            result = {
                "prediction": prediction,
                "probabilities": prob_list
            }

            # 클래슀 이름 μΆ”κ°€
            if self.class_names:
                result["class_name"] = self.class_names.get(str(prediction))

            results.append(result)

        return results

2.2 이미지 λΆ„λ₯˜ ν•Έλ“€λŸ¬

"""
image_classifier_handler.py - 이미지 λΆ„λ₯˜ ν•Έλ“€λŸ¬
"""

import torch
import torch.nn.functional as F
from torchvision import transforms
from ts.torch_handler.vision_handler import VisionHandler
from PIL import Image
import io
import base64

class ImageClassifierHandler(VisionHandler):
    """이미지 λΆ„λ₯˜ ν•Έλ“€λŸ¬"""

    def __init__(self):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def preprocess(self, data):
        """이미지 μ „μ²˜λ¦¬"""
        images = []

        for row in data:
            image_data = row.get("data") or row.get("body")

            # Base64 λ””μ½”λ”©
            if isinstance(image_data, str):
                image_data = base64.b64decode(image_data)

            # 이미지 λ‘œλ“œ
            image = Image.open(io.BytesIO(image_data)).convert("RGB")
            image = self.transform(image)
            images.append(image)

        return torch.stack(images).to(self.device)

    def postprocess(self, data):
        """κ²°κ³Ό ν›„μ²˜λ¦¬"""
        probabilities = F.softmax(data, dim=1)
        top_k = torch.topk(probabilities, 5)

        results = []
        for probs, indices in zip(top_k.values, top_k.indices):
            result = {
                "predictions": [
                    {
                        "class_id": int(idx),
                        "class_name": self.mapping.get(str(int(idx)), "unknown"),
                        "probability": float(prob)
                    }
                    for prob, idx in zip(probs, indices)
                ]
            }
            results.append(result)

        return results

3. λͺ¨λΈ μ•„μΉ΄μ΄λΈŒ 및 배포

3.1 λͺ¨λΈ μ•„μΉ΄μ΄λΈŒ 생성

# λͺ¨λΈμ„ TorchScript둜 μ €μž₯
python -c "
import torch
model = YourModel()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
scripted = torch.jit.script(model)
scripted.save('model.pt')
"

# MAR 파일 생성
torch-model-archiver \
    --model-name my_model \
    --version 1.0 \
    --serialized-file model.pt \
    --handler custom_handler.py \
    --extra-files "index_to_name.json,config.json" \
    --export-path model_store

# μƒμ„±λœ 파일: model_store/my_model.mar

3.2 TorchServe μ‹œμž‘

# κΈ°λ³Έ μ‹œμž‘
torchserve --start \
    --model-store model_store \
    --models my_model=my_model.mar

# μ„€μ • 파일과 ν•¨κ»˜
torchserve --start \
    --model-store model_store \
    --models my_model=my_model.mar \
    --ts-config config.properties

# Docker둜 μ‹€ν–‰
docker run -d \
    --name torchserve \
    -p 8080:8080 \
    -p 8081:8081 \
    -p 8082:8082 \
    -v $(pwd)/model_store:/home/model-server/model-store \
    pytorch/torchserve:latest \
    torchserve --start --model-store /home/model-server/model-store

3.3 μ„€μ • 파일

# config.properties
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082

# μ›Œμ»€ μ„€μ •
default_workers_per_model=4
job_queue_size=1000

# 배치 μ„€μ •
max_batch_delay=100
batch_size=32

# GPU μ„€μ •
number_of_gpu=1

# λͺ¨λΈ μ„€μ •
model_store=/home/model-server/model-store
load_models=my_model.mar

3.4 API 호좜

"""
TorchServe API 호좜
"""

import requests
import json

# 예츑 μš”μ²­
def predict(data):
    response = requests.post(
        "http://localhost:8080/predictions/my_model",
        json=data
    )
    return response.json()

# 단일 예츑
result = predict({"data": [1.0, 2.0, 3.0, 4.0]})
print(result)

# 배치 예츑
batch_data = [
    {"data": [1.0, 2.0, 3.0, 4.0]},
    {"data": [5.0, 6.0, 7.0, 8.0]}
]
results = [predict(d) for d in batch_data]

# 관리 API
# λͺ¨λΈ λͺ©λ‘
models = requests.get("http://localhost:8081/models").json()

# λͺ¨λΈ 상세 정보
model_info = requests.get("http://localhost:8081/models/my_model").json()

# μ›Œμ»€ μŠ€μΌ€μΌλ§
requests.put(
    "http://localhost:8081/models/my_model",
    params={"min_worker": 2, "max_worker": 4}
)

# λͺ¨λΈ 등둝
requests.post(
    "http://localhost:8081/models",
    params={
        "url": "my_model_v2.mar",
        "initial_workers": 2
    }
)

4. Triton Inference Server

4.1 Triton κ°œμš”

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                   Triton Inference Server                           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                     β”‚
β”‚   지원 ν”„λ ˆμž„μ›Œν¬:                                                   β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”‚
β”‚   β”‚ PyTorch  β”‚ β”‚TensorFlowβ”‚ β”‚   ONNX   β”‚ β”‚ TensorRT β”‚              β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜              β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                           β”‚
β”‚   β”‚  Python  β”‚ β”‚   DALI   β”‚ β”‚   vLLM   β”‚                           β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                           β”‚
β”‚                                                                     β”‚
β”‚   μ£Όμš” κΈ°λŠ₯:                                                        β”‚
β”‚   - 동적 λ°°μΉ­ (Dynamic Batching)                                    β”‚
β”‚   - λͺ¨λΈ 앙상블                                                     β”‚
β”‚   - λ©€ν‹° λͺ¨λΈ μ„œλΉ™                                                  β”‚
β”‚   - GPU μŠ€μΌ€μ€„λ§                                                    β”‚
β”‚   - λͺ¨λΈ 버전 관리                                                  β”‚
β”‚                                                                     β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

4.2 λͺ¨λΈ μ €μž₯μ†Œ ꡬ쑰

model_repository/
β”œβ”€β”€ model_a/
β”‚   β”œβ”€β”€ config.pbtxt
β”‚   β”œβ”€β”€ 1/
β”‚   β”‚   └── model.onnx
β”‚   └── 2/
β”‚       └── model.onnx
β”œβ”€β”€ model_b/
β”‚   β”œβ”€β”€ config.pbtxt
β”‚   └── 1/
β”‚       └── model.pt
└── ensemble_model/
    └── config.pbtxt

4.3 λͺ¨λΈ μ„€μ •

# config.pbtxt - ONNX λͺ¨λΈ
name: "churn_predictor"
platform: "onnxruntime_onnx"
max_batch_size: 32

input [
  {
    name: "input"
    data_type: TYPE_FP32
    dims: [4]
  }
]

output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [2]
  }
]

instance_group [
  {
    count: 2
    kind: KIND_GPU
    gpus: [0]
  }
]

dynamic_batching {
  preferred_batch_size: [8, 16, 32]
  max_queue_delay_microseconds: 100
}

# 버전 μ •μ±…
version_policy: {
  latest: {
    num_versions: 2
  }
}
# config.pbtxt - PyTorch λͺ¨λΈ
name: "image_classifier"
platform: "pytorch_libtorch"
max_batch_size: 64

input [
  {
    name: "input__0"
    data_type: TYPE_FP32
    dims: [3, 224, 224]
  }
]

output [
  {
    name: "output__0"
    data_type: TYPE_FP32
    dims: [1000]
  }
]

instance_group [
  {
    count: 1
    kind: KIND_GPU
  }
]

4.4 Triton μ‹€ν–‰

# Docker둜 μ‹€ν–‰
docker run --gpus all -d \
    --name triton \
    -p 8000:8000 \
    -p 8001:8001 \
    -p 8002:8002 \
    -v $(pwd)/model_repository:/models \
    nvcr.io/nvidia/tritonserver:23.10-py3 \
    tritonserver --model-repository=/models

# ν—¬μŠ€ 체크
curl -v localhost:8000/v2/health/ready

# λͺ¨λΈ 메타데이터
curl localhost:8000/v2/models/churn_predictor

4.5 Python ν΄λΌμ΄μ–ΈνŠΈ

"""
Triton Python ν΄λΌμ΄μ–ΈνŠΈ
"""

import numpy as np
import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient

# HTTP ν΄λΌμ΄μ–ΈνŠΈ
def triton_http_inference(model_name: str, input_data: np.ndarray):
    """HTTPλ₯Ό ν†΅ν•œ μΆ”λ‘ """
    client = httpclient.InferenceServerClient(url="localhost:8000")

    # μž…λ ₯ μ„€μ •
    inputs = [
        httpclient.InferInput("input", input_data.shape, "FP32")
    ]
    inputs[0].set_data_from_numpy(input_data)

    # 좜λ ₯ μ„€μ •
    outputs = [
        httpclient.InferRequestedOutput("output")
    ]

    # μΆ”λ‘ 
    response = client.infer(
        model_name=model_name,
        inputs=inputs,
        outputs=outputs
    )

    return response.as_numpy("output")

# gRPC ν΄λΌμ΄μ–ΈνŠΈ (더 빠름)
def triton_grpc_inference(model_name: str, input_data: np.ndarray):
    """gRPCλ₯Ό ν†΅ν•œ μΆ”λ‘ """
    client = grpcclient.InferenceServerClient(url="localhost:8001")

    inputs = [
        grpcclient.InferInput("input", input_data.shape, "FP32")
    ]
    inputs[0].set_data_from_numpy(input_data)

    outputs = [
        grpcclient.InferRequestedOutput("output")
    ]

    response = client.infer(
        model_name=model_name,
        inputs=inputs,
        outputs=outputs
    )

    return response.as_numpy("output")

# μ‚¬μš© μ˜ˆμ‹œ
data = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
result = triton_http_inference("churn_predictor", data)
print(f"Predictions: {result}")

5. λͺ¨λΈ μ΅œμ ν™”

5.1 ONNX λ³€ν™˜

"""
PyTorch λͺ¨λΈμ„ ONNX둜 λ³€ν™˜
"""

import torch
import torch.onnx

# λͺ¨λΈ λ‘œλ“œ
model = YourModel()
model.load_state_dict(torch.load("model.pth"))
model.eval()

# 더미 μž…λ ₯
dummy_input = torch.randn(1, 4)

# ONNX둜 내보내기
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    export_params=True,
    opset_version=13,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    }
)

# ONNX λͺ¨λΈ 검증
import onnx
import onnxruntime as ort

onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

# μΆ”λ‘  ν…ŒμŠ€νŠΈ
session = ort.InferenceSession("model.onnx")
result = session.run(
    None,
    {"input": dummy_input.numpy()}
)
print(f"ONNX output: {result}")

5.2 TensorRT μ΅œμ ν™”

"""
TensorRT μ΅œμ ν™”
"""

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np

def build_engine(onnx_path: str, engine_path: str):
    """ONNXμ—μ„œ TensorRT μ—”μ§„ λΉŒλ“œ"""
    logger = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(logger)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    parser = trt.OnnxParser(network, logger)

    # ONNX νŒŒμ‹±
    with open(onnx_path, "rb") as f:
        if not parser.parse(f.read()):
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None

    # λΉŒλ” μ„€μ •
    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB

    # FP16 μ΅œμ ν™” (μ§€μ›λ˜λŠ” 경우)
    if builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)

    # 동적 배치 μ„€μ •
    profile = builder.create_optimization_profile()
    profile.set_shape("input", (1, 4), (16, 4), (64, 4))  # min, opt, max
    config.add_optimization_profile(profile)

    # μ—”μ§„ λΉŒλ“œ
    engine = builder.build_serialized_network(network, config)

    # μ €μž₯
    with open(engine_path, "wb") as f:
        f.write(engine)

    return engine

# μ—”μ§„ λΉŒλ“œ
build_engine("model.onnx", "model.engine")

5.3 μ–‘μžν™”

"""
λͺ¨λΈ μ–‘μžν™”
"""

import torch
from torch.quantization import quantize_dynamic, quantize_static

# 동적 μ–‘μžν™” (κ°€μž₯ 간단)
model_int8 = quantize_dynamic(
    model,
    {torch.nn.Linear},  # μ–‘μžν™”ν•  λ ˆμ΄μ–΄ νƒ€μž…
    dtype=torch.qint8
)

# 정적 μ–‘μžν™” (더 쒋은 μ„±λŠ₯)
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
model_prepared = torch.quantization.prepare(model)

# μΊ˜λ¦¬λΈŒλ ˆμ΄μ…˜ (λŒ€ν‘œ λ°μ΄ν„°λ‘œ)
with torch.no_grad():
    for data in calibration_data:
        model_prepared(data)

model_quantized = torch.quantization.convert(model_prepared)

# 크기 비ꡐ
import os
torch.save(model.state_dict(), "model_fp32.pth")
torch.save(model_quantized.state_dict(), "model_int8.pth")

print(f"FP32 size: {os.path.getsize('model_fp32.pth') / 1e6:.2f} MB")
print(f"INT8 size: {os.path.getsize('model_int8.pth') / 1e6:.2f} MB")

6. λ©€ν‹°λͺ¨λΈ μ„œλΉ™

6.1 Triton 앙상블

# ensemble_model/config.pbtxt
name: "ensemble_pipeline"
platform: "ensemble"
max_batch_size: 32

input [
  {
    name: "raw_input"
    data_type: TYPE_FP32
    dims: [4]
  }
]

output [
  {
    name: "final_output"
    data_type: TYPE_FP32
    dims: [2]
  }
]

ensemble_scheduling {
  step [
    {
      model_name: "preprocessor"
      model_version: -1
      input_map {
        key: "raw_input"
        value: "raw_input"
      }
      output_map {
        key: "processed_output"
        value: "processed"
      }
    },
    {
      model_name: "classifier"
      model_version: -1
      input_map {
        key: "processed"
        value: "input"
      }
      output_map {
        key: "output"
        value: "final_output"
      }
    }
  ]
}

6.2 A/B ν…ŒμŠ€νŠΈ

"""
TorchServe A/B ν…ŒμŠ€νŠΈ μ„€μ •
"""

import random
import requests

class ABTestRouter:
    """A/B ν…ŒμŠ€νŠΈ λΌμš°ν„°"""

    def __init__(self, model_a: str, model_b: str, traffic_split: float = 0.5):
        self.model_a = model_a
        self.model_b = model_b
        self.traffic_split = traffic_split
        self.base_url = "http://localhost:8080/predictions"

    def predict(self, data: dict) -> dict:
        """A/B λΆ„λ°°λœ 예츑"""
        # νŠΈλž˜ν”½ λΆ„λ°°
        if random.random() < self.traffic_split:
            model = self.model_a
            variant = "A"
        else:
            model = self.model_b
            variant = "B"

        # 예츑
        response = requests.post(
            f"{self.base_url}/{model}",
            json=data
        )

        result = response.json()
        result["variant"] = variant
        result["model"] = model

        return result

# μ‚¬μš©
router = ABTestRouter("model_v1", "model_v2", traffic_split=0.8)
result = router.predict({"data": [1.0, 2.0, 3.0, 4.0]})

μ—°μŠ΅ 문제

문제 1: TorchServe 배포

PyTorch 이미지 λΆ„λ₯˜ λͺ¨λΈμ„ TorchServe둜 λ°°ν¬ν•˜μ„Έμš”.

문제 2: ONNX λ³€ν™˜

PyTorch λͺ¨λΈμ„ ONNX둜 λ³€ν™˜ν•˜κ³  Tritonμ—μ„œ μ„œλΉ™ν•˜μ„Έμš”.

문제 3: μ„±λŠ₯ μ΅œμ ν™”

TensorRT둜 λͺ¨λΈμ„ μ΅œμ ν™”ν•˜κ³  μΆ”λ‘  속도λ₯Ό λΉ„κ΅ν•˜μ„Έμš”.


μš”μ•½

도ꡬ μž₯점 μ ν•©ν•œ 상황
TorchServe PyTorch λ„€μ΄ν‹°λΈŒ, 간단 PyTorch λͺ¨λΈ
Triton λ©€ν‹° ν”„λ ˆμž„μ›Œν¬, κ³ μ„±λŠ₯ λ³΅μž‘ν•œ μš”κ΅¬μ‚¬ν•­
ONNX Runtime λ²”μš©, 크둜슀 ν”Œλž«νΌ κ²½λŸ‰ 배포
TensorRT GPU μ΅œμ ν™” 졜고 μ„±λŠ₯ ν•„μš”

참고 자료

to navigate between lessons