Spaces:
Running
on
L4
Running
on
L4
| import spaces | |
| import gradio as gr | |
| def sam3_inference(image, text_prompt=None, boxes=None, box_labels=None, points=None, point_labels=None, confidence_threshold=0.5): | |
| """ | |
| Core SAM3 inference function for Stateless GPU environment | |
| Supports text prompts, box prompts, and point prompts (individually or combined) | |
| Returns raw results for both UI and API use | |
| """ | |
| # Import everything inside the GPU function | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import base64 | |
| import io | |
| from transformers import Sam3Model, Sam3Processor | |
| try: | |
| # Validate inputs | |
| if not text_prompt and not boxes and not points: | |
| raise ValueError("At least one of text_prompt, boxes, or points must be provided") | |
| if boxes and not box_labels: | |
| raise ValueError("box_labels must be provided when boxes are specified") | |
| if points and not point_labels: | |
| raise ValueError("point_labels must be provided when points are specified") | |
| # Handle base64 input if needed | |
| if isinstance(image, str): | |
| if image.startswith('data:image'): | |
| image = image.split(',')[1] | |
| image_bytes = base64.b64decode(image) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # Initialize model and processor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = Sam3Model.from_pretrained( | |
| "facebook/sam3", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ).to(device) | |
| processor = Sam3Processor.from_pretrained("facebook/sam3") | |
| # Prepare processor inputs based on prompt type | |
| processor_kwargs = { | |
| "images": image, | |
| "return_tensors": "pt" | |
| } | |
| # Add text prompt if provided | |
| if text_prompt: | |
| processor_kwargs["text"] = text_prompt.strip() | |
| # Add box prompts if provided | |
| if boxes and box_labels: | |
| # Convert boxes to expected format: [[x1, y1, x2, y2], ...] | |
| # Ensure boxes are in the right format for SAM3 | |
| formatted_boxes = [] | |
| formatted_labels = [] | |
| for i, box in enumerate(boxes): | |
| if len(box) == 4: # [x1, y1, x2, y2] | |
| formatted_boxes.append(box) | |
| # Use the provided label (supports both positive=1 and negative=0) | |
| if i < len(box_labels): | |
| formatted_labels.append(box_labels[i]) | |
| else: | |
| raise ValueError(f"Missing label for box {i}") | |
| if formatted_boxes: | |
| # Wrap in a single array to indicate batch size of 1 | |
| processor_kwargs["input_boxes"] = [formatted_boxes] | |
| processor_kwargs["input_boxes_labels"] = [formatted_labels] | |
| # Add point prompts if provided | |
| if points and point_labels: | |
| # Convert points to expected format: [[[x1, y1], [x2, y2]], ...] | |
| # SAM3 expects points as nested lists for batch processing | |
| formatted_points = [] | |
| formatted_point_labels = [] | |
| for i, point in enumerate(points): | |
| if len(point) == 2: # [x, y] | |
| formatted_points.append(point) | |
| # Use the provided label (supports both positive=1 and negative=0) | |
| if i < len(point_labels): | |
| formatted_point_labels.append(point_labels[i]) | |
| else: | |
| raise ValueError(f"Missing label for point {i}") | |
| if formatted_points: | |
| processor_kwargs["input_points"] = [formatted_points] | |
| processor_kwargs["input_points_labels"] = [formatted_point_labels] | |
| # Process input | |
| inputs = processor(**processor_kwargs).to(device) | |
| # Convert dtype to match model | |
| for key in inputs: | |
| if inputs[key].dtype == torch.float32: | |
| inputs[key] = inputs[key].to(model.dtype) | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Post-process | |
| results = processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=confidence_threshold, | |
| mask_threshold=0.5, | |
| target_sizes=inputs.get("original_sizes").tolist() | |
| )[0] | |
| return results | |
| except Exception as e: | |
| raise Exception(f"SAM3 inference error: {str(e)}") | |
| def gradio_interface(image, text_prompt, confidence_threshold): | |
| """Gradio interface wrapper for UI""" | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| try: | |
| results = sam3_inference(image, text_prompt=text_prompt, confidence_threshold=confidence_threshold) | |
| # Return results for UI | |
| if len(results["masks"]) > 0: | |
| # Convert first mask for display | |
| mask_np = results["masks"][0].cpu().numpy().astype(np.uint8) * 255 | |
| score = results["scores"][0].item() | |
| mask_image = Image.fromarray(mask_np, mode='L') | |
| return f"Found {len(results['masks'])} masks. Best score: {score:.3f}", mask_image | |
| else: | |
| return "No masks found above confidence threshold", None | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| def api_predict(image, text_prompt, confidence_threshold): | |
| """API prediction function for simple Gradio API""" | |
| import numpy as np | |
| from PIL import Image | |
| import base64 | |
| import io | |
| try: | |
| results = sam3_inference(image, text_prompt=text_prompt, confidence_threshold=confidence_threshold) | |
| # Prepare API response | |
| response = { | |
| "masks": [], | |
| "scores": [], | |
| "prompt_type": "text", | |
| "prompt_value": text_prompt, | |
| "num_masks": len(results["masks"]) | |
| } | |
| # Process each mask | |
| for i in range(len(results["masks"])): | |
| mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255 | |
| score = results["scores"][i].item() | |
| if score >= confidence_threshold: | |
| # Convert mask to base64 for API response | |
| mask_image = Image.fromarray(mask_np, mode='L') | |
| buffer = io.BytesIO() | |
| mask_image.save(buffer, format='PNG') | |
| mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| response["masks"].append(mask_b64) | |
| response["scores"].append(score) | |
| return response | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def _mask_to_polygons_original_size(binary_mask, epsilon=2.0): | |
| """ | |
| Convert binary mask to vector polygons (mask is already at original image size) | |
| Args: | |
| binary_mask: Binary mask array (0 or 1) at original image size | |
| epsilon: Polygon simplification epsilon | |
| Returns: | |
| List of polygons, where each polygon is a list of [x, y] points in pixel coordinates | |
| """ | |
| import cv2 | |
| import numpy as np | |
| try: | |
| # Find contours using OpenCV | |
| contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| polygons = [] | |
| for contour in contours: | |
| if len(contour) < 3: # Skip small contours | |
| continue | |
| # Simplify polygon using Douglas-Peucker algorithm | |
| simplified = cv2.approxPolyDP(contour, epsilon, True) | |
| # Convert to list of [x, y] points | |
| polygon_points = [[float(point[0][0]), float(point[0][1])] for point in simplified] | |
| # Only add polygons with at least 3 points | |
| if len(polygon_points) >= 3: | |
| polygons.append(polygon_points) | |
| return polygons | |
| except Exception as e: | |
| # Return empty list on error, but don't fail the entire request | |
| print(f"Warning: Polygon extraction failed: {e}") | |
| return [] | |
| def sam2_compatible_api(data): | |
| """ | |
| SAM2-compatible API endpoint with SAM3 extensions | |
| Supports text prompts (SAM3), points, and boxes (SAM2 compatible) | |
| Includes vectorize option for polygon extraction | |
| """ | |
| import numpy as np | |
| from PIL import Image | |
| import base64 | |
| import io | |
| import cv2 | |
| try: | |
| inputs_data = data.get("inputs", {}) | |
| # Extract inputs | |
| image_b64 = inputs_data.get("image") | |
| text_prompts = inputs_data.get("text_prompts", []) | |
| input_points = inputs_data.get("points", []) | |
| input_point_labels = inputs_data.get("point_labels", []) | |
| input_boxes = inputs_data.get("boxes", []) | |
| input_box_labels = inputs_data.get("box_labels", []) | |
| confidence_threshold = inputs_data.get("confidence_threshold", 0.5) | |
| vectorize = inputs_data.get("vectorize", False) | |
| simplify_epsilon = inputs_data.get("simplify_epsilon", 2.0) | |
| # Validate inputs | |
| if not image_b64: | |
| return {"error": "No image provided", "success": False} | |
| has_text = bool(text_prompts) | |
| has_points = bool(input_points and input_point_labels) | |
| has_boxes = bool(input_boxes) | |
| if not (has_text or has_points or has_boxes): | |
| return {"error": "Must provide at least one prompt type: text_prompts, points+point_labels, or boxes", "success": False} | |
| if has_points and len(input_points) != len(input_point_labels): | |
| return {"error": "Number of points and point_labels must match", "success": False} | |
| if has_boxes and input_box_labels and len(input_boxes) != len(input_box_labels): | |
| return {"error": "Number of boxes and box_labels must match", "success": False} | |
| # Decode image | |
| if image_b64.startswith('data:image'): | |
| image_b64 = image_b64.split(',')[1] | |
| image_bytes = base64.b64decode(image_b64) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| original_image_size = image.size # Store for response metadata | |
| all_masks = [] | |
| all_scores = [] | |
| all_polygons = [] | |
| prompt_types = [] | |
| # Determine what prompt types are being used | |
| if has_text: | |
| prompt_types.append("text") | |
| if has_points or has_boxes: | |
| prompt_types.append("visual") | |
| # Process text prompts individually (SAM3 works best with individual text prompts) | |
| if has_text: | |
| for text_prompt in text_prompts: | |
| if text_prompt.strip(): # Skip empty prompts | |
| results = sam3_inference( | |
| image=image, | |
| text_prompt=text_prompt.strip(), | |
| confidence_threshold=confidence_threshold | |
| ) | |
| if results and len(results["masks"]) > 0: | |
| for i in range(len(results["masks"])): | |
| mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255 | |
| score = results["scores"][i].item() | |
| if score >= confidence_threshold: | |
| # Convert mask to base64 | |
| mask_image = Image.fromarray(mask_np, mode='L') | |
| buffer = io.BytesIO() | |
| mask_image.save(buffer, format='PNG') | |
| mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| all_masks.append(mask_b64) | |
| all_scores.append(score) | |
| # Extract polygons if vectorize is enabled | |
| if vectorize: | |
| binary_mask = (mask_np > 0).astype(np.uint8) | |
| polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon) | |
| all_polygons.append(polygons) | |
| # Process visual prompts (boxes and/or points) - can be combined in a single call | |
| if has_boxes or has_points: | |
| combined_boxes = input_boxes if has_boxes else None | |
| combined_box_labels = input_box_labels if (has_boxes and input_box_labels) else ([1] * len(input_boxes) if has_boxes else None) | |
| combined_points = input_points if has_points else None | |
| combined_point_labels = input_point_labels if has_points else None | |
| results = sam3_inference( | |
| image=image, | |
| text_prompt=None, | |
| boxes=combined_boxes, | |
| box_labels=combined_box_labels, | |
| points=combined_points, | |
| point_labels=combined_point_labels, | |
| confidence_threshold=confidence_threshold | |
| ) | |
| if results and len(results["masks"]) > 0: | |
| for i in range(len(results["masks"])): | |
| mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255 | |
| score = results["scores"][i].item() | |
| if score >= confidence_threshold: | |
| # Convert mask to base64 | |
| mask_image = Image.fromarray(mask_np, mode='L') | |
| buffer = io.BytesIO() | |
| mask_image.save(buffer, format='PNG') | |
| mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| all_masks.append(mask_b64) | |
| all_scores.append(score) | |
| # Extract polygons if vectorize is enabled | |
| if vectorize: | |
| binary_mask = (mask_np > 0).astype(np.uint8) | |
| polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon) | |
| all_polygons.append(polygons) | |
| # Build SAM2-compatible response | |
| response = { | |
| "masks": all_masks, | |
| "scores": all_scores, | |
| "num_objects": len(all_masks), | |
| "sam_version": "3.0", | |
| "prompt_types": prompt_types, | |
| "success": True | |
| } | |
| # Add polygon data if vectorize is enabled | |
| if vectorize: | |
| response.update({ | |
| "polygons": all_polygons, | |
| "polygon_format": "pixel_coordinates", | |
| "original_image_size": original_image_size | |
| }) | |
| return response | |
| except Exception as e: | |
| return {"error": str(e), "success": False, "sam_version": "3.0"} | |
| # Create comprehensive Gradio interface with API endpoints | |
| def create_interface(): | |
| with gr.Blocks(title="SAM3 Inference API") as demo: | |
| gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>") | |
| gr.HTML("<p>This Space provides both a UI and API for SAM3 inference with SAM2 compatibility. Use the interface below or call the API programmatically.</p>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what to segment (e.g., 'cat', 'person', 'car')") | |
| confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold") | |
| predict_btn = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| info_output = gr.Textbox(label="Results Info") | |
| mask_output = gr.Image(label="Sample Mask") | |
| # Main UI prediction with API endpoint | |
| predict_btn.click( | |
| gradio_interface, | |
| inputs=[image_input, text_input, confidence_slider], | |
| outputs=[info_output, mask_output], | |
| api_name="predict" # Creates /api/predict endpoint | |
| ) | |
| # Simple API endpoint for Gradio format | |
| gr.Interface( | |
| fn=api_predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Image"), | |
| gr.Textbox(label="Text Prompt"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Confidence Threshold") | |
| ], | |
| outputs=gr.JSON(label="API Response"), | |
| title="Simple API", | |
| description="Returns structured JSON response with base64 encoded masks", | |
| api_name="simple_api" | |
| ) | |
| # SAM2-compatible API endpoint | |
| with gr.Row(): | |
| gr.HTML("<h3>SAM2/SAM3 Compatible API</h3>") | |
| with gr.Row(): | |
| api_input = gr.JSON(label="SAM2/SAM3 Compatible Input") | |
| api_output = gr.JSON(label="SAM2/SAM3 Compatible Output") | |
| with gr.Row(): | |
| api_btn = gr.Button("Test API", variant="secondary") | |
| # Create the API endpoint | |
| api_btn.click( | |
| fn=sam2_compatible_api, | |
| inputs=api_input, | |
| outputs=api_output, | |
| api_name="sam2_compatible" | |
| ) | |
| # Add comprehensive API documentation | |
| gr.HTML(""" | |
| <h2>API Usage</h2> | |
| <h3>1. Simple Text API (Gradio format)</h3> | |
| <pre> | |
| import requests | |
| import base64 | |
| # Encode your image to base64 | |
| with open("image.jpg", "rb") as f: | |
| image_b64 = base64.b64encode(f.read()).decode() | |
| # Make API request | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/predict", | |
| json={ | |
| "data": [image_b64, "kitten", 0.5] | |
| } | |
| ) | |
| result = response.json() | |
| </pre> | |
| <h3>2. SAM2/SAM3 Compatible API (Inference Endpoint format)</h3> | |
| <pre> | |
| import requests | |
| import base64 | |
| # Encode your image to base64 | |
| with open("image.jpg", "rb") as f: | |
| image_b64 = base64.b64encode(f.read()).decode() | |
| # SAM3 Text Prompts Only | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/sam2_compatible", | |
| json={ | |
| "inputs": { | |
| "image": image_b64, | |
| "text_prompts": ["kitten", "toy"], | |
| "confidence_threshold": 0.5 | |
| } | |
| } | |
| ) | |
| # SAM2 Compatible (Points/Boxes Only) | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/sam2_compatible", | |
| json={ | |
| "inputs": { | |
| "image": image_b64, | |
| "boxes": [[100, 100, 200, 200]], | |
| "box_labels": [1], # 1=positive, 0=negative | |
| "confidence_threshold": 0.5 | |
| } | |
| } | |
| ) | |
| # SAM3 with Multiple Text Prompts (processed individually) | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/sam2_compatible", | |
| json={ | |
| "inputs": { | |
| "image": image_b64, | |
| "text_prompts": ["cat", "dog"], # Each prompt processed separately | |
| "confidence_threshold": 0.5 | |
| } | |
| } | |
| ) | |
| # SAM3 Combined Visual Prompts (boxes + points in single call) | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/sam2_compatible", | |
| json={ | |
| "inputs": { | |
| "image": image_b64, | |
| "boxes": [[50, 50, 150, 150]], # Bounding box | |
| "box_labels": [0], # 0=negative (exclude this area) | |
| "points": [[200, 200]], # Point prompt | |
| "point_labels": [1], # 1=positive point | |
| "confidence_threshold": 0.5 | |
| } | |
| } | |
| ) | |
| # SAM3 with Vectorize (returns both masks and polygons) | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/sam2_compatible", | |
| json={ | |
| "inputs": { | |
| "image": image_b64, | |
| "text_prompts": ["cat"], | |
| "confidence_threshold": 0.5, | |
| "vectorize": true, | |
| "simplify_epsilon": 2.0 | |
| } | |
| } | |
| ) | |
| result = response.json() | |
| </pre> | |
| <h3>3. API Parameters</h3> | |
| <h4>SAM2-Compatible API Input</h4> | |
| <pre> | |
| { | |
| "inputs": { | |
| "image": "base64_encoded_image_string", | |
| // SAM3 NEW: Text-based prompts (each processed individually for best results) | |
| "text_prompts": ["person", "car"], // List of text descriptions - each processed separately | |
| // SAM2 COMPATIBLE: Point-based prompts (can be combined with text/boxes) | |
| "points": [[x1, y1], [x2, y2]], // Individual points (not nested arrays) | |
| "point_labels": [1, 0], // Labels for each point (1=positive/foreground, 0=negative/background) | |
| // SAM2 COMPATIBLE: Bounding box prompts (can be combined with text/points) | |
| "boxes": [[x1, y1, x2, y2], [x3, y3, x4, y4]], // Bounding boxes | |
| "box_labels": [1, 0], // Labels for each box (1=positive, 0=negative/exclude) | |
| "multimask_output": false, // Optional, defaults to False | |
| "confidence_threshold": 0.5, // Optional, minimum confidence for returned masks | |
| "vectorize": false, // Optional, return vector polygons instead of/in addition to bitmaps | |
| "simplify_epsilon": 2.0 // Optional, polygon simplification factor | |
| } | |
| } | |
| </pre> | |
| <h4>API Response</h4> | |
| <pre> | |
| { | |
| "masks": ["base64_encoded_mask_1", "base64_encoded_mask_2"], | |
| "scores": [0.95, 0.87], | |
| "num_objects": 2, | |
| "sam_version": "3.0", | |
| "prompt_types": ["text", "visual"], // Types of prompts used in the request | |
| "success": true, | |
| // If vectorize=true, additional fields: | |
| "polygons": [[[x1,y1],[x2,y2],...], [[x1,y1],...]], // Array of polygon arrays for each object | |
| "polygon_format": "pixel_coordinates", | |
| "original_image_size": [width, height] | |
| } | |
| </pre> | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |