finhdev commited on
Commit
d46ebda
·
verified ·
1 Parent(s): 4bb64f6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -8
handler.py CHANGED
@@ -1,14 +1,12 @@
1
- # handler.py (repo root)
2
  import io, base64, torch, open_clip
3
  from PIL import Image
4
- # optional: from open_clip import fuse_conv_bn_sequential # if you want re‑param
5
 
6
  class EndpointHandler:
7
  """
8
- MobileCLIP‑B ('datacompdr') zero‑shot classifier with per‑process
9
  text‑embedding cache.
10
 
11
- Expected client JSON:
12
  {
13
  "inputs": {
14
  "image": "<base64 PNG/JPEG>",
@@ -17,13 +15,11 @@ class EndpointHandler:
17
  }
18
  """
19
 
 
20
  def __init__(self, path=""):
21
- # Load the exact weights your local run uses
22
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
23
  "mobileclip_b", pretrained="datacompdr"
24
  )
25
- # Optional: fuse conv+bn for speed
26
- # self.model = fuse_conv_bn_sequential(self.model).eval()
27
  self.model.eval()
28
 
29
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
@@ -32,6 +28,7 @@ class EndpointHandler:
32
 
33
  self.cache: dict[str, torch.Tensor] = {} # prompt → embedding
34
 
 
35
  def __call__(self, data):
36
  payload = data.get("inputs", data)
37
  img_b64 = payload["image"]
@@ -43,7 +40,7 @@ class EndpointHandler:
43
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
44
  img_t = self.preprocess(img).unsqueeze(0).to(self.device)
45
 
46
- # Text embeddings with cache
47
  new = [l for l in labels if l not in self.cache]
48
  if new:
49
  tok = self.tokenizer(new).to(self.device)
@@ -65,6 +62,7 @@ class EndpointHandler:
65
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
66
  ]
67
 
 
68
  # import io, base64, torch
69
  # from PIL import Image
70
  # import open_clip
 
 
1
  import io, base64, torch, open_clip
2
  from PIL import Image
 
3
 
4
  class EndpointHandler:
5
  """
6
+ MobileCLIP‑B ('datacompdr') zero‑shot classifier with a per‑process
7
  text‑embedding cache.
8
 
9
+ Client JSON must look like:
10
  {
11
  "inputs": {
12
  "image": "<base64 PNG/JPEG>",
 
15
  }
16
  """
17
 
18
+ # ---------- init (runs once per container) ----------
19
  def __init__(self, path=""):
 
20
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
21
  "mobileclip_b", pretrained="datacompdr"
22
  )
 
 
23
  self.model.eval()
24
 
25
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
 
28
 
29
  self.cache: dict[str, torch.Tensor] = {} # prompt → embedding
30
 
31
+ # ----------------- inference ------------------------
32
  def __call__(self, data):
33
  payload = data.get("inputs", data)
34
  img_b64 = payload["image"]
 
40
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
41
  img_t = self.preprocess(img).unsqueeze(0).to(self.device)
42
 
43
+ # Text embeddings (cached)
44
  new = [l for l in labels if l not in self.cache]
45
  if new:
46
  tok = self.tokenizer(new).to(self.device)
 
62
  for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
63
  ]
64
 
65
+
66
  # import io, base64, torch
67
  # from PIL import Image
68
  # import open_clip