import spaces import gradio as gr @spaces.GPU def sam3_predict(image, text_prompt, confidence_threshold=0.5): """ SAM3 prediction function for Stateless GPU environment All imports and CUDA operations happen here """ # Import everything inside the GPU function import torch import numpy as np from PIL import Image import base64 import io from transformers import Sam3Model, Sam3Processor try: # Handle base64 input if needed if isinstance(image, str): if image.startswith('data:image'): image = image.split(',')[1] image_bytes = base64.b64decode(image) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Initialize model and processor device = "cuda" if torch.cuda.is_available() else "cpu" model = Sam3Model.from_pretrained( "facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) processor = Sam3Processor.from_pretrained("facebook/sam3") # Process input inputs = processor( images=image, text=text_prompt.strip(), return_tensors="pt" ).to(device) # Convert dtype to match model for key in inputs: if inputs[key].dtype == torch.float32: inputs[key] = inputs[key].to(model.dtype) # Run inference with torch.no_grad(): outputs = model(**inputs) # Post-process results = processor.post_process_instance_segmentation( outputs, threshold=confidence_threshold, mask_threshold=0.5, target_sizes=inputs.get("original_sizes").tolist() )[0] # Return results for UI if len(results["masks"]) > 0: # Convert first mask for display mask_np = results["masks"][0].cpu().numpy().astype(np.uint8) * 255 score = results["scores"][0].item() mask_image = Image.fromarray(mask_np, mode='L') return f"Found {len(results['masks'])} masks. Best score: {score:.3f}", mask_image else: return "No masks found above confidence threshold", None except Exception as e: return f"Error: {str(e)}", None # Simple gradio interface - no class, no global state def create_interface(): with gr.Blocks(title="SAM3 Inference") as demo: gr.HTML("

SAM3 Promptable Concept Segmentation

") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Input Image") text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what to segment (e.g., 'cat', 'person', 'car')") confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold") predict_btn = gr.Button("Segment", variant="primary") with gr.Column(): info_output = gr.Textbox(label="Results Info") mask_output = gr.Image(label="Sample Mask") predict_btn.click( sam3_predict, inputs=[image_input, text_input, confidence_slider], outputs=[info_output, mask_output] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch(server_name="0.0.0.0", server_port=7860)