manpk-ai commited on
Commit
28328d0
·
1 Parent(s): 1a193cc

create handler

Browse files
Files changed (1) hide show
  1. handler.py +171 -0
handler.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModel,
5
+ AutoImageProcessor,
6
+ )
7
+ import torch
8
+ from PIL import Image
9
+ import base64
10
+ import io
11
+
12
+ # get dtype and device
13
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ class EndpointHandler():
17
+ def __init__(self, path=""):
18
+ print(f"Initializing model on device: {device}")
19
+ print(f"Using dtype: {dtype}")
20
+
21
+ # load the model - using AutoModel like in local inference
22
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
23
+ self.image_processor = AutoImageProcessor.from_pretrained(path, trust_remote_code=True)
24
+
25
+ # Load model with explicit device mapping
26
+ if device == "cuda":
27
+ self.model = AutoModel.from_pretrained(
28
+ path,
29
+ torch_dtype=dtype,
30
+ trust_remote_code=True,
31
+ device_map="auto" # Automatically map to available GPUs
32
+ )
33
+ else:
34
+ self.model = AutoModel.from_pretrained(
35
+ path,
36
+ torch_dtype=dtype,
37
+ trust_remote_code=True
38
+ )
39
+ self.model = self.model.to(device)
40
+
41
+ print(f"Model loaded successfully on device: {self.model.device}")
42
+ print(f"Model dtype: {next(self.model.parameters()).dtype}")
43
+
44
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
45
+ """
46
+ data args:
47
+ inputs (:obj: `str` or `list`): messages in chat format or text input
48
+ parameters (:obj: `dict`): generation parameters
49
+ Return:
50
+ A :obj:`list` | `dict`: will be serialized and returned
51
+ """
52
+ print("Call inside handler")
53
+ # get inputs
54
+ inputs = data.pop("inputs", data)
55
+ parameters = data.pop("parameters", {})
56
+ print("parameters", parameters)
57
+
58
+ # Remove parameters that might cause issues
59
+ parameters.pop("details", None)
60
+ parameters.pop("stop", None)
61
+ parameters.pop("return_full_text", None)
62
+ if "do_sample" in parameters:
63
+ parameters["do_sample"] = True
64
+
65
+ # Set default generation parameters
66
+ max_new_tokens = parameters.pop("max_new_tokens", 512)
67
+ temperature = parameters.pop("temperature", 0)
68
+
69
+ try:
70
+ # Handle different input formats
71
+ if isinstance(inputs, str):
72
+ # If it's a string, treat it as a simple text prompt
73
+ input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)
74
+ generated_ids = self.model.generate(
75
+ input_ids,
76
+ max_new_tokens=max_new_tokens,
77
+ temperature=temperature,
78
+ **parameters
79
+ )
80
+ prompt_len = input_ids.shape[1]
81
+ generated_ids = generated_ids[:, prompt_len:]
82
+ output_text = self.tokenizer.batch_decode(
83
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
84
+ )
85
+ return [{"generated_text": output_text[0]}]
86
+
87
+ elif isinstance(inputs, list):
88
+ # Handle chat format with images
89
+ messages = inputs
90
+
91
+ # Apply chat template
92
+ input_ids = self.tokenizer.apply_chat_template(
93
+ messages, tokenize=True, add_generation_prompt=True
94
+ )
95
+ input_text = self.tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
96
+ print(input_text)
97
+
98
+ input_ids = torch.tensor([input_ids]).to(self.model.device)
99
+
100
+ # Process ALL images if present
101
+ pixel_values_list = []
102
+ grid_thws_list = []
103
+
104
+ # Look for images in the messages
105
+ for message in messages:
106
+ if isinstance(message.get("content"), list):
107
+ for content_item in message["content"]:
108
+ if content_item.get("type") == "image_url":
109
+ image_data = content_item.get("image_url").get("url", "")
110
+ if image_data.startswith("data:image"):
111
+ # Decode base64 image
112
+ image_data = image_data.split(",")[1]
113
+ image_bytes = base64.b64decode(image_data)
114
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
115
+
116
+ # Process each image individually
117
+ info = self.image_processor.preprocess(images=[image])
118
+ pixel_values = torch.tensor(info['pixel_values']).to(dtype=dtype, device=self.model.device)
119
+ grid_thws = torch.tensor(info['image_grid_thw']).to(self.model.device)
120
+
121
+ pixel_values_list.append(pixel_values)
122
+ grid_thws_list.append(grid_thws)
123
+
124
+ # Generate response
125
+ if pixel_values_list and grid_thws_list:
126
+ # Multi-modal generation with images
127
+ # Concatenate all pixel_values and grid_thws for batch processing
128
+ all_pixel_values = torch.cat(pixel_values_list, dim=0)
129
+ all_grid_thws = torch.cat(grid_thws_list, dim=0)
130
+
131
+ print(f"Processing {len(pixel_values_list)} images")
132
+ print(f"pixel_values shape: {all_pixel_values.shape}")
133
+ print(f"grid_thws shape: {all_grid_thws.shape}")
134
+ print("grid_thws", all_grid_thws)
135
+
136
+ # Ensure all tensors are on the same device as the model
137
+ all_pixel_values = all_pixel_values.to(self.model.device)
138
+ all_grid_thws = all_grid_thws.to(self.model.device)
139
+
140
+ with torch.no_grad():
141
+ generated_ids = self.model.generate(
142
+ input_ids,
143
+ pixel_values=all_pixel_values,
144
+ grid_thws=all_grid_thws,
145
+ max_new_tokens=max_new_tokens,
146
+ temperature=temperature,
147
+ **parameters
148
+ )
149
+ else:
150
+ # Text-only generation
151
+ generated_ids = self.model.generate(
152
+ input_ids,
153
+ max_new_tokens=max_new_tokens,
154
+ temperature=temperature,
155
+ **parameters
156
+ )
157
+
158
+ prompt_len = input_ids.shape[1]
159
+ generated_ids = generated_ids[:, prompt_len:]
160
+ output_text = self.tokenizer.batch_decode(
161
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
162
+ )
163
+ print("##Model Response##", output_text)
164
+ return [{"generated_text": output_text[0]}]
165
+
166
+ else:
167
+ raise ValueError(f"Unsupported input type: {type(inputs)}")
168
+
169
+ except Exception as e:
170
+ print(f"Error during inference: {str(e)}")
171
+ return [{"error": str(e)}]