tejoess commited on
Commit
3a507ec
·
1 Parent(s): 378fa97

Add custom handler for Inference Endpoint deployment

Browse files
Files changed (2) hide show
  1. handler.py +93 -0
  2. requirements.txt +6 -0
handler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ from typing import Dict, Any, List
3
+ import torch
4
+ import PIL.Image
5
+ from io import BytesIO
6
+ import base64
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import logging
9
+
10
+ # Configure logging for debugging purposes
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ class EndpointHandler:
14
+ def __init__(self, path=""):
15
+ logging.info("Initializing EndpointHandler for Moondream2")
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ logging.info(f"Using device: {self.device}")
18
+
19
+ # Load the model with trust_remote_code enabled.
20
+ # 'path' points to the location of the model files inside the container.
21
+ self.model = AutoModelForCausalLM.from_pretrained(
22
+ path,
23
+ trust_remote_code=True,
24
+ torch_dtype=torch.float16,
25
+ device_map=self.device
26
+ )
27
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
28
+
29
+ # Ensure the model is moved to the device
30
+ self.model.to(self.device)
31
+ self.model.eval()
32
+
33
+ logging.info("Moondream2 model loaded successfully.")
34
+
35
+ def preprocess_image(self, encoded_image: str) -> PIL.Image.Image:
36
+ """Decode and preprocess the base64 encoded image."""
37
+ try:
38
+ image_data = base64.b64decode(encoded_image)
39
+ return PIL.Image.open(BytesIO(image_data)).convert("RGB")
40
+ except Exception as e:
41
+ logging.error(f"Error decoding image: {e}")
42
+ raise ValueError(f"Failed to decode image data: {e}")
43
+
44
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
45
+ """
46
+ Handles the API call. The `data` argument is a dictionary containing the payload.
47
+ Expects a JSON payload like:
48
+ {
49
+ "inputs": {
50
+ "prompt": "What's in this picture?",
51
+ "image": "base64_encoded_image_string"
52
+ }
53
+ }
54
+ """
55
+ logging.info("Received request payload")
56
+ inputs = data.get("inputs", {})
57
+ prompt = inputs.get("prompt", "")
58
+ encoded_image = inputs.get("image", "")
59
+
60
+ if not prompt or not encoded_image:
61
+ raise ValueError("Prompt and base64 encoded image must be provided in the 'inputs' field.")
62
+
63
+ image = self.preprocess_image(encoded_image)
64
+
65
+ # Process the image and prompt
66
+ enc_image = self.model.encode_image(image)
67
+
68
+ # Create the conversation history for inference
69
+ chat_history = f"Question: {prompt}\n\nAnswer:"
70
+
71
+ logging.info(f"Running inference with prompt: {prompt}")
72
+ with torch.no_grad():
73
+ output_tokens = self.model.generate(
74
+ enc_image,
75
+ self.tokenizer,
76
+ chat_history,
77
+ pad_token_id=self.tokenizer.eos_token_id,
78
+ # Add other generation parameters here if needed
79
+ )
80
+
81
+ # Decode the generated tokens
82
+ generated_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
83
+ logging.info(f"Inference complete. Generated text: {generated_text}")
84
+
85
+ # Post-process the output to isolate the answer
86
+ try:
87
+ # The model output includes the prompt, so we need to extract only the answer part.
88
+ answer_start_tag = "\n\nAnswer:"
89
+ generated_answer = generated_text.split(answer_start_tag)[-1].strip()
90
+ except IndexError:
91
+ generated_answer = generated_text # Fallback if splitting fails
92
+
93
+ return [{"generated_text": generated_answer}]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ transformers
3
+ torch
4
+ accelerate
5
+ timm
6
+ einops