Spaces:
Runtime error
Runtime error
| from io import BytesIO | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| import requests | |
| import torch | |
| from transformers import AutoModelForImageSegmentation, AutoFeatureExtractor | |
| from utils import annotate_masks | |
| from utils.sam import predict | |
| # Load the model and feature extractor | |
| model_name = "facebook/detr-resnet-50" | |
| model = AutoModelForImageSegmentation.from_pretrained(model_name) | |
| extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
| # Function to handle segmentation | |
| def segment_image(image): | |
| method = "sam" | |
| if method == "sam": | |
| point=[300,300] | |
| image_rgb = np.array(image) # Converts PIL image directly to RGB NumPy array | |
| if image_rgb.size == 0: | |
| raise ValueError("The image is empty!") | |
| if len(image_rgb.shape) == 2: # Grayscale image fix | |
| image_rgb = np.stack([image_rgb]*3, axis=-1) | |
| elif len(image_rgb.shape) == 3 and image_rgb.shape[2] == 4: # RGBA to RGB | |
| image_rgb = image_rgb[:, :, :3] | |
| print(f"========================Image type: {type(image_rgb)}, Shape: {image_rgb.shape}") | |
| # Ensure correct format for SAM (RGB and np.uint8) | |
| if image_rgb.dtype != np.uint8: | |
| image_rgb = (image_rgb * 255).astype(np.uint8) | |
| masks, scores, logits = predict(image_rgb, [point]) | |
| return annotate_masks(image_rgb, masks) | |
| else: | |
| # Prepare the image and perform segmentation | |
| inputs = extractor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| segmentation_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy() | |
| # Convert the segmentation mask to an image | |
| mask_image = Image.fromarray(segmentation_mask.astype('uint8')) | |
| return mask_image | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=segment_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Image(type="pil"), | |
| live=True, | |
| title="Image Segmentation App", | |
| description="Upload an image and get the segmented output using a pre-trained model." | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| demo.launch() | |