testmobileclip / handler.py
finhdev's picture
Update handler.py
407a13c verified
import contextlib, io, base64, torch, json
from PIL import Image
import open_clip
from reparam import reparameterize_model
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load the model (happens only once at startup)
model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained='datacompdr'
)
model.eval()
self.model = reparameterize_model(model)
tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
self.model.to(self.device)
if self.device == "cuda":
self.model.to(torch.float16)
# --- OPTIMIZATION: Pre-compute text features from your JSON ---
# 2. Load your rich class definitions from the file
with open(f"{path}/items.json", "r", encoding="utf-8") as f:
class_definitions = json.load(f)
# 3. Prepare the data for encoding and for the final response
# - Use the 'prompt' field for creating the embeddings
# - Keep 'name' and 'id' to structure the response later
prompts = [item['prompt'] for item in class_definitions]
self.class_ids = [item['id'] for item in class_definitions]
self.class_names = [item['name'] for item in class_definitions]
# 4. Tokenize and encode all prompts at once
with torch.no_grad():
text_tokens = tokenizer(prompts).to(self.device)
self.text_features = self.model.encode_text(text_tokens)
self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
def __call__(self, data):
# The payload only needs the image now
payload = data.get("inputs", data)
img_b64 = payload["image"]
# ---------------- decode image ----------------
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
if self.device == "cuda":
img_tensor = img_tensor.to(torch.float16)
# ---------------- forward pass (very fast) -----------------
with torch.no_grad():
# 1. Encode only the image
img_feat = self.model.encode_image(img_tensor)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
# 2. Compute similarity against the pre-computed text features
probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0]
# 3. Combine the results with your stored class IDs and names
# and convert the tensor of probabilities to a list of floats
results = zip(self.class_ids, self.class_names, probs.cpu().tolist())
# 4. Create a sorted list of dictionaries for a clean JSON response
return sorted(
[{"id": i, "label": name, "score": float(p)} for i, name, p in results],
key=lambda x: x["score"],
reverse=True
)
# """
# MobileCLIP‑B Zero‑Shot Image Classifier (Hugging Face Inference Endpoint)
# ===========================================================================
# * One container instance is created per replica; the `EndpointHandler`
# object below is instantiated exactly **once** at start‑up.
# * At request time (`__call__`) we receive a base‑64‑encoded image, run a
# **single forward pass**, and return class probabilities.
# Design choices
# --------------
# 1. **Model & transform come from OpenCLIP**
# This guarantees we apply **identical preprocessing** to what the model
# was trained with (224 × 224 crop + mean/std normalisation).
# 2. **Re‑parameterisation for inference**
# MobileCLIP uses MobileOne blocks that have extra convolution branches
# for training; `reparameterize_model` fuses them so inference is fast
# and deterministic.
# 3. **Text embeddings are cached**
# The class “prompts” (e.g. `"a photo of a cat"`) are encoded **once at
# start‑up**. Each request therefore encodes *only* the image and
# performs a single matrix multiplication.
# 4. **Mixed precision on GPU**
# If the container has CUDA, we cast the model **and** inputs to
# `float16`. That halves memory and roughly doubles throughput on most
# modern GPUs. On CPU we stay in `float32` for numerical stability.
# """
# import contextlib, io, base64, json
# from pathlib import Path
# from typing import Any, Dict, List
# import torch
# from PIL import Image
# import open_clip
# from reparam import reparameterize_model # local copy (~60 LoC) of Apple’s helper
# class EndpointHandler:
# """
# Hugging Face entry‑point. The toolkit will instantiate this class
# once and call it for every HTTP request.
# Parameters
# ----------
# path : str, optional
# Root directory of the repository. HF mounts the code under
# `/repository`; we use this path to locate `items.json`.
# """
# # ------------------------------------------------------------------ #
# # INITIALISATION (runs **once**) #
# # ------------------------------------------------------------------ #
# def __init__(self, path: str = "") -> None:
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# # 1️⃣ Load MobileCLIP‑B weights & transforms -------------------
# # `pretrained="datacompdr"` makes OpenCLIP download the
# # official checkpoint from the Hub (cached in the image layer).
# model, _, self.preprocess = open_clip.create_model_and_transforms(
# "MobileCLIP-B", pretrained="datacompdr"
# )
# model.eval() # disable dropout / BN updates
# model = reparameterize_model(model) # fuse MobileOne branches
# model.to(self.device)
# if self.device == "cuda":
# model = model.to(torch.float16) # FP16 for throughput
# self.model = model # hold a reference
# # 2️⃣ Build the tokenizer once --------------------------------
# tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# # 3️⃣ Load class metadata -------------------------------------
# # Expect JSON file: [{"id": 3, "name": "cat", "prompt": "cat"}, …]
# items_path = Path(path) / "items.json"
# with items_path.open("r", encoding="utf-8") as f:
# class_defs: List[Dict[str, Any]] = json.load(f)
# # Extract the bits we need later
# prompts = [item["prompt"] for item in class_defs]
# self.class_ids: List[int] = [item["id"] for item in class_defs]
# self.class_names: List[str] = [item["name"] for item in class_defs]
# # 4️⃣ Encode all prompts once ---------------------------------
# with torch.no_grad():
# text_tokens = tokenizer(prompts).to(self.device)
# text_feats = self.model.encode_text(text_tokens)
# text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
# self.text_features = text_feats # [num_classes, 512]
# # ------------------------------------------------------------------ #
# # INFERENCE CALL #
# # ------------------------------------------------------------------ #
# def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# """
# Parameters
# ----------
# data : dict
# Either the raw payload `{"image": "<base64>"}` **or** the
# Hugging Face convention `{"inputs": {...}}`.
# Returns
# -------
# list of dict
# Sorted list of `{"id": int, "label": str, "score": float}`.
# Scores are the softmax probabilities over the *provided*
# class list (they sum to 1.0).
# """
# # 1️⃣ Unpack the request payload ------------------------------
# payload: Dict[str, Any] = data.get("inputs", data)
# img_b64: str = payload["image"]
# # 2️⃣ Decode + preprocess -------------------------------------
# image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # [1, 3, 224, 224]
# if self.device == "cuda":
# img_tensor = img_tensor.to(torch.float16)
# # 3️⃣ Forward pass (image only) -------------------------------
# with torch.no_grad(): # no autograd graph
# img_feat = self.model.encode_image(img_tensor) # [1, 512]
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) # L2‑normalise
# # cosine similarity → logits → softmax probabilities
# probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0] # [num_classes]
# # 4️⃣ Assemble JSON‑serialisable response ---------------------
# results = zip(self.class_ids, self.class_names, probs.cpu().tolist())
# return sorted(
# [{"id": cid, "label": name, "score": float(p)} for cid, name, p in results],
# key=lambda x: x["score"],
# reverse=True,
# )