sam3-test / app.py
chris-propeller's picture
stateless
a36d7fa
raw
history blame
3.44 kB
import spaces
import gradio as gr
@spaces.GPU
def sam3_predict(image, text_prompt, confidence_threshold=0.5):
"""
SAM3 prediction function for Stateless GPU environment
All imports and CUDA operations happen here
"""
# Import everything inside the GPU function
import torch
import numpy as np
from PIL import Image
import base64
import io
from transformers import Sam3Model, Sam3Processor
try:
# Handle base64 input if needed
if isinstance(image, str):
if image.startswith('data:image'):
image = image.split(',')[1]
image_bytes = base64.b64decode(image)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Initialize model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Sam3Model.from_pretrained(
"facebook/sam3",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
# Process input
inputs = processor(
images=image,
text=text_prompt.strip(),
return_tensors="pt"
).to(device)
# Convert dtype to match model
for key in inputs:
if inputs[key].dtype == torch.float32:
inputs[key] = inputs[key].to(model.dtype)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process
results = processor.post_process_instance_segmentation(
outputs,
threshold=confidence_threshold,
mask_threshold=0.5,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
# Return results for UI
if len(results["masks"]) > 0:
# Convert first mask for display
mask_np = results["masks"][0].cpu().numpy().astype(np.uint8) * 255
score = results["scores"][0].item()
mask_image = Image.fromarray(mask_np, mode='L')
return f"Found {len(results['masks'])} masks. Best score: {score:.3f}", mask_image
else:
return "No masks found above confidence threshold", None
except Exception as e:
return f"Error: {str(e)}", None
# Simple gradio interface - no class, no global state
def create_interface():
with gr.Blocks(title="SAM3 Inference") as demo:
gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what to segment (e.g., 'cat', 'person', 'car')")
confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold")
predict_btn = gr.Button("Segment", variant="primary")
with gr.Column():
info_output = gr.Textbox(label="Results Info")
mask_output = gr.Image(label="Sample Mask")
predict_btn.click(
sam3_predict,
inputs=[image_input, text_input, confidence_slider],
outputs=[info_output, mask_output]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)