Spaces:
Running
on
L4
Running
on
L4
| import spaces | |
| import gradio as gr | |
| def sam3_predict(image, text_prompt, confidence_threshold=0.5): | |
| """ | |
| SAM3 prediction function for Stateless GPU environment | |
| All imports and CUDA operations happen here | |
| """ | |
| # Import everything inside the GPU function | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import base64 | |
| import io | |
| from transformers import Sam3Model, Sam3Processor | |
| try: | |
| # Handle base64 input if needed | |
| if isinstance(image, str): | |
| if image.startswith('data:image'): | |
| image = image.split(',')[1] | |
| image_bytes = base64.b64decode(image) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # Initialize model and processor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = Sam3Model.from_pretrained( | |
| "facebook/sam3", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ).to(device) | |
| processor = Sam3Processor.from_pretrained("facebook/sam3") | |
| # Process input | |
| inputs = processor( | |
| images=image, | |
| text=text_prompt.strip(), | |
| return_tensors="pt" | |
| ).to(device) | |
| # Convert dtype to match model | |
| for key in inputs: | |
| if inputs[key].dtype == torch.float32: | |
| inputs[key] = inputs[key].to(model.dtype) | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Post-process | |
| results = processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=confidence_threshold, | |
| mask_threshold=0.5, | |
| target_sizes=inputs.get("original_sizes").tolist() | |
| )[0] | |
| # Return results for UI | |
| if len(results["masks"]) > 0: | |
| # Convert first mask for display | |
| mask_np = results["masks"][0].cpu().numpy().astype(np.uint8) * 255 | |
| score = results["scores"][0].item() | |
| mask_image = Image.fromarray(mask_np, mode='L') | |
| return f"Found {len(results['masks'])} masks. Best score: {score:.3f}", mask_image | |
| else: | |
| return "No masks found above confidence threshold", None | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| # Simple gradio interface - no class, no global state | |
| def create_interface(): | |
| with gr.Blocks(title="SAM3 Inference") as demo: | |
| gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what to segment (e.g., 'cat', 'person', 'car')") | |
| confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold") | |
| predict_btn = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| info_output = gr.Textbox(label="Results Info") | |
| mask_output = gr.Image(label="Sample Mask") | |
| predict_btn.click( | |
| sam3_predict, | |
| inputs=[image_input, text_input, confidence_slider], | |
| outputs=[info_output, mask_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |