Spaces:
Running
on
L4
Running
on
L4
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import base64 | |
| import io | |
| from typing import Dict, Any | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| def sam3_inference(image, text_prompt, confidence_threshold=0.5): | |
| """ | |
| Standalone GPU function with model initialization for Spaces Stateless GPU | |
| All CUDA operations and related imports must happen inside this decorated function | |
| """ | |
| try: | |
| # Import torch and transformers inside GPU function to avoid main process CUDA init | |
| import torch | |
| from transformers import Sam3Model, Sam3Processor | |
| # Initialize model and processor inside GPU function (required for Stateless GPU) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = Sam3Model.from_pretrained( | |
| "facebook/sam3", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ).to(device) | |
| processor = Sam3Processor.from_pretrained("facebook/sam3") | |
| print(f"Model loaded on device: {device}") | |
| # Handle base64 input (for API) | |
| if isinstance(image, str): | |
| if image.startswith('data:image'): | |
| image = image.split(',')[1] | |
| image_bytes = base64.b64decode(image) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # Process with SAM3 | |
| inputs = processor( | |
| images=image, | |
| text=text_prompt.strip(), | |
| return_tensors="pt" | |
| ).to(device) | |
| # Convert dtype to match model | |
| for key in inputs: | |
| if inputs[key].dtype == torch.float32: | |
| inputs[key] = inputs[key].to(model.dtype) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Use proper SAM3 post-processing | |
| results = processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=confidence_threshold, | |
| mask_threshold=0.5, | |
| target_sizes=inputs.get("original_sizes").tolist() | |
| )[0] | |
| return results | |
| except Exception as e: | |
| raise Exception(f"SAM3 inference error: {str(e)}") | |
| class SAM3Handler: | |
| """SAM3 handler for both UI and API access""" | |
| def __init__(self): | |
| print("SAM3 handler initialized (models will be loaded lazily)") | |
| def predict(self, image, text_prompt, confidence_threshold=0.5): | |
| """ | |
| Main prediction function for both UI and API | |
| Args: | |
| image: PIL Image or base64 string | |
| text_prompt: String describing what to segment | |
| confidence_threshold: Minimum confidence for masks | |
| Returns: | |
| Dict with masks, scores, and metadata | |
| """ | |
| try: | |
| # Call the standalone GPU function | |
| results = sam3_inference(image, text_prompt, confidence_threshold) | |
| # Prepare response | |
| response = { | |
| "masks": [], | |
| "scores": [], | |
| "prompt_type": "text", | |
| "prompt_value": text_prompt, | |
| "num_masks": len(results["masks"]) | |
| } | |
| # Process each mask | |
| for i in range(len(results["masks"])): | |
| mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255 | |
| score = results["scores"][i].item() | |
| if score >= confidence_threshold: | |
| # Convert mask to base64 for API response | |
| mask_image = Image.fromarray(mask_np, mode='L') | |
| buffer = io.BytesIO() | |
| mask_image.save(buffer, format='PNG') | |
| mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| response["masks"].append(mask_b64) | |
| response["scores"].append(score) | |
| return response | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Initialize the handler | |
| handler = SAM3Handler() | |
| def gradio_interface(image, text_prompt, confidence_threshold): | |
| """Gradio interface wrapper""" | |
| result = handler.predict(image, text_prompt, confidence_threshold) | |
| if "error" in result: | |
| return f"Error: {result['error']}", None | |
| # For UI, show the first mask as an example | |
| if result["masks"]: | |
| first_mask_b64 = result["masks"][0] | |
| first_score = result["scores"][0] | |
| # Decode first mask for display | |
| mask_bytes = base64.b64decode(first_mask_b64) | |
| mask_image = Image.open(io.BytesIO(mask_bytes)) | |
| info = f"Found {result['num_masks']} masks. First mask score: {first_score:.3f}" | |
| return info, mask_image | |
| else: | |
| return "No masks found above confidence threshold", None | |
| def api_predict(data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| API function matching SAM2 inference endpoint format | |
| Expected input format (matching SAM2 + SAM3 extensions): | |
| { | |
| "inputs": { | |
| "image": "base64_encoded_image_string", | |
| # SAM3 NEW: Text-based prompts | |
| "text_prompts": ["person", "car"], # List of text descriptions | |
| # SAM2 compatible: Point-based prompts | |
| "points": [[[x1, y1]], [[x2, y2]]], # Points for each object | |
| "labels": [[1], [1]], # Labels for each point (1=foreground, 0=background) | |
| # SAM2 compatible: Bounding box prompts | |
| "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], # Bounding boxes | |
| "multimask_output": false, # Optional, defaults to False | |
| "confidence_threshold": 0.5 # Optional, minimum confidence for returned masks | |
| } | |
| } | |
| Returns (matching SAM2 format): | |
| { | |
| "masks": [base64_encoded_mask_1, base64_encoded_mask_2, ...], | |
| "scores": [score1, score2, ...], | |
| "num_objects": int, | |
| "sam_version": "3.0", | |
| "success": true | |
| } | |
| """ | |
| try: | |
| inputs_data = data.get("inputs", {}) | |
| # Extract inputs | |
| image_b64 = inputs_data.get("image") | |
| text_prompts = inputs_data.get("text_prompts", []) | |
| input_points = inputs_data.get("points", []) | |
| input_labels = inputs_data.get("labels", []) | |
| input_boxes = inputs_data.get("boxes", []) | |
| multimask_output = inputs_data.get("multimask_output", False) | |
| confidence_threshold = inputs_data.get("confidence_threshold", 0.5) | |
| # Validate inputs | |
| if not image_b64: | |
| return {"error": "No image provided", "success": False} | |
| has_text = bool(text_prompts) | |
| has_points = bool(input_points and input_labels) | |
| has_boxes = bool(input_boxes) | |
| if not (has_text or has_points or has_boxes): | |
| return {"error": "Must provide at least one prompt type: text_prompts, points+labels, or boxes", "success": False} | |
| if has_points and len(input_points) != len(input_labels): | |
| return {"error": "Number of points and labels must match", "success": False} | |
| # Decode image | |
| if image_b64.startswith('data:image'): | |
| image_b64 = image_b64.split(',')[1] | |
| image_bytes = base64.b64decode(image_b64) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| all_masks = [] | |
| all_scores = [] | |
| # Process text prompts (SAM3 feature) | |
| if has_text: | |
| for text_prompt in text_prompts: | |
| result = handler.predict(image, text_prompt, confidence_threshold) | |
| if "error" not in result: | |
| all_masks.extend(result["masks"]) | |
| all_scores.extend(result["scores"]) | |
| # Process visual prompts (SAM2 compatibility) - Basic implementation | |
| # Note: This is a simplified version. Full SAM2 compatibility would require | |
| # implementing the visual prompt processing in the handler | |
| if has_boxes or has_points: | |
| # For now, fall back to a generic prompt if no text provided | |
| if not has_text: | |
| result = handler.predict(image, "object", confidence_threshold) | |
| if "error" not in result and result["masks"]: | |
| # Take only the number of masks requested | |
| num_requested = len(input_boxes) if has_boxes else len(input_points) | |
| all_masks.extend(result["masks"][:num_requested]) | |
| all_scores.extend(result["scores"][:num_requested]) | |
| # Build SAM2-compatible response | |
| return { | |
| "masks": all_masks, | |
| "scores": all_scores, | |
| "num_objects": len(all_masks), | |
| "sam_version": "3.0", | |
| "success": True | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "success": False, "sam_version": "3.0"} | |
| # Create Gradio interface | |
| with gr.Blocks(title="SAM3 Inference API") as demo: | |
| gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>") | |
| gr.HTML("<p>This Space provides both a UI and API for SAM3 inference. Use the interface below or call the API programmatically.</p>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what you want to segment (e.g., 'cat', 'person', 'car')") | |
| confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold") | |
| predict_btn = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| info_output = gr.Textbox(label="Results Info") | |
| mask_output = gr.Image(label="Sample Mask") | |
| # API endpoint - this creates /api/predict/ | |
| predict_btn.click( | |
| gradio_interface, | |
| inputs=[image_input, text_input, confidence_slider], | |
| outputs=[info_output, mask_output], | |
| api_name="predict" # This creates the API endpoint | |
| ) | |
| # SAM2-compatible API endpoint - this creates /api/sam2_compatible/ | |
| gr.Interface( | |
| fn=api_predict, | |
| inputs=gr.JSON(label="SAM2/SAM3 Compatible Input"), | |
| outputs=gr.JSON(label="SAM2/SAM3 Compatible Output"), | |
| title="SAM2/SAM3 Compatible API", | |
| description="API endpoint that matches SAM2 inference endpoint format with SAM3 extensions", | |
| api_name="sam2_compatible" | |
| ) | |
| # Add API documentation | |
| gr.HTML(""" | |
| <h2>API Usage</h2> | |
| <h3>1. Simple Text API (Gradio format)</h3> | |
| <pre> | |
| import requests | |
| import base64 | |
| # Encode your image to base64 | |
| with open("image.jpg", "rb") as f: | |
| image_b64 = base64.b64encode(f.read()).decode() | |
| # Make API request | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/predict", | |
| json={ | |
| "data": [image_b64, "kitten", 0.5] | |
| } | |
| ) | |
| result = response.json() | |
| </pre> | |
| <h3>2. SAM2/SAM3 Compatible API (Inference Endpoint format)</h3> | |
| <pre> | |
| import requests | |
| import base64 | |
| # Encode your image to base64 | |
| with open("image.jpg", "rb") as f: | |
| image_b64 = base64.b64encode(f.read()).decode() | |
| # SAM3 Text Prompts (NEW) | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/sam2_compatible", | |
| json={ | |
| "data": [{ | |
| "inputs": { | |
| "image": image_b64, | |
| "text_prompts": ["kitten", "toy"], | |
| "confidence_threshold": 0.5 | |
| } | |
| }] | |
| } | |
| ) | |
| # SAM2 Compatible (Points/Boxes) | |
| response = requests.post( | |
| "https://your-username-sam3-api.hf.space/api/sam2_compatible", | |
| json={ | |
| "data": [{ | |
| "inputs": { | |
| "image": image_b64, | |
| "boxes": [[100, 100, 200, 200]], | |
| "confidence_threshold": 0.5 | |
| } | |
| }] | |
| } | |
| ) | |
| result = response.json() | |
| </pre> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |