finhdev commited on
Commit
35037e4
·
verified ·
1 Parent(s): 6325232

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -11
handler.py CHANGED
@@ -1,21 +1,56 @@
 
1
  import io, base64, torch
2
  from PIL import Image
3
- from transformers import CLIPProcessor, CLIPModel
 
 
4
 
5
  class EndpointHandler:
6
- def __init__(self, path=""):
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
- self.model = CLIPModel.from_pretrained(path).to(device)
9
- self.processor = CLIPProcessor.from_pretrained(path)
10
- self.device = device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def __call__(self, data):
13
- # Expect JSON {"image": "<base64 PNG/JPEG>", "candidate_labels": ["cat","dog"]}
14
  img_b64 = data["image"]
15
  labels = data.get("candidate_labels", [])
16
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
17
 
18
- inputs = self.processor(text=labels, images=image,
19
- return_tensors="pt", padding=True).to(self.device)
20
- probs = self.model(**inputs).logits_per_image.softmax(dim=-1)[0].tolist()
21
- return [{"label": l, "score": float(p)} for l, p in zip(labels, probs)]
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py – place in repo root
2
  import io, base64, torch
3
  from PIL import Image
4
+
5
+ import open_clip
6
+ from mobileclip.modules.common.mobileone import reparameterize_model
7
 
8
  class EndpointHandler:
9
+ """
10
+ Zero‑shot image classifier for MobileCLIP‑B using OpenCLIP.
11
+ Expects JSON:
12
+ {
13
+ "image": "<base64‑encoded PNG/JPEG>",
14
+ "candidate_labels": ["cat", "dog", ...]
15
+ }
16
+ """
17
+ def __init__(self, path: str = ""):
18
+ # Hugging Face Endpoints clones the repo into `path`.
19
+ # The weights file is mobileclip_b.pt (already in the repo).
20
+ weights = f"{path}/mobileclip_b.pt"
21
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
22
+ "MobileCLIP-B", pretrained=weights
23
+ )
24
+
25
+ # Re‑parameterize once for faster inference (as per MobileCLIP docs)
26
+ self.model = reparameterize_model(self.model)
27
+ self.model.eval()
28
+
29
+ # OpenCLIP tokenizer (same as CLIP)
30
+ self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
31
+
32
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ self.model.to(self.device)
34
 
35
  def __call__(self, data):
36
+ # Decode input
37
  img_b64 = data["image"]
38
  labels = data.get("candidate_labels", [])
39
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
40
 
41
+ # Preprocess
42
+ image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
43
+ text_tokens = self.tokenizer(labels).to(self.device)
44
+
45
+ with torch.no_grad(), torch.cuda.amp.autocast():
46
+ img_feat = self.model.encode_image(image_tensor)
47
+ txt_feat = self.model.encode_text(text_tokens)
48
+ img_feat /= img_feat.norm(dim=-1, keepdim=True)
49
+ txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
50
+ probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
51
+
52
+ return [
53
+ {"label": l, "score": float(p)} for l, p in sorted(
54
+ zip(labels, probs), key=lambda x: x[1], reverse=True
55
+ )
56
+ ]