diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..7442d2d39efdf3870535ae1e130b6743be9c33ec 100644 --- a/.gitattributes +++ b/.gitattributes @@ -31,5 +31,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.wasm filter=lfs diff=lfs merge=lfs -text *.xz filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..cce2271ec487a5debb14ea9051db762035caa45a --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +.vscode/ +docs/ +debug_images/ +images/ +result/ +vots/ +vots.py diff --git a/README.md b/README.md index 7196705a2587a95a5500c163d4b8bdfce39cc1fe..ac470bb9a3137ce77162340c20255116d87a0ae9 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,74 @@ + + + + --- -title: ROSE -emoji: 👁 -colorFrom: green -colorTo: indigo +title: ROSE Awesome Space +emoji: 🚀 +colorFrom: blue +colorTo: pink sdk: gradio -sdk_version: 5.34.2 +sdk_version: 4.15.0 app_file: app.py pinned: false -license: apache-2.0 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +## Get Started + +1. Install ProPainter Dependencies + + You can follow the [Dependencies and Installation](https://github.com/Luo-Yihang/ProPainter-pr/tree/dev_yihang#dependencies-and-installation). + +3. Install Demo Dependencies +```shell +cd web-demos/hugging_face + +# install python dependencies +pip3 install -r requirements.txt + +# Run the demo +python app.py +``` + +## Usage Guidance +* Step 1: Upload your video and click the `Get video info` button. + ![Step 1](./assets/step1.png) + +* Step 2: + 1. *[Optional]* Specify the tracking period for the currently added mask by dragging the `Track start frame` or `Track end frame`. + 2. Click the image on the left to select the mask area. + 3. - Click `Add mask` if you are satisfied with the mask, or + - *[Optional]* Click `Clear clicks` if you want to reselect the mask area, or + - *[Optional]* Click `Remove mask` to remove all masks. + 4. *[Optional]* Go back to step 2.1 to add another mask. + ![Step 2](./assets/step2.png) + +* Step 3: + 1. Click the `Tracking` button to track the masks for the whole video. + 2. *[Optional]* Select the ProPainter parameters if the `ProPainter Parameters` dropdown. + 2. Then click `Inpainting` to get the inpainting results. + ![Step 3](./assets/step3.png) + +*You can always refer to the `Highlighted Text` box on the page for guidance on the next step!* + + +## Citation +If you find our repo useful for your research, please consider citing our paper: +```bibtex +@inproceedings{zhou2023propainter, + title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting}, + author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change}, + booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)}, + year={2023} +} +``` + + +## License + +This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license. + + +## Acknowledgements + +The project harnesses the capabilities from [Track Anything](https://github.com/gaomingqi/Track-Anything), [Segment Anything](https://github.com/facebookresearch/segment-anything) and [Cutie](https://github.com/hkchengrex/Cutie). Thanks for their awesome works. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..51252fe0a70ea77a52364eec919de42ae7a95576 --- /dev/null +++ b/app.py @@ -0,0 +1,679 @@ +import sys +sys.path.append("./") + +import os +import json +import time +import psutil +import argparse + +import cv2 +import torch +import torchvision +import numpy as np +import gradio as gr + +from tools.painter import mask_painter +from track_anything import TrackingAnything + +from utils.misc import get_device +from utils.download_util import load_file_from_url +from transformers import AutoTokenizer, AutoModel +from omegaconf import OmegaConf +from torchvision.transforms import functional as TF +from torchvision.utils import save_image +from einops import rearrange +from PIL import Image + +from rose.models import AutoencoderKLWan, CLIPModel, WanT5EncoderModel, WanTransformer3DModel +from rose.pipeline import WanFunInpaintPipeline +from diffusers import FlowMatchEulerDiscreteScheduler + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + return {k: v for k, v in kwargs.items() if k in valid_params} + +# pretrained_model_path = "./models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP" +pretrained_model_path = "alibaba-pai/Wan2.1-Fun-1.3B-InP" +transformer_path = "Kunbyte/ROSE" +# config_path = "configs/wan2.1/wan_civitai.yaml" +config_path = "./configs/wan2.1/wan_civitai.yaml" +config = OmegaConf.load(config_path) + +tokenizer_subpath = config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer') +tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_model_path}/{tokenizer_subpath}") + +text_encoder_subpath = config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder') +text_encoder = WanT5EncoderModel.from_pretrained( + f"{pretrained_model_path}/{text_encoder_subpath}", + additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), + low_cpu_mem_usage=True, +) + +transformer_subpath = config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer') +transformer3d = WanTransformer3DModel.from_pretrained( + f"{transformer_path}/{transformer_subpath}", + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), +) + +image_encoder_subpath = config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder') +clip_image_encoder = CLIPModel.from_pretrained(f"{pretrained_model_path}/{image_encoder_subpath}") + +vae_subpath = config['vae_kwargs'].get('vae_subpath', 'vae') +vae = AutoencoderKLWan.from_pretrained( + f"{pretrained_model_path}/{vae_subpath}", + additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), +) + +noise_scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) +) + +# tokenizer = AutoTokenizer.from_pretrained( +# os.path.join(pretrained_model_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), +# ) +# text_encoder = WanT5EncoderModel.from_pretrained( +# os.path.join(pretrained_model_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), +# additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), +# low_cpu_mem_usage=True, +# ) +# clip_image_encoder = CLIPModel.from_pretrained( +# os.path.join(pretrained_model_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), +# ) +# vae = AutoencoderKLWan.from_pretrained( +# os.path.join(pretrained_model_path, config['vae_kwargs'].get('vae_subpath', 'vae')), +# additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), +# ) +# transformer3d = WanTransformer3DModel.from_pretrained( +# os.path.join(transformer_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), +# transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), +# ) +# noise_scheduler = FlowMatchEulerDiscreteScheduler( +# **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) +# ) + +pipeline = WanFunInpaintPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer3d, + scheduler=noise_scheduler, + clip_image_encoder=clip_image_encoder +).to("cuda", torch.float16) + + +def parse_augment(): + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default=None) + parser.add_argument('--sam_model_type', type=str, default="vit_h") + parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications") + parser.add_argument('--mask_save', default=False) + args = parser.parse_args() + + if not args.device: + args.device = str(get_device()) + + return args + +# convert points input to prompt state +def get_prompt(click_state, click_input): + inputs = json.loads(click_input) + points = click_state[0] + labels = click_state[1] + for input in inputs: + points.append(input[:2]) + labels.append(input[2]) + click_state[0] = points + click_state[1] = labels + prompt = { + "prompt_type":["click"], + "input_point":click_state[0], + "input_label":click_state[1], + "multimask_output":"True", + } + return prompt + +# extract frames from upload video +def get_frames_from_video(video_input, video_state): + """ + Args: + video_path:str + timestamp:float64 + Return + [[0:nearest_frame], [nearest_frame:], nearest_frame] + """ + video_path = video_input + frames = [] + user_name = time.time() + operation_log = [("[Must Do]", "Click image"), (": Video uploaded! Try to click the image shown in step2 to add masks.\n", None)] + try: + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + while cap.isOpened(): + ret, frame = cap.read() + if ret == True: + current_memory_usage = psutil.virtual_memory().percent + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + if current_memory_usage > 90: + operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")] + print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.") + break + else: + break + except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: + print("read_frame_source:{} error. {}\n".format(video_path, str(e))) + image_size = (frames[0].shape[0],frames[0].shape[1]) + # initialize video_state + video_state = { + "user_name": user_name, + "video_name": os.path.split(video_path)[-1], + "origin_images": frames, + "painted_images": frames.copy(), + "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), + "logits": [None]*len(frames), + "select_frame_number": 0, + "fps": fps + } + video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) + return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True),\ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True, choices=[], value=[]), \ + gr.update(visible=True, value=operation_log), gr.update(visible=True, value=operation_log) + +# get the select frame from gradio slider +def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown): + + # images = video_state[1] + image_selection_slider -= 1 + video_state["select_frame_number"] = image_selection_slider + + # once select a new template frame, set the image in sam + + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) + + operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")] + + return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log + +# set the tracking end frame +def get_end_number(track_pause_number_slider, video_state, interactive_state): + interactive_state["track_end_number"] = track_pause_number_slider + operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")] + + return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log + +# use sam to get the mask +def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData): + """ + Args: + template_frame: PIL.Image + point_prompt: flag for positive or negative button click + click_state: [[points], [labels]] + """ + if point_prompt == "Positive": + coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) + interactive_state["positive_click_times"] += 1 + else: + coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) + interactive_state["negative_click_times"] += 1 + + # prompt for sam model + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) + prompt = get_prompt(click_state=click_state, click_input=coordinate) + + mask, logit, painted_image = model.first_frame_click( + image=video_state["origin_images"][video_state["select_frame_number"]], + points=np.array(prompt["input_point"]), + labels=np.array(prompt["input_label"]), + multimask=prompt["multimask_output"], + ) + + video_state["masks"][video_state["select_frame_number"]] = mask + video_state["logits"][video_state["select_frame_number"]] = logit + video_state["painted_images"][video_state["select_frame_number"]] = painted_image + + operation_log = [("[Must Do]", "Add mask"), (": add the current displayed mask for video segmentation.\n", None), + ("[Optional]", "Remove mask"), (": remove all added masks.\n", None), + ("[Optional]", "Clear clicks"), (": clear current displayed mask.\n", None), + ("[Optional]", "Click image"), (": Try to click the image shown in step2 if you want to generate more masks.\n", None)] + return painted_image, video_state, interactive_state, operation_log, operation_log + +def add_multi_mask(video_state, interactive_state, mask_dropdown): + try: + mask = video_state["masks"][video_state["select_frame_number"]] + interactive_state["multi_mask"]["masks"].append(mask) + interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) + mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) + select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown) + operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")] + except: + operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")] + return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log + +def clear_click(video_state, click_state): + click_state = [[],[]] + template_frame = video_state["origin_images"][video_state["select_frame_number"]] + operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")] + return template_frame, click_state, operation_log, operation_log + +def remove_multi_mask(interactive_state, mask_dropdown): + interactive_state["multi_mask"]["mask_names"]= [] + interactive_state["multi_mask"]["masks"] = [] + + operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")] + return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log + +def show_mask(video_state, interactive_state, mask_dropdown): + mask_dropdown.sort() + select_frame = video_state["origin_images"][video_state["select_frame_number"]] + for i in range(len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + mask = interactive_state["multi_mask"]["masks"][mask_number] + select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) + + operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")] + return select_frame, operation_log, operation_log + +# tracking vos +def vos_tracking_video(video_state, interactive_state, mask_dropdown): + operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")] + model.cutie.clear_memory() + if interactive_state["track_end_number"]: + following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] + else: + following_frames = video_state["origin_images"][video_state["select_frame_number"]:] + + if interactive_state["multi_mask"]["masks"]: + if len(mask_dropdown) == 0: + mask_dropdown = ["mask_001"] + mask_dropdown.sort() + template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) + for i in range(1,len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) + video_state["masks"][video_state["select_frame_number"]]= template_mask + else: + template_mask = video_state["masks"][video_state["select_frame_number"]] + + fps = float(video_state["fps"]) + # operation error + if len(np.unique(template_mask))==1: + template_mask[0][0]=1 + operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")] + # return video_output, video_state, interactive_state, operation_error + masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) + # clear GPU memory + model.cutie.clear_memory() + + if interactive_state["track_end_number"]: + video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks + video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits + video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images + else: + video_state["masks"][video_state["select_frame_number"]:] = masks + video_state["logits"][video_state["select_frame_number"]:] = logits + video_state["painted_images"][video_state["select_frame_number"]:] = painted_images + + video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video + interactive_state["inference_times"] += 1 + + print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"], + interactive_state["positive_click_times"]+interactive_state["negative_click_times"], + interactive_state["positive_click_times"], + interactive_state["negative_click_times"])) + + #### shanggao code for mask save + if interactive_state["mask_save"]: + if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])): + os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0])) + i = 0 + print("save mask") + for mask in video_state["masks"]: + np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask) + i+=1 + # save_mask(video_state["masks"], video_state["video_name"]) + #### shanggao code for mask save + return video_output, video_state, interactive_state, operation_log, operation_log + +def inpaint_video(video_state, *_): + operation_log = [("", ""), ("Inpainting finished!", "Normal")] + + # import pdb;pdb.set_trace() + frames = video_state["origin_images"] + masks = video_state["masks"] + # masks = masks * 255 + fps = int(video_state["fps"]) + + total_frames = len(frames) + target_frame_count = (total_frames - 1) // 16 * 16 + 1 + frames = frames[:target_frame_count] + masks = masks[:target_frame_count] + + frames_resized = [cv2.resize(frame, (720, 480), interpolation=cv2.INTER_CUBIC) for frame in frames] + masks_resized = [cv2.resize(mask, (720, 480), interpolation=cv2.INTER_CUBIC) for mask in masks] + + with torch.no_grad(): + video_tensor = torch.stack([TF.to_tensor(Image.fromarray(f)) for f in frames_resized], dim=1).unsqueeze(0).to("cuda", torch.float16) + mask_tensor = torch.stack([TF.to_tensor(Image.fromarray(m*255)) for m in masks_resized], dim=1).unsqueeze(0).to("cuda", torch.float16) + #video_tensor = torch.stack([torch.from_numpy(f).float() for f in frames_resized], dim=1).unsqueeze(0).to("cuda", torch.bfloat16) + #mask_tensor = torch.stack([torch.from_numpy(m).float() for m in masks_resized], dim=1).unsqueeze(0).to("cuda", torch.bfloat16) + + output = pipeline( + prompt="", + video=video_tensor, + mask_video=mask_tensor, + num_frames=video_tensor.shape[2], + num_inference_steps=50 + ).videos + + output = output.clamp(0, 1).cpu() + output_np = (output[0].permute(1, 2, 3, 0).numpy() * 255).astype(np.uint8) + + output_path = f"./result/inpaint/{video_state['video_name']}" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + torchvision.io.write_video(output_path, torch.from_numpy(output_np), fps=fps, video_codec="libx264") + + return output_path, operation_log, operation_log + + +# generate video after vos inference +def generate_video_from_frames(frames, output_path, fps=30): + """ + Generates a video from a list of frames. + + Args: + frames (list of numpy arrays): The frames to include in the video. + output_path (str): The path to save the generated video. + fps (int, optional): The frame rate of the output video. Defaults to 30. + """ + frames = torch.from_numpy(np.asarray(frames)) + if not os.path.exists(os.path.dirname(output_path)): + os.makedirs(os.path.dirname(output_path)) + fps = int(fps) + # import pdb;pdb.set_trace() + torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264") + return output_path + +def restart(): + operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")] + return { + "user_name": "", + "video_name": "", + "origin_images": None, + "painted_images": None, + "masks": None, + "inpaint_masks": None, + "logits": None, + "select_frame_number": 0, + "fps": 30 + }, { + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": args.mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_number": None, + }, [[],[]], None, None, None, \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \ + gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log) + + +# args, defined in track_anything.py +args = parse_augment() +pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/' +sam_checkpoint_url_dict = { + 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" +} +checkpoint_fodler = os.path.join('.', 'weights') + +sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler) +cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler) +# propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'ProPainter.pth'), checkpoint_fodler) +# raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'raft-things.pth'), checkpoint_fodler) +# flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), checkpoint_fodler) + +# initialize sam, cutie, propainter models +model = TrackingAnything(sam_checkpoint, cutie_checkpoint, args) + + +title = r"""

ROSE: Remove Objects with Side Effects in Videos

