finhdev commited on
Commit
2fb4fd2
·
verified ·
1 Parent(s): 3bad150

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +18 -14
handler.py CHANGED
@@ -1,32 +1,36 @@
1
  # handler.py (repo root)
2
-
3
  import io, base64, torch, open_clip
4
  from PIL import Image
5
- from mobileclip.modules.common.mobileone import reparameterize_model # optional
6
 
7
  class EndpointHandler:
8
  """
9
- MobileCLIP‑B ('datacompdr') · textembedding cache.
10
- Expects: {
11
- "inputs": {
12
- "image": "<base64>",
13
- "candidate_labels": ["a photo of a cat", ...]
 
 
 
 
14
  }
15
- }
16
  """
17
 
18
  def __init__(self, path=""):
19
- # -- Load MobileCLIP‑B checkpoint identical to local run -------------
20
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
21
  "mobileclip_b", pretrained="datacompdr"
22
  )
23
- self.model = reparameterize_model(self.model).eval() # matches local pipeline
 
 
24
 
25
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.model.to(self.device)
28
 
29
- self.cache: dict[str, torch.Tensor] = {} # label → embedding
30
 
31
  def __call__(self, data):
32
  payload = data.get("inputs", data)
@@ -35,11 +39,11 @@ class EndpointHandler:
35
  if not labels:
36
  return {"error": "candidate_labels list is empty"}
37
 
38
- # -------- image preprocessing --------------------------------------
39
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
40
  img_t = self.preprocess(img).unsqueeze(0).to(self.device)
41
 
42
- # -------- text embeddings with cache -------------------------------
43
  new = [l for l in labels if l not in self.cache]
44
  if new:
45
  tok = self.tokenizer(new).to(self.device)
@@ -50,7 +54,7 @@ class EndpointHandler:
50
  self.cache[l] = e
51
  txt_t = torch.stack([self.cache[l] for l in labels])
52
 
53
- # -------- forward & softmax ----------------------------------------
54
  with torch.no_grad(), torch.cuda.amp.autocast():
55
  img_f = self.model.encode_image(img_t)
56
  img_f = img_f / img_f.norm(dim=-1, keepdim=True)
 
1
  # handler.py (repo root)
 
2
  import io, base64, torch, open_clip
3
  from PIL import Image
4
+ # optional: from open_clip import fuse_conv_bn_sequential # if you want re‑param
5
 
6
  class EndpointHandler:
7
  """
8
+ MobileCLIP‑B ('datacompdr') zero‑shot classifier with perprocess
9
+ text‑embedding cache.
10
+
11
+ Expected client JSON:
12
+ {
13
+ "inputs": {
14
+ "image": "<base64 PNG/JPEG>",
15
+ "candidate_labels": ["a photo of a cat", ...]
16
+ }
17
  }
 
18
  """
19
 
20
  def __init__(self, path=""):
21
+ # Load the exact weights your local run uses
22
  self.model, _, self.preprocess = open_clip.create_model_and_transforms(
23
  "mobileclip_b", pretrained="datacompdr"
24
  )
25
+ # Optional: fuse conv+bn for speed
26
+ # self.model = fuse_conv_bn_sequential(self.model).eval()
27
+ self.model.eval()
28
 
29
  self.tokenizer = open_clip.get_tokenizer("mobileclip_b")
30
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
  self.model.to(self.device)
32
 
33
+ self.cache: dict[str, torch.Tensor] = {} # prompt → embedding
34
 
35
  def __call__(self, data):
36
  payload = data.get("inputs", data)
 
39
  if not labels:
40
  return {"error": "candidate_labels list is empty"}
41
 
42
+ # Image tensor
43
  img = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
44
  img_t = self.preprocess(img).unsqueeze(0).to(self.device)
45
 
46
+ # Text embeddings with cache
47
  new = [l for l in labels if l not in self.cache]
48
  if new:
49
  tok = self.tokenizer(new).to(self.device)
 
54
  self.cache[l] = e
55
  txt_t = torch.stack([self.cache[l] for l in labels])
56
 
57
+ # Forward
58
  with torch.no_grad(), torch.cuda.amp.autocast():
59
  img_f = self.model.encode_image(img_t)
60
  img_f = img_f / img_f.norm(dim=-1, keepdim=True)