1"""
2TorchServe Custom Handler Example
3=================================
4
5TorchServe์์ ์ฌ์ฉํ ์ปค์คํ
ํธ๋ค๋ฌ ์์ ์
๋๋ค.
6
7์ฌ์ฉ ๋ฐฉ๋ฒ:
8 1. ๋ชจ๋ธ ์์นด์ด๋ธ ์์ฑ:
9 torch-model-archiver --model-name mymodel \\
10 --version 1.0 \\
11 --serialized-file model.pt \\
12 --handler torchserve_handler.py \\
13 --export-path model_store
14
15 2. TorchServe ์์:
16 torchserve --start --model-store model_store --models mymodel=mymodel.mar
17
18 3. ์์ธก ์์ฒญ:
19 curl -X POST http://localhost:8080/predictions/mymodel \\
20 -H "Content-Type: application/json" \\
21 -d '{"data": [1.0, 2.0, 3.0, 4.0]}'
22"""
23
24import torch
25import torch.nn.functional as F
26from ts.torch_handler.base_handler import BaseHandler
27import json
28import logging
29import os
30import time
31
32logger = logging.getLogger(__name__)
33
34
35class ChurnPredictionHandler(BaseHandler):
36 """
37 ๊ณ ๊ฐ ์ดํ ์์ธก ๋ชจ๋ธ ํธ๋ค๋ฌ
38
39 ์ด ํธ๋ค๋ฌ๋ ๋ค์์ ์ํํฉ๋๋ค:
40 1. ๋ชจ๋ธ ์ด๊ธฐํ ๋ฐ ๋ก๋
41 2. ์
๋ ฅ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
42 3. ์ถ๋ก ์ํ
43 4. ๊ฒฐ๊ณผ ํ์ฒ๋ฆฌ
44 """
45
46 def __init__(self):
47 super().__init__()
48 self.initialized = False
49 self.model = None
50 self.device = None
51 self.class_names = None
52 self.feature_names = None
53
54 def initialize(self, context):
55 """
56 ๋ชจ๋ธ ์ด๊ธฐํ
57
58 Args:
59 context: TorchServe ์ปจํ
์คํธ ๊ฐ์ฒด
60 """
61 logger.info("Initializing model...")
62
63 # ์ปจํ
์คํธ์์ ์ ๋ณด ์ถ์ถ
64 self.manifest = context.manifest
65 properties = context.system_properties
66 model_dir = properties.get("model_dir")
67
68 # ๋๋ฐ์ด์ค ์ค์
69 if torch.cuda.is_available() and properties.get("gpu_id") is not None:
70 self.device = torch.device(f"cuda:{properties.get('gpu_id')}")
71 logger.info(f"Using GPU: {properties.get('gpu_id')}")
72 else:
73 self.device = torch.device("cpu")
74 logger.info("Using CPU")
75
76 # ๋ชจ๋ธ ๋ก๋
77 serialized_file = self.manifest["model"]["serializedFile"]
78 model_path = os.path.join(model_dir, serialized_file)
79
80 try:
81 self.model = torch.jit.load(model_path, map_location=self.device)
82 self.model.eval()
83 logger.info(f"Model loaded from {model_path}")
84 except Exception as e:
85 logger.error(f"Failed to load model: {e}")
86 raise
87
88 # ์ถ๊ฐ ์ค์ ํ์ผ ๋ก๋
89 self._load_config(model_dir)
90
91 self.initialized = True
92 logger.info("Model initialization complete")
93
94 def _load_config(self, model_dir):
95 """์ค์ ํ์ผ ๋ก๋"""
96 # ํด๋์ค ์ด๋ฆ
97 class_file = os.path.join(model_dir, "index_to_name.json")
98 if os.path.exists(class_file):
99 with open(class_file) as f:
100 self.class_names = json.load(f)
101 logger.info(f"Loaded class names: {self.class_names}")
102 else:
103 self.class_names = {"0": "not_churned", "1": "churned"}
104
105 # ํผ์ฒ ์ด๋ฆ
106 feature_file = os.path.join(model_dir, "feature_names.json")
107 if os.path.exists(feature_file):
108 with open(feature_file) as f:
109 self.feature_names = json.load(f)
110 logger.info(f"Loaded feature names: {self.feature_names}")
111
112 def preprocess(self, data):
113 """
114 ์
๋ ฅ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
115
116 Args:
117 data: ์์ฒญ ๋ฐ์ดํฐ ๋ฆฌ์คํธ
118
119 Returns:
120 torch.Tensor: ์ ์ฒ๋ฆฌ๋ ์
๋ ฅ ํ
์
121 """
122 logger.info(f"Preprocessing {len(data)} samples")
123 inputs = []
124
125 for row in data:
126 # ์์ฒญ ๋ฐ์ดํฐ ํ์ฑ
127 if isinstance(row, dict):
128 features = row.get("data") or row.get("body")
129 else:
130 features = row.get("body")
131
132 # ๋ฐ์ดํธ ๋ฐ์ดํฐ ์ฒ๋ฆฌ
133 if isinstance(features, (bytes, bytearray)):
134 features = json.loads(features.decode("utf-8"))
135
136 # JSON ๋ฌธ์์ด ์ฒ๋ฆฌ
137 if isinstance(features, str):
138 features = json.loads(features)
139
140 # dict์ธ ๊ฒฝ์ฐ ๊ฐ๋ง ์ถ์ถ
141 if isinstance(features, dict):
142 if "data" in features:
143 features = features["data"]
144 else:
145 features = list(features.values())
146
147 # ํ
์๋ก ๋ณํ
148 tensor = torch.tensor(features, dtype=torch.float32)
149 inputs.append(tensor)
150
151 # ๋ฐฐ์น๋ก ๋ฌถ๊ธฐ
152 batch = torch.stack(inputs).to(self.device)
153 logger.info(f"Input batch shape: {batch.shape}")
154
155 return batch
156
157 def inference(self, data):
158 """
159 ๋ชจ๋ธ ์ถ๋ก
160
161 Args:
162 data: ์ ์ฒ๋ฆฌ๋ ์
๋ ฅ ํ
์
163
164 Returns:
165 torch.Tensor: ๋ชจ๋ธ ์ถ๋ ฅ
166 """
167 logger.info("Running inference...")
168 start_time = time.time()
169
170 with torch.no_grad():
171 outputs = self.model(data)
172
173 # ํ๋ฅ ๋ก ๋ณํ (๋ถ๋ฅ ๋ชจ๋ธ์ธ ๊ฒฝ์ฐ)
174 if outputs.dim() > 1 and outputs.shape[1] > 1:
175 probabilities = F.softmax(outputs, dim=1)
176 else:
177 probabilities = torch.sigmoid(outputs)
178
179 inference_time = time.time() - start_time
180 logger.info(f"Inference completed in {inference_time:.4f}s")
181
182 return probabilities
183
184 def postprocess(self, data):
185 """
186 ์ถ๋ ฅ ํ์ฒ๋ฆฌ
187
188 Args:
189 data: ๋ชจ๋ธ ์ถ๋ ฅ ํ
์
190
191 Returns:
192 list: JSON ์ง๋ ฌํ ๊ฐ๋ฅํ ๊ฒฐ๊ณผ ๋ฆฌ์คํธ
193 """
194 logger.info("Postprocessing results...")
195 results = []
196
197 for prob in data:
198 prob_list = prob.cpu().numpy().tolist()
199
200 # ์ด์ง ๋ถ๋ฅ
201 if len(prob_list) == 1:
202 prediction = 1 if prob_list[0] > 0.5 else 0
203 probabilities = [1 - prob_list[0], prob_list[0]]
204 # ๋ค์ค ํด๋์ค
205 else:
206 prediction = int(torch.argmax(prob).item())
207 probabilities = prob_list
208
209 result = {
210 "prediction": prediction,
211 "probabilities": probabilities,
212 "confidence": max(probabilities)
213 }
214
215 # ํด๋์ค ์ด๋ฆ ์ถ๊ฐ
216 if self.class_names:
217 result["class_name"] = self.class_names.get(
218 str(prediction),
219 f"class_{prediction}"
220 )
221
222 results.append(result)
223
224 logger.info(f"Processed {len(results)} results")
225 return results
226
227 def handle(self, data, context):
228 """
229 ์ ์ฒด ์์ฒญ ์ฒ๋ฆฌ (preprocess -> inference -> postprocess)
230
231 TorchServe๊ฐ ํธ์ถํ๋ ๋ฉ์ธ ๋ฉ์๋
232 """
233 if not self.initialized:
234 self.initialize(context)
235
236 if data is None:
237 return None
238
239 # ์ ์ฒ๋ฆฌ
240 model_input = self.preprocess(data)
241
242 # ์ถ๋ก
243 model_output = self.inference(model_input)
244
245 # ํ์ฒ๋ฆฌ
246 return self.postprocess(model_output)
247
248
249# ํธ๋ค๋ฌ ์ธ์คํด์ค (TorchServe๊ฐ ๋ก๋)
250_service = ChurnPredictionHandler()
251
252
253def handle(data, context):
254 """TorchServe ์ํธ๋ฆฌ ํฌ์ธํธ"""
255 return _service.handle(data, context)
256
257
258# ============================================================
259# ๋ก์ปฌ ํ
์คํธ์ฉ ์ฝ๋
260# ============================================================
261
262if __name__ == "__main__":
263 import torch.nn as nn
264
265 # ๊ฐ๋จํ ํ
์คํธ ๋ชจ๋ธ
266 class SimpleModel(nn.Module):
267 def __init__(self, input_size, hidden_size, num_classes):
268 super().__init__()
269 self.fc1 = nn.Linear(input_size, hidden_size)
270 self.fc2 = nn.Linear(hidden_size, num_classes)
271 self.relu = nn.ReLU()
272
273 def forward(self, x):
274 x = self.relu(self.fc1(x))
275 x = self.fc2(x)
276 return x
277
278 # ๋ชจ๋ธ ์์ฑ ๋ฐ ์ ์ฅ
279 print("ํ
์คํธ ๋ชจ๋ธ ์์ฑ...")
280 model = SimpleModel(4, 10, 2)
281 model.eval()
282
283 # TorchScript๋ก ์ ์ฅ
284 scripted = torch.jit.script(model)
285 scripted.save("test_model.pt")
286 print("๋ชจ๋ธ ์ ์ฅ: test_model.pt")
287
288 # ํธ๋ค๋ฌ ํ
์คํธ
289 print("\nํธ๋ค๋ฌ ํ
์คํธ...")
290
291 # Mock ์ปจํ
์คํธ
292 class MockContext:
293 manifest = {"model": {"serializedFile": "test_model.pt"}}
294 system_properties = {"model_dir": ".", "gpu_id": None}
295
296 handler = ChurnPredictionHandler()
297 handler.initialize(MockContext())
298
299 # ํ
์คํธ ์์ฒญ
300 test_data = [
301 {"data": [1.0, 2.0, 3.0, 4.0]},
302 {"data": [5.0, 6.0, 7.0, 8.0]}
303 ]
304
305 results = handler.handle(test_data, MockContext())
306
307 print("\n๊ฒฐ๊ณผ:")
308 for i, result in enumerate(results):
309 print(f" ์ํ {i+1}:")
310 print(f" ์์ธก: {result['prediction']}")
311 print(f" ํ๋ฅ : {result['probabilities']}")
312 print(f" ์ ๋ขฐ๋: {result['confidence']:.4f}")
313
314 # ์ ๋ฆฌ
315 import os
316 os.remove("test_model.pt")
317 print("\nํ
์คํธ ์๋ฃ!")