chris-propeller commited on
Commit
a36d7fa
·
1 Parent(s): 334daaa
Files changed (3) hide show
  1. app-bak.py +342 -0
  2. app.py +51 -295
  3. test_minimal.py +15 -0
app-bak.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import base64
6
+ import io
7
+ from typing import Dict, Any
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ @spaces.GPU
12
+ def sam3_inference(image, text_prompt, confidence_threshold=0.5):
13
+ """
14
+ Standalone GPU function with model initialization for Spaces Stateless GPU
15
+ All CUDA operations and related imports must happen inside this decorated function
16
+ """
17
+ try:
18
+ # Import torch and transformers inside GPU function to avoid main process CUDA init
19
+ import torch
20
+ from transformers import Sam3Model, Sam3Processor
21
+
22
+ # Initialize model and processor inside GPU function (required for Stateless GPU)
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ model = Sam3Model.from_pretrained(
25
+ "facebook/sam3",
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
+ ).to(device)
28
+ processor = Sam3Processor.from_pretrained("facebook/sam3")
29
+ print(f"Model loaded on device: {device}")
30
+
31
+ # Handle base64 input (for API)
32
+ if isinstance(image, str):
33
+ if image.startswith('data:image'):
34
+ image = image.split(',')[1]
35
+ image_bytes = base64.b64decode(image)
36
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
37
+
38
+ # Process with SAM3
39
+ inputs = processor(
40
+ images=image,
41
+ text=text_prompt.strip(),
42
+ return_tensors="pt"
43
+ ).to(device)
44
+
45
+ # Convert dtype to match model
46
+ for key in inputs:
47
+ if inputs[key].dtype == torch.float32:
48
+ inputs[key] = inputs[key].to(model.dtype)
49
+
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ # Use proper SAM3 post-processing
54
+ results = processor.post_process_instance_segmentation(
55
+ outputs,
56
+ threshold=confidence_threshold,
57
+ mask_threshold=0.5,
58
+ target_sizes=inputs.get("original_sizes").tolist()
59
+ )[0]
60
+
61
+ return results
62
+
63
+ except Exception as e:
64
+ raise Exception(f"SAM3 inference error: {str(e)}")
65
+
66
+ class SAM3Handler:
67
+ """SAM3 handler for both UI and API access"""
68
+
69
+ def __init__(self):
70
+ print("SAM3 handler initialized (models will be loaded lazily)")
71
+
72
+ def predict(self, image, text_prompt, confidence_threshold=0.5):
73
+ """
74
+ Main prediction function for both UI and API
75
+
76
+ Args:
77
+ image: PIL Image or base64 string
78
+ text_prompt: String describing what to segment
79
+ confidence_threshold: Minimum confidence for masks
80
+
81
+ Returns:
82
+ Dict with masks, scores, and metadata
83
+ """
84
+ try:
85
+ # Call the standalone GPU function
86
+ results = sam3_inference(image, text_prompt, confidence_threshold)
87
+
88
+ # Prepare response
89
+ response = {
90
+ "masks": [],
91
+ "scores": [],
92
+ "prompt_type": "text",
93
+ "prompt_value": text_prompt,
94
+ "num_masks": len(results["masks"])
95
+ }
96
+
97
+ # Process each mask
98
+ for i in range(len(results["masks"])):
99
+ mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
100
+ score = results["scores"][i].item()
101
+
102
+ if score >= confidence_threshold:
103
+ # Convert mask to base64 for API response
104
+ mask_image = Image.fromarray(mask_np, mode='L')
105
+ buffer = io.BytesIO()
106
+ mask_image.save(buffer, format='PNG')
107
+ mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
108
+
109
+ response["masks"].append(mask_b64)
110
+ response["scores"].append(score)
111
+
112
+ return response
113
+
114
+ except Exception as e:
115
+ return {"error": str(e)}
116
+
117
+ # Initialize the handler
118
+ handler = SAM3Handler()
119
+
120
+ def gradio_interface(image, text_prompt, confidence_threshold):
121
+ """Gradio interface wrapper"""
122
+ result = handler.predict(image, text_prompt, confidence_threshold)
123
+
124
+ if "error" in result:
125
+ return f"Error: {result['error']}", None
126
+
127
+ # For UI, show the first mask as an example
128
+ if result["masks"]:
129
+ first_mask_b64 = result["masks"][0]
130
+ first_score = result["scores"][0]
131
+
132
+ # Decode first mask for display
133
+ mask_bytes = base64.b64decode(first_mask_b64)
134
+ mask_image = Image.open(io.BytesIO(mask_bytes))
135
+
136
+ info = f"Found {result['num_masks']} masks. First mask score: {first_score:.3f}"
137
+ return info, mask_image
138
+ else:
139
+ return "No masks found above confidence threshold", None
140
+
141
+ def api_predict(data: Dict[str, Any]) -> Dict[str, Any]:
142
+ """
143
+ API function matching SAM2 inference endpoint format
144
+
145
+ Expected input format (matching SAM2 + SAM3 extensions):
146
+ {
147
+ "inputs": {
148
+ "image": "base64_encoded_image_string",
149
+
150
+ # SAM3 NEW: Text-based prompts
151
+ "text_prompts": ["person", "car"], # List of text descriptions
152
+
153
+ # SAM2 compatible: Point-based prompts
154
+ "points": [[[x1, y1]], [[x2, y2]]], # Points for each object
155
+ "labels": [[1], [1]], # Labels for each point (1=foreground, 0=background)
156
+
157
+ # SAM2 compatible: Bounding box prompts
158
+ "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], # Bounding boxes
159
+
160
+ "multimask_output": false, # Optional, defaults to False
161
+ "confidence_threshold": 0.5 # Optional, minimum confidence for returned masks
162
+ }
163
+ }
164
+
165
+ Returns (matching SAM2 format):
166
+ {
167
+ "masks": [base64_encoded_mask_1, base64_encoded_mask_2, ...],
168
+ "scores": [score1, score2, ...],
169
+ "num_objects": int,
170
+ "sam_version": "3.0",
171
+ "success": true
172
+ }
173
+ """
174
+ try:
175
+ inputs_data = data.get("inputs", {})
176
+
177
+ # Extract inputs
178
+ image_b64 = inputs_data.get("image")
179
+ text_prompts = inputs_data.get("text_prompts", [])
180
+ input_points = inputs_data.get("points", [])
181
+ input_labels = inputs_data.get("labels", [])
182
+ input_boxes = inputs_data.get("boxes", [])
183
+ multimask_output = inputs_data.get("multimask_output", False)
184
+ confidence_threshold = inputs_data.get("confidence_threshold", 0.5)
185
+
186
+ # Validate inputs
187
+ if not image_b64:
188
+ return {"error": "No image provided", "success": False}
189
+
190
+ has_text = bool(text_prompts)
191
+ has_points = bool(input_points and input_labels)
192
+ has_boxes = bool(input_boxes)
193
+
194
+ if not (has_text or has_points or has_boxes):
195
+ return {"error": "Must provide at least one prompt type: text_prompts, points+labels, or boxes", "success": False}
196
+
197
+ if has_points and len(input_points) != len(input_labels):
198
+ return {"error": "Number of points and labels must match", "success": False}
199
+
200
+ # Decode image
201
+ if image_b64.startswith('data:image'):
202
+ image_b64 = image_b64.split(',')[1]
203
+ image_bytes = base64.b64decode(image_b64)
204
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
205
+
206
+ all_masks = []
207
+ all_scores = []
208
+
209
+ # Process text prompts (SAM3 feature)
210
+ if has_text:
211
+ for text_prompt in text_prompts:
212
+ result = handler.predict(image, text_prompt, confidence_threshold)
213
+ if "error" not in result:
214
+ all_masks.extend(result["masks"])
215
+ all_scores.extend(result["scores"])
216
+
217
+ # Process visual prompts (SAM2 compatibility) - Basic implementation
218
+ # Note: This is a simplified version. Full SAM2 compatibility would require
219
+ # implementing the visual prompt processing in the handler
220
+ if has_boxes or has_points:
221
+ # For now, fall back to a generic prompt if no text provided
222
+ if not has_text:
223
+ result = handler.predict(image, "object", confidence_threshold)
224
+ if "error" not in result and result["masks"]:
225
+ # Take only the number of masks requested
226
+ num_requested = len(input_boxes) if has_boxes else len(input_points)
227
+ all_masks.extend(result["masks"][:num_requested])
228
+ all_scores.extend(result["scores"][:num_requested])
229
+
230
+ # Build SAM2-compatible response
231
+ return {
232
+ "masks": all_masks,
233
+ "scores": all_scores,
234
+ "num_objects": len(all_masks),
235
+ "sam_version": "3.0",
236
+ "success": True
237
+ }
238
+
239
+ except Exception as e:
240
+ return {"error": str(e), "success": False, "sam_version": "3.0"}
241
+
242
+ # Create Gradio interface
243
+ with gr.Blocks(title="SAM3 Inference API") as demo:
244
+ gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>")
245
+ gr.HTML("<p>This Space provides both a UI and API for SAM3 inference. Use the interface below or call the API programmatically.</p>")
246
+
247
+ with gr.Row():
248
+ with gr.Column():
249
+ image_input = gr.Image(type="pil", label="Input Image")
250
+ text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what you want to segment (e.g., 'cat', 'person', 'car')")
251
+ confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold")
252
+ predict_btn = gr.Button("Segment", variant="primary")
253
+
254
+ with gr.Column():
255
+ info_output = gr.Textbox(label="Results Info")
256
+ mask_output = gr.Image(label="Sample Mask")
257
+
258
+ # API endpoint - this creates /api/predict/
259
+ predict_btn.click(
260
+ gradio_interface,
261
+ inputs=[image_input, text_input, confidence_slider],
262
+ outputs=[info_output, mask_output],
263
+ api_name="predict" # This creates the API endpoint
264
+ )
265
+
266
+ # SAM2-compatible API endpoint - this creates /api/sam2_compatible/
267
+ gr.Interface(
268
+ fn=api_predict,
269
+ inputs=gr.JSON(label="SAM2/SAM3 Compatible Input"),
270
+ outputs=gr.JSON(label="SAM2/SAM3 Compatible Output"),
271
+ title="SAM2/SAM3 Compatible API",
272
+ description="API endpoint that matches SAM2 inference endpoint format with SAM3 extensions",
273
+ api_name="sam2_compatible"
274
+ )
275
+
276
+ # Add API documentation
277
+ gr.HTML("""
278
+ <h2>API Usage</h2>
279
+
280
+ <h3>1. Simple Text API (Gradio format)</h3>
281
+ <pre>
282
+ import requests
283
+ import base64
284
+
285
+ # Encode your image to base64
286
+ with open("image.jpg", "rb") as f:
287
+ image_b64 = base64.b64encode(f.read()).decode()
288
+
289
+ # Make API request
290
+ response = requests.post(
291
+ "https://your-username-sam3-api.hf.space/api/predict",
292
+ json={
293
+ "data": [image_b64, "kitten", 0.5]
294
+ }
295
+ )
296
+
297
+ result = response.json()
298
+ </pre>
299
+
300
+ <h3>2. SAM2/SAM3 Compatible API (Inference Endpoint format)</h3>
301
+ <pre>
302
+ import requests
303
+ import base64
304
+
305
+ # Encode your image to base64
306
+ with open("image.jpg", "rb") as f:
307
+ image_b64 = base64.b64encode(f.read()).decode()
308
+
309
+ # SAM3 Text Prompts (NEW)
310
+ response = requests.post(
311
+ "https://your-username-sam3-api.hf.space/api/sam2_compatible",
312
+ json={
313
+ "data": [{
314
+ "inputs": {
315
+ "image": image_b64,
316
+ "text_prompts": ["kitten", "toy"],
317
+ "confidence_threshold": 0.5
318
+ }
319
+ }]
320
+ }
321
+ )
322
+
323
+ # SAM2 Compatible (Points/Boxes)
324
+ response = requests.post(
325
+ "https://your-username-sam3-api.hf.space/api/sam2_compatible",
326
+ json={
327
+ "data": [{
328
+ "inputs": {
329
+ "image": image_b64,
330
+ "boxes": [[100, 100, 200, 200]],
331
+ "confidence_threshold": 0.5
332
+ }
333
+ }]
334
+ }
335
+ )
336
+
337
+ result = response.json()
338
+ </pre>
339
+ """)
340
+
341
+ if __name__ == "__main__":
342
+ demo.launch(server_name="0.0.0.0", server_port=7860)
app.py CHANGED
@@ -1,41 +1,37 @@
1
  import spaces
