File size: 9,444 Bytes
2f51b26
9188b68
147df04
9106e2d
9188b68
652b877
147df04
c5d457b
 
2f51b26
c5d457b
 
35037e4
c5d457b
2f51b26
 
c5d457b
 
 
 
9188b68
2f51b26
 
 
048809c
2f51b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9188b68
2f51b26
825b375
e1369ab
9188b68
2f51b26
147df04
c5d457b
 
 
 
 
2f51b26
c5d457b
2f51b26
147df04
652b877
35037e4
2f51b26
 
 
 
 
 
 
 
 
 
 
 
 
407a13c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa10251
 
407a13c
 
 
aa10251
c5d457b
407a13c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f51b26
 
407a13c
 
 
2f51b26
407a13c
aa10251
407a13c
 
 
2f51b26
407a13c
 
aa10251
407a13c
 
aa10251
407a13c
 
 
 
 
2f51b26
407a13c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f51b26
 
 
407a13c
 
 
 
aa10251
407a13c
 
 
 
 
 
 
 
 
 
c5d457b
2f51b26
c5d457b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import contextlib, io, base64, torch, json
from PIL import Image
import open_clip
from reparam import reparameterize_model

class EndpointHandler:
    def __init__(self, path: str = ""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # 1. Load the model (happens only once at startup)
        model, _, self.preprocess = open_clip.create_model_and_transforms(
            "MobileCLIP-B", pretrained='datacompdr'
        )
        model.eval()
        self.model = reparameterize_model(model)
        tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
        self.model.to(self.device)

        if self.device == "cuda":
            self.model.to(torch.float16)

        # --- OPTIMIZATION: Pre-compute text features from your JSON ---
        
        # 2. Load your rich class definitions from the file
        with open(f"{path}/items.json", "r", encoding="utf-8") as f:
            class_definitions = json.load(f)

        # 3. Prepare the data for encoding and for the final response
        #    - Use the 'prompt' field for creating the embeddings
        #    - Keep 'name' and 'id' to structure the response later
        prompts = [item['prompt'] for item in class_definitions]
        self.class_ids = [item['id'] for item in class_definitions]
        self.class_names = [item['name'] for item in class_definitions]
        
        # 4. Tokenize and encode all prompts at once
        with torch.no_grad():
            text_tokens = tokenizer(prompts).to(self.device)
            self.text_features = self.model.encode_text(text_tokens)
            self.text_features /= self.text_features.norm(dim=-1, keepdim=True)

    def __call__(self, data):
        # The payload only needs the image now
        payload = data.get("inputs", data)
        img_b64 = payload["image"]

        # ---------------- decode image ----------------
        image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
        img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

        if self.device == "cuda":
            img_tensor = img_tensor.to(torch.float16)

        # ---------------- forward pass (very fast) -----------------
        with torch.no_grad():
            # 1. Encode only the image
            img_feat = self.model.encode_image(img_tensor)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)

            # 2. Compute similarity against the pre-computed text features
            probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0]

        # 3. Combine the results with your stored class IDs and names
        #    and convert the tensor of probabilities to a list of floats
        results = zip(self.class_ids, self.class_names, probs.cpu().tolist())
        
        # 4. Create a sorted list of dictionaries for a clean JSON response
        return sorted(
            [{"id": i, "label": name, "score": float(p)} for i, name, p in results],
            key=lambda x: x["score"],
            reverse=True
        )



# """
# MobileCLIP‑B Zero‑Shot Image Classifier  (Hugging Face Inference Endpoint)
# ===========================================================================

# * One container instance is created per replica; the `EndpointHandler`
#   object below is instantiated exactly **once** at start‑up.

# * At request time (`__call__`) we receive a base‑64‑encoded image, run a
#   **single forward pass**, and return class probabilities.

# Design choices
# --------------

# 1. **Model & transform come from OpenCLIP**  
#    This guarantees we apply **identical preprocessing** to what the model
#    was trained with (224 × 224 crop + mean/std normalisation).

# 2. **Re‑parameterisation for inference**  
#    MobileCLIP uses MobileOne blocks that have extra convolution branches
#    for training; `reparameterize_model` fuses them so inference is fast
#    and deterministic.

# 3. **Text embeddings are cached**  
#    The class “prompts” (e.g. `"a photo of a cat"`) are encoded **once at
#    start‑up**.  Each request therefore encodes *only* the image and
#    performs a single matrix multiplication.

