finhdev commited on
Commit
e1369ab
·
verified ·
1 Parent(s): aa10251

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -31
handler.py CHANGED
@@ -1,16 +1,17 @@
1
  # handler.py (repo root)
2
 
 
 
3
  import io, base64, torch
4
  from PIL import Image
5
  import open_clip
6
- from open_clip import fuse_conv_bn_sequential
7
 
8
 
9
  class EndpointHandler:
10
  """
11
- Zero‑shot classifier for MobileCLIP‑B (OpenCLIP).
12
 
13
- Client JSON format:
14
  {
15
  "inputs": {
16
  "image": "<base64 PNG/JPEG>",
@@ -19,68 +20,57 @@ class EndpointHandler:
19
  }
20
  """
21
 
22
- # ----------------------------------------------------- #
23
- # INITIALISATION (once) #
24
- # ----------------------------------------------------- #
25
  def __init__(self, path: str = ""):
26
  weights = f"{path}/mobileclip_b.pt"
27
 
28
- # Load model + transforms
29
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
30
  "MobileCLIP-B", pretrained=weights
31
  )
 
32
 
33
- # Fuse Conv+BN for faster inference
34
- self.model = fuse_conv_bn_sequential(self.model).eval()
35
-
36
- # Tokeniser
37
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
38
-
39
- # Device
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
  self.model.to(self.device)
42
 
43
- # -------- text‑embedding cache --------
44
- # key: prompt string • value: torch.Tensor [512] on correct device
45
  self.label_cache: dict[str, torch.Tensor] = {}
46
 
47
- # ----------------------------------------------------- #
48
- # INFERENCE (per request) #
49
- # ----------------------------------------------------- #
50
  def __call__(self, data):
51
- # 1. Unwrap the HF "inputs" envelope
52
  payload = data.get("inputs", data)
53
 
54
- img_b64 = payload["image"]
55
- labels = payload.get("candidate_labels", [])
56
  if not labels:
57
  return {"error": "candidate_labels list is empty"}
58
 
59
- # 2. Decode & preprocess image
60
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
61
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
62
 
63
- # 3. Text embeddings with cache
64
  missing = [l for l in labels if l not in self.label_cache]
65
  if missing:
66
  tokens = self.tokenizer(missing).to(self.device)
67
  with torch.no_grad():
68
  emb = self.model.encode_text(tokens)
69
  emb = emb / emb.norm(dim=-1, keepdim=True)
70
- for lbl, vec in zip(missing, emb):
71
- self.label_cache[lbl] = vec # store on device
72
-
73
  txt_feat = torch.stack([self.label_cache[l] for l in labels])
74
 
75
- # 4. Forward pass for image
76
  with torch.no_grad(), torch.cuda.amp.autocast():
77
  img_feat = self.model.encode_image(img_tensor)
78
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
 
79
 
80
- # 5. Similarity & softmax
81
- probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
82
-
83
- # 6. Return sorted list
84
  return [
85
  {"label": l, "score": float(p)}
86
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
 
1
  # handler.py (repo root)
2
 
3
+ # handler.py (repo root)
4
+
5
  import io, base64, torch
6
  from PIL import Image
7
  import open_clip
 
8
 
9
 
10
  class EndpointHandler:
11
  """
12
+ Zero‑shot classifier for MobileCLIP‑B (OpenCLIP) with a text‑embedding cache.
13
 
14
+ Client JSON:
15
  {
16
  "inputs": {
17
  "image": "<base64 PNG/JPEG>",
 
20
  }
21
  """
22
 
23
+ # ------------------------------------------------- #
24
+ # INITIALISATION #
25
+ # ------------------------------------------------- #
26
  def __init__(self, path: str = ""):
27
  weights = f"{path}/mobileclip_b.pt"
28
 
 
29
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
30
  "MobileCLIP-B", pretrained=weights
31
  )
32
+ self.model.eval()
33
 
 
 
 
 
34
  self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
 
 
35
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
  self.model.to(self.device)
37
 
38
+ # cache: {prompt -> 1×512 tensor on device}
 
39
  self.label_cache: dict[str, torch.Tensor] = {}
40
 
41
+ # ------------------------------------------------- #
42
+ # INFERENCE #
43
+ # ------------------------------------------------- #
44
  def __call__(self, data):
 
45
  payload = data.get("inputs", data)
46
 
47
+ img_b64 = payload["image"]
48
+ labels = payload.get("candidate_labels", [])
49
  if not labels:
50
  return {"error": "candidate_labels list is empty"}
51
 
52
+ # --- image ----
53
  image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
54
  img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
55
 
56
+ # --- text (with cache) ----
57
  missing = [l for l in labels if l not in self.label_cache]
58
  if missing:
59
  tokens = self.tokenizer(missing).to(self.device)
60
  with torch.no_grad():
61
  emb = self.model.encode_text(tokens)
62
  emb = emb / emb.norm(dim=-1, keepdim=True)
63
+ for l, e in zip(missing, emb):
64
+ self.label_cache[l] = e
 
65
  txt_feat = torch.stack([self.label_cache[l] for l in labels])
66
 
67
+ # --- forward & softmax ----
68
  with torch.no_grad(), torch.cuda.amp.autocast():
69
  img_feat = self.model.encode_image(img_tensor)
70
  img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
71
+ probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].tolist()
72
 
73
+ # --- sorted output ----
 
 
 
74
  return [
75
  {"label": l, "score": float(p)}
76
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)