torchserve_handler.py

Download
python 318 lines 9.0 KB
  1"""
  2TorchServe Custom Handler Example
  3=================================
  4
  5TorchServe์—์„œ ์‚ฌ์šฉํ•  ์ปค์Šคํ…€ ํ•ธ๋“ค๋Ÿฌ ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.
  6
  7์‚ฌ์šฉ ๋ฐฉ๋ฒ•:
  8    1. ๋ชจ๋ธ ์•„์นด์ด๋ธŒ ์ƒ์„ฑ:
  9       torch-model-archiver --model-name mymodel \\
 10           --version 1.0 \\
 11           --serialized-file model.pt \\
 12           --handler torchserve_handler.py \\
 13           --export-path model_store
 14
 15    2. TorchServe ์‹œ์ž‘:
 16       torchserve --start --model-store model_store --models mymodel=mymodel.mar
 17
 18    3. ์˜ˆ์ธก ์š”์ฒญ:
 19       curl -X POST http://localhost:8080/predictions/mymodel \\
 20           -H "Content-Type: application/json" \\
 21           -d '{"data": [1.0, 2.0, 3.0, 4.0]}'
 22"""
 23
 24import torch
 25import torch.nn.functional as F
 26from ts.torch_handler.base_handler import BaseHandler
 27import json
 28import logging
 29import os
 30import time
 31
 32logger = logging.getLogger(__name__)
 33
 34
 35class ChurnPredictionHandler(BaseHandler):
 36    """
 37    ๊ณ ๊ฐ ์ดํƒˆ ์˜ˆ์ธก ๋ชจ๋ธ ํ•ธ๋“ค๋Ÿฌ
 38
 39    ์ด ํ•ธ๋“ค๋Ÿฌ๋Š” ๋‹ค์Œ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค:
 40    1. ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๋กœ๋“œ
 41    2. ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ
 42    3. ์ถ”๋ก  ์ˆ˜ํ–‰
 43    4. ๊ฒฐ๊ณผ ํ›„์ฒ˜๋ฆฌ
 44    """
 45
 46    def __init__(self):
 47        super().__init__()
 48        self.initialized = False
 49        self.model = None
 50        self.device = None
 51        self.class_names = None
 52        self.feature_names = None
 53
 54    def initialize(self, context):
 55        """
 56        ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
 57
 58        Args:
 59            context: TorchServe ์ปจํ…์ŠคํŠธ ๊ฐ์ฒด
 60        """
 61        logger.info("Initializing model...")
 62
 63        # ์ปจํ…์ŠคํŠธ์—์„œ ์ •๋ณด ์ถ”์ถœ
 64        self.manifest = context.manifest
 65        properties = context.system_properties
 66        model_dir = properties.get("model_dir")
 67
 68        # ๋””๋ฐ”์ด์Šค ์„ค์ •
 69        if torch.cuda.is_available() and properties.get("gpu_id") is not None:
 70            self.device = torch.device(f"cuda:{properties.get('gpu_id')}")
 71            logger.info(f"Using GPU: {properties.get('gpu_id')}")
 72        else:
 73            self.device = torch.device("cpu")
 74            logger.info("Using CPU")
 75
 76        # ๋ชจ๋ธ ๋กœ๋“œ
 77        serialized_file = self.manifest["model"]["serializedFile"]
 78        model_path = os.path.join(model_dir, serialized_file)
 79
 80        try:
 81            self.model = torch.jit.load(model_path, map_location=self.device)
 82            self.model.eval()
 83            logger.info(f"Model loaded from {model_path}")
 84        except Exception as e:
 85            logger.error(f"Failed to load model: {e}")
 86            raise
 87
 88        # ์ถ”๊ฐ€ ์„ค์ • ํŒŒ์ผ ๋กœ๋“œ
 89        self._load_config(model_dir)
 90
 91        self.initialized = True
 92        logger.info("Model initialization complete")
 93
 94    def _load_config(self, model_dir):
 95        """์„ค์ • ํŒŒ์ผ ๋กœ๋“œ"""
 96        # ํด๋ž˜์Šค ์ด๋ฆ„
 97        class_file = os.path.join(model_dir, "index_to_name.json")
 98        if os.path.exists(class_file):
 99            with open(class_file) as f:
