Krishna Indukuri commited on
Commit
e031746
·
verified ·
1 Parent(s): a026ca0

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +155 -0
handler.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import base64
5
+ from io import BytesIO
6
+ from typing import List, Dict, Any, Union
7
+ from PIL import Image
8
+ from transformers import AutoProcessor
9
+ from custom_st import Transformer
10
+
11
+ class ModelHandler:
12
+ """
13
+ Custom handler for the embedding model using the Transformer class from custom_st.py
14
+ """
15
+ def __init__(self):
16
+ self.initialized = False
17
+ self.model = None
18
+ self.processor = None
19
+ self.device = None
20
+ self.default_task = "retrieval" # Default task, can be overridden in initialize
21
+ self.max_seq_length = 8192 # Default max sequence length
22
+
23
+ def initialize(self, context):
24
+ """
25
+ Initialize model and processor
26
+ """
27
+ self.initialized = True
28
+
29
+ # Get model directory
30
+ properties = context.system_properties
31
+ model_dir = properties.get("model_dir")
32
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ # Load config if exists
35
+ config_path = os.path.join(model_dir, "config.json")
36
+ if os.path.exists(config_path):
37
+ with open(config_path, 'r') as f:
38
+ config = json.load(f)
39
+ self.default_task = config.get("default_task", self.default_task)
40
+ self.max_seq_length = config.get("max_seq_length", self.max_seq_length)
41
+
42
+ # Initialize model
43
+ self.model = Transformer(
44
+ model_name_or_path=model_dir,
45
+ max_seq_length=self.max_seq_length,
46
+ model_args={"default_task": self.default_task}
47
+ )
48
+ self.model.model.to(self.device)
49
+ self.model.model.eval()
50
+
51
+ # Get processor from the model
52
+ self.processor = self.model.processor
53
+
54
+ def preprocess(self, data):
55
+ """
56
+ Process input data for the model
57
+ """
58
+ inputs = []
59
+
60
+ # Extract request body
61
+ for row in data:
62
+ body = row.get("body", {})
63
+ if isinstance(body, (bytes, bytearray)):
64
+ body = json.loads(body.decode('utf-8'))
65
+ elif isinstance(body, str):
66
+ body = json.loads(body)
67
+
68
+ # Handle different input formats
69
+ if "inputs" in body:
70
+ raw_inputs = body["inputs"]
71
+ if isinstance(raw_inputs, str):
72
+ inputs.append(raw_inputs)
73
+ elif isinstance(raw_inputs, list):
74
+ inputs.extend(raw_inputs)
75
+ elif "text" in body:
76
+ inputs.append(body["text"])
77
+ elif "image" in body:
78
+ # Handle base64 encoded images
79
+ image_data = body["image"]
80
+ if isinstance(image_data, str) and image_data.startswith("data:image"):
81
+ # Extract base64 data from data URL
82
+ image_data = image_data.split(",")[1]
83
+ image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
84
+ inputs.append(image)
85
+ else:
86
+ inputs.append(image_data) # URL or file path
87
+ elif "inputs" not in body and not body:
88
+ # Empty request, return empty response
89
+ return []
90
+
91
+ # Use the model's tokenize method to process inputs
92
+ if inputs:
93
+ features = self.model.tokenize(inputs)
94
+ return features
95
+
96
+ return []
97
+
98
+ def inference(self, features):
99
+ """
100
+ Run inference with the processed features
101
+ """
102
+ if not features:
103
+ return {"embeddings": []}
104
+
105
+ # Move tensors to the device
106
+ for key, value in features.items():
107
+ if isinstance(value, torch.Tensor):
108
+ features[key] = value.to(self.device)
109
+
110
+ with torch.no_grad():
111
+ outputs = self.model.forward(features, task=self.default_task)
112
+
113
+ # Get the embeddings
114
+ embeddings = outputs.get("sentence_embedding", None)
115
+
116
+ if embeddings is not None:
117
+ # Convert to list for JSON serialization
118
+ return {"embeddings": embeddings.cpu().numpy().tolist()}
119
+ else:
120
+ return {"error": "No embeddings were generated"}
121
+
122
+ def postprocess(self, inference_output):
123
+ """
124
+ Process model output for the response
125
+ """
126
+ return [inference_output]
127
+
128
+ def handle(self, data, context):
129
+ """
130
+ Main handler function
131
+ """
132
+ if not self.initialized:
133
+ self.initialize(context)
134
+
135
+ if not data:
136
+ return {"embeddings": []}
137
+
138
+ try:
139
+ processed_data = self.preprocess(data)
140
+ if not processed_data:
141
+ return [{"embeddings": []}]
142
+
143
+ inference_result = self.inference(processed_data)
144
+ return self.postprocess(inference_result)
145
+ except Exception as e:
146
+ raise Exception(f"Error processing request: {str(e)}")
147
+
148
+ # Define the handler for torchserve
149
+ _service = ModelHandler()
150
+
151
+ def handle(data, context):
152
+ """
153
+ Torchserve handler function
154
+ """
155
+ return _service.handle(data, context)