| from typing import Dict, Any | |
| import requests | |
| import io | |
| import base64 | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| self.processor = CLIPProcessor.from_pretrained(path) | |
| self.model = CLIPModel.from_pretrained(path) | |
| def __call__(self, data: Dict[str, Any]) -> Dict: | |
| print("this shows the custom endpoint handler is being called") | |
| inputs = data.pop("inputs", data) | |
| text = inputs.pop("text") | |
| if "image_url" in inputs: | |
| image_url = inputs.pop("image_url") | |
| image = Image.open(requests.get(image_url, stream=True).raw) | |
| else: | |
| image = inputs.pop("image") | |
| image = Image.open(io.BytesIO(base64.b64decode(image))) | |
| processed_inputs = self.processor(text=text, images=image, | |
| return_tensors="pt", padding=True, truncation=True) | |
| outputs = self.model(**processed_inputs) | |
| embedding_similarity = cosine_similarity(outputs.text_embeds.detach().numpy(), | |
| outputs.image_embeds.detach().numpy())[0][0].item() | |
| return {"text_embedding": outputs.text_embeds[0].tolist(), | |
| "image_embedding": outputs.image_embeds[0].tolist(), | |
| "embedding_similarity": embedding_similarity} | |