""" + +description = r""" +
+Official Gradio demo for Remove Objects with Side Effects in Videos.
+🔥 ROSE is a robust inpainting algorithm.
+🤗 Try to drop your video, add the masks and get the the inpainting results!
+""" + +css = """ +.gradio-container {width: 85% !important; margin: 0 auto !important;} +.gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important} +button {border-radius: 8px !important;} +.add_button {background-color: #4CAF50 !important;} +.remove_button {background-color: #f44336 !important;} +.mask_button_group {gap: 10px !important;} +.video {height: 300px !important;} +.image {height: 300px !important;} +.video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;} +.video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;} +.margin_center {width: 50% !important; margin: auto !important;} +.jc_center {justify-content: center !important;} +body { + display: flex; + justify-content: center; + align-items: center; + min-height: 100vh; + margin: 0; +} +""" + +with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface: + click_state = gr.State([[],[]]) + + interactive_state = gr.State({ + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": args.mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_number": None, + } + ) + + video_state = gr.State( + { + "user_name": "", + "video_name": "", + "origin_images": None, + "painted_images": None, + "masks": None, + "inpaint_masks": None, + "logits": None, + "select_frame_number": 0, + "fps": 30 + } + ) + + gr.Markdown(title) + gr.Markdown(description) + + with gr.Column(): + # input video + gr.Markdown("## Step1: Upload video") + with gr.Row(equal_height=True): + with gr.Column(scale=2): + video_input = gr.Video(elem_classes="video") + extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary") + with gr.Column(scale=2): + run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")], + color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"}) + video_info = gr.Textbox(label="Video Info") + + + # add masks + step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False) + with gr.Row(equal_height=True): + with gr.Column(scale=2): + template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image") + image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False) + track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False) + with gr.Column(scale=2, elem_classes="jc_center"): + run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")], + color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"}) + with gr.Row(): + with gr.Column(scale=2, elem_classes="mask_button_group"): + clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False) + remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button") + Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button") + point_prompt = gr.Radio( + choices=["Positive", "Negative"], + value="Positive", + label="Point prompt", + interactive=True, + visible=False, + min_width=100, + scale=1) + mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False) + + # output video + step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False) + with gr.Row(equal_height=True): + with gr.Column(scale=2): + tracking_video_output = gr.Video(visible=False, elem_classes="video") + tracking_video_predict_button = gr.Button(value="1. Tracking", visible=False, elem_classes="margin_center") + with gr.Column(scale=2): + inpaiting_video_output = gr.Video(visible=False, elem_classes="video") + inpaint_video_predict_button = gr.Button(value="2. Inpainting", visible=False, elem_classes="margin_center") + + # first step: get the video information + extract_frames_button.click( + fn=get_frames_from_video, + inputs=[ + video_input, video_state + ], + outputs=[video_state, video_info, template_frame, + image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame, + tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button, inpaint_video_predict_button, step2_title, step3_title,mask_dropdown, run_status, run_status2] + ) + + # second step: select images from slider + image_selection_slider.release(fn=select_template, + inputs=[image_selection_slider, video_state, interactive_state], + outputs=[template_frame, video_state, interactive_state, run_status, run_status2], api_name="select_image") + track_pause_number_slider.release(fn=get_end_number, + inputs=[track_pause_number_slider, video_state, interactive_state], + outputs=[template_frame, interactive_state, run_status, run_status2], api_name="end_image") + + # click select image to get mask using sam + template_frame.select( + fn=sam_refine, + inputs=[video_state, point_prompt, click_state, interactive_state], + outputs=[template_frame, video_state, interactive_state, run_status, run_status2] + ) + + # add different mask + Add_mask_button.click( + fn=add_multi_mask, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status, run_status2] + ) + + remove_mask_button.click( + fn=remove_multi_mask, + inputs=[interactive_state, mask_dropdown], + outputs=[interactive_state, mask_dropdown, run_status, run_status2] + ) + + # tracking video from select image and mask + tracking_video_predict_button.click( + fn=vos_tracking_video, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[tracking_video_output, video_state, interactive_state, run_status, run_status2] + ) + + # inpaint video from select image and mask + inpaint_video_predict_button.click( + fn=inpaint_video, + #inputs=[video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown], + inputs=[video_state, mask_dropdown], + outputs=[inpaiting_video_output, run_status, run_status2] + ) + + # click to get mask + mask_dropdown.change( + fn=show_mask, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[template_frame, run_status, run_status2] + ) + + # clear input + video_input.change( + fn=restart, + inputs=[], + outputs=[ + video_state, + interactive_state, + click_state, + tracking_video_output, inpaiting_video_output, + template_frame, + tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click, + Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2 + ], + queue=False, + show_progress=False) + + video_input.clear( + fn=restart, + inputs=[], + outputs=[ + video_state, + interactive_state, + click_state, + tracking_video_output, inpaiting_video_output, + template_frame, + tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click, + Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2 + ], + queue=False, + show_progress=False) + + # points clear + clear_button_click.click( + fn = clear_click, + inputs = [video_state, click_state,], + outputs = [template_frame,click_state, run_status, run_status2], + ) + + # set example + gr.Markdown("## Examples") + gr.Examples( + examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4"]], + inputs=[video_input], + ) + # gr.Markdown(article) + +# iface.queue(concurrency_count=1) +iface.queue() +iface.launch(debug=True) \ No newline at end of file diff --git a/assets/rose_logo.png b/assets/rose_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..8385233fcf98549a0e38e2b321bd079833237ce5 --- /dev/null +++ b/assets/rose_logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf454f99eaabcece90cb664c39a45d17b58f8201eba8b220fa155ac22a014c4a +size 71765 diff --git a/assets/step1.png b/assets/step1.png new file mode 100644 index 0000000000000000000000000000000000000000..b5b389402c6853ec138e101719ec06096b5bb7a1 --- /dev/null +++ b/assets/step1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c93010fa938c75ae671e3aa362205f3f2692783930f67a6623e0a438479e7326 +size 308524 diff --git a/assets/step2.png b/assets/step2.png new file mode 100644 index 0000000000000000000000000000000000000000..04af9bcbc28c46f26bcae95426b0ca32e14bed3e --- /dev/null +++ b/assets/step2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48bdd827f2581da65df1e731325163bd7e5095511ed6e753346f74b71156dc8d +size 457568 diff --git a/assets/step3.png b/assets/step3.png new file mode 100644 index 0000000000000000000000000000000000000000..b5f980a93e575887ee438c285ca41deae7e38f84 --- /dev/null +++ b/assets/step3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8273104c2e558cb0d1edfe91f5e4ca27483815f5e31aad153f199a351c87b12 +size 1143298 diff --git a/configs/wan2.1/wan_civitai.yaml b/configs/wan2.1/wan_civitai.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a2747bdc9e86a6812c3ca09d2a8653aaea6cb54 --- /dev/null +++ b/configs/wan2.1/wan_civitai.yaml @@ -0,0 +1,39 @@ +format: civitai +pipeline: Wan +transformer_additional_kwargs: + transformer_subpath: ./ + dict_mapping: + in_dim: in_channels + dim: hidden_size + +vae_kwargs: + vae_subpath: Wan2.1_VAE.pth + temporal_compression_ratio: 4 + spatial_compression_ratio: 8 + +text_encoder_kwargs: + text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth + tokenizer_subpath: google/umt5-xxl + text_length: 512 + vocab: 256384 + dim: 4096 + dim_attn: 4096 + dim_ffn: 10240 + num_heads: 64 + num_layers: 24 + num_buckets: 32 + shared_pos: False + dropout: 0.0 + +scheduler_kwargs: + scheduler_subpath: null + num_train_timesteps: 1000 + shift: 5.0 + use_dynamic_shifting: false + base_shift: 0.5 + max_shift: 1.15 + base_image_seq_len: 256 + max_image_seq_len: 4096 + +image_encoder_kwargs: + image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \ No newline at end of file diff --git a/inpainter/base_inpainter.py b/inpainter/base_inpainter.py new file mode 100644 index 0000000000000000000000000000000000000000..795663dcdf22fbdc8c3e010918884f40e3d2d40f --- /dev/null +++ b/inpainter/base_inpainter.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +import os +import sys +import cv2 +import numpy as np +import scipy.ndimage +from PIL import Image +from tqdm import tqdm + +import torch +import torchvision + +from model.modules.flow_comp_raft import RAFT_bi +from model.recurrent_flow_completion import RecurrentFlowCompleteNet +from model.propainter import InpaintGenerator +from core.utils import to_tensors + +import warnings +warnings.filterwarnings("ignore") + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def resize_frames(frames, size=None): + if size is not None: + out_size = size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + frames = [f.resize(process_size) for f in frames] + else: + out_size = frames[0].size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + if not out_size == process_size: + frames = [f.resize(process_size) for f in frames] + + return frames, process_size, out_size + + +def read_frame_from_videos(frame_root): + if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path + video_name = os.path.basename(frame_root)[:-4] + vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB + frames = list(vframes.numpy()) + frames = [Image.fromarray(f) for f in frames] + fps = info['video_fps'] + else: + video_name = os.path.basename(frame_root) + frames = [] + fr_lst = sorted(os.listdir(frame_root)) + for fr in fr_lst: + frame = cv2.imread(os.path.join(frame_root, fr)) + frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frames.append(frame) + fps = None + size = frames[0].size + + return frames, fps, size, video_name + + +def binary_mask(mask, th=0.1): + mask[mask>th] = 1 + mask[mask<=th] = 0 + return mask + + +def extrapolation(video_ori, scale): + """Prepares the data for video outpainting. + """ + nFrame = len(video_ori) + imgW, imgH = video_ori[0].size + + # Defines new FOV. + imgH_extr = int(scale[0] * imgH) + imgW_extr = int(scale[1] * imgW) + imgH_extr = imgH_extr - imgH_extr % 8 + imgW_extr = imgW_extr - imgW_extr % 8 + H_start = int((imgH_extr - imgH) / 2) + W_start = int((imgW_extr - imgW) / 2) + + # Extrapolates the FOV for video. + frames = [] + for v in video_ori: + frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8) + frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v + frames.append(Image.fromarray(frame)) + + # Generates the mask for missing region. + masks_dilated = [] + flow_masks = [] + + dilate_h = 4 if H_start > 10 else 0 + dilate_w = 4 if W_start > 10 else 0 + mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8) + + mask[H_start+dilate_h: H_start+imgH-dilate_h, + W_start+dilate_w: W_start+imgW-dilate_w] = 0 + flow_masks.append(Image.fromarray(mask * 255)) + + mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0 + masks_dilated.append(Image.fromarray(mask * 255)) + + flow_masks = flow_masks * nFrame + masks_dilated = masks_dilated * nFrame + + return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr) + + +def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1): + ref_index = [] + if ref_num == -1: + for i in range(0, length, ref_stride): + if i not in neighbor_ids: + ref_index.append(i) + else: + start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2)) + end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2)) + for i in range(start_idx, end_idx, ref_stride): + if i not in neighbor_ids: + if len(ref_index) > ref_num: + break + ref_index.append(i) + return ref_index + + +def read_mask_demo(masks, length, size, flow_mask_dilates=8, mask_dilates=5): + masks_img = [] + masks_dilated = [] + flow_masks = [] + + for mp in masks: + masks_img.append(Image.fromarray(mp.astype('uint8'))) + + for mask_img in masks_img: + if size is not None: + mask_img = mask_img.resize(size, Image.NEAREST) + mask_img = np.array(mask_img.convert('L')) + + # Dilate 8 pixel so that all known pixel is trustworthy + if flow_mask_dilates > 0: + flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8) + else: + flow_mask_img = binary_mask(mask_img).astype(np.uint8) + + flow_masks.append(Image.fromarray(flow_mask_img * 255)) + + if mask_dilates > 0: + mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8) + else: + mask_img = binary_mask(mask_img).astype(np.uint8) + masks_dilated.append(Image.fromarray(mask_img * 255)) + + if len(masks_img) == 1: + flow_masks = flow_masks * length + masks_dilated = masks_dilated * length + + return flow_masks, masks_dilated + + +class ProInpainter: + def __init__(self, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, device="cuda:0", use_half=True): + self.device = device + self.use_half = use_half + if self.device == torch.device('cpu'): + self.use_half = False + + ############################################## + # set up RAFT and flow competition model + ############################################## + self.fix_raft = RAFT_bi(raft_checkpoint, self.device) + + self.fix_flow_complete = RecurrentFlowCompleteNet(flow_completion_checkpoint) + for p in self.fix_flow_complete.parameters(): + p.requires_grad = False + self.fix_flow_complete.to(self.device) + self.fix_flow_complete.eval() + + ############################################## + # set up ProPainter model + ############################################## + self.model = InpaintGenerator(model_path=propainter_checkpoint).to(self.device) + self.model.eval() + + if self.use_half: + self.fix_flow_complete = self.fix_flow_complete.half() + self.model = self.model.half() + + def inpaint(self, npframes, masks, ratio=1.0, dilate_radius=4, raft_iter=20, subvideo_length=80, neighbor_length=10, ref_stride=10): + """ + Perform Inpainting for video subsets + + Output: + inpainted_frames: numpy array, T, H, W, 3 + """ + + frames = [] + for i in range(len(npframes)): + frames.append(Image.fromarray(npframes[i].astype('uint8'), mode="RGB")) + del npframes + + size = frames[0].size + # The ouput size should be divided by 2 so that it can encoded by libx264 + size = (int(ratio*size[0])//2*2, int(ratio*size[1])//2*2) + + frames_len = len(frames) + frames, size, out_size = resize_frames(frames, size) + flow_masks, masks_dilated = read_mask_demo(masks, frames_len, size, dilate_radius, dilate_radius) + w, h = size + + frames_inp = [np.array(f).astype(np.uint8) for f in frames] + frames = to_tensors()(frames).unsqueeze(0) * 2 - 1 + flow_masks = to_tensors()(flow_masks).unsqueeze(0) + masks_dilated = to_tensors()(masks_dilated).unsqueeze(0) + frames, flow_masks, masks_dilated = frames.to(self.device), flow_masks.to(self.device), masks_dilated.to(self.device) + + ############################################## + # ProPainter inference + ############################################## + video_length = frames.size(1) + with torch.no_grad(): + # ---- compute flow ---- + if frames.size(-1) <= 640: + short_clip_len = 12 + elif frames.size(-1) <= 720: + short_clip_len = 8 + elif frames.size(-1) <= 1280: + short_clip_len = 4 + else: + short_clip_len = 2 + + # use fp32 for RAFT + if frames.size(1) > short_clip_len: + gt_flows_f_list, gt_flows_b_list = [], [] + for f in range(0, video_length, short_clip_len): + end_f = min(video_length, f + short_clip_len) + if f == 0: + flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter) + else: + flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter) + + gt_flows_f_list.append(flows_f) + gt_flows_b_list.append(flows_b) + torch.cuda.empty_cache() + + gt_flows_f = torch.cat(gt_flows_f_list, dim=1) + gt_flows_b = torch.cat(gt_flows_b_list, dim=1) + gt_flows_bi = (gt_flows_f, gt_flows_b) + else: + gt_flows_bi = self.fix_raft(frames, iters=raft_iter) + torch.cuda.empty_cache() + + if self.use_half: + frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half() + gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half()) + + # ---- complete flow ---- + flow_length = gt_flows_bi[0].size(1) + if flow_length > subvideo_length: + pred_flows_f, pred_flows_b = [], [] + pad_len = 5 + for f in range(0, flow_length, subvideo_length): + s_f = max(0, f - pad_len) + e_f = min(flow_length, f + subvideo_length + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(flow_length, f + subvideo_length) + pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow( + (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), + flow_masks[:, s_f:e_f+1]) + pred_flows_bi_sub = self.fix_flow_complete.combine_flow( + (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), + pred_flows_bi_sub, + flow_masks[:, s_f:e_f+1]) + + pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e]) + pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + pred_flows_f = torch.cat(pred_flows_f, dim=1) + pred_flows_b = torch.cat(pred_flows_b, dim=1) + pred_flows_bi = (pred_flows_f, pred_flows_b) + else: + pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks) + pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks) + torch.cuda.empty_cache() + + # ---- image propagation ---- + masked_frames = frames * (1 - masks_dilated) + subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation + if video_length > subvideo_length_img_prop: + updated_frames, updated_masks = [], [] + pad_len = 10 + for f in range(0, video_length, subvideo_length_img_prop): + s_f = max(0, f - pad_len) + e_f = min(video_length, f + subvideo_length_img_prop + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop) + + b, t, _, _, _ = masks_dilated[:, s_f:e_f].size() + pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1]) + prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f], + pred_flows_bi_sub, + masks_dilated[:, s_f:e_f], + 'nearest') + updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \ + prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f] + updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w) + + updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + updated_frames = torch.cat(updated_frames, dim=1) + updated_masks = torch.cat(updated_masks, dim=1) + else: + b, t, _, _, _ = masks_dilated.size() + prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest') + updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated + updated_masks = updated_local_masks.view(b, t, 1, h, w) + torch.cuda.empty_cache() + + ori_frames = frames_inp + comp_frames = [None] * video_length + + neighbor_stride = neighbor_length // 2 + if video_length > subvideo_length: + ref_num = subvideo_length // ref_stride + else: + ref_num = -1 + + # ---- feature propagation + transformer ---- + for f in tqdm(range(0, video_length, neighbor_stride)): + neighbor_ids = [ + i for i in range(max(0, f - neighbor_stride), + min(video_length, f + neighbor_stride + 1)) + ] + ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num) + selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] + selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :] + selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] + selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) + + with torch.no_grad(): + # 1.0 indicates mask + l_t = len(neighbor_ids) + + # pred_img = selected_imgs # results of image propagation + pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) + + pred_img = pred_img.view(-1, 3, h, w) + + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute( + 0, 2, 3, 1).numpy().astype(np.uint8) + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ + + ori_frames[idx] * (1 - binary_masks[i]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 + + comp_frames[idx] = comp_frames[idx].astype(np.uint8) + + torch.cuda.empty_cache() + + # need to return numpy array, T, H, W, 3 + comp_frames = [cv2.resize(f, out_size) for f in comp_frames] + + return comp_frames diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e5f9a8a4ea08061ca95a057897b8601de94ead7e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +progressbar2 +gdown +gitpython +git+https://github.com/cheind/py-thin-plate-spline +hickle +tensorboard +numpy +git+https://github.com/facebookresearch/segment-anything.git +gradio +opencv-python +matplotlib +pyyaml +av +openmim +tqdm +psutil +omegaconf \ No newline at end of file diff --git a/rose/__init__.py b/rose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rose/data/bucket_sampler.py b/rose/data/bucket_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..24b4160f3d2bbadca1d23e90c23e887ea6d15f70 --- /dev/null +++ b/rose/data/bucket_sampler.py @@ -0,0 +1,379 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import (Generic, Iterable, Iterator, List, Optional, Sequence, + Sized, TypeVar, Union) + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import BatchSampler, Dataset, Sampler + +ASPECT_RATIO_512 = { + '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], + '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], + '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], + '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], + '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], + '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], + '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], + '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], + '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], + '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] +} +ASPECT_RATIO_RANDOM_CROP_512 = { + '0.42': [320.0, 768.0], '0.5': [352.0, 704.0], + '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0], + '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], + '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0], + '2.0': [704.0, 352.0], '2.4': [768.0, 320.0] +} +ASPECT_RATIO_RANDOM_CROP_PROB = [ + 1, 2, + 4, 4, 4, 4, + 8, 8, 8, + 4, 4, 4, 4, + 2, 1 +] +ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB) + +def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512): + aspect_ratio = height / width + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return ratios[closest_ratio], float(closest_ratio) + +def get_image_size_without_loading(path): + with Image.open(path) as img: + return img.size # (width, height) + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + + data_source: Sized + replacement: bool + + def __init__(self, data_source: Sized, replacement: bool = False, + num_samples: Optional[int] = None, generator=None) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + self._pos_start = 0 + + if not isinstance(self.replacement, bool): + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + else: + for _ in range(self.num_samples // n): + xx = torch.randperm(n, generator=generator).tolist() + if self._pos_start >= n: + self._pos_start = 0 + print("xx top 10", xx[:10], self._pos_start) + for idx in range(self._pos_start, n): + yield xx[idx] + self._pos_start = (self._pos_start + 1) % n + self._pos_start = 0 + yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] + + def __len__(self) -> int: + return self.num_samples + +class AspectRatioBatchImageSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + def __init__( + self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + train_folder: str = None, + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False, + config=None, + **kwargs + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.train_folder = train_folder + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + self.config = config + # buckets for each aspect ratio + self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} + # [str(k) for k, v in aspect_ratios] + self.current_available_bucket_keys = list(aspect_ratios.keys()) + + def __iter__(self): + for idx in self.sampler: + try: + image_dict = self.dataset[idx] + + width, height = image_dict.get("width", None), image_dict.get("height", None) + if width is None or height is None: + image_id, name = image_dict['file_path'], image_dict['text'] + if self.train_folder is None: + image_dir = image_id + else: + image_dir = os.path.join(self.train_folder, image_id) + + width, height = get_image_size_without_loading(image_dir) + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e) + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self._aspect_ratio_buckets[closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + +class AspectRatioBatchSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + def __init__( + self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + video_folder: str = None, + train_data_format: str = "webvid", + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False, + config=None, + **kwargs + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.video_folder = video_folder + self.train_data_format = train_data_format + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + self.config = config + # buckets for each aspect ratio + self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} + # [str(k) for k, v in aspect_ratios] + self.current_available_bucket_keys = list(aspect_ratios.keys()) + + def __iter__(self): + for idx in self.sampler: + try: + video_dict = self.dataset[idx] + width, more = video_dict.get("width", None), video_dict.get("height", None) + + if width is None or height is None: + if self.train_data_format == "normal": + video_id, name = video_dict['file_path'], video_dict['text'] + if self.video_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.video_folder, video_id) + else: + videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] + video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") + cap = cv2.VideoCapture(video_dir) + + # 获取视频尺寸 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数 + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数 + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self._aspect_ratio_buckets[closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + +class AspectRatioBatchImageVideoSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + + def __init__(self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + train_folder: str = None, + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.train_folder = train_folder + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + + # buckets for each aspect ratio + self.current_available_bucket_keys = list(aspect_ratios.keys()) + self.bucket = { + 'image':{ratio: [] for ratio in aspect_ratios}, + 'video':{ratio: [] for ratio in aspect_ratios} + } + + def __iter__(self): + for idx in self.sampler: + content_type = self.dataset[idx].get('type', 'image') + if content_type == 'image': + try: + image_dict = self.dataset[idx] + + width, height = image_dict.get("width", None), image_dict.get("height", None) + if width is None or height is None: + image_id, name = image_dict['file_path'], image_dict['text'] + if self.train_folder is None: + image_dir = image_id + else: + image_dir = os.path.join(self.train_folder, image_id) + + width, height = get_image_size_without_loading(image_dir) + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self.bucket['image'][closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + else: + try: + video_dict = self.dataset[idx] + width, height = video_dict.get("width", None), video_dict.get("height", None) + + if width is None or height is None: + video_id, name = video_dict['file_path'], video_dict['text'] + if self.train_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.train_folder, video_id) + cap = cv2.VideoCapture(video_dir) + + # 获取视频尺寸 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数 + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数 + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self.bucket['video'][closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] \ No newline at end of file diff --git a/rose/data/dataset_image.py b/rose/data/dataset_image.py new file mode 100644 index 0000000000000000000000000000000000000000..098d49a4044f8daa351cd01b4cb1ec5415412e80 --- /dev/null +++ b/rose/data/dataset_image.py @@ -0,0 +1,76 @@ +import json +import os +import random + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from torch.utils.data.dataset import Dataset + + +class CC15M(Dataset): + def __init__( + self, + json_path, + video_folder=None, + resolution=512, + enable_bucket=False, + ): + print(f"loading annotations from {json_path} ...") + self.dataset = json.load(open(json_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.enable_bucket = enable_bucket + self.video_folder = video_folder + + resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution) + self.pixel_transforms = transforms.Compose([ + transforms.Resize(resolution[0]), + transforms.CenterCrop(resolution), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, name = video_dict['file_path'], video_dict['text'] + + if self.video_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.video_folder, video_id) + + pixel_values = Image.open(video_dir).convert("RGB") + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + except Exception as e: + print(e) + idx = random.randint(0, self.length-1) + + if not self.enable_bucket: + pixel_values = self.pixel_transforms(pixel_values) + else: + pixel_values = np.array(pixel_values) + + sample = dict(pixel_values=pixel_values, text=name) + return sample + +if __name__ == "__main__": + dataset = CC15M( + csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json", + resolution=512, + ) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) + for idx, batch in enumerate(dataloader): + print(batch["pixel_values"].shape, len(batch["text"])) \ No newline at end of file diff --git a/rose/data/dataset_image_video.py b/rose/data/dataset_image_video.py new file mode 100644 index 0000000000000000000000000000000000000000..0288432e4503791991389435677f3419a8f3ce11 --- /dev/null +++ b/rose/data/dataset_image_video.py @@ -0,0 +1,589 @@ +import csv +import io +import json +import math +import os +import random +from threading import Thread + +import albumentations +import cv2 +import gc +import numpy as np +import torch +import torchvision.transforms as transforms + +from func_timeout import func_timeout, FunctionTimedOut +from decord import VideoReader +from PIL import Image +from torch.utils.data import BatchSampler, Sampler +from torch.utils.data.dataset import Dataset +from contextlib import contextmanager + +VIDEO_READER_TIMEOUT = 20 + +def get_random_mask(shape, image_start_only=False): + f, c, h, w = shape + mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) + + if not image_start_only: + if f != 1: + mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05]) + else: + mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) + if mask_index == 0: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 + block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 + + start_x = max(center_x - block_size_x // 2, 0) + end_x = min(center_x + block_size_x // 2, w) + start_y = max(center_y - block_size_y // 2, 0) + end_y = min(center_y + block_size_y // 2, h) + mask[:, :, start_y:end_y, start_x:end_x] = 1 + elif mask_index == 1: + mask[:, :, :, :] = 1 + elif mask_index == 2: + mask_frame_index = np.random.randint(1, 5) + mask[mask_frame_index:, :, :, :] = 1 + elif mask_index == 3: + mask_frame_index = np.random.randint(1, 5) + mask[mask_frame_index:-mask_frame_index, :, :, :] = 1 + elif mask_index == 4: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 + block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 + + start_x = max(center_x - block_size_x // 2, 0) + end_x = min(center_x + block_size_x // 2, w) + start_y = max(center_y - block_size_y // 2, 0) + end_y = min(center_y + block_size_y // 2, h) + + mask_frame_before = np.random.randint(0, f // 2) + mask_frame_after = np.random.randint(f // 2, f) + mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 + elif mask_index == 5: + mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8) + elif mask_index == 6: + num_frames_to_mask = random.randint(1, max(f // 2, 1)) + frames_to_mask = random.sample(range(f), num_frames_to_mask) + + for i in frames_to_mask: + block_height = random.randint(1, h // 4) + block_width = random.randint(1, w // 4) + top_left_y = random.randint(0, h - block_height) + top_left_x = random.randint(0, w - block_width) + mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1 + elif mask_index == 7: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴 + b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴 + + for i in range(h): + for j in range(w): + if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1: + mask[:, :, i, j] = 1 + elif mask_index == 8: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() + for i in range(h): + for j in range(w): + if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2: + mask[:, :, i, j] = 1 + elif mask_index == 9: + for idx in range(f): + if np.random.rand() > 0.5: + mask[idx, :, :, :] = 1 + else: + raise ValueError(f"The mask_index {mask_index} is not define") + else: + if f != 1: + mask[1:, :, :, :] = 1 + else: + mask[:, :, :, :] = 1 + return mask + +class ImageVideoSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + + def __init__(self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + drop_last: bool = False + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # buckets for each aspect ratio + self.bucket = {'image':[], 'video':[]} + + def __iter__(self): + for idx in self.sampler: + content_type = self.dataset.dataset[idx].get('type', 'image') + self.bucket[content_type].append(idx) + + # yield a batch of indices in the same aspect ratio group + if len(self.bucket['video']) == self.batch_size: + bucket = self.bucket['video'] + yield bucket[:] + del bucket[:] + elif len(self.bucket['image']) == self.batch_size: + bucket = self.bucket['image'] + yield bucket[:] + del bucket[:] + +@contextmanager +def VideoReader_contextmanager(*args, **kwargs): + vr = VideoReader(*args, **kwargs) + try: + yield vr + finally: + del vr + gc.collect() + +def get_video_reader_batch(video_reader, batch_index): + frames = video_reader.get_batch(batch_index).asnumpy() + return frames + +def resize_frame(frame, target_short_side): + h, w, _ = frame.shape + if h < w: + if target_short_side > h: + return frame + new_h = target_short_side + new_w = int(target_short_side * w / h) + else: + if target_short_side > w: + return frame + new_w = target_short_side + new_h = int(target_short_side * h / w) + + resized_frame = cv2.resize(frame, (new_w, new_h)) + return resized_frame + +class ImageVideoDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + image_sample_size=512, + video_repeat=0, + text_drop_ratio=0.1, + enable_bucket=False, + video_length_drop_start=0.0, + video_length_drop_end=1.0, + enable_inpaint=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + if video_repeat > 0: + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + + if data_info.get('type', 'image')=='video': + video_id, text = data_info['file_path'], data_info['text'] + + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + return pixel_values, text, 'video' + else: + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + if not self.enable_bucket: + image = self.image_transforms(image).unsqueeze(0) + else: + image = np.expand_dims(np.array(image), 0) + if random.random() < self.text_drop_ratio: + text = '' + return image, text, 'image' + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, name, data_type = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + ref_pixel_values = sample["pixel_values"][0].unsqueeze(0) + if (mask == 1).all(): + ref_pixel_values = torch.ones_like(ref_pixel_values) * -1 + sample["ref_pixel_values"] = ref_pixel_values + + return sample + + +class ImageVideoControlDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + image_sample_size=512, + video_repeat=0, + text_drop_ratio=0.1, + enable_bucket=False, + video_length_drop_start=0.0, + video_length_drop_end=1.0, + enable_inpaint=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + if video_repeat > 0: + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + video_id, text = data_info['file_path'], data_info['text'] + + if data_info.get('type', 'image')=='video': + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + + control_video_id = data_info['control_file_path'] + + if self.data_root is None: + control_video_id = control_video_id + else: + control_video_id = os.path.join(self.data_root, control_video_id) + + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, batch_index) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + control_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + return pixel_values, control_pixel_values, text, "video" + else: + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + if not self.enable_bucket: + image = self.image_transforms(image).unsqueeze(0) + else: + image = np.expand_dims(np.array(image), 0) + + if random.random() < self.text_drop_ratio: + text = '' + + control_image_id = data_info['control_file_path'] + + if self.data_root is None: + control_image_id = control_image_id + else: + control_image_id = os.path.join(self.data_root, control_image_id) + + control_image = Image.open(control_image_id).convert('RGB') + if not self.enable_bucket: + control_image = self.image_transforms(control_image).unsqueeze(0) + else: + control_image = np.expand_dims(np.array(control_image), 0) + return image, control_image, text, 'image' + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, control_pixel_values, name, data_type = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["control_pixel_values"] = control_pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + ref_pixel_values = sample["pixel_values"][0].unsqueeze(0) + if (mask == 1).all(): + ref_pixel_values = torch.ones_like(ref_pixel_values) * -1 + sample["ref_pixel_values"] = ref_pixel_values + + return sample diff --git a/rose/data/dataset_video.py b/rose/data/dataset_video.py new file mode 100644 index 0000000000000000000000000000000000000000..c78367d0973ceb1abdcd005947612d16e2480831 --- /dev/null +++ b/rose/data/dataset_video.py @@ -0,0 +1,262 @@ +import csv +import gc +import io +import json +import math +import os +import random +from contextlib import contextmanager +from threading import Thread + +import albumentations +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from decord import VideoReader +from einops import rearrange +from func_timeout import FunctionTimedOut, func_timeout +from PIL import Image +from torch.utils.data import BatchSampler, Sampler +from torch.utils.data.dataset import Dataset + +VIDEO_READER_TIMEOUT = 20 + +def get_random_mask(shape): + f, c, h, w = shape + + mask_index = np.random.randint(0, 4) + mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) + if mask_index == 0: + mask[1:, :, :, :] = 1 + elif mask_index == 1: + mask_frame_index = 1 + mask[mask_frame_index:-mask_frame_index, :, :, :] = 1 + elif mask_index == 2: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 + block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 + + start_x = max(center_x - block_size_x // 2, 0) + end_x = min(center_x + block_size_x // 2, w) + start_y = max(center_y - block_size_y // 2, 0) + end_y = min(center_y + block_size_y // 2, h) + mask[:, :, start_y:end_y, start_x:end_x] = 1 + elif mask_index == 3: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 + block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 + + start_x = max(center_x - block_size_x // 2, 0) + end_x = min(center_x + block_size_x // 2, w) + start_y = max(center_y - block_size_y // 2, 0) + end_y = min(center_y + block_size_y // 2, h) + + mask_frame_before = np.random.randint(0, f // 2) + mask_frame_after = np.random.randint(f // 2, f) + mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 + else: + raise ValueError(f"The mask_index {mask_index} is not define") + return mask + + +@contextmanager +def VideoReader_contextmanager(*args, **kwargs): + vr = VideoReader(*args, **kwargs) + try: + yield vr + finally: + del vr + gc.collect() + + +def get_video_reader_batch(video_reader, batch_index): + frames = video_reader.get_batch(batch_index).asnumpy() + return frames + + +class WebVid10M(Dataset): + def __init__( + self, + csv_path, video_folder, + sample_size=256, sample_stride=4, sample_n_frames=16, + enable_bucket=False, enable_inpaint=False, is_image=False, + ): + print(f"loading annotations from {csv_path} ...") + with open(csv_path, 'r') as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.video_folder = video_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + self.is_image = is_image + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose([ + transforms.Resize(sample_size[0]), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] + + video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") + video_reader = VideoReader(video_dir) + video_length = len(video_reader) + + if not self.is_image: + clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + batch_index = [random.randint(0, video_length - 1)] + + if not self.enable_bucket: + pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = video_reader.get_batch(batch_index).asnumpy() + + if self.is_image: + pixel_values = pixel_values[0] + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + + except Exception as e: + print("Error info:", e) + idx = random.randint(0, self.length-1) + + if not self.enable_bucket: + pixel_values = self.pixel_transforms(pixel_values) + if self.enable_inpaint: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name) + else: + sample = dict(pixel_values=pixel_values, text=name) + return sample + + +class VideoDataset(Dataset): + def __init__( + self, + json_path, video_folder=None, + sample_size=256, sample_stride=4, sample_n_frames=16, + enable_bucket=False, enable_inpaint=False + ): + print(f"loading annotations from {json_path} ...") + self.dataset = json.load(open(json_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.video_folder = video_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose( + [ + transforms.Resize(sample_size[0]), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, name = video_dict['file_path'], video_dict['text'] + + if self.video_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.video_folder, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + video_length = len(video_reader) + + clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + + except Exception as e: + print("Error info:", e) + idx = random.randint(0, self.length-1) + + if not self.enable_bucket: + pixel_values = self.pixel_transforms(pixel_values) + if self.enable_inpaint: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name) + else: + sample = dict(pixel_values=pixel_values, text=name) + return sample + + +if __name__ == "__main__": + if 1: + dataset = VideoDataset( + json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json", + sample_size=256, + sample_stride=4, sample_n_frames=16, + ) + + if 0: + dataset = WebVid10M( + csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv", + video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val", + sample_size=256, + sample_stride=4, sample_n_frames=16, + is_image=False, + ) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) + for idx, batch in enumerate(dataloader): + print(batch["pixel_values"].shape, len(batch["text"])) \ No newline at end of file diff --git a/rose/dist/__init__.py b/rose/dist/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dea4c21bb4f0398efed1d17da98e03520d2c4d13 --- /dev/null +++ b/rose/dist/__init__.py @@ -0,0 +1,43 @@ +import importlib.util + +from .fsdp import shard_model +from .fuser import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + get_world_group, init_distributed_environment, + initialize_model_parallel, set_multi_gpus_devices, + xFuserLongContextAttention) +from .wan_xfuser import usp_attn_forward + +# The pai_fuser is an internally developed acceleration package, which can be used on PAI. +if importlib.util.find_spec("pai_fuser") is not None: + from pai_fuser.core import parallel_magvit_vae + from pai_fuser.core.attention import wan_usp_sparse_attention_wrapper + from . import wan_xfuser + + # The simple_wrapper is used to solve the problem about conflicts between cython and torch.compile + def simple_wrapper(func): + def inner(*args, **kwargs): + return func(*args, **kwargs) + return inner + + wan_xfuser.usp_attn_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward)) + usp_attn_forward = simple_wrapper(wan_xfuser.usp_attn_forward) + print("Import PAI VAE Turbo and Sparse Attention") + + from pai_fuser.core.rope import ENABLE_KERNEL, usp_fast_rope_apply_qk + + if ENABLE_KERNEL: + import torch + from .wan_xfuser import rope_apply + + def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs): + if torch.is_grad_enabled(): + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + return q, k + else: + return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs) + + wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk + rope_apply_qk = adaptive_fast_usp_rope_apply_qk + print("Import PAI Fast rope") \ No newline at end of file diff --git a/rose/dist/fsdp.py b/rose/dist/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..555479f9b4e33d5d4bed54913ceab68adceac713 --- /dev/null +++ b/rose/dist/fsdp.py @@ -0,0 +1,43 @@ +# Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +from functools import partial + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy +from torch.distributed.utils import _free_storage + + +def shard_model( + model, + device_id, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + process_group=None, + sharding_strategy=ShardingStrategy.FULL_SHARD, + sync_module_states=True, +): + model = FSDP( + module=model, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=partial( + lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype), + device_id=device_id, + sync_module_states=sync_module_states) + return model + +def free_model(model): + for m in model.modules(): + if isinstance(m, FSDP): + _free_storage(m._handle.flat_param.data) + del model + gc.collect() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/rose/dist/fuser.py b/rose/dist/fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..de12a2bdec6a309f3a83bed1b429d053d3569527 --- /dev/null +++ b/rose/dist/fuser.py @@ -0,0 +1,54 @@ +import importlib.util + +import torch +import torch.distributed as dist + +try: + # The pai_fuser is an internally developed acceleration package, which can be used on PAI. + if importlib.util.find_spec("pai_fuser") is not None: + import pai_fuser + from pai_fuser.core.distributed import ( + get_sequence_parallel_rank, get_sequence_parallel_world_size, + get_sp_group, get_world_group, init_distributed_environment, + initialize_model_parallel) + from pai_fuser.core.long_ctx_attention import \ + xFuserLongContextAttention + print("Import PAI DiT Turbo") + else: + import xfuser + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, get_world_group, + init_distributed_environment, + initialize_model_parallel) + from xfuser.core.long_ctx_attention import xFuserLongContextAttention + print("Xfuser import sucessful") +except Exception as ex: + get_sequence_parallel_world_size = None + get_sequence_parallel_rank = None + xFuserLongContextAttention = None + get_sp_group = None + get_world_group = None + init_distributed_environment = None + initialize_model_parallel = None + +def set_multi_gpus_devices(ulysses_degree, ring_degree): + if ulysses_degree > 1 or ring_degree > 1: + if get_sp_group is None: + raise RuntimeError("xfuser is not installed.") + dist.init_process_group("nccl") + print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % ( + ulysses_degree, ring_degree, dist.get_rank(), + dist.get_world_size())) + assert dist.get_world_size() == ring_degree * ulysses_degree, \ + "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size() + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(), + ring_degree=ring_degree, + ulysses_degree=ulysses_degree) + # device = torch.device("cuda:%d" % dist.get_rank()) + device = torch.device(f"cuda:{get_world_group().local_rank}") + print('rank=%d device=%s' % (get_world_group().rank, str(device))) + else: + device = "cuda" + return device \ No newline at end of file diff --git a/rose/dist/wan_xfuser.py b/rose/dist/wan_xfuser.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc59c8de2529df114caab61e675f406c1d12bd3 --- /dev/null +++ b/rose/dist/wan_xfuser.py @@ -0,0 +1,111 @@ +import torch +import torch.cuda.amp as amp + +from .fuser import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + init_distributed_environment, initialize_model_parallel, + xFuserLongContextAttention) + + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + +@amp.autocast(enabled=False) +@torch.compiler.disable() +def rope_apply(x, grid_sizes, freqs): + """ + x: [B, L, N, C]. + grid_sizes: [B, 3]. + freqs: [M, C // 2]. + """ + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape( + s, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs_i = pad_freqs(freqs_i, s * sp_size) + s_per_rank = s + freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * + s_per_rank), :, :] + x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + + # append to collection + output.append(x_i) + return torch.stack(output) + +def rope_apply_qk(q, k, grid_sizes, freqs): + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + return q, k + +def usp_attn_forward(self, + x, + seq_lens, + grid_sizes, + freqs, + dtype=torch.bfloat16, + t=0): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q, k = rope_apply_qk(q, k, grid_sizes, freqs) + + # TODO: We should use unpaded q,k,v for attention. + # k_lens = seq_lens // get_sequence_parallel_world_size() + # if k_lens is not None: + # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) + # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) + # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) + + x = xFuserLongContextAttention()( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + + # TODO: padding after attention. + # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) + + # output + x = x.flatten(2) + x = self.o(x) + return x \ No newline at end of file diff --git a/rose/models/__init__.py b/rose/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e7bd02faf879f2824d17138c93a8ac6d73b507 --- /dev/null +++ b/rose/models/__init__.py @@ -0,0 +1,6 @@ +from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer + +from .wan_image_encoder import CLIPModel +from .wan_text_encoder import WanT5EncoderModel +from .wan_transformer3d import WanTransformer3DModel +from .wan_vae import AutoencoderKLWan diff --git a/rose/models/cache_utils.py b/rose/models/cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9209fa7909065a2d4c99c696a704e4bd26856ec --- /dev/null +++ b/rose/models/cache_utils.py @@ -0,0 +1,74 @@ +import numpy as np +import torch + + +def get_teacache_coefficients(model_name): + if "wan2.1-t2v-1.3b" or "wan2.1-fun-1.3b" in model_name.lower(): + return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02] + elif "wan2.1-t2v-14b" in model_name.lower(): + return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01] + elif "wan2.1-i2v-14b-480p" in model_name.lower(): + return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01] + elif "wan2.1-i2v-14b-720p" or "wan2.1-fun-14b" in model_name.lower(): + return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02] + else: + print(f"The model {model_name} is not supported by TeaCache.") + return None + + +class TeaCache(): + """ + Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages + the fluctuating differences among model outputs across timesteps, thereby accelerating the inference. + Please refer to: + 1. https://github.com/ali-vilab/TeaCache. + 2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024). + """ + def __init__( + self, + coefficients: list[float], + num_steps: int, + rel_l1_thresh: float = 0.0, + num_skip_start_steps: int = 0, + offload: bool = True, + ): + if num_steps < 1: + raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.") + if rel_l1_thresh < 0: + raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.") + if num_skip_start_steps < 0 or num_skip_start_steps > num_steps: + raise ValueError( + "`num_skip_start_steps` must be great than or equal to 0 and " + f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}." + ) + self.coefficients = coefficients + self.num_steps = num_steps + self.rel_l1_thresh = rel_l1_thresh + self.num_skip_start_steps = num_skip_start_steps + self.offload = offload + self.rescale_func = np.poly1d(self.coefficients) + + self.cnt = 0 + self.should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + # Some pipelines concatenate the unconditional and text guide in forward. + self.previous_residual = None + # Some pipelines perform forward propagation separately on the unconditional and text guide. + self.previous_residual_cond = None + self.previous_residual_uncond = None + + @staticmethod + def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor: + rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean() + + return rel_l1_distance.cpu().item() + + def reset(self): + self.cnt = 0 + self.should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.previous_residual_cond = None + self.previous_residual_uncond = None \ No newline at end of file diff --git a/rose/models/diff_mask_predictor.py b/rose/models/diff_mask_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ecd526fbac9725762db412a28996e0ca626782 --- /dev/null +++ b/rose/models/diff_mask_predictor.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config + +class DiffMaskPredictor(nn.Module): + def __init__(self, input_dim=4608, patch_grid=(10, 15, 189), output_grid=(81, 480, 720), hidden_dim=256): + """ + Args: + input_dim (int): concatenated feature dimension, e.g. 1536 * num_selected_layers + patch_grid (tuple): (F_p, H_p, W_p) - patch token grid shape (e.g., from transformer block) + output_grid (tuple): (F, H, W) - final full resolution shape for mask + hidden_dim (int): intermediate conv/linear hidden dim + """ + super().__init__() + self.F_p, self.H_p, self.W_p = patch_grid + self.F, self.H, self.W = output_grid + + self.project = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, 1) + ) + + def forward(self, x): + """ + Args: + x (Tensor): shape [B, L, D_total], L = F_p H_p W_p + Returns: + Tensor: predicted diff mask, shape [B, 1, F, H, W] + """ + B, L, D = x.shape + assert L == self.F_p * self.H_p * self.W_p, \ + f"Input token length {L} doesn't match patch grid ({self.F_p}, {self.H_p}, {self.W_p})" + + x = self.project(x) # [B, L, 1] + x = x.view(B, 1, self.F_p, self.H_p, self.W_p) # [B, 1, F_p, H_p, W_p] + x = F.interpolate( + x, size=(self.F, self.H, self.W), + mode="trilinear", align_corners=False # upsample to match ground truth resolution + ) + return x # [B, 1, F, H, W] diff --git a/rose/models/wan_image_encoder.py b/rose/models/wan_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..68c442a79cef34081a9e67cdf0d9bf8d3ec67c9f --- /dev/null +++ b/rose/models/wan_image_encoder.py @@ -0,0 +1,553 @@ +# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T + +from .wan_transformer3d import attention +from .wan_xlm_roberta import XLMRoberta +from diffusers.configuration_utils import ConfigMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin + + +__all__ = [ + 'XLMRobertaCLIP', + 'clip_xlm_roberta_vit_h_14', + 'CLIPModel', +] + + +def pos_interpolate(pos, seq_len): + if pos.size(1) == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.size(1))) + tar_grid = int(math.sqrt(seq_len)) + n = pos.size(1) - src_grid * src_grid + return torch.cat([ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( + 0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode='bicubic', + align_corners=False).flatten(2).transpose(1, 2) + ], + dim=1) + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + causal=False, + attn_dropout=0.0, + proj_dropout=0.0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + + # compute attention + p = self.attn_dropout if self.training else 0.0 + x = attention(q, k, v, dropout_p=p, causal=self.causal) + x = x.reshape(b, s, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Module): + + def __init__(self, dim, mid_dim): + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(dim, mid_dim) + self.fc3 = nn.Linear(mid_dim, dim) + + def forward(self, x): + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + post_norm=False, + causal=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + norm_eps=1e-5): + assert activation in ['quick_gelu', 'gelu', 'swi_glu'] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, + proj_dropout) + self.norm2 = LayerNorm(dim, eps=norm_eps) + if activation == 'swi_glu': + self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) + else: + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Module): + + def __init__(self, + dim, + mlp_ratio, + num_heads, + activation='gelu', + proj_dropout=0.0, + norm_eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.to_q = nn.Linear(dim, dim) + self.to_kv = nn.Linear(dim, dim * 2) + self.proj = nn.Linear(dim, dim) + self.norm = LayerNorm(dim, eps=norm_eps) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + QuickGELU() if activation == 'quick_gelu' else nn.GELU(), + nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) + + def forward(self, x): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) + k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + + # compute attention + x = flash_attention(q, k, v, version=2) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + mlp_ratio=4, + out_dim=512, + num_heads=12, + num_layers=12, + pool_type='token', + pre_norm=True, + post_norm=False, + activation='quick_gelu', + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + if image_size % patch_size != 0: + print( + '[WARNING] image_size is not divisible by patch_size', + flush=True) + assert pool_type in ('token', 'token_fc', 'attn_pool') + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size)**2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, + dim, + kernel_size=patch_size, + stride=patch_size, + bias=not pre_norm) + if pool_type in ('token', 'token_fc'): + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter(gain * torch.randn( + 1, self.num_patches + + (1 if pool_type in ('token', 'token_fc') else 0), dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, + activation, attn_dropout, proj_dropout, norm_eps) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim, eps=norm_eps) + + # head + if pool_type == 'token': + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + elif pool_type == 'token_fc': + self.head = nn.Linear(dim, out_dim) + elif pool_type == 'attn_pool': + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, + proj_dropout, norm_eps) + + def forward(self, x, interpolation=False, use_31_block=False): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ('token', 'token_fc'): + x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.size(1)) + else: + e = self.pos_embedding + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class XLMRobertaWithHead(XLMRoberta): + + def __init__(self, **kwargs): + self.out_dim = kwargs.pop('out_dim') + super().__init__(**kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.Sequential( + nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), + nn.Linear(mid_dim, self.out_dim, bias=False)) + + def forward(self, ids): + # xlm-roberta + x = super().forward(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Module): + + def __init__(self, + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + vision_pre_norm=True, + vision_post_norm=False, + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + norm_eps=1e-5): + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps) + self.textual = XLMRobertaWithHead( + vocab_size=vocab_size, + max_seq_len=max_text_len, + type_size=type_size, + pad_id=pad_id, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + post_norm=text_post_norm, + dropout=text_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_ids): + """ + imgs: [B, 3, H, W] of torch.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of torch.long. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + +def _clip(pretrained=False, + pretrained_name=None, + model_cls=XLMRobertaCLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding='eos', + dtype=torch.float32, + device='cpu', + **kwargs): + # init a model on device + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if 'siglip' in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = T.Compose([ + T.Resize((model.image_size, model.image_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=mean, std=std) + ]) + output += (transforms,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, + pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', + **kwargs): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool='token', + activation='gelu', + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) + + +class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + + def __init__(self): + super(CLIPModel, self).__init__() + # init model + self.model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False) + + def forward(self, videos): + # preprocess + size = (self.model.image_size,) * 2 + videos = torch.cat([ + F.interpolate( + u.transpose(0, 1), + size=size, + mode='bicubic', + align_corners=False) for u in videos + ]) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) + + # forward + with torch.cuda.amp.autocast(dtype=self.dtype): + out = self.model.visual(videos, use_31_block=True) + return out + + @classmethod + def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}): + def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + + model = cls(**filter_kwargs(cls, transformer_additional_kwargs)) + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + tmp_state_dict = {} + for key in state_dict: + tmp_state_dict["model." + key] = state_dict[key] + state_dict = tmp_state_dict + m, u = model.load_state_dict(state_dict) + + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + return model \ No newline at end of file diff --git a/rose/models/wan_text_encoder.py b/rose/models/wan_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..49fc936bb9953387f9848d7d6340a1f6f1815695 --- /dev/null +++ b/rose/models/wan_text_encoder.py @@ -0,0 +1,376 @@ +# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class GELU(nn.Module): + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) + + def forward(self, + x, + mask=None, + encoder_states=None, + encoder_mask=None, + pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn( + self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + if torch.device(type="meta") != device: + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + else: + rel_pos = torch.arange(lk).unsqueeze(0) - \ + torch.arange(lq).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + +class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(WanT5EncoderModel, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): + x = self.token_embedding(input_ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, attention_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return (x, ) + + @classmethod + def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16): + def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + + if low_cpu_mem_usage: + try: + import re + + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls(**filter_kwargs(cls, additional_kwargs)) + + param_device = "cpu" + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls(**filter_kwargs(cls, additional_kwargs)) + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + return model \ No newline at end of file diff --git a/rose/models/wan_transformer3d.py b/rose/models/wan_transformer3d.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f5041d168a4d8b5b5a2dab9a60e2c586e5fb0f --- /dev/null +++ b/rose/models/wan_transformer3d.py @@ -0,0 +1,1203 @@ +# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import glob +import json +import math +import os +import types +import warnings +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import is_torch_version, logging +from torch import nn + +from ..dist import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + xFuserLongContextAttention) +from ..dist.wan_xfuser import usp_attn_forward +from .cache_utils import TeaCache + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + + +def flash_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, +): + """ + q: [B, Lq, Nq, C1]. + k: [B, Lk, Nk, C1]. + v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. + q_lens: [B]. + k_lens: [B]. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + causal: bool. Whether to apply causal attention mask. + window_size: (left right). If not (-1, -1), apply sliding window local attention. + deterministic: bool. If True, slightly slower and uses more memory. + dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. + """ + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + assert q.device.type == 'cuda' and q.size(-1) <= 256 + + # params + b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # preprocess query + if q_lens is None: + q = half(q.flatten(0, 1)) + q_lens = torch.tensor( + [lq] * b, dtype=torch.int32).to( + device=q.device, non_blocking=True) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + # preprocess key, value + if k_lens is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + k_lens = torch.tensor( + [lk] * b, dtype=torch.int32).to( + device=k.device, non_blocking=True) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + + if q_scale is not None: + q = q * q_scale + + if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + warnings.warn( + 'Flash attention 3 is not available, use flash attention 2 instead.' + ) + + # apply attention + if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: + # Note: dropout_p, window_size are not supported in FA3 now. + x = flash_attn_interface.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=lq, + max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic)[0].unflatten(0, (b, lq)) + else: + assert FLASH_ATTN_2_AVAILABLE + x = flash_attn.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=lq, + max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic).unflatten(0, (b, lq)) + + # output + return x.type(out_dtype) + + +def attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + fa_version=None, +): + if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: + return flash_attention( + q=q, + k=k, + v=v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=fa_version, + ) + else: + if q_lens is not None or k_lens is not None: + warnings.warn( + 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' + ) + attn_mask = None + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + + out = out.transpose(1, 2).contiguous() + return out + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast(enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + +# modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py +@amp.autocast(enabled=False) +def get_1d_rotary_pos_embed_riflex( + pos: Union[np.ndarray, int], + dim: int, + theta: float = 10000.0, + use_real=False, + k: Optional[int] = None, + L_test: Optional[int] = None, + L_test_scale: Optional[int] = None, +): + """ + RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE + L_test (`int`, *optional*, defaults to None): the number of frames for inference + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + freqs = 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim)) + + # === Riflex modification start === + # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)). + # Empirical observations show that a few videos may exhibit repetition in the tail frames. + # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period. + if k is not None: + freqs[k-1] = 0.9 * 2 * torch.pi / L_test + # === Riflex modification end === + if L_test_scale is not None: + freqs[k-1] = freqs[k-1] / L_test_scale + + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape( + seq_len, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, seq_lens, grid_sizes, freqs, dtype): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d) + k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d) + v = self.v(x.to(dtype)).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + x = attention( + q=rope_apply(q, grid_sizes, freqs).to(dtype), + k=rope_apply(k, grid_sizes, freqs).to(dtype), + v=v.to(dtype), + k_lens=seq_lens, + window_size=self.window_size) + x = x.to(dtype) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens, dtype): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) + k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) + v = self.v(context.to(dtype)).view(b, -1, n, d) + + # compute attention + x = attention( + q.to(dtype), + k.to(dtype), + v.to(dtype), + k_lens=context_lens + ) + x = x.to(dtype) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens, dtype): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) + k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) + v = self.v(context.to(dtype)).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d) + v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d) + + img_x = attention( + q.to(dtype), + k_img.to(dtype), + v_img.to(dtype), + k_lens=None + ) + img_x = img_x.to(dtype) + # compute attention + x = attention( + q.to(dtype), + k.to(dtype), + v.to(dtype), + k_lens=context_lens + ) + x = x.to(dtype) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + 't2v_cross_attn': WanT2VCrossAttention, + 'i2v_cross_attn': WanI2VCrossAttention, +} + + +class WanAttentionBlock(nn.Module): + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, + eps) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, + num_heads, + (-1, -1), + qk_norm, + eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + dtype=torch.float32 + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + e = (self.modulation + e).chunk(6, dim=1) + + # self-attention + temp_x = self.norm1(x) * (1 + e[1]) + e[0] + temp_x = temp_x.to(dtype) + + y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype) + x = x + y * e[2] + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + # cross-attention + x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype) + + # ffn function + temp_x = self.norm2(x) * (1 + e[4]) + e[3] + temp_x = temp_x.to(dtype) + + y = self.ffn(temp_x) + x = x + y * e[5] + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim)) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + + +class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + # ignore_for_config = [ + # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' + # ] + # _no_split_modules = ['WanAttentionBlock'] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + in_channels=16, + hidden_size=2048, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ['t2v', 'i2v'] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), + nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ]) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.d = d + self.freqs = torch.cat( + [ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1 + ) + + if model_type == 'i2v': + self.img_emb = MLPProj(1280, dim) + + self.teacache = None + self.gradient_checkpointing = False + self.sp_world_size = 1 + self.sp_world_rank = 0 + + def enable_teacache( + self, + coefficients, + num_steps: int, + rel_l1_thresh: float, + num_skip_start_steps: int = 0, + offload: bool = True + ): + self.teacache = TeaCache( + coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload + ) + + def disable_teacache(self): + self.teacache = None + + def enable_riflex( + self, + k = 6, + L_test = 66, + L_test_scale = 4.886, + ): + device = self.freqs.device + self.freqs = torch.cat( + [ + get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale), + rope_params(1024, 2 * (self.d // 6)), + rope_params(1024, 2 * (self.d // 6)) + ], + dim=1 + ).to(device) + + def disable_riflex(self): + device = self.freqs.device + self.freqs = torch.cat( + [ + rope_params(1024, self.d - 4 * (self.d // 6)), + rope_params(1024, 2 * (self.d // 6)), + rope_params(1024, 2 * (self.d // 6)) + ], + dim=1 + ).to(device) + + def enable_multi_gpus_inference(self,): + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + for block in self.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, + cond_flag=True, + return_intermediate=False, # whether return mid layers' output + selected_layers=(5, 15, 25) # layer idx to output + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + cond_flag (`bool`, *optional*, defaults to True): + Flag to indicate whether to forward the condition input + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + dtype = x.dtype + if self.freqs.device != device and torch.device(type="meta") != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + if self.sp_world_size > 1: + seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + # to bfloat16 for saving memeory + # assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e0 = e0.to(dtype) + e = e.to(dtype) + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # Context Parallel + if self.sp_world_size > 1: + x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] + + # TeaCache + if self.teacache is not None: + if cond_flag: + modulated_inp = e0 + skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps + if self.teacache.cnt == 0 or self.teacache.cnt == self.teacache.num_steps - 1 or skip_flag: + should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + else: + if cond_flag: + rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp) + self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance) + if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.teacache.accumulated_rel_l1_distance = 0 + self.teacache.previous_modulated_input = modulated_inp + self.teacache.cnt += 1 + if self.teacache.cnt == self.teacache.num_steps: + self.teacache.reset() + self.teacache.should_calc = should_calc + else: + should_calc = self.teacache.should_calc + + intermediate_features = [] + + # TeaCache + if self.teacache is not None: + if not should_calc: + previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond + x = x + previous_residual.to(x.device) + else: + ori_x = x.clone().cpu() if self.teacache.offload else x.clone() + + for idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.freqs, + context, + context_lens, + dtype, + **ckpt_kwargs, + ) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dtype=dtype + ) + x = block(x, **kwargs) + + if return_intermediate and idx in selected_layers: + intermediate_features.append(x.clone()) + + if cond_flag: + self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x + else: + for idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + e0, + seq_lens, + grid_sizes, + self.freqs, + context, + context_lens, + dtype, + **ckpt_kwargs, + ) + else: + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + dtype=dtype + ) + x = block(x, **kwargs) + + if return_intermediate and idx in selected_layers: + intermediate_features.append(x.clone()) + + + if self.sp_world_size > 1: + x = get_sp_group().all_gather(x, dim=1) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + x = torch.stack(x) + + if return_intermediate: + mid_feat = intermediate_features + return x, mid_feat + else: + return x + + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) + + @classmethod + def from_pretrained( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + + if "dict_mapping" in transformer_additional_kwargs.keys(): + for key in transformer_additional_kwargs["dict_mapping"]: + transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key] + + if low_cpu_mem_usage: + try: + import re + + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + print(model_files_safetensors) + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for _model_file_safetensors in model_files_safetensors: + _state_dict = load_file(_model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size(): + model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'] + model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0 + state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight'] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + model = model.to(torch_dtype) + return model \ No newline at end of file diff --git a/rose/models/wan_vae.py b/rose/models/wan_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..f25b7ad15f42fd700f76eb53e511cef3233362d5 --- /dev/null +++ b/rose/models/wan_vae.py @@ -0,0 +1,705 @@ +# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.autoencoders.vae import (DecoderOutput, + DiagonalGaussianDistribution) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from einops import rearrange + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class AutoencoderKLWan_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + scale = [item.to(x.device, x.dtype) for item in scale] + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + x = torch.cat([mu, log_var], dim = 1) + self.clear_cache() + return x + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + scale = [item.to(z.device, z.dtype) for item in scale] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(z_dim=None, **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0) + cfg.update(**kwargs) + + # init model + model = AutoencoderKLWan_(**cfg) + + return model + + +class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): + + @register_to_config + def __init__( + self, + latent_channels=16, + temporal_compression_ratio=4, + spacial_compression_ratio=8 + ): + super().__init__() + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=torch.float32) + self.std = torch.tensor(std, dtype=torch.float32) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae( + z_dim=latent_channels, + ) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + x = [ + self.model.encode(u.unsqueeze(0), self.scale).squeeze(0) + for u in x + ] + x = torch.stack(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, zs): + dec = [ + self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0) + for u in zs + ] + dec = torch.stack(dec) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + @classmethod + def from_pretrained(cls, pretrained_model_path, additional_kwargs={}): + def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + + model = cls(**filter_kwargs(cls, additional_kwargs)) + if pretrained_model_path.endswith(".safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(pretrained_model_path) + else: + state_dict = torch.load(pretrained_model_path, map_location="cpu") + tmp_state_dict = {} + for key in state_dict: + tmp_state_dict["model." + key] = state_dict[key] + state_dict = tmp_state_dict + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m, u) + return model \ No newline at end of file diff --git a/rose/models/wan_xlm_roberta.py b/rose/models/wan_xlm_roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..755baf394431bee95e1eac835b5dafe6ed37c5b9 --- /dev/null +++ b/rose/models/wan_xlm_roberta.py @@ -0,0 +1,170 @@ +# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['XLMRoberta', 'xlm_roberta_large'] + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, mask, p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps) + self.norm1 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), + nn.Dropout(dropout)) + self.norm2 = nn.LayerNorm(dim, eps=eps) + + def forward(self, x, mask): + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Module): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__(self, + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5): + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) + self.type_embedding = nn.Embedding(type_size, dim) + self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) + self.dropout = nn.Dropout(dropout) + + # blocks + self.blocks = nn.ModuleList([ + AttentionBlock(dim, num_heads, post_norm, dropout, eps) + for _ in range(num_layers) + ]) + + # norm layer + self.norm = nn.LayerNorm(dim, eps=eps) + + def forward(self, ids): + """ + ids: [B, L] of torch.LongTensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).long() + + # embeddings + x = self.token_embedding(ids) + \ + self.type_embedding(torch.zeros_like(ids)) + \ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = torch.where( + mask.view(b, 1, 1, s).gt(0), 0.0, + torch.finfo(x.dtype).min) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained=False, + return_tokenizer=False, + device='cpu', + **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5) + cfg.update(**kwargs) + + # init a model on device + with torch.device(device): + model = XLMRoberta(**cfg) + return model \ No newline at end of file diff --git a/rose/pipeline/__init__.py b/rose/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b18563edf028a869bc978af354eeb799dafcc1eb --- /dev/null +++ b/rose/pipeline/__init__.py @@ -0,0 +1,6 @@ +from .pipeline_wan_fun import WanFunPipeline +from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline +from .pipeline_wan_fun_control import WanFunControlPipeline + +WanPipeline = WanFunPipeline +WanI2VPipeline = WanFunInpaintPipeline \ No newline at end of file diff --git a/rose/pipeline/pipeline_wan_fun.py b/rose/pipeline/pipeline_wan_fun.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6317dd4e1145c0a369f3ab03afa20c37db3a5d --- /dev/null +++ b/rose/pipeline/pipeline_wan_fun.py @@ -0,0 +1,558 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, WanTransformer3DModel) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spacial_compression_ratio, + width // self.vae.spacial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = negative_prompt_embeds + prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 5. Prepare latents + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype): + noise_pred = self.transformer( + x=latent_model_input, + context=prompt_embeds, + t=timestep, + seq_len=seq_len, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/rose/pipeline/pipeline_wan_fun_control.py b/rose/pipeline/pipeline_wan_fun_control.py new file mode 100644 index 0000000000000000000000000000000000000000..6146bea405bd7ceae40ba16183a135f03ea04cb2 --- /dev/null +++ b/rose/pipeline/pipeline_wan_fun_control.py @@ -0,0 +1,723 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + clip_image_encoder: CLIPModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spacial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spacial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spacial_compression_ratio, + width // self.vae.spacial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + control_video: Union[torch.FloatTensor] = None, + ref_image: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_image: Image = None, + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = negative_prompt_embeds + prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + control_latents = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ).to(device, weight_dtype) + else: + control_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + control_latents = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ).to(device, weight_dtype) + + if ref_image is not None: + video_length = ref_image.shape[2] + ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length) + + ref_image_latentes = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + + ref_image_latentes_conv_in = torch.zeros_like(latents) + if latents.size()[2] != 1: + ref_image_latentes_conv_in[:, :, :1] = ref_image_latentes + ref_image_latentes_conv_in = ( + torch.cat([ref_image_latentes_conv_in] * 2) if do_classifier_free_guidance else ref_image_latentes_conv_in + ).to(device, weight_dtype) + control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1) + else: + ref_image_latentes_conv_in = torch.zeros_like(latents) + ref_image_latentes_conv_in = ( + torch.cat([ref_image_latentes_conv_in] * 2) if do_classifier_free_guidance else ref_image_latentes_conv_in + ).to(device, weight_dtype) + control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1) + + # Prepare clip latent variables + if clip_image is not None: + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + else: + clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0)) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + clip_context = torch.zeros_like(clip_context) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype): + noise_pred = self.transformer( + x=latent_model_input, + context=prompt_embeds, + t=timestep, + seq_len=seq_len, + y=control_latents, + clip_fea=clip_context, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/rose/pipeline/pipeline_wan_fun_inpaint.py b/rose/pipeline/pipeline_wan_fun_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..7e164274be03157e3abc29852debdac5d6b9a763 --- /dev/null +++ b/rose/pipeline/pipeline_wan_fun_inpaint.py @@ -0,0 +1,729 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer +from torchvision.utils import save_image + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + clip_image_encoder: CLIPModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spacial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spacial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spacial_compression_ratio, + width // self.vae.spacial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_image: Image = None, + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = negative_prompt_embeds + prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + # import pdb; pdb.set_trace() + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + # masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + masked_video = init_video + + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + + # Prepare clip latent variables + if clip_image is not None: + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + else: + clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0)) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + clip_context = torch.zeros_like(clip_context) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype): + noise_pred = self.transformer( + x=latent_model_input, + context=prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y, + clip_fea=clip_context, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/rose/utils/__init__.py b/rose/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rose/utils/discrete_sampler.py b/rose/utils/discrete_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..149dbe7beb94dfea2e6fe0ca3b5acf9437be60f7 --- /dev/null +++ b/rose/utils/discrete_sampler.py @@ -0,0 +1,46 @@ +"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py +""" +import torch + +class DiscreteSampling: + def __init__(self, num_idx, uniform_sampling=False): + self.num_idx = num_idx + self.uniform_sampling = uniform_sampling + self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + + if self.is_distributed and self.uniform_sampling: + world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + + i = 1 + while True: + if world_size % i != 0 or num_idx % (world_size // i) != 0: + i += 1 + else: + self.group_num = world_size // i + break + assert self.group_num > 0 + assert world_size % self.group_num == 0 + # the number of rank in one group + self.group_width = world_size // self.group_num + self.sigma_interval = self.num_idx // self.group_num + print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % ( + self.rank, world_size, self.group_num, + self.group_width, self.sigma_interval)) + + def __call__(self, n_samples, generator=None, device=None): + if self.is_distributed and self.uniform_sampling: + group_index = self.rank // self.group_width + idx = torch.randint( + group_index * self.sigma_interval, + (group_index + 1) * self.sigma_interval, + (n_samples,), + generator=generator, device=device, + ) + print('proc[%d] idx=%s' % (self.rank, idx)) + else: + idx = torch.randint( + 0, self.num_idx, (n_samples,), + generator=generator, device=device, + ) + return idx \ No newline at end of file diff --git a/rose/utils/fp8_optimization.py b/rose/utils/fp8_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa6d26fe9a0a0365401ab77ba4103fcc723fca9 --- /dev/null +++ b/rose/utils/fp8_optimization.py @@ -0,0 +1,56 @@ +"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper +""" +import torch +import torch.nn as nn + +def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): + weight_dtype = cls.weight.dtype + cls.to(origin_dtype) + + # Convert all inputs to the original dtype + inputs = [input.to(origin_dtype) for input in inputs] + out = cls.original_forward(*inputs, **kwargs) + + cls.to(weight_dtype) + return out + +def replace_parameters_by_name(module, name_keywords, device): + from torch import nn + for name, param in list(module.named_parameters(recurse=False)): + if any(keyword in name for keyword in name_keywords): + if isinstance(param, nn.Parameter): + tensor = param.data + delattr(module, name) + setattr(module, name, tensor.to(device=device)) + for child_name, child_module in module.named_children(): + replace_parameters_by_name(child_module, name_keywords, device) + +def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']): + for name, module in model.named_modules(): + flag = False + for _exclude_module_name in exclude_module_name: + if _exclude_module_name in name: + flag = True + if flag: + continue + for param_name, param in module.named_parameters(): + flag = False + for _exclude_module_name in exclude_module_name: + if _exclude_module_name in param_name: + flag = True + if flag: + continue + param.data = param.data.to(torch.float8_e4m3fn) + +def convert_weight_dtype_wrapper(module, origin_dtype): + for name, module in module.named_modules(): + if name == "" or "embed_tokens" in name: + continue + original_forward = module.forward + if hasattr(module, "weight") and module.weight is not None: + setattr(module, "original_forward", original_forward) + setattr( + module, + "forward", + lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) + ) diff --git a/rose/utils/lora_utils.py b/rose/utils/lora_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f801fa8f7a5a4024a4d4b003513813ed38f6649c --- /dev/null +++ b/rose/utils/lora_utils.py @@ -0,0 +1,516 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +# https://github.com/bmaltais/kohya_ss + +import hashlib +import math +import os +from collections import defaultdict +from io import BytesIO +from typing import List, Optional, Type, Union + +import safetensors.torch +import torch +import torch.utils.checkpoint +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from safetensors.torch import load_file +from transformers import T5EncoderModel + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x, *args, **kwargs): + weight_dtype = x.dtype + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x.to(self.lora_down.weight.dtype)) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +class LoRANetwork(torch.nn.Module): + TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel", "WanTransformer3DModel"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"] + LORA_PREFIX_TRANSFORMER = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + def __init__( + self, + text_encoder: Union[List[T5EncoderModel], T5EncoderModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + module_class: Type[object] = LoRAModule, + skip_name: str = None, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.dropout = dropout + + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print(f"neuron dropout: p={self.dropout}") + + # create module instances + def create_modules( + is_unet: bool, + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_TRANSFORMER + if is_unet + else self.LORA_PREFIX_TEXT_ENCODER + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if skip_name is not None and skip_name in child_name: + continue + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1: + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if text_encoder is not None: + text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + info = self.load_state_dict(weights_sd, False) + return info + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + pass + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + text_encoder: Union[T5EncoderModel, List[T5EncoderModel]], + transformer, + neuron_dropout: Optional[float] = None, + skip_name: str = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + network = LoRANetwork( + text_encoder, + transformer, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + skip_name=skip_name, + varbose=True, + ) + return network + +def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False): + LORA_PREFIX_TRANSFORMER = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + if state_dict is None: + state_dict = load_file(lora_path, device=device) + else: + state_dict = state_dict + updates = defaultdict(dict) + for key, value in state_dict.items(): + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + sequential_cpu_offload_flag = False + if pipeline.transformer.device == torch.device(type="meta"): + pipeline.remove_all_hooks() + sequential_cpu_offload_flag = True + offload_device = pipeline._offload_device + + for layer, elems in updates.items(): + + if "lora_te" in layer: + if transformer_only: + continue + else: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") + curr_layer = pipeline.transformer + + try: + curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) + except Exception: + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos)) + break + except Exception: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print('Error loading layer') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + origin_dtype = curr_layer.weight.data.dtype + origin_device = curr_layer.weight.data.device + + curr_layer = curr_layer.to(device, dtype) + weight_up = elems['lora_up.weight'].to(device, dtype) + weight_down = elems['lora_down.weight'].to(device, dtype) + + if 'alpha' in elems.keys(): + alpha = elems['alpha'].item() / weight_up.shape[1] + else: + alpha = 1.0 + + if len(weight_up.shape) == 4: + curr_layer.weight.data += multiplier * alpha * torch.mm( + weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + curr_layer = curr_layer.to(origin_device, origin_dtype) + + if sequential_cpu_offload_flag: + pipeline.enable_sequential_cpu_offload(device=offload_device) + return pipeline + +# TODO: Refactor with merge_lora. +def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32): + """Unmerge state_dict in LoRANetwork from the pipeline in diffusers.""" + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + state_dict = load_file(lora_path, device=device) + + updates = defaultdict(dict) + for key, value in state_dict.items(): + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + sequential_cpu_offload_flag = False + if pipeline.transformer.device == torch.device(type="meta"): + pipeline.remove_all_hooks() + sequential_cpu_offload_flag = True + + for layer, elems in updates.items(): + + if "lora_te" in layer: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.transformer + + try: + curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) + except Exception: + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos)) + break + except Exception: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print('Error loading layer') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + origin_dtype = curr_layer.weight.data.dtype + origin_device = curr_layer.weight.data.device + + curr_layer = curr_layer.to(device, dtype) + weight_up = elems['lora_up.weight'].to(device, dtype) + weight_down = elems['lora_down.weight'].to(device, dtype) + + if 'alpha' in elems.keys(): + alpha = elems['alpha'].item() / weight_up.shape[1] + else: + alpha = 1.0 + + if len(weight_up.shape) == 4: + curr_layer.weight.data -= multiplier * alpha * torch.mm( + weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) + curr_layer = curr_layer.to(origin_device, origin_dtype) + + if sequential_cpu_offload_flag: + pipeline.enable_sequential_cpu_offload(device=device) + return pipeline diff --git a/rose/utils/utils.py b/rose/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe8d8139df3531f8d12ab48e9f599d64f0594aeb --- /dev/null +++ b/rose/utils/utils.py @@ -0,0 +1,318 @@ +import os +import gc +import imageio +import inspect +import numpy as np +import torch +import torchvision +import cv2 +from einops import rearrange +from PIL import Image + +def filter_kwargs(cls, kwargs): + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def get_width_and_height_from_image_and_base_resolution(image, base_resolution): + target_pixels = int(base_resolution) * int(base_resolution) + original_width, original_height = Image.open(image).size + ratio = (target_pixels / (original_width * original_height)) ** 0.5 + width_slider = round(original_width * ratio) + height_slider = round(original_height * ratio) + return height_slider, width_slider + +def color_transfer(sc, dc): + """ + Transfer color distribution from of sc, referred to dc. + + Args: + sc (numpy.ndarray): input image to be transfered. + dc (numpy.ndarray): reference image + + Returns: + numpy.ndarray: Transferred color distribution on the sc. + """ + + def get_mean_and_std(img): + x_mean, x_std = cv2.meanStdDev(img) + x_mean = np.hstack(np.around(x_mean, 2)) + x_std = np.hstack(np.around(x_std, 2)) + return x_mean, x_std + + sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB) + s_mean, s_std = get_mean_and_std(sc) + dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB) + t_mean, t_std = get_mean_and_std(dc) + img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean + np.putmask(img_n, img_n > 255, 255) + np.putmask(img_n, img_n < 0, 0) + dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB) + return dst + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(Image.fromarray(x)) + + if color_transfer_post_process: + for i in range(1, len(outputs)): + outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0]))) + + os.makedirs(os.path.dirname(path), exist_ok=True) + if imageio_backend: + if path.endswith("mp4"): + imageio.mimsave(path, outputs, fps=fps) + else: + imageio.mimsave(path, outputs, duration=(1000 * 1/fps)) + else: + if path.endswith("mp4"): + path = path.replace('.mp4', '.gif') + outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0) + +def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size): + if validation_image_start is not None and validation_image_end is not None: + if type(validation_image_start) is str and os.path.isfile(validation_image_start): + image_start = clip_image = Image.open(validation_image_start).convert("RGB") + image_start = image_start.resize([sample_size[1], sample_size[0]]) + clip_image = clip_image.resize([sample_size[1], sample_size[0]]) + else: + image_start = clip_image = validation_image_start + image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] + clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] + + if type(validation_image_end) is str and os.path.isfile(validation_image_end): + image_end = Image.open(validation_image_end).convert("RGB") + image_end = image_end.resize([sample_size[1], sample_size[0]]) + else: + image_end = validation_image_end + image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end] + + if type(image_start) is list: + clip_image = clip_image[0] + start_video = torch.cat( + [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], + dim=2 + ) + input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) + input_video[:, :, :len(image_start)] = start_video + + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, len(image_start):] = 255 + else: + input_video = torch.tile( + torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), + [1, 1, video_length, 1, 1] + ) + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, 1:] = 255 + + if type(image_end) is list: + image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end] + end_video = torch.cat( + [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end], + dim=2 + ) + input_video[:, :, -len(end_video):] = end_video + + input_video_mask[:, :, -len(image_end):] = 0 + else: + image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) + input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) + input_video_mask[:, :, -1:] = 0 + + input_video = input_video / 255 + + elif validation_image_start is not None: + if type(validation_image_start) is str and os.path.isfile(validation_image_start): + image_start = clip_image = Image.open(validation_image_start).convert("RGB") + image_start = image_start.resize([sample_size[1], sample_size[0]]) + clip_image = clip_image.resize([sample_size[1], sample_size[0]]) + else: + image_start = clip_image = validation_image_start + image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] + clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] + image_end = None + + if type(image_start) is list: + clip_image = clip_image[0] + start_video = torch.cat( + [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], + dim=2 + ) + input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) + input_video[:, :, :len(image_start)] = start_video + input_video = input_video / 255 + + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, len(image_start):] = 255 + else: + input_video = torch.tile( + torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), + [1, 1, video_length, 1, 1] + ) / 255 + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, 1:, ] = 255 + else: + image_start = None + image_end = None + input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]]) + input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255 + clip_image = None + + del image_start + del image_end + gc.collect() + + return input_video, input_video_mask, clip_image + +def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None): + if input_video_path is not None: + if isinstance(input_video_path, str): + cap = cv2.VideoCapture(input_video_path) + input_video = [] + + original_fps = cap.get(cv2.CAP_PROP_FPS) + frame_skip = 1 if fps is None else int(original_fps // fps) + + frame_count = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + if frame_count % frame_skip == 0: + frame = cv2.resize(frame, (sample_size[1], sample_size[0])) + input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + + frame_count += 1 + + cap.release() + else: + input_video = input_video_path + + input_video = torch.from_numpy(np.array(input_video))[:video_length] + input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 + + if validation_video_mask is not None: + validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) + input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) + + input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) + input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) + input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) + else: + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, :] = 255 + else: + input_video, input_video_mask = None, None + + if ref_image is not None: + if isinstance(ref_image, str): + clip_image = Image.open(ref_image).convert("RGB") + else: + clip_image = Image.fromarray(np.array(ref_image, np.uint8)) + else: + clip_image = None + + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + return input_video, input_video_mask, ref_image, clip_image + + +def get_video_and_mask(input_video_path, video_length, sample_size, fps=None, input_mask_path=None, ref_image=None): + if input_video_path is not None: + if isinstance(input_video_path, str): + cap = cv2.VideoCapture(input_video_path) + input_video = [] + + original_fps = cap.get(cv2.CAP_PROP_FPS) + frame_skip = 1 if fps is None else int(original_fps // fps) + + frame_count = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + if frame_count % frame_skip == 0: + frame = cv2.resize(frame, (sample_size[1], sample_size[0])) + input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + + frame_count += 1 + + cap.release() + else: + input_video = input_video_path + + input_video = torch.from_numpy(np.array(input_video))[:video_length] + input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255.0 + + else: + input_video = None + + if input_mask_path is not None: + if isinstance(input_mask_path, str): + cap = cv2.VideoCapture(input_mask_path) + mask_frames = [] + original_fps = cap.get(cv2.CAP_PROP_FPS) + frame_skip = 1 if fps is None else int(original_fps // fps) + frame_count = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + if frame_count % frame_skip == 0: + frame = cv2.resize(frame, (sample_size[1], sample_size[0])) + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + mask_frames.append(gray) + frame_count += 1 + cap.release() + else: + mask_frames = input_mask_path + + mask_np = np.array(mask_frames)[:video_length] # (F, H, W), uint8 + mask_bin = np.where(mask_np < 240, 0, 1).astype(np.uint8) # (F,H,W) + mask_tensor = torch.from_numpy(mask_bin) + mask_tensor = mask_tensor.unsqueeze(1) + mask_tensor = mask_tensor.unsqueeze(0) + input_mask = mask_tensor.permute(0,2,1,3,4) + input_mask = input_mask.float() + else: + input_mask = None + + if ref_image is not None: + if isinstance(ref_image, str): + clip_image = Image.open(ref_image).convert("RGB") + else: + clip_image = Image.fromarray(np.array(ref_image, np.uint8)) + else: + clip_image = None + + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + return input_video, input_mask, ref_image, clip_image diff --git a/test_sample/test-sample0.mp4 b/test_sample/test-sample0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8bdc127da76beff17e384b2152b8b95002679328 --- /dev/null +++ b/test_sample/test-sample0.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d63abe5ce28c76c83b0f575e7b8cd2707ffb58e3e478b7f500865700d4738a2 +size 476512 diff --git a/test_sample/test-sample1.mp4 b/test_sample/test-sample1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..88e6cbf54eefe62c695bb0f8f543466d16d7d386 --- /dev/null +++ b/test_sample/test-sample1.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54369037f2fb7c42ada4165618894cedcd87d8368d40547571b43e5fd4ff0025 +size 975899 diff --git a/test_sample/test-sample2.mp4 b/test_sample/test-sample2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..36f585a88144ffaa238a1493f0ef8dca00129924 --- /dev/null +++ b/test_sample/test-sample2.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11e3e5b0aea5881bf8c5edf8a8196b51bef11fac019ee96192cc9b24146ae07b +size 286526 diff --git a/test_sample/test-sample3.mp4 b/test_sample/test-sample3.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cdccc497b8ce107ee2fe8d2e9c70ed52c047cf5e --- /dev/null +++ b/test_sample/test-sample3.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29f38c57585dad9946aaa32f7ef9faae10bcb87913708462114da5fb164d4775 +size 146705 diff --git a/test_sample/test-sample4.mp4 b/test_sample/test-sample4.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4305c6792e03d811cb3ddcb9ac557d9b1d28391f --- /dev/null +++ b/test_sample/test-sample4.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0a42fa9abcfa72baee2a437b06b4e1981b5109c815795bf005d4f7d5cd47096 +size 1465415 diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/base_segmenter.py b/tools/base_segmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..2b975bb779b47485f9e6ba7435646b4db40a2c6a --- /dev/null +++ b/tools/base_segmenter.py @@ -0,0 +1,129 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL +from .mask_painter import mask_painter + + +class BaseSegmenter: + def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): + """ + device: model device + SAM_checkpoint: path of SAM checkpoint + model_type: vit_b, vit_l, vit_h + """ + print(f"Initializing BaseSegmenter to {device}") + assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h' + + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint) + self.model.to(device=self.device) + self.predictor = SamPredictor(self.model) + self.embedded = False + + @torch.no_grad() + def set_image(self, image: np.ndarray): + # PIL.open(image_path) 3channel: RGB + # image embedding: avoid encode the same image multiple times + self.orignal_image = image + if self.embedded: + print('repeat embedding, please reset_image.') + return + self.predictor.set_image(image) + self.embedded = True + return + + @torch.no_grad() + def reset_image(self): + # reset image embeding + self.predictor.reset_image() + self.embedded = False + + def predict(self, prompts, mode, multimask=True): + """ + image: numpy array, h, w, 3 + prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input' + prompts['point_coords']: numpy array [N,2] + prompts['point_labels']: numpy array [1,N] + prompts['mask_input']: numpy array [1,256,256] + mode: 'point' (points only), 'mask' (mask only), 'both' (consider both) + mask_outputs: True (return 3 masks), False (return 1 mask only) + whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :] + """ + assert self.embedded, 'prediction is called before set_image (feature embedding).' + assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both' + + if mode == 'point': + masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], + point_labels=prompts['point_labels'], + multimask_output=multimask) + elif mode == 'mask': + masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'], + multimask_output=multimask) + elif mode == 'both': # both + masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], + point_labels=prompts['point_labels'], + mask_input=prompts['mask_input'], + multimask_output=multimask) + else: + raise("Not implement now!") + # masks (n, h, w), scores (n,), logits (n, 256, 256) + return masks, scores, logits + + +if __name__ == "__main__": + # load and show an image + image = cv2.imread('/hhd3/gaoshang/truck.jpg') + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3) + + # initialise BaseSegmenter + SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' + model_type = 'vit_h' + device = "cuda:4" + base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device) + + # image embedding (once embedded, multiple prompts can be applied) + base_segmenter.set_image(image) + + # examples + # point only ------------------------ + mode = 'point' + prompts = { + 'point_coords': np.array([[500, 375], [1125, 625]]), + 'point_labels': np.array([1, 1]), + } + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) + + # both ------------------------ + mode = 'both' + mask_input = logits[np.argmax(scores), :, :] + prompts = {'mask_input': mask_input [None, :, :]} + prompts = { + 'point_coords': np.array([[500, 375], [1125, 625]]), + 'point_labels': np.array([1, 0]), + 'mask_input': mask_input[None, :, :] + } + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image) + + # mask only ------------------------ + mode = 'mask' + mask_input = logits[np.argmax(scores), :, :] + + prompts = {'mask_input': mask_input[None, :, :]} + + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image) diff --git a/tools/interact_tools.py b/tools/interact_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..c70b8c40c6c57f19242b42bfcde60f378b1ce7ba --- /dev/null +++ b/tools/interact_tools.py @@ -0,0 +1,99 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL +from .mask_painter import mask_painter as mask_painter2 +from .base_segmenter import BaseSegmenter +from .painter import mask_painter, point_painter +import os +import requests +import sys + + +mask_color = 3 +mask_alpha = 0.7 +contour_color = 1 +contour_width = 5 +point_color_ne = 8 +point_color_ps = 50 +point_alpha = 0.9 +point_radius = 15 +contour_color = 2 +contour_width = 5 + + +class SamControler(): + def __init__(self, SAM_checkpoint, model_type, device): + ''' + initialize sam controler + ''' + self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) + + + # def seg_again(self, image: np.ndarray): + # ''' + # it is used when interact in video + # ''' + # self.sam_controler.reset_image() + # self.sam_controler.set_image(image) + # return + + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3): + ''' + it is used in first frame in video + return: mask, logit, painted image(mask+point) + ''' + # self.sam_controler.set_image(image) + origal_image = self.sam_controler.orignal_image + neg_flag = labels[-1] + if neg_flag==1: + #find neg + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + prompts = { + 'point_coords': points, + 'point_labels': labels, + 'mask_input': logit[None, :, :] + } + masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + else: + #find positive + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + + + assert len(points)==len(labels) + + painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + painted_image = Image.fromarray(painted_image) + + return mask, logit, painted_image + + + + + + + + + + + + \ No newline at end of file diff --git a/tools/mask_painter.py b/tools/mask_painter.py new file mode 100644 index 0000000000000000000000000000000000000000..f471ea0116d656e2cc236832893b07c6d7be1643 --- /dev/null +++ b/tools/mask_painter.py @@ -0,0 +1,288 @@ +import cv2 +import torch +import numpy as np +from PIL import Image +import copy +import time + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha): + background_color = np.array(background_color) + contour_color = np.array(contour_color) + + # background_mask = 1 - background_mask + # contour_mask = 1 - contour_mask + + for i in range(3): + image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \ + + background_color[i] * (background_alpha-background_mask*background_alpha) + + image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \ + + contour_color[i] * (contour_alpha-contour_mask*contour_alpha) + + return image.astype('uint8') + + +def mask_generator_00(mask, background_radius, contour_radius): + # no background width when '00' + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + return mask, contour_mask + + +def mask_generator_01(mask, background_radius, contour_radius): + # no background width when '00' + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + return mask, contour_mask + + +def mask_generator_10(mask, background_radius, contour_radius): + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # .....:::::!!!!! + background_mask = np.clip(dist_map, -background_radius, background_radius) + background_mask = (background_mask - np.min(background_mask)) + background_mask = background_mask / np.max(background_mask) + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + return background_mask, contour_mask + + +def mask_generator_11(mask, background_radius, contour_radius): + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # .....:::::!!!!! + background_mask = np.clip(dist_map, -background_radius, background_radius) + background_mask = (background_mask - np.min(background_mask)) + background_mask = background_mask / np.max(background_mask) + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + return background_mask, contour_mask + + +def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'): + """ + Input: + input_image: numpy array + input_mask: numpy array + background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing + background_blur_radius: radius of background blur, must be odd number + contour_width: width of mask contour, must be odd number + contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others + contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted + mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both + + Output: + painted_image: numpy array + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape' + assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD' + assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11' + + # downsample input image and mask + width, height = input_image.shape[0], input_image.shape[1] + res = 1024 + ratio = min(1.0 * res / max(width, height), 1.0) + input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio))) + input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio))) + + # 0: background, 1: foreground + msk = np.clip(input_mask, 0, 1) + + # generate masks for background and contour pixels + background_radius = (background_blur_radius - 1) // 2 + contour_radius = (contour_width - 1) // 2 + generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11} + background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius) + + # paint + painted_image = vis_add_mask\ + (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background + + return painted_image + + +if __name__ == '__main__': + + background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing + background_blur_radius = 31 # radius of background blur, must be odd number + contour_width = 11 # contour width, must be odd number + contour_color = 3 # id in color map, 0: black, 1: white, >1: others + contour_alpha = 1 # transparency of background, 0: no contour highlighted + + # load input image and mask + input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P')) + + # paint + overall_time_1 = 0 + overall_time_2 = 0 + overall_time_3 = 0 + overall_time_4 = 0 + overall_time_5 = 0 + + for i in range(50): + t2 = time.time() + painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00') + e2 = time.time() + + t3 = time.time() + painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10') + e3 = time.time() + + t1 = time.time() + painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha) + e1 = time.time() + + t4 = time.time() + painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01') + e4 = time.time() + + t5 = time.time() + painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11') + e5 = time.time() + + overall_time_1 += (e1 - t1) + overall_time_2 += (e2 - t2) + overall_time_3 += (e3 - t3) + overall_time_4 += (e4 - t4) + overall_time_5 += (e5 - t5) + + print(f'average time w gaussian: {overall_time_1/50}') + print(f'average time w/o gaussian00: {overall_time_2/50}') + print(f'average time w/o gaussian10: {overall_time_3/50}') + print(f'average time w/o gaussian01: {overall_time_4/50}') + print(f'average time w/o gaussian11: {overall_time_5/50}') + + # save + painted_image_00 = Image.fromarray(painted_image_00) + painted_image_00.save('./test_img/painter_output_image_00.png') + + painted_image_10 = Image.fromarray(painted_image_10) + painted_image_10.save('./test_img/painter_output_image_10.png') + + painted_image_01 = Image.fromarray(painted_image_01) + painted_image_01.save('./test_img/painter_output_image_01.png') + + painted_image_11 = Image.fromarray(painted_image_11) + painted_image_11.save('./test_img/painter_output_image_11.png') diff --git a/tools/painter.py b/tools/painter.py new file mode 100644 index 0000000000000000000000000000000000000000..0e711d35aa8348d15cdad9d1cd413da41ea4f1ab --- /dev/null +++ b/tools/painter.py @@ -0,0 +1,215 @@ +# paint masks, contours, or points on images, with specified colors +import cv2 +import torch +import numpy as np +from PIL import Image +import copy +import time + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def vis_add_mask(image, mask, color, alpha): + color = np.array(color_list[color]) + mask = mask > 0.5 + image[mask] = image[mask] * (1-alpha) + color * alpha + return image.astype('uint8') + +def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5): + h, w = input_image.shape[:2] + point_mask = np.zeros((h, w)).astype('uint8') + for point in input_points: + point_mask[point[1], point[0]] = 1 + + kernel = cv2.getStructuringElement(2, (point_radius, point_radius)) + point_mask = cv2.dilate(point_mask, kernel) + + contour_radius = (contour_width - 1) // 2 + dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + # paint mask + painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha) + # paint contour + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) + return painted_image + +def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3): + assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' + # 0: background, 1: foreground + mask = np.clip(input_mask, 0, 1) + contour_radius = (contour_width - 1) // 2 + + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + # paint mask + painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha) + # paint contour + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) + + return painted_image + +def background_remover(input_image, input_mask): + """ + input_image: H, W, 3, np.array + input_mask: H, W, np.array + + image_wo_background: PIL.Image + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' + # 0: background, 1: foreground + mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255 + image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4 + image_wo_background = Image.fromarray(image_wo_background).convert('RGBA') + + return image_wo_background + +if __name__ == '__main__': + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P')) + + # example of mask painter + mask_color = 3 + mask_alpha = 0.7 + contour_color = 1 + contour_width = 5 + + # save + painted_image = Image.fromarray(input_image) + painted_image.save('images/original.png') + + painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width) + # save + painted_image = Image.fromarray(input_image) + painted_image.save('images/original1.png') + + # example of point painter + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + input_points = np.array([[500, 375], [70, 600]]) # x, y + point_color = 5 + point_alpha = 0.9 + point_radius = 15 + contour_color = 2 + contour_width = 5 + painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width) + # save + painted_image = Image.fromarray(painted_image_1) + painted_image.save('images/point_painter_1.png') + + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29) + # save + painted_image = Image.fromarray(painted_image_2) + painted_image.save('images/point_painter_2.png') + + # example of background remover + input_image = np.array(Image.open('images/original.png').convert('RGB')) + image_wo_background = background_remover(input_image, input_mask) # return PIL.Image + image_wo_background.save('images/image_wo_background.png') diff --git a/track_anything.py b/track_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..57ea61282c0845324a8fe07e46538225b3c16c96 --- /dev/null +++ b/track_anything.py @@ -0,0 +1,40 @@ +import numpy as np +from tqdm import tqdm + +from tools.interact_tools import SamControler +from tracker.base_tracker import BaseTracker +# from inpainter.base_inpainter import ProInpainter + + +class TrackingAnything(): + def __init__(self, sam_checkpoint, cutie_checkpoint, args): + self.args = args + self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device) + self.cutie = BaseTracker(cutie_checkpoint, device=args.device) + # self.baseinpainter = ProInpainter(propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args.device) + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): + mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) + return mask, logit, painted_image + + def generator(self, images: list, template_mask:np.ndarray): + masks = [] + logits = [] + painted_images = [] + for i in tqdm(range(len(images)), desc="Tracking image"): + if i==0: + mask, logit, painted_image = self.cutie.track(images[i], template_mask) + masks.append(mask) + logits.append(logit) + painted_images.append(painted_image) + else: + mask, logit, painted_image = self.cutie.track(images[i]) + masks.append(mask) + logits.append(logit) + painted_images.append(painted_image) + return masks, logits, painted_images + + + + + \ No newline at end of file diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec527326c1df4b4bfb9fdd90bc925460cca39f2 --- /dev/null +++ b/tracker/base_tracker.py @@ -0,0 +1,103 @@ +import numpy as np +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf + +import sys +sys.path.append('../') + +from tracker.config import CONFIG +from tracker.model.cutie import CUTIE +from tracker.inference.inference_core import InferenceCore +from tracker.utils.mask_mapper import MaskMapper + +from tools.painter import mask_painter + + +class BaseTracker: + def __init__(self, cutie_checkpoint, device) -> None: + """ + device: model device + cutie_checkpoint: checkpoint of XMem model + """ + config = OmegaConf.create(CONFIG) + + # initialise XMem + network = CUTIE(config).to(device).eval() + model_weights = torch.load(cutie_checkpoint, map_location=device) + network.load_weights(model_weights) + + # initialise IncerenceCore + self.tracker = InferenceCore(network, config) + self.device = device + + # changable properties + self.mapper = MaskMapper() + self.initialised = False + + @torch.no_grad() + def resize_mask(self, mask): + # mask transform is applied AFTER mapper, so we need to post-process it in eval.py + h, w = mask.shape[-2:] + min_hw = min(h, w) + return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), + mode='nearest') + + @torch.no_grad() + def image_to_torch(self, frame: np.ndarray, device: str = 'cuda'): + # frame: H*W*3 numpy array + frame = frame.transpose(2, 0, 1) + frame = torch.from_numpy(frame).float().to(device, non_blocking=True) / 255 + return frame + + @torch.no_grad() + def track(self, frame, first_frame_annotation=None): + """ + Input: + frames: numpy arrays (H, W, 3) + logit: numpy array (H, W), logit + + Output: + mask: numpy arrays (H, W) + logit: numpy arrays, probability map (H, W) + painted_image: numpy array (H, W, 3) + """ + + if first_frame_annotation is not None: # first frame mask + # initialisation + mask, labels = self.mapper.convert_mask(first_frame_annotation) + mask = torch.Tensor(mask).to(self.device) + else: + mask = None + labels = None + + # prepare inputs + frame_tensor = self.image_to_torch(frame, self.device) + + # track one frame + probs = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W + + # convert to mask + out_mask = torch.argmax(probs, dim=0) + out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + + final_mask = np.zeros_like(out_mask) + + # map back + for k, v in self.mapper.remappings.items(): + final_mask[out_mask == v] = k + + num_objs = final_mask.max() + painted_image = frame + for obj in range(1, num_objs+1): + if np.max(final_mask==obj) == 0: + continue + painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1) + + return final_mask, final_mask, painted_image + + @torch.no_grad() + def clear_memory(self): + self.tracker.clear_memory() + self.mapper.clear_labels() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/tracker/config/__init__.py b/tracker/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a26a5c2f36fabe3c9870cf0c1ba3c4414e15cd42 --- /dev/null +++ b/tracker/config/__init__.py @@ -0,0 +1 @@ +CONFIG = {'exp_id': 'default', 'dataset': 'd17-val', 'amp': False, 'output_dir': None, 'flip_aug': False, 'max_internal_size': -1, 'image_directory': None, 'mask_directory': None, 'json_directory': None, 'size': None, 'save_all': None, 'use_all_masks': None, 'use_long_term': None, 'mem_every': 5, 'max_mem_frames': 5, 'long_term': {'count_usage': True, 'max_mem_frames': 10, 'min_mem_frames': 5, 'num_prototypes': 128, 'max_num_tokens': 10000, 'buffer_tokens': 2000}, 'top_k': 30, 'stagger_updates': 5, 'chunk_size': -1, 'save_scores': False, 'save_aux': False, 'visualize': False, 'model': {'pixel_mean': [0.485, 0.456, 0.406], 'pixel_std': [0.229, 0.224, 0.225], 'pixel_dim': 256, 'key_dim': 64, 'value_dim': 256, 'sensory_dim': 256, 'embed_dim': 256, 'pixel_encoder': {'type': 'resnet50', 'ms_dims': [1024, 512, 256]}, 'mask_encoder': {'type': 'resnet18', 'final_dim': 256}, 'pixel_pe_scale': 32, 'pixel_pe_temperature': 128, 'object_transformer': {'embed_dim': '${model.embed_dim}', 'ff_dim': 2048, 'num_heads': 8, 'num_blocks': 3, 'num_queries': 16, 'read_from_pixel': {'input_norm': False, 'input_add_pe': False, 'add_pe_to_qkv': [True, True, False]}, 'read_from_past': {'add_pe_to_qkv': [True, True, False]}, 'read_from_memory': {'add_pe_to_qkv': [True, True, False]}, 'read_from_query': {'add_pe_to_qkv': [True, True, False], 'output_norm': False}, 'query_self_attention': {'add_pe_to_qkv': [True, True, False]}, 'pixel_self_attention': {'add_pe_to_qkv': [True, True, False]}}, 'object_summarizer': {'embed_dim': '${model.object_transformer.embed_dim}', 'num_summaries': '${model.object_transformer.num_queries}', 'add_pe': True}, 'aux_loss': {'sensory': {'enabled': True, 'weight': 0.01}, 'query': {'enabled': True, 'weight': 0.01}}, 'mask_decoder': {'up_dims': [256, 128, 128]}}} \ No newline at end of file diff --git a/tracker/inference/__init__.py b/tracker/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tracker/inference/image_feature_store.py b/tracker/inference/image_feature_store.py new file mode 100644 index 0000000000000000000000000000000000000000..1d02a634b009bb06b935e49036adb121325f5816 --- /dev/null +++ b/tracker/inference/image_feature_store.py @@ -0,0 +1,49 @@ +import warnings +from typing import Iterable +import torch +from tracker.model.cutie import CUTIE + + +class ImageFeatureStore: + """ + A cache for image features. + These features might be reused at different parts of the inference pipeline. + This class provide an interface for reusing these features. + It is the user's responsibility to delete redundant features. + + Feature of a frame should be associated with a unique index -- typically the frame id. + """ + def __init__(self, network: CUTIE, no_warning: bool = False): + self.network = network + self._store = {} + self.no_warning = no_warning + + def _encode_feature(self, index: int, image: torch.Tensor) -> None: + ms_features, pix_feat = self.network.encode_image(image) + key, shrinkage, selection = self.network.transform_key(ms_features[0]) + self._store[index] = (ms_features, pix_feat, key, shrinkage, selection) + + def get_features(self, index: int, + image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): + if index not in self._store: + self._encode_feature(index, image) + + return self._store[index][:2] + + def get_key(self, index: int, + image: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): + if index not in self._store: + self._encode_feature(index, image) + + return self._store[index][2:] + + def delete(self, index: int) -> None: + if index in self._store: + del self._store[index] + + def __len__(self): + return len(self._store) + + def __del__(self): + if len(self._store) > 0 and not self.no_warning: + warnings.warn(f'Leaking {self._store.keys()} in the image feature store') diff --git a/tracker/inference/inference_core.py b/tracker/inference/inference_core.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d5b24cf9aaed19656a194b1d6dfc0aaffe91c6 --- /dev/null +++ b/tracker/inference/inference_core.py @@ -0,0 +1,381 @@ +from typing import List, Optional, Iterable, Dict +import logging +from omegaconf import DictConfig + +import numpy as np +import torch +import torch.nn.functional as F + +from tracker.inference.memory_manager import MemoryManager +from tracker.inference.object_manager import ObjectManager +from tracker.inference.image_feature_store import ImageFeatureStore +from tracker.model.cutie import CUTIE +from tracker.utils.tensor_utils import pad_divide_by, unpad, aggregate + +log = logging.getLogger() + + +class InferenceCore: + def __init__(self, + network: CUTIE, + cfg: DictConfig, + *, + image_feature_store: ImageFeatureStore = None): + self.network = network + self.cfg = cfg + self.mem_every = cfg.mem_every + stagger_updates = cfg.stagger_updates + self.chunk_size = cfg.chunk_size + self.save_aux = cfg.save_aux + self.max_internal_size = cfg.max_internal_size + self.flip_aug = cfg.flip_aug + + self.curr_ti = -1 + self.last_mem_ti = 0 + # at which time indices should we update the sensory memory + if stagger_updates >= self.mem_every: + self.stagger_ti = set(range(1, self.mem_every + 1)) + else: + self.stagger_ti = set( + np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int)) + self.object_manager = ObjectManager() + self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager) + + if image_feature_store is None: + self.image_feature_store = ImageFeatureStore(self.network) + else: + self.image_feature_store = image_feature_store + + self.last_mask = None + + def clear_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager) + + def clear_non_permanent_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory.clear_non_permanent_memory() + + def clear_sensory_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory.clear_sensory_memory() + + def update_config(self, cfg): + self.mem_every = cfg['mem_every'] + self.memory.update_config(cfg) + + def _add_memory(self, + image: torch.Tensor, + pix_feat: torch.Tensor, + prob: torch.Tensor, + key: torch.Tensor, + shrinkage: torch.Tensor, + selection: torch.Tensor, + *, + is_deep_update: bool = True, + force_permanent: bool = False) -> None: + """ + Memorize the given segmentation in all memory stores. + + The batch dimension is 1 if flip augmentation is not used. + image: RGB image, (1/2)*3*H*W + pix_feat: from the key encoder, (1/2)*_*H*W + prob: (1/2)*num_objects*H*W, in [0, 1] + key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W + selection can be None if not using long-term memory + is_deep_update: whether to use deep update (e.g. with the mask encoder) + force_permanent: whether to force the memory to be permanent + """ + if prob.shape[1] == 0: + # nothing to add + log.warn('Trying to add an empty object mask to memory!') + return + + if force_permanent: + as_permanent = 'all' + else: + as_permanent = 'first' + + self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids) + msk_value, sensory, obj_value, self.obj_logits = self.network.encode_mask( + image, + pix_feat, + self.memory.get_sensory(self.object_manager.all_obj_ids), + prob, + deep_update=is_deep_update, + chunk_size=self.chunk_size, + need_weights=self.save_aux) + self.memory.add_memory(key, + shrinkage, + msk_value, + obj_value, + self.object_manager.all_obj_ids, + selection=selection, + as_permanent=as_permanent) + self.last_mem_ti = self.curr_ti + if is_deep_update: + self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) + + def _segment(self, + key: torch.Tensor, + selection: torch.Tensor, + pix_feat: torch.Tensor, + ms_features: Iterable[torch.Tensor], + update_sensory: bool = True) -> torch.Tensor: + """ + Produce a segmentation using the given features and the memory + + The batch dimension is 1 if flip augmentation is not used. + key/selection: for anisotropic l2: (1/2) * _ * H * W + pix_feat: from the key encoder, (1/2) * _ * H * W + ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W + with strides 16, 8, and 4 respectively + update_sensory: whether to update the sensory memory + + Returns: (num_objects+1)*H*W normalized probability; the first channel is the background + """ + bs = key.shape[0] + if self.flip_aug: + assert bs == 2 + else: + assert bs == 1 + + if not self.memory.engaged: + log.warn('Trying to segment without any memory!') + return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), + device=key.device, + dtype=key.dtype) + + memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network) + memory_readout = self.object_manager.realize_dict(memory_readout) + sensory, _, pred_prob_with_bg = self.network.segment(ms_features, + memory_readout, + self.memory.get_sensory( + self.object_manager.all_obj_ids), + chunk_size=self.chunk_size, + update_sensory=update_sensory) + # remove batch dim + if self.flip_aug: + # average predictions of the non-flipped and flipped version + pred_prob_with_bg = (pred_prob_with_bg[0] + + torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2 + else: + pred_prob_with_bg = pred_prob_with_bg[0] + if update_sensory: + self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) + return pred_prob_with_bg + + def step(self, + image: torch.Tensor, + mask: Optional[torch.Tensor] = None, + objects: Optional[List[int]] = None, + *, + idx_mask: bool = True, + end: bool = False, + delete_buffer: bool = True, + force_permanent: bool = False) -> torch.Tensor: + """ + Take a step with a new incoming image. + If there is an incoming mask with new objects, we will memorize them. + If there is no incoming mask, we will segment the image using the memory. + In both cases, we will update the memory and return a segmentation. + + image: 3*H*W + mask: H*W (if idx mask) or len(objects)*H*W or None + objects: list of object ids that are valid in the mask Tensor. + The ids themselves do not need to be consecutive/in order, but they need to be + in the same position in the list as the corresponding mask + in the tensor in non-idx-mask mode. + objects is ignored if the mask is None. + If idx_mask is False and objects is None, we sequentially infer the object ids. + idx_mask: if True, mask is expected to contain an object id at every pixel. + If False, mask should have multiple channels with each channel representing one object. + end: if we are at the end of the sequence, we do not need to update memory + if unsure just set it to False + delete_buffer: whether to delete the image feature buffer after this step + force_permanent: the memory recorded this frame will be added to the permanent memory + """ + if objects is None and mask is not None: + assert not idx_mask + objects = list(range(1, mask.shape[0] + 1)) + + # resize input if needed -- currently only used for the GUI + resize_needed = False + if self.max_internal_size > 0: + h, w = image.shape[-2:] + min_side = min(h, w) + if min_side > self.max_internal_size: + resize_needed = True + new_h = int(h / min_side * self.max_internal_size) + new_w = int(w / min_side * self.max_internal_size) + image = F.interpolate(image.unsqueeze(0), + size=(new_h, new_w), + mode='bilinear', + align_corners=False)[0] + if mask is not None: + if idx_mask: + mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), + size=(new_h, new_w), + mode='nearest', + align_corners=False)[0, 0].round().long() + else: + mask = F.interpolate(mask.unsqueeze(0), + size=(new_h, new_w), + mode='bilinear', + align_corners=False)[0] + + self.curr_ti += 1 + + image, self.pad = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + if self.flip_aug: + image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0) + + # whether to update the working memory + is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or + (mask is not None)) and (not end) + # segment when there is no input mask or when the input mask is incomplete + need_segment = (mask is None) or (self.object_manager.num_obj > 0 + and not self.object_manager.has_all(objects)) + update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end) + + # encoding the image + ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) + key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image) + + # segmentation from memory if needed + if need_segment: + pred_prob_with_bg = self._segment(key, + selection, + pix_feat, + ms_feat, + update_sensory=update_sensory) + + # use the input mask if provided + if mask is not None: + # inform the manager of the new objects, and get a list of temporary id + # temporary ids -- indicates the position of objects in the tensor + # (starts with 1 due to the background channel) + corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects) + + mask, _ = pad_divide_by(mask, 16) + if need_segment: + # merge predicted mask with the incomplete input mask + pred_prob_no_bg = pred_prob_with_bg[1:] + # use the mutual exclusivity of segmentation + if idx_mask: + pred_prob_no_bg[:, mask > 0] = 0 + else: + pred_prob_no_bg[:, mask.max(0) > 0.5] = 0 + + new_masks = [] + for mask_id, tmp_id in enumerate(corresponding_tmp_ids): + if idx_mask: + this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg) + else: + this_mask = mask[tmp_id] + if tmp_id >= pred_prob_no_bg.shape[0]: + new_masks.append(this_mask.unsqueeze(0)) + else: + # +1 for padding the background channel + pred_prob_no_bg[tmp_id + 1] = this_mask + # new_masks are always in the order of tmp_id + mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0) + elif idx_mask: + # simply convert cls to one-hot representation + if len(objects) == 0: + if delete_buffer: + self.image_feature_store.delete(self.curr_ti) + log.warn('Trying to insert an empty mask as memory!') + return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), + device=key.device, + dtype=key.dtype) + mask = torch.stack( + [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)], + dim=0) + pred_prob_with_bg = aggregate(mask, dim=0) + pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0) + + self.last_mask = pred_prob_with_bg[1:].unsqueeze(0) + if self.flip_aug: + self.last_mask = torch.cat( + [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0) + + # save as memory if needed + if is_mem_frame or force_permanent: + self._add_memory(image, + pix_feat, + self.last_mask, + key, + shrinkage, + selection, + force_permanent=force_permanent) + + if delete_buffer: + self.image_feature_store.delete(self.curr_ti) + + output_prob = unpad(pred_prob_with_bg, self.pad) + if resize_needed: + # restore output to the original size + output_prob = F.interpolate(output_prob.unsqueeze(0), + size=(h, w), + mode='bilinear', + align_corners=False)[0] + + return output_prob + + def get_aux_outputs(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: + image, pads = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + _, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) + + aux_inputs = self.memory.aux + aux_outputs = self.network.compute_aux(pix_feat, aux_inputs, selector=None) + aux_outputs['q_weights'] = aux_inputs['q_weights'] + aux_outputs['p_weights'] = aux_inputs['p_weights'] + + for k, v in aux_outputs.items(): + if len(v.shape) == 5: + aux_outputs[k] = F.interpolate(v[0], + size=image.shape[-2:], + mode='bilinear', + align_corners=False) + elif 'weights' in k: + b, num_objects, num_heads, num_queries, h, w = v.shape + v = v.view(num_objects * num_heads, num_queries, h, w) + v = F.interpolate(v, size=image.shape[-2:], mode='bilinear', align_corners=False) + aux_outputs[k] = v.view(num_objects, num_heads, num_queries, *image.shape[-2:]) + else: + aux_outputs[k] = F.interpolate(v, + size=image.shape[-2:], + mode='bilinear', + align_corners=False)[0] + aux_outputs[k] = unpad(aux_outputs[k], pads) + if 'weights' in k: + weights = aux_outputs[k] + weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0] + + 1e-8) + aux_outputs[k] = (weights * 255).cpu().numpy() + else: + aux_outputs[k] = (aux_outputs[k].softmax(dim=0) * 255).cpu().numpy() + + self.image_feature_store.delete(self.curr_ti) + return aux_outputs + + def get_aux_object_weights(self, image: torch.Tensor) -> np.ndarray: + image, pads = pad_divide_by(image, 16) + # B*num_objects*H*W*num_queries -> num_objects*num_queries*H*W + # weights = F.softmax(self.obj_logits, dim=-1)[0] + weights = F.sigmoid(self.obj_logits)[0] + weights = weights.permute(0, 3, 1, 2).contiguous() + weights = F.interpolate(weights, + size=image.shape[-2:], + mode='bilinear', + align_corners=False) + # weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0]) + weights = unpad(weights, pads) + weights = (weights * 255).cpu().numpy() + return weights diff --git a/tracker/inference/kv_memory_store.py b/tracker/inference/kv_memory_store.py new file mode 100644 index 0000000000000000000000000000000000000000..e50b794dc227a772e8a7478d26d662749c0b1c6c --- /dev/null +++ b/tracker/inference/kv_memory_store.py @@ -0,0 +1,348 @@ +from typing import Dict, List, Optional, Literal +from collections import defaultdict +import torch + + +def _add_last_dim(dictionary, key, new_value, prepend=False): + # append/prepend a new value to the last dimension of a tensor in a dictionary + # if the key does not exist, put the new value in + # append by default + if key in dictionary: + dictionary[key] = torch.cat([dictionary[key], new_value], -1) + else: + dictionary[key] = new_value + + +class KeyValueMemoryStore: + """ + Works for key/value pairs type storage + e.g., working and long-term memory + """ + def __init__(self, save_selection: bool = False, save_usage: bool = False): + """ + We store keys and values of objects that first appear in the same frame in a bucket. + Each bucket contains a set of object ids. + Each bucket is associated with a single key tensor + and a dictionary of value tensors indexed by object id. + + The keys and values are stored as the concatenation of a permanent part and a temporary part. + """ + self.save_selection = save_selection + self.save_usage = save_usage + + self.global_bucket_id = 0 # does not reduce even if buckets are removed + self.buckets: Dict[int, List[int]] = {} # indexed by bucket id + self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id + self.v: Dict[int, torch.Tensor] = {} # indexed by object id + + # indexed by bucket id; the end point of permanent memory + self.perm_end_pt: Dict[int, int] = defaultdict(int) + + # shrinkage and selection are just like the keys + self.s = {} + if self.save_selection: + self.e = {} # does not contain the permanent memory part + + # usage + if self.save_usage: + self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part + self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part + + def add(self, + key: torch.Tensor, + values: Dict[int, torch.Tensor], + shrinkage: torch.Tensor, + selection: torch.Tensor, + supposed_bucket_id: int = -1, + as_permanent: Literal['no', 'first', 'all'] = 'no') -> None: + """ + key: (1/2)*C*N + values: dict of values ((1/2)*C*N), object ids are used as keys + shrinkage: (1/2)*1*N + selection: (1/2)*C*N + + supposed_bucket_id: used to sync the bucket id between working and long-term memory + if provided, the input should all be in a single bucket indexed by this id + as_permanent: whether to store the input as permanent memory + 'no': don't + 'first': only store it as permanent memory if the bucket is empty + 'all': always store it as permanent memory + """ + bs = key.shape[0] + ne = key.shape[-1] + assert len(key.shape) == 3 + assert len(shrinkage.shape) == 3 + assert not self.save_selection or len(selection.shape) == 3 + assert as_permanent in ['no', 'first', 'all'] + + # add the value and create new buckets if necessary + if supposed_bucket_id >= 0: + enabled_buckets = [supposed_bucket_id] + bucket_exist = supposed_bucket_id in self.buckets + for obj, value in values.items(): + if bucket_exist: + assert obj in self.v + assert obj in self.buckets[supposed_bucket_id] + _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) + else: + assert obj not in self.v + self.v[obj] = value + self.buckets[supposed_bucket_id] = list(values.keys()) + else: + new_bucket_id = None + enabled_buckets = set() + for obj, value in values.items(): + assert len(value.shape) == 3 + if obj in self.v: + _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) + bucket_used = [ + bucket_id for bucket_id, object_ids in self.buckets.items() + if obj in object_ids + ] + assert len(bucket_used) == 1 # each object should only be in one bucket + enabled_buckets.add(bucket_used[0]) + else: + self.v[obj] = value + if new_bucket_id is None: + # create new bucket + new_bucket_id = self.global_bucket_id + self.global_bucket_id += 1 + self.buckets[new_bucket_id] = [] + # put the new object into the corresponding bucket + self.buckets[new_bucket_id].append(obj) + enabled_buckets.add(new_bucket_id) + + # increment the permanent size if necessary + add_as_permanent = {} # indexed by bucket id + for bucket_id in enabled_buckets: + add_as_permanent[bucket_id] = False + if as_permanent == 'all': + self.perm_end_pt[bucket_id] += ne + add_as_permanent[bucket_id] = True + elif as_permanent == 'first': + if self.perm_end_pt[bucket_id] == 0: + self.perm_end_pt[bucket_id] = ne + add_as_permanent[bucket_id] = True + + # create new counters for usage if necessary + if self.save_usage and as_permanent != 'all': + new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7 + + # add the key to every bucket + for bucket_id in self.buckets: + if bucket_id not in enabled_buckets: + # if we are not adding new values to a bucket, we should skip it + continue + + _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id]) + _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id]) + if not add_as_permanent[bucket_id]: + if self.save_selection: + _add_last_dim(self.e, bucket_id, selection) + if self.save_usage: + _add_last_dim(self.use_cnt, bucket_id, new_count) + _add_last_dim(self.life_cnt, bucket_id, new_life) + + def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None: + # increase all life count by 1 + # increase use of indexed elements + if not self.save_usage: + return + + usage = usage[:, self.perm_end_pt[bucket_id]:] + if usage.shape[-1] == 0: + # if there is no temporary memory, we don't need to update + return + self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id]) + self.life_cnt[bucket_id] += 1 + + def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None: + # keep only the temporary elements *outside* of this range (with some boundary conditions) + # the permanent elements are ignored in this computation + # i.e., concat (a[:start], a[end:]) + # bucket with size <= min_size are not modified + + assert start >= 0 + assert end <= 0 + + object_ids = self.buckets[bucket_id] + bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id] + if bucket_num_elements <= min_size: + return + + if end == 0: + # negative 0 would not work as the end index! + # effectively make the second part an empty slice + end = self.k[bucket_id].shape[-1] + 1 + + p_size = self.perm_end_pt[bucket_id] + start = start + p_size + + k = self.k[bucket_id] + s = self.s[bucket_id] + if self.save_selection: + e = self.e[bucket_id] + if self.save_usage: + use_cnt = self.use_cnt[bucket_id] + life_cnt = self.life_cnt[bucket_id] + + self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1) + self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1) + if self.save_selection: + self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1) + if self.save_usage: + self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1) + self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]], + -1) + for obj_id in object_ids: + v = self.v[obj_id] + self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1) + + def remove_old_memory(self, bucket_id: int, max_len: int) -> None: + self.sieve_by_range(bucket_id, 0, -max_len, max_len) + + def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None: + # for long-term memory only + object_ids = self.buckets[bucket_id] + + assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory + + # normalize with life duration + usage = self.get_usage(bucket_id) + bs = usage.shape[0] + + survivals = [] + + for bi in range(bs): + _, survived = torch.topk(usage[bi], k=max_size) + survivals.append(survived.flatten()) + assert survived.shape[-1] == survivals[0].shape[-1] + + self.k[bucket_id] = torch.stack( + [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + self.s[bucket_id] = torch.stack( + [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + + if self.save_selection: + # Long-term memory does not store selection so this should not be needed + self.e[bucket_id] = torch.stack( + [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + for obj_id in object_ids: + self.v[obj_id] = torch.stack( + [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + + self.use_cnt[bucket_id] = torch.stack( + [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) + self.life_cnt[bucket_id] = torch.stack( + [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) + + def get_usage(self, bucket_id: int) -> torch.Tensor: + # return normalized usage + if not self.save_usage: + raise RuntimeError('I did not count usage!') + else: + usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id] + return usage + + def get_all_sliced( + self, bucket_id: int, start: int, end: int + ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): + # return k, sk, ek, value, normalized usage in order, sliced by start and end + # this only queries the temporary memory + + assert start >= 0 + assert end <= 0 + + p_size = self.perm_end_pt[bucket_id] + start = start + p_size + + if end == 0: + # negative 0 would not work as the end index! + k = self.k[bucket_id][:, :, start:] + sk = self.s[bucket_id][:, :, start:] + ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None + value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]} + usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None + else: + k = self.k[bucket_id][:, :, start:end] + sk = self.s[bucket_id][:, :, start:end] + ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None + value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]} + usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None + + return k, sk, ek, value, usage + + def purge_except(self, obj_keep_idx: List[int]): + # purge certain objects from the memory except the one listed + obj_keep_idx = set(obj_keep_idx) + + # remove objects that are not in the keep list from the buckets + buckets_to_remove = [] + for bucket_id, object_ids in self.buckets.items(): + self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx] + if len(self.buckets[bucket_id]) == 0: + buckets_to_remove.append(bucket_id) + + # remove object values that are not in the keep list + self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx} + + # remove buckets that are empty + for bucket_id in buckets_to_remove: + del self.buckets[bucket_id] + del self.k[bucket_id] + del self.s[bucket_id] + if self.save_selection: + del self.e[bucket_id] + if self.save_usage: + del self.use_cnt[bucket_id] + del self.life_cnt[bucket_id] + + def clear_non_permanent_memory(self): + # clear all non-permanent memory + for bucket_id in self.buckets: + self.sieve_by_range(bucket_id, 0, 0, 0) + + def get_v_size(self, obj_id: int) -> int: + return self.v[obj_id].shape[-1] + + def size(self, bucket_id: int) -> int: + if bucket_id not in self.k: + return 0 + else: + return self.k[bucket_id].shape[-1] + + def perm_size(self, bucket_id: int) -> int: + return self.perm_end_pt[bucket_id] + + def non_perm_size(self, bucket_id: int) -> int: + return self.size(bucket_id) - self.perm_size(bucket_id) + + def engaged(self, bucket_id: Optional[int] = None) -> bool: + if bucket_id is None: + return len(self.buckets) > 0 + else: + return bucket_id in self.buckets + + @property + def num_objects(self) -> int: + return len(self.v) + + @property + def key(self) -> Dict[int, torch.Tensor]: + return self.k + + @property + def value(self) -> Dict[int, torch.Tensor]: + return self.v + + @property + def shrinkage(self) -> Dict[int, torch.Tensor]: + return self.s + + @property + def selection(self) -> Dict[int, torch.Tensor]: + return self.e + + def __contains__(self, key): + return key in self.v diff --git a/tracker/inference/memory_manager.py b/tracker/inference/memory_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..53995a0f3c0998191858b955415a549dfcef248e --- /dev/null +++ b/tracker/inference/memory_manager.py @@ -0,0 +1,378 @@ +import logging +from omegaconf import DictConfig +from typing import List, Dict +import torch + +from tracker.inference.object_manager import ObjectManager +from tracker.inference.kv_memory_store import KeyValueMemoryStore +from tracker.model.cutie import CUTIE +from tracker.model.utils.memory_utils import * + +log = logging.getLogger() + + +class MemoryManager: + """ + Manages all three memory stores and the transition between working/long-term memory + """ + def __init__(self, cfg: DictConfig, object_manager: ObjectManager): + self.object_manager = object_manager + self.sensory_dim = cfg.model.sensory_dim + self.top_k = cfg.top_k + self.chunk_size = cfg.chunk_size + + self.save_aux = cfg.save_aux + + self.use_long_term = cfg.use_long_term + self.count_long_term_usage = cfg.long_term.count_usage + # subtract 1 because the first-frame is now counted as "permanent memory" + # and is not counted towards max_mem_frames + # but we want to keep the hyperparameters consistent as before for the same behavior + if self.use_long_term: + self.max_mem_frames = cfg.long_term.max_mem_frames - 1 + self.min_mem_frames = cfg.long_term.min_mem_frames - 1 + self.num_prototypes = cfg.long_term.num_prototypes + self.max_long_tokens = cfg.long_term.max_num_tokens + self.buffer_tokens = cfg.long_term.buffer_tokens + else: + self.max_mem_frames = cfg.max_mem_frames - 1 + + # dimensions will be inferred from input later + self.CK = self.CV = None + self.H = self.W = None + + # The sensory memory is stored as a dictionary indexed by object ids + # each of shape bs * C^h * H * W + self.sensory = {} + + # a dictionary indexed by object ids, each of shape bs * T * Q * C + self.obj_v = {} + + self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term, + save_usage=self.use_long_term) + if self.use_long_term: + self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage) + + self.config_stale = True + self.engaged = False + + def update_config(self, cfg: DictConfig) -> None: + self.config_stale = True + self.top_k = cfg['top_k'] + + assert self.use_long_term == cfg.use_long_term, 'cannot update this' + assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this' + + self.use_long_term = cfg.use_long_term + self.count_long_term_usage = cfg.long_term.count_usage + if self.use_long_term: + self.max_mem_frames = cfg.long_term.max_mem_frames - 1 + self.min_mem_frames = cfg.long_term.min_mem_frames - 1 + self.num_prototypes = cfg.long_term.num_prototypes + self.max_long_tokens = cfg.long_term.max_num_tokens + self.buffer_tokens = cfg.long_term.buffer_tokens + else: + self.max_mem_frames = cfg.max_mem_frames - 1 + + def _readout(self, affinity, v) -> torch.Tensor: + # affinity: bs*N*HW + # v: bs*C*N or bs*num_objects*C*N + # returns bs*C*HW or bs*num_objects*C*HW + if len(v.shape) == 3: + # single object + return v @ affinity + else: + bs, num_objects, C, N = v.shape + v = v.view(bs, num_objects * C, N) + out = v @ affinity + return out.view(bs, num_objects, C, -1) + + def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor: + # -1 because the mask does not contain the background channel + return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]] + + def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1) + + def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1) + + def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + # All the values that the object ids refer to should have the same shape + value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1) + if self.use_long_term and obj_ids[0] in self.long_mem.value: + lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1) + value = torch.cat([lt_value, value], dim=-1) + + return value + + def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor, + last_mask: torch.Tensor, network: CUTIE) -> Dict[int, torch.Tensor]: + """ + Read from all memory stores and returns a single memory readout tensor for each object + + pix_feat: (1/2) x C x H x W + query_key: (1/2) x C^k x H x W + selection: (1/2) x C^k x H x W + last_mask: (1/2) x num_objects x H x W (at stride 16) + return a dict of memory readouts, indexed by object indices. Each readout is C*H*W + """ + h, w = pix_feat.shape[-2:] + bs = pix_feat.shape[0] + assert query_key.shape[0] == bs + assert selection.shape[0] == bs + assert last_mask.shape[0] == bs + + query_key = query_key.flatten(start_dim=2) # bs*C^k*HW + selection = selection.flatten(start_dim=2) # bs*C^k*HW + """ + Compute affinity and perform readout + """ + all_readout_mem = {} + buckets = self.work_mem.buckets + for bucket_id, bucket in buckets.items(): + if self.use_long_term and self.long_mem.engaged(bucket_id): + # Use long-term memory + long_mem_size = self.long_mem.size(bucket_id) + memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]], + -1) + shrinkage = torch.cat( + [self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1) + + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + affinity, usage = do_softmax(similarity, + top_k=self.top_k, + inplace=True, + return_usage=True) + """ + Record memory usage for working and long-term memory + """ + # ignore the index return for long-term memory + work_usage = usage[:, long_mem_size:] + self.work_mem.update_bucket_usage(bucket_id, work_usage) + + if self.count_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_bucket_usage(bucket_id, long_usage) + else: + # no long-term memory + memory_key = self.work_mem.key[bucket_id] + shrinkage = self.work_mem.shrinkage[bucket_id] + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + + if self.use_long_term: + affinity, usage = do_softmax(similarity, + top_k=self.top_k, + inplace=True, + return_usage=True) + self.work_mem.update_bucket_usage(bucket_id, usage) + else: + affinity = do_softmax(similarity, top_k=self.top_k, inplace=True) + + if self.chunk_size < 1: + object_chunks = [bucket] + else: + object_chunks = [ + bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size) + ] + + for objects in object_chunks: + this_sensory = self._get_sensory_by_ids(objects) + this_last_mask = self._get_mask_by_ids(last_mask, objects) + this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N + visual_readout = self._readout(affinity, + this_msk_value).view(bs, len(objects), self.CV, h, w) + pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory, + this_last_mask) + this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2) + readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem) + for i, obj in enumerate(objects): + all_readout_mem[obj] = readout_memory[:, i] + + if self.save_aux: + aux_output = { + 'sensory': this_sensory, + 'pixel_readout': pixel_readout, + 'q_logits': aux_features['logits'] if aux_features else None, + 'q_weights': aux_features['q_weights'] if aux_features else None, + 'p_weights': aux_features['p_weights'] if aux_features else None, + 'attn_mask': aux_features['attn_mask'].float() if aux_features else None, + } + self.aux = aux_output + + return all_readout_mem + + def add_memory(self, + key: torch.Tensor, + shrinkage: torch.Tensor, + msk_value: torch.Tensor, + obj_value: torch.Tensor, + objects: List[int], + selection: torch.Tensor = None, + *, + as_permanent: bool = False) -> None: + # key: (1/2)*C*H*W + # msk_value: (1/2)*num_objects*C*H*W + # obj_value: (1/2)*num_objects*Q*C + # objects contains a list of object ids corresponding to the objects in msk_value/obj_value + bs = key.shape[0] + assert shrinkage.shape[0] == bs + assert msk_value.shape[0] == bs + assert obj_value.shape[0] == bs + + self.engaged = True + if self.H is None or self.config_stale: + self.config_stale = False + self.H, self.W = msk_value.shape[-2:] + self.HW = self.H * self.W + # convert from num. frames to num. tokens + self.max_work_tokens = self.max_mem_frames * self.HW + if self.use_long_term: + self.min_work_tokens = self.min_mem_frames * self.HW + + # key: bs*C*N + # value: bs*num_objects*C*N + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + self.CK = key.shape[1] + + msk_value = msk_value.flatten(start_dim=3) + self.CV = msk_value.shape[2] + + if selection is not None: + # not used in non-long-term mode + selection = selection.flatten(start_dim=2) + + # insert object values into object memory + for obj_id, obj in enumerate(objects): + if obj in self.obj_v: + """streaming average + each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1) + first embed_dim keeps track of the sum of embeddings + the last dim keeps the total count + averaging in done inside the object transformer + + incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1) + self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0) + """ + last_acc = self.obj_v[obj][:, :, -1] + new_acc = last_acc + obj_value[:, obj_id, :, -1] + + self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] + + obj_value[:, obj_id, :, :-1]) + self.obj_v[obj][:, :, -1] = new_acc + else: + self.obj_v[obj] = obj_value[:, obj_id] + + # convert mask value tensor into a dict for insertion + msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)} + self.work_mem.add(key, + msk_values, + shrinkage, + selection=selection, + as_permanent=as_permanent) + + for bucket_id in self.work_mem.buckets.keys(): + # long-term memory cleanup + if self.use_long_term: + # Do memory compressed if needed + if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens: + # Remove obsolete features if needed + if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens - + self.num_prototypes): + self.long_mem.remove_obsolete_features( + bucket_id, + self.max_long_tokens - self.num_prototypes - self.buffer_tokens) + + self.compress_features(bucket_id) + else: + # FIFO + self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens) + + def purge_except(self, obj_keep_idx: List[int]) -> None: + # purge certain objects from the memory except the one listed + self.work_mem.purge_except(obj_keep_idx) + if self.use_long_term and self.long_mem.engaged(): + self.long_mem.purge_except(obj_keep_idx) + self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx} + + if not self.work_mem.engaged(): + # everything is removed! + self.engaged = False + + def compress_features(self, bucket_id: int) -> None: + HW = self.HW + + # perform memory consolidation + prototype_key, prototype_value, prototype_shrinkage = self.consolidation( + *self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens)) + + # remove consolidated working memory + self.work_mem.sieve_by_range(bucket_id, + 0, + -self.min_work_tokens, + min_size=self.min_work_tokens) + + # add to long-term memory + self.long_mem.add(prototype_key, + prototype_value, + prototype_shrinkage, + selection=None, + supposed_bucket_id=bucket_id) + + def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor, + candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor], + usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): + # find the indices with max usage + bs = candidate_key.shape[0] + assert bs in [1, 2] + + prototype_key = [] + prototype_selection = [] + for bi in range(bs): + _, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True) + prototype_indices = max_usage_indices.flatten() + prototype_key.append(candidate_key[bi, :, prototype_indices]) + prototype_selection.append(candidate_selection[bi, :, prototype_indices]) + prototype_key = torch.stack(prototype_key, dim=0) + prototype_selection = torch.stack(prototype_selection, dim=0) + """ + Potentiation step + """ + similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, + prototype_selection) + affinity = do_softmax(similarity) + + # readout the values + prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()} + + # readout the shrinkage term + prototype_shrinkage = self._readout(affinity, candidate_shrinkage) + + return prototype_key, prototype_value, prototype_shrinkage + + def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]): + for obj in ids: + if obj not in self.sensory: + # also initializes the sensory memory + bs, _, h, w = sample_key.shape + self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w), + device=sample_key.device) + + def update_sensory(self, sensory: torch.Tensor, ids: List[int]): + # sensory: 1*num_objects*C*H*W + for obj_id, obj in enumerate(ids): + self.sensory[obj] = sensory[:, obj_id] + + def get_sensory(self, ids: List[int]): + # returns (1/2)*num_objects*C*H*W + return self._get_sensory_by_ids(ids) + + def clear_non_permanent_memory(self): + self.work_mem.clear_non_permanent_memory() + if self.use_long_term: + self.long_mem.clear_non_permanent_memory() + + def clear_sensory_memory(self): + self.sensory = {} diff --git a/tracker/inference/object_info.py b/tracker/inference/object_info.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e0bd45b10d0361c3ebc19783155e9ab29c8ad0 --- /dev/null +++ b/tracker/inference/object_info.py @@ -0,0 +1,24 @@ +class ObjectInfo: + """ + Store meta information for an object + """ + def __init__(self, id: int): + self.id = id + self.poke_count = 0 # count number of detections missed + + def poke(self) -> None: + self.poke_count += 1 + + def unpoke(self) -> None: + self.poke_count = 0 + + def __hash__(self): + return hash(self.id) + + def __eq__(self, other): + if type(other) == int: + return self.id == other + return self.id == other.id + + def __repr__(self): + return f'(ID: {self.id})' diff --git a/tracker/inference/object_manager.py b/tracker/inference/object_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e5a8533d3f4e5c72150abe7e79d05fcd2f7bd9 --- /dev/null +++ b/tracker/inference/object_manager.py @@ -0,0 +1,148 @@ +from typing import Union, List, Dict + +import torch +from tracker.inference.object_info import ObjectInfo + + +class ObjectManager: + """ + Object IDs are immutable. The same ID always represent the same object. + Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. + Temporary IDs start from 1. + """ + def __init__(self): + self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} + self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} + self.obj_id_to_obj: Dict[int, ObjectInfo] = {} + + self.all_historical_object_ids: List[int] = [] + + def _recompute_obj_id_to_obj_mapping(self) -> None: + self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} + + def add_new_objects( + self, objects: Union[List[ObjectInfo], ObjectInfo, + List[int]]) -> (List[int], List[int]): + if not isinstance(objects, list): + objects = [objects] + + corresponding_tmp_ids = [] + corresponding_obj_ids = [] + for obj in objects: + if isinstance(obj, int): + obj = ObjectInfo(id=obj) + + if obj in self.obj_to_tmp_id: + # old object + corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) + corresponding_obj_ids.append(obj.id) + else: + # new object + new_obj = ObjectInfo(id=obj) + + # new object + new_tmp_id = len(self.obj_to_tmp_id) + 1 + self.obj_to_tmp_id[new_obj] = new_tmp_id + self.tmp_id_to_obj[new_tmp_id] = new_obj + self.all_historical_object_ids.append(new_obj.id) + corresponding_tmp_ids.append(new_tmp_id) + corresponding_obj_ids.append(new_obj.id) + + self._recompute_obj_id_to_obj_mapping() + assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) + return corresponding_tmp_ids, corresponding_obj_ids + + def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None: + # delete an object or a list of objects + # re-sort the tmp ids + if isinstance(obj_ids_to_remove, int): + obj_ids_to_remove = [obj_ids_to_remove] + + new_tmp_id = 1 + total_num_id = len(self.obj_to_tmp_id) + + local_obj_to_tmp_id = {} + local_tmp_to_obj_id = {} + + for tmp_iter in range(1, total_num_id + 1): + obj = self.tmp_id_to_obj[tmp_iter] + if obj.id not in obj_ids_to_remove: + local_obj_to_tmp_id[obj] = new_tmp_id + local_tmp_to_obj_id[new_tmp_id] = obj + new_tmp_id += 1 + + self.obj_to_tmp_id = local_obj_to_tmp_id + self.tmp_id_to_obj = local_tmp_to_obj_id + self._recompute_obj_id_to_obj_mapping() + + def purge_inactive_objects(self, + max_missed_detection_count: int) -> (bool, List[int], List[int]): + # remove tmp ids of objects that are removed + obj_id_to_be_deleted = [] + tmp_id_to_be_deleted = [] + tmp_id_to_keep = [] + obj_id_to_keep = [] + + for obj in self.obj_to_tmp_id: + if obj.poke_count > max_missed_detection_count: + obj_id_to_be_deleted.append(obj.id) + tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) + else: + tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) + obj_id_to_keep.append(obj.id) + + purge_activated = len(obj_id_to_be_deleted) > 0 + if purge_activated: + self.delete_object(obj_id_to_be_deleted) + return purge_activated, tmp_id_to_keep, obj_id_to_keep + + def tmp_to_obj_cls(self, mask) -> torch.Tensor: + # remap tmp id cls representation to the true object id representation + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + return new_mask + + def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: + # returns the mapping in a dict format for saving it with pickle + return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} + + def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: + # turns a dict indexed by obj id into a tensor, ordered by tmp IDs + output = [] + for _, obj in self.tmp_id_to_obj.items(): + if obj.id not in obj_dict: + raise NotImplementedError + output.append(obj_dict[obj.id]) + output = torch.stack(output, dim=dim) + return output + + def make_one_hot(self, cls_mask) -> torch.Tensor: + output = [] + for _, obj in self.tmp_id_to_obj.items(): + output.append(cls_mask == obj.id) + if len(output) == 0: + output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) + else: + output = torch.stack(output, dim=0) + return output + + @property + def all_obj_ids(self) -> List[int]: + return [k.id for k in self.obj_to_tmp_id] + + @property + def num_obj(self) -> int: + return len(self.obj_to_tmp_id) + + def has_all(self, objects: List[int]) -> bool: + for obj in objects: + if obj not in self.obj_to_tmp_id: + return False + return True + + def find_object_by_id(self, obj_id) -> ObjectInfo: + return self.obj_id_to_obj[obj_id] + + def find_tmp_by_id(self, obj_id) -> int: + return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] diff --git a/tracker/inference/utils/__init__.py b/tracker/inference/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tracker/inference/utils/args_utils.py b/tracker/inference/utils/args_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a771ccaa080af2acd9757c7139c60c24652a1442 --- /dev/null +++ b/tracker/inference/utils/args_utils.py @@ -0,0 +1,30 @@ +import logging +from omegaconf import DictConfig + +log = logging.getLogger() + + +def get_dataset_cfg(cfg: DictConfig): + dataset_name = cfg.dataset + data_cfg = cfg.datasets[dataset_name] + + potential_overrides = [ + 'image_directory', + 'mask_directory', + 'json_directory', + 'size', + 'save_all', + 'use_all_masks', + 'use_long_term', + 'mem_every', + ] + + for override in potential_overrides: + if cfg[override] is not None: + log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}') + data_cfg[override] = cfg[override] + # escalte all potential overrides to the top-level config + if override in data_cfg: + cfg[override] = data_cfg[override] + + return data_cfg diff --git a/tracker/inference/utils/burst_utils.py b/tracker/inference/utils/burst_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..570442848c83378f8562485aa7cca3502910440c --- /dev/null +++ b/tracker/inference/utils/burst_utils.py @@ -0,0 +1,19 @@ +from os import path +import copy +import json + + +class BURSTResultHandler: + def __init__(self, dataset_json): + self.dataset_json = copy.deepcopy(dataset_json) + + # get rid of the segmentations while keeping the metadata + self.dataset_json['sequences'] = [] + + def add_sequence(self, sequence_json): + self.dataset_json['sequences'].append(sequence_json) + + def dump(self, root): + json_path = path.join(root, 'predictions.json') + with open(json_path, 'w') as f: + json.dump(self.dataset_json, f) \ No newline at end of file diff --git a/tracker/inference/utils/frame_utils.py b/tracker/inference/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dc18f35ed7c9ee4ff1aa41fa1ef988d372e91039 --- /dev/null +++ b/tracker/inference/utils/frame_utils.py @@ -0,0 +1,26 @@ +from typing import Dict, List, Tuple +import torch + +from inference.object_info import ObjectInfo + + +class FrameInfo: + def __init__(self, image: torch.Tensor, mask: torch.Tensor, segments_info: List[ObjectInfo], + ti: int, info: Dict): + self.image = image + self.mask = mask + self.segments_info = segments_info + self.ti = ti + self.info = info + + @property + def name(self) -> str: + return self.info['frame'] + + @property + def shape(self) -> Tuple(int): + return self.info['shape'] + + @property + def need_save(self) -> bool: + return self.info['save'] diff --git a/tracker/inference/utils/results_utils.py b/tracker/inference/utils/results_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db970ffea9ff89df45bdb4932a5972373a4f23f2 --- /dev/null +++ b/tracker/inference/utils/results_utils.py @@ -0,0 +1,256 @@ +from typing import Tuple, Optional, Dict +import logging +import os +import shutil +from os import path +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np + +import pycocotools.mask as mask_util +from threading import Thread +from queue import Queue +from dataclasses import dataclass +import copy + +from tracker.utils.pano_utils import ID2RGBConverter +from tracker.utils.palette import davis_palette_np +from tracker.inference.object_manager import ObjectManager +from tracker.inference.object_info import ObjectInfo + +log = logging.getLogger() + +try: + import hickle as hkl +except ImportError: + log.warning('Failed to import hickle. Fine if not using multi-scale testing.') + + +class ResultSaver: + def __init__(self, + output_root, + video_name, + *, + dataset, + object_manager: ObjectManager, + use_long_id, + palette=None, + save_mask=True, + save_scores=False, + score_output_root=None, + visualize_output_root=None, + visualize=False, + init_json=None): + self.output_root = output_root + self.video_name = video_name + self.dataset = dataset.lower() + self.use_long_id = use_long_id + self.palette = palette + self.object_manager = object_manager + self.save_mask = save_mask + self.save_scores = save_scores + self.score_output_root = score_output_root + self.visualize_output_root = visualize_output_root + self.visualize = visualize + + if self.visualize: + if self.palette is not None: + self.colors = np.array(self.palette, dtype=np.uint8).reshape(-1, 3) + else: + self.colors = davis_palette_np + + self.need_remapping = True + self.json_style = None + self.id2rgb_converter = ID2RGBConverter() + + if 'burst' in self.dataset: + assert init_json is not None + self.input_segmentations = init_json['segmentations'] + self.segmentations = [{} for _ in init_json['segmentations']] + self.annotated_frames = init_json['annotated_image_paths'] + self.video_json = {k: v for k, v in init_json.items() if k != 'segmentations'} + self.video_json['segmentations'] = self.segmentations + self.json_style = 'burst' + + self.queue = Queue(maxsize=10) + self.thread = Thread(target=save_result, args=(self.queue, )) + self.thread.daemon = True + self.thread.start() + + def process(self, + prob: torch.Tensor, + frame_name: str, + resize_needed: bool = False, + shape: Optional[Tuple[int, int]] = None, + last_frame: bool = False, + path_to_image: str = None): + + if resize_needed: + prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, + 0] + # Probability mask -> index mask + mask = torch.argmax(prob, dim=0) + if self.save_scores: + # also need to pass prob + prob = prob.cpu() + else: + prob = None + + # remap indices + if self.need_remapping: + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + mask = new_mask + + args = ResultArgs(saver=self, + prob=prob, + mask=mask.cpu(), + frame_name=frame_name, + path_to_image=path_to_image, + tmp_id_to_obj=copy.deepcopy(self.object_manager.tmp_id_to_obj), + obj_to_tmp_id=copy.deepcopy(self.object_manager.obj_to_tmp_id), + last_frame=last_frame) + + self.queue.put(args) + + def end(self): + self.queue.put(None) + self.queue.join() + self.thread.join() + + +@dataclass +class ResultArgs: + saver: ResultSaver + prob: torch.Tensor + mask: torch.Tensor + frame_name: str + path_to_image: str + tmp_id_to_obj: Dict[int, ObjectInfo] + obj_to_tmp_id: Dict[ObjectInfo, int] + last_frame: bool + + +def save_result(queue: Queue): + while True: + args: ResultArgs = queue.get() + if args is None: + queue.task_done() + break + + saver = args.saver + prob = args.prob + mask = args.mask + frame_name = args.frame_name + path_to_image = args.path_to_image + tmp_id_to_obj = args.tmp_id_to_obj + obj_to_tmp_id = args.obj_to_tmp_id + last_frame = args.last_frame + all_obj_ids = [k.id for k in obj_to_tmp_id] + + # record output in the json file + if saver.json_style == 'burst': + if frame_name in saver.annotated_frames: + frame_index = saver.annotated_frames.index(frame_name) + input_segments = saver.input_segmentations[frame_index] + frame_segments = saver.segmentations[frame_index] + + for id in all_obj_ids: + if id in input_segments: + # if this frame has been given as input, just copy + frame_segments[id] = input_segments[id] + continue + + segment = {} + segment_mask = (mask == id) + if segment_mask.sum() > 0: + coco_mask = mask_util.encode(np.asfortranarray(segment_mask.numpy())) + segment['rle'] = coco_mask['counts'].decode('utf-8') + frame_segments[id] = segment + + # save the mask to disk + if saver.save_mask: + if saver.use_long_id: + out_mask = mask.numpy().astype(np.uint32) + rgb_mask = np.zeros((*out_mask.shape[-2:], 3), dtype=np.uint8) + for id in all_obj_ids: + _, image = saver.id2rgb_converter.convert(id) + obj_mask = (out_mask == id) + rgb_mask[obj_mask] = image + out_img = Image.fromarray(rgb_mask) + else: + rgb_mask = None + out_mask = mask.numpy().astype(np.uint8) + out_img = Image.fromarray(out_mask) + if saver.palette is not None: + out_img.putpalette(saver.palette) + + this_out_path = path.join(saver.output_root, saver.video_name) + os.makedirs(this_out_path, exist_ok=True) + out_img.save(os.path.join(this_out_path, frame_name[:-4] + '.png')) + + # save scores for multi-scale testing + if saver.save_scores: + this_out_path = path.join(saver.score_output_root, saver.video_name) + os.makedirs(this_out_path, exist_ok=True) + + prob = (prob.detach().numpy() * 255).astype(np.uint8) + + if last_frame: + tmp_to_obj_mapping = {obj.id: tmp_id for obj, tmp_id in tmp_id_to_obj.items()} + hkl.dump(tmp_to_obj_mapping, path.join(this_out_path, f'backward.hkl'), mode='w') + + hkl.dump(prob, + path.join(this_out_path, f'{frame_name[:-4]}.hkl'), + mode='w', + compression='lzf') + + if saver.visualize: + if path_to_image is not None: + image_np = np.array(Image.open(path_to_image)) + else: + raise ValueError('Cannot visualize without path_to_image') + + if rgb_mask is None: + # we need to apply a palette + rgb_mask = np.zeros((*out_mask.shape, 3), dtype=np.uint8) + for id in all_obj_ids: + image = saver.colors[id] + obj_mask = (out_mask == id) + rgb_mask[obj_mask] = image + + alpha = (out_mask == 0).astype(np.float32) * 0.5 + 0.5 + alpha = alpha[:, :, None] + blend = (image_np * alpha + rgb_mask * (1 - alpha)).astype(np.uint8) + + # find a place to save the visualization + this_vis_path = path.join(saver.visualize_output_root, saver.video_name) + os.makedirs(this_vis_path, exist_ok=True) + Image.fromarray(blend).save(path.join(this_vis_path, frame_name[:-4] + '.jpg')) + + queue.task_done() + + +def make_zip(dataset, run_dir, exp_id, mask_output_root): + if dataset.startswith('y'): + # YoutubeVOS + log.info('Making zip for YouTubeVOS...') + shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir, + 'Annotations') + elif dataset == 'd17-test-dev': + # DAVIS 2017 test-dev -- zip from within the Annotation folder + log.info('Making zip for DAVIS test-dev...') + shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root) + elif dataset == 'mose-val': + # MOSE validation -- same as DAVIS test-dev + log.info('Making zip for MOSE validation...') + shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root) + elif dataset == 'lvos-test': + # LVOS test -- same as YouTubeVOS + log.info('Making zip for LVOS test...') + shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir, + 'Annotations') + else: + log.info(f'Not making zip for {dataset}.') diff --git a/tracker/model/__init__.py b/tracker/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tracker/model/aux_modules.py b/tracker/model/aux_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4a8ef3913dc5f057b29ad5917cb7bc5541d004 --- /dev/null +++ b/tracker/model/aux_modules.py @@ -0,0 +1,80 @@ +""" +For computing auxiliary outputs for auxiliary losses +""" +from typing import Dict +from omegaconf import DictConfig +import torch +import torch.nn as nn + +from tracker.model.group_modules import GConv2d +from tracker.utils.tensor_utils import aggregate + + +class LinearPredictor(nn.Module): + def __init__(self, x_dim: int, pix_dim: int): + super().__init__() + self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1) + + def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + # pixel_feat: B*pix_dim*H*W + # x: B*num_objects*x_dim*H*W + num_objects = x.shape[1] + x = self.projection(x) + + pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1] + return logits + + +class DirectPredictor(nn.Module): + def __init__(self, x_dim: int): + super().__init__() + self.projection = GConv2d(x_dim, 1, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: B*num_objects*x_dim*H*W + logits = self.projection(x).squeeze(2) + return logits + + +class AuxComputer(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + use_sensory_aux = cfg.model.aux_loss.sensory.enabled + self.use_query_aux = cfg.model.aux_loss.query.enabled + + sensory_dim = cfg.model.sensory_dim + embed_dim = cfg.model.embed_dim + + if use_sensory_aux: + self.sensory_aux = LinearPredictor(sensory_dim, embed_dim) + else: + self.sensory_aux = None + + def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + logits = aggregate(prob, dim=1) + return logits + + def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor], + selector: torch.Tensor) -> Dict[str, torch.Tensor]: + sensory = aux_input['sensory'] + q_logits = aux_input['q_logits'] + + aux_output = {} + aux_output['attn_mask'] = aux_input['attn_mask'] + + if self.sensory_aux is not None: + # B*num_objects*H*W + logits = self.sensory_aux(pix_feat, sensory) + aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector) + if self.use_query_aux: + # B*num_objects*num_levels*H*W + aux_output['q_logits'] = self._aggregate_with_selector( + torch.stack(q_logits, dim=2), + selector.unsqueeze(2) if selector is not None else None) + + return aux_output \ No newline at end of file diff --git a/tracker/model/big_modules.py b/tracker/model/big_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1daaf0d72811694922476e63e018cefa6c5656 --- /dev/null +++ b/tracker/model/big_modules.py @@ -0,0 +1,304 @@ +""" +big_modules.py - This file stores higher-level network blocks. + +x - usually denotes features that are shared between objects. +g - usually denotes features that are not shared between objects + with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W). + +The trailing number of a variable usually denotes the stride +""" + +from omegaconf import DictConfig +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tracker.model.group_modules import * +from tracker.model.utils import resnet +from tracker.model.modules import * + + +class PixelEncoder(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type + if self.is_resnet: + if model_cfg.pixel_encoder.type == 'resnet18': + network = resnet.resnet18(pretrained=True) + elif model_cfg.pixel_encoder.type == 'resnet50': + network = resnet.resnet50(pretrained=True) + else: + raise NotImplementedError + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu + self.maxpool = network.maxpool + + self.res2 = network.layer1 + self.layer2 = network.layer2 + self.layer3 = network.layer3 + else: + raise NotImplementedError + + def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + f4 = self.res2(x) + f8 = self.layer2(f4) + f16 = self.layer3(f8) + + return f16, f8, f4 + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + + +class KeyProjection(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + in_dim = model_cfg.pixel_encoder.ms_dims[0] + mid_dim = model_cfg.pixel_dim + key_dim = model_cfg.key_dim + + self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1) + self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) + # shrinkage + self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1) + # selection + self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) + + nn.init.orthogonal_(self.key_proj.weight.data) + nn.init.zeros_(self.key_proj.bias.data) + + def forward(self, x: torch.Tensor, *, need_s: bool, + need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor): + x = self.pix_feat_proj(x) + shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None + selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None + + return self.key_proj(x), shrinkage, selection + + +class MaskEncoder(nn.Module): + def __init__(self, model_cfg: DictConfig, single_object=False): + super().__init__() + pixel_dim = model_cfg.pixel_dim + value_dim = model_cfg.value_dim + sensory_dim = model_cfg.sensory_dim + final_dim = model_cfg.mask_encoder.final_dim + + self.single_object = single_object + extra_dim = 1 if single_object else 2 + + if model_cfg.mask_encoder.type == 'resnet18': + network = resnet.resnet18(pretrained=True, extra_dim=extra_dim) + elif model_cfg.mask_encoder.type == 'resnet50': + network = resnet.resnet50(pretrained=True, extra_dim=extra_dim) + else: + raise NotImplementedError + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu + self.maxpool = network.maxpool + + self.layer1 = network.layer1 + self.layer2 = network.layer2 + self.layer3 = network.layer3 + + self.distributor = MainToGroupDistributor() + self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim) + + self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim) + + def forward(self, + image: torch.Tensor, + pix_feat: torch.Tensor, + sensory: torch.Tensor, + masks: torch.Tensor, + others: torch.Tensor, + *, + deep_update: bool = True, + chunk_size: int = -1) -> (torch.Tensor, torch.Tensor): + # ms_features are from the key encoder + # we only use the first one (lowest resolution), following XMem + if self.single_object: + g = masks.unsqueeze(2) + else: + g = torch.stack([masks, others], dim=2) + + g = self.distributor(image, g) + + batch_size, num_objects = g.shape[:2] + if chunk_size < 1 or chunk_size >= num_objects: + chunk_size = num_objects + fast_path = True + new_sensory = sensory + else: + if deep_update: + new_sensory = torch.empty_like(sensory) + else: + new_sensory = sensory + fast_path = False + + # chunk-by-chunk inference + all_g = [] + for i in range(0, num_objects, chunk_size): + if fast_path: + g_chunk = g + else: + g_chunk = g[:, i:i + chunk_size] + actual_chunk_size = g_chunk.shape[1] + g_chunk = g_chunk.flatten(start_dim=0, end_dim=1) + + g_chunk = self.conv1(g_chunk) + g_chunk = self.bn1(g_chunk) # 1/2, 64 + g_chunk = self.maxpool(g_chunk) # 1/4, 64 + g_chunk = self.relu(g_chunk) + + g_chunk = self.layer1(g_chunk) # 1/4 + g_chunk = self.layer2(g_chunk) # 1/8 + g_chunk = self.layer3(g_chunk) # 1/16 + + g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:]) + g_chunk = self.fuser(pix_feat, g_chunk) + all_g.append(g_chunk) + if deep_update: + if fast_path: + new_sensory = self.sensory_update(g_chunk, sensory) + else: + new_sensory[:, i:i + chunk_size] = self.sensory_update( + g_chunk, sensory[:, i:i + chunk_size]) + g = torch.cat(all_g, dim=1) + + return g, new_sensory + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + + +class PixelFeatureFuser(nn.Module): + def __init__(self, model_cfg: DictConfig, single_object=False): + super().__init__() + value_dim = model_cfg.value_dim + sensory_dim = model_cfg.sensory_dim + pixel_dim = model_cfg.pixel_dim + embed_dim = model_cfg.embed_dim + self.single_object = single_object + + self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim) + if self.single_object: + self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1) + else: + self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1) + + def forward(self, + pix_feat: torch.Tensor, + pixel_memory: torch.Tensor, + sensory_memory: torch.Tensor, + last_mask: torch.Tensor, + last_others: torch.Tensor, + *, + chunk_size: int = -1) -> torch.Tensor: + batch_size, num_objects = pixel_memory.shape[:2] + + if self.single_object: + last_mask = last_mask.unsqueeze(2) + else: + last_mask = torch.stack([last_mask, last_others], dim=2) + + if chunk_size < 1: + chunk_size = num_objects + + # chunk-by-chunk inference + all_p16 = [] + for i in range(0, num_objects, chunk_size): + sensory_readout = self.sensory_compress( + torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2)) + p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout + p16 = self.fuser(pix_feat, p16) + all_p16.append(p16) + p16 = torch.cat(all_p16, dim=1) + + return p16 + + +class MaskDecoder(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + embed_dim = model_cfg.embed_dim + sensory_dim = model_cfg.sensory_dim + ms_image_dims = model_cfg.pixel_encoder.ms_dims + up_dims = model_cfg.mask_decoder.up_dims + + assert embed_dim == up_dims[0] + + self.sensory_update = SensoryUpdater([up_dims[0], up_dims[1], up_dims[2] + 1], sensory_dim, + sensory_dim) + + self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1]) + self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1]) + self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2]) + + self.pred = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) + + def forward(self, + ms_image_feat: Iterable[torch.Tensor], + memory_readout: torch.Tensor, + sensory: torch.Tensor, + *, + chunk_size: int = -1, + update_sensory: bool = True) -> (torch.Tensor, torch.Tensor): + + batch_size, num_objects = memory_readout.shape[:2] + f8, f4 = self.decoder_feat_proc(ms_image_feat[1:]) + if chunk_size < 1 or chunk_size >= num_objects: + chunk_size = num_objects + fast_path = True + new_sensory = sensory + else: + if update_sensory: + new_sensory = torch.empty_like(sensory) + else: + new_sensory = sensory + fast_path = False + + # chunk-by-chunk inference + all_logits = [] + for i in range(0, num_objects, chunk_size): + if fast_path: + p16 = memory_readout + else: + p16 = memory_readout[:, i:i + chunk_size] + actual_chunk_size = p16.shape[1] + + p8 = self.up_16_8(p16, f8) + p4 = self.up_8_4(p8, f4) + with torch.cuda.amp.autocast(enabled=False): + logits = self.pred(F.relu(p4.flatten(start_dim=0, end_dim=1).float())) + + if update_sensory: + p4 = torch.cat( + [p4, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2) + if fast_path: + new_sensory = self.sensory_update([p16, p8, p4], sensory) + else: + new_sensory[:, + i:i + chunk_size] = self.sensory_update([p16, p8, p4], + sensory[:, + i:i + chunk_size]) + all_logits.append(logits) + logits = torch.cat(all_logits, dim=0) + logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) + + return new_sensory, logits diff --git a/tracker/model/channel_attn.py b/tracker/model/channel_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..a2096c1c4b4745a3ea2060bb25af3b19ff9cf3ec --- /dev/null +++ b/tracker/model/channel_attn.py @@ -0,0 +1,39 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CAResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, residual: bool = True): + super().__init__() + self.residual = residual + self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) + + t = int((abs(math.log2(out_dim)) + 1) // 2) + k = t if t % 2 else t + 1 + self.pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False) + + if self.residual: + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = x + x = self.conv1(F.relu(x)) + x = self.conv2(F.relu(x)) + + b, c = x.shape[:2] + w = self.pool(x).view(b, 1, c) + w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1 + + if self.residual: + x = x * w + self.downsample(r) + else: + x = x * w + + return x diff --git a/tracker/model/cutie.py b/tracker/model/cutie.py new file mode 100644 index 0000000000000000000000000000000000000000..82c5652a3f3d657ab71ed208cd11ca2322608d7a --- /dev/null +++ b/tracker/model/cutie.py @@ -0,0 +1,249 @@ +from typing import List, Dict +import logging +from omegaconf import DictConfig +import torch +import torch.nn as nn + +from tracker.model.modules import * +from tracker.model.big_modules import * +from tracker.model.aux_modules import AuxComputer +from tracker.model.utils.memory_utils import * +from tracker.model.transformer.object_transformer import QueryTransformer +from tracker.model.transformer.object_summarizer import ObjectSummarizer +from tracker.utils.tensor_utils import aggregate + +log = logging.getLogger() + + +class CUTIE(nn.Module): + def __init__(self, cfg: DictConfig, *, single_object=False): + super().__init__() + model_cfg = cfg.model + self.ms_dims = model_cfg.pixel_encoder.ms_dims + self.key_dim = model_cfg.key_dim + self.value_dim = model_cfg.value_dim + self.sensory_dim = model_cfg.sensory_dim + self.pixel_dim = model_cfg.pixel_dim + self.embed_dim = model_cfg.embed_dim + self.single_object = single_object + + log.info(f'Single object: {self.single_object}') + + self.pixel_encoder = PixelEncoder(model_cfg) + self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1) + self.key_proj = KeyProjection(model_cfg) + self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object) + self.mask_decoder = MaskDecoder(model_cfg) + self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object) + self.object_transformer = QueryTransformer(model_cfg) + self.object_summarizer = ObjectSummarizer(model_cfg) + self.aux_computer = AuxComputer(cfg) + + self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False) + + def _get_others(self, masks: torch.Tensor) -> torch.Tensor: + # for each object, return the sum of masks of all other objects + if self.single_object: + return None + + num_objects = masks.shape[1] + if num_objects >= 1: + others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1) + else: + others = torch.zeros_like(masks) + return others + + def encode_image(self, image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): + image = (image - self.pixel_mean) / self.pixel_std + ms_image_feat = self.pixel_encoder(image) + return ms_image_feat, self.pix_feat_proj(ms_image_feat[0]) + + def encode_mask( + self, + image: torch.Tensor, + ms_features: List[torch.Tensor], + sensory: torch.Tensor, + masks: torch.Tensor, + *, + deep_update: bool = True, + chunk_size: int = -1, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + image = (image - self.pixel_mean) / self.pixel_std + others = self._get_others(masks) + mask_value, new_sensory = self.mask_encoder(image, + ms_features, + sensory, + masks, + others, + deep_update=deep_update, + chunk_size=chunk_size) + object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights) + return mask_value, new_sensory, object_summaries, object_logits + + def transform_key(self, + final_pix_feat: torch.Tensor, + *, + need_sk: bool = True, + need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor): + key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek) + return key, shrinkage, selection + + # Used in training only. + # This step is replaced by MemoryManager in test time + def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor, + memory_key: torch.Tensor, memory_shrinkage: torch.Tensor, + msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor, + sensory: torch.Tensor, last_mask: torch.Tensor, + selector: torch.Tensor) -> (torch.Tensor, Dict[str, torch.Tensor]): + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + msk_value : B * num_objects * CV * T * H * W + obj_memory : B * num_objects * T * num_summaries * C + pixel_feature : B * C * H * W + """ + batch_size, num_objects = msk_value.shape[:2] + + # read using visual attention + with torch.cuda.amp.autocast(enabled=False): + affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(), + query_selection.float()) + + msk_value = msk_value.flatten(start_dim=1, end_dim=2).float() + + # B * (num_objects*CV) * H * W + pixel_readout = readout(affinity, msk_value) + pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim, + *pixel_readout.shape[-2:]) + pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) + + # read from query transformer + mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector) + + aux_output = { + 'sensory': sensory, + 'q_logits': aux_features['logits'] if aux_features else None, + 'attn_mask': aux_features['attn_mask'] if aux_features else None, + } + + return mem_readout, aux_output + + def pixel_fusion(self, + pix_feat: torch.Tensor, + pixel: torch.Tensor, + sensory: torch.Tensor, + last_mask: torch.Tensor, + *, + chunk_size: int = -1) -> torch.Tensor: + last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area') + last_others = self._get_others(last_mask) + fused = self.pixel_fuser(pix_feat, + pixel, + sensory, + last_mask, + last_others, + chunk_size=chunk_size) + return fused + + def readout_query(self, + pixel_readout, + obj_memory, + *, + selector=None, + need_weights=False) -> (torch.Tensor, Dict[str, torch.Tensor]): + return self.object_transformer(pixel_readout, + obj_memory, + selector=selector, + need_weights=need_weights) + + def segment(self, + ms_image_feat: List[torch.Tensor], + memory_readout: torch.Tensor, + sensory: torch.Tensor, + *, + selector: bool = None, + chunk_size: int = -1, + update_sensory: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor): + """ + multi_scale_features is from the key encoder for skip-connection + memory_readout is from working/long-term memory + sensory is the sensory memory + last_mask is the mask from the last frame, supplementing sensory memory + selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects + during training. + """ + sensory, logits = self.mask_decoder(ms_image_feat, + memory_readout, + sensory, + chunk_size=chunk_size, + update_sensory=update_sensory) + + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + + # Softmax over all objects[] + logits = aggregate(prob, dim=1) + logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False) + prob = F.softmax(logits, dim=1) + + return sensory, logits, prob + + def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor], + selector: torch.Tensor) -> Dict[str, torch.Tensor]: + return self.aux_computer(pix_feat, aux_inputs, selector) + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None: + if not self.single_object: + # Map single-object weight to multi-object weight (4->5 out channels in conv1) + for k in list(src_dict.keys()): + if k == 'mask_encoder.conv1.weight': + if src_dict[k].shape[1] == 4: + log.info(f'Converting {k} from single object to multiple objects.') + pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device) + if not init_as_zero_if_needed: + nn.init.orthogonal_(pads) + log.info(f'Randomly initialized padding for {k}.') + else: + log.info(f'Zero-initialized padding for {k}.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + elif k == 'pixel_fuser.sensory_compress.weight': + if src_dict[k].shape[1] == self.sensory_dim + 1: + log.info(f'Converting {k} from single object to multiple objects.') + pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device) + if not init_as_zero_if_needed: + nn.init.orthogonal_(pads) + log.info(f'Randomly initialized padding for {k}.') + else: + log.info(f'Zero-initialized padding for {k}.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + elif self.single_object: + """ + If the model is multiple-object and we are training in single-object, + we strip the last channel of conv1. + This is not supposed to happen in standard training except when users are trying to + finetune a trained model with single object datasets. + """ + if src_dict['mask_encoder.conv1.weight'].shape[1] == 5: + log.warning(f'Converting {k} from multiple objects to single object.' + 'This is not supposed to happen in standard training.') + src_dict[k] = src_dict[k][:, :-1] + + for k in src_dict: + if k not in self.state_dict(): + log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!') + for k in self.state_dict(): + if k not in src_dict: + log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!') + + self.load_state_dict(src_dict, strict=False) + + @property + def device(self) -> torch.device: + return self.pixel_mean.device \ No newline at end of file diff --git a/tracker/model/group_modules.py b/tracker/model/group_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a9bf64c51613be619705bee2da9b7508378cbb46 --- /dev/null +++ b/tracker/model/group_modules.py @@ -0,0 +1,127 @@ +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +from tracker.model.channel_attn import CAResBlock + + +def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, + align_corners: bool) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + g = F.interpolate(g.flatten(start_dim=0, end_dim=1), + scale_factor=ratio, + mode=mode, + align_corners=align_corners) + g = g.view(batch_size, num_objects, *g.shape[1:]) + return g + + +def upsample_groups(g: torch.Tensor, + ratio: float = 2, + mode: str = 'bilinear', + align_corners: bool = False) -> torch.Tensor: + return interpolate_groups(g, ratio, mode, align_corners) + + +def downsample_groups(g: torch.Tensor, + ratio: float = 1 / 2, + mode: str = 'area', + align_corners: bool = None) -> torch.Tensor: + return interpolate_groups(g, ratio, mode, align_corners) + + +class GConv2d(nn.Conv2d): + def forward(self, g: torch.Tensor) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + g = super().forward(g.flatten(start_dim=0, end_dim=1)) + return g.view(batch_size, num_objects, *g.shape[1:]) + + +class GroupResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = GConv2d(in_dim, out_dim, kernel_size=1) + + self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1) + + def forward(self, g: torch.Tensor) -> torch.Tensor: + out_g = self.conv1(F.relu(g)) + out_g = self.conv2(F.relu(out_g)) + + g = self.downsample(g) + + return out_g + g + + +class MainToGroupDistributor(nn.Module): + def __init__(self, + x_transform: Optional[nn.Module] = None, + g_transform: Optional[nn.Module] = None, + method: str = 'cat', + reverse_order: bool = False): + super().__init__() + + self.x_transform = x_transform + self.g_transform = g_transform + self.method = method + self.reverse_order = reverse_order + + def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor: + num_objects = g.shape[1] + + if self.x_transform is not None: + x = self.x_transform(x) + + if self.g_transform is not None: + g = self.g_transform(g) + + if not skip_expand: + x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + if self.method == 'cat': + if self.reverse_order: + g = torch.cat([g, x], 2) + else: + g = torch.cat([x, g], 2) + elif self.method == 'add': + g = x + g + elif self.method == 'mulcat': + g = torch.cat([x * g, g], dim=2) + elif self.method == 'muladd': + g = x * g + g + else: + raise NotImplementedError + + return g + + +class GroupFeatureFusionBlock(nn.Module): + def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int): + super().__init__() + + x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1) + g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1) + + self.distributor = MainToGroupDistributor(x_transform=x_transform, + g_transform=g_transform, + method='add') + self.block1 = CAResBlock(out_dim, out_dim) + self.block2 = CAResBlock(out_dim, out_dim) + + def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + + g = self.distributor(x, g) + + g = g.flatten(start_dim=0, end_dim=1) + + g = self.block1(g) + g = self.block2(g) + + g = g.view(batch_size, num_objects, *g.shape[1:]) + + return g \ No newline at end of file diff --git a/tracker/model/losses.py b/tracker/model/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9c7c5c3bbd7e7bd909a82aaa8a3b5c8046d7ee --- /dev/null +++ b/tracker/model/losses.py @@ -0,0 +1,97 @@ +from typing import List, Dict +from omegaconf import DictConfig +from collections import defaultdict +import torch +import torch.nn.functional as F + +from tracker.utils.point_features import calculate_uncertainty, point_sample, get_uncertain_point_coords_with_randomness +from tracker.utils.tensor_utils import cls_to_one_hot + + +@torch.jit.script +def ce_loss(logits: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor: + # logits: T*C*num_points + loss = F.cross_entropy(logits, soft_gt, reduction='none') + # sum over temporal dimension + return loss.sum(0).mean() + + +@torch.jit.script +def dice_loss(mask: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor: + # mask: T*C*num_points + # soft_gt: T*C*num_points + # ignores the background + mask = mask[:, 1:].flatten(start_dim=2) + gt = soft_gt[:, 1:].float().flatten(start_dim=2) + numerator = 2 * (mask * gt).sum(-1) + denominator = mask.sum(-1) + gt.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum(0).mean() + + +class LossComputer: + def __init__(self, cfg: DictConfig, stage_cfg: DictConfig): + super().__init__() + self.point_supervision = stage_cfg.point_supervision + self.num_points = stage_cfg.train_num_points + self.oversample_ratio = stage_cfg.oversample_ratio + self.importance_sample_ratio = stage_cfg.importance_sample_ratio + + self.sensory_weight = cfg.model.aux_loss.sensory.weight + self.query_weight = cfg.model.aux_loss.query.weight + + def mask_loss(self, logits: torch.Tensor, + soft_gt: torch.Tensor) -> (torch.Tensor, torch.Tensor): + assert self.point_supervision + + with torch.no_grad(): + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + logits, lambda x: calculate_uncertainty(x), self.num_points, self.oversample_ratio, + self.importance_sample_ratio) + # get gt labels + point_labels = point_sample(soft_gt, point_coords, align_corners=False) + point_logits = point_sample(logits, point_coords, align_corners=False) + # point_labels and point_logits: B*C*num_points + + loss_ce = ce_loss(point_logits, point_labels) + loss_dice = dice_loss(point_logits.softmax(dim=1), point_labels) + + return loss_ce, loss_dice + + def compute(self, data: Dict[str, torch.Tensor], + num_objects: List[int]) -> Dict[str, torch.Tensor]: + batch_size, num_frames = data['rgb'].shape[:2] + losses = defaultdict(float) + t_range = range(1, num_frames) + + for bi in range(batch_size): + logits = torch.stack([data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range], + dim=0) + cls_gt = data['cls_gt'][bi, 1:] # remove gt for the first frame + soft_gt = cls_to_one_hot(cls_gt, num_objects[bi]) + + loss_ce, loss_dice = self.mask_loss(logits, soft_gt) + losses['loss_ce'] += loss_ce / batch_size + losses['loss_dice'] += loss_dice / batch_size + + aux = [data[f'aux_{ti}'] for ti in t_range] + if 'sensory_logits' in aux[0]: + sensory_log = torch.stack( + [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0) + loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt) + losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight + losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight + if 'q_logits' in aux[0]: + num_levels = aux[0]['q_logits'].shape[2] + + for l in range(num_levels): + query_log = torch.stack( + [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0) + loss_ce, loss_dice = self.mask_loss(query_log, soft_gt) + losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight + losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight + + losses['total_loss'] = sum(losses.values()) + + return losses diff --git a/tracker/model/modules.py b/tracker/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..59c4170af5abfecf7b85ed7804fc390285e0194d --- /dev/null +++ b/tracker/model/modules.py @@ -0,0 +1,85 @@ +from typing import List, Iterable +import torch +import torch.nn as nn + +from tracker.model.group_modules import * + + +class MaskUpsampleBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): + super().__init__() + self.distributor = MainToGroupDistributor(method='add') + self.out_conv = GroupResBlock(in_dim, out_dim) + self.scale_factor = scale_factor + + def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: + g = upsample_groups(in_g, ratio=self.scale_factor) + g = self.distributor(skip_f, g) + g = self.out_conv(g) + return g + + +class DecoderFeatureProcessor(nn.Module): + def __init__(self, decoder_dims: List[int], out_dims: List[int]): + super().__init__() + self.transforms = nn.ModuleList([ + nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims) + ]) + + def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]: + outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)] + return outputs + + +# @torch.jit.script +def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor: + # h: batch_size * num_objects * hidden_dim * h * w + # values: batch_size * num_objects * (hidden_dim*3) * h * w + dim = values.shape[2] // 3 + forget_gate = torch.sigmoid(values[:, :, :dim]) + update_gate = torch.sigmoid(values[:, :, dim:dim * 2]) + new_value = torch.tanh(values[:, :, dim * 2:]) + new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value + return new_h + + +class SensoryUpdater(nn.Module): + # Used in the decoder, multi-scale feature + GRU + def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): + super().__init__() + self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) + self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) + self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) + + self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ + self.g4_conv(downsample_groups(g[2], ratio=1/4)) + + with torch.cuda.amp.autocast(enabled=False): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h + + +class SensoryDeepUpdater(nn.Module): + def __init__(self, f_dim: int, sensory_dim: int): + super().__init__() + self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + with torch.cuda.amp.autocast(enabled=False): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h diff --git a/tracker/model/transformer/__init__.py b/tracker/model/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tracker/model/transformer/object_summarizer.py b/tracker/model/transformer/object_summarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..42ee1b5385d607f34145e25b0362678f196064a2 --- /dev/null +++ b/tracker/model/transformer/object_summarizer.py @@ -0,0 +1,89 @@ +from typing import List, Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tracker.model.transformer.positional_encoding import PositionalEncoding + + +# @torch.jit.script +def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor, + logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + # value: B*num_objects*H*W*value_dim + # logits: B*num_objects*H*W*num_summaries + # masks: B*num_objects*H*W*num_summaries: 1 if allowed + weights = logits.sigmoid() * masks + # B*num_objects*num_summaries*value_dim + sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value) + # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1 + area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1) + + # B*num_objects*num_summaries*value_dim + return sums, area + + +class ObjectSummarizer(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_summarizer + self.value_dim = model_cfg.value_dim + self.embed_dim = this_cfg.embed_dim + self.num_summaries = this_cfg.num_summaries + self.add_pe = this_cfg.add_pe + self.pixel_pe_scale = model_cfg.pixel_pe_scale + self.pixel_pe_temperature = model_cfg.pixel_pe_temperature + + if self.add_pe: + self.pos_enc = PositionalEncoding(self.embed_dim, + scale=self.pixel_pe_scale, + temperature=self.pixel_pe_temperature) + + self.input_proj = nn.Linear(self.value_dim, self.embed_dim) + self.feature_pred = nn.Sequential( + nn.Linear(self.embed_dim, self.embed_dim), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dim, self.embed_dim), + ) + self.weights_pred = nn.Sequential( + nn.Linear(self.embed_dim, self.embed_dim), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dim, self.num_summaries), + ) + + def forward(self, + masks: torch.Tensor, + value: torch.Tensor, + need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]): + # masks: B*num_objects*(H0)*(W0) + # value: B*num_objects*value_dim*H*W + # -> B*num_objects*H*W*value_dim + h, w = value.shape[-2:] + masks = F.interpolate(masks, size=(h, w), mode='area') + masks = masks.unsqueeze(-1) + inv_masks = 1 - masks + repeated_masks = torch.cat([ + masks.expand(-1, -1, -1, -1, self.num_summaries // 2), + inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2), + ], + dim=-1) + + value = value.permute(0, 1, 3, 4, 2) + value = self.input_proj(value) + if self.add_pe: + pe = self.pos_enc(value) + value = value + pe + + with torch.cuda.amp.autocast(enabled=False): + value = value.float() + feature = self.feature_pred(value) + logits = self.weights_pred(value) + sums, area = _weighted_pooling(repeated_masks, feature, logits) + + summaries = torch.cat([sums, area], dim=-1) + + if need_weights: + return summaries, logits + else: + return summaries, None \ No newline at end of file diff --git a/tracker/model/transformer/object_transformer.py b/tracker/model/transformer/object_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..71f0830247495df161f6126dd40ea9ff7f30b9f2 --- /dev/null +++ b/tracker/model/transformer/object_transformer.py @@ -0,0 +1,205 @@ +from typing import Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn as nn +from tracker.model.group_modules import GConv2d +from tracker.utils.tensor_utils import aggregate +from tracker.model.transformer.positional_encoding import PositionalEncoding +from tracker.model.transformer.transformer_layers import * + + +class QueryTransformerBlock(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_transformer + self.embed_dim = this_cfg.embed_dim + self.num_heads = this_cfg.num_heads + self.num_queries = this_cfg.num_queries + self.ff_dim = this_cfg.ff_dim + + self.read_from_pixel = CrossAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv) + self.self_attn = SelfAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv) + self.ffn = FFN(self.embed_dim, self.ff_dim) + self.read_from_query = CrossAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv, + norm=this_cfg.read_from_query.output_norm) + self.pixel_ffn = PixelFFN(self.embed_dim) + + def forward( + self, + x: torch.Tensor, + pixel: torch.Tensor, + query_pe: torch.Tensor, + pixel_pe: torch.Tensor, + attn_mask: torch.Tensor, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + # x: (bs*num_objects)*num_queries*embed_dim + # pixel: bs*num_objects*C*H*W + # query_pe: (bs*num_objects)*num_queries*embed_dim + # pixel_pe: (bs*num_objects)*(H*W)*C + # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W) + + # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C + pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() + x, q_weights = self.read_from_pixel(x, + pixel_flat, + query_pe, + pixel_pe, + attn_mask=attn_mask, + need_weights=need_weights) + x = self.self_attn(x, query_pe) + x = self.ffn(x) + + pixel_flat, p_weights = self.read_from_query(pixel_flat, + x, + pixel_pe, + query_pe, + need_weights=need_weights) + pixel = self.pixel_ffn(pixel, pixel_flat) + + if need_weights: + bs, num_objects, _, h, w = pixel.shape + q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w) + p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads, + self.num_queries, h, w) + + return x, pixel, q_weights, p_weights + + +class QueryTransformer(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_transformer + self.value_dim = model_cfg.value_dim + self.embed_dim = this_cfg.embed_dim + self.num_heads = this_cfg.num_heads + self.num_queries = this_cfg.num_queries + + # query initialization and embedding + self.query_init = nn.Embedding(self.num_queries, self.embed_dim) + self.query_emb = nn.Embedding(self.num_queries, self.embed_dim) + + # projection from object summaries to query initialization and embedding + self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim) + self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim) + + self.pixel_pe_scale = model_cfg.pixel_pe_scale + self.pixel_pe_temperature = model_cfg.pixel_pe_temperature + self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) + self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) + self.spatial_pe = PositionalEncoding(self.embed_dim, + scale=self.pixel_pe_scale, + temperature=self.pixel_pe_temperature, + channel_last=False, + transpose_output=True) + + # transformer blocks + self.num_blocks = this_cfg.num_blocks + self.blocks = nn.ModuleList( + QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks)) + self.mask_pred = nn.ModuleList( + nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1)) + for _ in range(self.num_blocks + 1)) + + self.act = nn.ReLU(inplace=True) + + def forward(self, + pixel: torch.Tensor, + obj_summaries: torch.Tensor, + selector: Optional[torch.Tensor] = None, + need_weights: bool = False) -> (torch.Tensor, Dict[str, torch.Tensor]): + # pixel: B*num_objects*embed_dim*H*W + # obj_summaries: B*num_objects*T*num_queries*embed_dim + T = obj_summaries.shape[2] + bs, num_objects, _, H, W = pixel.shape + + # normalize object values + # the last channel is the cumulative area of the object + obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries, + self.embed_dim + 1) + # sum over time + # during inference, T=1 as we already did streaming average in memory_manager + obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1) + obj_area = obj_summaries[:, :, :, -1:].sum(dim=1) + obj_values = obj_sums / (obj_area + 1e-4) + obj_init = self.summary_to_query_init(obj_values) + obj_emb = self.summary_to_query_emb(obj_values) + + # positional embeddings for object queries + query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init + query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb + + # positional embeddings for pixel features + pixel_init = self.pixel_init_proj(pixel) + pixel_emb = self.pixel_emb_proj(pixel) + pixel_pe = self.spatial_pe(pixel.flatten(0, 1)) + pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() + pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb + + pixel = pixel_init + + # run the transformer + aux_features = {'logits': []} + + # first aux output + aux_logits = self.mask_pred[0](pixel).squeeze(2) + attn_mask = self._get_aux_mask(aux_logits, selector) + aux_features['logits'].append(aux_logits) + for i in range(self.num_blocks): + query, pixel, q_weights, p_weights = self.blocks[i](query, + pixel, + query_emb, + pixel_pe, + attn_mask, + need_weights=need_weights) + + if self.training or i <= self.num_blocks - 1 or need_weights: + aux_logits = self.mask_pred[i + 1](pixel).squeeze(2) + attn_mask = self._get_aux_mask(aux_logits, selector) + aux_features['logits'].append(aux_logits) + + aux_features['q_weights'] = q_weights # last layer only + aux_features['p_weights'] = p_weights # last layer only + + if self.training: + # no need to save all heads + aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads, + self.num_queries, H, W)[:, :, 0] + + return pixel, aux_features + + def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: + # logits: batch_size*num_objects*H*W + # selector: batch_size*num_objects*1*1 + # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W) + # where True means the attention is blocked + + if selector is None: + prob = logits.sigmoid() + else: + prob = logits.sigmoid() * selector + logits = aggregate(prob, dim=1) + + is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0]) + foreground_mask = is_foreground.bool().flatten(start_dim=2) + inv_foreground_mask = ~foreground_mask + inv_background_mask = foreground_mask + + aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat( + 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) + aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat( + 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) + + aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1) + + aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False + + return aux_mask \ No newline at end of file diff --git a/tracker/model/transformer/positional_encoding.py b/tracker/model/transformer/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..6c15bb73784d3e5fcb1a5d2f9713069e7a933f34 --- /dev/null +++ b/tracker/model/transformer/positional_encoding.py @@ -0,0 +1,108 @@ +# Reference: +# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py +# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py + +import math + +import numpy as np +import torch +from torch import nn + + +def get_emb(sin_inp: torch.Tensor) -> torch.Tensor: + """ + Gets a base embedding for one dimension with sin and cos intertwined + """ + emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) + return torch.flatten(emb, -2, -1) + + +class PositionalEncoding(nn.Module): + def __init__(self, + dim: int, + scale: float = math.pi * 2, + temperature: float = 10000, + normalize: bool = True, + channel_last: bool = True, + transpose_output: bool = False): + super().__init__() + dim = int(np.ceil(dim / 4) * 2) + self.dim = dim + inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.normalize = normalize + self.scale = scale + self.eps = 1e-6 + self.channel_last = channel_last + self.transpose_output = transpose_output + + self.cached_penc = None # the cache is irrespective of the number of objects + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + :param tensor: A 4/5d tensor of size + channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c) + channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w) + :return: positional encoding tensor that has the same shape as the input if the input is 4d + if the input is 5d, the output is broadcastable along the k-dimension + """ + if len(tensor.shape) != 4 and len(tensor.shape) != 5: + raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!') + + if len(tensor.shape) == 5: + # take a sample from the k dimension + num_objects = tensor.shape[1] + tensor = tensor[:, 0] + else: + num_objects = None + + if self.channel_last: + batch_size, h, w, c = tensor.shape + else: + batch_size, c, h, w = tensor.shape + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + if num_objects is None: + return self.cached_penc + else: + return self.cached_penc.unsqueeze(1) + + self.cached_penc = None + + pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype) + pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype) + if self.normalize: + pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale + pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale + + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + emb_y = get_emb(sin_inp_y).unsqueeze(1) + emb_x = get_emb(sin_inp_x) + + emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype) + emb[:, :, :self.dim] = emb_x + emb[:, :, self.dim:] = emb_y + + if not self.channel_last and self.transpose_output: + # cancelled out + pass + elif (not self.channel_last) or (self.transpose_output): + emb = emb.permute(2, 0, 1) + + self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1) + if num_objects is None: + return self.cached_penc + else: + return self.cached_penc.unsqueeze(1) + + +if __name__ == '__main__': + pe = PositionalEncoding(8).cuda() + input = torch.ones((1, 8, 8, 8)).cuda() + output = pe(input) + # print(output) + print(output[0, :, 0, 0]) + print(output[0, :, 0, 5]) + print(output[0, 0, :, 0]) + print(output[0, 0, 0, :]) diff --git a/tracker/model/transformer/transformer_layers.py b/tracker/model/transformer/transformer_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..80cf522dad092e8282f43e7f0f0dc05cfd15aa9b --- /dev/null +++ b/tracker/model/transformer/transformer_layers.py @@ -0,0 +1,161 @@ +# Modified from PyTorch nn.Transformer + +from typing import List, Callable + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from tracker.model.channel_attn import CAResBlock + + +class SelfAttention(nn.Module): + def __init__(self, + dim: int, + nhead: int, + dropout: float = 0.0, + batch_first: bool = True, + add_pe_to_qkv: List[bool] = [True, True, False]): + super().__init__() + self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) + self.norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) + self.add_pe_to_qkv = add_pe_to_qkv + + def forward(self, + x: torch.Tensor, + pe: torch.Tensor, + attn_mask: bool = None, + key_padding_mask: bool = None) -> torch.Tensor: + x = self.norm(x) + if any(self.add_pe_to_qkv): + x_with_pe = x + pe + q = x_with_pe if self.add_pe_to_qkv[0] else x + k = x_with_pe if self.add_pe_to_qkv[1] else x + v = x_with_pe if self.add_pe_to_qkv[2] else x + else: + q = k = v = x + + r = x + x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] + return r + self.dropout(x) + + +# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention +class CrossAttention(nn.Module): + def __init__(self, + dim: int, + nhead: int, + dropout: float = 0.0, + batch_first: bool = True, + add_pe_to_qkv: List[bool] = [True, True, False], + residual: bool = True, + norm: bool = True): + super().__init__() + self.cross_attn = nn.MultiheadAttention(dim, + nhead, + dropout=dropout, + batch_first=batch_first) + if norm: + self.norm = nn.LayerNorm(dim) + else: + self.norm = nn.Identity() + self.dropout = nn.Dropout(dropout) + self.add_pe_to_qkv = add_pe_to_qkv + self.residual = residual + + def forward(self, + x: torch.Tensor, + mem: torch.Tensor, + x_pe: torch.Tensor, + mem_pe: torch.Tensor, + attn_mask: bool = None, + *, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor): + x = self.norm(x) + if self.add_pe_to_qkv[0]: + q = x + x_pe + else: + q = x + + if any(self.add_pe_to_qkv[1:]): + mem_with_pe = mem + mem_pe + k = mem_with_pe if self.add_pe_to_qkv[1] else mem + v = mem_with_pe if self.add_pe_to_qkv[2] else mem + else: + k = v = mem + r = x + x, weights = self.cross_attn(q, + k, + v, + attn_mask=attn_mask, + need_weights=need_weights, + average_attn_weights=False) + + if self.residual: + return r + self.dropout(x), weights + else: + return self.dropout(x), weights + + +class FFN(nn.Module): + def __init__(self, dim_in: int, dim_ff: int, activation=F.relu): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_ff) + self.linear2 = nn.Linear(dim_ff, dim_in) + self.norm = nn.LayerNorm(dim_in) + + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = x + x = self.norm(x) + x = self.linear2(self.activation(self.linear1(x))) + x = r + x + return x + + +class PixelFFN(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + self.conv = CAResBlock(dim, dim) + + def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor: + # pixel: batch_size * num_objects * dim * H * W + # pixel_flat: (batch_size*num_objects) * (H*W) * dim + bs, num_objects, _, h, w = pixel.shape + pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim) + pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous() + + x = self.conv(pixel_flat) + x = x.view(bs, num_objects, self.dim, h, w) + return x + + +class OutputFFN(nn.Module): + def __init__(self, dim_in: int, dim_out: int, activation=F.relu): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_out) + self.linear2 = nn.Linear(dim_out, dim_out) + + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear2(self.activation(self.linear1(x))) + return x + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) diff --git a/tracker/model/utils/__init__.py b/tracker/model/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tracker/model/utils/memory_utils.py b/tracker/model/utils/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8726ce6bca51cbd76fff814db9cf03544397cd --- /dev/null +++ b/tracker/model/utils/memory_utils.py @@ -0,0 +1,95 @@ +import math +import torch +from typing import Optional, Union, Tuple + + +# @torch.jit.script +def get_similarity(mk: torch.Tensor, + ms: torch.Tensor, + qk: torch.Tensor, + qe: torch.Tensor, + add_batch_dim: bool = False) -> torch.Tensor: + # used for training/inference and memory reading/memory potentiation + # mk: B x CK x [N] - Memory keys + # ms: B x 1 x [N] - Memory shrinkage + # qk: B x CK x [HW/P] - Query keys + # qe: B x CK x [HW/P] - Query selection + # Dimensions in [] are flattened + if add_batch_dim: + mk, ms = mk.unsqueeze(0), ms.unsqueeze(0) + qk, qe = qk.unsqueeze(0), qe.unsqueeze(0) + + CK = mk.shape[1] + mk = mk.flatten(start_dim=2) + ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None + qk = qk.flatten(start_dim=2) + qe = qe.flatten(start_dim=2) if qe is not None else None + + if qe is not None: + # See XMem's appendix for derivation + mk = mk.transpose(1, 2) + a_sq = (mk.pow(2) @ qe) + two_ab = 2 * (mk @ (qk * qe)) + b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) + similarity = (-a_sq + two_ab - b_sq) + else: + # similar to STCN if we don't have the selection term + a_sq = mk.pow(2).sum(1).unsqueeze(2) + two_ab = 2 * (mk.transpose(1, 2) @ qk) + similarity = (-a_sq + two_ab) + + if ms is not None: + similarity = similarity * ms / math.sqrt(CK) # B*N*HW + else: + similarity = similarity / math.sqrt(CK) # B*N*HW + + return similarity + + +def do_softmax( + similarity: torch.Tensor, + top_k: Optional[int] = None, + inplace: bool = False, + return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # normalize similarity with top-k softmax + # similarity: B x N x [HW/P] + # use inplace with care + if top_k is not None: + values, indices = torch.topk(similarity, k=top_k, dim=1) + + x_exp = values.exp_() + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + if inplace: + similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW + affinity = similarity + else: + affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW + else: + maxes = torch.max(similarity, dim=1, keepdim=True)[0] + x_exp = torch.exp(similarity - maxes) + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + affinity = x_exp / x_exp_sum + indices = None + + if return_usage: + return affinity, affinity.sum(dim=2) + + return affinity + + +def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor, + qe: torch.Tensor) -> torch.Tensor: + # shorthand used in training with no top-k + similarity = get_similarity(mk, ms, qk, qe) + affinity = do_softmax(similarity) + return affinity + + +def readout(affinity: torch.Tensor, mv: torch.Tensor) -> torch.Tensor: + B, CV, T, H, W = mv.shape + + mo = mv.view(B, CV, T * H * W) + mem = torch.bmm(mo, affinity) + mem = mem.view(B, CV, H, W) + + return mem diff --git a/tracker/model/utils/parameter_groups.py b/tracker/model/utils/parameter_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..177866af48de5e6d8795bdf6734b0dccb5a1947b --- /dev/null +++ b/tracker/model/utils/parameter_groups.py @@ -0,0 +1,72 @@ +import logging + +log = logging.getLogger() + + +def get_parameter_groups(model, stage_cfg, print_log=False): + """ + Assign different weight decays and learning rates to different parameters. + Returns a parameter group which can be passed to the optimizer. + """ + weight_decay = stage_cfg.weight_decay + embed_weight_decay = stage_cfg.embed_weight_decay + backbone_lr_ratio = stage_cfg.backbone_lr_ratio + base_lr = stage_cfg.learning_rate + + backbone_params = [] + embed_params = [] + other_params = [] + + embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] + embedding_names = [e + '.weight' for e in embedding_names] + + # inspired by detectron2 + memo = set() + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # Avoid duplicating parameters + if param in memo: + continue + memo.add(param) + + if name.startswith('module'): + name = name[7:] + + inserted = False + if name.startswith('pixel_encoder.'): + backbone_params.append(param) + inserted = True + if print_log: + log.info(f'{name} counted as a backbone parameter.') + else: + for e in embedding_names: + if name.endswith(e): + embed_params.append(param) + inserted = True + if print_log: + log.info(f'{name} counted as an embedding parameter.') + break + + if not inserted: + other_params.append(param) + + parameter_groups = [ + { + 'params': backbone_params, + 'lr': base_lr * backbone_lr_ratio, + 'weight_decay': weight_decay + }, + { + 'params': embed_params, + 'lr': base_lr, + 'weight_decay': embed_weight_decay + }, + { + 'params': other_params, + 'lr': base_lr, + 'weight_decay': weight_decay + }, + ] + + return parameter_groups \ No newline at end of file diff --git a/tracker/model/utils/resnet.py b/tracker/model/utils/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3a07d2fd12da0b951ed6c724f97aa2f203877e7e --- /dev/null +++ b/tracker/model/utils/resnet.py @@ -0,0 +1,179 @@ +""" +resnet.py - A modified ResNet structure +We append extra channels to the first conv by some network surgery +""" + +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +from torch.utils import model_zoo + + +def load_weights_add_extra_dim(target, source_state, extra_dim=1): + new_dict = OrderedDict() + + for k1, v1 in target.state_dict().items(): + if not 'num_batches_tracked' in k1: + if k1 in source_state: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device) + nn.init.orthogonal_(pads) + tar_v = torch.cat([tar_v, pads], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + padding=dilation, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + + +def resnet18(pretrained=True, extra_dim=0): + model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) + return model + + +def resnet50(pretrained=True, extra_dim=0): + model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) + return model diff --git a/tracker/utils/__init__.py b/tracker/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tracker/utils/image_saver.py b/tracker/utils/image_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..c3edfa96e60fea0e5ec8fd087da85d2efaa6444c --- /dev/null +++ b/tracker/utils/image_saver.py @@ -0,0 +1,230 @@ +import cv2 +import numpy as np + +import torch +from collections import defaultdict + + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + + +def tensor_to_np_float(image): + image_np = image.numpy().astype('float32') + return image_np + + +def detach_to_cpu(x): + return x.detach().cpu() + + +def transpose_np(x): + return np.transpose(x, [1, 2, 0]) + + +def tensor_to_gray_im(x): + x = detach_to_cpu(x) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + + +def tensor_to_im(x): + x = detach_to_cpu(x).clamp(0, 1) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + + +# Predefined key <-> caption dict +key_captions = { + 'im': 'Image', + 'gt': 'GT', +} +""" +Return an image array with captions +keys in dictionary will be used as caption if not provided +values should contain lists of cv2 images +""" + + +def get_image_array(images, grid_shape, captions={}): + h, w = grid_shape + cate_counts = len(images) + rows_counts = len(next(iter(images.values()))) + + font = cv2.FONT_HERSHEY_SIMPLEX + + output_image = np.zeros([w * cate_counts, h * (rows_counts + 1), 3], dtype=np.uint8) + col_cnt = 0 + for k, v in images.items(): + + # Default as key value itself + caption = captions.get(k, k) + + # Handles new line character + dy = 40 + for i, line in enumerate(caption.split('\n')): + cv2.putText(output_image, line, (10, col_cnt * w + 100 + i * dy), font, 0.8, + (255, 255, 255), 2, cv2.LINE_AA) + + # Put images + for row_cnt, img in enumerate(v): + im_shape = img.shape + if len(im_shape) == 2: + img = img[..., np.newaxis] + + img = (img * 255).astype('uint8') + + output_image[(col_cnt + 0) * w:(col_cnt + 1) * w, + (row_cnt + 1) * h:(row_cnt + 2) * h, :] = img + + col_cnt += 1 + + return output_image + + +def base_transform(im, size): + im = tensor_to_np_float(im) + if len(im.shape) == 3: + im = im.transpose((1, 2, 0)) + else: + im = im[:, :, None] + + # Resize + if im.shape[1] != size: + im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) + + return im.clip(0, 1) + + +def im_transform(im, size): + return base_transform(detach_to_cpu(im), size=size) + + +def mask_transform(mask, size): + return base_transform(detach_to_cpu(mask), size=size) + + +def logits_transform(mask, size): + return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) + + +def add_attention(mask, pos): + mask = mask[:, :, None].repeat(3, axis=2) + pos = (pos + 1) / 2 + for i in range(pos.shape[0]): + y = int(pos[i][0] * mask.shape[0]) + x = int(pos[i][1] * mask.shape[1]) + y = max(min(y, mask.shape[0] - 1), 0) + x = max(min(x, mask.shape[1] - 1), 0) + # mask[y, x, :] = (255, 0, 0) + cv2.circle(mask, (x, y), 5, (1, 0, 0), -1) + return mask + + +def vis(images, size, num_objects): + req_images = defaultdict(list) + + b, t = images['rgb'].shape[:2] + + # limit the number of images saved + b = min(2, b) + + # find max num objects + max_num_objects = max(num_objects[:b]) + + GT_suffix = '' + for bi in range(b): + GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] + + for bi in range(b): + for ti in range(t): + req_images['RGB'].append(im_transform(images['rgb'][bi, ti], size)) + aux = images[f'aux_{max(ti, 1)}'] # no aux_0, use aux_1 for shape + if 'sensory_logits' in aux: + sensory_aux = aux['sensory_logits'][bi].softmax(dim=0) + # batch_size * num_objects * num_levels * H * W + q_mask_aux = aux['q_logits'][bi].softmax(dim=0) + num_levels = q_mask_aux.shape[1] + + for oi in range(max_num_objects): + if ti == 0 or oi >= num_objects[bi]: + req_images[f'Mask_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + req_images[f'S-Aux_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + for l in range(num_levels): + req_images[f'Q-Aux-L{l}_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + else: + mask = mask_transform(images[f'masks_{ti}'][bi][oi], size) + req_images[f'Mask_{oi}'].append(mask) + if 'sensory_logits' in aux: + req_images[f'S-Aux_{oi}'].append(mask_transform(sensory_aux[oi + 1], size)) + + for l in range(num_levels): + mask = mask_transform(q_mask_aux[oi + 1, l], size) + req_images[f'Q-Aux-L{l}_{oi}'].append(mask) + + req_images[f'GT_{oi}_{GT_suffix}'].append( + mask_transform(images['cls_gt'][bi, ti, 0] == (oi + 1), size)) + + return get_image_array(req_images, size, key_captions) + + +def vis_debug(images, size, num_objects): + req_images = defaultdict(list) + + b, t = images['rgb'].shape[:2] + + # limit the number of images saved + b = min(2, b) + + # find max num objects + max_num_objects = max(num_objects[:b]) + + GT_suffix = '' + for bi in range(b): + GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] + + for bi in range(b): + for ti in range(t): + req_images['RGB'].append(im_transform(images['rgb'][bi, ti], size)) + aux = images[f'aux_{max(ti, 1)}'] # no aux_0, use aux_1 for shape + sensory_aux = aux['sensory_logits'][bi].softmax(dim=0) + # batch_size * num_objects * num_levels * H * W + q_mask_aux = aux['q_logits'][bi].softmax(dim=0) + attn_mask = aux['attn_mask'][bi] + num_levels = q_mask_aux.shape[1] + num_queries = attn_mask.shape[1] + + for oi in range(max_num_objects): + if ti == 0 or oi >= num_objects[bi]: + req_images[f'Mask_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + req_images[f'S-Aux_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + for l in range(num_levels): + req_images[f'Q-Aux-L{l}_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + for q in range(num_queries): + req_images[f'Attn-Mask-Q{q}_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + else: + mask = mask_transform(images[f'masks_{ti}'][bi][oi], size) + req_images[f'Mask_{oi}'].append(mask) + req_images[f'S-Aux_{oi}'].append(mask_transform(sensory_aux[oi + 1], size)) + + for l in range(num_levels): + mask = mask_transform(q_mask_aux[oi + 1, l], size) + req_images[f'Q-Aux-L{l}_{oi}'].append(mask) + for q in range(num_queries): + mask = mask_transform(1 - attn_mask[oi, q].float(), size) + req_images[f'Attn-Mask-Q{q}_{oi}'].append(mask) + + req_images[f'GT_{oi}_{GT_suffix}'].append( + mask_transform(images['cls_gt'][bi, ti, 0] == (oi + 1), size)) + + return get_image_array(req_images, size, key_captions) \ No newline at end of file diff --git a/tracker/utils/load_subset.py b/tracker/utils/load_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..c16ed0391ae745a736290bb7b956c98539e087ca --- /dev/null +++ b/tracker/utils/load_subset.py @@ -0,0 +1,13 @@ +import json + + +def load_subset(path): + with open(path, mode='r') as f: + subset = set(f.read().splitlines()) + return subset + + +def load_empty_masks(path): + with open(path, mode='r') as f: + empty_masks = json.load(f) + return empty_masks diff --git a/tracker/utils/log_integrator.py b/tracker/utils/log_integrator.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e12ebb393cfd6b734859fad3c243850ba60ea7 --- /dev/null +++ b/tracker/utils/log_integrator.py @@ -0,0 +1,84 @@ +""" +Integrate numerical values for some iterations +Typically used for loss computation / logging to tensorboard +Call finalize and create a new Integrator when you want to display/log +""" +from typing import Dict, Callable, Tuple +import torch +from tracker.utils.logger import TensorboardLogger + + +class Integrator: + def __init__(self, logger: TensorboardLogger, distributed: bool = True): + self.values = {} + self.counts = {} + self.hooks = [] # List is used here to maintain insertion order + + self.logger = logger + + self.distributed = distributed + self.local_rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def add_tensor(self, key: str, tensor: torch.Tensor): + if key not in self.values: + self.counts[key] = 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] = tensor + else: + self.values[key] = tensor.mean().item() + else: + self.counts[key] += 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] += tensor + else: + self.values[key] += tensor.mean().item() + + def add_dict(self, tensor_dict: Dict[str, torch.Tensor]): + for k, v in tensor_dict.items(): + self.add_tensor(k, v) + + def add_hook(self, hook: Callable[[torch.Tensor], Tuple[str, torch.Tensor]]): + """ + Adds a custom hook, i.e. compute new metrics using values in the dict + The hook takes the dict as argument, and returns a (k, v) tuple + e.g. for computing IoU + """ + if type(hook) == list: + self.hooks.extend(hook) + else: + self.hooks.append(hook) + + def reset_except_hooks(self): + self.values = {} + self.counts = {} + + # Average and output the metrics + def finalize(self, exp_id: str, prefix: str, it: int) -> None: + + for hook in self.hooks: + k, v = hook(self.values) + self.add_tensor(k, v) + + outputs = {} + for k, v in self.values.items(): + + if k[:4] == 'hide': + continue + + avg = v / self.counts[k] + + if self.distributed: + # Inplace operation + avg = torch.tensor(avg).cuda() + torch.distributed.reduce(avg, dst=0) + + if self.local_rank == 0: + avg = (avg / self.world_size).cpu().item() + outputs[k] = avg + else: + # Simple does it + outputs[k] = avg + + if (not self.distributed) or (self.local_rank == 0): + self.logger.log_metrics(exp_id, prefix, outputs, it) diff --git a/tracker/utils/logger.py b/tracker/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..399f76d685d2f0fa8a68b69222c1400815a0a2e3 --- /dev/null +++ b/tracker/utils/logger.py @@ -0,0 +1,107 @@ +""" +Dumps things to tensorboard and console +""" + +import os +import logging +import datetime +from typing import Dict +import numpy as np +from PIL import Image + +from torch.utils.tensorboard import SummaryWriter +from tracker.utils.time_estimator import TimeEstimator + + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + + +def detach_to_cpu(x): + return x.detach().cpu() + + +def fix_width_trunc(x): + return ('{:.9s}'.format('{:0.9f}'.format(x))) + + +class TensorboardLogger: + def __init__(self, run_dir, py_logger: logging.Logger, *, enabled_tb): + self.run_dir = run_dir + self.py_log = py_logger + if enabled_tb: + self.tb_log = SummaryWriter(run_dir) + else: + self.tb_log = None + + # Get current git info for logging + try: + import git + repo = git.Repo(".") + git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha) + except (ImportError, RuntimeError): + print('Failed to fetch git info. Defaulting to None') + git_info = 'None' + + self.log_string('git', git_info) + + # used when logging metrics + self.time_estimator: TimeEstimator = None + + def log_scalar(self, tag, x, it): + if self.tb_log is None: + return + self.tb_log.add_scalar(tag, x, it) + + def log_metrics(self, exp_id, prefix, metrics: Dict, it): + msg = f'{exp_id}-{prefix} - it {it:6d}: ' + metrics_msg = '' + for k, v in sorted(metrics.items()): + self.log_scalar(f'{prefix}/{k}', v, it) + metrics_msg += f'{k: >10}:{v:.7f},\t' + + if self.time_estimator is not None: + self.time_estimator.update() + avg_time = self.time_estimator.get_and_reset_avg_time() + est = self.time_estimator.get_est_remaining(it) + est = datetime.timedelta(seconds=est) + if est.days > 0: + remaining_str = f'{est.days}d {est.seconds // 3600}h' + else: + remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m' + eta = datetime.datetime.now() + est + eta_str = eta.strftime('%Y-%m-%d %H:%M:%S') + time_msg = f'avg_time:{avg_time:.3f},remaining:{remaining_str},eta:{eta_str},\t' + msg = f'{msg} {time_msg}' + + msg = f'{msg} {metrics_msg}' + self.py_log.info(msg) + + def log_image(self, stage_name, tag, image, it): + image_dir = os.path.join(self.run_dir, f'{stage_name}_images') + os.makedirs(image_dir, exist_ok=True) + + image = Image.fromarray(image) + image.save(os.path.join(image_dir, f'{tag}_{it}.png')) + + def log_string(self, tag, x): + self.py_log.info(f'{tag} - {x}') + if self.tb_log is None: + return + self.tb_log.add_text(tag, x) + + def debug(self, x): + self.py_log.debug(x) + + def info(self, x): + self.py_log.info(x) + + def warning(self, x): + self.py_log.warning(x) + + def error(self, x): + self.py_log.error(x) + + def critical(self, x): + self.py_log.critical(x) diff --git a/tracker/utils/mask_mapper.py b/tracker/utils/mask_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..617af0c712d20f429a05274438b77a5afc88d2db --- /dev/null +++ b/tracker/utils/mask_mapper.py @@ -0,0 +1,78 @@ +import numpy as np +import torch + +def all_to_onehot(masks, labels): + if len(masks.shape) == 3: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) + else: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) + + for ni, l in enumerate(labels): + Ms[ni] = (masks == l).astype(np.uint8) + + return Ms + +class MaskMapper: + """ + This class is used to convert a indexed-mask to a one-hot representation. + It also takes care of remapping non-continuous indices + It has two modes: + 1. Default. Only masks with new indices are supposed to go into the remapper. + This is also the case for YouTubeVOS. + i.e., regions with index 0 are not "background", but "don't care". + + 2. Exhaustive. Regions with index 0 are considered "background". + Every single pixel is considered to be "labeled". + """ + def __init__(self): + self.labels = [] + self.remappings = {} + + # if coherent, no mapping is required + self.coherent = True + + def clear_labels(self): + self.labels = [] + self.remappings = {} + # if coherent, no mapping is required + self.coherent = True + + def convert_mask(self, mask, exhaustive=False): + # mask is in index representation, H*W numpy array + labels = np.unique(mask).astype(np.uint8) + labels = labels[labels!=0].tolist() + + new_labels = list(set(labels) - set(self.labels)) + if not exhaustive: + assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' + + # add new remappings + for i, l in enumerate(new_labels): + self.remappings[l] = i+len(self.labels)+1 + if self.coherent and i+len(self.labels)+1 != l: + self.coherent = False + + if exhaustive: + new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) + else: + if self.coherent: + new_mapped_labels = new_labels + else: + new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) + + self.labels.extend(new_labels) + # mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() + mask = torch.from_numpy(mask).float() + # mask num_objects*H*W + return mask, new_mapped_labels + + + def remap_index_mask(self, mask): + # mask is in index representation, H*W numpy array + if self.coherent: + return mask + + new_mask = np.zeros_like(mask) + for l, i in self.remappings.items(): + new_mask[mask==i] = l + return new_mask \ No newline at end of file diff --git a/tracker/utils/palette.py b/tracker/utils/palette.py new file mode 100644 index 0000000000000000000000000000000000000000..26a773c88bdd15fdb372fa9f552602a751625fc4 --- /dev/null +++ b/tracker/utils/palette.py @@ -0,0 +1,9 @@ +import numpy as np + +davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0' + +youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f' + +davis_palette_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3) + +youtube_palette_np = np.frombuffer(youtube_palette, dtype=np.uint8).reshape(-1, 3) \ No newline at end of file diff --git a/tracker/utils/pano_utils.py b/tracker/utils/pano_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0e2fe9448ec977f0f177c2b5dc9aaccf38250d --- /dev/null +++ b/tracker/utils/pano_utils.py @@ -0,0 +1,30 @@ +import numpy as np +from threading import Lock + + +class ID2RGBConverter: + def __init__(self): + self.all_id = [] + self.obj_to_id = {} + self.lock = Lock() + + def _id_to_rgb(self, id: int): + rgb = np.zeros((3, ), dtype=np.uint8) + for i in range(3): + rgb[i] = id % 256 + id = id // 256 + return rgb + + def convert(self, obj: int): + with self.lock: + if obj in self.obj_to_id: + id = self.obj_to_id[obj] + else: + while True: + id = np.random.randint(255, 256**3) + if id not in self.all_id: + break + self.obj_to_id[obj] = id + self.all_id.append(id) + + return id, self._id_to_rgb(id) diff --git a/tracker/utils/point_features.py b/tracker/utils/point_features.py new file mode 100644 index 0000000000000000000000000000000000000000..87b794ef23c856bba022215c84581ba38e6d030b --- /dev/null +++ b/tracker/utils/point_features.py @@ -0,0 +1,111 @@ +# This file is copied from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py +# such that users do not need to install detectron2 just for these two functions +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import List +import torch +from torch.nn import functional as F + + +def cat(tensors: List[torch.Tensor], dim: int = 0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + + +def calculate_uncertainty(sem_seg_logits): + """ + For each location of the prediction `sem_seg_logits` we estimate uncerainty as the + difference between top first and top second predicted logits. + Args: + mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and + C is the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + if sem_seg_logits.shape[1] == 2: + # binary segmentation + return -(torch.abs(sem_seg_logits[:, 1:2])) + top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) + + +def point_sample(input, point_coords, **kwargs): + """ + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + Args: + input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains + [0, 1] x [0, 1] normalized point coordinates. + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interpolation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +def get_uncertain_point_coords_with_randomness(coarse_logits, uncertainty_func, num_points, + oversample_ratio, importance_sample_ratio): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) + point_logits = point_sample(coarse_logits, point_coords, align_corners=False) + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the coarse predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, + 2) + if num_random_points > 0: + point_coords = cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), + ], + dim=1, + ) + return point_coords \ No newline at end of file diff --git a/tracker/utils/range_transform.py b/tracker/utils/range_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1b0b3b2a01a061b9b2220a93cdf7f7a6357bfb --- /dev/null +++ b/tracker/utils/range_transform.py @@ -0,0 +1,12 @@ +import torchvision.transforms as transforms + +im_mean = (124, 116, 104) + +im_normalization = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + +inv_im_trans = transforms.Normalize( + mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + std=[1/0.229, 1/0.224, 1/0.225]) diff --git a/tracker/utils/tensor_utils.py b/tracker/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e33c8936134ec4efdcc412945469853edc6498 --- /dev/null +++ b/tracker/utils/tensor_utils.py @@ -0,0 +1,62 @@ +from typing import List, Iterable +import torch +import torch.nn.functional as F + + +# STM +def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]): + h, w = in_img.shape[-2:] + + if h % d > 0: + new_h = h + d - h % d + else: + new_h = h + if w % d > 0: + new_w = w + d - w % d + else: + new_w = w + lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) + lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) + pad_array = (int(lw), int(uw), int(lh), int(uh)) + out = F.pad(in_img, pad_array) + return out, pad_array + + +def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor: + if len(img.shape) == 4: + if pad[2] + pad[3] > 0: + img = img[:, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, pad[0]:-pad[1]] + elif len(img.shape) == 3: + if pad[2] + pad[3] > 0: + img = img[:, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, pad[0]:-pad[1]] + elif len(img.shape) == 5: + if pad[2] + pad[3] > 0: + img = img[:, :, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, :, pad[0]:-pad[1]] + else: + raise NotImplementedError + return img + + +# @torch.jit.script +def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor: + with torch.cuda.amp.autocast(enabled=False): + prob = prob.float() + new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob], + dim).clamp(1e-7, 1 - 1e-7) + logits = torch.log((new_prob / (1 - new_prob))) + + return logits + + +# @torch.jit.script +def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor: + # cls_gt: B*1*H*W + B, _, H, W = cls_gt.shape + one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1) + return one_hot \ No newline at end of file diff --git a/tracker/utils/time_estimator.py b/tracker/utils/time_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..8d429b0404b641472ff84701305c570cc26280b7 --- /dev/null +++ b/tracker/utils/time_estimator.py @@ -0,0 +1,43 @@ +import time + + +class TimeEstimator: + def __init__(self, total_iter, step_size): + self.avg_time_window = [] # window-based average + self.exp_avg_time = None # exponential moving average + self.alpha = 0.7 # for exponential moving average + + self.last_time = time.time() # would not be accurate for the first iteration but well + self.total_iter = total_iter + self.step_size = step_size + + self.buffering_exp = True + + # call this at a fixed interval + # does not have to be every step + def update(self): + curr_time = time.time() + time_per_iter = curr_time - self.last_time + self.last_time = curr_time + + self.avg_time_window.append(time_per_iter) + + if self.buffering_exp: + if self.exp_avg_time is not None: + # discard the first iteration call to not pollute the ema + self.buffering_exp = False + self.exp_avg_time = time_per_iter + else: + self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter + + def get_est_remaining(self, it): + if self.exp_avg_time is None: + return 0 + + remaining_iter = self.total_iter - it + return remaining_iter * self.exp_avg_time / self.step_size + + def get_and_reset_avg_time(self): + avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size + self.avg_time_window = [] + return avg diff --git a/utils/download_util.py b/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8fb1b00522309d0c0931f5396355011fb200e7 --- /dev/null +++ b/utils/download_util.py @@ -0,0 +1,109 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + print(response_file_size) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file \ No newline at end of file diff --git a/utils/evaluate_video.py b/utils/evaluate_video.py new file mode 100644 index 0000000000000000000000000000000000000000..47382fa43863ec0b4ed8772b3c7d4b9950b39a65 --- /dev/null +++ b/utils/evaluate_video.py @@ -0,0 +1,133 @@ +import torch +import numpy as np +from skimage.metrics import structural_similarity as ssim +from skimage.metrics import peak_signal_noise_ratio as psnr +import lpips +from tqdm import tqdm +import cv2 + +class VideoQualityEvaluator: + def __init__(self, device='cuda'): + """Initialize video quality evaluator with specified computation device + + Args: + device (str): Computation device ('cuda' or 'cpu') + """ + self.device = device + # Initialize LPIPS model (perceptual similarity metric) + self.lpips_model = lpips.LPIPS(net='alex').to(device) + + def _preprocess_frame(self, frame): + """Convert frame to standardized format for evaluation + + Args: + frame: Input frame (numpy array or torch tensor) + + Returns: + Processed frame in HWC format with values in [0,1] + """ + if isinstance(frame, torch.Tensor): + frame = frame.detach().cpu().numpy() + + # Normalize to [0,1] if needed + if frame.max() > 1: + frame = frame / 255.0 + # Convert CHW to HWC if needed + if len(frame.shape) == 3 and frame.shape[0] == 3: + frame = frame.transpose(1, 2, 0) + return frame + + def calculate_psnr(self, vid1, vid2): + """Calculate average PSNR between two videos + + Args: + vid1: First video (list/array of frames) + vid2: Second video (list/array of frames) + + Returns: + Mean PSNR value across all frames + """ + psnrs = [] + for f1, f2 in zip(vid1, vid2): + f1 = self._preprocess_frame(f1) + f2 = self._preprocess_frame(f2) + # Calculate PSNR for this frame pair + psnrs.append(psnr(f1, f2, data_range=1.0)) + return np.mean(psnrs) + + def calculate_ssim(self, vid1, vid2): + """Calculate average SSIM between two videos + + Args: + vid1: First video (list/array of frames) + vid2: Second video (list/array of frames) + + Returns: + Mean SSIM value across all frames + """ + ssims = [] + for f1, f2 in zip(vid1, vid2): + f1 = self._preprocess_frame(f1) + f2 = self._preprocess_frame(f2) + # Calculate SSIM for this frame pair (multichannel for color images) + ssims.append(ssim(f1, f2, channel_axis=2, data_range=1.0)) + return np.mean(ssims) + + def calculate_lpips(self, vid1, vid2): + """Calculate average LPIPS (perceptual similarity) between two videos + + Args: + vid1: First video (list/array of frames) + vid2: Second video (list/array of frames) + + Returns: + Mean LPIPS value across all frames (lower is better) + """ + lpips_values = [] + for f1, f2 in zip(vid1, vid2): + # Convert to torch tensor if needed + if not isinstance(f1, torch.Tensor): + f1 = torch.from_numpy(f1).permute(2, 0, 1).unsqueeze(0).float() # HWC -> 1CHW + f2 = torch.from_numpy(f2).permute(2, 0, 1).unsqueeze(0).float() + + # Normalize to [-1,1] if needed + if f1.max() > 1: + f1 = f1 / 127.5 - 1.0 + f2 = f2 / 127.5 - 1.0 + + f1 = f1.to(self.device) + f2 = f2.to(self.device) + + # Calculate LPIPS with no gradients + with torch.no_grad(): + lpips_values.append(self.lpips_model(f1, f2).item()) + return np.mean(lpips_values) + + def evaluate_videos(self, generated_video, reference_video, metrics=['psnr','lpips','ssim']): + """Comprehensive video quality evaluation between generated and reference videos + + Args: + generated_video: Model-generated video [T,H,W,C] or [T,C,H,W] + reference_video: Ground truth reference video [T,H,W,C] or [T,C,H,W] + metrics: List of metrics to compute ('psnr', 'ssim', 'lpips') + + Returns: + Dictionary containing computed metric values + """ + results = {} + + # Verify video lengths match + assert len(generated_video) == len(reference_video), "Videos must have same number of frames" + + # Calculate requested metrics + if 'psnr' in metrics: + results['psnr'] = self.calculate_psnr(generated_video, reference_video) + + if 'ssim' in metrics: + results['ssim'] = self.calculate_ssim(generated_video, reference_video) + + if 'lpips' in metrics: + results['lpips'] = self.calculate_lpips(generated_video, reference_video) + + return results + diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..43b849902245dd338a36f4f4ff09e33425365af6 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,131 @@ +import os +import re +import random +import time +import torch +import torch.nn as nn +import logging +import numpy as np +from os import path as osp + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +initialized_logger = {} +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + + if log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ + torch.__version__)[0][:3])] >= [1, 12, 0] + +def gpu_is_available(): + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return True + return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False + +def get_device(gpu_id=None): + if gpu_id is None: + gpu_str = '' + elif isinstance(gpu_id, int): + gpu_str = f':{gpu_id}' + else: + raise TypeError('Input should be int value.') + + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return torch.device('mps'+gpu_str) + return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) \ No newline at end of file