sam3-test / app-bak.py
chris-propeller's picture
stateless
a36d7fa
raw
history blame
11.8 kB
import spaces
import gradio as gr
import numpy as np
from PIL import Image
import base64
import io
from typing import Dict, Any
import warnings
warnings.filterwarnings("ignore")
@spaces.GPU
def sam3_inference(image, text_prompt, confidence_threshold=0.5):
"""
Standalone GPU function with model initialization for Spaces Stateless GPU
All CUDA operations and related imports must happen inside this decorated function
"""
try:
# Import torch and transformers inside GPU function to avoid main process CUDA init
import torch
from transformers import Sam3Model, Sam3Processor
# Initialize model and processor inside GPU function (required for Stateless GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Sam3Model.from_pretrained(
"facebook/sam3",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
print(f"Model loaded on device: {device}")
# Handle base64 input (for API)
if isinstance(image, str):
if image.startswith('data:image'):
image = image.split(',')[1]
image_bytes = base64.b64decode(image)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Process with SAM3
inputs = processor(
images=image,
text=text_prompt.strip(),
return_tensors="pt"
).to(device)
# Convert dtype to match model
for key in inputs:
if inputs[key].dtype == torch.float32:
inputs[key] = inputs[key].to(model.dtype)
with torch.no_grad():
outputs = model(**inputs)
# Use proper SAM3 post-processing
results = processor.post_process_instance_segmentation(
outputs,
threshold=confidence_threshold,
mask_threshold=0.5,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
return results
except Exception as e:
raise Exception(f"SAM3 inference error: {str(e)}")
class SAM3Handler:
"""SAM3 handler for both UI and API access"""
def __init__(self):
print("SAM3 handler initialized (models will be loaded lazily)")
def predict(self, image, text_prompt, confidence_threshold=0.5):
"""
Main prediction function for both UI and API
Args:
image: PIL Image or base64 string
text_prompt: String describing what to segment
confidence_threshold: Minimum confidence for masks
Returns:
Dict with masks, scores, and metadata
"""
try:
# Call the standalone GPU function
results = sam3_inference(image, text_prompt, confidence_threshold)
# Prepare response
response = {
"masks": [],
"scores": [],
"prompt_type": "text",
"prompt_value": text_prompt,
"num_masks": len(results["masks"])
}
# Process each mask
for i in range(len(results["masks"])):
mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
score = results["scores"][i].item()
if score >= confidence_threshold:
# Convert mask to base64 for API response
mask_image = Image.fromarray(mask_np, mode='L')
buffer = io.BytesIO()
mask_image.save(buffer, format='PNG')
mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
response["masks"].append(mask_b64)
response["scores"].append(score)
return response
except Exception as e:
return {"error": str(e)}
# Initialize the handler
handler = SAM3Handler()
def gradio_interface(image, text_prompt, confidence_threshold):
"""Gradio interface wrapper"""
result = handler.predict(image, text_prompt, confidence_threshold)
if "error" in result:
return f"Error: {result['error']}", None
# For UI, show the first mask as an example
if result["masks"]:
first_mask_b64 = result["masks"][0]
first_score = result["scores"][0]
# Decode first mask for display
mask_bytes = base64.b64decode(first_mask_b64)
mask_image = Image.open(io.BytesIO(mask_bytes))
info = f"Found {result['num_masks']} masks. First mask score: {first_score:.3f}"
return info, mask_image
else:
return "No masks found above confidence threshold", None
def api_predict(data: Dict[str, Any]) -> Dict[str, Any]:
"""
API function matching SAM2 inference endpoint format
Expected input format (matching SAM2 + SAM3 extensions):
{
"inputs": {
"image": "base64_encoded_image_string",
# SAM3 NEW: Text-based prompts
"text_prompts": ["person", "car"], # List of text descriptions
# SAM2 compatible: Point-based prompts
"points": [[[x1, y1]], [[x2, y2]]], # Points for each object
"labels": [[1], [1]], # Labels for each point (1=foreground, 0=background)
# SAM2 compatible: Bounding box prompts
"boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], # Bounding boxes
"multimask_output": false, # Optional, defaults to False
"confidence_threshold": 0.5 # Optional, minimum confidence for returned masks
}
}
Returns (matching SAM2 format):
{
"masks": [base64_encoded_mask_1, base64_encoded_mask_2, ...],
"scores": [score1, score2, ...],
"num_objects": int,
"sam_version": "3.0",
"success": true
}
"""
try:
inputs_data = data.get("inputs", {})
# Extract inputs
image_b64 = inputs_data.get("image")
text_prompts = inputs_data.get("text_prompts", [])
input_points = inputs_data.get("points", [])
input_labels = inputs_data.get("labels", [])
input_boxes = inputs_data.get("boxes", [])
multimask_output = inputs_data.get("multimask_output", False)
confidence_threshold = inputs_data.get("confidence_threshold", 0.5)
# Validate inputs
if not image_b64:
return {"error": "No image provided", "success": False}
has_text = bool(text_prompts)
has_points = bool(input_points and input_labels)
has_boxes = bool(input_boxes)
if not (has_text or has_points or has_boxes):
return {"error": "Must provide at least one prompt type: text_prompts, points+labels, or boxes", "success": False}
if has_points and len(input_points) != len(input_labels):
return {"error": "Number of points and labels must match", "success": False}
# Decode image
if image_b64.startswith('data:image'):
image_b64 = image_b64.split(',')[1]
image_bytes = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
all_masks = []
all_scores = []
# Process text prompts (SAM3 feature)
if has_text:
for text_prompt in text_prompts:
result = handler.predict(image, text_prompt, confidence_threshold)
if "error" not in result:
all_masks.extend(result["masks"])
all_scores.extend(result["scores"])
# Process visual prompts (SAM2 compatibility) - Basic implementation
# Note: This is a simplified version. Full SAM2 compatibility would require
# implementing the visual prompt processing in the handler
if has_boxes or has_points:
# For now, fall back to a generic prompt if no text provided
if not has_text:
result = handler.predict(image, "object", confidence_threshold)
if "error" not in result and result["masks"]:
# Take only the number of masks requested
num_requested = len(input_boxes) if has_boxes else len(input_points)
all_masks.extend(result["masks"][:num_requested])
all_scores.extend(result["scores"][:num_requested])
# Build SAM2-compatible response
return {
"masks": all_masks,
"scores": all_scores,
"num_objects": len(all_masks),
"sam_version": "3.0",
"success": True
}
except Exception as e:
return {"error": str(e), "success": False, "sam_version": "3.0"}
# Create Gradio interface
with gr.Blocks(title="SAM3 Inference API") as demo:
gr.HTML("<h1>SAM3 Promptable Concept Segmentation</h1>")
gr.HTML("<p>This Space provides both a UI and API for SAM3 inference. Use the interface below or call the API programmatically.</p>")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
text_input = gr.Textbox(label="Text Prompt", placeholder="Enter what you want to segment (e.g., 'cat', 'person', 'car')")
confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Confidence Threshold")
predict_btn = gr.Button("Segment", variant="primary")
with gr.Column():
info_output = gr.Textbox(label="Results Info")
mask_output = gr.Image(label="Sample Mask")
# API endpoint - this creates /api/predict/
predict_btn.click(
gradio_interface,
inputs=[image_input, text_input, confidence_slider],
outputs=[info_output, mask_output],
api_name="predict" # This creates the API endpoint
)
# SAM2-compatible API endpoint - this creates /api/sam2_compatible/
gr.Interface(
fn=api_predict,
inputs=gr.JSON(label="SAM2/SAM3 Compatible Input"),
outputs=gr.JSON(label="SAM2/SAM3 Compatible Output"),
title="SAM2/SAM3 Compatible API",
description="API endpoint that matches SAM2 inference endpoint format with SAM3 extensions",
api_name="sam2_compatible"
)
# Add API documentation
gr.HTML("""
<h2>API Usage</h2>
<h3>1. Simple Text API (Gradio format)</h3>
<pre>
import requests
import base64
# Encode your image to base64
with open("image.jpg", "rb") as f:
image_b64 = base64.b64encode(f.read()).decode()
# Make API request
response = requests.post(
"https://your-username-sam3-api.hf.space/api/predict",
json={
"data": [image_b64, "kitten", 0.5]
}
)
result = response.json()
</pre>
<h3>2. SAM2/SAM3 Compatible API (Inference Endpoint format)</h3>
<pre>
import requests
import base64
# Encode your image to base64
with open("image.jpg", "rb") as f:
image_b64 = base64.b64encode(f.read()).decode()
# SAM3 Text Prompts (NEW)
response = requests.post(
"https://your-username-sam3-api.hf.space/api/sam2_compatible",
json={
"data": [{
"inputs": {
"image": image_b64,
"text_prompts": ["kitten", "toy"],
"confidence_threshold": 0.5
}
}]
}
)
# SAM2 Compatible (Points/Boxes)
response = requests.post(
"https://your-username-sam3-api.hf.space/api/sam2_compatible",
json={
"data": [{
"inputs": {
"image": image_b64,
"boxes": [[100, 100, 200, 200]],
"confidence_threshold": 0.5
}
}]
}
)
result = response.json()
</pre>
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)