import spaces import gradio as gr import numpy as np from PIL import Image import base64 import io from typing import Dict, Any import warnings warnings.filterwarnings("ignore") @spaces.GPU def sam3_inference(image, text_prompt, confidence_threshold=0.5): """ Standalone GPU function with model initialization for Spaces Stateless GPU All CUDA operations and related imports must happen inside this decorated function """ try: # Import torch and transformers inside GPU function to avoid main process CUDA init import torch from transformers import Sam3Model, Sam3Processor # Initialize model and processor inside GPU function (required for Stateless GPU) device = "cuda" if torch.cuda.is_available() else "cpu" model = Sam3Model.from_pretrained( "facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) processor = Sam3Processor.from_pretrained("facebook/sam3") print(f"Model loaded on device: {device}") # Handle base64 input (for API) if isinstance(image, str): if image.startswith('data:image'): image = image.split(',')[1] image_bytes = base64.b64decode(image) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Process with SAM3 inputs = processor( images=image, text=text_prompt.strip(), return_tensors="pt" ).to(device) # Convert dtype to match model for key in inputs: if inputs[key].dtype == torch.float32: inputs[key] = inputs[key].to(model.dtype) with torch.no_grad(): outputs = model(**inputs) # Use proper SAM3 post-processing results = processor.post_process_instance_segmentation( outputs, threshold=confidence_threshold, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist() )[0] return results except Exception as e: raise Exception(f"SAM3 inference error: {str(e)}") class SAM3Handler: """SAM3 handler for both UI and API access""" def __init__(self): print("SAM3 handler initialized (models will be loaded lazily)") def predict(self, image, text_prompt, confidence_threshold=0.5): """ Main prediction function for both UI and API Args: image: PIL Image or base64 string text_prompt: String describing what to segment confidence_threshold: Minimum confidence for masks Returns: Dict with masks, scores, and metadata """ try: # Call the standalone GPU function results = sam3_inference(image, text_prompt, confidence_threshold) # Prepare response response = { "masks": [], "scores": [], "prompt_type": "text", "prompt_value": text_prompt, "num_masks": len(results["masks"]) } # Process each mask for i in range(len(results["masks"])): mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255 score = results["scores"][i].item() if score >= confidence_threshold: # Convert mask to base64 for API response mask_image = Image.fromarray(mask_np, mode='L') buffer = io.BytesIO() mask_image.save(buffer, format='PNG') mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8') response["masks"].append(mask_b64) response["scores"].append(score) return response except Exception as e: return {"error": str(e)} # Initialize the handler handler = SAM3Handler() def gradio_interface(image, text_prompt, confidence_threshold): """Gradio interface wrapper""" result = handler.predict(image, text_prompt, confidence_threshold) if "error" in result: return f"Error: {result['error']}", None # For UI, show the first mask as an example if result["masks"]: first_mask_b64 = result["masks"][0] first_score = result["scores"][0] # Decode first mask for display mask_bytes = base64.b64decode(first_mask_b64) mask_image = Image.open(io.BytesIO(mask_bytes)) info = f"Found {result['num_masks']} masks. First mask score: {first_score:.3f}" return info, mask_image else: return "No masks found above confidence threshold", None def api_predict(data: Dict[str, Any]) -> Dict[str, Any]: """ API function matching SAM2 inference endpoint format Expected input format (matching SAM2 + SAM3 extensions): { "inputs": { "image": "base64_encoded_image_string", # SAM3 NEW: Text-based prompts "text_prompts": ["person", "car"], # List of text descriptions # SAM2 compatible: Point-based prompts "points": [[[x1, y1]], [[x2, y2]]], # Points for each object "labels": [[1], [1]], # Labels for each point (1=foreground, 0=background) # SAM2 compatible: Bounding box prompts "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], # Bounding boxes "multimask_output": false, # Optional, defaults to False "confidence_threshold": 0.5 # Optional, minimum confidence for returned masks } } Returns (matching SAM2 format): { "masks": [base64_encoded_mask_1, base64_encoded_mask_2, ...], "scores": [score1, score2, ...], "num_objects": int, "sam_version": "3.0", "success": true } """ try: inputs_data = data.get("inputs", {}) # Extract inputs image_b64 = inputs_data.get("image") text_prompts = inputs_data.get("text_prompts", []) input_points = inputs_data.get("points", []) input_labels = inputs_data.get("labels", []) input_boxes = inputs_data.get("boxes", []) multimask_output = inputs_data.get("multimask_output", False) confidence_threshold = inputs_data.get("confidence_threshold", 0.5) # Validate inputs if not image_b64: return {"error": "No image provided", "success": False} has_text = bool(text_prompts) has_points = bool(input_points and input_labels) has_boxes = bool(input_boxes) if not (has_text or has_points or has_boxes): return {"error": "Must provide at least one prompt type: text_prompts, points+labels, or boxes", "success": False} if has_points and len(input_points) != len(input_labels): return {"error": "Number of points and labels must match", "success": False} # Decode image if image_b64.startswith('data:image'): image_b64 = image_b64.split(',')[1] image_bytes = base64.b64decode(image_b64) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") all_masks = [] all_scores = [] # Process text prompts (SAM3 feature) if has_text: for text_prompt in text_prompts: result = handler.predict(image, text_prompt, confidence_threshold) if "error" not in result: all_masks.extend(result["masks"]) all_scores.extend(result["scores"]) # Process visual prompts (SAM2 compatibility) - Basic implementation # Note: This is a simplified version. Full SAM2 compatibility would require # implementing the visual prompt processing in the handler if has_boxes or has_points: # For now, fall back to a generic prompt if no text provided if not has_text: result = handler.predict(image, "object", confidence_threshold) if "error" not in result and result["masks"]: # Take only the number of masks requested num_requested = len(input_boxes) if has_boxes else len(input_points) all_masks.extend(result["masks"][:num_requested]) all_scores.extend(result["scores"][:num_requested]) # Build SAM2-compatible response return { "masks": all_masks, "scores": all_scores, "num_objects": len(all_masks), "sam_version": "3.0", "success": True } except Exception as e: return {"error": str(e), "success": False, "sam_version": "3.0"} # Create Gradio interface with gr.Blocks(title="SAM3 Inference API") as demo: gr.HTML("
This Space provides both a UI and API for SAM3 inference. Use the interface below or call the API programmatically.
") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Input Image") text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what you want to segment (e.g., 'cat', 'person', 'car')") confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold") predict_btn = gr.Button("Segment", variant="primary") with gr.Column(): info_output = gr.Textbox(label="Results Info") mask_output = gr.Image(label="Sample Mask") # API endpoint - this creates /api/predict/ predict_btn.click( gradio_interface, inputs=[image_input, text_input, confidence_slider], outputs=[info_output, mask_output], api_name="predict" # This creates the API endpoint ) # SAM2-compatible API endpoint - this creates /api/sam2_compatible/ gr.Interface( fn=api_predict, inputs=gr.JSON(label="SAM2/SAM3 Compatible Input"), outputs=gr.JSON(label="SAM2/SAM3 Compatible Output"), title="SAM2/SAM3 Compatible API", description="API endpoint that matches SAM2 inference endpoint format with SAM3 extensions", api_name="sam2_compatible" ) # Add API documentation gr.HTML("""
import requests
import base64
# Encode your image to base64
with open("image.jpg", "rb") as f:
image_b64 = base64.b64encode(f.read()).decode()
# Make API request
response = requests.post(
"https://your-username-sam3-api.hf.space/api/predict",
json={
"data": [image_b64, "kitten", 0.5]
}
)
result = response.json()
import requests
import base64
# Encode your image to base64
with open("image.jpg", "rb") as f:
image_b64 = base64.b64encode(f.read()).decode()
# SAM3 Text Prompts (NEW)
response = requests.post(
"https://your-username-sam3-api.hf.space/api/sam2_compatible",
json={
"data": [{
"inputs": {
"image": image_b64,
"text_prompts": ["kitten", "toy"],
"confidence_threshold": 0.5
}
}]
}
)
# SAM2 Compatible (Points/Boxes)
response = requests.post(
"https://your-username-sam3-api.hf.space/api/sam2_compatible",
json={
"data": [{
"inputs": {
"image": image_b64,
"boxes": [[100, 100, 200, 200]],
"confidence_threshold": 0.5
}
}]
}
)
result = response.json()
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)