Spaces:
Running
on
Zero
Running
on
Zero
| # ************************************************************************* | |
| # Grasp Any Region (GAR) - Gradio Demo | |
| # Region-level Multimodal Understanding for Vision-Language Models | |
| # ************************************************************************* | |
| # π¨ CRITICAL: Import spaces FIRST before any CUDA-related packages | |
| import spaces | |
| # Now import CUDA-related packages | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import ( | |
| AutoModel, | |
| AutoProcessor, | |
| GenerationConfig, | |
| SamModel, | |
| SamProcessor, | |
| ) | |
| import cv2 | |
| import sys | |
| import os | |
| # Add project root to path for imports | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| try: | |
| from evaluation.eval_dataset import SingleRegionCaptionDataset | |
| except ImportError: | |
| print("Warning: Could not import SingleRegionCaptionDataset. Using simplified version.") | |
| SingleRegionCaptionDataset = None | |
| # Initialize device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Global model variables (loaded once) | |
| gar_model = None | |
| gar_processor = None | |
| sam_model = None | |
| sam_processor = None | |
| def load_models(): | |
| """Load models once at startup""" | |
| global gar_model, gar_processor, sam_model, sam_processor | |
| if gar_model is None: | |
| print("Loading GAR model...") | |
| model_path = "HaochenWang/GAR-1B" | |
| gar_model = AutoModel.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ).eval() | |
| gar_processor = AutoProcessor.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| ) | |
| print("GAR model loaded successfully!") | |
| if sam_model is None: | |
| print("Loading SAM model...") | |
| sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) | |
| sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
| print("SAM model loaded successfully!") | |
| def generate_mask_from_points(image, points_str): | |
| """Generate mask using SAM from point coordinates""" | |
| try: | |
| load_models() | |
| if not points_str or points_str.strip() == "": | |
| return None, "Please provide points in format: x1,y1;x2,y2" | |
| # Parse points | |
| points = [] | |
| labels = [] | |
| for point in points_str.split(';'): | |
| point = point.strip() | |
| if point: | |
| x, y = map(float, point.split(',')) | |
| points.append([x, y]) | |
| labels.append(1) # Foreground point | |
| if not points: | |
| return None, "No valid points provided" | |
| # Apply SAM | |
| inputs = sam_processor( | |
| image, | |
| input_points=[points], | |
| input_labels=[labels], | |
| return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = sam_model(**inputs) | |
| masks = sam_processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu(), | |
| )[0][0] | |
| scores = outputs.iou_scores[0, 0] | |
| mask_selection_index = scores.argmax() | |
| mask_np = masks[mask_selection_index].numpy() | |
| # Visualize mask | |
| mask_img = (mask_np * 255).astype(np.uint8) | |
| return Image.fromarray(mask_img), "Mask generated successfully!" | |
| except Exception as e: | |
| return None, f"Error generating mask: {str(e)}" | |
| def generate_mask_from_box(image, box_str): | |
| """Generate mask using SAM from bounding box""" | |
| try: | |
| load_models() | |
| if not box_str or box_str.strip() == "": | |
| return None, "Please provide box in format: x1,y1,x2,y2" | |
| # Parse box | |
| box = list(map(float, box_str.split(','))) | |
| if len(box) != 4: | |
| return None, "Box must have 4 coordinates: x1,y1,x2,y2" | |
| # Apply SAM | |
| inputs = sam_processor( | |
| image, | |
| input_boxes=[[box]], | |
| return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = sam_model(**inputs) | |
| masks = sam_processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu(), | |
| )[0][0] | |
| scores = outputs.iou_scores[0, 0] | |
| mask_selection_index = scores.argmax() | |
| mask_np = masks[mask_selection_index].numpy() | |
| # Visualize mask | |
| mask_img = (mask_np * 255).astype(np.uint8) | |
| return Image.fromarray(mask_img), "Mask generated successfully!" | |
| except Exception as e: | |
| return None, f"Error generating mask: {str(e)}" | |
| def describe_region(image, mask): | |
| """Generate description for a region defined by a mask""" | |
| try: | |
| load_models() | |
| if image is None: | |
| return "Please provide an image" | |
| if mask is None: | |
| return "Please provide a mask (upload or generate using SAM)" | |
| # Convert mask to numpy | |
| if isinstance(mask, Image.Image): | |
| mask_np = np.array(mask.convert("L")) | |
| else: | |
| mask_np = np.array(mask) | |
| # Ensure mask is binary | |
| mask_np = (mask_np > 127).astype(np.uint8) | |
| # Prepare data | |
| prompt_number = gar_model.config.prompt_numbers | |
| prompt_tokens = [f"<Prompt{i_p}>" for i_p in range(prompt_number)] + ["<NO_Prompt>"] | |
| if SingleRegionCaptionDataset is not None: | |
| dataset = SingleRegionCaptionDataset( | |
| image=image, | |
| mask=mask_np, | |
| processor=gar_processor, | |
| prompt_number=prompt_number, | |
| visual_prompt_tokens=prompt_tokens, | |
| data_dtype=torch.bfloat16, | |
| ) | |
| data_sample = dataset[0] | |
| else: | |
| # Simplified processing if dataset class not available | |
| # This is a fallback - the actual implementation requires SingleRegionCaptionDataset | |
| return "Error: SingleRegionCaptionDataset not available. Please check installation." | |
| # Generate description | |
| with torch.no_grad(): | |
| generate_ids = gar_model.generate( | |
| **data_sample, | |
| generation_config=GenerationConfig( | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| eos_token_id=gar_processor.tokenizer.eos_token_id, | |
| pad_token_id=gar_processor.tokenizer.pad_token_id, | |
| ), | |
| return_dict=True, | |
| ) | |
| output_caption = gar_processor.tokenizer.decode( | |
| generate_ids.sequences[0], skip_special_tokens=True | |
| ).strip() | |
| return output_caption | |
| except Exception as e: | |
| return f"Error generating description: {str(e)}" | |
| def create_visualization(image, mask, points_str=None, box_str=None): | |
| """Create visualization with mask overlay""" | |
| try: | |
| if image is None or mask is None: | |
| return None | |
| img_np = np.array(image).astype(float) / 255.0 | |
| if isinstance(mask, Image.Image): | |
| mask_np = np.array(mask.convert("L")) > 127 | |
| else: | |
| mask_np = np.array(mask) > 127 | |
| # Draw contour | |
| mask_uint8 = mask_np.astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| img_vis = img_np.copy() | |
| cv2.drawContours(img_vis, contours, -1, (1.0, 1.0, 0.0), thickness=3) | |
| # Draw points if provided | |
| if points_str: | |
| for point in points_str.split(';'): | |
| point = point.strip() | |
| if point: | |
| x, y = map(float, point.split(',')) | |
| cv2.circle(img_vis, (int(x), int(y)), radius=8, color=(1.0, 0.0, 0.0), thickness=-1) | |
| cv2.circle(img_vis, (int(x), int(y)), radius=8, color=(1.0, 1.0, 1.0), thickness=2) | |
| # Draw box if provided | |
| if box_str: | |
| coords = list(map(float, box_str.split(','))) | |
| if len(coords) == 4: | |
| x1, y1, x2, y2 = map(int, coords) | |
| cv2.rectangle(img_vis, (x1, y1), (x2, y2), color=(1.0, 1.0, 1.0), thickness=3) | |
| cv2.rectangle(img_vis, (x1, y1), (x2, y2), color=(1.0, 0.0, 0.0), thickness=1) | |
| img_pil = Image.fromarray((img_vis * 255.0).astype(np.uint8)) | |
| return img_pil | |
| except Exception as e: | |
| print(f"Error creating visualization: {str(e)}") | |
| return None | |
| # Create Gradio interface | |
| with gr.Blocks(title="Grasp Any Region (GAR) Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π― Grasp Any Region (GAR) | |
| **Region-level Multimodal Understanding for Vision-Language Models** | |
| This demo showcases GAR's ability to understand and describe specific regions in images: | |
| - π¨ **Single Region Understanding**: Describe specific areas using points, boxes, or masks | |
| - π **SAM Integration**: Generate masks interactively using Segment Anything Model | |
| - π‘ **Detailed Descriptions**: Get comprehensive descriptions of any region | |
| Built on top of Perception-LM with RoI-aligned feature replay technique. | |
| π [Paper](https://arxiv.org/abs/2510.18876) | π» [GitHub](https://github.com/Haochen-Wang409/Grasp-Any-Region) | π€ [Model](https://huggingface.co/HaochenWang/GAR-1B) | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Points-based segmentation | |
| with gr.Tab("π― Points β Describe"): | |
| gr.Markdown("### Click points on the image or enter coordinates to segment and describe a region") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_points = gr.Image(label="Input Image", type="pil") | |
| points_input = gr.Textbox( | |
| label="Points (format: x1,y1;x2,y2;...)", | |
| placeholder="e.g., 1172,812;1572,800", | |
| value="1172,812;1572,800" | |
| ) | |
| with gr.Row(): | |
| gen_mask_points_btn = gr.Button("Generate Mask", variant="primary") | |
| describe_points_btn = gr.Button("Describe Region", variant="secondary") | |
| with gr.Column(): | |
| mask_points = gr.Image(label="Generated Mask", type="pil") | |
| vis_points = gr.Image(label="Visualization") | |
| desc_points = gr.Textbox(label="Region Description", lines=5) | |
| points_status = gr.Textbox(label="Status", visible=False) | |
| gen_mask_points_btn.click( | |
| fn=generate_mask_from_points, | |
| inputs=[img_points, points_input], | |
| outputs=[mask_points, points_status] | |
| ) | |
| describe_points_btn.click( | |
| fn=describe_region, | |
| inputs=[img_points, mask_points], | |
| outputs=desc_points | |
| ).then( | |
| fn=create_visualization, | |
| inputs=[img_points, mask_points, points_input, gr.Textbox(visible=False)], | |
| outputs=vis_points | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["assets/demo_image_2.jpg", "1172,812;1572,800"], | |
| ], | |
| inputs=[img_points, points_input], | |
| label="Example Images" | |
| ) | |
| # Tab 2: Box-based segmentation | |
| with gr.Tab("π¦ Box β Describe"): | |
| gr.Markdown("### Draw a bounding box or enter coordinates to segment and describe a region") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_box = gr.Image(label="Input Image", type="pil") | |
| box_input = gr.Textbox( | |
| label="Bounding Box (format: x1,y1,x2,y2)", | |
| placeholder="e.g., 800,500,1800,1000", | |
| value="800,500,1800,1000" | |
| ) | |
| with gr.Row(): | |
| gen_mask_box_btn = gr.Button("Generate Mask", variant="primary") | |
| describe_box_btn = gr.Button("Describe Region", variant="secondary") | |
| with gr.Column(): | |
| mask_box = gr.Image(label="Generated Mask", type="pil") | |
| vis_box = gr.Image(label="Visualization") | |
| desc_box = gr.Textbox(label="Region Description", lines=5) | |
| box_status = gr.Textbox(label="Status", visible=False) | |
| gen_mask_box_btn.click( | |
| fn=generate_mask_from_box, | |
| inputs=[img_box, box_input], | |
| outputs=[mask_box, box_status] | |
| ) | |
| describe_box_btn.click( | |
| fn=describe_region, | |
| inputs=[img_box, mask_box], | |
| outputs=desc_box | |
| ).then( | |
| fn=create_visualization, | |
| inputs=[img_box, mask_box, gr.Textbox(visible=False), box_input], | |
| outputs=vis_box | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["assets/demo_image_2.jpg", "800,500,1800,1000"], | |
| ], | |
| inputs=[img_box, box_input], | |
| label="Example Images" | |
| ) | |
| # Tab 3: Direct mask upload | |
| with gr.Tab("π Mask β Describe"): | |
| gr.Markdown("### Upload a pre-made mask to describe a region") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_mask = gr.Image(label="Input Image", type="pil") | |
| mask_upload = gr.Image(label="Upload Mask", type="pil") | |
| describe_mask_btn = gr.Button("Describe Region", variant="primary") | |
| with gr.Column(): | |
| vis_mask = gr.Image(label="Visualization") | |
| desc_mask = gr.Textbox(label="Region Description", lines=5) | |
| describe_mask_btn.click( | |
| fn=describe_region, | |
| inputs=[img_mask, mask_upload], | |
| outputs=desc_mask | |
| ).then( | |
| fn=create_visualization, | |
| inputs=[img_mask, mask_upload, gr.Textbox(visible=False), gr.Textbox(visible=False)], | |
| outputs=vis_mask | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["assets/demo_image_1.png", "assets/demo_mask_1.png"], | |
| ], | |
| inputs=[img_mask, mask_upload], | |
| label="Example Images" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π How to Use: | |
| 1. **Points β Describe**: Click or enter point coordinates, generate mask, then describe | |
| 2. **Box β Describe**: Draw or enter a bounding box, generate mask, then describe | |
| 3. **Mask β Describe**: Upload a pre-made mask directly and describe | |
| ### π§ Technical Details: | |
| - **Model**: GAR-1B (1 billion parameters) | |
| - **Base**: Facebook Perception-LM with RoI-aligned feature replay | |
| - **Segmentation**: Segment Anything Model (SAM ViT-Huge) | |
| - **Hardware**: Powered by ZeroGPU (NVIDIA H200, 70GB VRAM) | |
| ### π Citation: | |
| ```bibtex | |
| @article{wang2025grasp, | |
| title={Grasp Any Region: Prompting MLLM to Understand the Dense World}, | |
| author={Haochen Wang et al.}, | |
| journal={arXiv preprint arXiv:2510.18876}, | |
| year={2025} | |
| } | |
| ``` | |
| """) | |
| # Load models on startup | |
| try: | |
| load_models() | |
| except Exception as e: | |
| print(f"Warning: Could not pre-load models: {e}") | |
| print("Models will be loaded on first use.") | |
| if __name__ == "__main__": | |
| demo.launch() | |