100                self.class_names = json.load(f)
101            logger.info(f"Loaded class names: {self.class_names}")
102        else:
103            self.class_names = {"0": "not_churned", "1": "churned"}
104
105        # ํ”ผ์ฒ˜ ์ด๋ฆ„
106        feature_file = os.path.join(model_dir, "feature_names.json")
107        if os.path.exists(feature_file):
108            with open(feature_file) as f:
109                self.feature_names = json.load(f)
110            logger.info(f"Loaded feature names: {self.feature_names}")
111
112    def preprocess(self, data):
113        """
114        ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ
115
116        Args:
117            data: ์š”์ฒญ ๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ
118
119        Returns:
120            torch.Tensor: ์ „์ฒ˜๋ฆฌ๋œ ์ž…๋ ฅ ํ…์„œ
121        """
122        logger.info(f"Preprocessing {len(data)} samples")
123        inputs = []
124
125        for row in data:
126            # ์š”์ฒญ ๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ
127            if isinstance(row, dict):
128                features = row.get("data") or row.get("body")
129            else:
130                features = row.get("body")
131
132            # ๋ฐ”์ดํŠธ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
133            if isinstance(features, (bytes, bytearray)):
134                features = json.loads(features.decode("utf-8"))
135
136            # JSON ๋ฌธ์ž์—ด ์ฒ˜๋ฆฌ
137            if isinstance(features, str):
138                features = json.loads(features)
139
140            # dict์ธ ๊ฒฝ์šฐ ๊ฐ’๋งŒ ์ถ”์ถœ
141            if isinstance(features, dict):
142                if "data" in features:
143                    features = features["data"]
144                else:
145                    features = list(features.values())
146
147            # ํ…์„œ๋กœ ๋ณ€ํ™˜
148            tensor = torch.tensor(features, dtype=torch.float32)
149            inputs.append(tensor)
150
151        # ๋ฐฐ์น˜๋กœ ๋ฌถ๊ธฐ
152        batch = torch.stack(inputs).to(self.device)
153        logger.info(f"Input batch shape: {batch.shape}")
154
155        return batch
156
157    def inference(self, data):
158        """
159        ๋ชจ๋ธ ์ถ”๋ก 
160
161        Args:
162            data: ์ „์ฒ˜๋ฆฌ๋œ ์ž…๋ ฅ ํ…์„œ
163
164        Returns:
165            torch.Tensor: ๋ชจ๋ธ ์ถœ๋ ฅ
166        """
167        logger.info("Running inference...")
168        start_time = time.time()
169
170        with torch.no_grad():
171            outputs = self.model(data)
172
173            # ํ™•๋ฅ ๋กœ ๋ณ€ํ™˜ (๋ถ„๋ฅ˜ ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ)
174            if outputs.dim() > 1 and outputs.shape[1] > 1:
175                probabilities = F.softmax(outputs, dim=1)
176            else:
177                probabilities = torch.sigmoid(outputs)
178
179        inference_time = time.time() - start_time
180        logger.info(f"Inference completed in {inference_time:.4f}s")
181
182        return probabilities
183
184    def postprocess(self, data):
185        """
186        ์ถœ๋ ฅ ํ›„์ฒ˜๋ฆฌ
187
188        Args:
189            data: ๋ชจ๋ธ ์ถœ๋ ฅ ํ…์„œ
190
191        Returns:
192            list: JSON ์ง๋ ฌํ™” ๊ฐ€๋Šฅํ•œ ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ
193        """
194        logger.info("Postprocessing results...")
195        results = []
196
197        for prob in data:
198            prob_list = prob.cpu().numpy().tolist()
199
200            # ์ด์ง„ ๋ถ„๋ฅ˜
201            if len(prob_list) == 1:
202                prediction = 1 if prob_list[0] > 0.5 else 0
203                probabilities = [1 - prob_list[0], prob_list[0]]
204            # ๋‹ค์ค‘ ํด๋ž˜์Šค
205            else:
206                prediction = int(torch.argmax(prob).item())
207                probabilities = prob_list
208
209            result = {
210                "prediction": prediction,
211                "probabilities": probabilities,
212                "confidence": max(probabilities)
213            }
214
215            # ํด๋ž˜์Šค ์ด๋ฆ„ ์ถ”๊ฐ€
216            if self.class_names:
217                result["class_name"] = self.class_names.get(
218                    str(prediction),
219                    f"class_{prediction}"
220                )
221
222            results.append(result)
223
224        logger.info(f"Processed {len(results)} results")
225        return results
226
227    def handle(self, data, context):
228        """
229        ์ „์ฒด ์š”์ฒญ ์ฒ˜๋ฆฌ (preprocess -> inference -> postprocess)
230
231        TorchServe๊ฐ€ ํ˜ธ์ถœํ•˜๋Š” ๋ฉ”์ธ ๋ฉ”์„œ๋“œ
232        """
233        if not self.initialized:
234            self.initialize(context)
235
236        if data is None:
237            return None
238
239        # ์ „์ฒ˜๋ฆฌ
240        model_input = self.preprocess(data)
241
242        # ์ถ”๋ก 
243        model_output = self.inference(model_input)
244
245        # ํ›„์ฒ˜๋ฆฌ
246        return self.postprocess(model_output)
247
248
249# ํ•ธ๋“ค๋Ÿฌ ์ธ์Šคํ„ด์Šค (TorchServe๊ฐ€ ๋กœ๋“œ)
250_service = ChurnPredictionHandler()
251
252
253def handle(data, context):
254    """TorchServe ์—”ํŠธ๋ฆฌ ํฌ์ธํŠธ"""
255    return _service.handle(data, context)
256
257
258# ============================================================
259# ๋กœ์ปฌ ํ…Œ์ŠคํŠธ์šฉ ์ฝ”๋“œ
260# ============================================================
261
262if __name__ == "__main__":
263    import torch.nn as nn
264
265    # ๊ฐ„๋‹จํ•œ ํ…Œ์ŠคํŠธ ๋ชจ๋ธ
266    class SimpleModel(nn.Module):
267        def __init__(self, input_size, hidden_size, num_classes):
268            super().__init__()
269            self.fc1 = nn.Linear(input_size, hidden_size)
270            self.fc2 = nn.Linear(hidden_size, num_classes)
271            self.relu = nn.ReLU()
272
273        def forward(self, x):
274            x = self.relu(self.fc1(x))
275            x = self.fc2(x)
276            return x
277
278    # ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ์ €์žฅ
279    print("ํ…Œ์ŠคํŠธ ๋ชจ๋ธ ์ƒ์„ฑ...")
280    model = SimpleModel(4, 10, 2)
281    model.eval()
282
283    # TorchScript๋กœ ์ €์žฅ
284    scripted = torch.jit.script(model)
285    scripted.save("test_model.pt")
286    print("๋ชจ๋ธ ์ €์žฅ: test_model.pt")
287
288    # ํ•ธ๋“ค๋Ÿฌ ํ…Œ์ŠคํŠธ
289    print("\nํ•ธ๋“ค๋Ÿฌ ํ…Œ์ŠคํŠธ...")
290
291    # Mock ์ปจํ…์ŠคํŠธ
292    class MockContext:
293        manifest = {"model": {"serializedFile": "test_model.pt"}}
294        system_properties = {"model_dir": ".", "gpu_id": None}
295
296    handler = ChurnPredictionHandler()
297    handler.initialize(MockContext())
298
299    # ํ…Œ์ŠคํŠธ ์š”์ฒญ
300    test_data = [
301        {"data": [1.0, 2.0, 3.0, 4.0]},
302        {"data": [5.0, 6.0, 7.0, 8.0]}
303    ]
304
305    results = handler.handle(test_data, MockContext())
306
307    print("\n๊ฒฐ๊ณผ:")
308    for i, result in enumerate(results):
309        print(f"  ์ƒ˜ํ”Œ {i+1}:")
310        print(f"    ์˜ˆ์ธก: {result['prediction']}")
311        print(f"    ํ™•๋ฅ : {result['probabilities']}")
312        print(f"    ์‹ ๋ขฐ๋„: {result['confidence']:.4f}")
313
314    # ์ •๋ฆฌ
315    import os
316    os.remove("test_model.pt")
317    print("\nํ…Œ์ŠคํŠธ ์™„๋ฃŒ!")