finhdev commited on
Commit
88b442b
·
verified ·
1 Parent(s): 440d3d5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -20
handler.py CHANGED
@@ -5,8 +5,10 @@ from PIL import Image
5
 
6
  class EndpointHandler:
7
  """
8
- MobileCLIP‑B zero‑shot (OpenCLIP, pretrained = 'datacompdr')
9
- Expects JSON:
 
 
10
  {
11
  "inputs": {
12
  "image": "<base64 PNG/JPEG>",
@@ -15,52 +17,51 @@ class EndpointHandler:
15
  }
16
  """
17
 
18
- # ---------- initialisation (once per container) ----------
19
  def __init__(self, path=""):
20
- # Use the same checkpoint as your local workflow
21
- # • No need for the local mobileclip_b.pt file
22
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
23
  "mobileclip_b", pretrained="datacompdr"
24
  )
25
  self.model.eval()
26
 
 
27
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
  self.model.to(self.device)
30
 
31
- # Cache: {prompt -> 1×512 tensor}
32
- self.label_cache: dict[str, torch.Tensor] = {}
33
 
34
- # -------------------- inference --------------------------
35
  def __call__(self, data):
 
36
  payload = data.get("inputs", data)
37
  img_b64 = payload["image"]
38
  labels = payload.get("candidate_labels", [])
39
  if not labels:
40
  return {"error": "candidate_labels list is empty"}
41
 
42
- # image tensor
43
- image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
44
- img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
45
 
46
- # text cached embeddings
47
- missing = [l for l in labels if l not in self.label_cache]
48
  if missing:
49
  tok = self.tokenizer(missing).to(self.device)
50
  with torch.no_grad():
51
  emb = self.model.encode_text(tok)
52
  emb = emb / emb.norm(dim=-1, keepdim=True)
53
  for l, e in zip(missing, emb):
54
- self.label_cache[l] = e
55
- txt_feat = torch.stack([self.label_cache[l] for l in labels])
56
 
57
- # forward
58
  with torch.no_grad(), torch.cuda.amp.autocast():
59
- img_feat = self.model.encode_image(img_tensor)
60
- img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
61
- probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
62
 
63
- # sorted result
64
  return [
65
  {"label": l, "score": float(p)}
66
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
 
5
 
6
  class EndpointHandler:
7
  """
8
+ MobileCLIP‑B (pretrained='datacompdr') zero‑shot classifier with
9
+ per‑container text‑embedding cache.
10
+
11
+ Client JSON:
12
  {
13
  "inputs": {
14
  "image": "<base64 PNG/JPEG>",
 
17
  }
18
  """
19
 
 
20
  def __init__(self, path=""):
21
+ # --- model & transforms ---------------------------------
 
22
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
23
  "mobileclip_b", pretrained="datacompdr"
24
  )
25
  self.model.eval()
26
 
27
+ # --- tokenizer & device --------------------------------
28
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
29
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
  self.model.to(self.device)
31
 
32
+ # --- text‑embedding cache ------------------------------
33
+ self.cache: dict[str, torch.Tensor] = {}
34
 
 
35
  def __call__(self, data):
36
+ # 1. unwrap HF 'inputs'
37
  payload = data.get("inputs", data)
38
  img_b64 = payload["image"]
39
  labels = payload.get("candidate_labels", [])
40
  if not labels:
41
  return {"error": "candidate_labels list is empty"}
42
 
43
+ # 2. image -> tensor
44
+ img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
45
+ img_t = self.preprocess(img).unsqueeze(0).to(self.device)
46
 
47
+ # 3. text -> cached embeddings
48
+ missing = [l for l in labels if l not in self.cache]
49
  if missing:
50
  tok = self.tokenizer(missing).to(self.device)
51
  with torch.no_grad():
52
  emb = self.model.encode_text(tok)
53
  emb = emb / emb.norm(dim=-1, keepdim=True)
54
  for l, e in zip(missing, emb):
55
+ self.cache[l] = e
56
+ txt_t = torch.stack([self.cache[l] for l in labels])
57
 
58
+ # 4. forward
59
  with torch.no_grad(), torch.cuda.amp.autocast():
60
+ img_f = self.model.encode_image(img_t)
61
+ img_f = img_f / img_f.norm(dim=-1, keepdim=True)
62
+ probs = (100 * img_f @ txt_t.T).softmax(dim=-1)[0].tolist()
63
 
64
+ # 5. sorted response
65
  return [
66
  {"label": l, "score": float(p)}
67
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)