finhdev commited on
Commit
37da151
·
verified ·
1 Parent(s): b897d37

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -16
handler.py CHANGED
@@ -5,10 +5,8 @@ from PIL import Image
5
 
6
  class EndpointHandler:
7
  """
8
- MobileCLIP‑B (pretrained='datacompdr') zero‑shot classifier with
9
- per‑container text‑embedding cache.
10
-
11
- Client JSON:
12
  {
13
  "inputs": {
14
  "image": "<base64 PNG/JPEG>",
@@ -18,50 +16,48 @@ class EndpointHandler:
18
  """
19
 
20
  def __init__(self, path=""):
21
- # --- model & transforms ---------------------------------
22
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
23
  "mobileclip_b", pretrained="datacompdr"
24
  )
25
  self.model.eval()
26
 
27
- # --- tokenizer & device --------------------------------
28
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
29
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
  self.model.to(self.device)
31
 
32
- # --- text‑embedding cache ------------------------------
33
  self.cache: dict[str, torch.Tensor] = {}
34
 
35
  def __call__(self, data):
36
- # 1. unwrap HF 'inputs'
37
  payload = data.get("inputs", data)
38
  img_b64 = payload["image"]
39
  labels = payload.get("candidate_labels", [])
40
  if not labels:
41
  return {"error": "candidate_labels list is empty"}
42
 
43
- # 2. image -> tensor
44
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
45
  img_t = self.preprocess(img).unsqueeze(0).to(self.device)
46
 
47
- # 3. text -> cached embeddings
48
- missing = [l for l in labels if l not in self.cache]
49
- if missing:
50
- tok = self.tokenizer(missing).to(self.device)
51
  with torch.no_grad():
52
  emb = self.model.encode_text(tok)
53
  emb = emb / emb.norm(dim=-1, keepdim=True)
54
- for l, e in zip(missing, emb):
55
  self.cache[l] = e
56
  txt_t = torch.stack([self.cache[l] for l in labels])
57
 
58
- # 4. forward
59
  with torch.no_grad(), torch.cuda.amp.autocast():
60
  img_f = self.model.encode_image(img_t)
61
  img_f = img_f / img_f.norm(dim=-1, keepdim=True)
62
  probs = (100 * img_f @ txt_t.T).softmax(dim=-1)[0].tolist()
63
 
64
- # 5. sorted response
65
  return [
66
  {"label": l, "score": float(p)}
67
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
 
5
 
6
  class EndpointHandler:
7
  """
8
+ MobileCLIP‑B (pretrained='datacompdr') zero‑shot classifier.
9
+ Expects JSON:
 
 
10
  {
11
  "inputs": {
12
  "image": "<base64 PNG/JPEG>",
 
16
  """
17
 
18
  def __init__(self, path=""):
19
+ # Model + transforms
20
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
21
  "mobileclip_b", pretrained="datacompdr"
22
  )
23
  self.model.eval()
24
 
25
+ # Tokeniser & device
26
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
27
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
  self.model.to(self.device)
29
 
30
+ # Cache {prompt 1×512 tensor}
31
  self.cache: dict[str, torch.Tensor] = {}
32
 
33
  def __call__(self, data):
 
34
  payload = data.get("inputs", data)
35
  img_b64 = payload["image"]
36
  labels = payload.get("candidate_labels", [])
37
  if not labels:
38
  return {"error": "candidate_labels list is empty"}
39
 
40
+ # Image tensor
41
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
42
  img_t = self.preprocess(img).unsqueeze(0).to(self.device)
43
 
44
+ # Text embeddings with cache
45
+ new_labels = [l for l in labels if l not in self.cache]
46
+ if new_labels:
47
+ tok = self.tokenizer(new_labels).to(self.device)
48
  with torch.no_grad():
49
  emb = self.model.encode_text(tok)
50
  emb = emb / emb.norm(dim=-1, keepdim=True)
51
+ for l, e in zip(new_labels, emb):
52
  self.cache[l] = e
53
  txt_t = torch.stack([self.cache[l] for l in labels])
54
 
55
+ # Forward
56
  with torch.no_grad(), torch.cuda.amp.autocast():
57
  img_f = self.model.encode_image(img_t)
58
  img_f = img_f / img_f.norm(dim=-1, keepdim=True)
59
  probs = (100 * img_f @ txt_t.T).softmax(dim=-1)[0].tolist()
60
 
 
61
  return [
62
  {"label": l, "score": float(p)}
63
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)