auto-labelizer / app.py
Metal3d's picture
Auto cleanup temporary files
405761d
import os
import gradio as gr
import spaces
from PIL import Image
from dataset import ImageDataset
from documentation import DOC_CONTENT
from labelizer import get_task_response
from packager import create_dataset_zip
# Drop files after 24 hours
CACHE_TTL = 24 * 60 * 60
@spaces.GPU
def auto_label(
image: Image.Image, imid: int, dataset: ImageDataset
) -> tuple[str, ImageDataset]:
"""Generate automatic label for a single image using AI model.
Args:
image: PIL Image to generate label for
imid: Image ID in the dataset
dataset: Current ImageDataset instance
Returns:
Tuple of (generated_label_text, updated_dataset)
"""
text = get_task_response("<MORE_DETAILED_CAPTION>", image)
ds = dataset.update_label(imid, text)
return text, ds
def label_changed(label: str, imid: int, dataset: ImageDataset) -> ImageDataset:
"""Handle label text change event for an image.
Args:
label: New label text
imid: Image ID in the dataset
dataset: Current ImageDataset instance
Returns:
Updated ImageDataset with new label
"""
return dataset.update_label(imid, label)
def update_single_label(
dataset: ImageDataset, label_text: str, image_id: int
) -> ImageDataset:
"""Update single image label in dataset."""
return dataset.update_label(image_id, label_text)
def uploaded(files: list, current_dataset: ImageDataset | None) -> ImageDataset:
"""Handle file upload - return new dataset instance."""
if current_dataset is None:
current_dataset = ImageDataset()
return current_dataset.add_images(files)
@spaces.GPU
def labelize_all_images(
dataset: ImageDataset, label: str, progress=gr.Progress(True)
) -> tuple[ImageDataset, str]:
"""Generate labels for all images and return new dataset instance."""
# Generate actual labels
labels_dict = {}
for imdata in progress.tqdm(dataset.images):
text = get_task_response("<MORE_DETAILED_CAPTION>", Image.open(imdata["path"])) # type: ignore
labels_dict[imdata["id"]] = text # type: ignore
return dataset.update_all_labels(labels_dict), label
def create_dataset_zipfile(dataset: ImageDataset, organize_in_folders: bool):
"""Create and return zip file for download."""
zip_path = create_dataset_zip(dataset, organize_in_folders)
output = gr.update(visible=True, value=zip_path)
return output, zip_path
def update_buttons_states(dataset: ImageDataset, labeling_in_progress=False):
"""Update all button states based on dataset and labeling progress."""
count = len(dataset.images)
return (
gr.update(interactive=count == 0 and not labeling_in_progress), # upload
gr.update(interactive=count > 0 and not labeling_in_progress), # label all
gr.update(visible=labeling_in_progress), # progressbar
gr.update(interactive=count > 0 and not labeling_in_progress), # remove all
gr.update(interactive=count > 0 and not labeling_in_progress), # download
labeling_in_progress, # is_labeling_in_progress
)
def start_labeling(dataset: ImageDataset):
"""Start labeling process - disable buttons and show progress."""
return update_buttons_states(dataset, labeling_in_progress=True)
def finish_labeling(dataset: ImageDataset):
"""Finish labeling process - enable buttons and hide progress."""
return update_buttons_states(dataset, labeling_in_progress=False)
with gr.Blocks(
title="Labelizer", fill_width=True, delete_cache=(CACHE_TTL, CACHE_TTL)
) as demo:
dataset = gr.State()
with gr.Sidebar():
gr.Markdown("# 🖼️ Image Labeling Tool")
with gr.Group():
gr.Markdown("Upload images and add labels to build your dataset.")
upload_button = gr.UploadButton(
"📁 Upload images",
file_count="multiple",
)
label_all = gr.Button(
"🏷️ Labelize all images",
interactive=False,
)
is_labeling_in_progress = gr.State(
False,
)
progressbar = gr.Label(
"",
visible=False,
label="Preparing...",
)
remove_all = gr.Button(
"🗑️ Remove all",
interactive=False,
)
with gr.Group():
# original zip file to drop, see the click envent of download_button
to_delete = gr.State()
# should create a zip file
download_button = gr.Button(
"💾 Create zip file to download",
interactive=False,
size="lg",
)
# the download section
download_file = gr.File(label="Generated datasets", visible=False)
# to organize dataset in folders or not
organize_files = gr.Checkbox(label="📂 Organize in folders", value=False)
@gr.render(inputs=[dataset, is_labeling_in_progress])
def render_grid(ds, is_labeling_in_progress):
"""Render the image grid with labels and controls.
Args:
ds: Current ImageDataset instance
is_labeling_in_progress: Whether labeling is currently in progress
Returns:
None - renders UI components directly
"""
if not ds or len(ds.images) == 0:
gr.Markdown(DOC_CONTENT)
return
# Hidden component to trigger label refresh
with gr.Row(equal_height=True):
for im in ds.images:
with (
gr.Column(
elem_classes="label-image-box",
preserved_by_key=[
f"image_{im['id']}",
f"text_{im['id']}",
f"button_{im['id']}",
f"button_clicked_{im['id']}",
f"label_changed_{im['id']}",
],
),
):
# Hidden component to store current image ID
current_image_id = gr.State(value=im["id"])
image = gr.Image(
im["path"],
type="pil",
container=False,
sources=None,
buttons=["fullscreen"],
height=300,
key=f"image_{im['id']}",
)
label = gr.Text(
im["label"],
placeholder="Description...",
lines=5,
container=False,
interactive=not is_labeling_in_progress,
key=f"text_{im['id']}",
)
button = gr.Button(
"✨ Generate label",
interactive=not is_labeling_in_progress,
key=f"button_{im['id']}",
)
button.click(
auto_label,
inputs=[image, current_image_id, dataset],
outputs=[label, dataset],
key=f"button_clicked_{im['id']}",
)
# Update dataset when label is changed
label.change(
label_changed,
inputs=[label, current_image_id, dataset],
outputs=[dataset],
key=f"label_changed_{im['id']}",
)
# Remove everything
remove_all.click(
lambda: ImageDataset(),
inputs=None,
outputs=dataset,
).then(
update_buttons_states,
inputs=[dataset, is_labeling_in_progress],
outputs=[
upload_button,
label_all,
progressbar,
remove_all,
download_button,
is_labeling_in_progress,
],
)
# Label all images
label_all.click(
fn=start_labeling,
inputs=[dataset],
outputs=[
upload_button,
label_all,
progressbar,
remove_all,
download_button,
is_labeling_in_progress,
],
).then(
fn=labelize_all_images,
inputs=[dataset, progressbar],
outputs=[dataset, progressbar],
).then(
fn=finish_labeling,
inputs=[dataset],
outputs=[
upload_button,
label_all,
progressbar,
remove_all,
download_button,
is_labeling_in_progress,
],
)
# Upload images
upload_button.upload(
uploaded,
inputs=[upload_button, dataset],
outputs=dataset,
).then(
update_buttons_states,
inputs=[dataset, is_labeling_in_progress],
outputs=[
upload_button,
label_all,
progressbar,
remove_all,
download_button,
is_labeling_in_progress,
],
)
# create the zip file and set the download file section ready to use
download_button.click(
lambda: gr.update(visible=True),
inputs=None,
outputs=download_file,
).then(
create_dataset_zipfile,
inputs=[dataset, organize_files],
outputs=[download_file, to_delete],
).then(
# delete the generated files from /tmp as it is now coppied in gradio cache
lambda x: os.remove(x),
inputs=[to_delete],
)
if __name__ == "__main__":
CSS = """
.gr-group {
padding: .2rem;
}
.label-image-box {
}
"""
demo.queue().launch(css=CSS)