File size: 1,444 Bytes
f1f49b4 0e2d235 d7911ee 0e2d235 f1f49b4 d7911ee f1f49b4 0e2d235 f1f49b4 d7911ee f1f49b4 0e2d235 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import Dict, Any
import requests
import io
import base64
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
class EndpointHandler:
def __init__(self, path=""):
self.processor = CLIPProcessor.from_pretrained(path)
self.model = CLIPModel.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict:
print("this shows the custom endpoint handler is being called")
inputs = data.pop("inputs", data)
text = inputs.pop("text")
if "image_url" in inputs:
image_url = inputs.pop("image_url")
image = Image.open(requests.get(image_url, stream=True).raw)
else:
image = inputs.pop("image")
image = Image.open(io.BytesIO(base64.b64decode(image)))
processed_inputs = self.processor(text=text, images=image,
return_tensors="pt", padding=True, truncation=True)
outputs = self.model(**processed_inputs)
embedding_similarity = cosine_similarity(outputs.text_embeds.detach().numpy(),
outputs.image_embeds.detach().numpy())[0][0].item()
return {"text_embedding": outputs.text_embeds[0].tolist(),
"image_embedding": outputs.image_embeds[0].tolist(),
"embedding_similarity": embedding_similarity}
|