# 4. **Mixed precision on GPU**  
#    If the container has CUDA, we cast the model **and** inputs to
#    `float16`.  That halves memory and roughly doubles throughput on most
#    modern GPUs.  On CPU we stay in `float32` for numerical stability.
# """

# import contextlib, io, base64, json
# from pathlib import Path
# from typing import Any, Dict, List

# import torch
# from PIL import Image
# import open_clip

# from reparam import reparameterize_model   # local copy (~60 LoC) of Apple’s helper


# class EndpointHandler:
#     """
#     Hugging Face entry‑point.  The toolkit will instantiate this class
#     once and call it for every HTTP request.

#     Parameters
#     ----------
#     path : str, optional
#         Root directory of the repository.  HF mounts the code under
#         `/repository`; we use this path to locate `items.json`.
#     """

#     # ------------------------------------------------------------------ #
#     #                 INITIALISATION  (runs **once**)                     #
#     # ------------------------------------------------------------------ #
#     def __init__(self, path: str = "") -> None:
#         self.device = "cuda" if torch.cuda.is_available() else "cpu"

#         # 1️⃣  Load MobileCLIP‑B weights & transforms -------------------
#         #    `pretrained="datacompdr"` makes OpenCLIP download the
#         #    official checkpoint from the Hub (cached in the image layer).
#         model, _, self.preprocess = open_clip.create_model_and_transforms(
#             "MobileCLIP-B", pretrained="datacompdr"
#         )
#         model.eval()                       # disable dropout / BN updates
#         model = reparameterize_model(model)  # fuse MobileOne branches
#         model.to(self.device)
#         if self.device == "cuda":
#             model = model.to(torch.float16)  # FP16 for throughput
#         self.model = model                  # hold a reference

#         # 2️⃣  Build the tokenizer once --------------------------------
#         tokenizer = open_clip.get_tokenizer("MobileCLIP-B")

#         # 3️⃣  Load class metadata -------------------------------------
#         #     Expect JSON file: [{"id": 3, "name": "cat", "prompt": "cat"}, …]
#         items_path = Path(path) / "items.json"
#         with items_path.open("r", encoding="utf-8") as f:
#             class_defs: List[Dict[str, Any]] = json.load(f)

#         #     Extract the bits we need later
#         prompts                 = [item["prompt"] for item in class_defs]
#         self.class_ids:   List[int]   = [item["id"]   for item in class_defs]
#         self.class_names: List[str]   = [item["name"] for item in class_defs]

#         # 4️⃣  Encode all prompts once ---------------------------------
#         with torch.no_grad():
#             text_tokens  = tokenizer(prompts).to(self.device)
#             text_feats   = self.model.encode_text(text_tokens)
#             text_feats   = text_feats / text_feats.norm(dim=-1, keepdim=True)
#         self.text_features = text_feats           # [num_classes, 512]

#     # ------------------------------------------------------------------ #
#     #                          INFERENCE CALL                            #
#     # ------------------------------------------------------------------ #
#     def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
#         """
#         Parameters
#         ----------
#         data : dict
#             Either the raw payload `{"image": "<base64>"}` **or** the
#             Hugging Face convention `{"inputs": {...}}`.

#         Returns
#         -------
#         list of dict
#             Sorted list of `{"id": int, "label": str, "score": float}`.
#             Scores are the softmax probabilities over the *provided*
#             class list (they sum to 1.0).
#         """
#         # 1️⃣  Unpack the request payload ------------------------------
#         payload: Dict[str, Any] = data.get("inputs", data)
#         img_b64: str = payload["image"]

#         # 2️⃣  Decode + preprocess -------------------------------------
#         image      = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
#         img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)  # [1, 3, 224, 224]
#         if self.device == "cuda":
#             img_tensor = img_tensor.to(torch.float16)

#         # 3️⃣  Forward pass (image only) -------------------------------
#         with torch.no_grad():                    # no autograd graph
#             img_feat = self.model.encode_image(img_tensor)            # [1, 512]
#             img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) # L2‑normalise

#             # cosine similarity → logits → softmax probabilities
#             probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0]  # [num_classes]

#         # 4️⃣  Assemble JSON‑serialisable response ---------------------
#         results = zip(self.class_ids, self.class_names, probs.cpu().tolist())
#         return sorted(
#             [{"id": cid, "label": name, "score": float(p)} for cid, name, p in results],
#             key=lambda x: x["score"],
#             reverse=True,
#         )