Spaces:
Sleeping
Sleeping
| ##!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import gradio as gr | |
| import os, sys | |
| import json | |
| import copy | |
| import spaces | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torchvision.utils import save_image | |
| from transformers import AutoImageProcessor, Dinov2Model | |
| from segment_anything import SamPredictor, sam_model_registry | |
| from diffusers import ( | |
| StableDiffusionBlobNetPipeline, | |
| BlobNetModel, | |
| UNet2DConditionModel, | |
| UniPCMultistepScheduler, | |
| DDIMScheduler, | |
| DPMSolverMultistepScheduler, | |
| ) | |
| from huggingface_hub import snapshot_download | |
| sys.path.append(os.getcwd()+ '/examples/blobctrl') | |
| from utils.utils import splat_features, viz_score_fn, BLOB_VIS_COLORS, vis_gt_ellipse_from_ellipse | |
| weight_dtype = torch.float16 | |
| device = "cuda" | |
| # download blobctrl models | |
| BlobCtrl_path = "examples/blobctrl/models" | |
| if not (os.path.exists(f"{BlobCtrl_path}/blobnet") and os.path.exists(f"{BlobCtrl_path}/unet_lora")): | |
| BlobCtrl_path = snapshot_download( | |
| repo_id="Yw22/BlobCtrl", | |
| local_dir=BlobCtrl_path, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| print(f"BlobCtrl checkpoints downloaded to {BlobCtrl_path}") | |
| # download stable-diffusion-v1-5 | |
| StableDiffusion_path = "examples/blobctrl/models/stable-diffusion-v1-5" | |
| if not os.path.exists(StableDiffusion_path): | |
| StableDiffusion_path = snapshot_download( | |
| repo_id="sd-legacy/stable-diffusion-v1-5", | |
| local_dir=StableDiffusion_path, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| print(f"StableDiffusion checkpoints downloaded to {StableDiffusion_path}") | |
| # download dinov2-large | |
| Dino_path = "examples/blobctrl/models/dinov2-large" | |
| if not os.path.exists(Dino_path): | |
| Dino_path = snapshot_download( | |
| repo_id="facebook/dinov2-large", | |
| local_dir=Dino_path, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| print(f"Dino checkpoints downloaded to {Dino_path}") | |
| # download SAM model | |
| SAM_path = "examples/blobctrl/models/sam" | |
| if not os.path.exists(SAM_path): | |
| SAM_path = snapshot_download( | |
| repo_id="kunkaran/sam_vit_h_4b8939.pth", | |
| local_dir=SAM_path, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| print(f"SAM checkpoints downloaded to {SAM_path}") | |
| # Check if SAM model file exists | |
| sam_model_file = os.path.join(SAM_path, "sam_vt_h.pth") | |
| if os.path.exists(sam_model_file): | |
| print(f"SAM model file found at {sam_model_file}") | |
| else: | |
| print(f"SAM model file not found at {sam_model_file}") | |
| ## load models and pipeline | |
| blobnet_path = "./examples/blobctrl/models/blobnet" | |
| unet_lora_path = "./examples/blobctrl/models/unet_lora" | |
| stabel_diffusion_model_path = "./examples/blobctrl/models/stable-diffusion-v1-5" | |
| dinov2_path = "./examples/blobctrl/models/dinov2-large" | |
| sam_path = "./examples/blobctrl/models/sam/sam_vt_h.pth" | |
| ## unet | |
| print(f"Loading UNet...") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| stabel_diffusion_model_path, | |
| subfolder="unet", | |
| ) | |
| with torch.no_grad(): | |
| initial_input_channels = unet.config.in_channels | |
| new_conv_in = torch.nn.Conv2d( | |
| initial_input_channels + 1, | |
| unet.conv_in.out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=unet.conv_in.bias is not None, | |
| dtype=unet.dtype, | |
| device=unet.device, | |
| ) | |
| new_conv_in.weight.zero_() | |
| new_conv_in.weight[:, :initial_input_channels].copy_(unet.conv_in.weight) | |
| if unet.conv_in.bias is not None: | |
| new_conv_in.bias.copy_(unet.conv_in.bias) | |
| unet.conv_in = new_conv_in | |
| ## blobnet | |
| print(f"Loading BlobNet...") | |
| blobnet = BlobNetModel.from_pretrained(blobnet_path, ignore_mismatched_sizes=True) | |
| ## sam | |
| print(f"Loading SAM...") | |
| mobile_sam = sam_model_registry['vit_h'](checkpoint=sam_path).to(device) | |
| # mobile_sam.eval() | |
| mobile_predictor = SamPredictor(mobile_sam) | |
| colors = [(255, 0, 0), (0, 255, 0)] | |
| markers = [1, 5] | |
| rgba_colors = [(255, 0, 255, 255), (0, 255, 0, 255), (0, 0, 255, 255)] | |
| ## dinov2 | |
| print(f"Loading Dinov2...") | |
| dinov2_processor = AutoImageProcessor.from_pretrained(dinov2_path) | |
| dinov2 = Dinov2Model.from_pretrained(dinov2_path).to(device) | |
| ## stable diffusion with blobnet pipeline | |
| print(f"Loading StableDiffusionBlobNetPipeline...") | |
| pipeline = StableDiffusionBlobNetPipeline.from_pretrained( | |
| stabel_diffusion_model_path, | |
| unet=unet, | |
| blobnet=blobnet, | |
| torch_dtype=weight_dtype, | |
| dinov2_processor=dinov2_processor, | |
| dinov2=dinov2, | |
| ) | |
| print(f"Loading UNetLora...") | |
| pipeline.load_lora_weights( | |
| unet_lora_path, | |
| adapter_name="default", | |
| ) | |
| pipeline.set_adapters(["default"]) | |
| pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) | |
| # pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | |
| pipeline.to(device) | |
| pipeline.set_progress_bar_config(leave=False) | |
| ## meta info | |
| logo = r""" | |
| <center><img src='./examples/blobctrl/assets/logo_512.png' alt='BlobCtrl logo' style="width:80px; margin-bottom:10px"></center> | |
| """ | |
| head= r""" | |
| <div style="text-align: center;"> | |
| <h1> BlobCtrl: A Unified and Flexible Framework for Element-level Image Generation and Editing </h1> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <a href='https://liyaowei-stu.github.io/project/BlobCtrl/'><img src='https://img.shields.io/badge/Project_Page-BlobCtrl-green' alt='Project Page'></a> | |
| <a href='http://arxiv.org/abs/2503.13434'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a> | |
| <a href='https://github.com/TencentARC/BlobCtrl'><img src='https://img.shields.io/badge/Code-Github-orange'></a> | |
| </div> | |
| </br> | |
| </div> | |
| """ | |
| descriptions = r""" | |
| Official Gradio Demo for <a href=''><b>BlobCtrl: A Unified and Flexible Framework for Element-level Image Generation and Editing</b></a><br> | |
| 🦉 BlobCtrl enables precise, user-friendly element-level visual manipulation. <br> | |
| """ | |
| citation = r""" | |
| If BlobCtrl is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/BlobCtrl' target='_blank'>Github Repo</a>. Thanks! | |
| []() | |
| --- | |
| 📝 **Citation** | |
| <br> | |
| If our work is useful for your research, please consider citing: | |
| ```bibtex | |
| @misc{li2025blobctrl, | |
| title={BlobCtrl: A Unified and Flexible Framework for Element-level Image Generation and Editing}, | |
| author={Yaowei Li}, | |
| year={2025}, | |
| eprint={2502.09477}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.CV} | |
| } | |
| ``` | |
| 📧 **Contact** | |
| <br> | |
| If you have any questions, please feel free to reach me out at <b>liyaowei@gmail.com</b>. | |
| """ | |
| # - - - - - examples - - - - - # | |
| EXAMPLES= [ | |
| [ | |
| "examples/blobctrl/assets/results/demo/move_hat/input_image/input_image.png", | |
| "A frog sits on a rock in a pond, with a top hat beside it, surrounded by butterflies and vibrant flowers.", | |
| 1.0, | |
| 0.0, | |
| 0.9, | |
| 1248464818, | |
| 0, | |
| ], | |
| [ | |
| "examples/blobctrl/assets/results/demo/move_cup/input_image/input_image.png", | |
| "a rustic wooden table.", | |
| 1.0, | |
| 0.0, | |
| 1.0, | |
| 1248464818, | |
| 1, | |
| ], | |
| [ | |
| "examples/blobctrl/assets/results/demo/enlarge_deer/input_image/input_image.png", | |
| "A cute, young deer with large ears standing in a grassy field at sunrise, surrounded by trees.", | |
| 1.6, | |
| 0.0, | |
| 1.0, | |
| 1288911487, | |
| 2, | |
| ], | |
| [ | |
| "examples/blobctrl/assets/results/demo/shrink_dragon/input_image/input_image.png", | |
| "A detailed, handcrafted cardboard dragon with red wings and expressive eyes.", | |
| 1.0, | |
| 0.0, | |
| 1.0, | |
| 1248464818, | |
| 3, | |
| ], | |
| [ | |
| "examples/blobctrl/assets/results/demo/remove_shit/input_image/input_image.png", | |
| "The background consists of a textured, gray concrete surface with a red brick wall behind it. The bricks are arranged in a classic pattern, showcasing various shades of red and some weathering.", | |
| 1.0, | |
| 0.0, | |
| 1.0, | |
| 1248464818, | |
| 4, | |
| ], | |
| [ | |
| "examples/blobctrl/assets/results/demo/remove_cow/input_image/input_image.png", | |
| "A majestic mountain range with rugged peaks under a cloudy sky, and a grassy field in the foreground.", | |
| 1.0, | |
| 0.0, | |
| 1.0, | |
| 1248464818, | |
| 5, | |
| ], | |
| [ | |
| "examples/blobctrl/assets/results/demo/compose_rabbit/input_image/input_image.png", | |
| "A cute brown rabbit sitting on a wooden surface with a serene lake and mountains in the background.", | |
| 1.0, | |
| 0.0, | |
| 1.0, | |
| 1248464818, | |
| 6, | |
| ], | |
| [ | |
| "examples/blobctrl/assets/results/demo/compose_cake/input_image/input_image.png", | |
| " slice of cake on a light blue background.", | |
| 1.2, | |
| 0.0, | |
| 1.0, | |
| 1248464818, | |
| 7, | |
| ], | |
| [ | |
| "examples/blobctrl/assets/results/demo/replace_knife/input_image/input_image.png", | |
| "A slice of cake on a light blue background, with a knife in the center.", | |
| 1.2, | |
| 0.0, | |
| 1.0, | |
| 1248464818, | |
| 8, | |
| ] | |
| ] | |
| # | |
| OBJECT_IMAGE_GALLERY = [ | |
| ["examples/blobctrl/assets/results/demo/move_hat/object_image_gallery/validation_object_region_center.png"], | |
| ["examples/blobctrl/assets/results/demo/move_cup/object_image_gallery/validation_object_region_center.png"], | |
| ["examples/blobctrl/assets/results/demo/enlarge_deer/object_image_gallery/validation_object_region_center.png"], | |
| ["examples/blobctrl/assets/results/demo/shrink_dragon/object_image_gallery/validation_object_region_center.png"], | |
| ["examples/blobctrl/assets/results/demo/remove_shit/object_image_gallery/validation_object_region_center.png"], | |
| ["examples/blobctrl/assets/results/demo/remove_cow/object_image_gallery/validation_object_region_center.png"], | |
| ["examples/blobctrl/assets/results/demo/compose_rabbit/object_image_gallery/validation_object_region_center.png"], | |
| ["examples/blobctrl/assets/results/demo/compose_cake/object_image_gallery/validation_object_region_center.png"], | |
| ["examples/blobctrl/assets/results/demo/replace_knife/object_image_gallery/validation_object_region_center.png"], | |
| ] | |
| ORI_RESULT_GALLERY = [ | |
| ["examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_4.png"], | |
| ["examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_4.png"], | |
| ["examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_4.png"], | |
| ["examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_4.png"], | |
| ["examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_4.png"], | |
| ["examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_4.png"], | |
| ["examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_4.png"], | |
| ["examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_4.png"], | |
| ["examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_4.png"], | |
| ] | |
| EDITABLE_BLOB = [ | |
| "examples/blobctrl/assets/results/demo/move_hat/editable_blob/editable_blob.png", | |
| "examples/blobctrl/assets/results/demo/move_cup/editable_blob/editable_blob.png", | |
| "examples/blobctrl/assets/results/demo/enlarge_deer/editable_blob/editable_blob.png", | |
| "examples/blobctrl/assets/results/demo/shrink_dragon/editable_blob/editable_blob.png", | |
| "examples/blobctrl/assets/results/demo/remove_shit/editable_blob/editable_blob.png", | |
| "examples/blobctrl/assets/results/demo/remove_cow/editable_blob/editable_blob.png", | |
| "examples/blobctrl/assets/results/demo/compose_rabbit/editable_blob/editable_blob.png", | |
| "examples/blobctrl/assets/results/demo/compose_cake/editable_blob/editable_blob.png", | |
| "examples/blobctrl/assets/results/demo/replace_knife/editable_blob/editable_blob.png", | |
| ] | |
| EDITED_RESULT_GALLERY = [ | |
| ["examples/blobctrl/assets/results/demo/move_hat/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/move_hat/edited_result_gallery/edited_result_gallery_1.png"], | |
| ["examples/blobctrl/assets/results/demo/move_cup/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/move_cup/edited_result_gallery/edited_result_gallery_1.png"], | |
| ["examples/blobctrl/assets/results/demo/enlarge_deer/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/enlarge_deer/edited_result_gallery/edited_result_gallery_1.png"], | |
| ["examples/blobctrl/assets/results/demo/shrink_dragon/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/shrink_dragon/edited_result_gallery/edited_result_gallery_1.png"], | |
| ["examples/blobctrl/assets/results/demo/remove_shit/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_shit/edited_result_gallery/edited_result_gallery_1.png"], | |
| ["examples/blobctrl/assets/results/demo/remove_cow/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_cow/edited_result_gallery/edited_result_gallery_1.png"], | |
| ["examples/blobctrl/assets/results/demo/compose_rabbit/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_rabbit/edited_result_gallery/edited_result_gallery_1.png"], | |
| ["examples/blobctrl/assets/results/demo/compose_cake/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_cake/edited_result_gallery/edited_result_gallery_1.png"], | |
| ["examples/blobctrl/assets/results/demo/replace_knife/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/replace_knife/edited_result_gallery/edited_result_gallery_1.png"], | |
| ] | |
| RESULTS_GALLERY = [ | |
| ["examples/blobctrl/assets/results/demo/move_hat/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/move_hat/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/move_hat/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/move_hat/results_gallery/results_gallery_3.png"], | |
| ["examples/blobctrl/assets/results/demo/move_cup/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/move_cup/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/move_cup/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/move_cup/results_gallery/results_gallery_3.png"], | |
| ["examples/blobctrl/assets/results/demo/enlarge_deer/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/enlarge_deer/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/enlarge_deer/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/enlarge_deer/results_gallery/results_gallery_3.png"], | |
| ["examples/blobctrl/assets/results/demo/shrink_dragon/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/shrink_dragon/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/shrink_dragon/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/shrink_dragon/results_gallery/results_gallery_3.png"], | |
| ["examples/blobctrl/assets/results/demo/remove_shit/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_shit/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/remove_shit/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/remove_shit/results_gallery/results_gallery_3.png"], | |
| ["examples/blobctrl/assets/results/demo/remove_cow/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_cow/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/remove_cow/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/remove_cow/results_gallery/results_gallery_3.png"], | |
| ["examples/blobctrl/assets/results/demo/compose_rabbit/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_rabbit/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/compose_rabbit/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/compose_rabbit/results_gallery/results_gallery_3.png"], | |
| ["examples/blobctrl/assets/results/demo/compose_cake/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_cake/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/compose_cake/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/compose_cake/results_gallery/results_gallery_3.png"], | |
| ["examples/blobctrl/assets/results/demo/replace_knife/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/replace_knife/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/replace_knife/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/replace_knife/results_gallery/results_gallery_3.png"], | |
| ] | |
| ELLIPSE_LISTS = [ | |
| [[[[227.10665893554688, 118.85255432128906], [85.48122482299804, 103.65433502197266], 87.37393951416016], [1, 1, 1, 0], 0], [[[361.1066589355469, 367.85255432128906], [85.48122482299804, 103.65433502197266], 87.37393951416016], [1, 1, 1, 0], 1]], | |
| [[[[249.1703643798828, 149.63021850585938], [83.36424179077149, 115.79973449707032], 0.8257154226303101], [1, 1, 1, 0], 0], [[[245.1703643798828, 270.6302185058594], [83.36424179077149, 115.79973449707032], 0.8257154226303101], [1, 1, 1, 0], 1]], | |
| [[[[234.69358825683594, 255.60946655273438], [196.208619140625, 341.067111328125], 15.866915702819824], [1, 1, 1, 0], 0], [[[234.69358825683594, 255.60946655273438], [226.394560546875, 393.538974609375], 15.866915702819824], [1.2, 1, 1, 0], 2], [[[234.69358825683594, 255.60946655273438], [237.71428857421876, 413.21592333984376], 15.866915702819824], [1.05, 1, 1, 0], 2], [[[237.69358825683594, 237.60946655273438], [237.71428857421876, 413.21592333984376], 15.866915702819824], [1.05, 1, 1, 0], 1], [[[237.69358825683594, 233.60946655273438], [237.71428857421876, 413.21592333984376], 15.866915702819824], [1.05, 1, 1, 0], 1]], | |
| [[[[367.17742919921875, 201.1094512939453], [206.3889125696118, 377.8448820272314], 56.17562484741211], [1, 1, 1, 0], 0], [[[367.17742919921875, 201.1094512939453], [147.91468688964844, 297.0842980957031], 56.17562484741211], [0.8, 1, 1, 0], 2], [[[367.17742919921875, 201.1094512939453], [140.518952545166, 282.2300831909179], 56.17562484741211], [0.95, 1, 1, 0], 2], [[[324.17742919921875, 235.1094512939453], [140.518952545166, 282.2300831909179], 56.17562484741211], [0.95, 1, 1, 0], 1], [[[335.17742919921875, 225.1094512939453], [140.518952545166, 282.2300831909179], 56.17562484741211], [0.95, 1, 1, 0], 1]], | |
| [[[[255.23663330078125, 315.4020080566406], [263.64675201416014, 295.38494384765625], 153.8949432373047], [1, 1, 1, 0], 0]], | |
| [[[[335.09979248046875, 236.41409301757812], [168.37833966064454, 345.3470615478516], 0.7639619708061218], [1, 1, 1, 0], 0]], | |
| [[[[256.0, 256.0], [1e-05, 1e-05], 0], [1, 1, 1, 0], 0], [[[271.6672, 275.3536], [136.85061800371966, 303.75044578074284], 177.008], [1, 1, 1, 0], 0], [[[271.6672, 275.3536], [150.53567980409164, 303.75044578074284], 177.008], [1.1, 1, 1.1, 0], 4], [[[271.6672, 275.3536], [158.06246379429624, 318.93796806977997], 177.008], [1.05, 1, 1.1, 0], 2], [[[271.6672, 275.3536], [165.96558698401105, 334.88486647326897], 177.008], [1.05, 1, 1.1, 0], 2], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 177.008], [1.1, 1, 1.1, 0], 4], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 7.00800000000001], [1.1, 1, 1.1, 10], 5], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 3.0080000000000098], [1.1, 1, 1.1, -4], 5], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 177.008], [1.1, 1, 1.1, -6], 5], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 179.008], [1.1, 1, 1.1, 2], 5], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 178.008], [1.1, 1, 1.1, -1], 5], [[[271.6672, 275.3536], [182.56214568241217, 368.3733531205959], 178.008], [1.1, 1.1, 1.1, -1], 3], [[[271.6672, 275.3536], [182.56214568241217, 349.95468546456607], 178.008], [1.1, 0.95, 1.1, -1], 3], [[[271.6672, 275.3536], [182.56214568241217, 349.95468546456607], 170.008], [1.1, 0.95, 1.1, -8], 5], [[[271.6672, 275.3536], [182.56214568241217, 349.95468546456607], 172.008], [1.1, 0.95, 1.1, 2], 5]], | |
| [[[[256.0, 256.0], [1e-05, 1e-05], 0], [1, 1, 1, 0], 0], [[[256.0, 256.0], [144.81546878700496, 144.81546878700496], 0], [1, 1, 1, 0], 0], [[[256.0, 256.0], [123.09314846895421, 123.09314846895421], 0], [0.85, 1, 1, 0], 2], [[[256.0, 256.0], [110.7838336220588, 110.7838336220588], 0], [0.9, 1, 1, 0], 2], [[[88.0, 418.0], [110.7838336220588, 110.7838336220588], 0], [0.9, 1, 1, 0], 1]], | |
| [[[[164.6718292236328, 385.8408508300781], [41.45796089172364, 319.87034912109374], 142.05267333984375], [1, 1, 1, 0], 0]], | |
| ] | |
| TRACKING_POINTS = [ | |
| [[227, 118], [361, 367]], | |
| [[249, 150], [248, 269]], | |
| [[234, 255], [234, 255], [234, 255], [237, 237], [237, 233]], | |
| [[367, 201], [367, 201], [367, 201], [324, 235], [335, 225]], | |
| [[255, 315]], | |
| [[335, 236]], | |
| [[256, 256], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271]], | |
| [[256, 256], [256, 256], [256, 256], [256, 256], [88, 418]], | |
| [[164, 385]], | |
| ] | |
| REMOVE_STATE=[ | |
| False, | |
| False, | |
| False, | |
| False, | |
| True, | |
| True, | |
| False, | |
| False, | |
| False, | |
| ] | |
| INPUT_IMAGE=[ | |
| "examples/blobctrl/assets/results/demo/move_hat/input_image/input_image.png", | |
| "examples/blobctrl/assets/results/demo/move_cup/input_image/input_image.png", | |
| "examples/blobctrl/assets/results/demo/enlarge_deer/input_image/input_image.png", | |
| "examples/blobctrl/assets/results/demo/shrink_dragon/input_image/input_image.png", | |
| "examples/blobctrl/assets/results/demo/remove_shit/input_image/input_image.png", | |
| "examples/blobctrl/assets/results/demo/remove_cow/input_image/input_image.png", | |
| "examples/blobctrl/assets/results/demo/compose_rabbit/input_image/input_image.png", | |
| "examples/blobctrl/assets/results/demo/compose_cake/input_image/input_image.png", | |
| "examples/blobctrl/assets/results/demo/replace_knife/input_image/input_image.png", | |
| ] | |
| ## normal functions | |
| def _get_ellipse(mask): | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| contours_copy = [contour_copy.tolist() for contour_copy in contours] | |
| concat_contours = np.concatenate(contours, axis=0) | |
| hull = cv2.convexHull(concat_contours) | |
| ellipse = cv2.fitEllipse(hull) | |
| return ellipse, contours_copy | |
| def ellipse_to_gaussian(x, y, a, b, theta): | |
| """ | |
| Convert ellipse parameters to mean and covariance matrix of a Gaussian distribution. | |
| Parameters: | |
| x (float): x-coordinate of the ellipse center. | |
| y (float): y-coordinate of the ellipse center. | |
| a (float): Length of the minor semi-axis of the ellipse. | |
| b (float): Length of the major semi-axis of the ellipse. | |
| theta (float): Rotation angle of the ellipse (in radians), counterclockwise angle of the major axis. | |
| Returns: | |
| mean (numpy.ndarray): Mean of the Gaussian distribution, an array of shape (2,) representing (x, y) coordinates. | |
| cov_matrix (numpy.ndarray): Covariance matrix of the Gaussian distribution, an array of shape (2, 2). | |
| """ | |
| # Mean | |
| mean = np.array([x, y]) | |
| # Diagonal elements of the covariance matrix | |
| # sigma_x = b / np.sqrt(2) | |
| # sigma_y = a / np.sqrt(2) | |
| # Not dividing by sqrt(2) is also acceptable. This conversion is mainly for specific statistical contexts, | |
| # to make the semi-axis length of the ellipse correspond to one standard deviation of the Gaussian distribution. | |
| # The purpose is to make the ellipse area contain about 68% of the probability mass of the Gaussian distribution | |
| # (in a one-dimensional Gaussian distribution, one standard deviation contains about 68% of the probability mass). | |
| # Diagonal elements of the covariance matrix | |
| sigma_x = b | |
| sigma_y = a | |
| # Covariance matrix (before rotation) | |
| cov_matrix = np.array([[sigma_x**2, 0], | |
| [0, sigma_y**2]]) | |
| # Rotation matrix | |
| R = np.array([[np.cos(theta), -np.sin(theta)], | |
| [np.sin(theta), np.cos(theta)]]) | |
| # Rotate the covariance matrix | |
| cov_matrix_rotated = R @ cov_matrix @ R.T | |
| cov_matrix_rotated[0, 1] *= -1 # Reverse the non-diagonal elements of the covariance matrix | |
| cov_matrix_rotated[1, 0] *= -1 # Reverse the non-diagonal elements of the covariance matrix | |
| # eigenvalues, eigenvectors = np.linalg.eig(cov_matrix_rotated) | |
| return mean, cov_matrix_rotated | |
| def normalize_gs(mean, cov_matrix_rotated, width, height): | |
| # Normalize mean | |
| normalized_mean = mean / np.array([width, height]) | |
| # Calculate maximum length for normalizing the covariance matrix | |
| max_length = np.sqrt(width**2 + height**2) | |
| # Normalize covariance matrix | |
| normalized_cov_matrix = cov_matrix_rotated / (max_length ** 2) | |
| return normalized_mean, normalized_cov_matrix | |
| def normalize_ellipse(ellipse, width, height): | |
| (xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse | |
| max_length = np.sqrt(width**2 + height**2) | |
| normalized_xc, normalized_yc = xc/width, yc/height | |
| normalized_d1, normalized_d2 = d1/max_length, d2/max_length | |
| return normalized_xc, normalized_yc, normalized_d1, normalized_d2, angle_clockwise_short_axis | |
| def composite_mask_and_image(mask, image, masked_color=[0,0,0]): | |
| if isinstance(mask, Image.Image): | |
| mask_np = np.array(mask) | |
| else: | |
| mask_np = mask | |
| if isinstance(image, Image.Image): | |
| image_np = np.array(image) | |
| else: | |
| image_np = image | |
| if mask_np.ndim == 2: | |
| mask_indicator = (mask_np>0).astype(np.uint8) | |
| else: | |
| mask_indicator = (mask_np.sum(-1)>255).astype(np.uint8) | |
| masked_image = image_np * (1-mask_indicator[:,:,np.newaxis]) + masked_color * mask_indicator[:,:,np.newaxis] | |
| return Image.fromarray(masked_image.astype(np.uint8)).convert("RGB") | |
| def is_point_in_ellipse(point, ellipse): | |
| # 提取椭圆参数 | |
| (xc, yc), (d1, d2), angle = ellipse | |
| # 将角度转换为弧度 | |
| theta = np.radians(angle) | |
| # 计算相对坐标 | |
| x, y = point | |
| x_prime = x - xc | |
| y_prime = y - yc | |
| # 计算旋转后的坐标 | |
| x_rotated = x_prime * np.cos(theta) - y_prime * np.sin(theta) | |
| y_rotated = x_prime * np.sin(theta) + y_prime * np.cos(theta) | |
| # 计算椭圆方程,d1 和 d2 是全长轴和全短轴,需除以 2 | |
| ellipse_equation = (x_rotated**2) / ((d1 / 2)**2) + (y_rotated**2) / ((d2 / 2)**2) | |
| # 判断点是否在椭圆内 | |
| return ellipse_equation <= 1 | |
| def calculate_ellipse_vertices(ellipse): | |
| (xc, yc), (d1, d2), angle_clockwise_short_axis = ellipse | |
| # Convert angle from degrees to radians | |
| angle_rad = np.deg2rad(angle_clockwise_short_axis) | |
| # Calculate the rotation matrix | |
| rotation_matrix = np.array([ | |
| [np.cos(angle_rad), -np.sin(angle_rad)], | |
| [np.sin(angle_rad), np.cos(angle_rad)] | |
| ]) | |
| # Calculate the unrotated vertices | |
| half_d1 = d1 / 2 | |
| half_d2 = d2 / 2 | |
| vertices = np.array([ | |
| [half_d1, 0], # Rightmost point on the long axis | |
| [-half_d1, 0], # Leftmost point on the long axis | |
| [0, half_d2], # Topmost point on the short axis | |
| [0, -half_d2] # Bottommost point on the short axis | |
| ]) | |
| # Rotate the vertices | |
| rotated_vertices = np.dot(vertices, rotation_matrix.T) | |
| # Translate vertices to the original center | |
| final_vertices = rotated_vertices + np.array([xc, yc]) | |
| return final_vertices | |
| def move_ellipse(ellipse, tracking_points): | |
| (xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse | |
| last_xc, last_yc = tracking_points[-1] | |
| second_last_xc, second_last_yc = tracking_points[-2] | |
| vx = last_xc - second_last_xc | |
| vy = last_yc - second_last_yc | |
| xc += vx | |
| yc += vy | |
| return (xc,yc), (d1,d2), angle_clockwise_short_axis | |
| def resize_blob_func(ellipse, resizing_factor, height, width, resize_type): | |
| (xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse | |
| too_big = False | |
| too_small = False | |
| min_blob_area = 1600 | |
| exceed_threshold = 0.4 | |
| while True: | |
| if resize_type == 0: | |
| resized_d1 = d1 * resizing_factor | |
| resized_d2 = d2 * resizing_factor | |
| elif resize_type == 1: | |
| resized_d1 = d1 | |
| resized_d2 = d2 * resizing_factor | |
| elif resize_type == 2: | |
| resized_d1 = d1 * resizing_factor | |
| resized_d2 = d2 | |
| resized_ellipse = (xc,yc), (resized_d1, resized_d2), angle_clockwise_short_axis | |
| resized_ellipse_vertices = calculate_ellipse_vertices(resized_ellipse) | |
| resized_ellipse_vertices = resized_ellipse_vertices / np.array([width, height]) | |
| if resizing_factor != 1: | |
| # soft the threshold allowed to exceed the image range | |
| if np.all(resized_ellipse_vertices >= -exceed_threshold) and np.all(resized_ellipse_vertices <= 1+exceed_threshold): | |
| # calculate the blob area | |
| blob_area = np.pi * (resized_d1 / 2) * (resized_d2 / 2) | |
| if blob_area >= min_blob_area: | |
| break | |
| else: | |
| too_small = True | |
| resizing_factor += 0.1 | |
| if blob_area < 1e-6: | |
| ## if the blob area is too too too small, break | |
| break | |
| else: | |
| too_big = True | |
| resizing_factor -= 0.1 | |
| else: | |
| break | |
| if too_big: | |
| gr.Warning(f"The blob is too big, adaptive reduction of magnification to fit the image, The threshold allowed to exceed the image range is {exceed_threshold}") | |
| if too_small: | |
| gr.Warning(f"The blob is too small, adaptive enlargement of magnification to fit the image, The minimum blob area is {min_blob_area} px") | |
| return resized_ellipse, resizing_factor | |
| def rotate_blob_func(ellipse, rotation_degree): | |
| (xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse | |
| rotated_angle_clockwise_short_axis = (angle_clockwise_short_axis + rotation_degree) % 180 | |
| rotated_ellipse = (xc,yc), (d1,d2), rotated_angle_clockwise_short_axis | |
| return rotated_ellipse, rotation_degree | |
| def get_theta_anti_clockwise_long_axis(angle_clockwise_short_axis): | |
| angle_anti_clockwise_short_axis = (180 - angle_clockwise_short_axis) % 180 | |
| angle_anti_clockwise_long_axis = (angle_anti_clockwise_short_axis + 90) % 180 | |
| theta_anti_clockwise_long_axis = np.radians(angle_anti_clockwise_long_axis) | |
| return theta_anti_clockwise_long_axis | |
| def get_gs_from_ellipse(ellipse): | |
| (xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse | |
| theta_anti_clockwise_long_axis = get_theta_anti_clockwise_long_axis(angle_clockwise_short_axis) | |
| a = d1 / 2 | |
| b = d2 / 2 | |
| mean, cov_matrix = ellipse_to_gaussian(xc, yc, a, b, theta_anti_clockwise_long_axis) | |
| return mean, cov_matrix | |
| def get_blob_dict_from_norm_gs(normalized_mean, normalized_cov_matrix): | |
| xs, ys = normalized_mean | |
| blob = { | |
| "xs": torch.tensor(xs).unsqueeze(0), | |
| "ys": torch.tensor(ys).unsqueeze(0), | |
| "covs": torch.tensor(normalized_cov_matrix).unsqueeze(0).unsqueeze(0), | |
| "sizes": torch.tensor([1.0]).unsqueeze(0), | |
| } | |
| return blob | |
| def clear_ellipse_lists(ellipse_lists): | |
| ellipse_lists = [] | |
| return ellipse_lists | |
| def get_blob_vis_img_from_blob_dict(blob, viz_size=64, score_size=64): | |
| blob_vis = splat_features(**blob, | |
| interp_size=64, | |
| viz_size=viz_size, | |
| is_viz=True, | |
| ret_layout=True, | |
| score_size=score_size, | |
| viz_score_fn=viz_score_fn, | |
| viz_colors=BLOB_VIS_COLORS, | |
| only_vis=True)["feature_img"] | |
| blob_vis_img = blob_vis[0].permute(1,2,0).contiguous().cpu().numpy() | |
| blob_vis_img = (blob_vis_img*255).astype(np.uint8) | |
| blob_vis_img = Image.fromarray(blob_vis_img) | |
| return blob_vis_img | |
| def get_blob_score_from_blob_dict(blob, score_size=64): | |
| blob_score = splat_features(**blob, | |
| score_size=score_size, | |
| return_d_score=True, | |
| )[0] | |
| return blob_score | |
| def get_object_region_from_mask(mask, original_image): | |
| if isinstance(mask, Image.Image): | |
| mask_np = np.array(mask) | |
| else: | |
| mask_np = mask | |
| if mask_np.ndim == 2: | |
| mask_indicator = (mask_np>0).astype(np.uint8) | |
| else: | |
| mask_indicator = (mask_np.sum(-1)>255).astype(np.uint8) | |
| x, y, w, h = cv2.boundingRect(mask_indicator) | |
| rect_mask = mask_indicator[y:y+h, x:x+w] | |
| tmp = original_image.copy() | |
| rect_region = tmp[y:y+h, x:x+w] | |
| rect_region_object_white_background = np.where(rect_mask[:, :, None] > 0, rect_region, 255) | |
| target_height, target_width = tmp.shape[:2] | |
| start_y = (target_height - h) // 2 | |
| start_x = (target_width - w) // 2 | |
| rect_region_object_white_background_center = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255 | |
| rect_region_object_white_background_center[start_y:start_y+h, start_x:start_x+w] = rect_region_object_white_background | |
| rect_region_object_white_background_center = Image.fromarray(rect_region_object_white_background_center).convert("RGB") | |
| return rect_region_object_white_background_center | |
| def extract_contours(object_image): | |
| """ | |
| Extract contours from an object image | |
| :param object_image: Input object image, shape (h, w, 3), value range [0, 255] | |
| :return: Contour image | |
| """ | |
| # 将图像转换为灰度图 | |
| gray_image = cv2.cvtColor(object_image, cv2.COLOR_BGR2GRAY) | |
| # 将图像二值化,假设物体不是白色 [255, 255, 255] | |
| _, binary_image = cv2.threshold(gray_image, 240, 255, cv2.THRESH_BINARY_INV) | |
| # 提取轮廓 | |
| contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # 创建一个空白图像用于绘制轮廓 | |
| contour_image = np.zeros_like(gray_image) | |
| # 在空白图像上绘制轮廓 | |
| cv2.drawContours(contour_image, contours, -1, (255), thickness=cv2.FILLED) | |
| return contour_image | |
| def get_mask_from_ellipse(ellipse, height, width): | |
| ellipse_mask_np = np.zeros((height, width)) | |
| ellipse_mask_np = cv2.ellipse(ellipse_mask_np, ellipse, 255, -1) | |
| ellipse_mask = Image.fromarray(ellipse_mask_np).convert("L") | |
| return ellipse_mask | |
| # gradio functions | |
| def run_function( | |
| original_image, | |
| scene_prompt, | |
| ori_result_gallery, | |
| object_image_gallery, | |
| edited_result_gallery, | |
| ellipse_lists, | |
| blobnet_control_strength, | |
| blobnet_control_guidance_start, | |
| blobnet_control_guidance_end, | |
| remove_blob_box, | |
| num_samples, | |
| seed, | |
| guidance_scale, | |
| num_inference_steps, | |
| ## for save | |
| editable_blob, | |
| resize_blob_slider_maintain_aspect_ratio, | |
| resize_blob_slider_along_long_axis, | |
| resize_blob_slider_along_short_axis, | |
| rotation_blob_slider, | |
| resize_init_blob_slider, | |
| resize_init_blob_slider_long_axis, | |
| resize_init_blob_slider_short_axis, | |
| tracking_points, | |
| ): | |
| if object_image_gallery == [] or object_image_gallery == None or ori_result_gallery == [] or ori_result_gallery == None: | |
| gr.Warning("Please generate the blob first") | |
| return None | |
| if edited_result_gallery == [] or edited_result_gallery == None: | |
| gr.Warning("Please click the region in the blob in the first time.") | |
| return None | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| ## prepare img: object_region_center, edited_background_region | |
| gt_i_ellipse_img_path, masked_image_path, mask_image_path, ellipse_mask_path, ellipse_masked_image_path = ori_result_gallery | |
| object_white_background_center_path = object_image_gallery[0] | |
| validation_object_region_center = Image.open(object_white_background_center_path[0]) | |
| ori_ellipse_mask = Image.open(ellipse_mask_path[0]) | |
| width, height = validation_object_region_center.size | |
| latent_height, latent_width = height // 8, width // 8 | |
| if not remove_blob_box: | |
| edited_ellipse_masked_image_path, edited_ellipse_mask_path = edited_result_gallery | |
| validation_edited_background_region = Image.open(edited_ellipse_masked_image_path[0]) | |
| ## prepare gs_score | |
| final_ellipse, final_transform_param, final_blob_edited_type = ellipse_lists[-1] | |
| mean, cov_matrix = get_gs_from_ellipse(final_ellipse) | |
| normalized_mean, normalized_cov_matrix = normalize_gs(mean, cov_matrix, width, height) | |
| blob_dict = get_blob_dict_from_norm_gs(normalized_mean, normalized_cov_matrix) | |
| validation_gs_score = get_blob_score_from_blob_dict(blob_dict, score_size=(latent_height, latent_width)).unsqueeze(0).to(device) # bnhw | |
| else: | |
| img_tmp = original_image.copy() | |
| validation_edited_background_region = composite_mask_and_image(ori_ellipse_mask, img_tmp, masked_color=[255,255,255]) | |
| ## prepare gs_score | |
| start_ellipse, start_transform_param, start_blob_edited_type = ellipse_lists[0] | |
| mean, cov_matrix = get_gs_from_ellipse(start_ellipse) | |
| normalized_mean, normalized_cov_matrix = normalize_gs(mean, cov_matrix, width, height) | |
| blob_dict = get_blob_dict_from_norm_gs(normalized_mean, normalized_cov_matrix) | |
| validation_gs_score = get_blob_score_from_blob_dict(blob_dict, score_size=(latent_height, latent_width)).unsqueeze(0).to(device) # bnhw | |
| validation_gs_score[:,0] = 1.0 | |
| validation_gs_score[:,1] = 0.0 | |
| final_ellipse = start_ellipse | |
| ## set blobnet control strength to 0.0 | |
| blobnet_control_strength = 0.0 | |
| with torch.autocast("cuda"): | |
| output = pipeline( | |
| fg_image=validation_object_region_center, | |
| bg_image=validation_edited_background_region, | |
| gs_score=validation_gs_score, | |
| generator=generator, | |
| prompt=[scene_prompt]*num_samples, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| blobnet_control_guidance_start=float(blobnet_control_guidance_start), | |
| blobnet_control_guidance_end=float(blobnet_control_guidance_end), | |
| blobnet_conditioning_scale=float(blobnet_control_strength), | |
| width=width, | |
| height=height, | |
| return_sample=False, | |
| ) | |
| edited_images = output.images | |
| edit_image_plots = [] | |
| for i in range(num_samples): | |
| edit_image_np = np.array(edited_images[i]) | |
| edit_image_np_plot = cv2.ellipse(edit_image_np, final_ellipse, [0,255,0], 3) | |
| edit_image_plot = Image.fromarray(edit_image_np_plot).convert("RGB") | |
| edit_image_plots.append(edit_image_plot) | |
| results_gallery = [*edited_images, *edit_image_plots] | |
| ## save results | |
| # ori_save_path = "examples/blobctrl/assets/results/tmp/ori_result_gallery" | |
| # os.makedirs(ori_save_path, exist_ok=True) | |
| # # import ipdb; ipdb.set_trace() | |
| # for i in range(len(ori_result_gallery)): | |
| # result = Image.open(ori_result_gallery[i][0]) | |
| # result.save(f"{ori_save_path}/ori_result_gallery_{i}.png") | |
| # object_save_path = "examples/blobctrl/assets/results/tmp/object_image_gallery" | |
| # os.makedirs(object_save_path, exist_ok=True) | |
| # validation_object_region_center.save(f"{object_save_path}/validation_object_region_center.png") | |
| # edited_save_path = "examples/blobctrl/assets/results/tmp/edited_result_gallery" | |
| # os.makedirs(edited_save_path, exist_ok=True) | |
| # for i in range(len(edited_result_gallery)): | |
| # result = Image.open(edited_result_gallery[i][0]) | |
| # result.save(f"{edited_save_path}/edited_result_gallery_{i}.png") | |
| # results_save_path = "examples/blobctrl/assets/results/tmp/results_gallery" | |
| # os.makedirs(results_save_path, exist_ok=True) | |
| # for i in range(len(results_gallery)): | |
| # results_gallery[i].save(f"{results_save_path}/results_gallery_{i}.png") | |
| # editable_blob_save_path = "examples/blobctrl/assets/results/tmp/editable_blob" | |
| # os.makedirs(editable_blob_save_path, exist_ok=True) | |
| # editable_blob_pil = Image.fromarray(editable_blob) | |
| # editable_blob_pil.save(f"{editable_blob_save_path}/editable_blob.png") | |
| # state_save_path = "examples/blobctrl/assets/results/tmp/state" | |
| # os.makedirs(state_save_path, exist_ok=True) | |
| # with open(f"{state_save_path}/state.json", "w") as f: | |
| # json.dump({ | |
| # "blobnet_control_strength": blobnet_control_strength, | |
| # "blobnet_control_guidance_start": blobnet_control_guidance_start, | |
| # "blobnet_control_guidance_end": blobnet_control_guidance_end, | |
| # "remove_blob_box": remove_blob_box, | |
| # "num_samples": num_samples, | |
| # "seed": seed, | |
| # "guidance_scale": guidance_scale, | |
| # "num_inference_steps": num_inference_steps, | |
| # "ellipse_lists": ellipse_lists, | |
| # "scene_prompt": scene_prompt, | |
| # "resize_blob_slider_maintain_aspect_ratio": resize_blob_slider_maintain_aspect_ratio, | |
| # "resize_blob_slider_along_long_axis": resize_blob_slider_along_long_axis, | |
| # "resize_blob_slider_along_short_axis": resize_blob_slider_along_short_axis, | |
| # "rotation_blob_slider": rotation_blob_slider, | |
| # "resize_init_blob_slider": resize_init_blob_slider, | |
| # "resize_init_blob_slider_long_axis": resize_init_blob_slider_long_axis, | |
| # "resize_init_blob_slider_short_axis": resize_init_blob_slider_short_axis, | |
| # "tracking_points": tracking_points, | |
| # }, f) | |
| # input_image_save_path = "examples/blobctrl/assets/results/tmp/input_image" | |
| # os.makedirs(input_image_save_path, exist_ok=True) | |
| # Image.fromarray(original_image).save(f"{input_image_save_path}/input_image.png") | |
| torch.cuda.empty_cache() | |
| return results_gallery | |
| def generate_blob( | |
| original_image, | |
| original_mask, | |
| selected_points, | |
| ellipse_lists, | |
| init_resize_factor=1.05, | |
| ): | |
| if original_image is None: | |
| raise gr.Error('Please upload the input image') | |
| if (original_mask is None) or (len(selected_points)==0): | |
| raise gr.Error("Please click the region where you hope unchanged/changed in input image to get segmentation mask") | |
| else: | |
| original_mask = np.clip(255 - original_mask, 0, 255).astype(np.uint8) | |
| ## get ellipse parameters from mask | |
| height, width = original_image.shape[:2] | |
| binary_mask = 255*(original_mask.sum(-1)>255).astype(np.uint8) | |
| ellipse, contours = _get_ellipse(binary_mask) | |
| ## properly enlarge ellipse to cover the whole blob | |
| ellipse, resizing_factor = resize_blob_func(ellipse, init_resize_factor, height, width, 0) | |
| ## get gaussian parameters from ellipse | |
| mean, cov_matrix = get_gs_from_ellipse(ellipse) | |
| normalized_mean, normalized_cov_matrix = normalize_gs(mean, cov_matrix, width, height) | |
| blob_dict = get_blob_dict_from_norm_gs(normalized_mean, normalized_cov_matrix) | |
| blob_vis_img = get_blob_vis_img_from_blob_dict(blob_dict, viz_size=(height, width)) | |
| ## plot masked image | |
| masked_image = composite_mask_and_image(original_mask, original_image) | |
| mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("L") | |
| ## get object region | |
| object_white_background_center = get_object_region_from_mask(original_mask, original_image) | |
| ## plot ellipse | |
| gt_i_ellipse = vis_gt_ellipse_from_ellipse(torch.tensor(original_image).round().contiguous().cpu().numpy(), | |
| ellipse, | |
| color=[0,255,0]) | |
| gt_i_ellipse_img = Image.fromarray(gt_i_ellipse.astype(np.uint8)) | |
| ellipse_mask = get_mask_from_ellipse(ellipse, height, width) | |
| ellipse_masked_image = composite_mask_and_image(ellipse_mask, original_image) | |
| ## return images | |
| ori_result_gallery = [gt_i_ellipse_img, masked_image, mask_image, ellipse_mask, ellipse_masked_image] | |
| object_image_gallery = [object_white_background_center] | |
| ## init ellipse_lists, 0: init, 1: move , 2: resize remain aspect ratio, 3: resize along long axis, 4: resize along short axis, 5: rotation | |
| ## ellipse_int = (ellipse, (resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle), blob_edited_type) | |
| ellipse_init = (ellipse, (1, 1, 1, 0), 0) | |
| if len(ellipse_lists) == 0: | |
| ellipse_lists.append(ellipse_init) | |
| else: | |
| ellipse_lists = clear_ellipse_lists(ellipse_lists) | |
| ellipse_lists.append(ellipse_init) | |
| ## init parameters | |
| rotation_blob_slider = 0 | |
| resize_blob_slider_maintain_aspect_ratio = 1 | |
| resize_blob_slider_along_long_axis = 1 | |
| resize_blob_slider_along_short_axis = 1 | |
| resize_init_blob_slider = 1 | |
| resize_init_blob_slider_long_axis = 1 | |
| resize_init_blob_slider_short_axis = 1 | |
| init_ellipse_parameter = None | |
| init_object_image = None | |
| tracking_points = [] | |
| edited_result_gallery = None | |
| return blob_vis_img, ori_result_gallery, object_image_gallery, ellipse_lists, tracking_points, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio, resize_blob_slider_along_long_axis, resize_blob_slider_along_short_axis, rotation_blob_slider, resize_init_blob_slider, resize_init_blob_slider_long_axis, resize_init_blob_slider_short_axis, init_ellipse_parameter, init_object_image | |
| # undo the selected point | |
| def undo_seg_points(orig_img, sel_pix): | |
| # draw points | |
| output_mask = None | |
| if len(sel_pix) != 0: | |
| temp = orig_img.copy() | |
| sel_pix.pop() | |
| # online show seg mask | |
| if len(sel_pix) !=0: | |
| temp, output_mask = segmentation(temp, sel_pix) | |
| return temp.astype(np.uint8), output_mask | |
| else: | |
| gr.Warning("Nothing to Undo") | |
| # once user upload an image, the original image is stored in `original_image` | |
| def initialize_img(img): | |
| if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0: | |
| raise gr.Error('image aspect ratio cannot be larger than 2.0') | |
| # Check if image needs resizing | |
| # Resize and crop to 512x512 | |
| h, w = img.shape[:2] | |
| # First resize so shortest side is 512 | |
| scale = 512 / min(h, w) | |
| new_h = int(h * scale) | |
| new_w = int(w * scale) | |
| img = cv2.resize(img, (new_w, new_h)) | |
| # Then crop to 512x512 | |
| h, w = img.shape[:2] | |
| start_y = (h - 512) // 2 | |
| start_x = (w - 512) // 2 | |
| img = img[start_y:start_y+512, start_x:start_x+512] | |
| original_image = img.copy() | |
| editable_blob = None | |
| selected_points = [] | |
| tracking_points = [] | |
| ellipse_lists = [] | |
| ori_result_gallery = [] | |
| object_image_gallery = [] | |
| edited_result_gallery = [] | |
| results_gallery = [] | |
| blobnet_control_strength = 1.2 | |
| blobnet_control_guidance_start = 0.0 | |
| blobnet_control_guidance_end = 1.0 | |
| resize_blob_slider_maintain_aspect_ratio = 1 | |
| resize_blob_slider_along_long_axis = 1 | |
| resize_blob_slider_along_short_axis = 1 | |
| rotation_blob_slider = 0 | |
| resize_init_blob_slider = 1 | |
| resize_init_blob_slider_long_axis = 1 | |
| resize_init_blob_slider_short_axis = 1 | |
| init_ellipse_parameter = "[0.5, 0.5, 0.2, 0.2, 180]" | |
| init_object_image = None | |
| remove_blob_box = False | |
| return img, original_image, editable_blob, selected_points, tracking_points, ellipse_lists, ori_result_gallery, object_image_gallery, edited_result_gallery, results_gallery, blobnet_control_strength, blobnet_control_guidance_start, blobnet_control_guidance_end, resize_blob_slider_maintain_aspect_ratio, resize_blob_slider_along_long_axis, resize_blob_slider_along_short_axis, rotation_blob_slider, resize_init_blob_slider, resize_init_blob_slider_long_axis, resize_init_blob_slider_short_axis, init_ellipse_parameter, init_object_image, remove_blob_box | |
| # user click the image to get points, and show the points on the image | |
| def segmentation(img, sel_pix): | |
| # online show seg mask | |
| points = [] | |
| labels = [] | |
| for p, l in sel_pix: | |
| points.append(p) | |
| labels.append(l) | |
| mobile_predictor.set_image(img.astype(np.uint8) if isinstance(img, np.ndarray) else np.array(img).astype(np.uint8)) | |
| with torch.no_grad(): | |
| masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False) | |
| print("=======img=========") | |
| print(img) | |
| print(img.shape) | |
| print("=======points and labels=========") | |
| print(points) | |
| print(labels) | |
| print("=======masks=========") | |
| print(masks) | |
| print(np.unique(masks)) | |
| print("================") | |
| print(mobile_predictor) | |
| output_mask = np.ones((masks.shape[1], masks.shape[2], 3))*255 | |
| for i in range(3): | |
| output_mask[masks[0] == True, i] = 0.0 | |
| mask_all = np.ones((masks.shape[1], masks.shape[2], 3)) | |
| color_mask = np.random.random((1, 3)).tolist()[0] | |
| for i in range(3): | |
| mask_all[masks[0] == True, i] = color_mask[i] | |
| masked_img = img / 255 * 0.3 + mask_all * 0.7 | |
| masked_img = masked_img*255 | |
| ## draw points | |
| for point, label in sel_pix: | |
| cv2.drawMarker(masked_img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) | |
| return masked_img, output_mask | |
| def get_point(img, sel_pix, evt: gr.SelectData): | |
| sel_pix.append((evt.index, 1)) # default foreground_point | |
| # online show seg mask | |
| print(evt.index) | |
| masked_img, output_mask = segmentation(img, sel_pix) | |
| # print(masked_img.shape) | |
| # print(output_mask.shape) | |
| # print(masked_img) | |
| # print(output_mask) | |
| # print(np.unique(output_mask)) | |
| return masked_img.astype(np.uint8), output_mask | |
| def tracking_points_for_blob(original_image, | |
| tracking_points, | |
| ellipse_lists, | |
| height, | |
| width, | |
| edit_status=True): | |
| sel_pix_transparent_layer = np.zeros((height, width, 4)) | |
| sel_ell_transparent_layer = np.zeros((height, width, 4)) | |
| start_ellipse, start_transform_param, start_blob_edited_type = ellipse_lists[0] | |
| current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[-1] | |
| ## plot start point | |
| start_point = tracking_points[0] | |
| cv2.drawMarker(sel_pix_transparent_layer, start_point, rgba_colors[-1], markerType=markers[1], markerSize=20, thickness=5) | |
| ## plot tracking points | |
| if len(tracking_points) > 1: | |
| tracking_points_real = [] | |
| for point in tracking_points: | |
| if not tracking_points_real or point != tracking_points_real[-1]: | |
| tracking_points_real.append(point) | |
| for i in range(len(tracking_points_real)-1): | |
| start_point = tracking_points_real[i] | |
| end_point = tracking_points_real[i+1] | |
| vx = end_point[0] - start_point[0] | |
| vy = end_point[1] - start_point[1] | |
| arrow_length = np.sqrt(vx**2 + vy**2) | |
| ## draw arrow | |
| if i == len(tracking_points_real)-2: | |
| cv2.arrowedLine(sel_pix_transparent_layer, tuple(start_point), tuple(end_point), rgba_colors[-1], 2, tipLength=8 / arrow_length) | |
| else: | |
| cv2.line(sel_pix_transparent_layer, tuple(start_point), tuple(end_point), rgba_colors[-1], 2,) | |
| if edit_status: | |
| edited_ellipse = move_ellipse(current_ellipse, tracking_points_real) | |
| transform_param = current_transform_param | |
| ellipse_lists.append((edited_ellipse, transform_param, 1)) | |
| ## draw ellipse, current ellipse need to be rearanged, because the ellipse_lists may be changed | |
| current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[-1] | |
| cv2.ellipse(sel_ell_transparent_layer, current_ellipse, rgba_colors[-1], 2, -1) | |
| # get current ellipse | |
| current_mean, current_cov_matrix = get_gs_from_ellipse(current_ellipse) | |
| current_normalized_mean, current_normalized_cov_matrix = normalize_gs(current_mean, current_cov_matrix, width, height) | |
| current_blob_dict = get_blob_dict_from_norm_gs(current_normalized_mean, current_normalized_cov_matrix) | |
| transparent_background = get_blob_vis_img_from_blob_dict(current_blob_dict, viz_size=(height, width)).convert('RGBA') | |
| ## composite images | |
| sel_pix_transparent_layer = Image.fromarray(sel_pix_transparent_layer.astype(np.uint8)) | |
| sel_ell_transparent_layer = Image.fromarray(sel_ell_transparent_layer.astype(np.uint8)) | |
| transform_gs_img = Image.alpha_composite(transparent_background, sel_pix_transparent_layer) | |
| transform_gs_img = Image.alpha_composite(transform_gs_img, sel_ell_transparent_layer) | |
| ## get vis edited image and mask | |
| # Use anti-aliasing to get smoother ellipse edges | |
| original_ellipse_mask_np = np.zeros((height, width), dtype=np.float32) | |
| original_ellipse_mask_np = cv2.ellipse(original_ellipse_mask_np, start_ellipse, 1.0, -1, lineType=cv2.LINE_AA) | |
| original_ellipse_mask_np = (original_ellipse_mask_np * 255).astype(np.uint8) | |
| original_ellipse_mask = Image.fromarray(original_ellipse_mask_np).convert("L") | |
| edited_ellipse_mask_np = np.zeros((height, width), dtype=np.float32) | |
| edited_ellipse_mask_np = cv2.ellipse(edited_ellipse_mask_np, current_ellipse, 1.0, -1, lineType=cv2.LINE_AA) | |
| edited_ellipse_mask_np = (edited_ellipse_mask_np * 255).astype(np.uint8) | |
| edited_ellipse_mask = Image.fromarray(edited_ellipse_mask_np).convert("L") | |
| # import ipdb; ipdb.set_trace() | |
| original_ellipse_masked_image = composite_mask_and_image(original_ellipse_mask, original_image, masked_color=[255,255,255]) | |
| edited_ellipse_masked_image = composite_mask_and_image(edited_ellipse_mask, original_ellipse_masked_image, masked_color=[0,0,0]) | |
| edited_result_gallery = [edited_ellipse_masked_image, edited_ellipse_mask] | |
| return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery | |
| def add_tracking_points(original_image, | |
| tracking_points, | |
| ellipse_lists, | |
| evt: gr.SelectData): # SelectData is a subclass of EventData | |
| height, width = original_image.shape[:2] | |
| if len(ellipse_lists) == 0: | |
| gr.Warning("Please generate the blob first") | |
| return None, tracking_points, ellipse_lists, None | |
| ## get start ellipse | |
| start_ellipse, transform_param, blob_edited_type = ellipse_lists[0] | |
| ## check if the point is in the ellipse initially | |
| if not is_point_in_ellipse(evt.index, start_ellipse) and len(tracking_points) == 0: | |
| gr.Warning("Please click the region in the blob in the first time.") | |
| start_mean, start_cov_matrix = get_gs_from_ellipse(start_ellipse) | |
| start_normalized_mean, start_normalized_cov_matrix = normalize_gs(start_mean, start_cov_matrix, width, height) | |
| start_blob_dict = get_blob_dict_from_norm_gs(start_normalized_mean, start_normalized_cov_matrix) | |
| start_transparent_background = get_blob_vis_img_from_blob_dict(start_blob_dict, viz_size=(height, width)).convert('RGBA') | |
| return start_transparent_background, tracking_points, ellipse_lists, None | |
| if len(tracking_points) == 0: | |
| xc, yc = start_ellipse[0] | |
| tracking_points.append([int(xc), int(yc)]) | |
| else: | |
| tracking_points.append(evt.index) | |
| tmp_img = original_image.copy() | |
| transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img, | |
| tracking_points, | |
| ellipse_lists, | |
| height, | |
| width, | |
| edit_status=True) | |
| return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery | |
| def undo_blob_points(original_image, tracking_points, ellipse_lists): | |
| height, width = original_image.shape[:2] | |
| if len(tracking_points) > 1: | |
| tmp_img = original_image.copy() | |
| tracking_points.pop() | |
| ellipse_lists.pop() | |
| transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img, | |
| tracking_points, | |
| ellipse_lists, | |
| height, | |
| width, | |
| edit_status=False) | |
| current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[-1] | |
| # resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle = current_transform_param | |
| resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle = 1,1,1,0 | |
| return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery, resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle | |
| else: | |
| if len(tracking_points) == 1: | |
| tracking_points.pop() | |
| else: | |
| gr.Warning("Nothing to Undo") | |
| transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery, resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle = reset_blob_points(original_image, tracking_points, ellipse_lists) | |
| return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery, resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle | |
| def reset_blob_points(original_image, tracking_points, ellipse_lists): | |
| edited_result_gallery = None | |
| height, width = original_image.shape[:2] | |
| tracking_points = [] | |
| start_ellipse, start_transform_param, start_blob_edited_type = ellipse_lists[0] | |
| ellipse_lists = clear_ellipse_lists(ellipse_lists) | |
| ellipse_lists.append((start_ellipse, start_transform_param, start_blob_edited_type)) | |
| current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[0] | |
| resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle = current_transform_param | |
| current_mean, current_cov_matrix = get_gs_from_ellipse(current_ellipse) | |
| current_normalized_mean, current_normalized_cov_matrix = normalize_gs(current_mean, current_cov_matrix, width, height) | |
| current_blob_dict = get_blob_dict_from_norm_gs(current_normalized_mean, current_normalized_cov_matrix) | |
| transform_gs_img = get_blob_vis_img_from_blob_dict(current_blob_dict, viz_size=(height, width)).convert('RGBA') | |
| return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery, resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle | |
| def resize_blob(editable_blob, | |
| original_image, | |
| tracking_points, | |
| ellipse_lists, | |
| resizing_factor, | |
| resize_type, | |
| edited_result_gallery, | |
| remove_blob_box): | |
| if remove_blob_box: | |
| gr.Warning("Please use initial blob resize in remove mode to ensure the initial blob surrounds the object") | |
| return editable_blob, ellipse_lists, edited_result_gallery, 1 | |
| if len(ellipse_lists) == 0: | |
| gr.Warning("Please generate the blob first") | |
| return None, ellipse_lists, None, 1 | |
| if len(tracking_points) == 0: | |
| gr.Warning("Please select the blob first") | |
| return editable_blob, ellipse_lists, None, 1 | |
| height, width = original_image.shape[:2] | |
| # resize_type: 0: maintain aspect ratio, 1: along long axis, 2: along short axis | |
| current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[-1] | |
| if resize_type == 0: | |
| edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 0) | |
| transform_param = (resizing_factor, current_transform_param[1], current_transform_param[2], current_transform_param[3]) | |
| ellipse_lists.append((edited_ellipse, transform_param, 2)) | |
| elif resize_type == 1: | |
| edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 1) | |
| transform_param = (current_transform_param[0], resizing_factor, current_transform_param[2], current_transform_param[3]) | |
| ellipse_lists.append((edited_ellipse, transform_param, 3)) | |
| elif resize_type == 2: | |
| edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 2) | |
| transform_param = (resizing_factor, current_transform_param[1], resizing_factor, current_transform_param[3]) | |
| ellipse_lists.append((edited_ellipse, transform_param, 4)) | |
| ## reset resizing factor, resize is progressive | |
| resizing_factor = 1 | |
| if len(tracking_points) > 0: | |
| tracking_points.append(tracking_points[-1]) | |
| else: | |
| xc, yc = edited_ellipse[0] | |
| tracking_points.append([int(xc), int(yc)]) | |
| tmp_img = original_image.copy() | |
| transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img, | |
| tracking_points, | |
| ellipse_lists, | |
| height, | |
| width, | |
| edit_status=False) | |
| return transform_gs_img, ellipse_lists, edited_result_gallery, resizing_factor | |
| def resize_start_blob(editable_blob, | |
| original_image, | |
| tracking_points, | |
| ellipse_lists, | |
| ori_result_gallery, | |
| resizing_factor, | |
| resize_type): | |
| if len(ellipse_lists) == 0: | |
| gr.Warning("Please generate the blob first") | |
| return None, ellipse_lists, None, None, 1 | |
| if len(tracking_points) == 0: | |
| gr.Warning("Please select the blob first") | |
| return editable_blob, ellipse_lists, None, None, 1 | |
| height, width = original_image.shape[:2] | |
| ## resize start blob for background | |
| current_idx = 0 | |
| current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[current_idx] | |
| if resize_type == 0: | |
| edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 0) | |
| elif resize_type == 1: | |
| edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 1) | |
| elif resize_type == 2: | |
| edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 2) | |
| transform_param = (current_transform_param[0], current_transform_param[1], current_transform_param[2], current_transform_param[3]) | |
| ellipse_lists[0] = (edited_ellipse, transform_param, 0) | |
| ## reset resizing factor, resize along long axis is progressive | |
| resizing_factor = 1 | |
| tmp_img = original_image.copy() | |
| transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img, | |
| tracking_points, | |
| ellipse_lists, | |
| height, | |
| width, | |
| edit_status=False) | |
| ## ori_result_gallery | |
| gt_i_ellipse_img_path, masked_image_path, mask_image_path, ellipse_mask_path, ellipse_masked_image_path = ori_result_gallery | |
| masked_image = Image.open(masked_image_path[0]) | |
| mask_image = Image.open(mask_image_path[0]) | |
| ## new ellipse mask | |
| current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[current_idx] | |
| new_ellipse_mask_img = get_mask_from_ellipse(current_ellipse, height, width) | |
| new_ellipse_masked_img = composite_mask_and_image(new_ellipse_mask_img, tmp_img) | |
| gt_i_ellipse = vis_gt_ellipse_from_ellipse(torch.tensor(tmp_img).round().contiguous().cpu().numpy(), | |
| current_ellipse, | |
| color=[0,255,0]) | |
| new_gt_i_ellipse_img = Image.fromarray(gt_i_ellipse.astype(np.uint8)) | |
| ori_result_gallery = [new_gt_i_ellipse_img, masked_image, mask_image, new_ellipse_mask_img, new_ellipse_masked_img] | |
| return transform_gs_img, ellipse_lists, edited_result_gallery, ori_result_gallery, resizing_factor | |
| def rotate_blob(editable_blob, | |
| original_image, | |
| tracking_points, | |
| ellipse_lists, | |
| rotation_degree): | |
| if len(ellipse_lists) == 0: | |
| gr.Warning("Please generate the blob first") | |
| return None, ellipse_lists, None, 0 | |
| if len(tracking_points) == 0: | |
| gr.Warning("Please select the blob first") | |
| return editable_blob, ellipse_lists, None, 0 | |
| height, width = original_image.shape[:2] | |
| current_idx = -1 | |
| current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[current_idx] | |
| edited_ellipse, rotation_degree = rotate_blob_func(current_ellipse, rotation_degree) | |
| transform_param = (current_transform_param[0], current_transform_param[1], current_transform_param[2], rotation_degree) | |
| ellipse_lists.append((edited_ellipse, transform_param, 5)) | |
| rotation_degree = 0 | |
| if len(tracking_points) > 0: | |
| tracking_points.append(tracking_points[-1]) | |
| else: | |
| xc, yc = edited_ellipse[0] | |
| tracking_points.append([int(xc), int(yc)]) | |
| tmp_img = original_image.copy() | |
| transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img, | |
| tracking_points, | |
| ellipse_lists, | |
| height, | |
| width, | |
| edit_status=False) | |
| return transform_gs_img, ellipse_lists, edited_result_gallery, rotation_degree | |
| def remove_blob_box_func(editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, remove_blob_box): | |
| if remove_blob_box: | |
| return resize_start_blob(editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, 1.2, 0) | |
| else: | |
| return resize_start_blob(editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, 1.0, 0) | |
| def set_init_ellipse(original_image, original_mask, edited_result_gallery, ellipse_lists, tracking_points, editable_blob, ori_result_gallery, init_ellipse_parameter): | |
| ## if init_ellipse_parameter is not None, use the manual initial ellipse | |
| if init_ellipse_parameter is not None and init_ellipse_parameter != "": | |
| # Parse string input like '[0.5,0.5,0.2,0.2,180]' | |
| params = eval(init_ellipse_parameter) | |
| normalized_xc, normalized_yc, normalized_d1, normalized_d2, angle_clockwise_short_axis = params | |
| height, width = original_image.shape[:2] | |
| max_length = np.sqrt(height**2 + width**2) | |
| ellipse_zero = ((width/2, height/2), (1e-5, 1e-5), 0) | |
| ellipse = ((normalized_xc*width, normalized_yc*height), (normalized_d1*max_length, normalized_d2*max_length), angle_clockwise_short_axis) | |
| original_mask = np.array(get_mask_from_ellipse(ellipse, height, width)) | |
| original_mask = np.stack([original_mask, original_mask, original_mask], axis=-1) | |
| ellipse_init = (ellipse_zero, (1, 1, 1, 0), 0) | |
| ellipse_next = (ellipse, (1, 1, 1, 0), 0) | |
| if len(ellipse_lists) == 0: | |
| ellipse_lists.append(ellipse_init) | |
| ellipse_lists.append(ellipse_next) | |
| else: | |
| ellipse_lists = clear_ellipse_lists(ellipse_lists) | |
| ellipse_lists.append(ellipse_init) | |
| ellipse_lists.append(ellipse_next) | |
| tmp_img = original_image.copy() | |
| tracking_points = [[int(ellipse_init[0][0][1]), int(ellipse_init[0][0][0])], [int(ellipse_next[0][0][1]), int(ellipse_next[0][0][0])]] | |
| transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img, | |
| tracking_points, | |
| ellipse_lists, | |
| height, | |
| width, | |
| edit_status=False) | |
| ## plot masked image | |
| masked_image = composite_mask_and_image(original_mask, original_image) | |
| mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("L") | |
| ## plot ellipse | |
| gt_i_ellipse = vis_gt_ellipse_from_ellipse(torch.tensor(original_image).round().contiguous().cpu().numpy(), | |
| ellipse, | |
| color=[0,255,0]) | |
| gt_i_ellipse_img = Image.fromarray(gt_i_ellipse.astype(np.uint8)) | |
| ellipse_mask = get_mask_from_ellipse(ellipse, height, width) | |
| ellipse_masked_image = composite_mask_and_image(ellipse_mask, original_image) | |
| ori_result_gallery = [gt_i_ellipse_img, masked_image, mask_image, ellipse_mask, ellipse_masked_image] | |
| return transform_gs_img, edited_result_gallery, ellipse_lists, tracking_points, ori_result_gallery, None | |
| gr.Warning("Please set the valid initial ellipse first") | |
| return editable_blob, edited_result_gallery, ellipse_lists, tracking_points, ori_result_gallery, "[0.5, 0.5, 0.2, 0.2, 180]" | |
| def upload_object_image(object_image, edited_result_gallery, remove_blob_box): | |
| if edited_result_gallery == [] or edited_result_gallery == None: | |
| raise gr.Error("Please generate the blob first") | |
| else: | |
| # Check if image needs resizing | |
| # Resize and crop to 512x512 | |
| h, w = object_image.shape[:2] | |
| # First resize so shortest side is 512 | |
| scale = 512 / min(h, w) | |
| new_h = int(h * scale) | |
| new_w = int(w * scale) | |
| object_image = cv2.resize(object_image, (new_w, new_h)) | |
| # Then crop to 512x512 | |
| h, w = object_image.shape[:2] | |
| start_y = (h - 512) // 2 | |
| start_x = (w - 512) // 2 | |
| object_image = object_image[start_y:start_y+512, start_x:start_x+512] | |
| object_image_gallery = [object_image] | |
| remove_blob_box = False | |
| return object_image_gallery, remove_blob_box | |
| block = gr.Blocks() | |
| with block as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML(head) | |
| gr.Markdown(descriptions) | |
| original_image = gr.State(value=None) | |
| original_mask = gr.State(value=None) | |
| resize_blob_maintain_aspect_ratio_state = gr.State(value=0) | |
| resize_blob_along_long_axis_state = gr.State(value=1) | |
| resize_blob_along_short_axis_state = gr.State(value=2) | |
| selected_points = gr.State([]) | |
| tracking_points = gr.State([]) | |
| ellipse_lists = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Column(elem_id="Input"): | |
| gr.Markdown("## **Step 1: Upload an image and click to segment the object**", show_label=False) | |
| with gr.Row(): | |
| input_image = gr.Image(type="numpy", label="input", scale=2, height=576, interactive=True) | |
| with gr.Row(elem_id="Seg"): | |
| undo_seg_button = gr.Button('🔙 Undo Seg', elem_id="undo_btnSEG", scale=1) | |
| gr.Markdown("## **Step 2: Input the scene prompt and 🎩 generate the blob**", show_label=False) | |
| scene_prompt = gr.Textbox(label="Scene Prompt", value="Fill image using foreground and background.") | |
| generate_blob_button = gr.Button("🎩 Generate Blob",elem_id="btn") | |
| gr.Markdown("### 💡 Hint: Adjust the control strength and control timesteps range to balance appearance and flexibility", show_label=False) | |
| blobnet_control_strength = gr.Slider(label="🎚️ Control Strength:", minimum=0, maximum=2.5, value=1.6, step=0.01) | |
| with gr.Row(): | |
| blobnet_control_guidance_start = gr.Slider(label="Blobnet Control Timestep Start", minimum=0, maximum=1, step=0.01, value=0) | |
| blobnet_control_guidance_end = gr.Slider(label="Blobnet Control Timestep End", minimum=0, maximum=1, step=0.01, value=0.9) | |
| gr.Markdown("### Click to adjust the diffusion sampling options 👇", show_label=False) | |
| with gr.Accordion("Diffusion Options", open=False, elem_id="accordion1"): | |
| seed = gr.Slider( | |
| label="Seed: ", minimum=0, maximum=2147483647, step=1, value=1248464818, scale=2 | |
| ) | |
| num_samples = gr.Slider( | |
| label="Num samples", minimum=0, maximum=4, step=1, value=2 | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="CFG scale", minimum=1, maximum=12, step=0.1, value=7.5) | |
| num_inference_steps = gr.Slider(label="NFE", minimum=1, maximum=100, step=1, value=50) | |
| with gr.Column(): | |
| gr.Markdown("### Click to expand more previews 👇", show_label=False) | |
| with gr.Row(): | |
| with gr.Accordion("More Previews", open=False, elem_id="accordion2"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tab(elem_classes="feedback", label="Object Image"): | |
| object_image_gallery = gr.Gallery(label='Object Image', height=320, elem_id="gallery", show_label=True, interactive=False, preview=True) | |
| with gr.Column(): | |
| with gr.Tab(elem_classes="feedback", label="Original Preview"): | |
| ori_result_gallery = gr.Gallery(label='Original Preview', height=320, elem_id="gallery", show_label=True, interactive=False, preview=True) | |
| gr.Markdown("## **Step 3: Edit the blob, such as move/resize/remove the blob**", show_label=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tab(elem_classes="feedback", label="Editable Blob"): | |
| editable_blob = gr.Image(label="Editable Blob", height=320, interactive=False, container=True) | |
| with gr.Column(): | |
| with gr.Tab(elem_classes="feedback", label="Edited Preview"): | |
| edited_result_gallery = gr.Gallery(label='Edited Preview', height=320, elem_id="gallery", show_label=True, interactive=False, preview=True) | |
| gr.Markdown("### Click to adjust the target blob size 👇", show_label=False) | |
| with gr.Row(): | |
| with gr.Group(): | |
| resize_blob_slider_maintain_aspect_ratio = gr.Slider(label="Resize Blob (Maintain Aspect Ratio)", minimum=0.1, maximum=2, step=0.05, value=1) | |
| with gr.Row(): | |
| undo_blob_button = gr.Button('🔙 Undo Blob', elem_id="undo_btnBlob", scale=1) | |
| reset_blob_button = gr.Button('🔄 Reset Blob', elem_id="reset_btnBlob", scale=1) | |
| gr.Markdown("### Click to adjust the initial blob size to ensure it surrounds the object👇", show_label=False) | |
| with gr.Group(): | |
| with gr.Row(): | |
| resize_init_blob_slider = gr.Slider(label="Resize Initial Blob (Maintain Aspect Ratio)", minimum=0.1, maximum=2, step=0.05, value=1, scale=4) | |
| with gr.Row(): | |
| remove_blob_box = gr.Checkbox(label="Remove Blob", value=False, scale=1) | |
| gr.Markdown("### Click to achieve more edit types, such as single-sided resize, composition, etc. 👇", show_label=False) | |
| with gr.Accordion("More Edit Types", open=False, elem_id="accordion3"): | |
| gr.Markdown("### slide to achieve single-sided resize and rotation", show_label=False) | |
| with gr.Group(): | |
| with gr.Row(): | |
| resize_blob_slider_along_long_axis = gr.Slider(label="Resize Blob (Along Long Axis)", minimum=0, maximum=2, step=0.05, value=1) | |
| resize_blob_slider_along_short_axis = gr.Slider(label="Resize Blob (Along Short Axis)", minimum=0, maximum=2, step=0.05, value=1) | |
| with gr.Row(): | |
| rotation_blob_slider = gr.Slider(label="Rotate Blob (Clockwise)", minimum=-180, maximum=180, step=1, value=0) | |
| gr.Markdown("### slide to adjust the initial blob (single-sided)", show_label=False) | |
| with gr.Group(): | |
| with gr.Row(): | |
| resize_init_blob_slider_long_axis = gr.Slider(label="Resize Initial Blob (Long Axis)", minimum=0, maximum=2, step=0.01, value=1) | |
| resize_init_blob_slider_short_axis = gr.Slider(label="Resize Initial Blob (Short Axis)", minimum=0, maximum=2, step=0.01, value=1) | |
| gr.Markdown("### 🎨 Click to set the initial blob and upload object image for compositional generation👇", show_label=False) | |
| with gr.Accordion("Compositional Generation", open=False, elem_id="accordion5"): | |
| with gr.Row(): | |
| init_ellipse_parameter = gr.Textbox(label="Initial Ellipse", value="[0.5, 0.5, 0.2, 0.2, 180]", scale=4) | |
| init_ellipse_button = gr.Button("Set Initial Ellipse", elem_id="set_init_ellipse_btn", scale=1) | |
| with gr.Row(elem_id="Image"): | |
| with gr.Tab(elem_classes="feedback1", label="User-specified Object Image"): | |
| init_object_image = gr.Image(type="numpy", label="User-specified Object Image", height=320) | |
| gr.Markdown("## **Step 4: 🚀 Run Generation**", show_label=False) | |
| run_button = gr.Button("🚀 Run Generation",elem_id="btn") | |
| with gr.Row(): | |
| with gr.Tab(elem_classes="feedback", label="Results"): | |
| results_gallery = gr.Gallery(label='Results', height=320, elem_id="gallery", show_label=True, interactive=False, preview=True) | |
| eg_index = gr.Textbox(label="Example Index", value="", visible=False) | |
| with gr.Row(): | |
| examples_inputs = [ | |
| input_image, | |
| scene_prompt, | |
| blobnet_control_strength, | |
| blobnet_control_guidance_start, | |
| blobnet_control_guidance_end, | |
| seed, | |
| eg_index, | |
| ] | |
| examples_outputs = [ | |
| object_image_gallery, | |
| ori_result_gallery, | |
| editable_blob, | |
| edited_result_gallery, | |
| results_gallery, | |
| ellipse_lists, | |
| tracking_points, | |
| original_image, | |
| remove_blob_box, | |
| ] | |
| def process_example(input_image, | |
| scene_prompt, | |
| blobnet_control_strength, | |
| blobnet_control_guidance_start, | |
| blobnet_control_guidance_end, | |
| seed, | |
| eg_index): | |
| eg_index = int(eg_index) | |
| # Force reload images from disk each time | |
| object_image_gallery = [Image.open(path).copy() for path in OBJECT_IMAGE_GALLERY[eg_index]] | |
| ori_result_gallery = [Image.open(path).copy() for path in ORI_RESULT_GALLERY[eg_index]] | |
| editable_blob = Image.open(EDITABLE_BLOB[eg_index]).copy() | |
| edited_result_gallery = [Image.open(path).copy() for path in EDITED_RESULT_GALLERY[eg_index]] | |
| results_gallery = [path for path in RESULTS_GALLERY[eg_index]] # Paths only | |
| # Deep copy mutable data structures | |
| ellipse_lists = copy.deepcopy(ELLIPSE_LISTS[eg_index]) | |
| tracking_points = copy.deepcopy(TRACKING_POINTS[eg_index]) | |
| # Force reload input image | |
| original_image = np.array(Image.open(INPUT_IMAGE[eg_index]).copy()) | |
| remove_blob_box = REMOVE_STATE[eg_index] | |
| return object_image_gallery, ori_result_gallery, editable_blob, edited_result_gallery, results_gallery, ellipse_lists, tracking_points, original_image, remove_blob_box | |
| example = gr.Examples( | |
| label="Quick Example", | |
| examples=EXAMPLES, | |
| inputs=examples_inputs, | |
| outputs=examples_outputs, | |
| fn=process_example, | |
| examples_per_page=10, | |
| cache_examples=False, | |
| run_on_click=True, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(citation) | |
| ## initial | |
| initial_output = [ | |
| input_image, | |
| original_image, | |
| editable_blob, | |
| selected_points, | |
| tracking_points, | |
| ellipse_lists, | |
| ori_result_gallery, | |
| object_image_gallery, | |
| edited_result_gallery, | |
| results_gallery, | |
| blobnet_control_strength, | |
| blobnet_control_guidance_start, | |
| blobnet_control_guidance_end, | |
| resize_blob_slider_maintain_aspect_ratio, | |
| resize_blob_slider_along_long_axis, | |
| resize_blob_slider_along_short_axis, | |
| rotation_blob_slider, | |
| resize_init_blob_slider, | |
| resize_init_blob_slider_long_axis, | |
| resize_init_blob_slider_short_axis, | |
| init_ellipse_parameter, | |
| init_object_image, | |
| remove_blob_box, | |
| ] | |
| input_image.upload( | |
| initialize_img, | |
| [input_image], | |
| initial_output | |
| ) | |
| ## select point | |
| input_image.select( | |
| get_point, | |
| [input_image, selected_points], | |
| [input_image, original_mask], | |
| ) | |
| undo_seg_button.click( | |
| undo_seg_points, | |
| [original_image, selected_points], | |
| [input_image, original_mask] | |
| ) | |
| ## blob image and tracking points: move | |
| editable_blob.select( | |
| add_tracking_points, | |
| [original_image, tracking_points, ellipse_lists], | |
| [editable_blob, tracking_points, ellipse_lists, edited_result_gallery] | |
| ) | |
| ## undo, reset and save blob | |
| undo_blob_button.click( | |
| undo_blob_points, | |
| [original_image, tracking_points, ellipse_lists], | |
| [editable_blob, tracking_points, ellipse_lists, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio, resize_blob_slider_along_long_axis, resize_blob_slider_along_short_axis, rotation_blob_slider] | |
| ) | |
| reset_blob_button.click( | |
| reset_blob_points, | |
| [original_image, tracking_points, ellipse_lists], | |
| [editable_blob, tracking_points, ellipse_lists, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio] | |
| ) | |
| ## generate blob | |
| generate_blob_button.click(fn=generate_blob, | |
| inputs=[original_image, original_mask, selected_points, ellipse_lists], | |
| outputs=[editable_blob, ori_result_gallery, object_image_gallery, ellipse_lists, tracking_points, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio, resize_blob_slider_along_long_axis, resize_blob_slider_along_short_axis, rotation_blob_slider, resize_init_blob_slider, resize_init_blob_slider_long_axis, resize_init_blob_slider_short_axis, init_ellipse_parameter, init_object_image]) | |
| ## resize blob | |
| resize_blob_slider_maintain_aspect_ratio.release( | |
| resize_blob, | |
| [editable_blob, original_image, tracking_points, ellipse_lists, resize_blob_slider_maintain_aspect_ratio, resize_blob_maintain_aspect_ratio_state, edited_result_gallery, remove_blob_box], | |
| [editable_blob, ellipse_lists, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio] | |
| ) | |
| resize_blob_slider_along_long_axis.release( | |
| resize_blob, | |
| [editable_blob, original_image, tracking_points, ellipse_lists, resize_blob_slider_along_long_axis, resize_blob_along_long_axis_state, edited_result_gallery, remove_blob_box], | |
| [editable_blob, ellipse_lists, edited_result_gallery, resize_blob_slider_along_long_axis] | |
| ) | |
| resize_blob_slider_along_short_axis.release( | |
| resize_blob, | |
| [editable_blob, original_image, tracking_points, ellipse_lists, resize_blob_slider_along_short_axis, resize_blob_along_short_axis_state, edited_result_gallery, remove_blob_box], | |
| [editable_blob, ellipse_lists, edited_result_gallery, resize_blob_slider_along_short_axis] | |
| ) | |
| ## rotate blob | |
| rotation_blob_slider.release( | |
| rotate_blob, | |
| [editable_blob, original_image, tracking_points, ellipse_lists, rotation_blob_slider], | |
| [editable_blob, ellipse_lists, edited_result_gallery, rotation_blob_slider] | |
| ) | |
| remove_blob_box.change( | |
| remove_blob_box_func, | |
| [editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, remove_blob_box], | |
| [editable_blob, ellipse_lists, edited_result_gallery, ori_result_gallery, resize_blob_slider_maintain_aspect_ratio] | |
| ) | |
| ## resize init blob | |
| resize_init_blob_slider.release( | |
| resize_start_blob, | |
| [editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, resize_init_blob_slider, resize_blob_maintain_aspect_ratio_state], | |
| [editable_blob, ellipse_lists, edited_result_gallery, ori_result_gallery, resize_init_blob_slider] | |
| ) | |
| resize_init_blob_slider_long_axis.release( | |
| resize_start_blob, | |
| [editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, resize_init_blob_slider_long_axis, resize_blob_along_long_axis_state], | |
| [editable_blob, ellipse_lists, edited_result_gallery, ori_result_gallery, resize_init_blob_slider_long_axis] | |
| ) | |
| resize_init_blob_slider_short_axis.release( | |
| resize_start_blob, | |
| [editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, resize_init_blob_slider_short_axis, resize_blob_along_short_axis_state], | |
| [editable_blob, ellipse_lists, edited_result_gallery, ori_result_gallery, resize_init_blob_slider_short_axis] | |
| ) | |
| ## set initial ellipse | |
| init_ellipse_button.click( | |
| set_init_ellipse, | |
| inputs=[original_image, original_mask, edited_result_gallery, ellipse_lists, tracking_points, editable_blob, ori_result_gallery, init_ellipse_parameter], | |
| outputs=[editable_blob, edited_result_gallery, ellipse_lists, tracking_points, ori_result_gallery, init_ellipse_parameter] | |
| ) | |
| ## upload user-specified object image | |
| init_object_image.upload( | |
| upload_object_image, | |
| [init_object_image, edited_result_gallery, remove_blob_box], | |
| [object_image_gallery, remove_blob_box] | |
| ) | |
| ## run BlobEdit | |
| ips = [ | |
| original_image, | |
| scene_prompt, | |
| ori_result_gallery, | |
| object_image_gallery, | |
| edited_result_gallery, | |
| ellipse_lists, | |
| blobnet_control_strength, | |
| blobnet_control_guidance_start, | |
| blobnet_control_guidance_end, | |
| remove_blob_box, | |
| num_samples, | |
| seed, | |
| guidance_scale, | |
| num_inference_steps, | |
| # for save | |
| editable_blob, | |
| resize_blob_slider_maintain_aspect_ratio, | |
| resize_blob_slider_along_long_axis, | |
| resize_blob_slider_along_short_axis, | |
| rotation_blob_slider, | |
| resize_init_blob_slider, | |
| resize_init_blob_slider_long_axis, | |
| resize_init_blob_slider_short_axis, | |
| tracking_points, | |
| ] | |
| run_button.click( | |
| run_function, | |
| ips, | |
| [results_gallery] | |
| ) | |
| ## if have a localhost access error, try to use the following code | |
| # demo.launch(server_name="0.0.0.0", server_port=12346) | |
| demo.launch() |