finhdev commited on
Commit
147df04
·
verified ·
1 Parent(s): d46ebda

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -37
handler.py CHANGED
@@ -1,68 +1,63 @@
1
- import io, base64, torch, open_clip
 
 
2
  from PIL import Image
 
3
 
4
  class EndpointHandler:
5
  """
6
- MobileCLIPB ('datacompdr') zero‑shot classifier with a per‑process
7
- text‑embedding cache.
8
-
9
- Client JSON must look like:
10
- {
11
- "inputs": {
12
- "image": "<base64 PNG/JPEG>",
13
- "candidate_labels": ["a photo of a cat", ...]
14
- }
15
  }
 
16
  """
17
 
18
- # ---------- init (runs once per container) ----------
19
- def __init__(self, path=""):
20
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
21
- "mobileclip_b", pretrained="datacompdr"
22
  )
23
  self.model.eval()
24
 
25
- self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.model.to(self.device)
28
 
29
- self.cache: dict[str, torch.Tensor] = {} # prompt → embedding
30
-
31
- # ----------------- inference ------------------------
32
  def __call__(self, data):
 
33
  payload = data.get("inputs", data)
 
34
  img_b64 = payload["image"]
35
  labels = payload.get("candidate_labels", [])
36
  if not labels:
37
  return {"error": "candidate_labels list is empty"}
38
 
39
- # Image tensor
40
- img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
41
- img_t = self.preprocess(img).unsqueeze(0).to(self.device)
42
-
43
- # Text embeddings (cached)
44
- new = [l for l in labels if l not in self.cache]
45
- if new:
46
- tok = self.tokenizer(new).to(self.device)
47
- with torch.no_grad():
48
- emb = self.model.encode_text(tok)
49
- emb = emb / emb.norm(dim=-1, keepdim=True)
50
- for l, e in zip(new, emb):
51
- self.cache[l] = e
52
- txt_t = torch.stack([self.cache[l] for l in labels])
53
-
54
- # Forward
55
  with torch.no_grad(), torch.cuda.amp.autocast():
56
- img_f = self.model.encode_image(img_t)
57
- img_f = img_f / img_f.norm(dim=-1, keepdim=True)
58
- probs = (100 * img_f @ txt_t.T).softmax(dim=-1)[0].tolist()
 
 
59
 
 
60
  return [
61
  {"label": l, "score": float(p)}
62
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
63
  ]
64
 
65
-
66
  # import io, base64, torch
67
  # from PIL import Image
68
  # import open_clip
 
1
+
2
+ # handler.py (repo root)
3
+ import io, base64, torch
4
  from PIL import Image
5
+ import open_clip
6
 
7
  class EndpointHandler:
8
  """
9
+ Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
10
+
11
+ Expected client JSON *to the endpoint*:
12
+ {
13
+ "inputs": {
14
+ "image": "<base64 PNG/JPEG>",
15
+ "candidate_labels": ["cat", "dog", ...]
 
 
16
  }
17
+ }
18
  """
19
 
20
+ def __init__(self, path: str = ""):
21
+ weights = f"{path}/mobileclip_b.pt"
22
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
23
+ "MobileCLIP-B", pretrained=weights
24
  )
25
  self.model.eval()
26
 
27
+ self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
  self.model.to(self.device)
30
 
 
 
 
31
  def __call__(self, data):
32
+ # ── unwrap Hugging Face's `inputs` envelope ───────────
33
  payload = data.get("inputs", data)
34
+
35
  img_b64 = payload["image"]
36
  labels = payload.get("candidate_labels", [])
37
  if not labels:
38
  return {"error": "candidate_labels list is empty"}
39
 
40
+ # Decode & preprocess image
41
+ image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
42
+ img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
43
+
44
+ # Tokenise labels
45
+ text_tokens = self.tokenizer(labels).to(self.device)
46
+
47
+ # Forward pass
 
 
 
 
 
 
 
 
48
  with torch.no_grad(), torch.cuda.amp.autocast():
49
+ img_feat = self.model.encode_image(img_tensor)
50
+ txt_feat = self.model.encode_text(text_tokens)
51
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
52
+ txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
53
+ probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
54
 
55
+ # Sorted output
56
  return [
57
  {"label": l, "score": float(p)}
58
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
59
  ]
60
 
 
61
  # import io, base64, torch
62
  # from PIL import Image
63
  # import open_clip