# handler.py from typing import Dict, Any, List import torch import PIL.Image from io import BytesIO import base64 from transformers import AutoTokenizer, AutoModelForCausalLM import logging # Configure logging for debugging purposes logging.basicConfig(level=logging.INFO) class EndpointHandler: def __init__(self, path=""): logging.info("Initializing EndpointHandler for Moondream2") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device: {self.device}") # Load the model with trust_remote_code enabled. # 'path' points to the location of the model files inside the container. self.model = AutoModelForCausalLM.from_pretrained( path, trust_remote_code=True, torch_dtype=torch.float16, device_map=self.device ) self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) # Ensure the model is moved to the device self.model.to(self.device) self.model.eval() logging.info("Moondream2 model loaded successfully.") def preprocess_image(self, encoded_image: str) -> PIL.Image.Image: """Decode and preprocess the base64 encoded image.""" try: image_data = base64.b64decode(encoded_image) return PIL.Image.open(BytesIO(image_data)).convert("RGB") except Exception as e: logging.error(f"Error decoding image: {e}") raise ValueError(f"Failed to decode image data: {e}") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Handles the API call. The `data` argument is a dictionary containing the payload. Expects a JSON payload like: { "inputs": { "prompt": "What's in this picture?", "image": "base64_encoded_image_string" } } """ logging.info("Received request payload") inputs = data.get("inputs", {}) prompt = inputs.get("prompt", "") encoded_image = inputs.get("image", "") if not prompt or not encoded_image: raise ValueError("Prompt and base64 encoded image must be provided in the 'inputs' field.") image = self.preprocess_image(encoded_image) # Process the image and prompt enc_image = self.model.encode_image(image) # Create the conversation history for inference chat_history = f"Question: {prompt}\n\nAnswer:" logging.info(f"Running inference with prompt: {prompt}") with torch.no_grad(): output_tokens = self.model.generate( enc_image, self.tokenizer, chat_history, pad_token_id=self.tokenizer.eos_token_id, # Add other generation parameters here if needed ) # Decode the generated tokens generated_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0] logging.info(f"Inference complete. Generated text: {generated_text}") # Post-process the output to isolate the answer try: # The model output includes the prompt, so we need to extract only the answer part. answer_start_tag = "\n\nAnswer:" generated_answer = generated_text.split(answer_start_tag)[-1].strip() except IndexError: generated_answer = generated_text # Fallback if splitting fails return [{"generated_text": generated_answer}]