Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import gradio as gr | |
| import spaces | |
| from PIL import Image | |
| from dataset import ImageDataset | |
| from documentation import DOC_CONTENT | |
| from labelizer import get_task_response | |
| from packager import create_dataset_zip | |
| # Drop files after 24 hours | |
| CACHE_TTL = 24 * 60 * 60 | |
| def auto_label( | |
| image: Image.Image, imid: int, dataset: ImageDataset | |
| ) -> tuple[str, ImageDataset]: | |
| """Generate automatic label for a single image using AI model. | |
| Args: | |
| image: PIL Image to generate label for | |
| imid: Image ID in the dataset | |
| dataset: Current ImageDataset instance | |
| Returns: | |
| Tuple of (generated_label_text, updated_dataset) | |
| """ | |
| text = get_task_response("<MORE_DETAILED_CAPTION>", image) | |
| ds = dataset.update_label(imid, text) | |
| return text, ds | |
| def label_changed(label: str, imid: int, dataset: ImageDataset) -> ImageDataset: | |
| """Handle label text change event for an image. | |
| Args: | |
| label: New label text | |
| imid: Image ID in the dataset | |
| dataset: Current ImageDataset instance | |
| Returns: | |
| Updated ImageDataset with new label | |
| """ | |
| return dataset.update_label(imid, label) | |
| def update_single_label( | |
| dataset: ImageDataset, label_text: str, image_id: int | |
| ) -> ImageDataset: | |
| """Update single image label in dataset.""" | |
| return dataset.update_label(image_id, label_text) | |
| def uploaded(files: list, current_dataset: ImageDataset | None) -> ImageDataset: | |
| """Handle file upload - return new dataset instance.""" | |
| if current_dataset is None: | |
| current_dataset = ImageDataset() | |
| return current_dataset.add_images(files) | |
| def labelize_all_images( | |
| dataset: ImageDataset, label: str, progress=gr.Progress(True) | |
| ) -> tuple[ImageDataset, str]: | |
| """Generate labels for all images and return new dataset instance.""" | |
| # Generate actual labels | |
| labels_dict = {} | |
| for imdata in progress.tqdm(dataset.images): | |
| text = get_task_response("<MORE_DETAILED_CAPTION>", Image.open(imdata["path"])) # type: ignore | |
| labels_dict[imdata["id"]] = text # type: ignore | |
| return dataset.update_all_labels(labels_dict), label | |
| def create_dataset_zipfile(dataset: ImageDataset, organize_in_folders: bool): | |
| """Create and return zip file for download.""" | |
| zip_path = create_dataset_zip(dataset, organize_in_folders) | |
| output = gr.update(visible=True, value=zip_path) | |
| return output, zip_path | |
| def update_buttons_states(dataset: ImageDataset, labeling_in_progress=False): | |
| """Update all button states based on dataset and labeling progress.""" | |
| count = len(dataset.images) | |
| return ( | |
| gr.update(interactive=count == 0 and not labeling_in_progress), # upload | |
| gr.update(interactive=count > 0 and not labeling_in_progress), # label all | |
| gr.update(visible=labeling_in_progress), # progressbar | |
| gr.update(interactive=count > 0 and not labeling_in_progress), # remove all | |
| gr.update(interactive=count > 0 and not labeling_in_progress), # download | |
| labeling_in_progress, # is_labeling_in_progress | |
| ) | |
| def start_labeling(dataset: ImageDataset): | |
| """Start labeling process - disable buttons and show progress.""" | |
| return update_buttons_states(dataset, labeling_in_progress=True) | |
| def finish_labeling(dataset: ImageDataset): | |
| """Finish labeling process - enable buttons and hide progress.""" | |
| return update_buttons_states(dataset, labeling_in_progress=False) | |
| with gr.Blocks( | |
| title="Labelizer", fill_width=True, delete_cache=(CACHE_TTL, CACHE_TTL) | |
| ) as demo: | |
| dataset = gr.State() | |
| with gr.Sidebar(): | |
| gr.Markdown("# 🖼️ Image Labeling Tool") | |
| with gr.Group(): | |
| gr.Markdown("Upload images and add labels to build your dataset.") | |
| upload_button = gr.UploadButton( | |
| "📁 Upload images", | |
| file_count="multiple", | |
| ) | |
| label_all = gr.Button( | |
| "🏷️ Labelize all images", | |
| interactive=False, | |
| ) | |
| is_labeling_in_progress = gr.State( | |
| False, | |
| ) | |
| progressbar = gr.Label( | |
| "", | |
| visible=False, | |
| label="Preparing...", | |
| ) | |
| remove_all = gr.Button( | |
| "🗑️ Remove all", | |
| interactive=False, | |
| ) | |
| with gr.Group(): | |
| # original zip file to drop, see the click envent of download_button | |
| to_delete = gr.State() | |
| # should create a zip file | |
| download_button = gr.Button( | |
| "💾 Create zip file to download", | |
| interactive=False, | |
| size="lg", | |
| ) | |
| # the download section | |
| download_file = gr.File(label="Generated datasets", visible=False) | |
| # to organize dataset in folders or not | |
| organize_files = gr.Checkbox(label="📂 Organize in folders", value=False) | |
| def render_grid(ds, is_labeling_in_progress): | |
| """Render the image grid with labels and controls. | |
| Args: | |
| ds: Current ImageDataset instance | |
| is_labeling_in_progress: Whether labeling is currently in progress | |
| Returns: | |
| None - renders UI components directly | |
| """ | |
| if not ds or len(ds.images) == 0: | |
| gr.Markdown(DOC_CONTENT) | |
| return | |
| # Hidden component to trigger label refresh | |
| with gr.Row(equal_height=True): | |
| for im in ds.images: | |
| with ( | |
| gr.Column( | |
| elem_classes="label-image-box", | |
| preserved_by_key=[ | |
| f"image_{im['id']}", | |
| f"text_{im['id']}", | |
| f"button_{im['id']}", | |
| f"button_clicked_{im['id']}", | |
| f"label_changed_{im['id']}", | |
| ], | |
| ), | |
| ): | |
| # Hidden component to store current image ID | |
| current_image_id = gr.State(value=im["id"]) | |
| image = gr.Image( | |
| im["path"], | |
| type="pil", | |
| container=False, | |
| sources=None, | |
| buttons=["fullscreen"], | |
| height=300, | |
| key=f"image_{im['id']}", | |
| ) | |
| label = gr.Text( | |
| im["label"], | |
| placeholder="Description...", | |
| lines=5, | |
| container=False, | |
| interactive=not is_labeling_in_progress, | |
| key=f"text_{im['id']}", | |
| ) | |
| button = gr.Button( | |
| "✨ Generate label", | |
| interactive=not is_labeling_in_progress, | |
| key=f"button_{im['id']}", | |
| ) | |
| button.click( | |
| auto_label, | |
| inputs=[image, current_image_id, dataset], | |
| outputs=[label, dataset], | |
| key=f"button_clicked_{im['id']}", | |
| ) | |
| # Update dataset when label is changed | |
| label.change( | |
| label_changed, | |
| inputs=[label, current_image_id, dataset], | |
| outputs=[dataset], | |
| key=f"label_changed_{im['id']}", | |
| ) | |
| # Remove everything | |
| remove_all.click( | |
| lambda: ImageDataset(), | |
| inputs=None, | |
| outputs=dataset, | |
| ).then( | |
| update_buttons_states, | |
| inputs=[dataset, is_labeling_in_progress], | |
| outputs=[ | |
| upload_button, | |
| label_all, | |
| progressbar, | |
| remove_all, | |
| download_button, | |
| is_labeling_in_progress, | |
| ], | |
| ) | |
| # Label all images | |
| label_all.click( | |
| fn=start_labeling, | |
| inputs=[dataset], | |
| outputs=[ | |
| upload_button, | |
| label_all, | |
| progressbar, | |
| remove_all, | |
| download_button, | |
| is_labeling_in_progress, | |
| ], | |
| ).then( | |
| fn=labelize_all_images, | |
| inputs=[dataset, progressbar], | |
| outputs=[dataset, progressbar], | |
| ).then( | |
| fn=finish_labeling, | |
| inputs=[dataset], | |
| outputs=[ | |
| upload_button, | |
| label_all, | |
| progressbar, | |
| remove_all, | |
| download_button, | |
| is_labeling_in_progress, | |
| ], | |
| ) | |
| # Upload images | |
| upload_button.upload( | |
| uploaded, | |
| inputs=[upload_button, dataset], | |
| outputs=dataset, | |
| ).then( | |
| update_buttons_states, | |
| inputs=[dataset, is_labeling_in_progress], | |
| outputs=[ | |
| upload_button, | |
| label_all, | |
| progressbar, | |
| remove_all, | |
| download_button, | |
| is_labeling_in_progress, | |
| ], | |
| ) | |
| # create the zip file and set the download file section ready to use | |
| download_button.click( | |
| lambda: gr.update(visible=True), | |
| inputs=None, | |
| outputs=download_file, | |
| ).then( | |
| create_dataset_zipfile, | |
| inputs=[dataset, organize_files], | |
| outputs=[download_file, to_delete], | |
| ).then( | |
| # delete the generated files from /tmp as it is now coppied in gradio cache | |
| lambda x: os.remove(x), | |
| inputs=[to_delete], | |
| ) | |
| if __name__ == "__main__": | |
| CSS = """ | |
| .gr-group { | |
| padding: .2rem; | |
| } | |
| .label-image-box { | |
| } | |
| """ | |
| demo.queue().launch(css=CSS) | |