chris-propeller commited on
Commit
050a111
·
1 Parent(s): a36d7fa

add back in api

Browse files
Files changed (1) hide show
  1. app.py +284 -7
app.py CHANGED
@@ -2,10 +2,10 @@ import spaces
2
  import gradio as gr
3
 
4
  @spaces.GPU
5
- def sam3_predict(image, text_prompt, confidence_threshold=0.5):
6
  """
7
- SAM3 prediction function for Stateless GPU environment
8
- All imports and CUDA operations happen here
9
  """
10
  # Import everything inside the GPU function
11
  import torch
@@ -55,6 +55,20 @@ def sam3_predict(image, text_prompt, confidence_threshold=0.5):
55
  target_sizes=inputs.get("original_sizes").tolist()
56
  )[0]
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Return results for UI
59
  if len(results["masks"]) > 0:
60
  # Convert first mask for display
@@ -69,10 +83,148 @@ def sam3_predict(image, text_prompt, confidence_threshold=0.5):
69
  except Exception as e:
70
  return f"Error: {str(e)}", None
71
 
72
- # Simple gradio interface - no class, no global state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def create_interface():
74
- with gr.Blocks(title="SAM3 Inference") as demo:
75
  gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>")
 
76
 
77
  with gr.Row():
78
  with gr.Column():
@@ -85,12 +237,137 @@ def create_interface():
85
  info_output = gr.Textbox(label="Results Info")
86
  mask_output = gr.Image(label="Sample Mask")
