testmobileclip / handler.py
finhdev's picture
Update handler.py
048809c verified
raw
history blame
5.38 kB
import contextlib, io, base64, torch, json
from PIL import Image
import open_clip
from reparam import reparameterize_model
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load the model (happens only once at startup)
model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained='datacompdr'
)
model.eval()
self.model = reparameterize_model(model)
tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
self.model.to(self.device)
if self.device == "cuda":
self.model.to(torch.float16)
# --- OPTIMIZATION: Pre-compute text features from your JSON ---
# 2. Load your rich class definitions from the file
with open(f"{path}/items.json", "r", encoding="utf-8") as f:
class_definitions = json.load(f)
# 3. Prepare the data for encoding and for the final response
# - Use the 'prompt' field for creating the embeddings
# - Keep 'name' and 'id' to structure the response later
prompts = [item['prompt'] for item in class_definitions]
self.class_ids = [item['id'] for item in class_definitions]
self.class_names = [item['name'] for item in class_definitions]
# 4. Tokenize and encode all prompts at once
with torch.no_grad():
text_tokens = tokenizer(prompts).to(self.device)
self.text_features = self.model.encode_text(text_tokens)
self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
def __call__(self, data):
# The payload only needs the image now
payload = data.get("inputs", data)
img_b64 = payload["image"]
# ---------------- decode image ----------------
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
if self.device == "cuda":
img_tensor = img_tensor.to(torch.float16)
# ---------------- forward pass (very fast) -----------------
with torch.no_grad():
# 1. Encode only the image
img_feat = self.model.encode_image(img_tensor)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
# 2. Compute similarity against the pre-computed text features
probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0]
# 3. Combine the results with your stored class IDs and names
# and convert the tensor of probabilities to a list of floats
results = zip(self.class_ids, self.class_names, probs.cpu().tolist())
# 4. Create a sorted list of dictionaries for a clean JSON response
return sorted(
[{"id": i, "label": name, "score": float(p)} for i, name, p in results],
key=lambda x: x["score"],
reverse=True
)
# import contextlib, io, base64, torch
# from PIL import Image
# import open_clip
# from reparam import reparameterize_model
# class EndpointHandler:
# def __init__(self, path: str = ""):
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# # Fix 1: Load weights directly from the web, just like local script
# # This guarantees the weights are identical.
# model, _, self.preprocess = open_clip.create_model_and_transforms(
# "MobileCLIP-B", pretrained='datacompdr'
# )
# model.eval()
# self.model = reparameterize_model(model) # fuse branches
# self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# self.model.to(self.device)
# # Fix 2: Explicitly set model to half-precision if on CUDA
# # This matches the behavior of torch.set_default_dtype(torch.float16)
# if self.device == "cuda":
# self.model.to(torch.float16)
# def __call__(self, data):
# payload = data.get("inputs", data)
# img_b64 = payload["image"]
# labels = payload.get("candidate_labels", [])
# if not labels:
# return {"error": "candidate_labels list is empty"}
# # ---------------- decode inputs ----------------
# image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# # The preprocessor might output float32, so ensure tensor matches model dtype
# if self.device == "cuda":
# img_tensor = img_tensor.to(torch.float16)
# text_tokens = self.tokenizer(labels).to(self.device)
# # ---------------- forward pass -----------------
# # No need for autocast if everything is already float16
# with torch.no_grad():
# img_feat = self.model.encode_image(img_tensor)
# txt_feat = self.model.encode_text(text_tokens)
# img_feat /= img_feat.norm(dim=-1, keepdim=True)
# txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
# probs = (100 * img_feat @ txt_feat.T).softmax(dim=-1)[0].cpu().tolist()
# return [
# {"label": l, "score": float(p)}
# for l, p in sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
# ]