2
  import gradio as gr
3
- import numpy as np
4
- from PIL import Image
5
- import base64
6
- import io
7
- from typing import Dict, Any
8
- import warnings
9
- warnings.filterwarnings("ignore")
10
 
11
  @spaces.GPU
12
- def sam3_inference(image, text_prompt, confidence_threshold=0.5):
13
  """
14
- Standalone GPU function with model initialization for Spaces Stateless GPU
15
- All CUDA operations and related imports must happen inside this decorated function
16
  """
 
 
 
 
 
 
 
 
17
  try:
18
- # Import torch and transformers inside GPU function to avoid main process CUDA init
19
- import torch
20
- from transformers import Sam3Model, Sam3Processor
 
 
 
21
 
22
- # Initialize model and processor inside GPU function (required for Stateless GPU)
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  model = Sam3Model.from_pretrained(
25
  "facebook/sam3",
26
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
  ).to(device)
28
  processor = Sam3Processor.from_pretrained("facebook/sam3")
29
- print(f"Model loaded on device: {device}")
30
 
31
- # Handle base64 input (for API)
32
- if isinstance(image, str):
33
- if image.startswith('data:image'):
34
- image = image.split(',')[1]
35
- image_bytes = base64.b64decode(image)
36
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
37
-
38
- # Process with SAM3
39
  inputs = processor(
40
  images=image,
41
  text=text_prompt.strip(),
@@ -47,10 +43,11 @@ def sam3_inference(image, text_prompt, confidence_threshold=0.5):
47
  if inputs[key].dtype == torch.float32:
48
  inputs[key] = inputs[key].to(model.dtype)
49
 
 
50
  with torch.no_grad():
51
  outputs = model(**inputs)
52
 
53
- # Use proper SAM3 post-processing
54
  results = processor.post_process_instance_segmentation(
55
  outputs,
56
  threshold=confidence_threshold,
@@ -58,285 +55,44 @@ def sam3_inference(image, text_prompt, confidence_threshold=0.5):
58
  target_sizes=inputs.get("original_sizes").tolist()
59
  )[0]
60
 
61
- return results
62
-
63
- except Exception as e:
64
- raise Exception(f"SAM3 inference error: {str(e)}")
65
-
66
- class SAM3Handler:
67
- """SAM3 handler for both UI and API access"""
68
-
69
- def __init__(self):
70
- print("SAM3 handler initialized (models will be loaded lazily)")
71
-
72
- def predict(self, image, text_prompt, confidence_threshold=0.5):
73
- """
74
- Main prediction function for both UI and API
75
-
76
- Args:
77
- image: PIL Image or base64 string
78
- text_prompt: String describing what to segment
79
- confidence_threshold: Minimum confidence for masks
80
-
81
- Returns:
82
- Dict with masks, scores, and metadata
83
- """
84
- try:
85
- # Call the standalone GPU function
86
- results = sam3_inference(image, text_prompt, confidence_threshold)
87
-
88
- # Prepare response
89
- response = {
90
- "masks": [],
91
- "scores": [],
92
- "prompt_type": "text",
93
- "prompt_value": text_prompt,
94
- "num_masks": len(results["masks"])
95
- }
96
-
97
- # Process each mask
98
- for i in range(len(results["masks"])):
99
- mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
100
- score = results["scores"][i].item()
101
-
102
- if score >= confidence_threshold:
103
- # Convert mask to base64 for API response
104
- mask_image = Image.fromarray(mask_np, mode='L')
105
- buffer = io.BytesIO()
106
- mask_image.save(buffer, format='PNG')
107
- mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
108
-
109
- response["masks"].append(mask_b64)
110
- response["scores"].append(score)
111
-
112
- return response
113
-
114
- except Exception as e:
115
- return {"error": str(e)}
116
-
117
- # Initialize the handler
118
- handler = SAM3Handler()
119
-
120
- def gradio_interface(image, text_prompt, confidence_threshold):
121
- """Gradio interface wrapper"""
122
- result = handler.predict(image, text_prompt, confidence_threshold)
123
-
124
- if "error" in result:
125
- return f"Error: {result['error']}", None
126
-
127
- # For UI, show the first mask as an example
128
- if result["masks"]:
129
- first_mask_b64 = result["masks"][0]
130
- first_score = result["scores"][0]
131
-
132
- # Decode first mask for display
133
- mask_bytes = base64.b64decode(first_mask_b64)
134
- mask_image = Image.open(io.BytesIO(mask_bytes))
135
 
136
- info = f"Found {result['num_masks']} masks. First mask score: {first_score:.3f}"
137
- return info, mask_image
138
- else:
139
- return "No masks found above confidence threshold", None
140
-
141
- def api_predict(data: Dict[str, Any]) -> Dict[str, Any]:
142
- """
143
- API function matching SAM2 inference endpoint format
144
-
145
- Expected input format (matching SAM2 + SAM3 extensions):
146
- {
147
- "inputs": {
148
- "image": "base64_encoded_image_string",
149
-
150
- # SAM3 NEW: Text-based prompts
151
- "text_prompts": ["person", "car"], # List of text descriptions
152
-
153
- # SAM2 compatible: Point-based prompts
154
- "points": [[[x1, y1]], [[x2, y2]]], # Points for each object
155
- "labels": [[1], [1]], # Labels for each point (1=foreground, 0=background)
156
-
157
- # SAM2 compatible: Bounding box prompts
158
- "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], # Bounding boxes
159
-
160
- "multimask_output": false, # Optional, defaults to False
161
- "confidence_threshold": 0.5 # Optional, minimum confidence for returned masks
162
- }
163
- }
164
-
165
- Returns (matching SAM2 format):
166
- {
167
- "masks": [base64_encoded_mask_1, base64_encoded_mask_2, ...],
168
- "scores": [score1, score2, ...],
169
- "num_objects": int,
170
- "sam_version": "3.0",
171
- "success": true
172
- }
173
- """
174
- try:
175
- inputs_data = data.get("inputs", {})
176
-
177
- # Extract inputs
178
- image_b64 = inputs_data.get("image")
179
- text_prompts = inputs_data.get("text_prompts", [])
180
- input_points = inputs_data.get("points", [])
181
- input_labels = inputs_data.get("labels", [])
182
- input_boxes = inputs_data.get("boxes", [])
183
- multimask_output = inputs_data.get("multimask_output", False)
184
- confidence_threshold = inputs_data.get("confidence_threshold", 0.5)
185
-
186
- # Validate inputs
187
- if not image_b64:
188
- return {"error": "No image provided", "success": False}
189
-
190
- has_text = bool(text_prompts)
191
- has_points = bool(input_points and input_labels)
192
- has_boxes = bool(input_boxes)
193
-
194
- if not (has_text or has_points or has_boxes):
195
- return {"error": "Must provide at least one prompt type: text_prompts, points+labels, or boxes", "success": False}
196
-
197
- if has_points and len(input_points) != len(input_labels):
198
- return {"error": "Number of points and labels must match", "success": False}
199
-
200
- # Decode image
201
- if image_b64.startswith('data:image'):
202
- image_b64 = image_b64.split(',')[1]
203
- image_bytes = base64.b64decode(image_b64)
204
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
205
-
206
- all_masks = []
207
- all_scores = []
208
-
209
- # Process text prompts (SAM3 feature)
210
- if has_text:
211
- for text_prompt in text_prompts:
212
- result = handler.predict(image, text_prompt, confidence_threshold)
213
- if "error" not in result:
214
- all_masks.extend(result["masks"])
215
- all_scores.extend(result["scores"])
216
-
217
- # Process visual prompts (SAM2 compatibility) - Basic implementation
218
- # Note: This is a simplified version. Full SAM2 compatibility would require
219
- # implementing the visual prompt processing in the handler
220
- if has_boxes or has_points:
221
- # For now, fall back to a generic prompt if no text provided
222
- if not has_text:
223
- result = handler.predict(image, "object", confidence_threshold)
224
- if "error" not in result and result["masks"]:
225
- # Take only the number of masks requested
226
- num_requested = len(input_boxes) if has_boxes else len(input_points)
227
- all_masks.extend(result["masks"][:num_requested])
228
- all_scores.extend(result["scores"][:num_requested])
229
-
230
- # Build SAM2-compatible response
231
- return {
232
- "masks": all_masks,
233
- "scores": all_scores,
234
- "num_objects": len(all_masks),
235
- "sam_version": "3.0",
236
- "success": True
237
- }
238
 
239
  except Exception as e:
240
- return {"error": str(e), "success": False, "sam_version": "3.0"}
241
-
242
- # Create Gradio interface
243
- with gr.Blocks(title="SAM3 Inference API") as demo:
244
- gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>")
245
- gr.HTML("<p>This Space provides both a UI and API for SAM3 inference. Use the interface below or call the API programmatically.</p>")
246
-
247
- with gr.Row():
248
- with gr.Column():
249
- image_input = gr.Image(type="pil", label="Input Image")
250
- text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what you want to segment (e.g., 'cat', 'person', 'car')")
251
- confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold")
252
- predict_btn = gr.Button("Segment", variant="primary")
253
-
254
- with gr.Column():
255
- info_output = gr.Textbox(label="Results Info")
256
- mask_output = gr.Image(label="Sample Mask")
257
-
258
- # API endpoint - this creates /api/predict/
259
- predict_btn.click(
260
- gradio_interface,
261
- inputs=[image_input, text_input, confidence_slider],
262
- outputs=[info_output, mask_output],
263
- api_name="predict" # This creates the API endpoint
264
- )
265
-
266
- # SAM2-compatible API endpoint - this creates /api/sam2_compatible/
267
- gr.Interface(
268
- fn=api_predict,
269
- inputs=gr.JSON(label="SAM2/SAM3 Compatible Input"),
270
- outputs=gr.JSON(label="SAM2/SAM3 Compatible Output"),
271
- title="SAM2/SAM3 Compatible API",
272
- description="API endpoint that matches SAM2 inference endpoint format with SAM3 extensions",
273
- api_name="sam2_compatible"
274
- )
275
-
276
- # Add API documentation
277
- gr.HTML("""
278
- <h2>API Usage</h2>
279
-
280
- <h3>1. Simple Text API (Gradio format)</h3>
281
- <pre>
282
- import requests
283
- import base64
284
-
285
- # Encode your image to base64
286
- with open("image.jpg", "rb") as f:
287
- image_b64 = base64.b64encode(f.read()).decode()
288
-
289
- # Make API request
290
- response = requests.post(
291
- "https://your-username-sam3-api.hf.space/api/predict",
292
- json={
293
- "data": [image_b64, "kitten", 0.5]
294
- }
295
- )
296
-
297
- result = response.json()
298
- </pre>
299
 
300
- <h3>2. SAM2/SAM3 Compatible API (Inference Endpoint format)</h3>
301
- <pre>
302
- import requests
303
- import base64
304
 
305
- # Encode your image to base64
306
- with open("image.jpg", "rb") as f:
307
- image_b64 = base64.b64encode(f.read()).decode()
 
 
 
308
 
309
- # SAM3 Text Prompts (NEW)
310
- response = requests.post(
311
- "https://your-username-sam3-api.hf.space/api/sam2_compatible",
312
- json={
313
- "data": [{
314
- "inputs": {
315
- "image": image_b64,
316
- "text_prompts": ["kitten", "toy"],
317
- "confidence_threshold": 0.5
318
- }
319
- }]
320
- }
321
- )
322
 
323
- # SAM2 Compatible (Points/Boxes)
324
- response = requests.post(
325
- "https://your-username-sam3-api.hf.space/api/sam2_compatible",
326
- json={
327
- "data": [{
328
- "inputs": {
329
- "image": image_b64,
330
- "boxes": [[100, 100, 200, 200]],
331
- "confidence_threshold": 0.5
332
- }
333
- }]
334
- }
335
- )
336
 
337
- result = response.json()
338
- </pre>
339
- """)
340
 
341
  if __name__ == "__main__":
 
342
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import spaces
2
  import gradio as gr
 
 
 
 
 
 
 
3
 
4
  @spaces.GPU
5
+ def sam3_predict(image, text_prompt, confidence_threshold=0.5):
6
  """
7
+ SAM3 prediction function for Stateless GPU environment
8
+ All imports and CUDA operations happen here
9
  """
10
+ # Import everything inside the GPU function
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import base64
15
+ import io
16
+ from transformers import Sam3Model, Sam3Processor
17
+
18
  try:
19
+ # Handle base64 input if needed
20
+ if isinstance(image, str):
21
+ if image.startswith('data:image'):
22
+ image = image.split(',')[1]
23
+ image_bytes = base64.b64decode(image)
24
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
25
 
26
+ # Initialize model and processor
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  model = Sam3Model.from_pretrained(
29
  "facebook/sam3",
30
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
31
  ).to(device)
32
  processor = Sam3Processor.from_pretrained("facebook/sam3")
 
33
 
34
+ # Process input
 
 
 
 
 
 
 
35
  inputs = processor(
36
  images=image,
37
  text=text_prompt.strip(),
 
43
  if inputs[key].dtype == torch.float32:
44
  inputs[key] = inputs[key].to(model.dtype)
45
 
46
+ # Run inference
47
  with torch.no_grad():
48
  outputs = model(**inputs)
49
 
50
+ # Post-process
51
  results = processor.post_process_instance_segmentation(
52
  outputs,
53
  threshold=confidence_threshold,
 
55
  target_sizes=inputs.get("original_sizes").tolist()
56
  )[0]
57
 
58
+ # Return results for UI
59
+ if len(results["masks"]) > 0:
60
+ # Convert first mask for display
61
+ mask_np = results["masks"][0].cpu().numpy().astype(np.uint8) * 255
62
+ score = results["scores"][0].item()
63
+ mask_image = Image.fromarray(mask_np, mode='L')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ return f"Found {len(results['masks'])} masks. Best score: {score:.3f}", mask_image
66
+ else:
67
+ return "No masks found above confidence threshold", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  except Exception as e:
70
+ return f"Error: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # Simple gradio interface - no class, no global state
73
+ def create_interface():
74
+ with gr.Blocks(title="SAM3 Inference") as demo:
75
+ gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>")
76
 
77
+ with gr.Row():
78
+ with gr.Column():
79
+ image_input = gr.Image(type="pil", label="Input Image")
80
+ text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what to segment (e.g., 'cat', 'person', 'car')")
81
+ confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold")
82
+ predict_btn = gr.Button("Segment", variant="primary")
83
 
84
+ with gr.Column():
85
+ info_output = gr.Textbox(label="Results Info")
86
+ mask_output = gr.Image(label="Sample Mask")
 
 
 
 
 
 
 
 
 
 
87
 
88
+ predict_btn.click(
89
+ sam3_predict,
90
+ inputs=[image_input, text_input, confidence_slider],
91
+ outputs=[info_output, mask_output]
92
+ )
 
 
 
 
 
 
 
 
93
 
94
+ return demo
 
 
95
 
96
  if __name__ == "__main__":
97
+ demo = create_interface()
98
  demo.launch(server_name="0.0.0.0", server_port=7860)
test_minimal.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ @spaces.GPU
4
+ def test_gpu():
5
+ import torch
6
+ from transformers import Sam3Model, Sam3Processor
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"Test GPU function works! Device: {device}")
10
+ return f"GPU test successful on {device}"
11
+
12
+ if __name__ == "__main__":
13
+ print("Starting minimal test...")
14
+ result = test_gpu()
15
+ print(result)