roemmele commited on
Commit
f1f49b4
·
1 Parent(s): b025b7f

Debugging handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -8
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict
2
  import requests
3
  from transformers import CLIPProcessor, CLIPModel
4
  from PIL import Image
@@ -10,16 +10,18 @@ class EndpointHandler:
10
  self.processor = CLIPProcessor.from_pretrained(path)
11
  self.model = CLIPModel.from_pretrained(path)
12
 
13
- def __call__(self, data: Dict) -> Dict:
14
- text = data.pop("text")
 
 
15
  if "image_url" in data:
16
- image_url = data.pop("image_url")
17
  image = Image.open(requests.get(image_url, stream=True).raw)
18
  else:
19
- image = data.pop("image")
20
- inputs = self.processor(text=text, images=image,
21
- return_tensors="pt", padding=True, truncation=True)
22
- outputs = self.model(**inputs)
23
  embedding_similarity = cosine_similarity(outputs.text_embeds.detach().numpy(),
24
  outputs.image_embeds.detach().numpy())[0][0].item()
25
  return {"text_embedding": outputs.text_embeds[0].tolist(),
 
1
+ from typing import Dict, Any
2
  import requests
3
  from transformers import CLIPProcessor, CLIPModel
4
  from PIL import Image
 
10
  self.processor = CLIPProcessor.from_pretrained(path)
11
  self.model = CLIPModel.from_pretrained(path)
12
 
13
+ def __call__(self, data: Dict[str, Any]) -> Dict:
14
+ print("this shows the custom endpoint handler is being called")
15
+ inputs = data.pop("inputs", data)
16
+ text = inputs.pop("text")
17
  if "image_url" in data:
18
+ image_url = inputs.pop("image_url")
19
  image = Image.open(requests.get(image_url, stream=True).raw)
20
  else:
21
+ image = inputs.pop("image")
22
+ processed_inputs = self.processor(text=text, images=image,
23
+ return_tensors="pt", padding=True, truncation=True)
24
+ outputs = self.model(**processed_inputs)
25
  embedding_similarity = cosine_similarity(outputs.text_embeds.detach().numpy(),
26
  outputs.image_embeds.detach().numpy())[0][0].item()
27
  return {"text_embedding": outputs.text_embeds[0].tolist(),