Spaces:
Runtime error
Runtime error
feat: ✨ color and checkbox events change annotator results dynamically feature added
0dbbe5d
verified
| import os # added for cache_examples | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import supervision as sv | |
| from gradio import ColorPicker | |
| from PIL import Image | |
| from torch import cuda, device | |
| from ultralytics import YOLO | |
| # Use GPU if available | |
| if cuda.is_available(): | |
| device = device("cuda") | |
| else: | |
| device = device("cpu") | |
| TITLE = """<h1 align="center">Supervision Annotator Playground 🚀</h1>""" | |
| SUBTITLE = """<h2 align="center">Experiment with Supervision Annotators</h2>""" | |
| BANNER = """ | |
| <div align="center"> | |
| <p> | |
| <a align="center" href="https://supervision.roboflow.com/" target="_blank"> | |
| <img style="max-width: 50%; height: auto; margin: 0 auto; display: block; padding: 20" | |
| src="https://media.roboflow.com/open-source/supervision/rf-supervision-banner.png?updatedAt=1678995927529"> | |
| </a> | |
| </p> | |
| </div> | |
| """ # noqa: E501 title/docs | |
| DESC = """ | |
| <div style="text-align: center; display: flex; justify-content: center; align-items: center;"> | |
| <a href="https://huggingface.co/spaces/Roboflow/Annotators?duplicate=true"> | |
| <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space" style="margin-right: 10px;"> | |
| </a> | |
| <a href="https://github.com/roboflow/supervision"> | |
| <img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/roboflow/supervision" | |
| style="margin-right: 10px;"> | |
| </a> | |
| <a href="https://colab.research.google.com/github/roboflow/supervision/blob/main/demo.ipynb"> | |
| <img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" | |
| style="margin-right: 10px;"> | |
| </a> | |
| </div> | |
| """ # noqa: E501 title/docs | |
| last_detections = sv.Detections.empty() | |
| last_labels: list[str] = [] | |
| def load_model(img, model: str | Path = "yolov8s-seg.pt"): | |
| # Load model, get results and return detections/labels | |
| model = YOLO(model=model) | |
| result = model(img, verbose=False, imgsz=1280)[0] | |
| detections = sv.Detections.from_ultralytics(result) | |
| labels = [ | |
| f"{model.model.names[class_id]} {confidence:.2f}" | |
| for class_id, confidence in zip(detections.class_id, detections.confidence) | |
| ] | |
| return detections, labels | |
| def calculate_crop_dim(a, b): | |
| # Calculates the crop dimensions of the image resultant | |
| if a > b: | |
| width = a | |
| height = a | |
| else: | |
| width = b | |
| height = b | |
| return width, height | |
| def annotators( | |
| img, | |
| last_detections, | |
| annotators_list, | |
| last_labels, | |
| colorbb, | |
| colormask, | |
| colorellipse, | |
| colorbc, | |
| colorcir, | |
| colorlabel, | |
| colorhalo, | |
| colortri, | |
| colordot, | |
| ) -> np.ndarray: | |
| if last_detections == sv.Detections.empty(): | |
| gr.Warning("Detection is empty please add image and annotate first") | |
| return np.zeros() | |
| if "Blur" in annotators_list: | |
| # Apply Blur | |
| blur_annotator = sv.BlurAnnotator() | |
| img = blur_annotator.annotate(img, detections=last_detections) | |
| if "BoundingBox" in annotators_list: | |
| # Draw Boundingbox | |
| box_annotator = sv.BoundingBoxAnnotator(sv.Color.from_hex(str(colorbb))) | |
| img = box_annotator.annotate(img, detections=last_detections) | |
| if "Mask" in annotators_list: | |
| # Draw Mask | |
| mask_annotator = sv.MaskAnnotator(sv.Color.from_hex(str(colormask))) | |
| img = mask_annotator.annotate(img, detections=last_detections) | |
| if "Ellipse" in annotators_list: | |
| # Draw Ellipse | |
| ellipse_annotator = sv.EllipseAnnotator(sv.Color.from_hex(str(colorellipse))) | |
| img = ellipse_annotator.annotate(img, detections=last_detections) | |
| if "BoxCorner" in annotators_list: | |
| # Draw Box corner | |
| corner_annotator = sv.BoxCornerAnnotator(sv.Color.from_hex(str(colorbc))) | |
| img = corner_annotator.annotate(img, detections=last_detections) | |
| if "Circle" in annotators_list: | |
| # Draw Circle | |
| circle_annotator = sv.CircleAnnotator(sv.Color.from_hex(str(colorcir))) | |
| img = circle_annotator.annotate(img, detections=last_detections) | |
| if "Label" in annotators_list: | |
| # Draw Label | |
| label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) | |
| label_annotator = sv.LabelAnnotator(sv.Color.from_hex(str(colorlabel))) | |
| img = label_annotator.annotate( | |
| img, detections=last_detections, labels=last_labels | |
| ) | |
| if "Pixelate" in annotators_list: | |
| # Apply PixelateAnnotator | |
| pixelate_annotator = sv.PixelateAnnotator() | |
| img = pixelate_annotator.annotate(img, detections=last_detections) | |
| if "Halo" in annotators_list: | |
| # Draw HaloAnnotator | |
| halo_annotator = sv.HaloAnnotator(sv.Color.from_hex(str(colorhalo))) | |
| img = halo_annotator.annotate(img, detections=last_detections) | |
| if "HeatMap" in annotators_list: | |
| # Draw HeatMapAnnotator | |
| heatmap_annotator = sv.HeatMapAnnotator() | |
| img = heatmap_annotator.annotate(img, detections=last_detections) | |
| if "Dot" in annotators_list: | |
| # Dot DotAnnotator | |
| dot_annotator = sv.DotAnnotator(sv.Color.from_hex(str(colordot))) | |
| img = dot_annotator.annotate(img, detections=last_detections) | |
| if "Triangle" in annotators_list: | |
| # Draw TriangleAnnotator | |
| tri_annotator = sv.TriangleAnnotator(sv.Color.from_hex(str(colortri))) | |
| img = tri_annotator.annotate(img, detections=last_detections) | |
| # crop image for the largest possible square | |
| res_img = Image.fromarray(img) | |
| # print(type(res_img)) | |
| x = 0 | |
| y = 0 | |
| # print("size of the pil im=", res_img.size) | |
| (v1, v2) = res_img.size | |
| width, height = calculate_crop_dim(v1, v2) | |
| # print(width, height) | |
| my_img = np.array(res_img) | |
| crop_img = my_img[y : y + height, x : x + width] | |
| # print(type(crop_img)) | |
| return crop_img[..., ::-1].copy() # BGR to RGB using numpy | |
| def annotator( | |
| img, | |
| model, | |
| annotators_list, | |
| colorbb, | |
| colormask, | |
| colorellipse, | |
| colorbc, | |
| colorcir, | |
| colorlabel, | |
| colorhalo, | |
| colortri, | |
| colordot, | |
| progress=gr.Progress(track_tqdm=True), | |
| ) -> np.ndarray: | |
| """ | |
| Function that changes the color of annotators | |
| Args: | |
| annotators: Icon whose color needs to be changed. | |
| color: Chosen color with which to edit the input icon in Hex. | |
| img: Input image is numpy matrix in BGR. | |
| Returns: | |
| annotators: annotated image | |
| """ | |
| img = img[..., ::-1].copy() # BGR to RGB using numpy | |
| detections, labels = load_model(img, model) | |
| last_detections = detections | |
| last_labels = labels | |
| return annotators( | |
| img, | |
| last_detections, | |
| annotators_list, | |
| last_labels, | |
| colorbb, | |
| colormask, | |
| colorellipse, | |
| colorbc, | |
| colorcir, | |
| colorlabel, | |
| colorhalo, | |
| colortri, | |
| colordot, | |
| ) | |
| purple_theme = theme = gr.themes.Soft(primary_hue=gr.themes.colors.purple).set( | |
| button_primary_background_fill="*primary_600", | |
| button_primary_background_fill_hover="*primary_700", | |
| checkbox_label_background_fill_selected="*primary_600", | |
| checkbox_background_color_selected="*primary_400", | |
| ) | |
| with gr.Blocks(theme=purple_theme) as app: | |
| gr.HTML(TITLE) | |
| gr.HTML(SUBTITLE) | |
| gr.HTML(BANNER) | |
| gr.HTML(DESC) | |
| models = gr.Dropdown( | |
| [ | |
| "yolov8n-seg.pt", | |
| "yolov8s-seg.pt", | |
| "yolov8m-seg.pt", | |
| "yolov8l-seg.pt", | |
| "yolov8x-seg.pt", | |
| ], | |
| type="value", | |
| value="yolov8s-seg.pt", | |
| label="Select Model:", | |
| ) | |
| annotators_list = gr.CheckboxGroup( | |
| choices=[ | |
| "BoundingBox", | |
| "Mask", | |
| "Halo", | |
| "Ellipse", | |
| "BoxCorner", | |
| "Circle", | |
| "Label", | |
| "Blur", | |
| "Pixelate", | |
| "HeatMap", | |
| "Dot", | |
| "Triangle", | |
| ], | |
| value=["BoundingBox", "Mask"], | |
| label="Select Annotators:", | |
| ) | |
| gr.Markdown("## Color Picker 🎨") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(): | |
| colorbb = gr.ColorPicker(value="#A351FB", label="BoundingBox") | |
| colormask = gr.ColorPicker(value="#A351FB", label="Mask") | |
| colorellipse = gr.ColorPicker(value="#A351FB", label="Ellipse") | |
| with gr.Column(): | |
| colorbc = gr.ColorPicker(value="#A351FB", label="BoxCorner") | |
| colorcir = gr.ColorPicker(value="#A351FB", label="Circle") | |
| colorlabel = gr.ColorPicker(value="#A351FB", label="Label") | |
| with gr.Column(): | |
| colorhalo = gr.ColorPicker(value="#A351FB", label="Halo") | |
| colordot = gr.ColorPicker(value="#A351FB", label="Dot") | |
| colortri = gr.ColorPicker(value="#A351FB", label="Triangle") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tab("Input image"): | |
| image_input = gr.Image(type="numpy", show_label=False) | |
| with gr.Column(): | |
| with gr.Tab("Result image"): | |
| image_output = gr.Image(type="numpy", show_label=False) | |
| image_button = gr.Button(value="Annotate it!", variant="primary") | |
| image_button.click( | |
| annotator, | |
| inputs=[ | |
| image_input, | |
| models, | |
| annotators_list, | |
| colorbb, | |
| colormask, | |
| colorellipse, | |
| colorbc, | |
| colorcir, | |
| colorlabel, | |
| colorhalo, | |
| colortri, | |
| colordot, | |
| ], | |
| outputs=image_output, | |
| ) | |
| gr.Markdown("## Image Examples 🖼️") | |
| gr.Examples( | |
| examples=[ | |
| os.path.join(os.path.abspath(""), "./assets/city.jpg"), | |
| os.path.join(os.path.abspath(""), "./assets/household.jpg"), | |
| os.path.join(os.path.abspath(""), "./assets/industry.jpg"), | |
| os.path.join(os.path.abspath(""), "./assets/retail.jpg"), | |
| os.path.join(os.path.abspath(""), "./assets/aerodefence.jpg"), | |
| ], | |
| inputs=image_input, | |
| outputs=image_output, | |
| fn=annotator, | |
| cache_examples=False, | |
| ) | |
| annotators_list.change( | |
| fn=annotator, | |
| inputs=[ | |
| image_input, | |
| models, | |
| annotators_list, | |
| colorbb, | |
| colormask, | |
| colorellipse, | |
| colorbc, | |
| colorcir, | |
| colorlabel, | |
| colorhalo, | |
| colortri, | |
| colordot, | |
| ], | |
| outputs=image_output, | |
| ) | |
| def change_color(color: ColorPicker): | |
| color.change( | |
| fn=annotator, | |
| inputs=[ | |
| image_input, | |
| models, | |
| annotators_list, | |
| colorbb, | |
| colormask, | |
| colorellipse, | |
| colorbc, | |
| colorcir, | |
| colorlabel, | |
| colorhalo, | |
| colortri, | |
| colordot, | |
| ], | |
| outputs=image_output, | |
| ) | |
| colors = [ | |
| colorbb, | |
| colormask, | |
| colorellipse, | |
| colorbc, | |
| colorcir, | |
| colorlabel, | |
| colorhalo, | |
| colortri, | |
| colordot, | |
| ] | |
| for color in colors: | |
| change_color(color) | |
| if __name__ == "__main__": | |
| print("Starting app...") | |
| print("Dark theme is available at: http://localhost:7860/?__theme=dark") | |
| # app.launch(debug=False, server_name="0.0.0.0") # for local network | |
| app.launch(debug=False) | |