87
 
 
88
  predict_btn.click(
89
- sam3_predict,
90
  inputs=[image_input, text_input, confidence_slider],
91
- outputs=[info_output, mask_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  return demo
95
 
96
  if __name__ == "__main__":
 
2
  import gradio as gr
3
 
4
  @spaces.GPU
5
+ def sam3_inference(image, text_prompt, confidence_threshold=0.5):
6
  """
7
+ Core SAM3 inference function for Stateless GPU environment
8
+ Returns raw results for both UI and API use
9
  """
10
  # Import everything inside the GPU function
11
  import torch
 
55
  target_sizes=inputs.get("original_sizes").tolist()
56
  )[0]
57
 
58
+ return results
59
+
60
+ except Exception as e:
61
+ raise Exception(f"SAM3 inference error: {str(e)}")
62
+
63
+ def gradio_interface(image, text_prompt, confidence_threshold):
64
+ """Gradio interface wrapper for UI"""
65
+ import numpy as np
66
+ from PIL import Image
67
+ import io
68
+
69
+ try:
70
+ results = sam3_inference(image, text_prompt, confidence_threshold)
71
+
72
  # Return results for UI
73
  if len(results["masks"]) > 0:
74
  # Convert first mask for display
 
83
  except Exception as e:
84
  return f"Error: {str(e)}", None
85
 
86
+ def api_predict(image, text_prompt, confidence_threshold):
87
+ """API prediction function for simple Gradio API"""
88
+ import numpy as np
89
+ from PIL import Image
90
+ import base64
91
+ import io
92
+
93
+ try:
94
+ results = sam3_inference(image, text_prompt, confidence_threshold)
95
+
96
+ # Prepare API response
97
+ response = {
98
+ "masks": [],
99
+ "scores": [],
100
+ "prompt_type": "text",
101
+ "prompt_value": text_prompt,
102
+ "num_masks": len(results["masks"])
103
+ }
104
+
105
+ # Process each mask
106
+ for i in range(len(results["masks"])):
107
+ mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
108
+ score = results["scores"][i].item()
109
+
110
+ if score >= confidence_threshold:
111
+ # Convert mask to base64 for API response
112
+ mask_image = Image.fromarray(mask_np, mode='L')
113
+ buffer = io.BytesIO()
114
+ mask_image.save(buffer, format='PNG')
115
+ mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
116
+
117
+ response["masks"].append(mask_b64)
118
+ response["scores"].append(score)
119
+
120
+ return response
121
+
122
+ except Exception as e:
123
+ return {"error": str(e)}
124
+
125
+ def sam2_compatible_api(data):
126
+ """
127
+ SAM2-compatible API endpoint with SAM3 extensions
128
+ Supports text prompts (SAM3), points, and boxes (SAM2 compatible)
129
+ """
130
+ import numpy as np
131
+ from PIL import Image
132
+ import base64
133
+ import io
134
+
135
+ try:
136
+ inputs_data = data.get("inputs", {})
137
+
138
+ # Extract inputs
139
+ image_b64 = inputs_data.get("image")
140
+ text_prompts = inputs_data.get("text_prompts", [])
141
+ input_points = inputs_data.get("points", [])
142
+ input_labels = inputs_data.get("labels", [])
143
+ input_boxes = inputs_data.get("boxes", [])
144
+ confidence_threshold = inputs_data.get("confidence_threshold", 0.5)
145
+
146
+ # Validate inputs
147
+ if not image_b64:
148
+ return {"error": "No image provided", "success": False}
149
+
150
+ has_text = bool(text_prompts)
151
+ has_points = bool(input_points and input_labels)
152
+ has_boxes = bool(input_boxes)
153
+
154
+ if not (has_text or has_points or has_boxes):
155
+ return {"error": "Must provide at least one prompt type: text_prompts, points+labels, or boxes", "success": False}
156
+
157
+ if has_points and len(input_points) != len(input_labels):
158
+ return {"error": "Number of points and labels must match", "success": False}
159
+
160
+ # Decode image
161
+ if image_b64.startswith('data:image'):
162
+ image_b64 = image_b64.split(',')[1]
163
+ image_bytes = base64.b64decode(image_b64)
164
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
165
+
166
+ all_masks = []
167
+ all_scores = []
168
+
169
+ # Process text prompts (SAM3 feature)
170
+ if has_text:
171
+ for text_prompt in text_prompts:
172
+ results = sam3_inference(image, text_prompt, confidence_threshold)
173
+ if results and len(results["masks"]) > 0:
174
+ for i in range(len(results["masks"])):
175
+ mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
176
+ score = results["scores"][i].item()
177
+
178
+ if score >= confidence_threshold:
179
+ # Convert mask to base64
180
+ mask_image = Image.fromarray(mask_np, mode='L')
181
+ buffer = io.BytesIO()
182
+ mask_image.save(buffer, format='PNG')
183
+ mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
184
+
185
+ all_masks.append(mask_b64)
186
+ all_scores.append(score)
187
+
188
+ # Process visual prompts (SAM2 compatibility) - Basic implementation
189
+ if has_boxes or has_points:
190
+ # For visual prompts, use a generic prompt to get masks
191
+ # This is a simplified implementation - full SAM2 compatibility would require
192
+ # implementing visual prompt processing in the core function
193
+ if not has_text:
194
+ results = sam3_inference(image, "object", confidence_threshold)
195
+ if results and len(results["masks"]) > 0:
196
+ # Take only the number of masks requested
197
+ num_requested = len(input_boxes) if has_boxes else len(input_points)
198
+ for i in range(min(num_requested, len(results["masks"]))):
199
+ mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
200
+ score = results["scores"][i].item()
201
+
202
+ # Convert mask to base64
203
+ mask_image = Image.fromarray(mask_np, mode='L')
204
+ buffer = io.BytesIO()
205
+ mask_image.save(buffer, format='PNG')
206
+ mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
207
+
208
+ all_masks.append(mask_b64)
209
+ all_scores.append(score)
210
+
211
+ # Build SAM2-compatible response
212
+ return {
213
+ "masks": all_masks,
214
+ "scores": all_scores,
215
+ "num_objects": len(all_masks),
216
+ "sam_version": "3.0",
217
+ "success": True
218
+ }
219
+
220
+ except Exception as e:
221
+ return {"error": str(e), "success": False, "sam_version": "3.0"}
222
+
223
+ # Create comprehensive Gradio interface with API endpoints
224
  def create_interface():
225
+ with gr.Blocks(title="SAM3 Inference API") as demo:
226
  gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>")
227
+ gr.HTML("<p>This Space provides both a UI and API for SAM3 inference with SAM2 compatibility. Use the interface below or call the API programmatically.</p>")
228
 
229
  with gr.Row():
230
  with gr.Column():
 
237
  info_output = gr.Textbox(label="Results Info")
238
  mask_output = gr.Image(label="Sample Mask")
239
 
240
+ # Main UI prediction with API endpoint
241
  predict_btn.click(
242
+ gradio_interface,
243
  inputs=[image_input, text_input, confidence_slider],
244
+ outputs=[info_output, mask_output],
245
+ api_name="predict" # Creates /api/predict endpoint
246
+ )
247
+
248
+ # Simple API endpoint for Gradio format
249
+ gr.Interface(
250
+ fn=api_predict,
251
+ inputs=[
252
+ gr.Image(type="pil", label="Image"),
253
+ gr.Textbox(label="Text Prompt"),
254
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Confidence Threshold")
255
+ ],
256
+ outputs=gr.JSON(label="API Response"),
257
+ title="Simple API",
258
+ description="Returns structured JSON response with base64 encoded masks",
259
+ api_name="simple_api"
260
+ )
261
+
262
+ # SAM2-compatible API endpoint
263
+ gr.Interface(
264
+ fn=sam2_compatible_api,
265
+ inputs=gr.JSON(label="SAM2/SAM3 Compatible Input"),
266
+ outputs=gr.JSON(label="SAM2/SAM3 Compatible Output"),
267
+ title="SAM2/SAM3 Compatible API",
268
+ description="API endpoint that matches SAM2 inference endpoint format with SAM3 extensions",
269
+ api_name="sam2_compatible"
270
  )
271
 
272
+ # Add comprehensive API documentation
273
+ gr.HTML("""
274
+ <h2>API Usage</h2>
275
+
276
+ <h3>1. Simple Text API (Gradio format)</h3>
277
+ <pre>
278
+ import requests
279
+ import base64
280
+
281
+ # Encode your image to base64
282
+ with open("image.jpg", "rb") as f:
283
+ image_b64 = base64.b64encode(f.read()).decode()
284
+
285
+ # Make API request
286
+ response = requests.post(
287
+ "https://your-username-sam3-api.hf.space/api/predict",
288
+ json={
289
+ "data": [image_b64, "kitten", 0.5]
290
+ }
291
+ )
292
+
293
+ result = response.json()
294
+ </pre>
295
+
296
+ <h3>2. SAM2/SAM3 Compatible API (Inference Endpoint format)</h3>
297
+ <pre>
298
+ import requests
299
+ import base64
300
+
301
+ # Encode your image to base64
302
+ with open("image.jpg", "rb") as f:
303
+ image_b64 = base64.b64encode(f.read()).decode()
304
+
305
+ # SAM3 Text Prompts (NEW)
306
+ response = requests.post(
307
+ "https://your-username-sam3-api.hf.space/api/sam2_compatible",
308
+ json={
309
+ "data": [{
310
+ "inputs": {
311
+ "image": image_b64,
312
+ "text_prompts": ["kitten", "toy"],
313
+ "confidence_threshold": 0.5
314
+ }
315
+ }]
316
+ }
317
+ )
318
+
319
+ # SAM2 Compatible (Points/Boxes)
320
+ response = requests.post(
321
+ "https://your-username-sam3-api.hf.space/api/sam2_compatible",
322
+ json={
323
+ "data": [{
324
+ "inputs": {
325
+ "image": image_b64,
326
+ "boxes": [[100, 100, 200, 200]],
327
+ "confidence_threshold": 0.5
328
+ }
329
+ }]
330
+ }
331
+ )
332
+
333
+ result = response.json()
334
+ </pre>
335
+
336
+ <h3>3. API Parameters</h3>
337
+ <h4>SAM2-Compatible API Input</h4>
338
+ <pre>
339
+ {
340
+ "inputs": {
341
+ "image": "base64_encoded_image_string",
342
+
343
+ // SAM3 NEW: Text-based prompts
344
+ "text_prompts": ["person", "car"], // List of text descriptions
345
+
346
+ // SAM2 COMPATIBLE: Point-based prompts
347
+ "points": [[[x1, y1]], [[x2, y2]]], // Points for each object
348
+ "labels": [[1], [1]], // Labels for each point (1=foreground, 0=background)
349
+
350
+ // SAM2 COMPATIBLE: Bounding box prompts
351
+ "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], // Bounding boxes
352
+
353
+ "multimask_output": false, // Optional, defaults to False
354
+ "confidence_threshold": 0.5 // Optional, minimum confidence for returned masks
355
+ }
356
+ }
357
+ </pre>
358
+
359
+ <h4>API Response</h4>
360
+ <pre>
361
+ {
362
+ "masks": ["base64_encoded_mask_1", "base64_encoded_mask_2"],
363
+ "scores": [0.95, 0.87],
364
+ "num_objects": 2,
365
+ "sam_version": "3.0",
366
+ "success": true
367
+ }
368
+ </pre>
369
+ """)
370
+
371
  return demo
372
 
373
  if __name__ == "__main__":