TorchServe & Triton Inference Server
TorchServe & Triton Inference Server¶
1. TorchServe κ°μ¶
TorchServeλ PyTorch λͺ¨λΈμ νλ‘λμ νκ²½μμ μλΉνκΈ° μν 곡μ λꡬμ λλ€.
1.1 TorchServe μν€ν μ²¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β TorchServe μν€ν
μ² β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β βββββββββββββββ β
β β Client β β
β ββββββββ¬βββββββ β
β β β
β βΌ β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Frontend β β
β β ββββββββββββββββ ββββββββββββββββ β β
β β β REST API β β gRPC API β β β
β β β :8080 β β :7070 β β β
β β ββββββββββββββββ ββββββββββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β β
β βΌ β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Backend β β
β β ββββββββββββββββββββββββββββββββββββββββββββ β β
β β β Model Store β β β
β β β βββββββββββ βββββββββββ βββββββββββ β β β
β β β β Model A β β Model B β β Model C β β β β
β β β βββββββββββ βββββββββββ βββββββββββ β β β
β β ββββββββββββββββββββββββββββββββββββββββββββ β β
β β β β
β β ββββββββββββββββ ββββββββββββββββ β β
β β β Worker 1 β β Worker 2 β ... β β
β β ββββββββββββββββ ββββββββββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
1.2 μ€μΉ¶
# TorchServe μ€μΉ
pip install torchserve torch-model-archiver torch-workflow-archiver
# λ²μ νμΈ
torchserve --version
2. νΈλ€λ¬ μμ±¶
2.1 κΈ°λ³Έ νΈλ€λ¬¶
"""
custom_handler.py - TorchServe 컀μ€ν
νΈλ€λ¬
"""
import torch
import torch.nn.functional as F
from ts.torch_handler.base_handler import BaseHandler
import json
import logging
logger = logging.getLogger(__name__)
class CustomHandler(BaseHandler):
"""컀μ€ν
TorchServe νΈλ€λ¬"""
def __init__(self):
super().__init__()
self.initialized = False
def initialize(self, context):
"""λͺ¨λΈ μ΄κΈ°ν"""
self.manifest = context.manifest
properties = context.system_properties
model_dir = properties.get("model_dir")
# λλ°μ΄μ€ μ€μ
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# λͺ¨λΈ λ‘λ
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = f"{model_dir}/{serialized_file}"
self.model = torch.jit.load(model_pt_path, map_location=self.device)
self.model.eval()
# μΆκ° μ€μ λ‘λ (μλ κ²½μ°)
self.class_names = self._load_class_names(model_dir)
self.initialized = True
logger.info("Model initialized successfully")
def _load_class_names(self, model_dir):
"""ν΄λμ€ μ΄λ¦ λ‘λ"""
try:
with open(f"{model_dir}/index_to_name.json") as f:
return json.load(f)
except FileNotFoundError:
return None
def preprocess(self, data):
"""μ
λ ₯ μ μ²λ¦¬"""
inputs = []
for row in data:
# JSON μ
λ ₯ μ²λ¦¬
if isinstance(row, dict):
features = row.get("data") or row.get("body")
else:
features = row.get("body")
if isinstance(features, (bytes, bytearray)):
features = json.loads(features.decode("utf-8"))
tensor = torch.tensor(features, dtype=torch.float32)
inputs.append(tensor)
# λ°°μΉλ‘ λ¬ΆκΈ°
return torch.stack(inputs).to(self.device)
def inference(self, data):
"""μΆλ‘ μν"""
with torch.no_grad():
outputs = self.model(data)
probabilities = F.softmax(outputs, dim=1)
return probabilities
def postprocess(self, data):
"""μΆλ ₯ νμ²λ¦¬"""
results = []
for prob in data:
prob_list = prob.cpu().numpy().tolist()
prediction = int(torch.argmax(prob).item())
result = {
"prediction": prediction,
"probabilities": prob_list
}
# ν΄λμ€ μ΄λ¦ μΆκ°
if self.class_names:
result["class_name"] = self.class_names.get(str(prediction))
results.append(result)
return results
2.2 μ΄λ―Έμ§ λΆλ₯ νΈλ€λ¬¶
"""
image_classifier_handler.py - μ΄λ―Έμ§ λΆλ₯ νΈλ€λ¬
"""
import torch
import torch.nn.functional as F
from torchvision import transforms
from ts.torch_handler.vision_handler import VisionHandler
from PIL import Image
import io
import base64
class ImageClassifierHandler(VisionHandler):
"""μ΄λ―Έμ§ λΆλ₯ νΈλ€λ¬"""
def __init__(self):
super().__init__()
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def preprocess(self, data):
"""μ΄λ―Έμ§ μ μ²λ¦¬"""
images = []
for row in data:
image_data = row.get("data") or row.get("body")
# Base64 λμ½λ©
if isinstance(image_data, str):
image_data = base64.b64decode(image_data)
# μ΄λ―Έμ§ λ‘λ
image = Image.open(io.BytesIO(image_data)).convert("RGB")
image = self.transform(image)
images.append(image)
return torch.stack(images).to(self.device)
def postprocess(self, data):
"""κ²°κ³Ό νμ²λ¦¬"""
probabilities = F.softmax(data, dim=1)
top_k = torch.topk(probabilities, 5)
results = []
for probs, indices in zip(top_k.values, top_k.indices):
result = {
"predictions": [
{
"class_id": int(idx),
"class_name": self.mapping.get(str(int(idx)), "unknown"),
"probability": float(prob)
}
for prob, idx in zip(probs, indices)
]
}
results.append(result)
return results
3. λͺ¨λΈ μμΉ΄μ΄λΈ λ° λ°°ν¬¶
3.1 λͺ¨λΈ μμΉ΄μ΄λΈ μμ±¶
# λͺ¨λΈμ TorchScriptλ‘ μ μ₯
python -c "
import torch
model = YourModel()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
scripted = torch.jit.script(model)
scripted.save('model.pt')
"
# MAR νμΌ μμ±
torch-model-archiver \
--model-name my_model \
--version 1.0 \
--serialized-file model.pt \
--handler custom_handler.py \
--extra-files "index_to_name.json,config.json" \
--export-path model_store
# μμ±λ νμΌ: model_store/my_model.mar
3.2 TorchServe μμ¶
# κΈ°λ³Έ μμ
torchserve --start \
--model-store model_store \
--models my_model=my_model.mar
# μ€μ νμΌκ³Ό ν¨κ»
torchserve --start \
--model-store model_store \
--models my_model=my_model.mar \
--ts-config config.properties
# Dockerλ‘ μ€ν
docker run -d \
--name torchserve \
-p 8080:8080 \
-p 8081:8081 \
-p 8082:8082 \
-v $(pwd)/model_store:/home/model-server/model-store \
pytorch/torchserve:latest \
torchserve --start --model-store /home/model-server/model-store
3.3 μ€μ νμΌ¶
# config.properties
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
# μ컀 μ€μ
default_workers_per_model=4
job_queue_size=1000
# λ°°μΉ μ€μ
max_batch_delay=100
batch_size=32
# GPU μ€μ
number_of_gpu=1
# λͺ¨λΈ μ€μ
model_store=/home/model-server/model-store
load_models=my_model.mar
3.4 API νΈμΆ¶
"""
TorchServe API νΈμΆ
"""
import requests
import json
# μμΈ‘ μμ²
def predict(data):
response = requests.post(
"http://localhost:8080/predictions/my_model",
json=data
)
return response.json()
# λ¨μΌ μμΈ‘
result = predict({"data": [1.0, 2.0, 3.0, 4.0]})
print(result)
# λ°°μΉ μμΈ‘
batch_data = [
{"data": [1.0, 2.0, 3.0, 4.0]},
{"data": [5.0, 6.0, 7.0, 8.0]}
]
results = [predict(d) for d in batch_data]
# κ΄λ¦¬ API
# λͺ¨λΈ λͺ©λ‘
models = requests.get("http://localhost:8081/models").json()
# λͺ¨λΈ μμΈ μ 보
model_info = requests.get("http://localhost:8081/models/my_model").json()
# μ컀 μ€μΌμΌλ§
requests.put(
"http://localhost:8081/models/my_model",
params={"min_worker": 2, "max_worker": 4}
)
# λͺ¨λΈ λ±λ‘
requests.post(
"http://localhost:8081/models",
params={
"url": "my_model_v2.mar",
"initial_workers": 2
}
)
4. Triton Inference Server¶
4.1 Triton κ°μ¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Triton Inference Server β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β μ§μ νλ μμν¬: β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
β β PyTorch β βTensorFlowβ β ONNX β β TensorRT β β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
β ββββββββββββ ββββββββββββ ββββββββββββ β
β β Python β β DALI β β vLLM β β
β ββββββββββββ ββββββββββββ ββββββββββββ β
β β
β μ£Όμ κΈ°λ₯: β
β - λμ λ°°μΉ (Dynamic Batching) β
β - λͺ¨λΈ μμλΈ β
β - λ©ν° λͺ¨λΈ μλΉ β
β - GPU μ€μΌμ€λ§ β
β - λͺ¨λΈ λ²μ κ΄λ¦¬ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
4.2 λͺ¨λΈ μ μ₯μ ꡬ쑰¶
model_repository/
βββ model_a/
β βββ config.pbtxt
β βββ 1/
β β βββ model.onnx
β βββ 2/
β βββ model.onnx
βββ model_b/
β βββ config.pbtxt
β βββ 1/
β βββ model.pt
βββ ensemble_model/
βββ config.pbtxt
4.3 λͺ¨λΈ μ€μ ¶
# config.pbtxt - ONNX λͺ¨λΈ
name: "churn_predictor"
platform: "onnxruntime_onnx"
max_batch_size: 32
input [
{
name: "input"
data_type: TYPE_FP32
dims: [4]
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [2]
}
]
instance_group [
{
count: 2
kind: KIND_GPU
gpus: [0]
}
]
dynamic_batching {
preferred_batch_size: [8, 16, 32]
max_queue_delay_microseconds: 100
}
# λ²μ μ μ±
version_policy: {
latest: {
num_versions: 2
}
}
# config.pbtxt - PyTorch λͺ¨λΈ
name: "image_classifier"
platform: "pytorch_libtorch"
max_batch_size: 64
input [
{
name: "input__0"
data_type: TYPE_FP32
dims: [3, 224, 224]
}
]
output [
{
name: "output__0"
data_type: TYPE_FP32
dims: [1000]
}
]
instance_group [
{
count: 1
kind: KIND_GPU
}
]
4.4 Triton μ€ν¶
# Dockerλ‘ μ€ν
docker run --gpus all -d \
--name triton \
-p 8000:8000 \
-p 8001:8001 \
-p 8002:8002 \
-v $(pwd)/model_repository:/models \
nvcr.io/nvidia/tritonserver:23.10-py3 \
tritonserver --model-repository=/models
# ν¬μ€ 체ν¬
curl -v localhost:8000/v2/health/ready
# λͺ¨λΈ λ©νλ°μ΄ν°
curl localhost:8000/v2/models/churn_predictor
4.5 Python ν΄λΌμ΄μΈνΈ¶
"""
Triton Python ν΄λΌμ΄μΈνΈ
"""
import numpy as np
import tritonclient.http as httpclient
import tritonclient.grpc as grpcclient
# HTTP ν΄λΌμ΄μΈνΈ
def triton_http_inference(model_name: str, input_data: np.ndarray):
"""HTTPλ₯Ό ν΅ν μΆλ‘ """
client = httpclient.InferenceServerClient(url="localhost:8000")
# μ
λ ₯ μ€μ
inputs = [
httpclient.InferInput("input", input_data.shape, "FP32")
]
inputs[0].set_data_from_numpy(input_data)
# μΆλ ₯ μ€μ
outputs = [
httpclient.InferRequestedOutput("output")
]
# μΆλ‘
response = client.infer(
model_name=model_name,
inputs=inputs,
outputs=outputs
)
return response.as_numpy("output")
# gRPC ν΄λΌμ΄μΈνΈ (λ λΉ λ¦)
def triton_grpc_inference(model_name: str, input_data: np.ndarray):
"""gRPCλ₯Ό ν΅ν μΆλ‘ """
client = grpcclient.InferenceServerClient(url="localhost:8001")
inputs = [
grpcclient.InferInput("input", input_data.shape, "FP32")
]
inputs[0].set_data_from_numpy(input_data)
outputs = [
grpcclient.InferRequestedOutput("output")
]
response = client.infer(
model_name=model_name,
inputs=inputs,
outputs=outputs
)
return response.as_numpy("output")
# μ¬μ© μμ
data = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
result = triton_http_inference("churn_predictor", data)
print(f"Predictions: {result}")
5. λͺ¨λΈ μ΅μ ν¶
5.1 ONNX λ³ν¶
"""
PyTorch λͺ¨λΈμ ONNXλ‘ λ³ν
"""
import torch
import torch.onnx
# λͺ¨λΈ λ‘λ
model = YourModel()
model.load_state_dict(torch.load("model.pth"))
model.eval()
# λλ―Έ μ
λ ₯
dummy_input = torch.randn(1, 4)
# ONNXλ‘ λ΄λ³΄λ΄κΈ°
torch.onnx.export(
model,
dummy_input,
"model.onnx",
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"}
}
)
# ONNX λͺ¨λΈ κ²μ¦
import onnx
import onnxruntime as ort
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
# μΆλ‘ ν
μ€νΈ
session = ort.InferenceSession("model.onnx")
result = session.run(
None,
{"input": dummy_input.numpy()}
)
print(f"ONNX output: {result}")
5.2 TensorRT μ΅μ ν¶
"""
TensorRT μ΅μ ν
"""
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
def build_engine(onnx_path: str, engine_path: str):
"""ONNXμμ TensorRT μμ§ λΉλ"""
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
# ONNX νμ±
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
# λΉλ μ€μ
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
# FP16 μ΅μ ν (μ§μλλ κ²½μ°)
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# λμ λ°°μΉ μ€μ
profile = builder.create_optimization_profile()
profile.set_shape("input", (1, 4), (16, 4), (64, 4)) # min, opt, max
config.add_optimization_profile(profile)
# μμ§ λΉλ
engine = builder.build_serialized_network(network, config)
# μ μ₯
with open(engine_path, "wb") as f:
f.write(engine)
return engine
# μμ§ λΉλ
build_engine("model.onnx", "model.engine")
5.3 μμν¶
"""
λͺ¨λΈ μμν
"""
import torch
from torch.quantization import quantize_dynamic, quantize_static
# λμ μμν (κ°μ₯ κ°λ¨)
model_int8 = quantize_dynamic(
model,
{torch.nn.Linear}, # μμνν λ μ΄μ΄ νμ
dtype=torch.qint8
)
# μ μ μμν (λ μ’μ μ±λ₯)
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
model_prepared = torch.quantization.prepare(model)
# μΊλ¦¬λΈλ μ΄μ
(λν λ°μ΄ν°λ‘)
with torch.no_grad():
for data in calibration_data:
model_prepared(data)
model_quantized = torch.quantization.convert(model_prepared)
# ν¬κΈ° λΉκ΅
import os
torch.save(model.state_dict(), "model_fp32.pth")
torch.save(model_quantized.state_dict(), "model_int8.pth")
print(f"FP32 size: {os.path.getsize('model_fp32.pth') / 1e6:.2f} MB")
print(f"INT8 size: {os.path.getsize('model_int8.pth') / 1e6:.2f} MB")
6. λ©ν°λͺ¨λΈ μλΉ¶
6.1 Triton μμλΈ¶
# ensemble_model/config.pbtxt
name: "ensemble_pipeline"
platform: "ensemble"
max_batch_size: 32
input [
{
name: "raw_input"
data_type: TYPE_FP32
dims: [4]
}
]
output [
{
name: "final_output"
data_type: TYPE_FP32
dims: [2]
}
]
ensemble_scheduling {
step [
{
model_name: "preprocessor"
model_version: -1
input_map {
key: "raw_input"
value: "raw_input"
}
output_map {
key: "processed_output"
value: "processed"
}
},
{
model_name: "classifier"
model_version: -1
input_map {
key: "processed"
value: "input"
}
output_map {
key: "output"
value: "final_output"
}
}
]
}
6.2 A/B ν μ€νΈ¶
"""
TorchServe A/B ν
μ€νΈ μ€μ
"""
import random
import requests
class ABTestRouter:
"""A/B ν
μ€νΈ λΌμ°ν°"""
def __init__(self, model_a: str, model_b: str, traffic_split: float = 0.5):
self.model_a = model_a
self.model_b = model_b
self.traffic_split = traffic_split
self.base_url = "http://localhost:8080/predictions"
def predict(self, data: dict) -> dict:
"""A/B λΆλ°°λ μμΈ‘"""
# νΈλν½ λΆλ°°
if random.random() < self.traffic_split:
model = self.model_a
variant = "A"
else:
model = self.model_b
variant = "B"
# μμΈ‘
response = requests.post(
f"{self.base_url}/{model}",
json=data
)
result = response.json()
result["variant"] = variant
result["model"] = model
return result
# μ¬μ©
router = ABTestRouter("model_v1", "model_v2", traffic_split=0.8)
result = router.predict({"data": [1.0, 2.0, 3.0, 4.0]})
μ°μ΅ λ¬Έμ ¶
λ¬Έμ 1: TorchServe λ°°ν¬¶
PyTorch μ΄λ―Έμ§ λΆλ₯ λͺ¨λΈμ TorchServeλ‘ λ°°ν¬νμΈμ.
λ¬Έμ 2: ONNX λ³ν¶
PyTorch λͺ¨λΈμ ONNXλ‘ λ³ννκ³ Tritonμμ μλΉνμΈμ.
λ¬Έμ 3: μ±λ₯ μ΅μ ν¶
TensorRTλ‘ λͺ¨λΈμ μ΅μ ννκ³ μΆλ‘ μλλ₯Ό λΉκ΅νμΈμ.
μμ½¶
| λꡬ | μ₯μ | μ ν©ν μν© |
|---|---|---|
| TorchServe | PyTorch λ€μ΄ν°λΈ, κ°λ¨ | PyTorch λͺ¨λΈ |
| Triton | λ©ν° νλ μμν¬, κ³ μ±λ₯ | 볡μ‘ν μꡬμ¬ν |
| ONNX Runtime | λ²μ©, ν¬λ‘μ€ νλ«νΌ | κ²½λ λ°°ν¬ |
| TensorRT | GPU μ΅μ ν | μ΅κ³ μ±λ₯ νμ |