diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..7442d2d39efdf3870535ae1e130b6743be9c33ec 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -31,5 +31,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..cce2271ec487a5debb14ea9051db762035caa45a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,8 @@
+__pycache__/
+.vscode/
+docs/
+debug_images/
+images/
+result/
+vots/
+vots.py
diff --git a/README.md b/README.md
index 7196705a2587a95a5500c163d4b8bdfce39cc1fe..ac470bb9a3137ce77162340c20255116d87a0ae9 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,74 @@
+
+
+
+
---
-title: ROSE
-emoji: 👁
-colorFrom: green
-colorTo: indigo
+title: ROSE Awesome Space
+emoji: 🚀
+colorFrom: blue
+colorTo: pink
sdk: gradio
-sdk_version: 5.34.2
+sdk_version: 4.15.0
app_file: app.py
pinned: false
-license: apache-2.0
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+## Get Started
+
+1. Install ProPainter Dependencies
+
+ You can follow the [Dependencies and Installation](https://github.com/Luo-Yihang/ProPainter-pr/tree/dev_yihang#dependencies-and-installation).
+
+3. Install Demo Dependencies
+```shell
+cd web-demos/hugging_face
+
+# install python dependencies
+pip3 install -r requirements.txt
+
+# Run the demo
+python app.py
+```
+
+## Usage Guidance
+* Step 1: Upload your video and click the `Get video info` button.
+ 
+
+* Step 2:
+ 1. *[Optional]* Specify the tracking period for the currently added mask by dragging the `Track start frame` or `Track end frame`.
+ 2. Click the image on the left to select the mask area.
+ 3. - Click `Add mask` if you are satisfied with the mask, or
+ - *[Optional]* Click `Clear clicks` if you want to reselect the mask area, or
+ - *[Optional]* Click `Remove mask` to remove all masks.
+ 4. *[Optional]* Go back to step 2.1 to add another mask.
+ 
+
+* Step 3:
+ 1. Click the `Tracking` button to track the masks for the whole video.
+ 2. *[Optional]* Select the ProPainter parameters if the `ProPainter Parameters` dropdown.
+ 2. Then click `Inpainting` to get the inpainting results.
+ 
+
+*You can always refer to the `Highlighted Text` box on the page for guidance on the next step!*
+
+
+## Citation
+If you find our repo useful for your research, please consider citing our paper:
+```bibtex
+@inproceedings{zhou2023propainter,
+ title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting},
+ author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change},
+ booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)},
+ year={2023}
+}
+```
+
+
+## License
+
+This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license.
+
+
+## Acknowledgements
+
+The project harnesses the capabilities from [Track Anything](https://github.com/gaomingqi/Track-Anything), [Segment Anything](https://github.com/facebookresearch/segment-anything) and [Cutie](https://github.com/hkchengrex/Cutie). Thanks for their awesome works.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..51252fe0a70ea77a52364eec919de42ae7a95576
--- /dev/null
+++ b/app.py
@@ -0,0 +1,679 @@
+import sys
+sys.path.append("./")
+
+import os
+import json
+import time
+import psutil
+import argparse
+
+import cv2
+import torch
+import torchvision
+import numpy as np
+import gradio as gr
+
+from tools.painter import mask_painter
+from track_anything import TrackingAnything
+
+from utils.misc import get_device
+from utils.download_util import load_file_from_url
+from transformers import AutoTokenizer, AutoModel
+from omegaconf import OmegaConf
+from torchvision.transforms import functional as TF
+from torchvision.utils import save_image
+from einops import rearrange
+from PIL import Image
+
+from rose.models import AutoencoderKLWan, CLIPModel, WanT5EncoderModel, WanTransformer3DModel
+from rose.pipeline import WanFunInpaintPipeline
+from diffusers import FlowMatchEulerDiscreteScheduler
+
+def filter_kwargs(cls, kwargs):
+ import inspect
+ sig = inspect.signature(cls.__init__)
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
+ return {k: v for k, v in kwargs.items() if k in valid_params}
+
+# pretrained_model_path = "./models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
+pretrained_model_path = "alibaba-pai/Wan2.1-Fun-1.3B-InP"
+transformer_path = "Kunbyte/ROSE"
+# config_path = "configs/wan2.1/wan_civitai.yaml"
+config_path = "./configs/wan2.1/wan_civitai.yaml"
+config = OmegaConf.load(config_path)
+
+tokenizer_subpath = config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')
+tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_model_path}/{tokenizer_subpath}")
+
+text_encoder_subpath = config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')
+text_encoder = WanT5EncoderModel.from_pretrained(
+ f"{pretrained_model_path}/{text_encoder_subpath}",
+ additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
+ low_cpu_mem_usage=True,
+)
+
+transformer_subpath = config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')
+transformer3d = WanTransformer3DModel.from_pretrained(
+ f"{transformer_path}/{transformer_subpath}",
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
+)
+
+image_encoder_subpath = config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')
+clip_image_encoder = CLIPModel.from_pretrained(f"{pretrained_model_path}/{image_encoder_subpath}")
+
+vae_subpath = config['vae_kwargs'].get('vae_subpath', 'vae')
+vae = AutoencoderKLWan.from_pretrained(
+ f"{pretrained_model_path}/{vae_subpath}",
+ additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
+)
+
+noise_scheduler = FlowMatchEulerDiscreteScheduler(
+ **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs']))
+)
+
+# tokenizer = AutoTokenizer.from_pretrained(
+# os.path.join(pretrained_model_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
+# )
+# text_encoder = WanT5EncoderModel.from_pretrained(
+# os.path.join(pretrained_model_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
+# additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
+# low_cpu_mem_usage=True,
+# )
+# clip_image_encoder = CLIPModel.from_pretrained(
+# os.path.join(pretrained_model_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
+# )
+# vae = AutoencoderKLWan.from_pretrained(
+# os.path.join(pretrained_model_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
+# additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
+# )
+# transformer3d = WanTransformer3DModel.from_pretrained(
+# os.path.join(transformer_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
+# transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
+# )
+# noise_scheduler = FlowMatchEulerDiscreteScheduler(
+# **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs']))
+# )
+
+pipeline = WanFunInpaintPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer3d,
+ scheduler=noise_scheduler,
+ clip_image_encoder=clip_image_encoder
+).to("cuda", torch.float16)
+
+
+def parse_augment():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--device', type=str, default=None)
+ parser.add_argument('--sam_model_type', type=str, default="vit_h")
+ parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
+ parser.add_argument('--mask_save', default=False)
+ args = parser.parse_args()
+
+ if not args.device:
+ args.device = str(get_device())
+
+ return args
+
+# convert points input to prompt state
+def get_prompt(click_state, click_input):
+ inputs = json.loads(click_input)
+ points = click_state[0]
+ labels = click_state[1]
+ for input in inputs:
+ points.append(input[:2])
+ labels.append(input[2])
+ click_state[0] = points
+ click_state[1] = labels
+ prompt = {
+ "prompt_type":["click"],
+ "input_point":click_state[0],
+ "input_label":click_state[1],
+ "multimask_output":"True",
+ }
+ return prompt
+
+# extract frames from upload video
+def get_frames_from_video(video_input, video_state):
+ """
+ Args:
+ video_path:str
+ timestamp:float64
+ Return
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
+ """
+ video_path = video_input
+ frames = []
+ user_name = time.time()
+ operation_log = [("[Must Do]", "Click image"), (": Video uploaded! Try to click the image shown in step2 to add masks.\n", None)]
+ try:
+ cap = cv2.VideoCapture(video_path)
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret == True:
+ current_memory_usage = psutil.virtual_memory().percent
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ if current_memory_usage > 90:
+ operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
+ print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
+ break
+ else:
+ break
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
+ image_size = (frames[0].shape[0],frames[0].shape[1])
+ # initialize video_state
+ video_state = {
+ "user_name": user_name,
+ "video_name": os.path.split(video_path)[-1],
+ "origin_images": frames,
+ "painted_images": frames.copy(),
+ "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
+ "logits": [None]*len(frames),
+ "select_frame_number": 0,
+ "fps": fps
+ }
+ video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
+ model.samcontroler.sam_controler.reset_image()
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
+ return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
+ gr.update(visible=True), gr.update(visible=True), \
+ gr.update(visible=True), gr.update(visible=True),\
+ gr.update(visible=True), gr.update(visible=True), \
+ gr.update(visible=True), gr.update(visible=True), \
+ gr.update(visible=True), gr.update(visible=True), \
+ gr.update(visible=True), gr.update(visible=True, choices=[], value=[]), \
+ gr.update(visible=True, value=operation_log), gr.update(visible=True, value=operation_log)
+
+# get the select frame from gradio slider
+def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown):
+
+ # images = video_state[1]
+ image_selection_slider -= 1
+ video_state["select_frame_number"] = image_selection_slider
+
+ # once select a new template frame, set the image in sam
+
+ model.samcontroler.sam_controler.reset_image()
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
+
+ operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")]
+
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log
+
+# set the tracking end frame
+def get_end_number(track_pause_number_slider, video_state, interactive_state):
+ interactive_state["track_end_number"] = track_pause_number_slider
+ operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")]
+
+ return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
+
+# use sam to get the mask
+def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
+ """
+ Args:
+ template_frame: PIL.Image
+ point_prompt: flag for positive or negative button click
+ click_state: [[points], [labels]]
+ """
+ if point_prompt == "Positive":
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
+ interactive_state["positive_click_times"] += 1
+ else:
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
+ interactive_state["negative_click_times"] += 1
+
+ # prompt for sam model
+ model.samcontroler.sam_controler.reset_image()
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
+
+ mask, logit, painted_image = model.first_frame_click(
+ image=video_state["origin_images"][video_state["select_frame_number"]],
+ points=np.array(prompt["input_point"]),
+ labels=np.array(prompt["input_label"]),
+ multimask=prompt["multimask_output"],
+ )
+
+ video_state["masks"][video_state["select_frame_number"]] = mask
+ video_state["logits"][video_state["select_frame_number"]] = logit
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
+
+ operation_log = [("[Must Do]", "Add mask"), (": add the current displayed mask for video segmentation.\n", None),
+ ("[Optional]", "Remove mask"), (": remove all added masks.\n", None),
+ ("[Optional]", "Clear clicks"), (": clear current displayed mask.\n", None),
+ ("[Optional]", "Click image"), (": Try to click the image shown in step2 if you want to generate more masks.\n", None)]
+ return painted_image, video_state, interactive_state, operation_log, operation_log
+
+def add_multi_mask(video_state, interactive_state, mask_dropdown):
+ try:
+ mask = video_state["masks"][video_state["select_frame_number"]]
+ interactive_state["multi_mask"]["masks"].append(mask)
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
+ select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown)
+ operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
+ except:
+ operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")]
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log
+
+def clear_click(video_state, click_state):
+ click_state = [[],[]]
+ template_frame = video_state["origin_images"][video_state["select_frame_number"]]
+ operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")]
+ return template_frame, click_state, operation_log, operation_log
+
+def remove_multi_mask(interactive_state, mask_dropdown):
+ interactive_state["multi_mask"]["mask_names"]= []
+ interactive_state["multi_mask"]["masks"] = []
+
+ operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")]
+ return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log
+
+def show_mask(video_state, interactive_state, mask_dropdown):
+ mask_dropdown.sort()
+ select_frame = video_state["origin_images"][video_state["select_frame_number"]]
+ for i in range(len(mask_dropdown)):
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
+
+ operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")]
+ return select_frame, operation_log, operation_log
+
+# tracking vos
+def vos_tracking_video(video_state, interactive_state, mask_dropdown):
+ operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")]
+ model.cutie.clear_memory()
+ if interactive_state["track_end_number"]:
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
+ else:
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
+
+ if interactive_state["multi_mask"]["masks"]:
+ if len(mask_dropdown) == 0:
+ mask_dropdown = ["mask_001"]
+ mask_dropdown.sort()
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
+ for i in range(1,len(mask_dropdown)):
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
+ else:
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
+
+ fps = float(video_state["fps"])
+ # operation error
+ if len(np.unique(template_mask))==1:
+ template_mask[0][0]=1
+ operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
+ # return video_output, video_state, interactive_state, operation_error
+ masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
+ # clear GPU memory
+ model.cutie.clear_memory()
+
+ if interactive_state["track_end_number"]:
+ video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
+ video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
+ video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
+ else:
+ video_state["masks"][video_state["select_frame_number"]:] = masks
+ video_state["logits"][video_state["select_frame_number"]:] = logits
+ video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
+
+ video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
+ interactive_state["inference_times"] += 1
+
+ print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"],
+ interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
+ interactive_state["positive_click_times"],
+ interactive_state["negative_click_times"]))
+
+ #### shanggao code for mask save
+ if interactive_state["mask_save"]:
+ if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
+ os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0]))
+ i = 0
+ print("save mask")
+ for mask in video_state["masks"]:
+ np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
+ i+=1
+ # save_mask(video_state["masks"], video_state["video_name"])
+ #### shanggao code for mask save
+ return video_output, video_state, interactive_state, operation_log, operation_log
+
+def inpaint_video(video_state, *_):
+ operation_log = [("", ""), ("Inpainting finished!", "Normal")]
+
+ # import pdb;pdb.set_trace()
+ frames = video_state["origin_images"]
+ masks = video_state["masks"]
+ # masks = masks * 255
+ fps = int(video_state["fps"])
+
+ total_frames = len(frames)
+ target_frame_count = (total_frames - 1) // 16 * 16 + 1
+ frames = frames[:target_frame_count]
+ masks = masks[:target_frame_count]
+
+ frames_resized = [cv2.resize(frame, (720, 480), interpolation=cv2.INTER_CUBIC) for frame in frames]
+ masks_resized = [cv2.resize(mask, (720, 480), interpolation=cv2.INTER_CUBIC) for mask in masks]
+
+ with torch.no_grad():
+ video_tensor = torch.stack([TF.to_tensor(Image.fromarray(f)) for f in frames_resized], dim=1).unsqueeze(0).to("cuda", torch.float16)
+ mask_tensor = torch.stack([TF.to_tensor(Image.fromarray(m*255)) for m in masks_resized], dim=1).unsqueeze(0).to("cuda", torch.float16)
+ #video_tensor = torch.stack([torch.from_numpy(f).float() for f in frames_resized], dim=1).unsqueeze(0).to("cuda", torch.bfloat16)
+ #mask_tensor = torch.stack([torch.from_numpy(m).float() for m in masks_resized], dim=1).unsqueeze(0).to("cuda", torch.bfloat16)
+
+ output = pipeline(
+ prompt="",
+ video=video_tensor,
+ mask_video=mask_tensor,
+ num_frames=video_tensor.shape[2],
+ num_inference_steps=50
+ ).videos
+
+ output = output.clamp(0, 1).cpu()
+ output_np = (output[0].permute(1, 2, 3, 0).numpy() * 255).astype(np.uint8)
+
+ output_path = f"./result/inpaint/{video_state['video_name']}"
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+ torchvision.io.write_video(output_path, torch.from_numpy(output_np), fps=fps, video_codec="libx264")
+
+ return output_path, operation_log, operation_log
+
+
+# generate video after vos inference
+def generate_video_from_frames(frames, output_path, fps=30):
+ """
+ Generates a video from a list of frames.
+
+ Args:
+ frames (list of numpy arrays): The frames to include in the video.
+ output_path (str): The path to save the generated video.
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
+ """
+ frames = torch.from_numpy(np.asarray(frames))
+ if not os.path.exists(os.path.dirname(output_path)):
+ os.makedirs(os.path.dirname(output_path))
+ fps = int(fps)
+ # import pdb;pdb.set_trace()
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
+ return output_path
+
+def restart():
+ operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")]
+ return {
+ "user_name": "",
+ "video_name": "",
+ "origin_images": None,
+ "painted_images": None,
+ "masks": None,
+ "inpaint_masks": None,
+ "logits": None,
+ "select_frame_number": 0,
+ "fps": 30
+ }, {
+ "inference_times": 0,
+ "negative_click_times" : 0,
+ "positive_click_times": 0,
+ "mask_save": args.mask_save,
+ "multi_mask": {
+ "mask_names": [],
+ "masks": []
+ },
+ "track_end_number": None,
+ }, [[],[]], None, None, None, \
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \
+ gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log)
+
+
+# args, defined in track_anything.py
+args = parse_augment()
+pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
+sam_checkpoint_url_dict = {
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
+}
+checkpoint_fodler = os.path.join('.', 'weights')
+
+sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler)
+cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler)
+# propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'ProPainter.pth'), checkpoint_fodler)
+# raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'raft-things.pth'), checkpoint_fodler)
+# flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), checkpoint_fodler)
+
+# initialize sam, cutie, propainter models
+model = TrackingAnything(sam_checkpoint, cutie_checkpoint, args)
+
+
+title = r"""
ROSE: Remove Objects with Side Effects in Videos
"""
+
+description = r"""
+
+Official Gradio demo for Remove Objects with Side Effects in Videos.
+🔥 ROSE is a robust inpainting algorithm.
+🤗 Try to drop your video, add the masks and get the the inpainting results!
+"""
+
+css = """
+.gradio-container {width: 85% !important; margin: 0 auto !important;}
+.gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important}
+button {border-radius: 8px !important;}
+.add_button {background-color: #4CAF50 !important;}
+.remove_button {background-color: #f44336 !important;}
+.mask_button_group {gap: 10px !important;}
+.video {height: 300px !important;}
+.image {height: 300px !important;}
+.video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;}
+.video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;}
+.margin_center {width: 50% !important; margin: auto !important;}
+.jc_center {justify-content: center !important;}
+body {
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ min-height: 100vh;
+ margin: 0;
+}
+"""
+
+with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface:
+ click_state = gr.State([[],[]])
+
+ interactive_state = gr.State({
+ "inference_times": 0,
+ "negative_click_times" : 0,
+ "positive_click_times": 0,
+ "mask_save": args.mask_save,
+ "multi_mask": {
+ "mask_names": [],
+ "masks": []
+ },
+ "track_end_number": None,
+ }
+ )
+
+ video_state = gr.State(
+ {
+ "user_name": "",
+ "video_name": "",
+ "origin_images": None,
+ "painted_images": None,
+ "masks": None,
+ "inpaint_masks": None,
+ "logits": None,
+ "select_frame_number": 0,
+ "fps": 30
+ }
+ )
+
+ gr.Markdown(title)
+ gr.Markdown(description)
+
+ with gr.Column():
+ # input video
+ gr.Markdown("## Step1: Upload video")
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=2):
+ video_input = gr.Video(elem_classes="video")
+ extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
+ with gr.Column(scale=2):
+ run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")],
+ color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"})
+ video_info = gr.Textbox(label="Video Info")
+
+
+ # add masks
+ step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False)
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=2):
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
+ with gr.Column(scale=2, elem_classes="jc_center"):
+ run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")],
+ color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"})
+ with gr.Row():
+ with gr.Column(scale=2, elem_classes="mask_button_group"):
+ clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False)
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button")
+ Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button")
+ point_prompt = gr.Radio(
+ choices=["Positive", "Negative"],
+ value="Positive",
+ label="Point prompt",
+ interactive=True,
+ visible=False,
+ min_width=100,
+ scale=1)
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
+
+ # output video
+ step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False)
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=2):
+ tracking_video_output = gr.Video(visible=False, elem_classes="video")
+ tracking_video_predict_button = gr.Button(value="1. Tracking", visible=False, elem_classes="margin_center")
+ with gr.Column(scale=2):
+ inpaiting_video_output = gr.Video(visible=False, elem_classes="video")
+ inpaint_video_predict_button = gr.Button(value="2. Inpainting", visible=False, elem_classes="margin_center")
+
+ # first step: get the video information
+ extract_frames_button.click(
+ fn=get_frames_from_video,
+ inputs=[
+ video_input, video_state
+ ],
+ outputs=[video_state, video_info, template_frame,
+ image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
+ tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button, inpaint_video_predict_button, step2_title, step3_title,mask_dropdown, run_status, run_status2]
+ )
+
+ # second step: select images from slider
+ image_selection_slider.release(fn=select_template,
+ inputs=[image_selection_slider, video_state, interactive_state],
+ outputs=[template_frame, video_state, interactive_state, run_status, run_status2], api_name="select_image")
+ track_pause_number_slider.release(fn=get_end_number,
+ inputs=[track_pause_number_slider, video_state, interactive_state],
+ outputs=[template_frame, interactive_state, run_status, run_status2], api_name="end_image")
+
+ # click select image to get mask using sam
+ template_frame.select(
+ fn=sam_refine,
+ inputs=[video_state, point_prompt, click_state, interactive_state],
+ outputs=[template_frame, video_state, interactive_state, run_status, run_status2]
+ )
+
+ # add different mask
+ Add_mask_button.click(
+ fn=add_multi_mask,
+ inputs=[video_state, interactive_state, mask_dropdown],
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status, run_status2]
+ )
+
+ remove_mask_button.click(
+ fn=remove_multi_mask,
+ inputs=[interactive_state, mask_dropdown],
+ outputs=[interactive_state, mask_dropdown, run_status, run_status2]
+ )
+
+ # tracking video from select image and mask
+ tracking_video_predict_button.click(
+ fn=vos_tracking_video,
+ inputs=[video_state, interactive_state, mask_dropdown],
+ outputs=[tracking_video_output, video_state, interactive_state, run_status, run_status2]
+ )
+
+ # inpaint video from select image and mask
+ inpaint_video_predict_button.click(
+ fn=inpaint_video,
+ #inputs=[video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown],
+ inputs=[video_state, mask_dropdown],
+ outputs=[inpaiting_video_output, run_status, run_status2]
+ )
+
+ # click to get mask
+ mask_dropdown.change(
+ fn=show_mask,
+ inputs=[video_state, interactive_state, mask_dropdown],
+ outputs=[template_frame, run_status, run_status2]
+ )
+
+ # clear input
+ video_input.change(
+ fn=restart,
+ inputs=[],
+ outputs=[
+ video_state,
+ interactive_state,
+ click_state,
+ tracking_video_output, inpaiting_video_output,
+ template_frame,
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
+ Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2
+ ],
+ queue=False,
+ show_progress=False)
+
+ video_input.clear(
+ fn=restart,
+ inputs=[],
+ outputs=[
+ video_state,
+ interactive_state,
+ click_state,
+ tracking_video_output, inpaiting_video_output,
+ template_frame,
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
+ Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2
+ ],
+ queue=False,
+ show_progress=False)
+
+ # points clear
+ clear_button_click.click(
+ fn = clear_click,
+ inputs = [video_state, click_state,],
+ outputs = [template_frame,click_state, run_status, run_status2],
+ )
+
+ # set example
+ gr.Markdown("## Examples")
+ gr.Examples(
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4"]],
+ inputs=[video_input],
+ )
+ # gr.Markdown(article)
+
+# iface.queue(concurrency_count=1)
+iface.queue()
+iface.launch(debug=True)
\ No newline at end of file
diff --git a/assets/rose_logo.png b/assets/rose_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..8385233fcf98549a0e38e2b321bd079833237ce5
--- /dev/null
+++ b/assets/rose_logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf454f99eaabcece90cb664c39a45d17b58f8201eba8b220fa155ac22a014c4a
+size 71765
diff --git a/assets/step1.png b/assets/step1.png
new file mode 100644
index 0000000000000000000000000000000000000000..b5b389402c6853ec138e101719ec06096b5bb7a1
--- /dev/null
+++ b/assets/step1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c93010fa938c75ae671e3aa362205f3f2692783930f67a6623e0a438479e7326
+size 308524
diff --git a/assets/step2.png b/assets/step2.png
new file mode 100644
index 0000000000000000000000000000000000000000..04af9bcbc28c46f26bcae95426b0ca32e14bed3e
--- /dev/null
+++ b/assets/step2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:48bdd827f2581da65df1e731325163bd7e5095511ed6e753346f74b71156dc8d
+size 457568
diff --git a/assets/step3.png b/assets/step3.png
new file mode 100644
index 0000000000000000000000000000000000000000..b5f980a93e575887ee438c285ca41deae7e38f84
--- /dev/null
+++ b/assets/step3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8273104c2e558cb0d1edfe91f5e4ca27483815f5e31aad153f199a351c87b12
+size 1143298
diff --git a/configs/wan2.1/wan_civitai.yaml b/configs/wan2.1/wan_civitai.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3a2747bdc9e86a6812c3ca09d2a8653aaea6cb54
--- /dev/null
+++ b/configs/wan2.1/wan_civitai.yaml
@@ -0,0 +1,39 @@
+format: civitai
+pipeline: Wan
+transformer_additional_kwargs:
+ transformer_subpath: ./
+ dict_mapping:
+ in_dim: in_channels
+ dim: hidden_size
+
+vae_kwargs:
+ vae_subpath: Wan2.1_VAE.pth
+ temporal_compression_ratio: 4
+ spatial_compression_ratio: 8
+
+text_encoder_kwargs:
+ text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
+ tokenizer_subpath: google/umt5-xxl
+ text_length: 512
+ vocab: 256384
+ dim: 4096
+ dim_attn: 4096
+ dim_ffn: 10240
+ num_heads: 64
+ num_layers: 24
+ num_buckets: 32
+ shared_pos: False
+ dropout: 0.0
+
+scheduler_kwargs:
+ scheduler_subpath: null
+ num_train_timesteps: 1000
+ shift: 5.0
+ use_dynamic_shifting: false
+ base_shift: 0.5
+ max_shift: 1.15
+ base_image_seq_len: 256
+ max_image_seq_len: 4096
+
+image_encoder_kwargs:
+ image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
\ No newline at end of file
diff --git a/inpainter/base_inpainter.py b/inpainter/base_inpainter.py
new file mode 100644
index 0000000000000000000000000000000000000000..795663dcdf22fbdc8c3e010918884f40e3d2d40f
--- /dev/null
+++ b/inpainter/base_inpainter.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+import os
+import sys
+import cv2
+import numpy as np
+import scipy.ndimage
+from PIL import Image
+from tqdm import tqdm
+
+import torch
+import torchvision
+
+from model.modules.flow_comp_raft import RAFT_bi
+from model.recurrent_flow_completion import RecurrentFlowCompleteNet
+from model.propainter import InpaintGenerator
+from core.utils import to_tensors
+
+import warnings
+warnings.filterwarnings("ignore")
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def resize_frames(frames, size=None):
+ if size is not None:
+ out_size = size
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
+ frames = [f.resize(process_size) for f in frames]
+ else:
+ out_size = frames[0].size
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
+ if not out_size == process_size:
+ frames = [f.resize(process_size) for f in frames]
+
+ return frames, process_size, out_size
+
+
+def read_frame_from_videos(frame_root):
+ if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
+ video_name = os.path.basename(frame_root)[:-4]
+ vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB
+ frames = list(vframes.numpy())
+ frames = [Image.fromarray(f) for f in frames]
+ fps = info['video_fps']
+ else:
+ video_name = os.path.basename(frame_root)
+ frames = []
+ fr_lst = sorted(os.listdir(frame_root))
+ for fr in fr_lst:
+ frame = cv2.imread(os.path.join(frame_root, fr))
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ frames.append(frame)
+ fps = None
+ size = frames[0].size
+
+ return frames, fps, size, video_name
+
+
+def binary_mask(mask, th=0.1):
+ mask[mask>th] = 1
+ mask[mask<=th] = 0
+ return mask
+
+
+def extrapolation(video_ori, scale):
+ """Prepares the data for video outpainting.
+ """
+ nFrame = len(video_ori)
+ imgW, imgH = video_ori[0].size
+
+ # Defines new FOV.
+ imgH_extr = int(scale[0] * imgH)
+ imgW_extr = int(scale[1] * imgW)
+ imgH_extr = imgH_extr - imgH_extr % 8
+ imgW_extr = imgW_extr - imgW_extr % 8
+ H_start = int((imgH_extr - imgH) / 2)
+ W_start = int((imgW_extr - imgW) / 2)
+
+ # Extrapolates the FOV for video.
+ frames = []
+ for v in video_ori:
+ frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8)
+ frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v
+ frames.append(Image.fromarray(frame))
+
+ # Generates the mask for missing region.
+ masks_dilated = []
+ flow_masks = []
+
+ dilate_h = 4 if H_start > 10 else 0
+ dilate_w = 4 if W_start > 10 else 0
+ mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8)
+
+ mask[H_start+dilate_h: H_start+imgH-dilate_h,
+ W_start+dilate_w: W_start+imgW-dilate_w] = 0
+ flow_masks.append(Image.fromarray(mask * 255))
+
+ mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0
+ masks_dilated.append(Image.fromarray(mask * 255))
+
+ flow_masks = flow_masks * nFrame
+ masks_dilated = masks_dilated * nFrame
+
+ return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr)
+
+
+def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
+ ref_index = []
+ if ref_num == -1:
+ for i in range(0, length, ref_stride):
+ if i not in neighbor_ids:
+ ref_index.append(i)
+ else:
+ start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
+ end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
+ for i in range(start_idx, end_idx, ref_stride):
+ if i not in neighbor_ids:
+ if len(ref_index) > ref_num:
+ break
+ ref_index.append(i)
+ return ref_index
+
+
+def read_mask_demo(masks, length, size, flow_mask_dilates=8, mask_dilates=5):
+ masks_img = []
+ masks_dilated = []
+ flow_masks = []
+
+ for mp in masks:
+ masks_img.append(Image.fromarray(mp.astype('uint8')))
+
+ for mask_img in masks_img:
+ if size is not None:
+ mask_img = mask_img.resize(size, Image.NEAREST)
+ mask_img = np.array(mask_img.convert('L'))
+
+ # Dilate 8 pixel so that all known pixel is trustworthy
+ if flow_mask_dilates > 0:
+ flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
+ else:
+ flow_mask_img = binary_mask(mask_img).astype(np.uint8)
+
+ flow_masks.append(Image.fromarray(flow_mask_img * 255))
+
+ if mask_dilates > 0:
+ mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
+ else:
+ mask_img = binary_mask(mask_img).astype(np.uint8)
+ masks_dilated.append(Image.fromarray(mask_img * 255))
+
+ if len(masks_img) == 1:
+ flow_masks = flow_masks * length
+ masks_dilated = masks_dilated * length
+
+ return flow_masks, masks_dilated
+
+
+class ProInpainter:
+ def __init__(self, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, device="cuda:0", use_half=True):
+ self.device = device
+ self.use_half = use_half
+ if self.device == torch.device('cpu'):
+ self.use_half = False
+
+ ##############################################
+ # set up RAFT and flow competition model
+ ##############################################
+ self.fix_raft = RAFT_bi(raft_checkpoint, self.device)
+
+ self.fix_flow_complete = RecurrentFlowCompleteNet(flow_completion_checkpoint)
+ for p in self.fix_flow_complete.parameters():
+ p.requires_grad = False
+ self.fix_flow_complete.to(self.device)
+ self.fix_flow_complete.eval()
+
+ ##############################################
+ # set up ProPainter model
+ ##############################################
+ self.model = InpaintGenerator(model_path=propainter_checkpoint).to(self.device)
+ self.model.eval()
+
+ if self.use_half:
+ self.fix_flow_complete = self.fix_flow_complete.half()
+ self.model = self.model.half()
+
+ def inpaint(self, npframes, masks, ratio=1.0, dilate_radius=4, raft_iter=20, subvideo_length=80, neighbor_length=10, ref_stride=10):
+ """
+ Perform Inpainting for video subsets
+
+ Output:
+ inpainted_frames: numpy array, T, H, W, 3
+ """
+
+ frames = []
+ for i in range(len(npframes)):
+ frames.append(Image.fromarray(npframes[i].astype('uint8'), mode="RGB"))
+ del npframes
+
+ size = frames[0].size
+ # The ouput size should be divided by 2 so that it can encoded by libx264
+ size = (int(ratio*size[0])//2*2, int(ratio*size[1])//2*2)
+
+ frames_len = len(frames)
+ frames, size, out_size = resize_frames(frames, size)
+ flow_masks, masks_dilated = read_mask_demo(masks, frames_len, size, dilate_radius, dilate_radius)
+ w, h = size
+
+ frames_inp = [np.array(f).astype(np.uint8) for f in frames]
+ frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
+ flow_masks = to_tensors()(flow_masks).unsqueeze(0)
+ masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
+ frames, flow_masks, masks_dilated = frames.to(self.device), flow_masks.to(self.device), masks_dilated.to(self.device)
+
+ ##############################################
+ # ProPainter inference
+ ##############################################
+ video_length = frames.size(1)
+ with torch.no_grad():
+ # ---- compute flow ----
+ if frames.size(-1) <= 640:
+ short_clip_len = 12
+ elif frames.size(-1) <= 720:
+ short_clip_len = 8
+ elif frames.size(-1) <= 1280:
+ short_clip_len = 4
+ else:
+ short_clip_len = 2
+
+ # use fp32 for RAFT
+ if frames.size(1) > short_clip_len:
+ gt_flows_f_list, gt_flows_b_list = [], []
+ for f in range(0, video_length, short_clip_len):
+ end_f = min(video_length, f + short_clip_len)
+ if f == 0:
+ flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter)
+ else:
+ flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter)
+
+ gt_flows_f_list.append(flows_f)
+ gt_flows_b_list.append(flows_b)
+ torch.cuda.empty_cache()
+
+ gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
+ gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
+ gt_flows_bi = (gt_flows_f, gt_flows_b)
+ else:
+ gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
+ torch.cuda.empty_cache()
+
+ if self.use_half:
+ frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
+ gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
+
+ # ---- complete flow ----
+ flow_length = gt_flows_bi[0].size(1)
+ if flow_length > subvideo_length:
+ pred_flows_f, pred_flows_b = [], []
+ pad_len = 5
+ for f in range(0, flow_length, subvideo_length):
+ s_f = max(0, f - pad_len)
+ e_f = min(flow_length, f + subvideo_length + pad_len)
+ pad_len_s = max(0, f) - s_f
+ pad_len_e = e_f - min(flow_length, f + subvideo_length)
+ pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
+ flow_masks[:, s_f:e_f+1])
+ pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
+ pred_flows_bi_sub,
+ flow_masks[:, s_f:e_f+1])
+
+ pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
+ pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
+ torch.cuda.empty_cache()
+
+ pred_flows_f = torch.cat(pred_flows_f, dim=1)
+ pred_flows_b = torch.cat(pred_flows_b, dim=1)
+ pred_flows_bi = (pred_flows_f, pred_flows_b)
+ else:
+ pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
+ pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
+ torch.cuda.empty_cache()
+
+ # ---- image propagation ----
+ masked_frames = frames * (1 - masks_dilated)
+ subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
+ if video_length > subvideo_length_img_prop:
+ updated_frames, updated_masks = [], []
+ pad_len = 10
+ for f in range(0, video_length, subvideo_length_img_prop):
+ s_f = max(0, f - pad_len)
+ e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
+ pad_len_s = max(0, f) - s_f
+ pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
+
+ b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
+ pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
+ prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
+ pred_flows_bi_sub,
+ masks_dilated[:, s_f:e_f],
+ 'nearest')
+ updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
+ prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
+ updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
+
+ updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
+ updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
+ torch.cuda.empty_cache()
+
+ updated_frames = torch.cat(updated_frames, dim=1)
+ updated_masks = torch.cat(updated_masks, dim=1)
+ else:
+ b, t, _, _, _ = masks_dilated.size()
+ prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
+ updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
+ updated_masks = updated_local_masks.view(b, t, 1, h, w)
+ torch.cuda.empty_cache()
+
+ ori_frames = frames_inp
+ comp_frames = [None] * video_length
+
+ neighbor_stride = neighbor_length // 2
+ if video_length > subvideo_length:
+ ref_num = subvideo_length // ref_stride
+ else:
+ ref_num = -1
+
+ # ---- feature propagation + transformer ----
+ for f in tqdm(range(0, video_length, neighbor_stride)):
+ neighbor_ids = [
+ i for i in range(max(0, f - neighbor_stride),
+ min(video_length, f + neighbor_stride + 1))
+ ]
+ ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)
+ selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
+ selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
+ selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
+ selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
+
+ with torch.no_grad():
+ # 1.0 indicates mask
+ l_t = len(neighbor_ids)
+
+ # pred_img = selected_imgs # results of image propagation
+ pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
+
+ pred_img = pred_img.view(-1, 3, h, w)
+
+ pred_img = (pred_img + 1) / 2
+ pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
+ binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
+ 0, 2, 3, 1).numpy().astype(np.uint8)
+ for i in range(len(neighbor_ids)):
+ idx = neighbor_ids[i]
+ img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
+ + ori_frames[idx] * (1 - binary_masks[i])
+ if comp_frames[idx] is None:
+ comp_frames[idx] = img
+ else:
+ comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
+
+ comp_frames[idx] = comp_frames[idx].astype(np.uint8)
+
+ torch.cuda.empty_cache()
+
+ # need to return numpy array, T, H, W, 3
+ comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
+
+ return comp_frames
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e5f9a8a4ea08061ca95a057897b8601de94ead7e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+progressbar2
+gdown
+gitpython
+git+https://github.com/cheind/py-thin-plate-spline
+hickle
+tensorboard
+numpy
+git+https://github.com/facebookresearch/segment-anything.git
+gradio
+opencv-python
+matplotlib
+pyyaml
+av
+openmim
+tqdm
+psutil
+omegaconf
\ No newline at end of file
diff --git a/rose/__init__.py b/rose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/rose/data/bucket_sampler.py b/rose/data/bucket_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..24b4160f3d2bbadca1d23e90c23e887ea6d15f70
--- /dev/null
+++ b/rose/data/bucket_sampler.py
@@ -0,0 +1,379 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
+ Sized, TypeVar, Union)
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import BatchSampler, Dataset, Sampler
+
+ASPECT_RATIO_512 = {
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
+}
+ASPECT_RATIO_RANDOM_CROP_512 = {
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
+}
+ASPECT_RATIO_RANDOM_CROP_PROB = [
+ 1, 2,
+ 4, 4, 4, 4,
+ 8, 8, 8,
+ 4, 4, 4, 4,
+ 2, 1
+]
+ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
+
+def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
+ aspect_ratio = height / width
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
+ return ratios[closest_ratio], float(closest_ratio)
+
+def get_image_size_without_loading(path):
+ with Image.open(path) as img:
+ return img.size # (width, height)
+
+class RandomSampler(Sampler[int]):
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+
+ If with replacement, then user can specify :attr:`num_samples` to draw.
+
+ Args:
+ data_source (Dataset): dataset to sample from
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
+ generator (Generator): Generator used in sampling.
+ """
+
+ data_source: Sized
+ replacement: bool
+
+ def __init__(self, data_source: Sized, replacement: bool = False,
+ num_samples: Optional[int] = None, generator=None) -> None:
+ self.data_source = data_source
+ self.replacement = replacement
+ self._num_samples = num_samples
+ self.generator = generator
+ self._pos_start = 0
+
+ if not isinstance(self.replacement, bool):
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
+
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
+
+ @property
+ def num_samples(self) -> int:
+ # dataset size might change at runtime
+ if self._num_samples is None:
+ return len(self.data_source)
+ return self._num_samples
+
+ def __iter__(self) -> Iterator[int]:
+ n = len(self.data_source)
+ if self.generator is None:
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ generator = torch.Generator()
+ generator.manual_seed(seed)
+ else:
+ generator = self.generator
+
+ if self.replacement:
+ for _ in range(self.num_samples // 32):
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
+ else:
+ for _ in range(self.num_samples // n):
+ xx = torch.randperm(n, generator=generator).tolist()
+ if self._pos_start >= n:
+ self._pos_start = 0
+ print("xx top 10", xx[:10], self._pos_start)
+ for idx in range(self._pos_start, n):
+ yield xx[idx]
+ self._pos_start = (self._pos_start + 1) % n
+ self._pos_start = 0
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
+
+ def __len__(self) -> int:
+ return self.num_samples
+
+class AspectRatioBatchImageSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+ def __init__(
+ self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ train_folder: str = None,
+ aspect_ratios: dict = ASPECT_RATIO_512,
+ drop_last: bool = False,
+ config=None,
+ **kwargs
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.train_folder = train_folder
+ self.batch_size = batch_size
+ self.aspect_ratios = aspect_ratios
+ self.drop_last = drop_last
+ self.config = config
+ # buckets for each aspect ratio
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
+ # [str(k) for k, v in aspect_ratios]
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
+
+ def __iter__(self):
+ for idx in self.sampler:
+ try:
+ image_dict = self.dataset[idx]
+
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
+ if width is None or height is None:
+ image_id, name = image_dict['file_path'], image_dict['text']
+ if self.train_folder is None:
+ image_dir = image_id
+ else:
+ image_dir = os.path.join(self.train_folder, image_id)
+
+ width, height = get_image_size_without_loading(image_dir)
+
+ ratio = height / width # self.dataset[idx]
+ else:
+ height = int(height)
+ width = int(width)
+ ratio = height / width # self.dataset[idx]
+ except Exception as e:
+ print(e)
+ continue
+ # find the closest aspect ratio
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+ if closest_ratio not in self.current_available_bucket_keys:
+ continue
+ bucket = self._aspect_ratio_buckets[closest_ratio]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
+
+class AspectRatioBatchSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+ def __init__(
+ self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ video_folder: str = None,
+ train_data_format: str = "webvid",
+ aspect_ratios: dict = ASPECT_RATIO_512,
+ drop_last: bool = False,
+ config=None,
+ **kwargs
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.video_folder = video_folder
+ self.train_data_format = train_data_format
+ self.batch_size = batch_size
+ self.aspect_ratios = aspect_ratios
+ self.drop_last = drop_last
+ self.config = config
+ # buckets for each aspect ratio
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
+ # [str(k) for k, v in aspect_ratios]
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
+
+ def __iter__(self):
+ for idx in self.sampler:
+ try:
+ video_dict = self.dataset[idx]
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
+
+ if width is None or height is None:
+ if self.train_data_format == "normal":
+ video_id, name = video_dict['file_path'], video_dict['text']
+ if self.video_folder is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.video_folder, video_id)
+ else:
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
+ cap = cv2.VideoCapture(video_dir)
+
+ # 获取视频尺寸
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
+
+ ratio = height / width # self.dataset[idx]
+ else:
+ height = int(height)
+ width = int(width)
+ ratio = height / width # self.dataset[idx]
+ except Exception as e:
+ print(e, self.dataset[idx], "This item is error, please check it.")
+ continue
+ # find the closest aspect ratio
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+ if closest_ratio not in self.current_available_bucket_keys:
+ continue
+ bucket = self._aspect_ratio_buckets[closest_ratio]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
+
+class AspectRatioBatchImageVideoSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+
+ def __init__(self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ train_folder: str = None,
+ aspect_ratios: dict = ASPECT_RATIO_512,
+ drop_last: bool = False
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.train_folder = train_folder
+ self.batch_size = batch_size
+ self.aspect_ratios = aspect_ratios
+ self.drop_last = drop_last
+
+ # buckets for each aspect ratio
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
+ self.bucket = {
+ 'image':{ratio: [] for ratio in aspect_ratios},
+ 'video':{ratio: [] for ratio in aspect_ratios}
+ }
+
+ def __iter__(self):
+ for idx in self.sampler:
+ content_type = self.dataset[idx].get('type', 'image')
+ if content_type == 'image':
+ try:
+ image_dict = self.dataset[idx]
+
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
+ if width is None or height is None:
+ image_id, name = image_dict['file_path'], image_dict['text']
+ if self.train_folder is None:
+ image_dir = image_id
+ else:
+ image_dir = os.path.join(self.train_folder, image_id)
+
+ width, height = get_image_size_without_loading(image_dir)
+
+ ratio = height / width # self.dataset[idx]
+ else:
+ height = int(height)
+ width = int(width)
+ ratio = height / width # self.dataset[idx]
+ except Exception as e:
+ print(e, self.dataset[idx], "This item is error, please check it.")
+ continue
+ # find the closest aspect ratio
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+ if closest_ratio not in self.current_available_bucket_keys:
+ continue
+ bucket = self.bucket['image'][closest_ratio]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
+ else:
+ try:
+ video_dict = self.dataset[idx]
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
+
+ if width is None or height is None:
+ video_id, name = video_dict['file_path'], video_dict['text']
+ if self.train_folder is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.train_folder, video_id)
+ cap = cv2.VideoCapture(video_dir)
+
+ # 获取视频尺寸
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
+
+ ratio = height / width # self.dataset[idx]
+ else:
+ height = int(height)
+ width = int(width)
+ ratio = height / width # self.dataset[idx]
+ except Exception as e:
+ print(e, self.dataset[idx], "This item is error, please check it.")
+ continue
+ # find the closest aspect ratio
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+ if closest_ratio not in self.current_available_bucket_keys:
+ continue
+ bucket = self.bucket['video'][closest_ratio]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
\ No newline at end of file
diff --git a/rose/data/dataset_image.py b/rose/data/dataset_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..098d49a4044f8daa351cd01b4cb1ec5415412e80
--- /dev/null
+++ b/rose/data/dataset_image.py
@@ -0,0 +1,76 @@
+import json
+import os
+import random
+
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+from torch.utils.data.dataset import Dataset
+
+
+class CC15M(Dataset):
+ def __init__(
+ self,
+ json_path,
+ video_folder=None,
+ resolution=512,
+ enable_bucket=False,
+ ):
+ print(f"loading annotations from {json_path} ...")
+ self.dataset = json.load(open(json_path, 'r'))
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+
+ self.enable_bucket = enable_bucket
+ self.video_folder = video_folder
+
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
+ self.pixel_transforms = transforms.Compose([
+ transforms.Resize(resolution[0]),
+ transforms.CenterCrop(resolution),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ])
+
+ def get_batch(self, idx):
+ video_dict = self.dataset[idx]
+ video_id, name = video_dict['file_path'], video_dict['text']
+
+ if self.video_folder is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.video_folder, video_id)
+
+ pixel_values = Image.open(video_dir).convert("RGB")
+ return pixel_values, name
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ pixel_values, name = self.get_batch(idx)
+ break
+ except Exception as e:
+ print(e)
+ idx = random.randint(0, self.length-1)
+
+ if not self.enable_bucket:
+ pixel_values = self.pixel_transforms(pixel_values)
+ else:
+ pixel_values = np.array(pixel_values)
+
+ sample = dict(pixel_values=pixel_values, text=name)
+ return sample
+
+if __name__ == "__main__":
+ dataset = CC15M(
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
+ resolution=512,
+ )
+
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
+ for idx, batch in enumerate(dataloader):
+ print(batch["pixel_values"].shape, len(batch["text"]))
\ No newline at end of file
diff --git a/rose/data/dataset_image_video.py b/rose/data/dataset_image_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..0288432e4503791991389435677f3419a8f3ce11
--- /dev/null
+++ b/rose/data/dataset_image_video.py
@@ -0,0 +1,589 @@
+import csv
+import io
+import json
+import math
+import os
+import random
+from threading import Thread
+
+import albumentations
+import cv2
+import gc
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+
+from func_timeout import func_timeout, FunctionTimedOut
+from decord import VideoReader
+from PIL import Image
+from torch.utils.data import BatchSampler, Sampler
+from torch.utils.data.dataset import Dataset
+from contextlib import contextmanager
+
+VIDEO_READER_TIMEOUT = 20
+
+def get_random_mask(shape, image_start_only=False):
+ f, c, h, w = shape
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
+
+ if not image_start_only:
+ if f != 1:
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
+ else:
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
+ if mask_index == 0:
+ center_x = torch.randint(0, w, (1,)).item()
+ center_y = torch.randint(0, h, (1,)).item()
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
+
+ start_x = max(center_x - block_size_x // 2, 0)
+ end_x = min(center_x + block_size_x // 2, w)
+ start_y = max(center_y - block_size_y // 2, 0)
+ end_y = min(center_y + block_size_y // 2, h)
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
+ elif mask_index == 1:
+ mask[:, :, :, :] = 1
+ elif mask_index == 2:
+ mask_frame_index = np.random.randint(1, 5)
+ mask[mask_frame_index:, :, :, :] = 1
+ elif mask_index == 3:
+ mask_frame_index = np.random.randint(1, 5)
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
+ elif mask_index == 4:
+ center_x = torch.randint(0, w, (1,)).item()
+ center_y = torch.randint(0, h, (1,)).item()
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
+
+ start_x = max(center_x - block_size_x // 2, 0)
+ end_x = min(center_x + block_size_x // 2, w)
+ start_y = max(center_y - block_size_y // 2, 0)
+ end_y = min(center_y + block_size_y // 2, h)
+
+ mask_frame_before = np.random.randint(0, f // 2)
+ mask_frame_after = np.random.randint(f // 2, f)
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
+ elif mask_index == 5:
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
+ elif mask_index == 6:
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
+
+ for i in frames_to_mask:
+ block_height = random.randint(1, h // 4)
+ block_width = random.randint(1, w // 4)
+ top_left_y = random.randint(0, h - block_height)
+ top_left_x = random.randint(0, w - block_width)
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
+ elif mask_index == 7:
+ center_x = torch.randint(0, w, (1,)).item()
+ center_y = torch.randint(0, h, (1,)).item()
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
+
+ for i in range(h):
+ for j in range(w):
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
+ mask[:, :, i, j] = 1
+ elif mask_index == 8:
+ center_x = torch.randint(0, w, (1,)).item()
+ center_y = torch.randint(0, h, (1,)).item()
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
+ for i in range(h):
+ for j in range(w):
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
+ mask[:, :, i, j] = 1
+ elif mask_index == 9:
+ for idx in range(f):
+ if np.random.rand() > 0.5:
+ mask[idx, :, :, :] = 1
+ else:
+ raise ValueError(f"The mask_index {mask_index} is not define")
+ else:
+ if f != 1:
+ mask[1:, :, :, :] = 1
+ else:
+ mask[:, :, :, :] = 1
+ return mask
+
+class ImageVideoSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+
+ def __init__(self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ drop_last: bool = False
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ # buckets for each aspect ratio
+ self.bucket = {'image':[], 'video':[]}
+
+ def __iter__(self):
+ for idx in self.sampler:
+ content_type = self.dataset.dataset[idx].get('type', 'image')
+ self.bucket[content_type].append(idx)
+
+ # yield a batch of indices in the same aspect ratio group
+ if len(self.bucket['video']) == self.batch_size:
+ bucket = self.bucket['video']
+ yield bucket[:]
+ del bucket[:]
+ elif len(self.bucket['image']) == self.batch_size:
+ bucket = self.bucket['image']
+ yield bucket[:]
+ del bucket[:]
+
+@contextmanager
+def VideoReader_contextmanager(*args, **kwargs):
+ vr = VideoReader(*args, **kwargs)
+ try:
+ yield vr
+ finally:
+ del vr
+ gc.collect()
+
+def get_video_reader_batch(video_reader, batch_index):
+ frames = video_reader.get_batch(batch_index).asnumpy()
+ return frames
+
+def resize_frame(frame, target_short_side):
+ h, w, _ = frame.shape
+ if h < w:
+ if target_short_side > h:
+ return frame
+ new_h = target_short_side
+ new_w = int(target_short_side * w / h)
+ else:
+ if target_short_side > w:
+ return frame
+ new_w = target_short_side
+ new_h = int(target_short_side * h / w)
+
+ resized_frame = cv2.resize(frame, (new_w, new_h))
+ return resized_frame
+
+class ImageVideoDataset(Dataset):
+ def __init__(
+ self,
+ ann_path, data_root=None,
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
+ image_sample_size=512,
+ video_repeat=0,
+ text_drop_ratio=0.1,
+ enable_bucket=False,
+ video_length_drop_start=0.0,
+ video_length_drop_end=1.0,
+ enable_inpaint=False,
+ ):
+ # Loading annotations from files
+ print(f"loading annotations from {ann_path} ...")
+ if ann_path.endswith('.csv'):
+ with open(ann_path, 'r') as csvfile:
+ dataset = list(csv.DictReader(csvfile))
+ elif ann_path.endswith('.json'):
+ dataset = json.load(open(ann_path))
+
+ self.data_root = data_root
+
+ # It's used to balance num of images and videos.
+ self.dataset = []
+ for data in dataset:
+ if data.get('type', 'image') != 'video':
+ self.dataset.append(data)
+ if video_repeat > 0:
+ for _ in range(video_repeat):
+ for data in dataset:
+ if data.get('type', 'image') == 'video':
+ self.dataset.append(data)
+ del dataset
+
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+ # TODO: enable bucket training
+ self.enable_bucket = enable_bucket
+ self.text_drop_ratio = text_drop_ratio
+ self.enable_inpaint = enable_inpaint
+
+ self.video_length_drop_start = video_length_drop_start
+ self.video_length_drop_end = video_length_drop_end
+
+ # Video params
+ self.video_sample_stride = video_sample_stride
+ self.video_sample_n_frames = video_sample_n_frames
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
+ self.video_transforms = transforms.Compose(
+ [
+ transforms.Resize(min(self.video_sample_size)),
+ transforms.CenterCrop(self.video_sample_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ # Image params
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
+ self.image_transforms = transforms.Compose([
+ transforms.Resize(min(self.image_sample_size)),
+ transforms.CenterCrop(self.image_sample_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
+ ])
+
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
+
+ def get_batch(self, idx):
+ data_info = self.dataset[idx % len(self.dataset)]
+
+ if data_info.get('type', 'image')=='video':
+ video_id, text = data_info['file_path'], data_info['text']
+
+ if self.data_root is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.data_root, video_id)
+
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
+ min_sample_n_frames = min(
+ self.video_sample_n_frames,
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
+ )
+ if min_sample_n_frames == 0:
+ raise ValueError(f"No Frames in video.")
+
+ video_length = int(self.video_length_drop_end * len(video_reader))
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
+
+ try:
+ sample_args = (video_reader, batch_index)
+ pixel_values = func_timeout(
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+ )
+ resized_frames = []
+ for i in range(len(pixel_values)):
+ frame = pixel_values[i]
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
+ resized_frames.append(resized_frame)
+ pixel_values = np.array(resized_frames)
+ except FunctionTimedOut:
+ raise ValueError(f"Read {idx} timeout.")
+ except Exception as e:
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+ if not self.enable_bucket:
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
+ pixel_values = pixel_values / 255.
+ del video_reader
+ else:
+ pixel_values = pixel_values
+
+ if not self.enable_bucket:
+ pixel_values = self.video_transforms(pixel_values)
+
+ # Random use no text generation
+ if random.random() < self.text_drop_ratio:
+ text = ''
+ return pixel_values, text, 'video'
+ else:
+ image_path, text = data_info['file_path'], data_info['text']
+ if self.data_root is not None:
+ image_path = os.path.join(self.data_root, image_path)
+ image = Image.open(image_path).convert('RGB')
+ if not self.enable_bucket:
+ image = self.image_transforms(image).unsqueeze(0)
+ else:
+ image = np.expand_dims(np.array(image), 0)
+ if random.random() < self.text_drop_ratio:
+ text = ''
+ return image, text, 'image'
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ data_info = self.dataset[idx % len(self.dataset)]
+ data_type = data_info.get('type', 'image')
+ while True:
+ sample = {}
+ try:
+ data_info_local = self.dataset[idx % len(self.dataset)]
+ data_type_local = data_info_local.get('type', 'image')
+ if data_type_local != data_type:
+ raise ValueError("data_type_local != data_type")
+
+ pixel_values, name, data_type = self.get_batch(idx)
+ sample["pixel_values"] = pixel_values
+ sample["text"] = name
+ sample["data_type"] = data_type
+ sample["idx"] = idx
+
+ if len(sample) > 0:
+ break
+ except Exception as e:
+ print(e, self.dataset[idx % len(self.dataset)])
+ idx = random.randint(0, self.length-1)
+
+ if self.enable_inpaint and not self.enable_bucket:
+ mask = get_random_mask(pixel_values.size())
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+ sample["mask_pixel_values"] = mask_pixel_values
+ sample["mask"] = mask
+
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
+ sample["clip_pixel_values"] = clip_pixel_values
+
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
+ if (mask == 1).all():
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
+ sample["ref_pixel_values"] = ref_pixel_values
+
+ return sample
+
+
+class ImageVideoControlDataset(Dataset):
+ def __init__(
+ self,
+ ann_path, data_root=None,
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
+ image_sample_size=512,
+ video_repeat=0,
+ text_drop_ratio=0.1,
+ enable_bucket=False,
+ video_length_drop_start=0.0,
+ video_length_drop_end=1.0,
+ enable_inpaint=False,
+ ):
+ # Loading annotations from files
+ print(f"loading annotations from {ann_path} ...")
+ if ann_path.endswith('.csv'):
+ with open(ann_path, 'r') as csvfile:
+ dataset = list(csv.DictReader(csvfile))
+ elif ann_path.endswith('.json'):
+ dataset = json.load(open(ann_path))
+
+ self.data_root = data_root
+
+ # It's used to balance num of images and videos.
+ self.dataset = []
+ for data in dataset:
+ if data.get('type', 'image') != 'video':
+ self.dataset.append(data)
+ if video_repeat > 0:
+ for _ in range(video_repeat):
+ for data in dataset:
+ if data.get('type', 'image') == 'video':
+ self.dataset.append(data)
+ del dataset
+
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+ # TODO: enable bucket training
+ self.enable_bucket = enable_bucket
+ self.text_drop_ratio = text_drop_ratio
+ self.enable_inpaint = enable_inpaint
+
+ self.video_length_drop_start = video_length_drop_start
+ self.video_length_drop_end = video_length_drop_end
+
+ # Video params
+ self.video_sample_stride = video_sample_stride
+ self.video_sample_n_frames = video_sample_n_frames
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
+ self.video_transforms = transforms.Compose(
+ [
+ transforms.Resize(min(self.video_sample_size)),
+ transforms.CenterCrop(self.video_sample_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ # Image params
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
+ self.image_transforms = transforms.Compose([
+ transforms.Resize(min(self.image_sample_size)),
+ transforms.CenterCrop(self.image_sample_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
+ ])
+
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
+
+ def get_batch(self, idx):
+ data_info = self.dataset[idx % len(self.dataset)]
+ video_id, text = data_info['file_path'], data_info['text']
+
+ if data_info.get('type', 'image')=='video':
+ if self.data_root is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.data_root, video_id)
+
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
+ min_sample_n_frames = min(
+ self.video_sample_n_frames,
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
+ )
+ if min_sample_n_frames == 0:
+ raise ValueError(f"No Frames in video.")
+
+ video_length = int(self.video_length_drop_end * len(video_reader))
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
+
+ try:
+ sample_args = (video_reader, batch_index)
+ pixel_values = func_timeout(
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+ )
+ resized_frames = []
+ for i in range(len(pixel_values)):
+ frame = pixel_values[i]
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
+ resized_frames.append(resized_frame)
+ pixel_values = np.array(resized_frames)
+ except FunctionTimedOut:
+ raise ValueError(f"Read {idx} timeout.")
+ except Exception as e:
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+ if not self.enable_bucket:
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
+ pixel_values = pixel_values / 255.
+ del video_reader
+ else:
+ pixel_values = pixel_values
+
+ if not self.enable_bucket:
+ pixel_values = self.video_transforms(pixel_values)
+
+ # Random use no text generation
+ if random.random() < self.text_drop_ratio:
+ text = ''
+
+ control_video_id = data_info['control_file_path']
+
+ if self.data_root is None:
+ control_video_id = control_video_id
+ else:
+ control_video_id = os.path.join(self.data_root, control_video_id)
+
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
+ try:
+ sample_args = (control_video_reader, batch_index)
+ control_pixel_values = func_timeout(
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+ )
+ resized_frames = []
+ for i in range(len(control_pixel_values)):
+ frame = control_pixel_values[i]
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
+ resized_frames.append(resized_frame)
+ control_pixel_values = np.array(resized_frames)
+ except FunctionTimedOut:
+ raise ValueError(f"Read {idx} timeout.")
+ except Exception as e:
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+ if not self.enable_bucket:
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
+ control_pixel_values = control_pixel_values / 255.
+ del control_video_reader
+ else:
+ control_pixel_values = control_pixel_values
+
+ if not self.enable_bucket:
+ control_pixel_values = self.video_transforms(control_pixel_values)
+ return pixel_values, control_pixel_values, text, "video"
+ else:
+ image_path, text = data_info['file_path'], data_info['text']
+ if self.data_root is not None:
+ image_path = os.path.join(self.data_root, image_path)
+ image = Image.open(image_path).convert('RGB')
+ if not self.enable_bucket:
+ image = self.image_transforms(image).unsqueeze(0)
+ else:
+ image = np.expand_dims(np.array(image), 0)
+
+ if random.random() < self.text_drop_ratio:
+ text = ''
+
+ control_image_id = data_info['control_file_path']
+
+ if self.data_root is None:
+ control_image_id = control_image_id
+ else:
+ control_image_id = os.path.join(self.data_root, control_image_id)
+
+ control_image = Image.open(control_image_id).convert('RGB')
+ if not self.enable_bucket:
+ control_image = self.image_transforms(control_image).unsqueeze(0)
+ else:
+ control_image = np.expand_dims(np.array(control_image), 0)
+ return image, control_image, text, 'image'
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ data_info = self.dataset[idx % len(self.dataset)]
+ data_type = data_info.get('type', 'image')
+ while True:
+ sample = {}
+ try:
+ data_info_local = self.dataset[idx % len(self.dataset)]
+ data_type_local = data_info_local.get('type', 'image')
+ if data_type_local != data_type:
+ raise ValueError("data_type_local != data_type")
+
+ pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
+ sample["pixel_values"] = pixel_values
+ sample["control_pixel_values"] = control_pixel_values
+ sample["text"] = name
+ sample["data_type"] = data_type
+ sample["idx"] = idx
+
+ if len(sample) > 0:
+ break
+ except Exception as e:
+ print(e, self.dataset[idx % len(self.dataset)])
+ idx = random.randint(0, self.length-1)
+
+ if self.enable_inpaint and not self.enable_bucket:
+ mask = get_random_mask(pixel_values.size())
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+ sample["mask_pixel_values"] = mask_pixel_values
+ sample["mask"] = mask
+
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
+ sample["clip_pixel_values"] = clip_pixel_values
+
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
+ if (mask == 1).all():
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
+ sample["ref_pixel_values"] = ref_pixel_values
+
+ return sample
diff --git a/rose/data/dataset_video.py b/rose/data/dataset_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..c78367d0973ceb1abdcd005947612d16e2480831
--- /dev/null
+++ b/rose/data/dataset_video.py
@@ -0,0 +1,262 @@
+import csv
+import gc
+import io
+import json
+import math
+import os
+import random
+from contextlib import contextmanager
+from threading import Thread
+
+import albumentations
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from decord import VideoReader
+from einops import rearrange
+from func_timeout import FunctionTimedOut, func_timeout
+from PIL import Image
+from torch.utils.data import BatchSampler, Sampler
+from torch.utils.data.dataset import Dataset
+
+VIDEO_READER_TIMEOUT = 20
+
+def get_random_mask(shape):
+ f, c, h, w = shape
+
+ mask_index = np.random.randint(0, 4)
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
+ if mask_index == 0:
+ mask[1:, :, :, :] = 1
+ elif mask_index == 1:
+ mask_frame_index = 1
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
+ elif mask_index == 2:
+ center_x = torch.randint(0, w, (1,)).item()
+ center_y = torch.randint(0, h, (1,)).item()
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
+
+ start_x = max(center_x - block_size_x // 2, 0)
+ end_x = min(center_x + block_size_x // 2, w)
+ start_y = max(center_y - block_size_y // 2, 0)
+ end_y = min(center_y + block_size_y // 2, h)
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
+ elif mask_index == 3:
+ center_x = torch.randint(0, w, (1,)).item()
+ center_y = torch.randint(0, h, (1,)).item()
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
+
+ start_x = max(center_x - block_size_x // 2, 0)
+ end_x = min(center_x + block_size_x // 2, w)
+ start_y = max(center_y - block_size_y // 2, 0)
+ end_y = min(center_y + block_size_y // 2, h)
+
+ mask_frame_before = np.random.randint(0, f // 2)
+ mask_frame_after = np.random.randint(f // 2, f)
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
+ else:
+ raise ValueError(f"The mask_index {mask_index} is not define")
+ return mask
+
+
+@contextmanager
+def VideoReader_contextmanager(*args, **kwargs):
+ vr = VideoReader(*args, **kwargs)
+ try:
+ yield vr
+ finally:
+ del vr
+ gc.collect()
+
+
+def get_video_reader_batch(video_reader, batch_index):
+ frames = video_reader.get_batch(batch_index).asnumpy()
+ return frames
+
+
+class WebVid10M(Dataset):
+ def __init__(
+ self,
+ csv_path, video_folder,
+ sample_size=256, sample_stride=4, sample_n_frames=16,
+ enable_bucket=False, enable_inpaint=False, is_image=False,
+ ):
+ print(f"loading annotations from {csv_path} ...")
+ with open(csv_path, 'r') as csvfile:
+ self.dataset = list(csv.DictReader(csvfile))
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+
+ self.video_folder = video_folder
+ self.sample_stride = sample_stride
+ self.sample_n_frames = sample_n_frames
+ self.enable_bucket = enable_bucket
+ self.enable_inpaint = enable_inpaint
+ self.is_image = is_image
+
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
+ self.pixel_transforms = transforms.Compose([
+ transforms.Resize(sample_size[0]),
+ transforms.CenterCrop(sample_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ])
+
+ def get_batch(self, idx):
+ video_dict = self.dataset[idx]
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
+
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
+ video_reader = VideoReader(video_dir)
+ video_length = len(video_reader)
+
+ if not self.is_image:
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
+ start_idx = random.randint(0, video_length - clip_length)
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
+ else:
+ batch_index = [random.randint(0, video_length - 1)]
+
+ if not self.enable_bucket:
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
+ pixel_values = pixel_values / 255.
+ del video_reader
+ else:
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
+
+ if self.is_image:
+ pixel_values = pixel_values[0]
+ return pixel_values, name
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ pixel_values, name = self.get_batch(idx)
+ break
+
+ except Exception as e:
+ print("Error info:", e)
+ idx = random.randint(0, self.length-1)
+
+ if not self.enable_bucket:
+ pixel_values = self.pixel_transforms(pixel_values)
+ if self.enable_inpaint:
+ mask = get_random_mask(pixel_values.size())
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
+ else:
+ sample = dict(pixel_values=pixel_values, text=name)
+ return sample
+
+
+class VideoDataset(Dataset):
+ def __init__(
+ self,
+ json_path, video_folder=None,
+ sample_size=256, sample_stride=4, sample_n_frames=16,
+ enable_bucket=False, enable_inpaint=False
+ ):
+ print(f"loading annotations from {json_path} ...")
+ self.dataset = json.load(open(json_path, 'r'))
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+
+ self.video_folder = video_folder
+ self.sample_stride = sample_stride
+ self.sample_n_frames = sample_n_frames
+ self.enable_bucket = enable_bucket
+ self.enable_inpaint = enable_inpaint
+
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
+ self.pixel_transforms = transforms.Compose(
+ [
+ transforms.Resize(sample_size[0]),
+ transforms.CenterCrop(sample_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ def get_batch(self, idx):
+ video_dict = self.dataset[idx]
+ video_id, name = video_dict['file_path'], video_dict['text']
+
+ if self.video_folder is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.video_folder, video_id)
+
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
+ video_length = len(video_reader)
+
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
+ start_idx = random.randint(0, video_length - clip_length)
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
+
+ try:
+ sample_args = (video_reader, batch_index)
+ pixel_values = func_timeout(
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+ )
+ except FunctionTimedOut:
+ raise ValueError(f"Read {idx} timeout.")
+ except Exception as e:
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+ if not self.enable_bucket:
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
+ pixel_values = pixel_values / 255.
+ del video_reader
+ else:
+ pixel_values = pixel_values
+
+ return pixel_values, name
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ pixel_values, name = self.get_batch(idx)
+ break
+
+ except Exception as e:
+ print("Error info:", e)
+ idx = random.randint(0, self.length-1)
+
+ if not self.enable_bucket:
+ pixel_values = self.pixel_transforms(pixel_values)
+ if self.enable_inpaint:
+ mask = get_random_mask(pixel_values.size())
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
+ else:
+ sample = dict(pixel_values=pixel_values, text=name)
+ return sample
+
+
+if __name__ == "__main__":
+ if 1:
+ dataset = VideoDataset(
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
+ sample_size=256,
+ sample_stride=4, sample_n_frames=16,
+ )
+
+ if 0:
+ dataset = WebVid10M(
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
+ sample_size=256,
+ sample_stride=4, sample_n_frames=16,
+ is_image=False,
+ )
+
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
+ for idx, batch in enumerate(dataloader):
+ print(batch["pixel_values"].shape, len(batch["text"]))
\ No newline at end of file
diff --git a/rose/dist/__init__.py b/rose/dist/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dea4c21bb4f0398efed1d17da98e03520d2c4d13
--- /dev/null
+++ b/rose/dist/__init__.py
@@ -0,0 +1,43 @@
+import importlib.util
+
+from .fsdp import shard_model
+from .fuser import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size, get_sp_group,
+ get_world_group, init_distributed_environment,
+ initialize_model_parallel, set_multi_gpus_devices,
+ xFuserLongContextAttention)
+from .wan_xfuser import usp_attn_forward
+
+# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
+if importlib.util.find_spec("pai_fuser") is not None:
+ from pai_fuser.core import parallel_magvit_vae
+ from pai_fuser.core.attention import wan_usp_sparse_attention_wrapper
+ from . import wan_xfuser
+
+ # The simple_wrapper is used to solve the problem about conflicts between cython and torch.compile
+ def simple_wrapper(func):
+ def inner(*args, **kwargs):
+ return func(*args, **kwargs)
+ return inner
+
+ wan_xfuser.usp_attn_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
+ usp_attn_forward = simple_wrapper(wan_xfuser.usp_attn_forward)
+ print("Import PAI VAE Turbo and Sparse Attention")
+
+ from pai_fuser.core.rope import ENABLE_KERNEL, usp_fast_rope_apply_qk
+
+ if ENABLE_KERNEL:
+ import torch
+ from .wan_xfuser import rope_apply
+
+ def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
+ if torch.is_grad_enabled():
+ q = rope_apply(q, grid_sizes, freqs)
+ k = rope_apply(k, grid_sizes, freqs)
+ return q, k
+ else:
+ return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
+
+ wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
+ rope_apply_qk = adaptive_fast_usp_rope_apply_qk
+ print("Import PAI Fast rope")
\ No newline at end of file
diff --git a/rose/dist/fsdp.py b/rose/dist/fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..555479f9b4e33d5d4bed54913ceab68adceac713
--- /dev/null
+++ b/rose/dist/fsdp.py
@@ -0,0 +1,43 @@
+# Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import gc
+from functools import partial
+
+import torch
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
+from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
+from torch.distributed.utils import _free_storage
+
+
+def shard_model(
+ model,
+ device_id,
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ buffer_dtype=torch.float32,
+ process_group=None,
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
+ sync_module_states=True,
+):
+ model = FSDP(
+ module=model,
+ process_group=process_group,
+ sharding_strategy=sharding_strategy,
+ auto_wrap_policy=partial(
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
+ mixed_precision=MixedPrecision(
+ param_dtype=param_dtype,
+ reduce_dtype=reduce_dtype,
+ buffer_dtype=buffer_dtype),
+ device_id=device_id,
+ sync_module_states=sync_module_states)
+ return model
+
+def free_model(model):
+ for m in model.modules():
+ if isinstance(m, FSDP):
+ _free_storage(m._handle.flat_param.data)
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
\ No newline at end of file
diff --git a/rose/dist/fuser.py b/rose/dist/fuser.py
new file mode 100644
index 0000000000000000000000000000000000000000..de12a2bdec6a309f3a83bed1b429d053d3569527
--- /dev/null
+++ b/rose/dist/fuser.py
@@ -0,0 +1,54 @@
+import importlib.util
+
+import torch
+import torch.distributed as dist
+
+try:
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
+ if importlib.util.find_spec("pai_fuser") is not None:
+ import pai_fuser
+ from pai_fuser.core.distributed import (
+ get_sequence_parallel_rank, get_sequence_parallel_world_size,
+ get_sp_group, get_world_group, init_distributed_environment,
+ initialize_model_parallel)
+ from pai_fuser.core.long_ctx_attention import \
+ xFuserLongContextAttention
+ print("Import PAI DiT Turbo")
+ else:
+ import xfuser
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group, get_world_group,
+ init_distributed_environment,
+ initialize_model_parallel)
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
+ print("Xfuser import sucessful")
+except Exception as ex:
+ get_sequence_parallel_world_size = None
+ get_sequence_parallel_rank = None
+ xFuserLongContextAttention = None
+ get_sp_group = None
+ get_world_group = None
+ init_distributed_environment = None
+ initialize_model_parallel = None
+
+def set_multi_gpus_devices(ulysses_degree, ring_degree):
+ if ulysses_degree > 1 or ring_degree > 1:
+ if get_sp_group is None:
+ raise RuntimeError("xfuser is not installed.")
+ dist.init_process_group("nccl")
+ print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
+ ulysses_degree, ring_degree, dist.get_rank(),
+ dist.get_world_size()))
+ assert dist.get_world_size() == ring_degree * ulysses_degree, \
+ "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
+ initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
+ ring_degree=ring_degree,
+ ulysses_degree=ulysses_degree)
+ # device = torch.device("cuda:%d" % dist.get_rank())
+ device = torch.device(f"cuda:{get_world_group().local_rank}")
+ print('rank=%d device=%s' % (get_world_group().rank, str(device)))
+ else:
+ device = "cuda"
+ return device
\ No newline at end of file
diff --git a/rose/dist/wan_xfuser.py b/rose/dist/wan_xfuser.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc59c8de2529df114caab61e675f406c1d12bd3
--- /dev/null
+++ b/rose/dist/wan_xfuser.py
@@ -0,0 +1,111 @@
+import torch
+import torch.cuda.amp as amp
+
+from .fuser import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size, get_sp_group,
+ init_distributed_environment, initialize_model_parallel,
+ xFuserLongContextAttention)
+
+
+def pad_freqs(original_tensor, target_len):
+ seq_len, s1, s2 = original_tensor.shape
+ pad_size = target_len - seq_len
+ padding_tensor = torch.ones(
+ pad_size,
+ s1,
+ s2,
+ dtype=original_tensor.dtype,
+ device=original_tensor.device)
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
+ return padded_tensor
+
+@amp.autocast(enabled=False)
+@torch.compiler.disable()
+def rope_apply(x, grid_sizes, freqs):
+ """
+ x: [B, L, N, C].
+ grid_sizes: [B, 3].
+ freqs: [M, C // 2].
+ """
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
+ # split freqs
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
+
+ # loop over samples
+ output = []
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
+ seq_len = f * h * w
+
+ # precompute multipliers
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
+ s, n, -1, 2))
+ freqs_i = torch.cat([
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ],
+ dim=-1).reshape(seq_len, 1, -1)
+
+ # apply rotary embedding
+ sp_size = get_sequence_parallel_world_size()
+ sp_rank = get_sequence_parallel_rank()
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
+ s_per_rank = s
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
+ s_per_rank), :, :]
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
+ x_i = torch.cat([x_i, x[i, s:]])
+
+ # append to collection
+ output.append(x_i)
+ return torch.stack(output)
+
+def rope_apply_qk(q, k, grid_sizes, freqs):
+ q = rope_apply(q, grid_sizes, freqs)
+ k = rope_apply(k, grid_sizes, freqs)
+ return q, k
+
+def usp_attn_forward(self,
+ x,
+ seq_lens,
+ grid_sizes,
+ freqs,
+ dtype=torch.bfloat16,
+ t=0):
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+ half_dtypes = (torch.float16, torch.bfloat16)
+
+ def half(x):
+ return x if x.dtype in half_dtypes else x.to(dtype)
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ v = self.v(x).view(b, s, n, d)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+ q, k = rope_apply_qk(q, k, grid_sizes, freqs)
+
+ # TODO: We should use unpaded q,k,v for attention.
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
+ # if k_lens is not None:
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
+
+ x = xFuserLongContextAttention()(
+ None,
+ query=half(q),
+ key=half(k),
+ value=half(v),
+ window_size=self.window_size)
+
+ # TODO: padding after attention.
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
\ No newline at end of file
diff --git a/rose/models/__init__.py b/rose/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6e7bd02faf879f2824d17138c93a8ac6d73b507
--- /dev/null
+++ b/rose/models/__init__.py
@@ -0,0 +1,6 @@
+from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
+
+from .wan_image_encoder import CLIPModel
+from .wan_text_encoder import WanT5EncoderModel
+from .wan_transformer3d import WanTransformer3DModel
+from .wan_vae import AutoencoderKLWan
diff --git a/rose/models/cache_utils.py b/rose/models/cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9209fa7909065a2d4c99c696a704e4bd26856ec
--- /dev/null
+++ b/rose/models/cache_utils.py
@@ -0,0 +1,74 @@
+import numpy as np
+import torch
+
+
+def get_teacache_coefficients(model_name):
+ if "wan2.1-t2v-1.3b" or "wan2.1-fun-1.3b" in model_name.lower():
+ return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
+ elif "wan2.1-t2v-14b" in model_name.lower():
+ return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
+ elif "wan2.1-i2v-14b-480p" in model_name.lower():
+ return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
+ elif "wan2.1-i2v-14b-720p" or "wan2.1-fun-14b" in model_name.lower():
+ return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
+ else:
+ print(f"The model {model_name} is not supported by TeaCache.")
+ return None
+
+
+class TeaCache():
+ """
+ Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
+ the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
+ Please refer to:
+ 1. https://github.com/ali-vilab/TeaCache.
+ 2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
+ """
+ def __init__(
+ self,
+ coefficients: list[float],
+ num_steps: int,
+ rel_l1_thresh: float = 0.0,
+ num_skip_start_steps: int = 0,
+ offload: bool = True,
+ ):
+ if num_steps < 1:
+ raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
+ if rel_l1_thresh < 0:
+ raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
+ if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
+ raise ValueError(
+ "`num_skip_start_steps` must be great than or equal to 0 and "
+ f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
+ )
+ self.coefficients = coefficients
+ self.num_steps = num_steps
+ self.rel_l1_thresh = rel_l1_thresh
+ self.num_skip_start_steps = num_skip_start_steps
+ self.offload = offload
+ self.rescale_func = np.poly1d(self.coefficients)
+
+ self.cnt = 0
+ self.should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ # Some pipelines concatenate the unconditional and text guide in forward.
+ self.previous_residual = None
+ # Some pipelines perform forward propagation separately on the unconditional and text guide.
+ self.previous_residual_cond = None
+ self.previous_residual_uncond = None
+
+ @staticmethod
+ def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
+ rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
+
+ return rel_l1_distance.cpu().item()
+
+ def reset(self):
+ self.cnt = 0
+ self.should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.previous_residual = None
+ self.previous_residual_cond = None
+ self.previous_residual_uncond = None
\ No newline at end of file
diff --git a/rose/models/diff_mask_predictor.py b/rose/models/diff_mask_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6ecd526fbac9725762db412a28996e0ca626782
--- /dev/null
+++ b/rose/models/diff_mask_predictor.py
@@ -0,0 +1,42 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+
+class DiffMaskPredictor(nn.Module):
+ def __init__(self, input_dim=4608, patch_grid=(10, 15, 189), output_grid=(81, 480, 720), hidden_dim=256):
+ """
+ Args:
+ input_dim (int): concatenated feature dimension, e.g. 1536 * num_selected_layers
+ patch_grid (tuple): (F_p, H_p, W_p) - patch token grid shape (e.g., from transformer block)
+ output_grid (tuple): (F, H, W) - final full resolution shape for mask
+ hidden_dim (int): intermediate conv/linear hidden dim
+ """
+ super().__init__()
+ self.F_p, self.H_p, self.W_p = patch_grid
+ self.F, self.H, self.W = output_grid
+
+ self.project = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim),
+ nn.GELU(),
+ nn.Linear(hidden_dim, 1)
+ )
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): shape [B, L, D_total], L = F_p H_p W_p
+ Returns:
+ Tensor: predicted diff mask, shape [B, 1, F, H, W]
+ """
+ B, L, D = x.shape
+ assert L == self.F_p * self.H_p * self.W_p, \
+ f"Input token length {L} doesn't match patch grid ({self.F_p}, {self.H_p}, {self.W_p})"
+
+ x = self.project(x) # [B, L, 1]
+ x = x.view(B, 1, self.F_p, self.H_p, self.W_p) # [B, 1, F_p, H_p, W_p]
+ x = F.interpolate(
+ x, size=(self.F, self.H, self.W),
+ mode="trilinear", align_corners=False # upsample to match ground truth resolution
+ )
+ return x # [B, 1, F, H, W]
diff --git a/rose/models/wan_image_encoder.py b/rose/models/wan_image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c442a79cef34081a9e67cdf0d9bf8d3ec67c9f
--- /dev/null
+++ b/rose/models/wan_image_encoder.py
@@ -0,0 +1,553 @@
+# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as T
+
+from .wan_transformer3d import attention
+from .wan_xlm_roberta import XLMRoberta
+from diffusers.configuration_utils import ConfigMixin
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.models.modeling_utils import ModelMixin
+
+
+__all__ = [
+ 'XLMRobertaCLIP',
+ 'clip_xlm_roberta_vit_h_14',
+ 'CLIPModel',
+]
+
+
+def pos_interpolate(pos, seq_len):
+ if pos.size(1) == seq_len:
+ return pos
+ else:
+ src_grid = int(math.sqrt(pos.size(1)))
+ tar_grid = int(math.sqrt(seq_len))
+ n = pos.size(1) - src_grid * src_grid
+ return torch.cat([
+ pos[:, :n],
+ F.interpolate(
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
+ 0, 3, 1, 2),
+ size=(tar_grid, tar_grid),
+ mode='bicubic',
+ align_corners=False).flatten(2).transpose(1, 2)
+ ],
+ dim=1)
+
+
+class QuickGELU(nn.Module):
+
+ def forward(self, x):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerNorm(nn.LayerNorm):
+
+ def forward(self, x):
+ return super().forward(x.float()).type_as(x)
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ causal=False,
+ attn_dropout=0.0,
+ proj_dropout=0.0):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.causal = causal
+ self.attn_dropout = attn_dropout
+ self.proj_dropout = proj_dropout
+
+ # layers
+ self.to_qkv = nn.Linear(dim, dim * 3)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
+
+ # compute attention
+ p = self.attn_dropout if self.training else 0.0
+ x = attention(q, k, v, dropout_p=p, causal=self.causal)
+ x = x.reshape(b, s, c)
+
+ # output
+ x = self.proj(x)
+ x = F.dropout(x, self.proj_dropout, self.training)
+ return x
+
+
+class SwiGLU(nn.Module):
+
+ def __init__(self, dim, mid_dim):
+ super().__init__()
+ self.dim = dim
+ self.mid_dim = mid_dim
+
+ # layers
+ self.fc1 = nn.Linear(dim, mid_dim)
+ self.fc2 = nn.Linear(dim, mid_dim)
+ self.fc3 = nn.Linear(mid_dim, dim)
+
+ def forward(self, x):
+ x = F.silu(self.fc1(x)) * self.fc2(x)
+ x = self.fc3(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ mlp_ratio,
+ num_heads,
+ post_norm=False,
+ causal=False,
+ activation='quick_gelu',
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ norm_eps=1e-5):
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.post_norm = post_norm
+ self.causal = causal
+ self.norm_eps = norm_eps
+
+ # layers
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
+ proj_dropout)
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
+ if activation == 'swi_glu':
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, int(dim * mlp_ratio)),
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
+
+ def forward(self, x):
+ if self.post_norm:
+ x = x + self.norm1(self.attn(x))
+ x = x + self.norm2(self.mlp(x))
+ else:
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class AttentionPool(nn.Module):
+
+ def __init__(self,
+ dim,
+ mlp_ratio,
+ num_heads,
+ activation='gelu',
+ proj_dropout=0.0,
+ norm_eps=1e-5):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.proj_dropout = proj_dropout
+ self.norm_eps = norm_eps
+
+ # layers
+ gain = 1.0 / math.sqrt(dim)
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.to_q = nn.Linear(dim, dim)
+ self.to_kv = nn.Linear(dim, dim * 2)
+ self.proj = nn.Linear(dim, dim)
+ self.norm = LayerNorm(dim, eps=norm_eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, int(dim * mlp_ratio)),
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
+
+ def forward(self, x):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
+
+ # compute attention
+ x = flash_attention(q, k, v, version=2)
+ x = x.reshape(b, 1, c)
+
+ # output
+ x = self.proj(x)
+ x = F.dropout(x, self.proj_dropout, self.training)
+
+ # mlp
+ x = x + self.mlp(self.norm(x))
+ return x[:, 0]
+
+
+class VisionTransformer(nn.Module):
+
+ def __init__(self,
+ image_size=224,
+ patch_size=16,
+ dim=768,
+ mlp_ratio=4,
+ out_dim=512,
+ num_heads=12,
+ num_layers=12,
+ pool_type='token',
+ pre_norm=True,
+ post_norm=False,
+ activation='quick_gelu',
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5):
+ if image_size % patch_size != 0:
+ print(
+ '[WARNING] image_size is not divisible by patch_size',
+ flush=True)
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
+ out_dim = out_dim or dim
+ super().__init__()
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = (image_size // patch_size)**2
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.pool_type = pool_type
+ self.post_norm = post_norm
+ self.norm_eps = norm_eps
+
+ # embeddings
+ gain = 1.0 / math.sqrt(dim)
+ self.patch_embedding = nn.Conv2d(
+ 3,
+ dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=not pre_norm)
+ if pool_type in ('token', 'token_fc'):
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
+ 1, self.num_patches +
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
+ self.dropout = nn.Dropout(embedding_dropout)
+
+ # transformer
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
+ self.transformer = nn.Sequential(*[
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
+ activation, attn_dropout, proj_dropout, norm_eps)
+ for _ in range(num_layers)
+ ])
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
+
+ # head
+ if pool_type == 'token':
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
+ elif pool_type == 'token_fc':
+ self.head = nn.Linear(dim, out_dim)
+ elif pool_type == 'attn_pool':
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
+ proj_dropout, norm_eps)
+
+ def forward(self, x, interpolation=False, use_31_block=False):
+ b = x.size(0)
+
+ # embeddings
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
+ if self.pool_type in ('token', 'token_fc'):
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
+ if interpolation:
+ e = pos_interpolate(self.pos_embedding, x.size(1))
+ else:
+ e = self.pos_embedding
+ x = self.dropout(x + e)
+ if self.pre_norm is not None:
+ x = self.pre_norm(x)
+
+ # transformer
+ if use_31_block:
+ x = self.transformer[:-1](x)
+ return x
+ else:
+ x = self.transformer(x)
+ return x
+
+
+class XLMRobertaWithHead(XLMRoberta):
+
+ def __init__(self, **kwargs):
+ self.out_dim = kwargs.pop('out_dim')
+ super().__init__(**kwargs)
+
+ # head
+ mid_dim = (self.dim + self.out_dim) // 2
+ self.head = nn.Sequential(
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
+ nn.Linear(mid_dim, self.out_dim, bias=False))
+
+ def forward(self, ids):
+ # xlm-roberta
+ x = super().forward(ids)
+
+ # average pooling
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
+
+ # head
+ x = self.head(x)
+ return x
+
+
+class XLMRobertaCLIP(nn.Module):
+
+ def __init__(self,
+ embed_dim=1024,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1280,
+ vision_mlp_ratio=4,
+ vision_heads=16,
+ vision_layers=32,
+ vision_pool='token',
+ vision_pre_norm=True,
+ vision_post_norm=False,
+ activation='gelu',
+ vocab_size=250002,
+ max_text_len=514,
+ type_size=1,
+ pad_id=1,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24,
+ text_post_norm=True,
+ text_dropout=0.1,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.vision_dim = vision_dim
+ self.vision_mlp_ratio = vision_mlp_ratio
+ self.vision_heads = vision_heads
+ self.vision_layers = vision_layers
+ self.vision_pre_norm = vision_pre_norm
+ self.vision_post_norm = vision_post_norm
+ self.activation = activation
+ self.vocab_size = vocab_size
+ self.max_text_len = max_text_len
+ self.type_size = type_size
+ self.pad_id = pad_id
+ self.text_dim = text_dim
+ self.text_heads = text_heads
+ self.text_layers = text_layers
+ self.text_post_norm = text_post_norm
+ self.norm_eps = norm_eps
+
+ # models
+ self.visual = VisionTransformer(
+ image_size=image_size,
+ patch_size=patch_size,
+ dim=vision_dim,
+ mlp_ratio=vision_mlp_ratio,
+ out_dim=embed_dim,
+ num_heads=vision_heads,
+ num_layers=vision_layers,
+ pool_type=vision_pool,
+ pre_norm=vision_pre_norm,
+ post_norm=vision_post_norm,
+ activation=activation,
+ attn_dropout=attn_dropout,
+ proj_dropout=proj_dropout,
+ embedding_dropout=embedding_dropout,
+ norm_eps=norm_eps)
+ self.textual = XLMRobertaWithHead(
+ vocab_size=vocab_size,
+ max_seq_len=max_text_len,
+ type_size=type_size,
+ pad_id=pad_id,
+ dim=text_dim,
+ out_dim=embed_dim,
+ num_heads=text_heads,
+ num_layers=text_layers,
+ post_norm=text_post_norm,
+ dropout=text_dropout)
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
+
+ def forward(self, imgs, txt_ids):
+ """
+ imgs: [B, 3, H, W] of torch.float32.
+ - mean: [0.48145466, 0.4578275, 0.40821073]
+ - std: [0.26862954, 0.26130258, 0.27577711]
+ txt_ids: [B, L] of torch.long.
+ Encoded by data.CLIPTokenizer.
+ """
+ xi = self.visual(imgs)
+ xt = self.textual(txt_ids)
+ return xi, xt
+
+ def param_groups(self):
+ groups = [{
+ 'params': [
+ p for n, p in self.named_parameters()
+ if 'norm' in n or n.endswith('bias')
+ ],
+ 'weight_decay': 0.0
+ }, {
+ 'params': [
+ p for n, p in self.named_parameters()
+ if not ('norm' in n or n.endswith('bias'))
+ ]
+ }]
+ return groups
+
+
+def _clip(pretrained=False,
+ pretrained_name=None,
+ model_cls=XLMRobertaCLIP,
+ return_transforms=False,
+ return_tokenizer=False,
+ tokenizer_padding='eos',
+ dtype=torch.float32,
+ device='cpu',
+ **kwargs):
+ # init a model on device
+ with torch.device(device):
+ model = model_cls(**kwargs)
+
+ # set device
+ model = model.to(dtype=dtype, device=device)
+ output = (model,)
+
+ # init transforms
+ if return_transforms:
+ # mean and std
+ if 'siglip' in pretrained_name.lower():
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
+ else:
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+
+ # transforms
+ transforms = T.Compose([
+ T.Resize((model.image_size, model.image_size),
+ interpolation=T.InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=mean, std=std)
+ ])
+ output += (transforms,)
+ return output[0] if len(output) == 1 else output
+
+
+def clip_xlm_roberta_vit_h_14(
+ pretrained=False,
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
+ **kwargs):
+ cfg = dict(
+ embed_dim=1024,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1280,
+ vision_mlp_ratio=4,
+ vision_heads=16,
+ vision_layers=32,
+ vision_pool='token',
+ activation='gelu',
+ vocab_size=250002,
+ max_text_len=514,
+ type_size=1,
+ pad_id=1,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24,
+ text_post_norm=True,
+ text_dropout=0.1,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0)
+ cfg.update(**kwargs)
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
+
+
+class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+
+ def __init__(self):
+ super(CLIPModel, self).__init__()
+ # init model
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
+ pretrained=False,
+ return_transforms=True,
+ return_tokenizer=False)
+
+ def forward(self, videos):
+ # preprocess
+ size = (self.model.image_size,) * 2
+ videos = torch.cat([
+ F.interpolate(
+ u.transpose(0, 1),
+ size=size,
+ mode='bicubic',
+ align_corners=False) for u in videos
+ ])
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
+
+ # forward
+ with torch.cuda.amp.autocast(dtype=self.dtype):
+ out = self.model.visual(videos, use_31_block=True)
+ return out
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}):
+ def filter_kwargs(cls, kwargs):
+ import inspect
+ sig = inspect.signature(cls.__init__)
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
+ return filtered_kwargs
+
+ model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
+ if pretrained_model_path.endswith(".safetensors"):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(pretrained_model_path)
+ else:
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
+ tmp_state_dict = {}
+ for key in state_dict:
+ tmp_state_dict["model." + key] = state_dict[key]
+ state_dict = tmp_state_dict
+ m, u = model.load_state_dict(state_dict)
+
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ print(m, u)
+ return model
\ No newline at end of file
diff --git a/rose/models/wan_text_encoder.py b/rose/models/wan_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..49fc936bb9953387f9848d7d6340a1f6f1815695
--- /dev/null
+++ b/rose/models/wan_text_encoder.py
@@ -0,0 +1,376 @@
+# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.models.modeling_utils import ModelMixin
+
+
+def fp16_clamp(x):
+ if x.dtype == torch.float16 and torch.isinf(x).any():
+ clamp = torch.finfo(x.dtype).max - 1000
+ x = torch.clamp(x, min=-clamp, max=clamp)
+ return x
+
+
+def init_weights(m):
+ if isinstance(m, T5LayerNorm):
+ nn.init.ones_(m.weight)
+ elif isinstance(m, T5FeedForward):
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
+ elif isinstance(m, T5Attention):
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
+ elif isinstance(m, T5RelativeEmbedding):
+ nn.init.normal_(
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
+
+
+class GELU(nn.Module):
+ def forward(self, x):
+ return 0.5 * x * (1.0 + torch.tanh(
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+class T5LayerNorm(nn.Module):
+ def __init__(self, dim, eps=1e-6):
+ super(T5LayerNorm, self).__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
+ self.eps)
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ x = x.type_as(self.weight)
+ return self.weight * x
+
+
+class T5Attention(nn.Module):
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
+ assert dim_attn % num_heads == 0
+ super(T5Attention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.num_heads = num_heads
+ self.head_dim = dim_attn // num_heads
+
+ # layers
+ self.q = nn.Linear(dim, dim_attn, bias=False)
+ self.k = nn.Linear(dim, dim_attn, bias=False)
+ self.v = nn.Linear(dim, dim_attn, bias=False)
+ self.o = nn.Linear(dim_attn, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, context=None, mask=None, pos_bias=None):
+ """
+ x: [B, L1, C].
+ context: [B, L2, C] or None.
+ mask: [B, L2] or [B, L1, L2] or None.
+ """
+ # check inputs
+ context = x if context is None else context
+ b, n, c = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x).view(b, -1, n, c)
+ k = self.k(context).view(b, -1, n, c)
+ v = self.v(context).view(b, -1, n, c)
+
+ # attention bias
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
+ if pos_bias is not None:
+ attn_bias += pos_bias
+ if mask is not None:
+ assert mask.ndim in [2, 3]
+ mask = mask.view(b, 1, 1,
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
+
+ # compute attention (T5 does not use scaling)
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
+
+ # output
+ x = x.reshape(b, -1, n * c)
+ x = self.o(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5FeedForward(nn.Module):
+
+ def __init__(self, dim, dim_ffn, dropout=0.1):
+ super(T5FeedForward, self).__init__()
+ self.dim = dim
+ self.dim_ffn = dim_ffn
+
+ # layers
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.fc1(x) * self.gate(x)
+ x = self.dropout(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5SelfAttention(nn.Module):
+ def __init__(self,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(T5SelfAttention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.norm1 = T5LayerNorm(dim)
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm2 = T5LayerNorm(dim)
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=True)
+
+ def forward(self, x, mask=None, pos_bias=None):
+ e = pos_bias if self.shared_pos else self.pos_embedding(
+ x.size(1), x.size(1))
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
+ return x
+
+
+class T5CrossAttention(nn.Module):
+ def __init__(self,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(T5CrossAttention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.norm1 = T5LayerNorm(dim)
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm2 = T5LayerNorm(dim)
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm3 = T5LayerNorm(dim)
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=False)
+
+ def forward(self,
+ x,
+ mask=None,
+ encoder_states=None,
+ encoder_mask=None,
+ pos_bias=None):
+ e = pos_bias if self.shared_pos else self.pos_embedding(
+ x.size(1), x.size(1))
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
+ x = fp16_clamp(x + self.cross_attn(
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
+ return x
+
+
+class T5RelativeEmbedding(nn.Module):
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
+ super(T5RelativeEmbedding, self).__init__()
+ self.num_buckets = num_buckets
+ self.num_heads = num_heads
+ self.bidirectional = bidirectional
+ self.max_dist = max_dist
+
+ # layers
+ self.embedding = nn.Embedding(num_buckets, num_heads)
+
+ def forward(self, lq, lk):
+ device = self.embedding.weight.device
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
+ # torch.arange(lq).unsqueeze(1).to(device)
+ if torch.device(type="meta") != device:
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
+ torch.arange(lq, device=device).unsqueeze(1)
+ else:
+ rel_pos = torch.arange(lk).unsqueeze(0) - \
+ torch.arange(lq).unsqueeze(1)
+ rel_pos = self._relative_position_bucket(rel_pos)
+ rel_pos_embeds = self.embedding(rel_pos)
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
+ 0) # [1, N, Lq, Lk]
+ return rel_pos_embeds.contiguous()
+
+ def _relative_position_bucket(self, rel_pos):
+ # preprocess
+ if self.bidirectional:
+ num_buckets = self.num_buckets // 2
+ rel_buckets = (rel_pos > 0).long() * num_buckets
+ rel_pos = torch.abs(rel_pos)
+ else:
+ num_buckets = self.num_buckets
+ rel_buckets = 0
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
+
+ # embeddings for small and large positions
+ max_exact = num_buckets // 2
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
+ math.log(self.max_dist / max_exact) *
+ (num_buckets - max_exact)).long()
+ rel_pos_large = torch.min(
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
+ return rel_buckets
+
+class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ def __init__(self,
+ vocab,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ num_layers,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(WanT5EncoderModel, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
+ else nn.Embedding(vocab, dim)
+ self.pos_embedding = T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
+ self.dropout = nn.Dropout(dropout)
+ self.blocks = nn.ModuleList([
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
+ shared_pos, dropout) for _ in range(num_layers)
+ ])
+ self.norm = T5LayerNorm(dim)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ x = self.token_embedding(input_ids)
+ x = self.dropout(x)
+ e = self.pos_embedding(x.size(1),
+ x.size(1)) if self.shared_pos else None
+ for block in self.blocks:
+ x = block(x, attention_mask, pos_bias=e)
+ x = self.norm(x)
+ x = self.dropout(x)
+ return (x, )
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16):
+ def filter_kwargs(cls, kwargs):
+ import inspect
+ sig = inspect.signature(cls.__init__)
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
+ return filtered_kwargs
+
+ if low_cpu_mem_usage:
+ try:
+ import re
+
+ from diffusers.models.modeling_utils import \
+ load_model_dict_into_meta
+ from diffusers.utils import is_accelerate_available
+ if is_accelerate_available():
+ import accelerate
+
+ # Instantiate model with empty weights
+ with accelerate.init_empty_weights():
+ model = cls(**filter_kwargs(cls, additional_kwargs))
+
+ param_device = "cpu"
+ if pretrained_model_path.endswith(".safetensors"):
+ from safetensors.torch import load_file
+ state_dict = load_file(pretrained_model_path)
+ else:
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
+ # move the params from meta device to cpu
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
+ if len(missing_keys) > 0:
+ raise ValueError(
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
+ " those weights or else make sure your checkpoint file is correct."
+ )
+
+ unexpected_keys = load_model_dict_into_meta(
+ model,
+ state_dict,
+ device=param_device,
+ dtype=torch_dtype,
+ model_name_or_path=pretrained_model_path,
+ )
+
+ if cls._keys_to_ignore_on_load_unexpected is not None:
+ for pat in cls._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+
+ if len(unexpected_keys) > 0:
+ print(
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
+ )
+ return model
+ except Exception as e:
+ print(
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
+ )
+
+ model = cls(**filter_kwargs(cls, additional_kwargs))
+ if pretrained_model_path.endswith(".safetensors"):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(pretrained_model_path)
+ else:
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ print(m, u)
+ return model
\ No newline at end of file
diff --git a/rose/models/wan_transformer3d.py b/rose/models/wan_transformer3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7f5041d168a4d8b5b5a2dab9a60e2c586e5fb0f
--- /dev/null
+++ b/rose/models/wan_transformer3d.py
@@ -0,0 +1,1203 @@
+# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import glob
+import json
+import math
+import os
+import types
+import warnings
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import is_torch_version, logging
+from torch import nn
+
+from ..dist import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size, get_sp_group,
+ xFuserLongContextAttention)
+from ..dist.wan_xfuser import usp_attn_forward
+from .cache_utils import TeaCache
+
+try:
+ import flash_attn_interface
+ FLASH_ATTN_3_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_3_AVAILABLE = False
+
+try:
+ import flash_attn
+ FLASH_ATTN_2_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_2_AVAILABLE = False
+
+
+def flash_attention(
+ q,
+ k,
+ v,
+ q_lens=None,
+ k_lens=None,
+ dropout_p=0.,
+ softmax_scale=None,
+ q_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ deterministic=False,
+ dtype=torch.bfloat16,
+ version=None,
+):
+ """
+ q: [B, Lq, Nq, C1].
+ k: [B, Lk, Nk, C1].
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
+ q_lens: [B].
+ k_lens: [B].
+ dropout_p: float. Dropout probability.
+ softmax_scale: float. The scaling of QK^T before applying softmax.
+ causal: bool. Whether to apply causal attention mask.
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
+ deterministic: bool. If True, slightly slower and uses more memory.
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
+ """
+ half_dtypes = (torch.float16, torch.bfloat16)
+ assert dtype in half_dtypes
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
+
+ # params
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
+
+ def half(x):
+ return x if x.dtype in half_dtypes else x.to(dtype)
+
+ # preprocess query
+ if q_lens is None:
+ q = half(q.flatten(0, 1))
+ q_lens = torch.tensor(
+ [lq] * b, dtype=torch.int32).to(
+ device=q.device, non_blocking=True)
+ else:
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
+
+ # preprocess key, value
+ if k_lens is None:
+ k = half(k.flatten(0, 1))
+ v = half(v.flatten(0, 1))
+ k_lens = torch.tensor(
+ [lk] * b, dtype=torch.int32).to(
+ device=k.device, non_blocking=True)
+ else:
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
+
+ q = q.to(v.dtype)
+ k = k.to(v.dtype)
+
+ if q_scale is not None:
+ q = q * q_scale
+
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
+ warnings.warn(
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
+ )
+
+ # apply attention
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
+ # Note: dropout_p, window_size are not supported in FA3 now.
+ x = flash_attn_interface.flash_attn_varlen_func(
+ q=q,
+ k=k,
+ v=v,
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
+ seqused_q=None,
+ seqused_k=None,
+ max_seqlen_q=lq,
+ max_seqlen_k=lk,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
+ else:
+ assert FLASH_ATTN_2_AVAILABLE
+ x = flash_attn.flash_attn_varlen_func(
+ q=q,
+ k=k,
+ v=v,
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
+ max_seqlen_q=lq,
+ max_seqlen_k=lk,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=window_size,
+ deterministic=deterministic).unflatten(0, (b, lq))
+
+ # output
+ return x.type(out_dtype)
+
+
+def attention(
+ q,
+ k,
+ v,
+ q_lens=None,
+ k_lens=None,
+ dropout_p=0.,
+ softmax_scale=None,
+ q_scale=None,
+ causal=False,
+ window_size=(-1, -1),
+ deterministic=False,
+ dtype=torch.bfloat16,
+ fa_version=None,
+):
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
+ return flash_attention(
+ q=q,
+ k=k,
+ v=v,
+ q_lens=q_lens,
+ k_lens=k_lens,
+ dropout_p=dropout_p,
+ softmax_scale=softmax_scale,
+ q_scale=q_scale,
+ causal=causal,
+ window_size=window_size,
+ deterministic=deterministic,
+ dtype=dtype,
+ version=fa_version,
+ )
+ else:
+ if q_lens is not None or k_lens is not None:
+ warnings.warn(
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
+ )
+ attn_mask = None
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ out = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
+
+ out = out.transpose(1, 2).contiguous()
+ return out
+
+
+def sinusoidal_embedding_1d(dim, position):
+ # preprocess
+ assert dim % 2 == 0
+ half = dim // 2
+ position = position.type(torch.float64)
+
+ # calculation
+ sinusoid = torch.outer(
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x
+
+
+@amp.autocast(enabled=False)
+def rope_params(max_seq_len, dim, theta=10000):
+ assert dim % 2 == 0
+ freqs = torch.outer(
+ torch.arange(max_seq_len),
+ 1.0 / torch.pow(theta,
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs
+
+# modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py
+@amp.autocast(enabled=False)
+def get_1d_rotary_pos_embed_riflex(
+ pos: Union[np.ndarray, int],
+ dim: int,
+ theta: float = 10000.0,
+ use_real=False,
+ k: Optional[int] = None,
+ L_test: Optional[int] = None,
+ L_test_scale: Optional[int] = None,
+):
+ """
+ RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
+ data type.
+
+ Args:
+ dim (`int`): Dimension of the frequency tensor.
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
+ theta (`float`, *optional*, defaults to 10000.0):
+ Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (`bool`, *optional*):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
+ L_test (`int`, *optional*, defaults to None): the number of frames for inference
+ Returns:
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
+ """
+ assert dim % 2 == 0
+
+ if isinstance(pos, int):
+ pos = torch.arange(pos)
+ if isinstance(pos, np.ndarray):
+ pos = torch.from_numpy(pos) # type: ignore # [S]
+
+ freqs = 1.0 / torch.pow(theta,
+ torch.arange(0, dim, 2).to(torch.float64).div(dim))
+
+ # === Riflex modification start ===
+ # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
+ # Empirical observations show that a few videos may exhibit repetition in the tail frames.
+ # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
+ if k is not None:
+ freqs[k-1] = 0.9 * 2 * torch.pi / L_test
+ # === Riflex modification end ===
+ if L_test_scale is not None:
+ freqs[k-1] = freqs[k-1] / L_test_scale
+
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
+ if use_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ # lumina
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+@amp.autocast(enabled=False)
+def rope_apply(x, grid_sizes, freqs):
+ n, c = x.size(2), x.size(3) // 2
+
+ # split freqs
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
+
+ # loop over samples
+ output = []
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
+ seq_len = f * h * w
+
+ # precompute multipliers
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
+ seq_len, n, -1, 2))
+ freqs_i = torch.cat([
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ],
+ dim=-1).reshape(seq_len, 1, -1)
+
+ # apply rotary embedding
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
+ x_i = torch.cat([x_i, x[i, seq_len:]])
+
+ # append to collection
+ output.append(x_i)
+ return torch.stack(output).float()
+
+
+class WanRMSNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return self._norm(x.float()).type_as(x) * self.weight
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+
+class WanLayerNorm(nn.LayerNorm):
+
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class WanSelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, seq_lens, grid_sizes, freqs, dtype):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
+ seq_lens(Tensor): Shape [B]
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d)
+ k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d)
+ v = self.v(x.to(dtype)).view(b, s, n, d)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+
+ x = attention(
+ q=rope_apply(q, grid_sizes, freqs).to(dtype),
+ k=rope_apply(k, grid_sizes, freqs).to(dtype),
+ v=v.to(dtype),
+ k_lens=seq_lens,
+ window_size=self.window_size)
+ x = x.to(dtype)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanT2VCrossAttention(WanSelfAttention):
+
+ def forward(self, x, context, context_lens, dtype):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
+
+ # compute attention
+ x = attention(
+ q.to(dtype),
+ k.to(dtype),
+ v.to(dtype),
+ k_lens=context_lens
+ )
+ x = x.to(dtype)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanI2VCrossAttention(WanSelfAttention):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
+
+ self.k_img = nn.Linear(dim, dim)
+ self.v_img = nn.Linear(dim, dim)
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, context, context_lens, dtype):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ context_img = context[:, :257]
+ context = context[:, 257:]
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
+ k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d)
+ v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d)
+
+ img_x = attention(
+ q.to(dtype),
+ k_img.to(dtype),
+ v_img.to(dtype),
+ k_lens=None
+ )
+ img_x = img_x.to(dtype)
+ # compute attention
+ x = attention(
+ q.to(dtype),
+ k.to(dtype),
+ v.to(dtype),
+ k_lens=context_lens
+ )
+ x = x.to(dtype)
+
+ # output
+ x = x.flatten(2)
+ img_x = img_x.flatten(2)
+ x = x + img_x
+ x = self.o(x)
+ return x
+
+
+WAN_CROSSATTENTION_CLASSES = {
+ 't2v_cross_attn': WanT2VCrossAttention,
+ 'i2v_cross_attn': WanI2VCrossAttention,
+}
+
+
+class WanAttentionBlock(nn.Module):
+
+ def __init__(self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6):
+ super().__init__()
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # layers
+ self.norm1 = WanLayerNorm(dim, eps)
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
+ eps)
+ self.norm3 = WanLayerNorm(
+ dim, eps,
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
+ num_heads,
+ (-1, -1),
+ qk_norm,
+ eps)
+ self.norm2 = WanLayerNorm(dim, eps)
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
+ nn.Linear(ffn_dim, dim))
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ x,
+ e,
+ seq_lens,
+ grid_sizes,
+ freqs,
+ context,
+ context_lens,
+ dtype=torch.float32
+ ):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ e(Tensor): Shape [B, 6, C]
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ e = (self.modulation + e).chunk(6, dim=1)
+
+ # self-attention
+ temp_x = self.norm1(x) * (1 + e[1]) + e[0]
+ temp_x = temp_x.to(dtype)
+
+ y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype)
+ x = x + y * e[2]
+
+ # cross-attention & ffn function
+ def cross_attn_ffn(x, context, context_lens, e):
+ # cross-attention
+ x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype)
+
+ # ffn function
+ temp_x = self.norm2(x) * (1 + e[4]) + e[3]
+ temp_x = temp_x.to(dtype)
+
+ y = self.ffn(temp_x)
+ x = x + y * e[5]
+ return x
+
+ x = cross_attn_ffn(x, context, context_lens, e)
+ return x
+
+
+class Head(nn.Module):
+
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
+ super().__init__()
+ self.dim = dim
+ self.out_dim = out_dim
+ self.patch_size = patch_size
+ self.eps = eps
+
+ # layers
+ out_dim = math.prod(patch_size) * out_dim
+ self.norm = WanLayerNorm(dim, eps)
+ self.head = nn.Linear(dim, out_dim)
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+
+ def forward(self, x, e):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ e(Tensor): Shape [B, C]
+ """
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
+ return x
+
+
+class MLPProj(torch.nn.Module):
+
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
+ torch.nn.LayerNorm(out_dim))
+
+ def forward(self, image_embeds):
+ clip_extra_context_tokens = self.proj(image_embeds)
+ return clip_extra_context_tokens
+
+
+
+class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ # ignore_for_config = [
+ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
+ # ]
+ # _no_split_modules = ['WanAttentionBlock']
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ in_channels=16,
+ hidden_size=2048,
+ ):
+ r"""
+ Initialize the diffusion model backbone.
+
+ Args:
+ model_type (`str`, *optional*, defaults to 't2v'):
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
+ text_len (`int`, *optional*, defaults to 512):
+ Fixed length for text embeddings
+ in_dim (`int`, *optional*, defaults to 16):
+ Input video channels (C_in)
+ dim (`int`, *optional*, defaults to 2048):
+ Hidden dimension of the transformer
+ ffn_dim (`int`, *optional*, defaults to 8192):
+ Intermediate dimension in feed-forward network
+ freq_dim (`int`, *optional*, defaults to 256):
+ Dimension for sinusoidal time embeddings
+ text_dim (`int`, *optional*, defaults to 4096):
+ Input dimension for text embeddings
+ out_dim (`int`, *optional*, defaults to 16):
+ Output video channels (C_out)
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads
+ num_layers (`int`, *optional*, defaults to 32):
+ Number of transformer blocks
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
+ Window size for local attention (-1 indicates global attention)
+ qk_norm (`bool`, *optional*, defaults to True):
+ Enable query/key normalization
+ cross_attn_norm (`bool`, *optional*, defaults to False):
+ Enable cross-attention normalization
+ eps (`float`, *optional*, defaults to 1e-6):
+ Epsilon value for normalization layers
+ """
+
+ super().__init__()
+
+ assert model_type in ['t2v', 'i2v']
+ self.model_type = model_type
+
+ self.patch_size = patch_size
+ self.text_len = text_len
+ self.in_dim = in_dim
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.freq_dim = freq_dim
+ self.text_dim = text_dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # embeddings
+ self.patch_embedding = nn.Conv3d(
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
+ self.text_embedding = nn.Sequential(
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
+ nn.Linear(dim, dim))
+
+ self.time_embedding = nn.Sequential(
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
+
+ # blocks
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
+ self.blocks = nn.ModuleList([
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm, eps)
+ for _ in range(num_layers)
+ ])
+
+ # head
+ self.head = Head(dim, out_dim, patch_size, eps)
+
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
+ d = dim // num_heads
+ self.d = d
+ self.freqs = torch.cat(
+ [
+ rope_params(1024, d - 4 * (d // 6)),
+ rope_params(1024, 2 * (d // 6)),
+ rope_params(1024, 2 * (d // 6))
+ ],
+ dim=1
+ )
+
+ if model_type == 'i2v':
+ self.img_emb = MLPProj(1280, dim)
+
+ self.teacache = None
+ self.gradient_checkpointing = False
+ self.sp_world_size = 1
+ self.sp_world_rank = 0
+
+ def enable_teacache(
+ self,
+ coefficients,
+ num_steps: int,
+ rel_l1_thresh: float,
+ num_skip_start_steps: int = 0,
+ offload: bool = True
+ ):
+ self.teacache = TeaCache(
+ coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
+ )
+
+ def disable_teacache(self):
+ self.teacache = None
+
+ def enable_riflex(
+ self,
+ k = 6,
+ L_test = 66,
+ L_test_scale = 4.886,
+ ):
+ device = self.freqs.device
+ self.freqs = torch.cat(
+ [
+ get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale),
+ rope_params(1024, 2 * (self.d // 6)),
+ rope_params(1024, 2 * (self.d // 6))
+ ],
+ dim=1
+ ).to(device)
+
+ def disable_riflex(self):
+ device = self.freqs.device
+ self.freqs = torch.cat(
+ [
+ rope_params(1024, self.d - 4 * (self.d // 6)),
+ rope_params(1024, 2 * (self.d // 6)),
+ rope_params(1024, 2 * (self.d // 6))
+ ],
+ dim=1
+ ).to(device)
+
+ def enable_multi_gpus_inference(self,):
+ self.sp_world_size = get_sequence_parallel_world_size()
+ self.sp_world_rank = get_sequence_parallel_rank()
+ for block in self.blocks:
+ block.self_attn.forward = types.MethodType(
+ usp_attn_forward, block.self_attn)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ def forward(
+ self,
+ x,
+ t,
+ context,
+ seq_len,
+ clip_fea=None,
+ y=None,
+ cond_flag=True,
+ return_intermediate=False, # whether return mid layers' output
+ selected_layers=(5, 15, 25) # layer idx to output
+ ):
+ r"""
+ Forward pass through the diffusion model
+
+ Args:
+ x (List[Tensor]):
+ List of input video tensors, each with shape [C_in, F, H, W]
+ t (Tensor):
+ Diffusion timesteps tensor of shape [B]
+ context (List[Tensor]):
+ List of text embeddings each with shape [L, C]
+ seq_len (`int`):
+ Maximum sequence length for positional encoding
+ clip_fea (Tensor, *optional*):
+ CLIP image features for image-to-video mode
+ y (List[Tensor], *optional*):
+ Conditional video inputs for image-to-video mode, same shape as x
+ cond_flag (`bool`, *optional*, defaults to True):
+ Flag to indicate whether to forward the condition input
+
+ Returns:
+ List[Tensor]:
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
+ """
+ if self.model_type == 'i2v':
+ assert clip_fea is not None and y is not None
+ # params
+ device = self.patch_embedding.weight.device
+ dtype = x.dtype
+ if self.freqs.device != device and torch.device(type="meta") != device:
+ self.freqs = self.freqs.to(device)
+
+ if y is not None:
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack(
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
+ if self.sp_world_size > 1:
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
+ assert seq_lens.max() <= seq_len
+ x = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
+ dim=1) for u in x
+ ])
+
+ # time embeddings
+ with amp.autocast(dtype=torch.float32):
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+ # to bfloat16 for saving memeory
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
+ e0 = e0.to(dtype)
+ e = e.to(dtype)
+
+ # context
+ context_lens = None
+ context = self.text_embedding(
+ torch.stack([
+ torch.cat(
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
+ for u in context
+ ]))
+
+ if clip_fea is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context = torch.concat([context_clip, context], dim=1)
+
+ # Context Parallel
+ if self.sp_world_size > 1:
+ x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
+
+ # TeaCache
+ if self.teacache is not None:
+ if cond_flag:
+ modulated_inp = e0
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
+ if self.teacache.cnt == 0 or self.teacache.cnt == self.teacache.num_steps - 1 or skip_flag:
+ should_calc = True
+ self.teacache.accumulated_rel_l1_distance = 0
+ else:
+ if cond_flag:
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.teacache.accumulated_rel_l1_distance = 0
+ self.teacache.previous_modulated_input = modulated_inp
+ self.teacache.cnt += 1
+ if self.teacache.cnt == self.teacache.num_steps:
+ self.teacache.reset()
+ self.teacache.should_calc = should_calc
+ else:
+ should_calc = self.teacache.should_calc
+
+ intermediate_features = []
+
+ # TeaCache
+ if self.teacache is not None:
+ if not should_calc:
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
+ x = x + previous_residual.to(x.device)
+ else:
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
+
+ for idx, block in enumerate(self.blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x,
+ e0,
+ seq_lens,
+ grid_sizes,
+ self.freqs,
+ context,
+ context_lens,
+ dtype,
+ **ckpt_kwargs,
+ )
+ else:
+ # arguments
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=self.freqs,
+ context=context,
+ context_lens=context_lens,
+ dtype=dtype
+ )
+ x = block(x, **kwargs)
+
+ if return_intermediate and idx in selected_layers:
+ intermediate_features.append(x.clone())
+
+ if cond_flag:
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
+ else:
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
+ else:
+ for idx, block in enumerate(self.blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x,
+ e0,
+ seq_lens,
+ grid_sizes,
+ self.freqs,
+ context,
+ context_lens,
+ dtype,
+ **ckpt_kwargs,
+ )
+ else:
+ # arguments
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=self.freqs,
+ context=context,
+ context_lens=context_lens,
+ dtype=dtype
+ )
+ x = block(x, **kwargs)
+
+ if return_intermediate and idx in selected_layers:
+ intermediate_features.append(x.clone())
+
+
+ if self.sp_world_size > 1:
+ x = get_sp_group().all_gather(x, dim=1)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ x = torch.stack(x)
+
+ if return_intermediate:
+ mid_feat = intermediate_features
+ return x, mid_feat
+ else:
+ return x
+
+
+ def unpatchify(self, x, grid_sizes):
+ r"""
+ Reconstruct video tensors from patch embeddings.
+
+ Args:
+ x (List[Tensor]):
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
+ grid_sizes (Tensor):
+ Original spatial-temporal grid dimensions before patching,
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
+
+ Returns:
+ List[Tensor]:
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
+ """
+
+ c = self.out_dim
+ out = []
+ for u, v in zip(x, grid_sizes.tolist()):
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
+ out.append(u)
+ return out
+
+ def init_weights(self):
+ r"""
+ Initialize model parameters using Xavier initialization.
+ """
+
+ # basic init
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ # init embeddings
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
+ for m in self.text_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+ for m in self.time_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+
+ # init output layer
+ nn.init.zeros_(self.head.head.weight)
+
+ @classmethod
+ def from_pretrained(
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
+ ):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ from diffusers.utils import WEIGHTS_NAME
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
+
+ if "dict_mapping" in transformer_additional_kwargs.keys():
+ for key in transformer_additional_kwargs["dict_mapping"]:
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
+
+ if low_cpu_mem_usage:
+ try:
+ import re
+
+ from diffusers.models.modeling_utils import \
+ load_model_dict_into_meta
+ from diffusers.utils import is_accelerate_available
+ if is_accelerate_available():
+ import accelerate
+
+ # Instantiate model with empty weights
+ with accelerate.init_empty_weights():
+ model = cls.from_config(config, **transformer_additional_kwargs)
+
+ param_device = "cpu"
+ if os.path.exists(model_file):
+ state_dict = torch.load(model_file, map_location="cpu")
+ elif os.path.exists(model_file_safetensors):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(model_file_safetensors)
+ else:
+ from safetensors.torch import load_file, safe_open
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
+ state_dict = {}
+ print(model_files_safetensors)
+ for _model_file_safetensors in model_files_safetensors:
+ _state_dict = load_file(_model_file_safetensors)
+ for key in _state_dict:
+ state_dict[key] = _state_dict[key]
+ model._convert_deprecated_attention_blocks(state_dict)
+ # move the params from meta device to cpu
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
+ if len(missing_keys) > 0:
+ raise ValueError(
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
+ " those weights or else make sure your checkpoint file is correct."
+ )
+
+ unexpected_keys = load_model_dict_into_meta(
+ model,
+ state_dict,
+ device=param_device,
+ dtype=torch_dtype,
+ model_name_or_path=pretrained_model_path,
+ )
+
+ if cls._keys_to_ignore_on_load_unexpected is not None:
+ for pat in cls._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+
+ if len(unexpected_keys) > 0:
+ print(
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
+ )
+ return model
+ except Exception as e:
+ print(
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
+ )
+
+ model = cls.from_config(config, **transformer_additional_kwargs)
+ if os.path.exists(model_file):
+ state_dict = torch.load(model_file, map_location="cpu")
+ elif os.path.exists(model_file_safetensors):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(model_file_safetensors)
+ else:
+ from safetensors.torch import load_file, safe_open
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
+ state_dict = {}
+ for _model_file_safetensors in model_files_safetensors:
+ _state_dict = load_file(_model_file_safetensors)
+ for key in _state_dict:
+ state_dict[key] = _state_dict[key]
+
+ if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
+ model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight']
+ model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
+ state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
+
+ tmp_state_dict = {}
+ for key in state_dict:
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
+ tmp_state_dict[key] = state_dict[key]
+ else:
+ print(key, "Size don't match, skip")
+
+ state_dict = tmp_state_dict
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ print(m)
+
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
+ print(f"### All Parameters: {sum(params) / 1e6} M")
+
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
+
+ model = model.to(torch_dtype)
+ return model
\ No newline at end of file
diff --git a/rose/models/wan_vae.py b/rose/models/wan_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..f25b7ad15f42fd700f76eb53e511cef3233362d5
--- /dev/null
+++ b/rose/models/wan_vae.py
@@ -0,0 +1,705 @@
+# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.models.autoencoders.vae import (DecoderOutput,
+ DiagonalGaussianDistribution)
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from einops import rearrange
+
+CACHE_T = 2
+
+
+class CausalConv3d(nn.Conv3d):
+ """
+ Causal 3d convolusion.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
+ self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+
+ return super().forward(x)
+
+
+class RMS_norm(nn.Module):
+
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
+
+ def forward(self, x):
+ return F.normalize(
+ x, dim=(1 if self.channel_first else
+ -1)) * self.scale * self.gamma + self.bias
+
+
+class Upsample(nn.Upsample):
+
+ def forward(self, x):
+ """
+ Fix bfloat16 support for nearest neighbor interpolation.
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class Resample(nn.Module):
+
+ def __init__(self, dim, mode):
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
+ 'downsample3d')
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == 'upsample2d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
+ elif mode == 'upsample3d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
+ self.time_conv = CausalConv3d(
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == 'downsample2d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == 'downsample3d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = CausalConv3d(
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == 'upsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = 'Rep'
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] != 'Rep':
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] == 'Rep':
+ cache_x = torch.cat([
+ torch.zeros_like(cache_x).to(cache_x.device),
+ cache_x
+ ],
+ dim=2)
+ if feat_cache[idx] == 'Rep':
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
+ 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.resample(x)
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
+
+ if self.mode == 'downsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -1:, :, :].clone()
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
+ # # cache last frame of last two chunk
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.time_conv(
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+ def init_weight(self, conv):
+ conv_weight = conv.weight
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ one_matrix = torch.eye(c1, c2)
+ init_matrix = one_matrix
+ nn.init.zeros_(conv_weight)
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def init_weight2(self, conv):
+ conv_weight = conv.weight.data
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ init_matrix = torch.eye(c1 // 2, c2)
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, in_dim, out_dim, dropout=0.0):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # layers
+ self.residual = nn.Sequential(
+ RMS_norm(in_dim, images=False), nn.SiLU(),
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
+ if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ h = self.shortcut(x)
+ for layer in self.residual:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ Causal self-attention with a single head.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = RMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.proj.weight)
+
+ def forward(self, x):
+ identity = x
+ b, c, t, h, w = x.size()
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.norm(x)
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
+ -1).permute(0, 1, 3,
+ 2).contiguous().chunk(
+ 3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ )
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
+
+ # output
+ x = self.proj(x)
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
+ return x + identity
+
+
+class Encoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ downsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ downsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = 'downsample3d' if temperal_downsample[
+ i] else 'downsample2d'
+ downsamples.append(Resample(out_dim, mode=mode))
+ scale /= 2.0
+ self.downsamples = nn.Sequential(*downsamples)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
+ ResidualBlock(out_dim, out_dim, dropout))
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## downsamples
+ for layer in self.downsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+class Decoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2**(len(dim_mult) - 2)
+
+ # init block
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
+ ResidualBlock(dims[0], dims[0], dropout))
+
+ # upsample blocks
+ upsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i == 1 or i == 2 or i == 3:
+ in_dim = in_dim // 2
+ for _ in range(num_res_blocks + 1):
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ upsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # upsample block
+ if i != len(dim_mult) - 1:
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
+ upsamples.append(Resample(out_dim, mode=mode))
+ scale *= 2.0
+ self.upsamples = nn.Sequential(*upsamples)
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, 3, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## upsamples
+ for layer in self.upsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+def count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, CausalConv3d):
+ count += 1
+ return count
+
+
+class AutoencoderKLWan_(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ # modules
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_downsample, dropout)
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_upsample, dropout)
+
+ def forward(self, x):
+ mu, log_var = self.encode(x)
+ z = self.reparameterize(mu, log_var)
+ x_recon = self.decode(z)
+ return x_recon, mu, log_var
+
+ def encode(self, x, scale):
+ self.clear_cache()
+ ## cache
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+ scale = [item.to(x.device, x.dtype) for item in scale]
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(
+ x[:, :, :1, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ out = torch.cat([out, out_], 2)
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
+ if isinstance(scale[0], torch.Tensor):
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
+ 1, self.z_dim, 1, 1, 1)
+ else:
+ mu = (mu - scale[0]) * scale[1]
+ x = torch.cat([mu, log_var], dim = 1)
+ self.clear_cache()
+ return x
+
+ def decode(self, z, scale):
+ self.clear_cache()
+ # z: [b,c,t,h,w]
+ scale = [item.to(z.device, z.dtype) for item in scale]
+ if isinstance(scale[0], torch.Tensor):
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
+ 1, self.z_dim, 1, 1, 1)
+ else:
+ z = z / scale[1] + scale[0]
+ iter_ = z.shape[2]
+ x = self.conv2(z)
+ for i in range(iter_):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+ self.clear_cache()
+ return out
+
+ def reparameterize(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample(self, imgs, deterministic=False):
+ mu, log_var = self.encode(imgs)
+ if deterministic:
+ return mu
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
+ return mu + std * torch.randn_like(std)
+
+ def clear_cache(self):
+ self._conv_num = count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ #cache encode
+ self._enc_conv_num = count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+
+def _video_vae(z_dim=None, **kwargs):
+ """
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
+ """
+ # params
+ cfg = dict(
+ dim=96,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[False, True, True],
+ dropout=0.0)
+ cfg.update(**kwargs)
+
+ # init model
+ model = AutoencoderKLWan_(**cfg)
+
+ return model
+
+
+class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+
+ @register_to_config
+ def __init__(
+ self,
+ latent_channels=16,
+ temporal_compression_ratio=4,
+ spacial_compression_ratio=8
+ ):
+ super().__init__()
+ mean = [
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
+ ]
+ std = [
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
+ ]
+ self.mean = torch.tensor(mean, dtype=torch.float32)
+ self.std = torch.tensor(std, dtype=torch.float32)
+ self.scale = [self.mean, 1.0 / self.std]
+
+ # init model
+ self.model = _video_vae(
+ z_dim=latent_channels,
+ )
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ x = [
+ self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
+ for u in x
+ ]
+ x = torch.stack(x)
+ return x
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, zs):
+ dec = [
+ self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
+ for u in zs
+ ]
+ dec = torch.stack(dec)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
+ def filter_kwargs(cls, kwargs):
+ import inspect
+ sig = inspect.signature(cls.__init__)
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
+ return filtered_kwargs
+
+ model = cls(**filter_kwargs(cls, additional_kwargs))
+ if pretrained_model_path.endswith(".safetensors"):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(pretrained_model_path)
+ else:
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
+ tmp_state_dict = {}
+ for key in state_dict:
+ tmp_state_dict["model." + key] = state_dict[key]
+ state_dict = tmp_state_dict
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ print(m, u)
+ return model
\ No newline at end of file
diff --git a/rose/models/wan_xlm_roberta.py b/rose/models/wan_xlm_roberta.py
new file mode 100644
index 0000000000000000000000000000000000000000..755baf394431bee95e1eac835b5dafe6ed37c5b9
--- /dev/null
+++ b/rose/models/wan_xlm_roberta.py
@@ -0,0 +1,170 @@
+# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['XLMRoberta', 'xlm_roberta_large']
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, mask):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+
+ # compute attention
+ p = self.dropout.p if self.training else 0.0
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
+
+ # output
+ x = self.o(x)
+ x = self.dropout(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.post_norm = post_norm
+ self.eps = eps
+
+ # layers
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
+ nn.Dropout(dropout))
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
+
+ def forward(self, x, mask):
+ if self.post_norm:
+ x = self.norm1(x + self.attn(x, mask))
+ x = self.norm2(x + self.ffn(x))
+ else:
+ x = x + self.attn(self.norm1(x), mask)
+ x = x + self.ffn(self.norm2(x))
+ return x
+
+
+class XLMRoberta(nn.Module):
+ """
+ XLMRobertaModel with no pooler and no LM head.
+ """
+
+ def __init__(self,
+ vocab_size=250002,
+ max_seq_len=514,
+ type_size=1,
+ pad_id=1,
+ dim=1024,
+ num_heads=16,
+ num_layers=24,
+ post_norm=True,
+ dropout=0.1,
+ eps=1e-5):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.max_seq_len = max_seq_len
+ self.type_size = type_size
+ self.pad_id = pad_id
+ self.dim = dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.post_norm = post_norm
+ self.eps = eps
+
+ # embeddings
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
+ self.type_embedding = nn.Embedding(type_size, dim)
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
+ self.dropout = nn.Dropout(dropout)
+
+ # blocks
+ self.blocks = nn.ModuleList([
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
+ for _ in range(num_layers)
+ ])
+
+ # norm layer
+ self.norm = nn.LayerNorm(dim, eps=eps)
+
+ def forward(self, ids):
+ """
+ ids: [B, L] of torch.LongTensor.
+ """
+ b, s = ids.shape
+ mask = ids.ne(self.pad_id).long()
+
+ # embeddings
+ x = self.token_embedding(ids) + \
+ self.type_embedding(torch.zeros_like(ids)) + \
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
+ if self.post_norm:
+ x = self.norm(x)
+ x = self.dropout(x)
+
+ # blocks
+ mask = torch.where(
+ mask.view(b, 1, 1, s).gt(0), 0.0,
+ torch.finfo(x.dtype).min)
+ for block in self.blocks:
+ x = block(x, mask)
+
+ # output
+ if not self.post_norm:
+ x = self.norm(x)
+ return x
+
+
+def xlm_roberta_large(pretrained=False,
+ return_tokenizer=False,
+ device='cpu',
+ **kwargs):
+ """
+ XLMRobertaLarge adapted from Huggingface.
+ """
+ # params
+ cfg = dict(
+ vocab_size=250002,
+ max_seq_len=514,
+ type_size=1,
+ pad_id=1,
+ dim=1024,
+ num_heads=16,
+ num_layers=24,
+ post_norm=True,
+ dropout=0.1,
+ eps=1e-5)
+ cfg.update(**kwargs)
+
+ # init a model on device
+ with torch.device(device):
+ model = XLMRoberta(**cfg)
+ return model
\ No newline at end of file
diff --git a/rose/pipeline/__init__.py b/rose/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b18563edf028a869bc978af354eeb799dafcc1eb
--- /dev/null
+++ b/rose/pipeline/__init__.py
@@ -0,0 +1,6 @@
+from .pipeline_wan_fun import WanFunPipeline
+from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline
+from .pipeline_wan_fun_control import WanFunControlPipeline
+
+WanPipeline = WanFunPipeline
+WanI2VPipeline = WanFunInpaintPipeline
\ No newline at end of file
diff --git a/rose/pipeline/pipeline_wan_fun.py b/rose/pipeline/pipeline_wan_fun.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a6317dd4e1145c0a369f3ab03afa20c37db3a5d
--- /dev/null
+++ b/rose/pipeline/pipeline_wan_fun.py
@@ -0,0 +1,558 @@
+import inspect
+import math
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers import FlowMatchEulerDiscreteScheduler
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+from ..models import (AutoencoderKLWan, AutoTokenizer,
+ WanT5EncoderModel, WanTransformer3DModel)
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ pass
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+@dataclass
+class WanPipelineOutput(BaseOutput):
+ r"""
+ Output class for CogVideo pipelines.
+
+ Args:
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ videos: torch.Tensor
+
+
+class WanFunPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: WanT5EncoderModel,
+ vae: AutoencoderKLWan,
+ transformer: WanTransformer3DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
+ height // self.vae.spacial_compression_ratio,
+ width // self.vae.spacial_compression_ratio,
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
+ frames = (frames / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ frames = frames.cpu().float().numpy()
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "numpy",
+ return_dict: bool = False,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ comfyui_progressbar: bool = False,
+ ) -> Union[WanPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+ Args:
+
+ Examples:
+
+ Returns:
+
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ weight_dtype = self.text_encoder.dtype
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
+
+ # 4. Prepare timesteps
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
+ else:
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+ if comfyui_progressbar:
+ from comfy.utils import ProgressBar
+ pbar = ProgressBar(num_inference_steps + 1)
+
+ # 5. Prepare latents
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ weight_dtype,
+ device,
+ generator,
+ latents,
+ )
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if hasattr(self.scheduler, "scale_model_input"):
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
+ noise_pred = self.transformer(
+ x=latent_model_input,
+ context=prompt_embeds,
+ t=timestep,
+ seq_len=seq_len,
+ )
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ if output_type == "numpy":
+ video = self.decode_latents(latents)
+ elif not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ video = torch.from_numpy(video)
+
+ return WanPipelineOutput(videos=video)
diff --git a/rose/pipeline/pipeline_wan_fun_control.py b/rose/pipeline/pipeline_wan_fun_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..6146bea405bd7ceae40ba16183a135f03ea04cb2
--- /dev/null
+++ b/rose/pipeline/pipeline_wan_fun_control.py
@@ -0,0 +1,723 @@
+import inspect
+import math
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+from diffusers import FlowMatchEulerDiscreteScheduler
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.embeddings import get_1d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from einops import rearrange
+from PIL import Image
+from transformers import T5Tokenizer
+
+from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
+ WanT5EncoderModel, WanTransformer3DModel)
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ pass
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+def resize_mask(mask, latent, process_first_frame_only=True):
+ latent_size = latent.size()
+ batch_size, channels, num_frames, height, width = mask.shape
+
+ if process_first_frame_only:
+ target_size = list(latent_size[2:])
+ target_size[0] = 1
+ first_frame_resized = F.interpolate(
+ mask[:, :, 0:1, :, :],
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+
+ target_size = list(latent_size[2:])
+ target_size[0] = target_size[0] - 1
+ if target_size[0] != 0:
+ remaining_frames_resized = F.interpolate(
+ mask[:, :, 1:, :, :],
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+ else:
+ resized_mask = first_frame_resized
+ else:
+ target_size = list(latent_size[2:])
+ resized_mask = F.interpolate(
+ mask,
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+ return resized_mask
+
+
+@dataclass
+class WanPipelineOutput(BaseOutput):
+ r"""
+ Output class for CogVideo pipelines.
+
+ Args:
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ videos: torch.Tensor
+
+
+class WanFunControlPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: WanT5EncoderModel,
+ vae: AutoencoderKLWan,
+ transformer: WanTransformer3DModel,
+ clip_image_encoder: CLIPModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae.spacial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
+ height // self.vae.spacial_compression_ratio,
+ width // self.vae.spacial_compression_ratio,
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_control_latents(
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the control to latents shape as we concatenate the control to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+
+ if control is not None:
+ control = control.to(device=device, dtype=dtype)
+ bs = 1
+ new_control = []
+ for i in range(0, control.shape[0], bs):
+ control_bs = control[i : i + bs]
+ control_bs = self.vae.encode(control_bs)[0]
+ control_bs = control_bs.mode()
+ new_control.append(control_bs)
+ control = torch.cat(new_control, dim = 0)
+
+ if control_image is not None:
+ control_image = control_image.to(device=device, dtype=dtype)
+ bs = 1
+ new_control_pixel_values = []
+ for i in range(0, control_image.shape[0], bs):
+ control_pixel_values_bs = control_image[i : i + bs]
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
+ control_pixel_values_bs = control_pixel_values_bs.mode()
+ new_control_pixel_values.append(control_pixel_values_bs)
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
+ else:
+ control_image_latents = None
+
+ return control, control_image_latents
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
+ frames = (frames / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ frames = frames.cpu().float().numpy()
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ control_video: Union[torch.FloatTensor] = None,
+ ref_image: Union[torch.FloatTensor] = None,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "numpy",
+ return_dict: bool = False,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ clip_image: Image = None,
+ max_sequence_length: int = 512,
+ comfyui_progressbar: bool = False,
+ ) -> Union[WanPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+ Args:
+
+ Examples:
+
+ Returns:
+
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ weight_dtype = self.text_encoder.dtype
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
+
+ # 4. Prepare timesteps
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
+ else:
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+ if comfyui_progressbar:
+ from comfy.utils import ProgressBar
+ pbar = ProgressBar(num_inference_steps + 2)
+
+ # 5. Prepare latents.
+ latent_channels = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ weight_dtype,
+ device,
+ generator,
+ latents,
+ )
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ # Prepare mask latent variables
+ if control_video is not None:
+ video_length = control_video.shape[2]
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
+ control_video = control_video.to(dtype=torch.float32)
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
+ control_video_latents = self.prepare_control_latents(
+ None,
+ control_video,
+ batch_size,
+ height,
+ width,
+ weight_dtype,
+ device,
+ generator,
+ do_classifier_free_guidance
+ )[1]
+ control_latents = (
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
+ ).to(device, weight_dtype)
+ else:
+ control_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
+ control_latents = (
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
+ ).to(device, weight_dtype)
+
+ if ref_image is not None:
+ video_length = ref_image.shape[2]
+ ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
+ ref_image = ref_image.to(dtype=torch.float32)
+ ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
+
+ ref_image_latentes = self.prepare_control_latents(
+ None,
+ ref_image,
+ batch_size,
+ height,
+ width,
+ weight_dtype,
+ device,
+ generator,
+ do_classifier_free_guidance
+ )[1]
+
+ ref_image_latentes_conv_in = torch.zeros_like(latents)
+ if latents.size()[2] != 1:
+ ref_image_latentes_conv_in[:, :, :1] = ref_image_latentes
+ ref_image_latentes_conv_in = (
+ torch.cat([ref_image_latentes_conv_in] * 2) if do_classifier_free_guidance else ref_image_latentes_conv_in
+ ).to(device, weight_dtype)
+ control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
+ else:
+ ref_image_latentes_conv_in = torch.zeros_like(latents)
+ ref_image_latentes_conv_in = (
+ torch.cat([ref_image_latentes_conv_in] * 2) if do_classifier_free_guidance else ref_image_latentes_conv_in
+ ).to(device, weight_dtype)
+ control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
+
+ # Prepare clip latent variables
+ if clip_image is not None:
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
+ clip_context = (
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
+ )
+ else:
+ clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
+ clip_context = (
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
+ )
+ clip_context = torch.zeros_like(clip_context)
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if hasattr(self.scheduler, "scale_model_input"):
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
+ noise_pred = self.transformer(
+ x=latent_model_input,
+ context=prompt_embeds,
+ t=timestep,
+ seq_len=seq_len,
+ y=control_latents,
+ clip_fea=clip_context,
+ )
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ if output_type == "numpy":
+ video = self.decode_latents(latents)
+ elif not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ video = torch.from_numpy(video)
+
+ return WanPipelineOutput(videos=video)
diff --git a/rose/pipeline/pipeline_wan_fun_inpaint.py b/rose/pipeline/pipeline_wan_fun_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e164274be03157e3abc29852debdac5d6b9a763
--- /dev/null
+++ b/rose/pipeline/pipeline_wan_fun_inpaint.py
@@ -0,0 +1,729 @@
+import inspect
+import math
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+from diffusers import FlowMatchEulerDiscreteScheduler
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.embeddings import get_1d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from einops import rearrange
+from PIL import Image
+from transformers import T5Tokenizer
+from torchvision.utils import save_image
+
+from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
+ WanT5EncoderModel, WanTransformer3DModel)
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ pass
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+def resize_mask(mask, latent, process_first_frame_only=True):
+ latent_size = latent.size()
+ batch_size, channels, num_frames, height, width = mask.shape
+
+ if process_first_frame_only:
+ target_size = list(latent_size[2:])
+ target_size[0] = 1
+ first_frame_resized = F.interpolate(
+ mask[:, :, 0:1, :, :],
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+
+ target_size = list(latent_size[2:])
+ target_size[0] = target_size[0] - 1
+ if target_size[0] != 0:
+ remaining_frames_resized = F.interpolate(
+ mask[:, :, 1:, :, :],
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+ else:
+ resized_mask = first_frame_resized
+ else:
+ target_size = list(latent_size[2:])
+ resized_mask = F.interpolate(
+ mask,
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+ return resized_mask
+
+
+@dataclass
+class WanPipelineOutput(BaseOutput):
+ r"""
+ Output class for CogVideo pipelines.
+
+ Args:
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ videos: torch.Tensor
+
+
+class WanFunInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: WanT5EncoderModel,
+ vae: AutoencoderKLWan,
+ transformer: WanTransformer3DModel,
+ clip_image_encoder: CLIPModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae.spacial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
+ height // self.vae.spacial_compression_ratio,
+ width // self.vae.spacial_compression_ratio,
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+
+ if mask is not None:
+ mask = mask.to(device=device, dtype=self.vae.dtype)
+ bs = 1
+ new_mask = []
+ for i in range(0, mask.shape[0], bs):
+ mask_bs = mask[i : i + bs]
+ mask_bs = self.vae.encode(mask_bs)[0]
+ mask_bs = mask_bs.mode()
+ new_mask.append(mask_bs)
+ mask = torch.cat(new_mask, dim = 0)
+ # mask = mask * self.vae.config.scaling_factor
+
+ if masked_image is not None:
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
+ bs = 1
+ new_mask_pixel_values = []
+ for i in range(0, masked_image.shape[0], bs):
+ mask_pixel_values_bs = masked_image[i : i + bs]
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
+ new_mask_pixel_values.append(mask_pixel_values_bs)
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
+ # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
+ else:
+ masked_image_latents = None
+
+ return mask, masked_image_latents
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
+ frames = (frames / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ frames = frames.cpu().float().numpy()
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ video: Union[torch.FloatTensor] = None,
+ mask_video: Union[torch.FloatTensor] = None,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "numpy",
+ return_dict: bool = False,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ clip_image: Image = None,
+ max_sequence_length: int = 512,
+ comfyui_progressbar: bool = False,
+ ) -> Union[WanPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+ Args:
+
+ Examples:
+
+ Returns:
+
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ weight_dtype = self.text_encoder.dtype
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
+
+ # 4. Prepare timesteps
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
+ else:
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+ if comfyui_progressbar:
+ from comfy.utils import ProgressBar
+ pbar = ProgressBar(num_inference_steps + 2)
+
+ # 5. Prepare latents.
+ if video is not None:
+ video_length = video.shape[2]
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
+ init_video = init_video.to(dtype=torch.float32)
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
+ else:
+ init_video = None
+
+ # import pdb; pdb.set_trace()
+ latent_channels = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ weight_dtype,
+ device,
+ generator,
+ latents,
+ )
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ # Prepare mask latent variables
+ if init_video is not None:
+ if (mask_video == 255).all():
+ mask_latents = torch.tile(
+ torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
+ )
+ masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
+
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+ )
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
+ else:
+ bs, _, video_length, height, width = video.size()
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
+ mask_condition = mask_condition.to(dtype=torch.float32)
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
+
+ # masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
+ masked_video = init_video
+
+ _, masked_video_latents = self.prepare_mask_latents(
+ None,
+ masked_video,
+ batch_size,
+ height,
+ width,
+ weight_dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ noise_aug_strength=None,
+ )
+
+ mask_condition = torch.concat(
+ [
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
+ mask_condition[:, :, 1:]
+ ], dim=2
+ )
+ mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
+ mask_condition = mask_condition.transpose(1, 2)
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
+
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+ )
+
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
+
+ # Prepare clip latent variables
+ if clip_image is not None:
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
+ clip_context = (
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
+ )
+ else:
+ clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
+ clip_context = (
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
+ )
+ clip_context = torch.zeros_like(clip_context)
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if hasattr(self.scheduler, "scale_model_input"):
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
+ noise_pred = self.transformer(
+ x=latent_model_input,
+ context=prompt_embeds,
+ t=timestep,
+ seq_len=seq_len,
+ y=y,
+ clip_fea=clip_context,
+ )
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ if output_type == "numpy":
+ video = self.decode_latents(latents)
+ elif not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ video = torch.from_numpy(video)
+
+ return WanPipelineOutput(videos=video)
diff --git a/rose/utils/__init__.py b/rose/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/rose/utils/discrete_sampler.py b/rose/utils/discrete_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..149dbe7beb94dfea2e6fe0ca3b5acf9437be60f7
--- /dev/null
+++ b/rose/utils/discrete_sampler.py
@@ -0,0 +1,46 @@
+"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
+"""
+import torch
+
+class DiscreteSampling:
+ def __init__(self, num_idx, uniform_sampling=False):
+ self.num_idx = num_idx
+ self.uniform_sampling = uniform_sampling
+ self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
+
+ if self.is_distributed and self.uniform_sampling:
+ world_size = torch.distributed.get_world_size()
+ self.rank = torch.distributed.get_rank()
+
+ i = 1
+ while True:
+ if world_size % i != 0 or num_idx % (world_size // i) != 0:
+ i += 1
+ else:
+ self.group_num = world_size // i
+ break
+ assert self.group_num > 0
+ assert world_size % self.group_num == 0
+ # the number of rank in one group
+ self.group_width = world_size // self.group_num
+ self.sigma_interval = self.num_idx // self.group_num
+ print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
+ self.rank, world_size, self.group_num,
+ self.group_width, self.sigma_interval))
+
+ def __call__(self, n_samples, generator=None, device=None):
+ if self.is_distributed and self.uniform_sampling:
+ group_index = self.rank // self.group_width
+ idx = torch.randint(
+ group_index * self.sigma_interval,
+ (group_index + 1) * self.sigma_interval,
+ (n_samples,),
+ generator=generator, device=device,
+ )
+ print('proc[%d] idx=%s' % (self.rank, idx))
+ else:
+ idx = torch.randint(
+ 0, self.num_idx, (n_samples,),
+ generator=generator, device=device,
+ )
+ return idx
\ No newline at end of file
diff --git a/rose/utils/fp8_optimization.py b/rose/utils/fp8_optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aa6d26fe9a0a0365401ab77ba4103fcc723fca9
--- /dev/null
+++ b/rose/utils/fp8_optimization.py
@@ -0,0 +1,56 @@
+"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper
+"""
+import torch
+import torch.nn as nn
+
+def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
+ weight_dtype = cls.weight.dtype
+ cls.to(origin_dtype)
+
+ # Convert all inputs to the original dtype
+ inputs = [input.to(origin_dtype) for input in inputs]
+ out = cls.original_forward(*inputs, **kwargs)
+
+ cls.to(weight_dtype)
+ return out
+
+def replace_parameters_by_name(module, name_keywords, device):
+ from torch import nn
+ for name, param in list(module.named_parameters(recurse=False)):
+ if any(keyword in name for keyword in name_keywords):
+ if isinstance(param, nn.Parameter):
+ tensor = param.data
+ delattr(module, name)
+ setattr(module, name, tensor.to(device=device))
+ for child_name, child_module in module.named_children():
+ replace_parameters_by_name(child_module, name_keywords, device)
+
+def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']):
+ for name, module in model.named_modules():
+ flag = False
+ for _exclude_module_name in exclude_module_name:
+ if _exclude_module_name in name:
+ flag = True
+ if flag:
+ continue
+ for param_name, param in module.named_parameters():
+ flag = False
+ for _exclude_module_name in exclude_module_name:
+ if _exclude_module_name in param_name:
+ flag = True
+ if flag:
+ continue
+ param.data = param.data.to(torch.float8_e4m3fn)
+
+def convert_weight_dtype_wrapper(module, origin_dtype):
+ for name, module in module.named_modules():
+ if name == "" or "embed_tokens" in name:
+ continue
+ original_forward = module.forward
+ if hasattr(module, "weight") and module.weight is not None:
+ setattr(module, "original_forward", original_forward)
+ setattr(
+ module,
+ "forward",
+ lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
+ )
diff --git a/rose/utils/lora_utils.py b/rose/utils/lora_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f801fa8f7a5a4024a4d4b003513813ed38f6649c
--- /dev/null
+++ b/rose/utils/lora_utils.py
@@ -0,0 +1,516 @@
+# LoRA network module
+# reference:
+# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
+# https://github.com/bmaltais/kohya_ss
+
+import hashlib
+import math
+import os
+from collections import defaultdict
+from io import BytesIO
+from typing import List, Optional, Type, Union
+
+import safetensors.torch
+import torch
+import torch.utils.checkpoint
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from safetensors.torch import load_file
+from transformers import T5EncoderModel
+
+
+class LoRAModule(torch.nn.Module):
+ """
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
+ """
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ ):
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
+ super().__init__()
+ self.lora_name = lora_name
+
+ if org_module.__class__.__name__ == "Conv2d":
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+
+ self.lora_dim = lora_dim
+ if org_module.__class__.__name__ == "Conv2d":
+ kernel_size = org_module.kernel_size
+ stride = org_module.stride
+ padding = org_module.padding
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+ else:
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha))
+
+ # same as microsoft's
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+ torch.nn.init.zeros_(self.lora_up.weight)
+
+ self.multiplier = multiplier
+ self.org_module = org_module # remove in applying
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ def apply_to(self):
+ self.org_forward = self.org_module.forward
+ self.org_module.forward = self.forward
+ del self.org_module
+
+ def forward(self, x, *args, **kwargs):
+ weight_dtype = x.dtype
+ org_forwarded = self.org_forward(x)
+
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return org_forwarded
+
+ lx = self.lora_down(x.to(self.lora_down.weight.dtype))
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+ if len(lx.size()) == 3:
+ mask = mask.unsqueeze(1) # for Text Encoder
+ elif len(lx.size()) == 4:
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
+ lx = lx * mask
+
+ # scaling for rank dropout: treat as if the rank is changed
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lx = self.lora_up(lx)
+
+ return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
+
+
+def addnet_hash_legacy(b):
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
+ m = hashlib.sha256()
+
+ b.seek(0x100000)
+ m.update(b.read(0x10000))
+ return m.hexdigest()[0:8]
+
+
+def addnet_hash_safetensors(b):
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def precalculate_safetensors_hashes(tensors, metadata):
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
+ save time on indexing the model later."""
+
+ # Because writing user metadata to the file can change the result of
+ # sd_models.model_hash(), only retain the training metadata for purposes of
+ # calculating the hash, as they are meant to be immutable
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
+
+ bytes = safetensors.torch.save(tensors, metadata)
+ b = BytesIO(bytes)
+
+ model_hash = addnet_hash_safetensors(b)
+ legacy_hash = addnet_hash_legacy(b)
+ return model_hash, legacy_hash
+
+
+class LoRANetwork(torch.nn.Module):
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel", "WanTransformer3DModel"]
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"]
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
+ def __init__(
+ self,
+ text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
+ unet,
+ multiplier: float = 1.0,
+ lora_dim: int = 4,
+ alpha: float = 1,
+ dropout: Optional[float] = None,
+ module_class: Type[object] = LoRAModule,
+ skip_name: str = None,
+ varbose: Optional[bool] = False,
+ ) -> None:
+ super().__init__()
+ self.multiplier = multiplier
+
+ self.lora_dim = lora_dim
+ self.alpha = alpha
+ self.dropout = dropout
+
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+ print(f"neuron dropout: p={self.dropout}")
+
+ # create module instances
+ def create_modules(
+ is_unet: bool,
+ root_module: torch.nn.Module,
+ target_replace_modules: List[torch.nn.Module],
+ ) -> List[LoRAModule]:
+ prefix = (
+ self.LORA_PREFIX_TRANSFORMER
+ if is_unet
+ else self.LORA_PREFIX_TEXT_ENCODER
+ )
+ loras = []
+ skipped = []
+ for name, module in root_module.named_modules():
+ if module.__class__.__name__ in target_replace_modules:
+ for child_name, child_module in module.named_modules():
+ is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
+ is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+ if skip_name is not None and skip_name in child_name:
+ continue
+
+ if is_linear or is_conv2d:
+ lora_name = prefix + "." + name + "." + child_name
+ lora_name = lora_name.replace(".", "_")
+
+ dim = None
+ alpha = None
+
+ if is_linear or is_conv2d_1x1:
+ dim = self.lora_dim
+ alpha = self.alpha
+
+ if dim is None or dim == 0:
+ if is_linear or is_conv2d_1x1:
+ skipped.append(lora_name)
+ continue
+
+ lora = module_class(
+ lora_name,
+ child_module,
+ self.multiplier,
+ dim,
+ alpha,
+ dropout=dropout,
+ )
+ loras.append(lora)
+ return loras, skipped
+
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
+
+ self.text_encoder_loras = []
+ skipped_te = []
+ for i, text_encoder in enumerate(text_encoders):
+ if text_encoder is not None:
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+ self.text_encoder_loras.extend(text_encoder_loras)
+ skipped_te += skipped
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
+
+ self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
+
+ # assertion
+ names = set()
+ for lora in self.text_encoder_loras + self.unet_loras:
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+ names.add(lora.lora_name)
+
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
+ if apply_text_encoder:
+ print("enable LoRA for text encoder")
+ else:
+ self.text_encoder_loras = []
+
+ if apply_unet:
+ print("enable LoRA for U-Net")
+ else:
+ self.unet_loras = []
+
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.apply_to()
+ self.add_module(lora.lora_name, lora)
+
+ def set_multiplier(self, multiplier):
+ self.multiplier = multiplier
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.multiplier = self.multiplier
+
+ def load_weights(self, file):
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+ info = self.load_state_dict(weights_sd, False)
+ return info
+
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
+ self.requires_grad_(True)
+ all_params = []
+
+ def enumerate_params(loras):
+ params = []
+ for lora in loras:
+ params.extend(lora.parameters())
+ return params
+
+ if self.text_encoder_loras:
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
+ if text_encoder_lr is not None:
+ param_data["lr"] = text_encoder_lr
+ all_params.append(param_data)
+
+ if self.unet_loras:
+ param_data = {"params": enumerate_params(self.unet_loras)}
+ if unet_lr is not None:
+ param_data["lr"] = unet_lr
+ all_params.append(param_data)
+
+ return all_params
+
+ def enable_gradient_checkpointing(self):
+ pass
+
+ def get_trainable_params(self):
+ return self.parameters()
+
+ def save_weights(self, file, dtype, metadata):
+ if metadata is not None and len(metadata) == 0:
+ metadata = None
+
+ state_dict = self.state_dict()
+
+ if dtype is not None:
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(dtype)
+ state_dict[key] = v
+
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import save_file
+
+ # Precalculate model hashes to save time on indexing
+ if metadata is None:
+ metadata = {}
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
+ metadata["sshs_model_hash"] = model_hash
+ metadata["sshs_legacy_hash"] = legacy_hash
+
+ save_file(state_dict, file, metadata)
+ else:
+ torch.save(state_dict, file)
+
+def create_network(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
+ transformer,
+ neuron_dropout: Optional[float] = None,
+ skip_name: str = None,
+ **kwargs,
+):
+ if network_dim is None:
+ network_dim = 4 # default
+ if network_alpha is None:
+ network_alpha = 1.0
+
+ network = LoRANetwork(
+ text_encoder,
+ transformer,
+ multiplier=multiplier,
+ lora_dim=network_dim,
+ alpha=network_alpha,
+ dropout=neuron_dropout,
+ skip_name=skip_name,
+ varbose=True,
+ )
+ return network
+
+def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
+ if state_dict is None:
+ state_dict = load_file(lora_path, device=device)
+ else:
+ state_dict = state_dict
+ updates = defaultdict(dict)
+ for key, value in state_dict.items():
+ layer, elem = key.split('.', 1)
+ updates[layer][elem] = value
+
+ sequential_cpu_offload_flag = False
+ if pipeline.transformer.device == torch.device(type="meta"):
+ pipeline.remove_all_hooks()
+ sequential_cpu_offload_flag = True
+ offload_device = pipeline._offload_device
+
+ for layer, elems in updates.items():
+
+ if "lora_te" in layer:
+ if transformer_only:
+ continue
+ else:
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+ curr_layer = pipeline.text_encoder
+ else:
+ layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
+ curr_layer = pipeline.transformer
+
+ try:
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
+ except Exception:
+ temp_name = layer_infos.pop(0)
+ while len(layer_infos) > -1:
+ try:
+ curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
+ break
+ except Exception:
+ try:
+ curr_layer = curr_layer.__getattr__(temp_name)
+ if len(layer_infos) > 0:
+ temp_name = layer_infos.pop(0)
+ elif len(layer_infos) == 0:
+ break
+ except Exception:
+ if len(layer_infos) == 0:
+ print('Error loading layer')
+ if len(temp_name) > 0:
+ temp_name += "_" + layer_infos.pop(0)
+ else:
+ temp_name = layer_infos.pop(0)
+
+ origin_dtype = curr_layer.weight.data.dtype
+ origin_device = curr_layer.weight.data.device
+
+ curr_layer = curr_layer.to(device, dtype)
+ weight_up = elems['lora_up.weight'].to(device, dtype)
+ weight_down = elems['lora_down.weight'].to(device, dtype)
+
+ if 'alpha' in elems.keys():
+ alpha = elems['alpha'].item() / weight_up.shape[1]
+ else:
+ alpha = 1.0
+
+ if len(weight_up.shape) == 4:
+ curr_layer.weight.data += multiplier * alpha * torch.mm(
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
+ ).unsqueeze(2).unsqueeze(3)
+ else:
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
+
+ if sequential_cpu_offload_flag:
+ pipeline.enable_sequential_cpu_offload(device=offload_device)
+ return pipeline
+
+# TODO: Refactor with merge_lora.
+def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
+ """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
+ LORA_PREFIX_UNET = "lora_unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
+ state_dict = load_file(lora_path, device=device)
+
+ updates = defaultdict(dict)
+ for key, value in state_dict.items():
+ layer, elem = key.split('.', 1)
+ updates[layer][elem] = value
+
+ sequential_cpu_offload_flag = False
+ if pipeline.transformer.device == torch.device(type="meta"):
+ pipeline.remove_all_hooks()
+ sequential_cpu_offload_flag = True
+
+ for layer, elems in updates.items():
+
+ if "lora_te" in layer:
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+ curr_layer = pipeline.text_encoder
+ else:
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
+ curr_layer = pipeline.transformer
+
+ try:
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
+ except Exception:
+ temp_name = layer_infos.pop(0)
+ while len(layer_infos) > -1:
+ try:
+ curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
+ break
+ except Exception:
+ try:
+ curr_layer = curr_layer.__getattr__(temp_name)
+ if len(layer_infos) > 0:
+ temp_name = layer_infos.pop(0)
+ elif len(layer_infos) == 0:
+ break
+ except Exception:
+ if len(layer_infos) == 0:
+ print('Error loading layer')
+ if len(temp_name) > 0:
+ temp_name += "_" + layer_infos.pop(0)
+ else:
+ temp_name = layer_infos.pop(0)
+
+ origin_dtype = curr_layer.weight.data.dtype
+ origin_device = curr_layer.weight.data.device
+
+ curr_layer = curr_layer.to(device, dtype)
+ weight_up = elems['lora_up.weight'].to(device, dtype)
+ weight_down = elems['lora_down.weight'].to(device, dtype)
+
+ if 'alpha' in elems.keys():
+ alpha = elems['alpha'].item() / weight_up.shape[1]
+ else:
+ alpha = 1.0
+
+ if len(weight_up.shape) == 4:
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
+ ).unsqueeze(2).unsqueeze(3)
+ else:
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
+
+ if sequential_cpu_offload_flag:
+ pipeline.enable_sequential_cpu_offload(device=device)
+ return pipeline
diff --git a/rose/utils/utils.py b/rose/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe8d8139df3531f8d12ab48e9f599d64f0594aeb
--- /dev/null
+++ b/rose/utils/utils.py
@@ -0,0 +1,318 @@
+import os
+import gc
+import imageio
+import inspect
+import numpy as np
+import torch
+import torchvision
+import cv2
+from einops import rearrange
+from PIL import Image
+
+def filter_kwargs(cls, kwargs):
+ sig = inspect.signature(cls.__init__)
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
+ return filtered_kwargs
+
+def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
+ target_pixels = int(base_resolution) * int(base_resolution)
+ original_width, original_height = Image.open(image).size
+ ratio = (target_pixels / (original_width * original_height)) ** 0.5
+ width_slider = round(original_width * ratio)
+ height_slider = round(original_height * ratio)
+ return height_slider, width_slider
+
+def color_transfer(sc, dc):
+ """
+ Transfer color distribution from of sc, referred to dc.
+
+ Args:
+ sc (numpy.ndarray): input image to be transfered.
+ dc (numpy.ndarray): reference image
+
+ Returns:
+ numpy.ndarray: Transferred color distribution on the sc.
+ """
+
+ def get_mean_and_std(img):
+ x_mean, x_std = cv2.meanStdDev(img)
+ x_mean = np.hstack(np.around(x_mean, 2))
+ x_std = np.hstack(np.around(x_std, 2))
+ return x_mean, x_std
+
+ sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
+ s_mean, s_std = get_mean_and_std(sc)
+ dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
+ t_mean, t_std = get_mean_and_std(dc)
+ img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
+ np.putmask(img_n, img_n > 255, 255)
+ np.putmask(img_n, img_n < 0, 0)
+ dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
+ return dst
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(Image.fromarray(x))
+
+ if color_transfer_post_process:
+ for i in range(1, len(outputs)):
+ outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ if imageio_backend:
+ if path.endswith("mp4"):
+ imageio.mimsave(path, outputs, fps=fps)
+ else:
+ imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
+ else:
+ if path.endswith("mp4"):
+ path = path.replace('.mp4', '.gif')
+ outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
+
+def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
+ if validation_image_start is not None and validation_image_end is not None:
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
+ else:
+ image_start = clip_image = validation_image_start
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
+
+ if type(validation_image_end) is str and os.path.isfile(validation_image_end):
+ image_end = Image.open(validation_image_end).convert("RGB")
+ image_end = image_end.resize([sample_size[1], sample_size[0]])
+ else:
+ image_end = validation_image_end
+ image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
+
+ if type(image_start) is list:
+ clip_image = clip_image[0]
+ start_video = torch.cat(
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
+ dim=2
+ )
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
+ input_video[:, :, :len(image_start)] = start_video
+
+ input_video_mask = torch.zeros_like(input_video[:, :1])
+ input_video_mask[:, :, len(image_start):] = 255
+ else:
+ input_video = torch.tile(
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
+ [1, 1, video_length, 1, 1]
+ )
+ input_video_mask = torch.zeros_like(input_video[:, :1])
+ input_video_mask[:, :, 1:] = 255
+
+ if type(image_end) is list:
+ image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
+ end_video = torch.cat(
+ [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
+ dim=2
+ )
+ input_video[:, :, -len(end_video):] = end_video
+
+ input_video_mask[:, :, -len(image_end):] = 0
+ else:
+ image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
+ input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
+ input_video_mask[:, :, -1:] = 0
+
+ input_video = input_video / 255
+
+ elif validation_image_start is not None:
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
+ else:
+ image_start = clip_image = validation_image_start
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
+ image_end = None
+
+ if type(image_start) is list:
+ clip_image = clip_image[0]
+ start_video = torch.cat(
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
+ dim=2
+ )
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
+ input_video[:, :, :len(image_start)] = start_video
+ input_video = input_video / 255
+
+ input_video_mask = torch.zeros_like(input_video[:, :1])
+ input_video_mask[:, :, len(image_start):] = 255
+ else:
+ input_video = torch.tile(
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
+ [1, 1, video_length, 1, 1]
+ ) / 255
+ input_video_mask = torch.zeros_like(input_video[:, :1])
+ input_video_mask[:, :, 1:, ] = 255
+ else:
+ image_start = None
+ image_end = None
+ input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
+ input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
+ clip_image = None
+
+ del image_start
+ del image_end
+ gc.collect()
+
+ return input_video, input_video_mask, clip_image
+
+def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
+ if input_video_path is not None:
+ if isinstance(input_video_path, str):
+ cap = cv2.VideoCapture(input_video_path)
+ input_video = []
+
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
+ frame_skip = 1 if fps is None else int(original_fps // fps)
+
+ frame_count = 0
+
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ if frame_count % frame_skip == 0:
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+
+ frame_count += 1
+
+ cap.release()
+ else:
+ input_video = input_video_path
+
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
+
+ if validation_video_mask is not None:
+ validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
+ input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
+
+ input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
+ else:
+ input_video_mask = torch.zeros_like(input_video[:, :1])
+ input_video_mask[:, :, :] = 255
+ else:
+ input_video, input_video_mask = None, None
+
+ if ref_image is not None:
+ if isinstance(ref_image, str):
+ clip_image = Image.open(ref_image).convert("RGB")
+ else:
+ clip_image = Image.fromarray(np.array(ref_image, np.uint8))
+ else:
+ clip_image = None
+
+ if ref_image is not None:
+ if isinstance(ref_image, str):
+ ref_image = Image.open(ref_image).convert("RGB")
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
+ ref_image = torch.from_numpy(np.array(ref_image))
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
+ else:
+ ref_image = torch.from_numpy(np.array(ref_image))
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
+ return input_video, input_video_mask, ref_image, clip_image
+
+
+def get_video_and_mask(input_video_path, video_length, sample_size, fps=None, input_mask_path=None, ref_image=None):
+ if input_video_path is not None:
+ if isinstance(input_video_path, str):
+ cap = cv2.VideoCapture(input_video_path)
+ input_video = []
+
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
+ frame_skip = 1 if fps is None else int(original_fps // fps)
+
+ frame_count = 0
+
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ if frame_count % frame_skip == 0:
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+
+ frame_count += 1
+
+ cap.release()
+ else:
+ input_video = input_video_path
+
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255.0
+
+ else:
+ input_video = None
+
+ if input_mask_path is not None:
+ if isinstance(input_mask_path, str):
+ cap = cv2.VideoCapture(input_mask_path)
+ mask_frames = []
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
+ frame_skip = 1 if fps is None else int(original_fps // fps)
+ frame_count = 0
+
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+ if frame_count % frame_skip == 0:
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ mask_frames.append(gray)
+ frame_count += 1
+ cap.release()
+ else:
+ mask_frames = input_mask_path
+
+ mask_np = np.array(mask_frames)[:video_length] # (F, H, W), uint8
+ mask_bin = np.where(mask_np < 240, 0, 1).astype(np.uint8) # (F,H,W)
+ mask_tensor = torch.from_numpy(mask_bin)
+ mask_tensor = mask_tensor.unsqueeze(1)
+ mask_tensor = mask_tensor.unsqueeze(0)
+ input_mask = mask_tensor.permute(0,2,1,3,4)
+ input_mask = input_mask.float()
+ else:
+ input_mask = None
+
+ if ref_image is not None:
+ if isinstance(ref_image, str):
+ clip_image = Image.open(ref_image).convert("RGB")
+ else:
+ clip_image = Image.fromarray(np.array(ref_image, np.uint8))
+ else:
+ clip_image = None
+
+ if ref_image is not None:
+ if isinstance(ref_image, str):
+ ref_image = Image.open(ref_image).convert("RGB")
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
+ ref_image = torch.from_numpy(np.array(ref_image))
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
+ else:
+ ref_image = torch.from_numpy(np.array(ref_image))
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
+ return input_video, input_mask, ref_image, clip_image
diff --git a/test_sample/test-sample0.mp4 b/test_sample/test-sample0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8bdc127da76beff17e384b2152b8b95002679328
--- /dev/null
+++ b/test_sample/test-sample0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d63abe5ce28c76c83b0f575e7b8cd2707ffb58e3e478b7f500865700d4738a2
+size 476512
diff --git a/test_sample/test-sample1.mp4 b/test_sample/test-sample1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..88e6cbf54eefe62c695bb0f8f543466d16d7d386
--- /dev/null
+++ b/test_sample/test-sample1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54369037f2fb7c42ada4165618894cedcd87d8368d40547571b43e5fd4ff0025
+size 975899
diff --git a/test_sample/test-sample2.mp4 b/test_sample/test-sample2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..36f585a88144ffaa238a1493f0ef8dca00129924
--- /dev/null
+++ b/test_sample/test-sample2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11e3e5b0aea5881bf8c5edf8a8196b51bef11fac019ee96192cc9b24146ae07b
+size 286526
diff --git a/test_sample/test-sample3.mp4 b/test_sample/test-sample3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..cdccc497b8ce107ee2fe8d2e9c70ed52c047cf5e
--- /dev/null
+++ b/test_sample/test-sample3.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29f38c57585dad9946aaa32f7ef9faae10bcb87913708462114da5fb164d4775
+size 146705
diff --git a/test_sample/test-sample4.mp4 b/test_sample/test-sample4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4305c6792e03d811cb3ddcb9ac557d9b1d28391f
--- /dev/null
+++ b/test_sample/test-sample4.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0a42fa9abcfa72baee2a437b06b4e1981b5109c815795bf005d4f7d5cd47096
+size 1465415
diff --git a/tools/__init__.py b/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tools/base_segmenter.py b/tools/base_segmenter.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b975bb779b47485f9e6ba7435646b4db40a2c6a
--- /dev/null
+++ b/tools/base_segmenter.py
@@ -0,0 +1,129 @@
+import time
+import torch
+import cv2
+from PIL import Image, ImageDraw, ImageOps
+import numpy as np
+from typing import Union
+from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
+import matplotlib.pyplot as plt
+import PIL
+from .mask_painter import mask_painter
+
+
+class BaseSegmenter:
+ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
+ """
+ device: model device
+ SAM_checkpoint: path of SAM checkpoint
+ model_type: vit_b, vit_l, vit_h
+ """
+ print(f"Initializing BaseSegmenter to {device}")
+ assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
+
+ self.device = device
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
+ self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
+ self.model.to(device=self.device)
+ self.predictor = SamPredictor(self.model)
+ self.embedded = False
+
+ @torch.no_grad()
+ def set_image(self, image: np.ndarray):
+ # PIL.open(image_path) 3channel: RGB
+ # image embedding: avoid encode the same image multiple times
+ self.orignal_image = image
+ if self.embedded:
+ print('repeat embedding, please reset_image.')
+ return
+ self.predictor.set_image(image)
+ self.embedded = True
+ return
+
+ @torch.no_grad()
+ def reset_image(self):
+ # reset image embeding
+ self.predictor.reset_image()
+ self.embedded = False
+
+ def predict(self, prompts, mode, multimask=True):
+ """
+ image: numpy array, h, w, 3
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
+ prompts['point_coords']: numpy array [N,2]
+ prompts['point_labels']: numpy array [1,N]
+ prompts['mask_input']: numpy array [1,256,256]
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
+ """
+ assert self.embedded, 'prediction is called before set_image (feature embedding).'
+ assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
+
+ if mode == 'point':
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
+ point_labels=prompts['point_labels'],
+ multimask_output=multimask)
+ elif mode == 'mask':
+ masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
+ multimask_output=multimask)
+ elif mode == 'both': # both
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
+ point_labels=prompts['point_labels'],
+ mask_input=prompts['mask_input'],
+ multimask_output=multimask)
+ else:
+ raise("Not implement now!")
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ return masks, scores, logits
+
+
+if __name__ == "__main__":
+ # load and show an image
+ image = cv2.imread('/hhd3/gaoshang/truck.jpg')
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
+
+ # initialise BaseSegmenter
+ SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
+ model_type = 'vit_h'
+ device = "cuda:4"
+ base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
+
+ # image embedding (once embedded, multiple prompts can be applied)
+ base_segmenter.set_image(image)
+
+ # examples
+ # point only ------------------------
+ mode = 'point'
+ prompts = {
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
+ 'point_labels': np.array([1, 1]),
+ }
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+ cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
+
+ # both ------------------------
+ mode = 'both'
+ mask_input = logits[np.argmax(scores), :, :]
+ prompts = {'mask_input': mask_input [None, :, :]}
+ prompts = {
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
+ 'point_labels': np.array([1, 0]),
+ 'mask_input': mask_input[None, :, :]
+ }
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+ cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
+
+ # mask only ------------------------
+ mode = 'mask'
+ mask_input = logits[np.argmax(scores), :, :]
+
+ prompts = {'mask_input': mask_input[None, :, :]}
+
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
+ cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
diff --git a/tools/interact_tools.py b/tools/interact_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..c70b8c40c6c57f19242b42bfcde60f378b1ce7ba
--- /dev/null
+++ b/tools/interact_tools.py
@@ -0,0 +1,99 @@
+import time
+import torch
+import cv2
+from PIL import Image, ImageDraw, ImageOps
+import numpy as np
+from typing import Union
+from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
+import matplotlib.pyplot as plt
+import PIL
+from .mask_painter import mask_painter as mask_painter2
+from .base_segmenter import BaseSegmenter
+from .painter import mask_painter, point_painter
+import os
+import requests
+import sys
+
+
+mask_color = 3
+mask_alpha = 0.7
+contour_color = 1
+contour_width = 5
+point_color_ne = 8
+point_color_ps = 50
+point_alpha = 0.9
+point_radius = 15
+contour_color = 2
+contour_width = 5
+
+
+class SamControler():
+ def __init__(self, SAM_checkpoint, model_type, device):
+ '''
+ initialize sam controler
+ '''
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
+
+
+ # def seg_again(self, image: np.ndarray):
+ # '''
+ # it is used when interact in video
+ # '''
+ # self.sam_controler.reset_image()
+ # self.sam_controler.set_image(image)
+ # return
+
+
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
+ '''
+ it is used in first frame in video
+ return: mask, logit, painted image(mask+point)
+ '''
+ # self.sam_controler.set_image(image)
+ origal_image = self.sam_controler.orignal_image
+ neg_flag = labels[-1]
+ if neg_flag==1:
+ #find neg
+ prompts = {
+ 'point_coords': points,
+ 'point_labels': labels,
+ }
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+ prompts = {
+ 'point_coords': points,
+ 'point_labels': labels,
+ 'mask_input': logit[None, :, :]
+ }
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+ else:
+ #find positive
+ prompts = {
+ 'point_coords': points,
+ 'point_labels': labels,
+ }
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
+
+
+ assert len(points)==len(labels)
+
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
+ painted_image = Image.fromarray(painted_image)
+
+ return mask, logit, painted_image
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tools/mask_painter.py b/tools/mask_painter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f471ea0116d656e2cc236832893b07c6d7be1643
--- /dev/null
+++ b/tools/mask_painter.py
@@ -0,0 +1,288 @@
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+import copy
+import time
+
+
+def colormap(rgb=True):
+ color_list = np.array(
+ [
+ 0.000, 0.000, 0.000,
+ 1.000, 1.000, 1.000,
+ 1.000, 0.498, 0.313,
+ 0.392, 0.581, 0.929,
+ 0.000, 0.447, 0.741,
+ 0.850, 0.325, 0.098,
+ 0.929, 0.694, 0.125,
+ 0.494, 0.184, 0.556,
+ 0.466, 0.674, 0.188,
+ 0.301, 0.745, 0.933,
+ 0.635, 0.078, 0.184,
+ 0.300, 0.300, 0.300,
+ 0.600, 0.600, 0.600,
+ 1.000, 0.000, 0.000,
+ 1.000, 0.500, 0.000,
+ 0.749, 0.749, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.333, 0.333, 0.000,
+ 0.333, 0.667, 0.000,
+ 0.333, 1.000, 0.000,
+ 0.667, 0.333, 0.000,
+ 0.667, 0.667, 0.000,
+ 0.667, 1.000, 0.000,
+ 1.000, 0.333, 0.000,
+ 1.000, 0.667, 0.000,
+ 1.000, 1.000, 0.000,
+ 0.000, 0.333, 0.500,
+ 0.000, 0.667, 0.500,
+ 0.000, 1.000, 0.500,
+ 0.333, 0.000, 0.500,
+ 0.333, 0.333, 0.500,
+ 0.333, 0.667, 0.500,
+ 0.333, 1.000, 0.500,
+ 0.667, 0.000, 0.500,
+ 0.667, 0.333, 0.500,
+ 0.667, 0.667, 0.500,
+ 0.667, 1.000, 0.500,
+ 1.000, 0.000, 0.500,
+ 1.000, 0.333, 0.500,
+ 1.000, 0.667, 0.500,
+ 1.000, 1.000, 0.500,
+ 0.000, 0.333, 1.000,
+ 0.000, 0.667, 1.000,
+ 0.000, 1.000, 1.000,
+ 0.333, 0.000, 1.000,
+ 0.333, 0.333, 1.000,
+ 0.333, 0.667, 1.000,
+ 0.333, 1.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.667, 0.333, 1.000,
+ 0.667, 0.667, 1.000,
+ 0.667, 1.000, 1.000,
+ 1.000, 0.000, 1.000,
+ 1.000, 0.333, 1.000,
+ 1.000, 0.667, 1.000,
+ 0.167, 0.000, 0.000,
+ 0.333, 0.000, 0.000,
+ 0.500, 0.000, 0.000,
+ 0.667, 0.000, 0.000,
+ 0.833, 0.000, 0.000,
+ 1.000, 0.000, 0.000,
+ 0.000, 0.167, 0.000,
+ 0.000, 0.333, 0.000,
+ 0.000, 0.500, 0.000,
+ 0.000, 0.667, 0.000,
+ 0.000, 0.833, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 0.167,
+ 0.000, 0.000, 0.333,
+ 0.000, 0.000, 0.500,
+ 0.000, 0.000, 0.667,
+ 0.000, 0.000, 0.833,
+ 0.000, 0.000, 1.000,
+ 0.143, 0.143, 0.143,
+ 0.286, 0.286, 0.286,
+ 0.429, 0.429, 0.429,
+ 0.571, 0.571, 0.571,
+ 0.714, 0.714, 0.714,
+ 0.857, 0.857, 0.857
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3)) * 255
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list
+
+
+color_list = colormap()
+color_list = color_list.astype('uint8').tolist()
+
+
+def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
+ background_color = np.array(background_color)
+ contour_color = np.array(contour_color)
+
+ # background_mask = 1 - background_mask
+ # contour_mask = 1 - contour_mask
+
+ for i in range(3):
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
+
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
+
+ return image.astype('uint8')
+
+
+def mask_generator_00(mask, background_radius, contour_radius):
+ # no background width when '00'
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+
+ return mask, contour_mask
+
+
+def mask_generator_01(mask, background_radius, contour_radius):
+ # no background width when '00'
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ return mask, contour_mask
+
+
+def mask_generator_10(mask, background_radius, contour_radius):
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # .....:::::!!!!!
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
+ background_mask = (background_mask - np.min(background_mask))
+ background_mask = background_mask / np.max(background_mask)
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+ return background_mask, contour_mask
+
+
+def mask_generator_11(mask, background_radius, contour_radius):
+ # distance map
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # .....:::::!!!!!
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
+ background_mask = (background_mask - np.min(background_mask))
+ background_mask = background_mask / np.max(background_mask)
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ return background_mask, contour_mask
+
+
+def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
+ """
+ Input:
+ input_image: numpy array
+ input_mask: numpy array
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
+ background_blur_radius: radius of background blur, must be odd number
+ contour_width: width of mask contour, must be odd number
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
+
+ Output:
+ painted_image: numpy array
+ """
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
+
+ # downsample input image and mask
+ width, height = input_image.shape[0], input_image.shape[1]
+ res = 1024
+ ratio = min(1.0 * res / max(width, height), 1.0)
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
+
+ # 0: background, 1: foreground
+ msk = np.clip(input_mask, 0, 1)
+
+ # generate masks for background and contour pixels
+ background_radius = (background_blur_radius - 1) // 2
+ contour_radius = (contour_width - 1) // 2
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
+
+ # paint
+ painted_image = vis_add_mask\
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
+
+ return painted_image
+
+
+if __name__ == '__main__':
+
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
+ background_blur_radius = 31 # radius of background blur, must be odd number
+ contour_width = 11 # contour width, must be odd number
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
+
+ # load input image and mask
+ input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
+ input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
+
+ # paint
+ overall_time_1 = 0
+ overall_time_2 = 0
+ overall_time_3 = 0
+ overall_time_4 = 0
+ overall_time_5 = 0
+
+ for i in range(50):
+ t2 = time.time()
+ painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
+ e2 = time.time()
+
+ t3 = time.time()
+ painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
+ e3 = time.time()
+
+ t1 = time.time()
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
+ e1 = time.time()
+
+ t4 = time.time()
+ painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
+ e4 = time.time()
+
+ t5 = time.time()
+ painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
+ e5 = time.time()
+
+ overall_time_1 += (e1 - t1)
+ overall_time_2 += (e2 - t2)
+ overall_time_3 += (e3 - t3)
+ overall_time_4 += (e4 - t4)
+ overall_time_5 += (e5 - t5)
+
+ print(f'average time w gaussian: {overall_time_1/50}')
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
+
+ # save
+ painted_image_00 = Image.fromarray(painted_image_00)
+ painted_image_00.save('./test_img/painter_output_image_00.png')
+
+ painted_image_10 = Image.fromarray(painted_image_10)
+ painted_image_10.save('./test_img/painter_output_image_10.png')
+
+ painted_image_01 = Image.fromarray(painted_image_01)
+ painted_image_01.save('./test_img/painter_output_image_01.png')
+
+ painted_image_11 = Image.fromarray(painted_image_11)
+ painted_image_11.save('./test_img/painter_output_image_11.png')
diff --git a/tools/painter.py b/tools/painter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e711d35aa8348d15cdad9d1cd413da41ea4f1ab
--- /dev/null
+++ b/tools/painter.py
@@ -0,0 +1,215 @@
+# paint masks, contours, or points on images, with specified colors
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+import copy
+import time
+
+
+def colormap(rgb=True):
+ color_list = np.array(
+ [
+ 0.000, 0.000, 0.000,
+ 1.000, 1.000, 1.000,
+ 1.000, 0.498, 0.313,
+ 0.392, 0.581, 0.929,
+ 0.000, 0.447, 0.741,
+ 0.850, 0.325, 0.098,
+ 0.929, 0.694, 0.125,
+ 0.494, 0.184, 0.556,
+ 0.466, 0.674, 0.188,
+ 0.301, 0.745, 0.933,
+ 0.635, 0.078, 0.184,
+ 0.300, 0.300, 0.300,
+ 0.600, 0.600, 0.600,
+ 1.000, 0.000, 0.000,
+ 1.000, 0.500, 0.000,
+ 0.749, 0.749, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.333, 0.333, 0.000,
+ 0.333, 0.667, 0.000,
+ 0.333, 1.000, 0.000,
+ 0.667, 0.333, 0.000,
+ 0.667, 0.667, 0.000,
+ 0.667, 1.000, 0.000,
+ 1.000, 0.333, 0.000,
+ 1.000, 0.667, 0.000,
+ 1.000, 1.000, 0.000,
+ 0.000, 0.333, 0.500,
+ 0.000, 0.667, 0.500,
+ 0.000, 1.000, 0.500,
+ 0.333, 0.000, 0.500,
+ 0.333, 0.333, 0.500,
+ 0.333, 0.667, 0.500,
+ 0.333, 1.000, 0.500,
+ 0.667, 0.000, 0.500,
+ 0.667, 0.333, 0.500,
+ 0.667, 0.667, 0.500,
+ 0.667, 1.000, 0.500,
+ 1.000, 0.000, 0.500,
+ 1.000, 0.333, 0.500,
+ 1.000, 0.667, 0.500,
+ 1.000, 1.000, 0.500,
+ 0.000, 0.333, 1.000,
+ 0.000, 0.667, 1.000,
+ 0.000, 1.000, 1.000,
+ 0.333, 0.000, 1.000,
+ 0.333, 0.333, 1.000,
+ 0.333, 0.667, 1.000,
+ 0.333, 1.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.667, 0.333, 1.000,
+ 0.667, 0.667, 1.000,
+ 0.667, 1.000, 1.000,
+ 1.000, 0.000, 1.000,
+ 1.000, 0.333, 1.000,
+ 1.000, 0.667, 1.000,
+ 0.167, 0.000, 0.000,
+ 0.333, 0.000, 0.000,
+ 0.500, 0.000, 0.000,
+ 0.667, 0.000, 0.000,
+ 0.833, 0.000, 0.000,
+ 1.000, 0.000, 0.000,
+ 0.000, 0.167, 0.000,
+ 0.000, 0.333, 0.000,
+ 0.000, 0.500, 0.000,
+ 0.000, 0.667, 0.000,
+ 0.000, 0.833, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 0.167,
+ 0.000, 0.000, 0.333,
+ 0.000, 0.000, 0.500,
+ 0.000, 0.000, 0.667,
+ 0.000, 0.000, 0.833,
+ 0.000, 0.000, 1.000,
+ 0.143, 0.143, 0.143,
+ 0.286, 0.286, 0.286,
+ 0.429, 0.429, 0.429,
+ 0.571, 0.571, 0.571,
+ 0.714, 0.714, 0.714,
+ 0.857, 0.857, 0.857
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3)) * 255
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list
+
+
+color_list = colormap()
+color_list = color_list.astype('uint8').tolist()
+
+
+def vis_add_mask(image, mask, color, alpha):
+ color = np.array(color_list[color])
+ mask = mask > 0.5
+ image[mask] = image[mask] * (1-alpha) + color * alpha
+ return image.astype('uint8')
+
+def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
+ h, w = input_image.shape[:2]
+ point_mask = np.zeros((h, w)).astype('uint8')
+ for point in input_points:
+ point_mask[point[1], point[0]] = 1
+
+ kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
+ point_mask = cv2.dilate(point_mask, kernel)
+
+ contour_radius = (contour_width - 1) // 2
+ dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+
+ # paint mask
+ painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
+ # paint contour
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
+ return painted_image
+
+def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
+ # 0: background, 1: foreground
+ mask = np.clip(input_mask, 0, 1)
+ contour_radius = (contour_width - 1) // 2
+
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
+ dist_map = dist_transform_fore - dist_transform_back
+ # ...:::!!!:::...
+ contour_radius += 2
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
+ contour_mask = contour_mask / np.max(contour_mask)
+ contour_mask[contour_mask>0.5] = 1.
+
+ # paint mask
+ painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
+ # paint contour
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
+
+ return painted_image
+
+def background_remover(input_image, input_mask):
+ """
+ input_image: H, W, 3, np.array
+ input_mask: H, W, np.array
+
+ image_wo_background: PIL.Image
+ """
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
+ # 0: background, 1: foreground
+ mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
+ image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
+ image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
+
+ return image_wo_background
+
+if __name__ == '__main__':
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
+ input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
+
+ # example of mask painter
+ mask_color = 3
+ mask_alpha = 0.7
+ contour_color = 1
+ contour_width = 5
+
+ # save
+ painted_image = Image.fromarray(input_image)
+ painted_image.save('images/original.png')
+
+ painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
+ # save
+ painted_image = Image.fromarray(input_image)
+ painted_image.save('images/original1.png')
+
+ # example of point painter
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
+ input_points = np.array([[500, 375], [70, 600]]) # x, y
+ point_color = 5
+ point_alpha = 0.9
+ point_radius = 15
+ contour_color = 2
+ contour_width = 5
+ painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
+ # save
+ painted_image = Image.fromarray(painted_image_1)
+ painted_image.save('images/point_painter_1.png')
+
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
+ painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
+ # save
+ painted_image = Image.fromarray(painted_image_2)
+ painted_image.save('images/point_painter_2.png')
+
+ # example of background remover
+ input_image = np.array(Image.open('images/original.png').convert('RGB'))
+ image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
+ image_wo_background.save('images/image_wo_background.png')
diff --git a/track_anything.py b/track_anything.py
new file mode 100644
index 0000000000000000000000000000000000000000..57ea61282c0845324a8fe07e46538225b3c16c96
--- /dev/null
+++ b/track_anything.py
@@ -0,0 +1,40 @@
+import numpy as np
+from tqdm import tqdm
+
+from tools.interact_tools import SamControler
+from tracker.base_tracker import BaseTracker
+# from inpainter.base_inpainter import ProInpainter
+
+
+class TrackingAnything():
+ def __init__(self, sam_checkpoint, cutie_checkpoint, args):
+ self.args = args
+ self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
+ self.cutie = BaseTracker(cutie_checkpoint, device=args.device)
+ # self.baseinpainter = ProInpainter(propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args.device)
+
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
+ return mask, logit, painted_image
+
+ def generator(self, images: list, template_mask:np.ndarray):
+ masks = []
+ logits = []
+ painted_images = []
+ for i in tqdm(range(len(images)), desc="Tracking image"):
+ if i==0:
+ mask, logit, painted_image = self.cutie.track(images[i], template_mask)
+ masks.append(mask)
+ logits.append(logit)
+ painted_images.append(painted_image)
+ else:
+ mask, logit, painted_image = self.cutie.track(images[i])
+ masks.append(mask)
+ logits.append(logit)
+ painted_images.append(painted_image)
+ return masks, logits, painted_images
+
+
+
+
+
\ No newline at end of file
diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec527326c1df4b4bfb9fdd90bc925460cca39f2
--- /dev/null
+++ b/tracker/base_tracker.py
@@ -0,0 +1,103 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+from omegaconf import OmegaConf
+
+import sys
+sys.path.append('../')
+
+from tracker.config import CONFIG
+from tracker.model.cutie import CUTIE
+from tracker.inference.inference_core import InferenceCore
+from tracker.utils.mask_mapper import MaskMapper
+
+from tools.painter import mask_painter
+
+
+class BaseTracker:
+ def __init__(self, cutie_checkpoint, device) -> None:
+ """
+ device: model device
+ cutie_checkpoint: checkpoint of XMem model
+ """
+ config = OmegaConf.create(CONFIG)
+
+ # initialise XMem
+ network = CUTIE(config).to(device).eval()
+ model_weights = torch.load(cutie_checkpoint, map_location=device)
+ network.load_weights(model_weights)
+
+ # initialise IncerenceCore
+ self.tracker = InferenceCore(network, config)
+ self.device = device
+
+ # changable properties
+ self.mapper = MaskMapper()
+ self.initialised = False
+
+ @torch.no_grad()
+ def resize_mask(self, mask):
+ # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
+ h, w = mask.shape[-2:]
+ min_hw = min(h, w)
+ return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
+ mode='nearest')
+
+ @torch.no_grad()
+ def image_to_torch(self, frame: np.ndarray, device: str = 'cuda'):
+ # frame: H*W*3 numpy array
+ frame = frame.transpose(2, 0, 1)
+ frame = torch.from_numpy(frame).float().to(device, non_blocking=True) / 255
+ return frame
+
+ @torch.no_grad()
+ def track(self, frame, first_frame_annotation=None):
+ """
+ Input:
+ frames: numpy arrays (H, W, 3)
+ logit: numpy array (H, W), logit
+
+ Output:
+ mask: numpy arrays (H, W)
+ logit: numpy arrays, probability map (H, W)
+ painted_image: numpy array (H, W, 3)
+ """
+
+ if first_frame_annotation is not None: # first frame mask
+ # initialisation
+ mask, labels = self.mapper.convert_mask(first_frame_annotation)
+ mask = torch.Tensor(mask).to(self.device)
+ else:
+ mask = None
+ labels = None
+
+ # prepare inputs
+ frame_tensor = self.image_to_torch(frame, self.device)
+
+ # track one frame
+ probs = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
+
+ # convert to mask
+ out_mask = torch.argmax(probs, dim=0)
+ out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
+
+ final_mask = np.zeros_like(out_mask)
+
+ # map back
+ for k, v in self.mapper.remappings.items():
+ final_mask[out_mask == v] = k
+
+ num_objs = final_mask.max()
+ painted_image = frame
+ for obj in range(1, num_objs+1):
+ if np.max(final_mask==obj) == 0:
+ continue
+ painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
+
+ return final_mask, final_mask, painted_image
+
+ @torch.no_grad()
+ def clear_memory(self):
+ self.tracker.clear_memory()
+ self.mapper.clear_labels()
+ torch.cuda.empty_cache()
\ No newline at end of file
diff --git a/tracker/config/__init__.py b/tracker/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a26a5c2f36fabe3c9870cf0c1ba3c4414e15cd42
--- /dev/null
+++ b/tracker/config/__init__.py
@@ -0,0 +1 @@
+CONFIG = {'exp_id': 'default', 'dataset': 'd17-val', 'amp': False, 'output_dir': None, 'flip_aug': False, 'max_internal_size': -1, 'image_directory': None, 'mask_directory': None, 'json_directory': None, 'size': None, 'save_all': None, 'use_all_masks': None, 'use_long_term': None, 'mem_every': 5, 'max_mem_frames': 5, 'long_term': {'count_usage': True, 'max_mem_frames': 10, 'min_mem_frames': 5, 'num_prototypes': 128, 'max_num_tokens': 10000, 'buffer_tokens': 2000}, 'top_k': 30, 'stagger_updates': 5, 'chunk_size': -1, 'save_scores': False, 'save_aux': False, 'visualize': False, 'model': {'pixel_mean': [0.485, 0.456, 0.406], 'pixel_std': [0.229, 0.224, 0.225], 'pixel_dim': 256, 'key_dim': 64, 'value_dim': 256, 'sensory_dim': 256, 'embed_dim': 256, 'pixel_encoder': {'type': 'resnet50', 'ms_dims': [1024, 512, 256]}, 'mask_encoder': {'type': 'resnet18', 'final_dim': 256}, 'pixel_pe_scale': 32, 'pixel_pe_temperature': 128, 'object_transformer': {'embed_dim': '${model.embed_dim}', 'ff_dim': 2048, 'num_heads': 8, 'num_blocks': 3, 'num_queries': 16, 'read_from_pixel': {'input_norm': False, 'input_add_pe': False, 'add_pe_to_qkv': [True, True, False]}, 'read_from_past': {'add_pe_to_qkv': [True, True, False]}, 'read_from_memory': {'add_pe_to_qkv': [True, True, False]}, 'read_from_query': {'add_pe_to_qkv': [True, True, False], 'output_norm': False}, 'query_self_attention': {'add_pe_to_qkv': [True, True, False]}, 'pixel_self_attention': {'add_pe_to_qkv': [True, True, False]}}, 'object_summarizer': {'embed_dim': '${model.object_transformer.embed_dim}', 'num_summaries': '${model.object_transformer.num_queries}', 'add_pe': True}, 'aux_loss': {'sensory': {'enabled': True, 'weight': 0.01}, 'query': {'enabled': True, 'weight': 0.01}}, 'mask_decoder': {'up_dims': [256, 128, 128]}}}
\ No newline at end of file
diff --git a/tracker/inference/__init__.py b/tracker/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/inference/image_feature_store.py b/tracker/inference/image_feature_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d02a634b009bb06b935e49036adb121325f5816
--- /dev/null
+++ b/tracker/inference/image_feature_store.py
@@ -0,0 +1,49 @@
+import warnings
+from typing import Iterable
+import torch
+from tracker.model.cutie import CUTIE
+
+
+class ImageFeatureStore:
+ """
+ A cache for image features.
+ These features might be reused at different parts of the inference pipeline.
+ This class provide an interface for reusing these features.
+ It is the user's responsibility to delete redundant features.
+
+ Feature of a frame should be associated with a unique index -- typically the frame id.
+ """
+ def __init__(self, network: CUTIE, no_warning: bool = False):
+ self.network = network
+ self._store = {}
+ self.no_warning = no_warning
+
+ def _encode_feature(self, index: int, image: torch.Tensor) -> None:
+ ms_features, pix_feat = self.network.encode_image(image)
+ key, shrinkage, selection = self.network.transform_key(ms_features[0])
+ self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
+
+ def get_features(self, index: int,
+ image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
+ if index not in self._store:
+ self._encode_feature(index, image)
+
+ return self._store[index][:2]
+
+ def get_key(self, index: int,
+ image: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+ if index not in self._store:
+ self._encode_feature(index, image)
+
+ return self._store[index][2:]
+
+ def delete(self, index: int) -> None:
+ if index in self._store:
+ del self._store[index]
+
+ def __len__(self):
+ return len(self._store)
+
+ def __del__(self):
+ if len(self._store) > 0 and not self.no_warning:
+ warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
diff --git a/tracker/inference/inference_core.py b/tracker/inference/inference_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7d5b24cf9aaed19656a194b1d6dfc0aaffe91c6
--- /dev/null
+++ b/tracker/inference/inference_core.py
@@ -0,0 +1,381 @@
+from typing import List, Optional, Iterable, Dict
+import logging
+from omegaconf import DictConfig
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from tracker.inference.memory_manager import MemoryManager
+from tracker.inference.object_manager import ObjectManager
+from tracker.inference.image_feature_store import ImageFeatureStore
+from tracker.model.cutie import CUTIE
+from tracker.utils.tensor_utils import pad_divide_by, unpad, aggregate
+
+log = logging.getLogger()
+
+
+class InferenceCore:
+ def __init__(self,
+ network: CUTIE,
+ cfg: DictConfig,
+ *,
+ image_feature_store: ImageFeatureStore = None):
+ self.network = network
+ self.cfg = cfg
+ self.mem_every = cfg.mem_every
+ stagger_updates = cfg.stagger_updates
+ self.chunk_size = cfg.chunk_size
+ self.save_aux = cfg.save_aux
+ self.max_internal_size = cfg.max_internal_size
+ self.flip_aug = cfg.flip_aug
+
+ self.curr_ti = -1
+ self.last_mem_ti = 0
+ # at which time indices should we update the sensory memory
+ if stagger_updates >= self.mem_every:
+ self.stagger_ti = set(range(1, self.mem_every + 1))
+ else:
+ self.stagger_ti = set(
+ np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int))
+ self.object_manager = ObjectManager()
+ self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager)
+
+ if image_feature_store is None:
+ self.image_feature_store = ImageFeatureStore(self.network)
+ else:
+ self.image_feature_store = image_feature_store
+
+ self.last_mask = None
+
+ def clear_memory(self):
+ self.curr_ti = -1
+ self.last_mem_ti = 0
+ self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager)
+
+ def clear_non_permanent_memory(self):
+ self.curr_ti = -1
+ self.last_mem_ti = 0
+ self.memory.clear_non_permanent_memory()
+
+ def clear_sensory_memory(self):
+ self.curr_ti = -1
+ self.last_mem_ti = 0
+ self.memory.clear_sensory_memory()
+
+ def update_config(self, cfg):
+ self.mem_every = cfg['mem_every']
+ self.memory.update_config(cfg)
+
+ def _add_memory(self,
+ image: torch.Tensor,
+ pix_feat: torch.Tensor,
+ prob: torch.Tensor,
+ key: torch.Tensor,
+ shrinkage: torch.Tensor,
+ selection: torch.Tensor,
+ *,
+ is_deep_update: bool = True,
+ force_permanent: bool = False) -> None:
+ """
+ Memorize the given segmentation in all memory stores.
+
+ The batch dimension is 1 if flip augmentation is not used.
+ image: RGB image, (1/2)*3*H*W
+ pix_feat: from the key encoder, (1/2)*_*H*W
+ prob: (1/2)*num_objects*H*W, in [0, 1]
+ key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W
+ selection can be None if not using long-term memory
+ is_deep_update: whether to use deep update (e.g. with the mask encoder)
+ force_permanent: whether to force the memory to be permanent
+ """
+ if prob.shape[1] == 0:
+ # nothing to add
+ log.warn('Trying to add an empty object mask to memory!')
+ return
+
+ if force_permanent:
+ as_permanent = 'all'
+ else:
+ as_permanent = 'first'
+
+ self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids)
+ msk_value, sensory, obj_value, self.obj_logits = self.network.encode_mask(
+ image,
+ pix_feat,
+ self.memory.get_sensory(self.object_manager.all_obj_ids),
+ prob,
+ deep_update=is_deep_update,
+ chunk_size=self.chunk_size,
+ need_weights=self.save_aux)
+ self.memory.add_memory(key,
+ shrinkage,
+ msk_value,
+ obj_value,
+ self.object_manager.all_obj_ids,
+ selection=selection,
+ as_permanent=as_permanent)
+ self.last_mem_ti = self.curr_ti
+ if is_deep_update:
+ self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
+
+ def _segment(self,
+ key: torch.Tensor,
+ selection: torch.Tensor,
+ pix_feat: torch.Tensor,
+ ms_features: Iterable[torch.Tensor],
+ update_sensory: bool = True) -> torch.Tensor:
+ """
+ Produce a segmentation using the given features and the memory
+
+ The batch dimension is 1 if flip augmentation is not used.
+ key/selection: for anisotropic l2: (1/2) * _ * H * W
+ pix_feat: from the key encoder, (1/2) * _ * H * W
+ ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W
+ with strides 16, 8, and 4 respectively
+ update_sensory: whether to update the sensory memory
+
+ Returns: (num_objects+1)*H*W normalized probability; the first channel is the background
+ """
+ bs = key.shape[0]
+ if self.flip_aug:
+ assert bs == 2
+ else:
+ assert bs == 1
+
+ if not self.memory.engaged:
+ log.warn('Trying to segment without any memory!')
+ return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
+ device=key.device,
+ dtype=key.dtype)
+
+ memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network)
+ memory_readout = self.object_manager.realize_dict(memory_readout)
+ sensory, _, pred_prob_with_bg = self.network.segment(ms_features,
+ memory_readout,
+ self.memory.get_sensory(
+ self.object_manager.all_obj_ids),
+ chunk_size=self.chunk_size,
+ update_sensory=update_sensory)
+ # remove batch dim
+ if self.flip_aug:
+ # average predictions of the non-flipped and flipped version
+ pred_prob_with_bg = (pred_prob_with_bg[0] +
+ torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2
+ else:
+ pred_prob_with_bg = pred_prob_with_bg[0]
+ if update_sensory:
+ self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
+ return pred_prob_with_bg
+
+ def step(self,
+ image: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ objects: Optional[List[int]] = None,
+ *,
+ idx_mask: bool = True,
+ end: bool = False,
+ delete_buffer: bool = True,
+ force_permanent: bool = False) -> torch.Tensor:
+ """
+ Take a step with a new incoming image.
+ If there is an incoming mask with new objects, we will memorize them.
+ If there is no incoming mask, we will segment the image using the memory.
+ In both cases, we will update the memory and return a segmentation.
+
+ image: 3*H*W
+ mask: H*W (if idx mask) or len(objects)*H*W or None
+ objects: list of object ids that are valid in the mask Tensor.
+ The ids themselves do not need to be consecutive/in order, but they need to be
+ in the same position in the list as the corresponding mask
+ in the tensor in non-idx-mask mode.
+ objects is ignored if the mask is None.
+ If idx_mask is False and objects is None, we sequentially infer the object ids.
+ idx_mask: if True, mask is expected to contain an object id at every pixel.
+ If False, mask should have multiple channels with each channel representing one object.
+ end: if we are at the end of the sequence, we do not need to update memory
+ if unsure just set it to False
+ delete_buffer: whether to delete the image feature buffer after this step
+ force_permanent: the memory recorded this frame will be added to the permanent memory
+ """
+ if objects is None and mask is not None:
+ assert not idx_mask
+ objects = list(range(1, mask.shape[0] + 1))
+
+ # resize input if needed -- currently only used for the GUI
+ resize_needed = False
+ if self.max_internal_size > 0:
+ h, w = image.shape[-2:]
+ min_side = min(h, w)
+ if min_side > self.max_internal_size:
+ resize_needed = True
+ new_h = int(h / min_side * self.max_internal_size)
+ new_w = int(w / min_side * self.max_internal_size)
+ image = F.interpolate(image.unsqueeze(0),
+ size=(new_h, new_w),
+ mode='bilinear',
+ align_corners=False)[0]
+ if mask is not None:
+ if idx_mask:
+ mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(),
+ size=(new_h, new_w),
+ mode='nearest',
+ align_corners=False)[0, 0].round().long()
+ else:
+ mask = F.interpolate(mask.unsqueeze(0),
+ size=(new_h, new_w),
+ mode='bilinear',
+ align_corners=False)[0]
+
+ self.curr_ti += 1
+
+ image, self.pad = pad_divide_by(image, 16)
+ image = image.unsqueeze(0) # add the batch dimension
+ if self.flip_aug:
+ image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0)
+
+ # whether to update the working memory
+ is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or
+ (mask is not None)) and (not end)
+ # segment when there is no input mask or when the input mask is incomplete
+ need_segment = (mask is None) or (self.object_manager.num_obj > 0
+ and not self.object_manager.has_all(objects))
+ update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end)
+
+ # encoding the image
+ ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)
+ key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image)
+
+ # segmentation from memory if needed
+ if need_segment:
+ pred_prob_with_bg = self._segment(key,
+ selection,
+ pix_feat,
+ ms_feat,
+ update_sensory=update_sensory)
+
+ # use the input mask if provided
+ if mask is not None:
+ # inform the manager of the new objects, and get a list of temporary id
+ # temporary ids -- indicates the position of objects in the tensor
+ # (starts with 1 due to the background channel)
+ corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects)
+
+ mask, _ = pad_divide_by(mask, 16)
+ if need_segment:
+ # merge predicted mask with the incomplete input mask
+ pred_prob_no_bg = pred_prob_with_bg[1:]
+ # use the mutual exclusivity of segmentation
+ if idx_mask:
+ pred_prob_no_bg[:, mask > 0] = 0
+ else:
+ pred_prob_no_bg[:, mask.max(0) > 0.5] = 0
+
+ new_masks = []
+ for mask_id, tmp_id in enumerate(corresponding_tmp_ids):
+ if idx_mask:
+ this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg)
+ else:
+ this_mask = mask[tmp_id]
+ if tmp_id >= pred_prob_no_bg.shape[0]:
+ new_masks.append(this_mask.unsqueeze(0))
+ else:
+ # +1 for padding the background channel
+ pred_prob_no_bg[tmp_id + 1] = this_mask
+ # new_masks are always in the order of tmp_id
+ mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0)
+ elif idx_mask:
+ # simply convert cls to one-hot representation
+ if len(objects) == 0:
+ if delete_buffer:
+ self.image_feature_store.delete(self.curr_ti)
+ log.warn('Trying to insert an empty mask as memory!')
+ return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
+ device=key.device,
+ dtype=key.dtype)
+ mask = torch.stack(
+ [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)],
+ dim=0)
+ pred_prob_with_bg = aggregate(mask, dim=0)
+ pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0)
+
+ self.last_mask = pred_prob_with_bg[1:].unsqueeze(0)
+ if self.flip_aug:
+ self.last_mask = torch.cat(
+ [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0)
+
+ # save as memory if needed
+ if is_mem_frame or force_permanent:
+ self._add_memory(image,
+ pix_feat,
+ self.last_mask,
+ key,
+ shrinkage,
+ selection,
+ force_permanent=force_permanent)
+
+ if delete_buffer:
+ self.image_feature_store.delete(self.curr_ti)
+
+ output_prob = unpad(pred_prob_with_bg, self.pad)
+ if resize_needed:
+ # restore output to the original size
+ output_prob = F.interpolate(output_prob.unsqueeze(0),
+ size=(h, w),
+ mode='bilinear',
+ align_corners=False)[0]
+
+ return output_prob
+
+ def get_aux_outputs(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
+ image, pads = pad_divide_by(image, 16)
+ image = image.unsqueeze(0) # add the batch dimension
+ _, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)
+
+ aux_inputs = self.memory.aux
+ aux_outputs = self.network.compute_aux(pix_feat, aux_inputs, selector=None)
+ aux_outputs['q_weights'] = aux_inputs['q_weights']
+ aux_outputs['p_weights'] = aux_inputs['p_weights']
+
+ for k, v in aux_outputs.items():
+ if len(v.shape) == 5:
+ aux_outputs[k] = F.interpolate(v[0],
+ size=image.shape[-2:],
+ mode='bilinear',
+ align_corners=False)
+ elif 'weights' in k:
+ b, num_objects, num_heads, num_queries, h, w = v.shape
+ v = v.view(num_objects * num_heads, num_queries, h, w)
+ v = F.interpolate(v, size=image.shape[-2:], mode='bilinear', align_corners=False)
+ aux_outputs[k] = v.view(num_objects, num_heads, num_queries, *image.shape[-2:])
+ else:
+ aux_outputs[k] = F.interpolate(v,
+ size=image.shape[-2:],
+ mode='bilinear',
+ align_corners=False)[0]
+ aux_outputs[k] = unpad(aux_outputs[k], pads)
+ if 'weights' in k:
+ weights = aux_outputs[k]
+ weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0] +
+ 1e-8)
+ aux_outputs[k] = (weights * 255).cpu().numpy()
+ else:
+ aux_outputs[k] = (aux_outputs[k].softmax(dim=0) * 255).cpu().numpy()
+
+ self.image_feature_store.delete(self.curr_ti)
+ return aux_outputs
+
+ def get_aux_object_weights(self, image: torch.Tensor) -> np.ndarray:
+ image, pads = pad_divide_by(image, 16)
+ # B*num_objects*H*W*num_queries -> num_objects*num_queries*H*W
+ # weights = F.softmax(self.obj_logits, dim=-1)[0]
+ weights = F.sigmoid(self.obj_logits)[0]
+ weights = weights.permute(0, 3, 1, 2).contiguous()
+ weights = F.interpolate(weights,
+ size=image.shape[-2:],
+ mode='bilinear',
+ align_corners=False)
+ # weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0])
+ weights = unpad(weights, pads)
+ weights = (weights * 255).cpu().numpy()
+ return weights
diff --git a/tracker/inference/kv_memory_store.py b/tracker/inference/kv_memory_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..e50b794dc227a772e8a7478d26d662749c0b1c6c
--- /dev/null
+++ b/tracker/inference/kv_memory_store.py
@@ -0,0 +1,348 @@
+from typing import Dict, List, Optional, Literal
+from collections import defaultdict
+import torch
+
+
+def _add_last_dim(dictionary, key, new_value, prepend=False):
+ # append/prepend a new value to the last dimension of a tensor in a dictionary
+ # if the key does not exist, put the new value in
+ # append by default
+ if key in dictionary:
+ dictionary[key] = torch.cat([dictionary[key], new_value], -1)
+ else:
+ dictionary[key] = new_value
+
+
+class KeyValueMemoryStore:
+ """
+ Works for key/value pairs type storage
+ e.g., working and long-term memory
+ """
+ def __init__(self, save_selection: bool = False, save_usage: bool = False):
+ """
+ We store keys and values of objects that first appear in the same frame in a bucket.
+ Each bucket contains a set of object ids.
+ Each bucket is associated with a single key tensor
+ and a dictionary of value tensors indexed by object id.
+
+ The keys and values are stored as the concatenation of a permanent part and a temporary part.
+ """
+ self.save_selection = save_selection
+ self.save_usage = save_usage
+
+ self.global_bucket_id = 0 # does not reduce even if buckets are removed
+ self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
+ self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
+ self.v: Dict[int, torch.Tensor] = {} # indexed by object id
+
+ # indexed by bucket id; the end point of permanent memory
+ self.perm_end_pt: Dict[int, int] = defaultdict(int)
+
+ # shrinkage and selection are just like the keys
+ self.s = {}
+ if self.save_selection:
+ self.e = {} # does not contain the permanent memory part
+
+ # usage
+ if self.save_usage:
+ self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
+ self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
+
+ def add(self,
+ key: torch.Tensor,
+ values: Dict[int, torch.Tensor],
+ shrinkage: torch.Tensor,
+ selection: torch.Tensor,
+ supposed_bucket_id: int = -1,
+ as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
+ """
+ key: (1/2)*C*N
+ values: dict of values ((1/2)*C*N), object ids are used as keys
+ shrinkage: (1/2)*1*N
+ selection: (1/2)*C*N
+
+ supposed_bucket_id: used to sync the bucket id between working and long-term memory
+ if provided, the input should all be in a single bucket indexed by this id
+ as_permanent: whether to store the input as permanent memory
+ 'no': don't
+ 'first': only store it as permanent memory if the bucket is empty
+ 'all': always store it as permanent memory
+ """
+ bs = key.shape[0]
+ ne = key.shape[-1]
+ assert len(key.shape) == 3
+ assert len(shrinkage.shape) == 3
+ assert not self.save_selection or len(selection.shape) == 3
+ assert as_permanent in ['no', 'first', 'all']
+
+ # add the value and create new buckets if necessary
+ if supposed_bucket_id >= 0:
+ enabled_buckets = [supposed_bucket_id]
+ bucket_exist = supposed_bucket_id in self.buckets
+ for obj, value in values.items():
+ if bucket_exist:
+ assert obj in self.v
+ assert obj in self.buckets[supposed_bucket_id]
+ _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
+ else:
+ assert obj not in self.v
+ self.v[obj] = value
+ self.buckets[supposed_bucket_id] = list(values.keys())
+ else:
+ new_bucket_id = None
+ enabled_buckets = set()
+ for obj, value in values.items():
+ assert len(value.shape) == 3
+ if obj in self.v:
+ _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
+ bucket_used = [
+ bucket_id for bucket_id, object_ids in self.buckets.items()
+ if obj in object_ids
+ ]
+ assert len(bucket_used) == 1 # each object should only be in one bucket
+ enabled_buckets.add(bucket_used[0])
+ else:
+ self.v[obj] = value
+ if new_bucket_id is None:
+ # create new bucket
+ new_bucket_id = self.global_bucket_id
+ self.global_bucket_id += 1
+ self.buckets[new_bucket_id] = []
+ # put the new object into the corresponding bucket
+ self.buckets[new_bucket_id].append(obj)
+ enabled_buckets.add(new_bucket_id)
+
+ # increment the permanent size if necessary
+ add_as_permanent = {} # indexed by bucket id
+ for bucket_id in enabled_buckets:
+ add_as_permanent[bucket_id] = False
+ if as_permanent == 'all':
+ self.perm_end_pt[bucket_id] += ne
+ add_as_permanent[bucket_id] = True
+ elif as_permanent == 'first':
+ if self.perm_end_pt[bucket_id] == 0:
+ self.perm_end_pt[bucket_id] = ne
+ add_as_permanent[bucket_id] = True
+
+ # create new counters for usage if necessary
+ if self.save_usage and as_permanent != 'all':
+ new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
+ new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
+
+ # add the key to every bucket
+ for bucket_id in self.buckets:
+ if bucket_id not in enabled_buckets:
+ # if we are not adding new values to a bucket, we should skip it
+ continue
+
+ _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
+ _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
+ if not add_as_permanent[bucket_id]:
+ if self.save_selection:
+ _add_last_dim(self.e, bucket_id, selection)
+ if self.save_usage:
+ _add_last_dim(self.use_cnt, bucket_id, new_count)
+ _add_last_dim(self.life_cnt, bucket_id, new_life)
+
+ def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
+ # increase all life count by 1
+ # increase use of indexed elements
+ if not self.save_usage:
+ return
+
+ usage = usage[:, self.perm_end_pt[bucket_id]:]
+ if usage.shape[-1] == 0:
+ # if there is no temporary memory, we don't need to update
+ return
+ self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
+ self.life_cnt[bucket_id] += 1
+
+ def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
+ # keep only the temporary elements *outside* of this range (with some boundary conditions)
+ # the permanent elements are ignored in this computation
+ # i.e., concat (a[:start], a[end:])
+ # bucket with size <= min_size are not modified
+
+ assert start >= 0
+ assert end <= 0
+
+ object_ids = self.buckets[bucket_id]
+ bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
+ if bucket_num_elements <= min_size:
+ return
+
+ if end == 0:
+ # negative 0 would not work as the end index!
+ # effectively make the second part an empty slice
+ end = self.k[bucket_id].shape[-1] + 1
+
+ p_size = self.perm_end_pt[bucket_id]
+ start = start + p_size
+
+ k = self.k[bucket_id]
+ s = self.s[bucket_id]
+ if self.save_selection:
+ e = self.e[bucket_id]
+ if self.save_usage:
+ use_cnt = self.use_cnt[bucket_id]
+ life_cnt = self.life_cnt[bucket_id]
+
+ self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
+ self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
+ if self.save_selection:
+ self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
+ if self.save_usage:
+ self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
+ self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
+ -1)
+ for obj_id in object_ids:
+ v = self.v[obj_id]
+ self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
+
+ def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
+ self.sieve_by_range(bucket_id, 0, -max_len, max_len)
+
+ def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
+ # for long-term memory only
+ object_ids = self.buckets[bucket_id]
+
+ assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
+
+ # normalize with life duration
+ usage = self.get_usage(bucket_id)
+ bs = usage.shape[0]
+
+ survivals = []
+
+ for bi in range(bs):
+ _, survived = torch.topk(usage[bi], k=max_size)
+ survivals.append(survived.flatten())
+ assert survived.shape[-1] == survivals[0].shape[-1]
+
+ self.k[bucket_id] = torch.stack(
+ [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
+ self.s[bucket_id] = torch.stack(
+ [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
+
+ if self.save_selection:
+ # Long-term memory does not store selection so this should not be needed
+ self.e[bucket_id] = torch.stack(
+ [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
+ for obj_id in object_ids:
+ self.v[obj_id] = torch.stack(
+ [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
+
+ self.use_cnt[bucket_id] = torch.stack(
+ [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
+ self.life_cnt[bucket_id] = torch.stack(
+ [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
+
+ def get_usage(self, bucket_id: int) -> torch.Tensor:
+ # return normalized usage
+ if not self.save_usage:
+ raise RuntimeError('I did not count usage!')
+ else:
+ usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
+ return usage
+
+ def get_all_sliced(
+ self, bucket_id: int, start: int, end: int
+ ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
+ # return k, sk, ek, value, normalized usage in order, sliced by start and end
+ # this only queries the temporary memory
+
+ assert start >= 0
+ assert end <= 0
+
+ p_size = self.perm_end_pt[bucket_id]
+ start = start + p_size
+
+ if end == 0:
+ # negative 0 would not work as the end index!
+ k = self.k[bucket_id][:, :, start:]
+ sk = self.s[bucket_id][:, :, start:]
+ ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
+ value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
+ usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
+ else:
+ k = self.k[bucket_id][:, :, start:end]
+ sk = self.s[bucket_id][:, :, start:end]
+ ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
+ value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
+ usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
+
+ return k, sk, ek, value, usage
+
+ def purge_except(self, obj_keep_idx: List[int]):
+ # purge certain objects from the memory except the one listed
+ obj_keep_idx = set(obj_keep_idx)
+
+ # remove objects that are not in the keep list from the buckets
+ buckets_to_remove = []
+ for bucket_id, object_ids in self.buckets.items():
+ self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
+ if len(self.buckets[bucket_id]) == 0:
+ buckets_to_remove.append(bucket_id)
+
+ # remove object values that are not in the keep list
+ self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
+
+ # remove buckets that are empty
+ for bucket_id in buckets_to_remove:
+ del self.buckets[bucket_id]
+ del self.k[bucket_id]
+ del self.s[bucket_id]
+ if self.save_selection:
+ del self.e[bucket_id]
+ if self.save_usage:
+ del self.use_cnt[bucket_id]
+ del self.life_cnt[bucket_id]
+
+ def clear_non_permanent_memory(self):
+ # clear all non-permanent memory
+ for bucket_id in self.buckets:
+ self.sieve_by_range(bucket_id, 0, 0, 0)
+
+ def get_v_size(self, obj_id: int) -> int:
+ return self.v[obj_id].shape[-1]
+
+ def size(self, bucket_id: int) -> int:
+ if bucket_id not in self.k:
+ return 0
+ else:
+ return self.k[bucket_id].shape[-1]
+
+ def perm_size(self, bucket_id: int) -> int:
+ return self.perm_end_pt[bucket_id]
+
+ def non_perm_size(self, bucket_id: int) -> int:
+ return self.size(bucket_id) - self.perm_size(bucket_id)
+
+ def engaged(self, bucket_id: Optional[int] = None) -> bool:
+ if bucket_id is None:
+ return len(self.buckets) > 0
+ else:
+ return bucket_id in self.buckets
+
+ @property
+ def num_objects(self) -> int:
+ return len(self.v)
+
+ @property
+ def key(self) -> Dict[int, torch.Tensor]:
+ return self.k
+
+ @property
+ def value(self) -> Dict[int, torch.Tensor]:
+ return self.v
+
+ @property
+ def shrinkage(self) -> Dict[int, torch.Tensor]:
+ return self.s
+
+ @property
+ def selection(self) -> Dict[int, torch.Tensor]:
+ return self.e
+
+ def __contains__(self, key):
+ return key in self.v
diff --git a/tracker/inference/memory_manager.py b/tracker/inference/memory_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..53995a0f3c0998191858b955415a549dfcef248e
--- /dev/null
+++ b/tracker/inference/memory_manager.py
@@ -0,0 +1,378 @@
+import logging
+from omegaconf import DictConfig
+from typing import List, Dict
+import torch
+
+from tracker.inference.object_manager import ObjectManager
+from tracker.inference.kv_memory_store import KeyValueMemoryStore
+from tracker.model.cutie import CUTIE
+from tracker.model.utils.memory_utils import *
+
+log = logging.getLogger()
+
+
+class MemoryManager:
+ """
+ Manages all three memory stores and the transition between working/long-term memory
+ """
+ def __init__(self, cfg: DictConfig, object_manager: ObjectManager):
+ self.object_manager = object_manager
+ self.sensory_dim = cfg.model.sensory_dim
+ self.top_k = cfg.top_k
+ self.chunk_size = cfg.chunk_size
+
+ self.save_aux = cfg.save_aux
+
+ self.use_long_term = cfg.use_long_term
+ self.count_long_term_usage = cfg.long_term.count_usage
+ # subtract 1 because the first-frame is now counted as "permanent memory"
+ # and is not counted towards max_mem_frames
+ # but we want to keep the hyperparameters consistent as before for the same behavior
+ if self.use_long_term:
+ self.max_mem_frames = cfg.long_term.max_mem_frames - 1
+ self.min_mem_frames = cfg.long_term.min_mem_frames - 1
+ self.num_prototypes = cfg.long_term.num_prototypes
+ self.max_long_tokens = cfg.long_term.max_num_tokens
+ self.buffer_tokens = cfg.long_term.buffer_tokens
+ else:
+ self.max_mem_frames = cfg.max_mem_frames - 1
+
+ # dimensions will be inferred from input later
+ self.CK = self.CV = None
+ self.H = self.W = None
+
+ # The sensory memory is stored as a dictionary indexed by object ids
+ # each of shape bs * C^h * H * W
+ self.sensory = {}
+
+ # a dictionary indexed by object ids, each of shape bs * T * Q * C
+ self.obj_v = {}
+
+ self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
+ save_usage=self.use_long_term)
+ if self.use_long_term:
+ self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage)
+
+ self.config_stale = True
+ self.engaged = False
+
+ def update_config(self, cfg: DictConfig) -> None:
+ self.config_stale = True
+ self.top_k = cfg['top_k']
+
+ assert self.use_long_term == cfg.use_long_term, 'cannot update this'
+ assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this'
+
+ self.use_long_term = cfg.use_long_term
+ self.count_long_term_usage = cfg.long_term.count_usage
+ if self.use_long_term:
+ self.max_mem_frames = cfg.long_term.max_mem_frames - 1
+ self.min_mem_frames = cfg.long_term.min_mem_frames - 1
+ self.num_prototypes = cfg.long_term.num_prototypes
+ self.max_long_tokens = cfg.long_term.max_num_tokens
+ self.buffer_tokens = cfg.long_term.buffer_tokens
+ else:
+ self.max_mem_frames = cfg.max_mem_frames - 1
+
+ def _readout(self, affinity, v) -> torch.Tensor:
+ # affinity: bs*N*HW
+ # v: bs*C*N or bs*num_objects*C*N
+ # returns bs*C*HW or bs*num_objects*C*HW
+ if len(v.shape) == 3:
+ # single object
+ return v @ affinity
+ else:
+ bs, num_objects, C, N = v.shape
+ v = v.view(bs, num_objects * C, N)
+ out = v @ affinity
+ return out.view(bs, num_objects, C, -1)
+
+ def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor:
+ # -1 because the mask does not contain the background channel
+ return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]]
+
+ def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
+ return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1)
+
+ def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
+ return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1)
+
+ def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
+ # All the values that the object ids refer to should have the same shape
+ value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1)
+ if self.use_long_term and obj_ids[0] in self.long_mem.value:
+ lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1)
+ value = torch.cat([lt_value, value], dim=-1)
+
+ return value
+
+ def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
+ last_mask: torch.Tensor, network: CUTIE) -> Dict[int, torch.Tensor]:
+ """
+ Read from all memory stores and returns a single memory readout tensor for each object
+
+ pix_feat: (1/2) x C x H x W
+ query_key: (1/2) x C^k x H x W
+ selection: (1/2) x C^k x H x W
+ last_mask: (1/2) x num_objects x H x W (at stride 16)
+ return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
+ """
+ h, w = pix_feat.shape[-2:]
+ bs = pix_feat.shape[0]
+ assert query_key.shape[0] == bs
+ assert selection.shape[0] == bs
+ assert last_mask.shape[0] == bs
+
+ query_key = query_key.flatten(start_dim=2) # bs*C^k*HW
+ selection = selection.flatten(start_dim=2) # bs*C^k*HW
+ """
+ Compute affinity and perform readout
+ """
+ all_readout_mem = {}
+ buckets = self.work_mem.buckets
+ for bucket_id, bucket in buckets.items():
+ if self.use_long_term and self.long_mem.engaged(bucket_id):
+ # Use long-term memory
+ long_mem_size = self.long_mem.size(bucket_id)
+ memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]],
+ -1)
+ shrinkage = torch.cat(
+ [self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1)
+
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection)
+ affinity, usage = do_softmax(similarity,
+ top_k=self.top_k,
+ inplace=True,
+ return_usage=True)
+ """
+ Record memory usage for working and long-term memory
+ """
+ # ignore the index return for long-term memory
+ work_usage = usage[:, long_mem_size:]
+ self.work_mem.update_bucket_usage(bucket_id, work_usage)
+
+ if self.count_long_term_usage:
+ # ignore the index return for working memory
+ long_usage = usage[:, :long_mem_size]
+ self.long_mem.update_bucket_usage(bucket_id, long_usage)
+ else:
+ # no long-term memory
+ memory_key = self.work_mem.key[bucket_id]
+ shrinkage = self.work_mem.shrinkage[bucket_id]
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection)
+
+ if self.use_long_term:
+ affinity, usage = do_softmax(similarity,
+ top_k=self.top_k,
+ inplace=True,
+ return_usage=True)
+ self.work_mem.update_bucket_usage(bucket_id, usage)
+ else:
+ affinity = do_softmax(similarity, top_k=self.top_k, inplace=True)
+
+ if self.chunk_size < 1:
+ object_chunks = [bucket]
+ else:
+ object_chunks = [
+ bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
+ ]
+
+ for objects in object_chunks:
+ this_sensory = self._get_sensory_by_ids(objects)
+ this_last_mask = self._get_mask_by_ids(last_mask, objects)
+ this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
+ visual_readout = self._readout(affinity,
+ this_msk_value).view(bs, len(objects), self.CV, h, w)
+ pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory,
+ this_last_mask)
+ this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
+ readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
+ for i, obj in enumerate(objects):
+ all_readout_mem[obj] = readout_memory[:, i]
+
+ if self.save_aux:
+ aux_output = {
+ 'sensory': this_sensory,
+ 'pixel_readout': pixel_readout,
+ 'q_logits': aux_features['logits'] if aux_features else None,
+ 'q_weights': aux_features['q_weights'] if aux_features else None,
+ 'p_weights': aux_features['p_weights'] if aux_features else None,
+ 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
+ }
+ self.aux = aux_output
+
+ return all_readout_mem
+
+ def add_memory(self,
+ key: torch.Tensor,
+ shrinkage: torch.Tensor,
+ msk_value: torch.Tensor,
+ obj_value: torch.Tensor,
+ objects: List[int],
+ selection: torch.Tensor = None,
+ *,
+ as_permanent: bool = False) -> None:
+ # key: (1/2)*C*H*W
+ # msk_value: (1/2)*num_objects*C*H*W
+ # obj_value: (1/2)*num_objects*Q*C
+ # objects contains a list of object ids corresponding to the objects in msk_value/obj_value
+ bs = key.shape[0]
+ assert shrinkage.shape[0] == bs
+ assert msk_value.shape[0] == bs
+ assert obj_value.shape[0] == bs
+
+ self.engaged = True
+ if self.H is None or self.config_stale:
+ self.config_stale = False
+ self.H, self.W = msk_value.shape[-2:]
+ self.HW = self.H * self.W
+ # convert from num. frames to num. tokens
+ self.max_work_tokens = self.max_mem_frames * self.HW
+ if self.use_long_term:
+ self.min_work_tokens = self.min_mem_frames * self.HW
+
+ # key: bs*C*N
+ # value: bs*num_objects*C*N
+ key = key.flatten(start_dim=2)
+ shrinkage = shrinkage.flatten(start_dim=2)
+ self.CK = key.shape[1]
+
+ msk_value = msk_value.flatten(start_dim=3)
+ self.CV = msk_value.shape[2]
+
+ if selection is not None:
+ # not used in non-long-term mode
+ selection = selection.flatten(start_dim=2)
+
+ # insert object values into object memory
+ for obj_id, obj in enumerate(objects):
+ if obj in self.obj_v:
+ """streaming average
+ each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
+ first embed_dim keeps track of the sum of embeddings
+ the last dim keeps the total count
+ averaging in done inside the object transformer
+
+ incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
+ self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
+ """
+ last_acc = self.obj_v[obj][:, :, -1]
+ new_acc = last_acc + obj_value[:, obj_id, :, -1]
+
+ self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
+ obj_value[:, obj_id, :, :-1])
+ self.obj_v[obj][:, :, -1] = new_acc
+ else:
+ self.obj_v[obj] = obj_value[:, obj_id]
+
+ # convert mask value tensor into a dict for insertion
+ msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)}
+ self.work_mem.add(key,
+ msk_values,
+ shrinkage,
+ selection=selection,
+ as_permanent=as_permanent)
+
+ for bucket_id in self.work_mem.buckets.keys():
+ # long-term memory cleanup
+ if self.use_long_term:
+ # Do memory compressed if needed
+ if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens:
+ # Remove obsolete features if needed
+ if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens -
+ self.num_prototypes):
+ self.long_mem.remove_obsolete_features(
+ bucket_id,
+ self.max_long_tokens - self.num_prototypes - self.buffer_tokens)
+
+ self.compress_features(bucket_id)
+ else:
+ # FIFO
+ self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens)
+
+ def purge_except(self, obj_keep_idx: List[int]) -> None:
+ # purge certain objects from the memory except the one listed
+ self.work_mem.purge_except(obj_keep_idx)
+ if self.use_long_term and self.long_mem.engaged():
+ self.long_mem.purge_except(obj_keep_idx)
+ self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx}
+
+ if not self.work_mem.engaged():
+ # everything is removed!
+ self.engaged = False
+
+ def compress_features(self, bucket_id: int) -> None:
+ HW = self.HW
+
+ # perform memory consolidation
+ prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
+ *self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens))
+
+ # remove consolidated working memory
+ self.work_mem.sieve_by_range(bucket_id,
+ 0,
+ -self.min_work_tokens,
+ min_size=self.min_work_tokens)
+
+ # add to long-term memory
+ self.long_mem.add(prototype_key,
+ prototype_value,
+ prototype_shrinkage,
+ selection=None,
+ supposed_bucket_id=bucket_id)
+
+ def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor,
+ candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor],
+ usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
+ # find the indices with max usage
+ bs = candidate_key.shape[0]
+ assert bs in [1, 2]
+
+ prototype_key = []
+ prototype_selection = []
+ for bi in range(bs):
+ _, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True)
+ prototype_indices = max_usage_indices.flatten()
+ prototype_key.append(candidate_key[bi, :, prototype_indices])
+ prototype_selection.append(candidate_selection[bi, :, prototype_indices])
+ prototype_key = torch.stack(prototype_key, dim=0)
+ prototype_selection = torch.stack(prototype_selection, dim=0)
+ """
+ Potentiation step
+ """
+ similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key,
+ prototype_selection)
+ affinity = do_softmax(similarity)
+
+ # readout the values
+ prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()}
+
+ # readout the shrinkage term
+ prototype_shrinkage = self._readout(affinity, candidate_shrinkage)
+
+ return prototype_key, prototype_value, prototype_shrinkage
+
+ def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]):
+ for obj in ids:
+ if obj not in self.sensory:
+ # also initializes the sensory memory
+ bs, _, h, w = sample_key.shape
+ self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w),
+ device=sample_key.device)
+
+ def update_sensory(self, sensory: torch.Tensor, ids: List[int]):
+ # sensory: 1*num_objects*C*H*W
+ for obj_id, obj in enumerate(ids):
+ self.sensory[obj] = sensory[:, obj_id]
+
+ def get_sensory(self, ids: List[int]):
+ # returns (1/2)*num_objects*C*H*W
+ return self._get_sensory_by_ids(ids)
+
+ def clear_non_permanent_memory(self):
+ self.work_mem.clear_non_permanent_memory()
+ if self.use_long_term:
+ self.long_mem.clear_non_permanent_memory()
+
+ def clear_sensory_memory(self):
+ self.sensory = {}
diff --git a/tracker/inference/object_info.py b/tracker/inference/object_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0e0bd45b10d0361c3ebc19783155e9ab29c8ad0
--- /dev/null
+++ b/tracker/inference/object_info.py
@@ -0,0 +1,24 @@
+class ObjectInfo:
+ """
+ Store meta information for an object
+ """
+ def __init__(self, id: int):
+ self.id = id
+ self.poke_count = 0 # count number of detections missed
+
+ def poke(self) -> None:
+ self.poke_count += 1
+
+ def unpoke(self) -> None:
+ self.poke_count = 0
+
+ def __hash__(self):
+ return hash(self.id)
+
+ def __eq__(self, other):
+ if type(other) == int:
+ return self.id == other
+ return self.id == other.id
+
+ def __repr__(self):
+ return f'(ID: {self.id})'
diff --git a/tracker/inference/object_manager.py b/tracker/inference/object_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9e5a8533d3f4e5c72150abe7e79d05fcd2f7bd9
--- /dev/null
+++ b/tracker/inference/object_manager.py
@@ -0,0 +1,148 @@
+from typing import Union, List, Dict
+
+import torch
+from tracker.inference.object_info import ObjectInfo
+
+
+class ObjectManager:
+ """
+ Object IDs are immutable. The same ID always represent the same object.
+ Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
+ Temporary IDs start from 1.
+ """
+ def __init__(self):
+ self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
+ self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
+ self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
+
+ self.all_historical_object_ids: List[int] = []
+
+ def _recompute_obj_id_to_obj_mapping(self) -> None:
+ self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
+
+ def add_new_objects(
+ self, objects: Union[List[ObjectInfo], ObjectInfo,
+ List[int]]) -> (List[int], List[int]):
+ if not isinstance(objects, list):
+ objects = [objects]
+
+ corresponding_tmp_ids = []
+ corresponding_obj_ids = []
+ for obj in objects:
+ if isinstance(obj, int):
+ obj = ObjectInfo(id=obj)
+
+ if obj in self.obj_to_tmp_id:
+ # old object
+ corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
+ corresponding_obj_ids.append(obj.id)
+ else:
+ # new object
+ new_obj = ObjectInfo(id=obj)
+
+ # new object
+ new_tmp_id = len(self.obj_to_tmp_id) + 1
+ self.obj_to_tmp_id[new_obj] = new_tmp_id
+ self.tmp_id_to_obj[new_tmp_id] = new_obj
+ self.all_historical_object_ids.append(new_obj.id)
+ corresponding_tmp_ids.append(new_tmp_id)
+ corresponding_obj_ids.append(new_obj.id)
+
+ self._recompute_obj_id_to_obj_mapping()
+ assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
+ return corresponding_tmp_ids, corresponding_obj_ids
+
+ def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
+ # delete an object or a list of objects
+ # re-sort the tmp ids
+ if isinstance(obj_ids_to_remove, int):
+ obj_ids_to_remove = [obj_ids_to_remove]
+
+ new_tmp_id = 1
+ total_num_id = len(self.obj_to_tmp_id)
+
+ local_obj_to_tmp_id = {}
+ local_tmp_to_obj_id = {}
+
+ for tmp_iter in range(1, total_num_id + 1):
+ obj = self.tmp_id_to_obj[tmp_iter]
+ if obj.id not in obj_ids_to_remove:
+ local_obj_to_tmp_id[obj] = new_tmp_id
+ local_tmp_to_obj_id[new_tmp_id] = obj
+ new_tmp_id += 1
+
+ self.obj_to_tmp_id = local_obj_to_tmp_id
+ self.tmp_id_to_obj = local_tmp_to_obj_id
+ self._recompute_obj_id_to_obj_mapping()
+
+ def purge_inactive_objects(self,
+ max_missed_detection_count: int) -> (bool, List[int], List[int]):
+ # remove tmp ids of objects that are removed
+ obj_id_to_be_deleted = []
+ tmp_id_to_be_deleted = []
+ tmp_id_to_keep = []
+ obj_id_to_keep = []
+
+ for obj in self.obj_to_tmp_id:
+ if obj.poke_count > max_missed_detection_count:
+ obj_id_to_be_deleted.append(obj.id)
+ tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
+ else:
+ tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
+ obj_id_to_keep.append(obj.id)
+
+ purge_activated = len(obj_id_to_be_deleted) > 0
+ if purge_activated:
+ self.delete_object(obj_id_to_be_deleted)
+ return purge_activated, tmp_id_to_keep, obj_id_to_keep
+
+ def tmp_to_obj_cls(self, mask) -> torch.Tensor:
+ # remap tmp id cls representation to the true object id representation
+ new_mask = torch.zeros_like(mask)
+ for tmp_id, obj in self.tmp_id_to_obj.items():
+ new_mask[mask == tmp_id] = obj.id
+ return new_mask
+
+ def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
+ # returns the mapping in a dict format for saving it with pickle
+ return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
+
+ def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
+ # turns a dict indexed by obj id into a tensor, ordered by tmp IDs
+ output = []
+ for _, obj in self.tmp_id_to_obj.items():
+ if obj.id not in obj_dict:
+ raise NotImplementedError
+ output.append(obj_dict[obj.id])
+ output = torch.stack(output, dim=dim)
+ return output
+
+ def make_one_hot(self, cls_mask) -> torch.Tensor:
+ output = []
+ for _, obj in self.tmp_id_to_obj.items():
+ output.append(cls_mask == obj.id)
+ if len(output) == 0:
+ output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
+ else:
+ output = torch.stack(output, dim=0)
+ return output
+
+ @property
+ def all_obj_ids(self) -> List[int]:
+ return [k.id for k in self.obj_to_tmp_id]
+
+ @property
+ def num_obj(self) -> int:
+ return len(self.obj_to_tmp_id)
+
+ def has_all(self, objects: List[int]) -> bool:
+ for obj in objects:
+ if obj not in self.obj_to_tmp_id:
+ return False
+ return True
+
+ def find_object_by_id(self, obj_id) -> ObjectInfo:
+ return self.obj_id_to_obj[obj_id]
+
+ def find_tmp_by_id(self, obj_id) -> int:
+ return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]
diff --git a/tracker/inference/utils/__init__.py b/tracker/inference/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/inference/utils/args_utils.py b/tracker/inference/utils/args_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a771ccaa080af2acd9757c7139c60c24652a1442
--- /dev/null
+++ b/tracker/inference/utils/args_utils.py
@@ -0,0 +1,30 @@
+import logging
+from omegaconf import DictConfig
+
+log = logging.getLogger()
+
+
+def get_dataset_cfg(cfg: DictConfig):
+ dataset_name = cfg.dataset
+ data_cfg = cfg.datasets[dataset_name]
+
+ potential_overrides = [
+ 'image_directory',
+ 'mask_directory',
+ 'json_directory',
+ 'size',
+ 'save_all',
+ 'use_all_masks',
+ 'use_long_term',
+ 'mem_every',
+ ]
+
+ for override in potential_overrides:
+ if cfg[override] is not None:
+ log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
+ data_cfg[override] = cfg[override]
+ # escalte all potential overrides to the top-level config
+ if override in data_cfg:
+ cfg[override] = data_cfg[override]
+
+ return data_cfg
diff --git a/tracker/inference/utils/burst_utils.py b/tracker/inference/utils/burst_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..570442848c83378f8562485aa7cca3502910440c
--- /dev/null
+++ b/tracker/inference/utils/burst_utils.py
@@ -0,0 +1,19 @@
+from os import path
+import copy
+import json
+
+
+class BURSTResultHandler:
+ def __init__(self, dataset_json):
+ self.dataset_json = copy.deepcopy(dataset_json)
+
+ # get rid of the segmentations while keeping the metadata
+ self.dataset_json['sequences'] = []
+
+ def add_sequence(self, sequence_json):
+ self.dataset_json['sequences'].append(sequence_json)
+
+ def dump(self, root):
+ json_path = path.join(root, 'predictions.json')
+ with open(json_path, 'w') as f:
+ json.dump(self.dataset_json, f)
\ No newline at end of file
diff --git a/tracker/inference/utils/frame_utils.py b/tracker/inference/utils/frame_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc18f35ed7c9ee4ff1aa41fa1ef988d372e91039
--- /dev/null
+++ b/tracker/inference/utils/frame_utils.py
@@ -0,0 +1,26 @@
+from typing import Dict, List, Tuple
+import torch
+
+from inference.object_info import ObjectInfo
+
+
+class FrameInfo:
+ def __init__(self, image: torch.Tensor, mask: torch.Tensor, segments_info: List[ObjectInfo],
+ ti: int, info: Dict):
+ self.image = image
+ self.mask = mask
+ self.segments_info = segments_info
+ self.ti = ti
+ self.info = info
+
+ @property
+ def name(self) -> str:
+ return self.info['frame']
+
+ @property
+ def shape(self) -> Tuple(int):
+ return self.info['shape']
+
+ @property
+ def need_save(self) -> bool:
+ return self.info['save']
diff --git a/tracker/inference/utils/results_utils.py b/tracker/inference/utils/results_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db970ffea9ff89df45bdb4932a5972373a4f23f2
--- /dev/null
+++ b/tracker/inference/utils/results_utils.py
@@ -0,0 +1,256 @@
+from typing import Tuple, Optional, Dict
+import logging
+import os
+import shutil
+from os import path
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+import pycocotools.mask as mask_util
+from threading import Thread
+from queue import Queue
+from dataclasses import dataclass
+import copy
+
+from tracker.utils.pano_utils import ID2RGBConverter
+from tracker.utils.palette import davis_palette_np
+from tracker.inference.object_manager import ObjectManager
+from tracker.inference.object_info import ObjectInfo
+
+log = logging.getLogger()
+
+try:
+ import hickle as hkl
+except ImportError:
+ log.warning('Failed to import hickle. Fine if not using multi-scale testing.')
+
+
+class ResultSaver:
+ def __init__(self,
+ output_root,
+ video_name,
+ *,
+ dataset,
+ object_manager: ObjectManager,
+ use_long_id,
+ palette=None,
+ save_mask=True,
+ save_scores=False,
+ score_output_root=None,
+ visualize_output_root=None,
+ visualize=False,
+ init_json=None):
+ self.output_root = output_root
+ self.video_name = video_name
+ self.dataset = dataset.lower()
+ self.use_long_id = use_long_id
+ self.palette = palette
+ self.object_manager = object_manager
+ self.save_mask = save_mask
+ self.save_scores = save_scores
+ self.score_output_root = score_output_root
+ self.visualize_output_root = visualize_output_root
+ self.visualize = visualize
+
+ if self.visualize:
+ if self.palette is not None:
+ self.colors = np.array(self.palette, dtype=np.uint8).reshape(-1, 3)
+ else:
+ self.colors = davis_palette_np
+
+ self.need_remapping = True
+ self.json_style = None
+ self.id2rgb_converter = ID2RGBConverter()
+
+ if 'burst' in self.dataset:
+ assert init_json is not None
+ self.input_segmentations = init_json['segmentations']
+ self.segmentations = [{} for _ in init_json['segmentations']]
+ self.annotated_frames = init_json['annotated_image_paths']
+ self.video_json = {k: v for k, v in init_json.items() if k != 'segmentations'}
+ self.video_json['segmentations'] = self.segmentations
+ self.json_style = 'burst'
+
+ self.queue = Queue(maxsize=10)
+ self.thread = Thread(target=save_result, args=(self.queue, ))
+ self.thread.daemon = True
+ self.thread.start()
+
+ def process(self,
+ prob: torch.Tensor,
+ frame_name: str,
+ resize_needed: bool = False,
+ shape: Optional[Tuple[int, int]] = None,
+ last_frame: bool = False,
+ path_to_image: str = None):
+
+ if resize_needed:
+ prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,
+ 0]
+ # Probability mask -> index mask
+ mask = torch.argmax(prob, dim=0)
+ if self.save_scores:
+ # also need to pass prob
+ prob = prob.cpu()
+ else:
+ prob = None
+
+ # remap indices
+ if self.need_remapping:
+ new_mask = torch.zeros_like(mask)
+ for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
+ new_mask[mask == tmp_id] = obj.id
+ mask = new_mask
+
+ args = ResultArgs(saver=self,
+ prob=prob,
+ mask=mask.cpu(),
+ frame_name=frame_name,
+ path_to_image=path_to_image,
+ tmp_id_to_obj=copy.deepcopy(self.object_manager.tmp_id_to_obj),
+ obj_to_tmp_id=copy.deepcopy(self.object_manager.obj_to_tmp_id),
+ last_frame=last_frame)
+
+ self.queue.put(args)
+
+ def end(self):
+ self.queue.put(None)
+ self.queue.join()
+ self.thread.join()
+
+
+@dataclass
+class ResultArgs:
+ saver: ResultSaver
+ prob: torch.Tensor
+ mask: torch.Tensor
+ frame_name: str
+ path_to_image: str
+ tmp_id_to_obj: Dict[int, ObjectInfo]
+ obj_to_tmp_id: Dict[ObjectInfo, int]
+ last_frame: bool
+
+
+def save_result(queue: Queue):
+ while True:
+ args: ResultArgs = queue.get()
+ if args is None:
+ queue.task_done()
+ break
+
+ saver = args.saver
+ prob = args.prob
+ mask = args.mask
+ frame_name = args.frame_name
+ path_to_image = args.path_to_image
+ tmp_id_to_obj = args.tmp_id_to_obj
+ obj_to_tmp_id = args.obj_to_tmp_id
+ last_frame = args.last_frame
+ all_obj_ids = [k.id for k in obj_to_tmp_id]
+
+ # record output in the json file
+ if saver.json_style == 'burst':
+ if frame_name in saver.annotated_frames:
+ frame_index = saver.annotated_frames.index(frame_name)
+ input_segments = saver.input_segmentations[frame_index]
+ frame_segments = saver.segmentations[frame_index]
+
+ for id in all_obj_ids:
+ if id in input_segments:
+ # if this frame has been given as input, just copy
+ frame_segments[id] = input_segments[id]
+ continue
+
+ segment = {}
+ segment_mask = (mask == id)
+ if segment_mask.sum() > 0:
+ coco_mask = mask_util.encode(np.asfortranarray(segment_mask.numpy()))
+ segment['rle'] = coco_mask['counts'].decode('utf-8')
+ frame_segments[id] = segment
+
+ # save the mask to disk
+ if saver.save_mask:
+ if saver.use_long_id:
+ out_mask = mask.numpy().astype(np.uint32)
+ rgb_mask = np.zeros((*out_mask.shape[-2:], 3), dtype=np.uint8)
+ for id in all_obj_ids:
+ _, image = saver.id2rgb_converter.convert(id)
+ obj_mask = (out_mask == id)
+ rgb_mask[obj_mask] = image
+ out_img = Image.fromarray(rgb_mask)
+ else:
+ rgb_mask = None
+ out_mask = mask.numpy().astype(np.uint8)
+ out_img = Image.fromarray(out_mask)
+ if saver.palette is not None:
+ out_img.putpalette(saver.palette)
+
+ this_out_path = path.join(saver.output_root, saver.video_name)
+ os.makedirs(this_out_path, exist_ok=True)
+ out_img.save(os.path.join(this_out_path, frame_name[:-4] + '.png'))
+
+ # save scores for multi-scale testing
+ if saver.save_scores:
+ this_out_path = path.join(saver.score_output_root, saver.video_name)
+ os.makedirs(this_out_path, exist_ok=True)
+
+ prob = (prob.detach().numpy() * 255).astype(np.uint8)
+
+ if last_frame:
+ tmp_to_obj_mapping = {obj.id: tmp_id for obj, tmp_id in tmp_id_to_obj.items()}
+ hkl.dump(tmp_to_obj_mapping, path.join(this_out_path, f'backward.hkl'), mode='w')
+
+ hkl.dump(prob,
+ path.join(this_out_path, f'{frame_name[:-4]}.hkl'),
+ mode='w',
+ compression='lzf')
+
+ if saver.visualize:
+ if path_to_image is not None:
+ image_np = np.array(Image.open(path_to_image))
+ else:
+ raise ValueError('Cannot visualize without path_to_image')
+
+ if rgb_mask is None:
+ # we need to apply a palette
+ rgb_mask = np.zeros((*out_mask.shape, 3), dtype=np.uint8)
+ for id in all_obj_ids:
+ image = saver.colors[id]
+ obj_mask = (out_mask == id)
+ rgb_mask[obj_mask] = image
+
+ alpha = (out_mask == 0).astype(np.float32) * 0.5 + 0.5
+ alpha = alpha[:, :, None]
+ blend = (image_np * alpha + rgb_mask * (1 - alpha)).astype(np.uint8)
+
+ # find a place to save the visualization
+ this_vis_path = path.join(saver.visualize_output_root, saver.video_name)
+ os.makedirs(this_vis_path, exist_ok=True)
+ Image.fromarray(blend).save(path.join(this_vis_path, frame_name[:-4] + '.jpg'))
+
+ queue.task_done()
+
+
+def make_zip(dataset, run_dir, exp_id, mask_output_root):
+ if dataset.startswith('y'):
+ # YoutubeVOS
+ log.info('Making zip for YouTubeVOS...')
+ shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir,
+ 'Annotations')
+ elif dataset == 'd17-test-dev':
+ # DAVIS 2017 test-dev -- zip from within the Annotation folder
+ log.info('Making zip for DAVIS test-dev...')
+ shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root)
+ elif dataset == 'mose-val':
+ # MOSE validation -- same as DAVIS test-dev
+ log.info('Making zip for MOSE validation...')
+ shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root)
+ elif dataset == 'lvos-test':
+ # LVOS test -- same as YouTubeVOS
+ log.info('Making zip for LVOS test...')
+ shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir,
+ 'Annotations')
+ else:
+ log.info(f'Not making zip for {dataset}.')
diff --git a/tracker/model/__init__.py b/tracker/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/model/aux_modules.py b/tracker/model/aux_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf4a8ef3913dc5f057b29ad5917cb7bc5541d004
--- /dev/null
+++ b/tracker/model/aux_modules.py
@@ -0,0 +1,80 @@
+"""
+For computing auxiliary outputs for auxiliary losses
+"""
+from typing import Dict
+from omegaconf import DictConfig
+import torch
+import torch.nn as nn
+
+from tracker.model.group_modules import GConv2d
+from tracker.utils.tensor_utils import aggregate
+
+
+class LinearPredictor(nn.Module):
+ def __init__(self, x_dim: int, pix_dim: int):
+ super().__init__()
+ self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
+
+ def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ # pixel_feat: B*pix_dim*H*W
+ # x: B*num_objects*x_dim*H*W
+ num_objects = x.shape[1]
+ x = self.projection(x)
+
+ pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
+ logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
+ return logits
+
+
+class DirectPredictor(nn.Module):
+ def __init__(self, x_dim: int):
+ super().__init__()
+ self.projection = GConv2d(x_dim, 1, kernel_size=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: B*num_objects*x_dim*H*W
+ logits = self.projection(x).squeeze(2)
+ return logits
+
+
+class AuxComputer(nn.Module):
+ def __init__(self, cfg: DictConfig):
+ super().__init__()
+
+ use_sensory_aux = cfg.model.aux_loss.sensory.enabled
+ self.use_query_aux = cfg.model.aux_loss.query.enabled
+
+ sensory_dim = cfg.model.sensory_dim
+ embed_dim = cfg.model.embed_dim
+
+ if use_sensory_aux:
+ self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
+ else:
+ self.sensory_aux = None
+
+ def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
+ prob = torch.sigmoid(logits)
+ if selector is not None:
+ prob = prob * selector
+ logits = aggregate(prob, dim=1)
+ return logits
+
+ def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
+ selector: torch.Tensor) -> Dict[str, torch.Tensor]:
+ sensory = aux_input['sensory']
+ q_logits = aux_input['q_logits']
+
+ aux_output = {}
+ aux_output['attn_mask'] = aux_input['attn_mask']
+
+ if self.sensory_aux is not None:
+ # B*num_objects*H*W
+ logits = self.sensory_aux(pix_feat, sensory)
+ aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
+ if self.use_query_aux:
+ # B*num_objects*num_levels*H*W
+ aux_output['q_logits'] = self._aggregate_with_selector(
+ torch.stack(q_logits, dim=2),
+ selector.unsqueeze(2) if selector is not None else None)
+
+ return aux_output
\ No newline at end of file
diff --git a/tracker/model/big_modules.py b/tracker/model/big_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc1daaf0d72811694922476e63e018cefa6c5656
--- /dev/null
+++ b/tracker/model/big_modules.py
@@ -0,0 +1,304 @@
+"""
+big_modules.py - This file stores higher-level network blocks.
+
+x - usually denotes features that are shared between objects.
+g - usually denotes features that are not shared between objects
+ with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W).
+
+The trailing number of a variable usually denotes the stride
+"""
+
+from omegaconf import DictConfig
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from tracker.model.group_modules import *
+from tracker.model.utils import resnet
+from tracker.model.modules import *
+
+
+class PixelEncoder(nn.Module):
+ def __init__(self, model_cfg: DictConfig):
+ super().__init__()
+
+ self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
+ if self.is_resnet:
+ if model_cfg.pixel_encoder.type == 'resnet18':
+ network = resnet.resnet18(pretrained=True)
+ elif model_cfg.pixel_encoder.type == 'resnet50':
+ network = resnet.resnet50(pretrained=True)
+ else:
+ raise NotImplementedError
+ self.conv1 = network.conv1
+ self.bn1 = network.bn1
+ self.relu = network.relu
+ self.maxpool = network.maxpool
+
+ self.res2 = network.layer1
+ self.layer2 = network.layer2
+ self.layer3 = network.layer3
+ else:
+ raise NotImplementedError
+
+ def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ f4 = self.res2(x)
+ f8 = self.layer2(f4)
+ f16 = self.layer3(f8)
+
+ return f16, f8, f4
+
+ # override the default train() to freeze BN statistics
+ def train(self, mode=True):
+ self.training = False
+ for module in self.children():
+ module.train(False)
+ return self
+
+
+class KeyProjection(nn.Module):
+ def __init__(self, model_cfg: DictConfig):
+ super().__init__()
+ in_dim = model_cfg.pixel_encoder.ms_dims[0]
+ mid_dim = model_cfg.pixel_dim
+ key_dim = model_cfg.key_dim
+
+ self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
+ self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
+ # shrinkage
+ self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
+ # selection
+ self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
+
+ nn.init.orthogonal_(self.key_proj.weight.data)
+ nn.init.zeros_(self.key_proj.bias.data)
+
+ def forward(self, x: torch.Tensor, *, need_s: bool,
+ need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+ x = self.pix_feat_proj(x)
+ shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
+ selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
+
+ return self.key_proj(x), shrinkage, selection
+
+
+class MaskEncoder(nn.Module):
+ def __init__(self, model_cfg: DictConfig, single_object=False):
+ super().__init__()
+ pixel_dim = model_cfg.pixel_dim
+ value_dim = model_cfg.value_dim
+ sensory_dim = model_cfg.sensory_dim
+ final_dim = model_cfg.mask_encoder.final_dim
+
+ self.single_object = single_object
+ extra_dim = 1 if single_object else 2
+
+ if model_cfg.mask_encoder.type == 'resnet18':
+ network = resnet.resnet18(pretrained=True, extra_dim=extra_dim)
+ elif model_cfg.mask_encoder.type == 'resnet50':
+ network = resnet.resnet50(pretrained=True, extra_dim=extra_dim)
+ else:
+ raise NotImplementedError
+ self.conv1 = network.conv1
+ self.bn1 = network.bn1
+ self.relu = network.relu
+ self.maxpool = network.maxpool
+
+ self.layer1 = network.layer1
+ self.layer2 = network.layer2
+ self.layer3 = network.layer3
+
+ self.distributor = MainToGroupDistributor()
+ self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim)
+
+ self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim)
+
+ def forward(self,
+ image: torch.Tensor,
+ pix_feat: torch.Tensor,
+ sensory: torch.Tensor,
+ masks: torch.Tensor,
+ others: torch.Tensor,
+ *,
+ deep_update: bool = True,
+ chunk_size: int = -1) -> (torch.Tensor, torch.Tensor):
+ # ms_features are from the key encoder
+ # we only use the first one (lowest resolution), following XMem
+ if self.single_object:
+ g = masks.unsqueeze(2)
+ else:
+ g = torch.stack([masks, others], dim=2)
+
+ g = self.distributor(image, g)
+
+ batch_size, num_objects = g.shape[:2]
+ if chunk_size < 1 or chunk_size >= num_objects:
+ chunk_size = num_objects
+ fast_path = True
+ new_sensory = sensory
+ else:
+ if deep_update:
+ new_sensory = torch.empty_like(sensory)
+ else:
+ new_sensory = sensory
+ fast_path = False
+
+ # chunk-by-chunk inference
+ all_g = []
+ for i in range(0, num_objects, chunk_size):
+ if fast_path:
+ g_chunk = g
+ else:
+ g_chunk = g[:, i:i + chunk_size]
+ actual_chunk_size = g_chunk.shape[1]
+ g_chunk = g_chunk.flatten(start_dim=0, end_dim=1)
+
+ g_chunk = self.conv1(g_chunk)
+ g_chunk = self.bn1(g_chunk) # 1/2, 64
+ g_chunk = self.maxpool(g_chunk) # 1/4, 64
+ g_chunk = self.relu(g_chunk)
+
+ g_chunk = self.layer1(g_chunk) # 1/4
+ g_chunk = self.layer2(g_chunk) # 1/8
+ g_chunk = self.layer3(g_chunk) # 1/16
+
+ g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:])
+ g_chunk = self.fuser(pix_feat, g_chunk)
+ all_g.append(g_chunk)
+ if deep_update:
+ if fast_path:
+ new_sensory = self.sensory_update(g_chunk, sensory)
+ else:
+ new_sensory[:, i:i + chunk_size] = self.sensory_update(
+ g_chunk, sensory[:, i:i + chunk_size])
+ g = torch.cat(all_g, dim=1)
+
+ return g, new_sensory
+
+ # override the default train() to freeze BN statistics
+ def train(self, mode=True):
+ self.training = False
+ for module in self.children():
+ module.train(False)
+ return self
+
+
+class PixelFeatureFuser(nn.Module):
+ def __init__(self, model_cfg: DictConfig, single_object=False):
+ super().__init__()
+ value_dim = model_cfg.value_dim
+ sensory_dim = model_cfg.sensory_dim
+ pixel_dim = model_cfg.pixel_dim
+ embed_dim = model_cfg.embed_dim
+ self.single_object = single_object
+
+ self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim)
+ if self.single_object:
+ self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1)
+ else:
+ self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1)
+
+ def forward(self,
+ pix_feat: torch.Tensor,
+ pixel_memory: torch.Tensor,
+ sensory_memory: torch.Tensor,
+ last_mask: torch.Tensor,
+ last_others: torch.Tensor,
+ *,
+ chunk_size: int = -1) -> torch.Tensor:
+ batch_size, num_objects = pixel_memory.shape[:2]
+
+ if self.single_object:
+ last_mask = last_mask.unsqueeze(2)
+ else:
+ last_mask = torch.stack([last_mask, last_others], dim=2)
+
+ if chunk_size < 1:
+ chunk_size = num_objects
+
+ # chunk-by-chunk inference
+ all_p16 = []
+ for i in range(0, num_objects, chunk_size):
+ sensory_readout = self.sensory_compress(
+ torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2))
+ p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout
+ p16 = self.fuser(pix_feat, p16)
+ all_p16.append(p16)
+ p16 = torch.cat(all_p16, dim=1)
+
+ return p16
+
+
+class MaskDecoder(nn.Module):
+ def __init__(self, model_cfg: DictConfig):
+ super().__init__()
+ embed_dim = model_cfg.embed_dim
+ sensory_dim = model_cfg.sensory_dim
+ ms_image_dims = model_cfg.pixel_encoder.ms_dims
+ up_dims = model_cfg.mask_decoder.up_dims
+
+ assert embed_dim == up_dims[0]
+
+ self.sensory_update = SensoryUpdater([up_dims[0], up_dims[1], up_dims[2] + 1], sensory_dim,
+ sensory_dim)
+
+ self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1])
+ self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1])
+ self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2])
+
+ self.pred = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
+
+ def forward(self,
+ ms_image_feat: Iterable[torch.Tensor],
+ memory_readout: torch.Tensor,
+ sensory: torch.Tensor,
+ *,
+ chunk_size: int = -1,
+ update_sensory: bool = True) -> (torch.Tensor, torch.Tensor):
+
+ batch_size, num_objects = memory_readout.shape[:2]
+ f8, f4 = self.decoder_feat_proc(ms_image_feat[1:])
+ if chunk_size < 1 or chunk_size >= num_objects:
+ chunk_size = num_objects
+ fast_path = True
+ new_sensory = sensory
+ else:
+ if update_sensory:
+ new_sensory = torch.empty_like(sensory)
+ else:
+ new_sensory = sensory
+ fast_path = False
+
+ # chunk-by-chunk inference
+ all_logits = []
+ for i in range(0, num_objects, chunk_size):
+ if fast_path:
+ p16 = memory_readout
+ else:
+ p16 = memory_readout[:, i:i + chunk_size]
+ actual_chunk_size = p16.shape[1]
+
+ p8 = self.up_16_8(p16, f8)
+ p4 = self.up_8_4(p8, f4)
+ with torch.cuda.amp.autocast(enabled=False):
+ logits = self.pred(F.relu(p4.flatten(start_dim=0, end_dim=1).float()))
+
+ if update_sensory:
+ p4 = torch.cat(
+ [p4, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2)
+ if fast_path:
+ new_sensory = self.sensory_update([p16, p8, p4], sensory)
+ else:
+ new_sensory[:,
+ i:i + chunk_size] = self.sensory_update([p16, p8, p4],
+ sensory[:,
+ i:i + chunk_size])
+ all_logits.append(logits)
+ logits = torch.cat(all_logits, dim=0)
+ logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
+
+ return new_sensory, logits
diff --git a/tracker/model/channel_attn.py b/tracker/model/channel_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2096c1c4b4745a3ea2060bb25af3b19ff9cf3ec
--- /dev/null
+++ b/tracker/model/channel_attn.py
@@ -0,0 +1,39 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class CAResBlock(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
+ super().__init__()
+ self.residual = residual
+ self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
+ self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
+
+ t = int((abs(math.log2(out_dim)) + 1) // 2)
+ k = t if t % 2 else t + 1
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
+
+ if self.residual:
+ if in_dim == out_dim:
+ self.downsample = nn.Identity()
+ else:
+ self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ r = x
+ x = self.conv1(F.relu(x))
+ x = self.conv2(F.relu(x))
+
+ b, c = x.shape[:2]
+ w = self.pool(x).view(b, 1, c)
+ w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
+
+ if self.residual:
+ x = x * w + self.downsample(r)
+ else:
+ x = x * w
+
+ return x
diff --git a/tracker/model/cutie.py b/tracker/model/cutie.py
new file mode 100644
index 0000000000000000000000000000000000000000..82c5652a3f3d657ab71ed208cd11ca2322608d7a
--- /dev/null
+++ b/tracker/model/cutie.py
@@ -0,0 +1,249 @@
+from typing import List, Dict
+import logging
+from omegaconf import DictConfig
+import torch
+import torch.nn as nn
+
+from tracker.model.modules import *
+from tracker.model.big_modules import *
+from tracker.model.aux_modules import AuxComputer
+from tracker.model.utils.memory_utils import *
+from tracker.model.transformer.object_transformer import QueryTransformer
+from tracker.model.transformer.object_summarizer import ObjectSummarizer
+from tracker.utils.tensor_utils import aggregate
+
+log = logging.getLogger()
+
+
+class CUTIE(nn.Module):
+ def __init__(self, cfg: DictConfig, *, single_object=False):
+ super().__init__()
+ model_cfg = cfg.model
+ self.ms_dims = model_cfg.pixel_encoder.ms_dims
+ self.key_dim = model_cfg.key_dim
+ self.value_dim = model_cfg.value_dim
+ self.sensory_dim = model_cfg.sensory_dim
+ self.pixel_dim = model_cfg.pixel_dim
+ self.embed_dim = model_cfg.embed_dim
+ self.single_object = single_object
+
+ log.info(f'Single object: {self.single_object}')
+
+ self.pixel_encoder = PixelEncoder(model_cfg)
+ self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
+ self.key_proj = KeyProjection(model_cfg)
+ self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
+ self.mask_decoder = MaskDecoder(model_cfg)
+ self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
+ self.object_transformer = QueryTransformer(model_cfg)
+ self.object_summarizer = ObjectSummarizer(model_cfg)
+ self.aux_computer = AuxComputer(cfg)
+
+ self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False)
+
+ def _get_others(self, masks: torch.Tensor) -> torch.Tensor:
+ # for each object, return the sum of masks of all other objects
+ if self.single_object:
+ return None
+
+ num_objects = masks.shape[1]
+ if num_objects >= 1:
+ others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1)
+ else:
+ others = torch.zeros_like(masks)
+ return others
+
+ def encode_image(self, image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
+ image = (image - self.pixel_mean) / self.pixel_std
+ ms_image_feat = self.pixel_encoder(image)
+ return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
+
+ def encode_mask(
+ self,
+ image: torch.Tensor,
+ ms_features: List[torch.Tensor],
+ sensory: torch.Tensor,
+ masks: torch.Tensor,
+ *,
+ deep_update: bool = True,
+ chunk_size: int = -1,
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
+ image = (image - self.pixel_mean) / self.pixel_std
+ others = self._get_others(masks)
+ mask_value, new_sensory = self.mask_encoder(image,
+ ms_features,
+ sensory,
+ masks,
+ others,
+ deep_update=deep_update,
+ chunk_size=chunk_size)
+ object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
+ return mask_value, new_sensory, object_summaries, object_logits
+
+ def transform_key(self,
+ final_pix_feat: torch.Tensor,
+ *,
+ need_sk: bool = True,
+ need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+ key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
+ return key, shrinkage, selection
+
+ # Used in training only.
+ # This step is replaced by MemoryManager in test time
+ def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor,
+ memory_key: torch.Tensor, memory_shrinkage: torch.Tensor,
+ msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
+ sensory: torch.Tensor, last_mask: torch.Tensor,
+ selector: torch.Tensor) -> (torch.Tensor, Dict[str, torch.Tensor]):
+ """
+ query_key : B * CK * H * W
+ query_selection : B * CK * H * W
+ memory_key : B * CK * T * H * W
+ memory_shrinkage: B * 1 * T * H * W
+ msk_value : B * num_objects * CV * T * H * W
+ obj_memory : B * num_objects * T * num_summaries * C
+ pixel_feature : B * C * H * W
+ """
+ batch_size, num_objects = msk_value.shape[:2]
+
+ # read using visual attention
+ with torch.cuda.amp.autocast(enabled=False):
+ affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
+ query_selection.float())
+
+ msk_value = msk_value.flatten(start_dim=1, end_dim=2).float()
+
+ # B * (num_objects*CV) * H * W
+ pixel_readout = readout(affinity, msk_value)
+ pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim,
+ *pixel_readout.shape[-2:])
+ pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
+
+ # read from query transformer
+ mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector)
+
+ aux_output = {
+ 'sensory': sensory,
+ 'q_logits': aux_features['logits'] if aux_features else None,
+ 'attn_mask': aux_features['attn_mask'] if aux_features else None,
+ }
+
+ return mem_readout, aux_output
+
+ def pixel_fusion(self,
+ pix_feat: torch.Tensor,
+ pixel: torch.Tensor,
+ sensory: torch.Tensor,
+ last_mask: torch.Tensor,
+ *,
+ chunk_size: int = -1) -> torch.Tensor:
+ last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area')
+ last_others = self._get_others(last_mask)
+ fused = self.pixel_fuser(pix_feat,
+ pixel,
+ sensory,
+ last_mask,
+ last_others,
+ chunk_size=chunk_size)
+ return fused
+
+ def readout_query(self,
+ pixel_readout,
+ obj_memory,
+ *,
+ selector=None,
+ need_weights=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
+ return self.object_transformer(pixel_readout,
+ obj_memory,
+ selector=selector,
+ need_weights=need_weights)
+
+ def segment(self,
+ ms_image_feat: List[torch.Tensor],
+ memory_readout: torch.Tensor,
+ sensory: torch.Tensor,
+ *,
+ selector: bool = None,
+ chunk_size: int = -1,
+ update_sensory: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+ """
+ multi_scale_features is from the key encoder for skip-connection
+ memory_readout is from working/long-term memory
+ sensory is the sensory memory
+ last_mask is the mask from the last frame, supplementing sensory memory
+ selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects
+ during training.
+ """
+ sensory, logits = self.mask_decoder(ms_image_feat,
+ memory_readout,
+ sensory,
+ chunk_size=chunk_size,
+ update_sensory=update_sensory)
+
+ prob = torch.sigmoid(logits)
+ if selector is not None:
+ prob = prob * selector
+
+ # Softmax over all objects[]
+ logits = aggregate(prob, dim=1)
+ logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False)
+ prob = F.softmax(logits, dim=1)
+
+ return sensory, logits, prob
+
+ def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor],
+ selector: torch.Tensor) -> Dict[str, torch.Tensor]:
+ return self.aux_computer(pix_feat, aux_inputs, selector)
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None:
+ if not self.single_object:
+ # Map single-object weight to multi-object weight (4->5 out channels in conv1)
+ for k in list(src_dict.keys()):
+ if k == 'mask_encoder.conv1.weight':
+ if src_dict[k].shape[1] == 4:
+ log.info(f'Converting {k} from single object to multiple objects.')
+ pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
+ if not init_as_zero_if_needed:
+ nn.init.orthogonal_(pads)
+ log.info(f'Randomly initialized padding for {k}.')
+ else:
+ log.info(f'Zero-initialized padding for {k}.')
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
+ elif k == 'pixel_fuser.sensory_compress.weight':
+ if src_dict[k].shape[1] == self.sensory_dim + 1:
+ log.info(f'Converting {k} from single object to multiple objects.')
+ pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device)
+ if not init_as_zero_if_needed:
+ nn.init.orthogonal_(pads)
+ log.info(f'Randomly initialized padding for {k}.')
+ else:
+ log.info(f'Zero-initialized padding for {k}.')
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
+ elif self.single_object:
+ """
+ If the model is multiple-object and we are training in single-object,
+ we strip the last channel of conv1.
+ This is not supposed to happen in standard training except when users are trying to
+ finetune a trained model with single object datasets.
+ """
+ if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
+ log.warning(f'Converting {k} from multiple objects to single object.'
+ 'This is not supposed to happen in standard training.')
+ src_dict[k] = src_dict[k][:, :-1]
+
+ for k in src_dict:
+ if k not in self.state_dict():
+ log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!')
+ for k in self.state_dict():
+ if k not in src_dict:
+ log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!')
+
+ self.load_state_dict(src_dict, strict=False)
+
+ @property
+ def device(self) -> torch.device:
+ return self.pixel_mean.device
\ No newline at end of file
diff --git a/tracker/model/group_modules.py b/tracker/model/group_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9bf64c51613be619705bee2da9b7508378cbb46
--- /dev/null
+++ b/tracker/model/group_modules.py
@@ -0,0 +1,127 @@
+from typing import Optional
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tracker.model.channel_attn import CAResBlock
+
+
+def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
+ align_corners: bool) -> torch.Tensor:
+ batch_size, num_objects = g.shape[:2]
+ g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
+ scale_factor=ratio,
+ mode=mode,
+ align_corners=align_corners)
+ g = g.view(batch_size, num_objects, *g.shape[1:])
+ return g
+
+
+def upsample_groups(g: torch.Tensor,
+ ratio: float = 2,
+ mode: str = 'bilinear',
+ align_corners: bool = False) -> torch.Tensor:
+ return interpolate_groups(g, ratio, mode, align_corners)
+
+
+def downsample_groups(g: torch.Tensor,
+ ratio: float = 1 / 2,
+ mode: str = 'area',
+ align_corners: bool = None) -> torch.Tensor:
+ return interpolate_groups(g, ratio, mode, align_corners)
+
+
+class GConv2d(nn.Conv2d):
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
+ batch_size, num_objects = g.shape[:2]
+ g = super().forward(g.flatten(start_dim=0, end_dim=1))
+ return g.view(batch_size, num_objects, *g.shape[1:])
+
+
+class GroupResBlock(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int):
+ super().__init__()
+
+ if in_dim == out_dim:
+ self.downsample = nn.Identity()
+ else:
+ self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
+
+ self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
+ self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
+
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
+ out_g = self.conv1(F.relu(g))
+ out_g = self.conv2(F.relu(out_g))
+
+ g = self.downsample(g)
+
+ return out_g + g
+
+
+class MainToGroupDistributor(nn.Module):
+ def __init__(self,
+ x_transform: Optional[nn.Module] = None,
+ g_transform: Optional[nn.Module] = None,
+ method: str = 'cat',
+ reverse_order: bool = False):
+ super().__init__()
+
+ self.x_transform = x_transform
+ self.g_transform = g_transform
+ self.method = method
+ self.reverse_order = reverse_order
+
+ def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
+ num_objects = g.shape[1]
+
+ if self.x_transform is not None:
+ x = self.x_transform(x)
+
+ if self.g_transform is not None:
+ g = self.g_transform(g)
+
+ if not skip_expand:
+ x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
+ if self.method == 'cat':
+ if self.reverse_order:
+ g = torch.cat([g, x], 2)
+ else:
+ g = torch.cat([x, g], 2)
+ elif self.method == 'add':
+ g = x + g
+ elif self.method == 'mulcat':
+ g = torch.cat([x * g, g], dim=2)
+ elif self.method == 'muladd':
+ g = x * g + g
+ else:
+ raise NotImplementedError
+
+ return g
+
+
+class GroupFeatureFusionBlock(nn.Module):
+ def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
+ super().__init__()
+
+ x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
+ g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
+
+ self.distributor = MainToGroupDistributor(x_transform=x_transform,
+ g_transform=g_transform,
+ method='add')
+ self.block1 = CAResBlock(out_dim, out_dim)
+ self.block2 = CAResBlock(out_dim, out_dim)
+
+ def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
+ batch_size, num_objects = g.shape[:2]
+
+ g = self.distributor(x, g)
+
+ g = g.flatten(start_dim=0, end_dim=1)
+
+ g = self.block1(g)
+ g = self.block2(g)
+
+ g = g.view(batch_size, num_objects, *g.shape[1:])
+
+ return g
\ No newline at end of file
diff --git a/tracker/model/losses.py b/tracker/model/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc9c7c5c3bbd7e7bd909a82aaa8a3b5c8046d7ee
--- /dev/null
+++ b/tracker/model/losses.py
@@ -0,0 +1,97 @@
+from typing import List, Dict
+from omegaconf import DictConfig
+from collections import defaultdict
+import torch
+import torch.nn.functional as F
+
+from tracker.utils.point_features import calculate_uncertainty, point_sample, get_uncertain_point_coords_with_randomness
+from tracker.utils.tensor_utils import cls_to_one_hot
+
+
+@torch.jit.script
+def ce_loss(logits: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor:
+ # logits: T*C*num_points
+ loss = F.cross_entropy(logits, soft_gt, reduction='none')
+ # sum over temporal dimension
+ return loss.sum(0).mean()
+
+
+@torch.jit.script
+def dice_loss(mask: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor:
+ # mask: T*C*num_points
+ # soft_gt: T*C*num_points
+ # ignores the background
+ mask = mask[:, 1:].flatten(start_dim=2)
+ gt = soft_gt[:, 1:].float().flatten(start_dim=2)
+ numerator = 2 * (mask * gt).sum(-1)
+ denominator = mask.sum(-1) + gt.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss.sum(0).mean()
+
+
+class LossComputer:
+ def __init__(self, cfg: DictConfig, stage_cfg: DictConfig):
+ super().__init__()
+ self.point_supervision = stage_cfg.point_supervision
+ self.num_points = stage_cfg.train_num_points
+ self.oversample_ratio = stage_cfg.oversample_ratio
+ self.importance_sample_ratio = stage_cfg.importance_sample_ratio
+
+ self.sensory_weight = cfg.model.aux_loss.sensory.weight
+ self.query_weight = cfg.model.aux_loss.query.weight
+
+ def mask_loss(self, logits: torch.Tensor,
+ soft_gt: torch.Tensor) -> (torch.Tensor, torch.Tensor):
+ assert self.point_supervision
+
+ with torch.no_grad():
+ # sample point_coords
+ point_coords = get_uncertain_point_coords_with_randomness(
+ logits, lambda x: calculate_uncertainty(x), self.num_points, self.oversample_ratio,
+ self.importance_sample_ratio)
+ # get gt labels
+ point_labels = point_sample(soft_gt, point_coords, align_corners=False)
+ point_logits = point_sample(logits, point_coords, align_corners=False)
+ # point_labels and point_logits: B*C*num_points
+
+ loss_ce = ce_loss(point_logits, point_labels)
+ loss_dice = dice_loss(point_logits.softmax(dim=1), point_labels)
+
+ return loss_ce, loss_dice
+
+ def compute(self, data: Dict[str, torch.Tensor],
+ num_objects: List[int]) -> Dict[str, torch.Tensor]:
+ batch_size, num_frames = data['rgb'].shape[:2]
+ losses = defaultdict(float)
+ t_range = range(1, num_frames)
+
+ for bi in range(batch_size):
+ logits = torch.stack([data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range],
+ dim=0)
+ cls_gt = data['cls_gt'][bi, 1:] # remove gt for the first frame
+ soft_gt = cls_to_one_hot(cls_gt, num_objects[bi])
+
+ loss_ce, loss_dice = self.mask_loss(logits, soft_gt)
+ losses['loss_ce'] += loss_ce / batch_size
+ losses['loss_dice'] += loss_dice / batch_size
+
+ aux = [data[f'aux_{ti}'] for ti in t_range]
+ if 'sensory_logits' in aux[0]:
+ sensory_log = torch.stack(
+ [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0)
+ loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt)
+ losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight
+ losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight
+ if 'q_logits' in aux[0]:
+ num_levels = aux[0]['q_logits'].shape[2]
+
+ for l in range(num_levels):
+ query_log = torch.stack(
+ [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0)
+ loss_ce, loss_dice = self.mask_loss(query_log, soft_gt)
+ losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight
+ losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight
+
+ losses['total_loss'] = sum(losses.values())
+
+ return losses
diff --git a/tracker/model/modules.py b/tracker/model/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..59c4170af5abfecf7b85ed7804fc390285e0194d
--- /dev/null
+++ b/tracker/model/modules.py
@@ -0,0 +1,85 @@
+from typing import List, Iterable
+import torch
+import torch.nn as nn
+
+from tracker.model.group_modules import *
+
+
+class MaskUpsampleBlock(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
+ super().__init__()
+ self.distributor = MainToGroupDistributor(method='add')
+ self.out_conv = GroupResBlock(in_dim, out_dim)
+ self.scale_factor = scale_factor
+
+ def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
+ g = upsample_groups(in_g, ratio=self.scale_factor)
+ g = self.distributor(skip_f, g)
+ g = self.out_conv(g)
+ return g
+
+
+class DecoderFeatureProcessor(nn.Module):
+ def __init__(self, decoder_dims: List[int], out_dims: List[int]):
+ super().__init__()
+ self.transforms = nn.ModuleList([
+ nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
+ ])
+
+ def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
+ outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
+ return outputs
+
+
+# @torch.jit.script
+def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
+ # h: batch_size * num_objects * hidden_dim * h * w
+ # values: batch_size * num_objects * (hidden_dim*3) * h * w
+ dim = values.shape[2] // 3
+ forget_gate = torch.sigmoid(values[:, :, :dim])
+ update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
+ new_value = torch.tanh(values[:, :, dim * 2:])
+ new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
+ return new_h
+
+
+class SensoryUpdater(nn.Module):
+ # Used in the decoder, multi-scale feature + GRU
+ def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
+ super().__init__()
+ self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
+ self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
+ self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
+
+ self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
+
+ nn.init.xavier_normal_(self.transform.weight)
+
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
+ self.g4_conv(downsample_groups(g[2], ratio=1/4))
+
+ with torch.cuda.amp.autocast(enabled=False):
+ g = g.float()
+ h = h.float()
+ values = self.transform(torch.cat([g, h], dim=2))
+ new_h = _recurrent_update(h, values)
+
+ return new_h
+
+
+class SensoryDeepUpdater(nn.Module):
+ def __init__(self, f_dim: int, sensory_dim: int):
+ super().__init__()
+ self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
+
+ nn.init.xavier_normal_(self.transform.weight)
+
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
+ with torch.cuda.amp.autocast(enabled=False):
+ g = g.float()
+ h = h.float()
+ values = self.transform(torch.cat([g, h], dim=2))
+ new_h = _recurrent_update(h, values)
+
+ return new_h
diff --git a/tracker/model/transformer/__init__.py b/tracker/model/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/model/transformer/object_summarizer.py b/tracker/model/transformer/object_summarizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..42ee1b5385d607f34145e25b0362678f196064a2
--- /dev/null
+++ b/tracker/model/transformer/object_summarizer.py
@@ -0,0 +1,89 @@
+from typing import List, Dict, Optional
+from omegaconf import DictConfig
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tracker.model.transformer.positional_encoding import PositionalEncoding
+
+
+# @torch.jit.script
+def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
+ logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
+ # value: B*num_objects*H*W*value_dim
+ # logits: B*num_objects*H*W*num_summaries
+ # masks: B*num_objects*H*W*num_summaries: 1 if allowed
+ weights = logits.sigmoid() * masks
+ # B*num_objects*num_summaries*value_dim
+ sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
+ # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
+ area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
+
+ # B*num_objects*num_summaries*value_dim
+ return sums, area
+
+
+class ObjectSummarizer(nn.Module):
+ def __init__(self, model_cfg: DictConfig):
+ super().__init__()
+
+ this_cfg = model_cfg.object_summarizer
+ self.value_dim = model_cfg.value_dim
+ self.embed_dim = this_cfg.embed_dim
+ self.num_summaries = this_cfg.num_summaries
+ self.add_pe = this_cfg.add_pe
+ self.pixel_pe_scale = model_cfg.pixel_pe_scale
+ self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
+
+ if self.add_pe:
+ self.pos_enc = PositionalEncoding(self.embed_dim,
+ scale=self.pixel_pe_scale,
+ temperature=self.pixel_pe_temperature)
+
+ self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
+ self.feature_pred = nn.Sequential(
+ nn.Linear(self.embed_dim, self.embed_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.embed_dim, self.embed_dim),
+ )
+ self.weights_pred = nn.Sequential(
+ nn.Linear(self.embed_dim, self.embed_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.embed_dim, self.num_summaries),
+ )
+
+ def forward(self,
+ masks: torch.Tensor,
+ value: torch.Tensor,
+ need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
+ # masks: B*num_objects*(H0)*(W0)
+ # value: B*num_objects*value_dim*H*W
+ # -> B*num_objects*H*W*value_dim
+ h, w = value.shape[-2:]
+ masks = F.interpolate(masks, size=(h, w), mode='area')
+ masks = masks.unsqueeze(-1)
+ inv_masks = 1 - masks
+ repeated_masks = torch.cat([
+ masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
+ inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
+ ],
+ dim=-1)
+
+ value = value.permute(0, 1, 3, 4, 2)
+ value = self.input_proj(value)
+ if self.add_pe:
+ pe = self.pos_enc(value)
+ value = value + pe
+
+ with torch.cuda.amp.autocast(enabled=False):
+ value = value.float()
+ feature = self.feature_pred(value)
+ logits = self.weights_pred(value)
+ sums, area = _weighted_pooling(repeated_masks, feature, logits)
+
+ summaries = torch.cat([sums, area], dim=-1)
+
+ if need_weights:
+ return summaries, logits
+ else:
+ return summaries, None
\ No newline at end of file
diff --git a/tracker/model/transformer/object_transformer.py b/tracker/model/transformer/object_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..71f0830247495df161f6126dd40ea9ff7f30b9f2
--- /dev/null
+++ b/tracker/model/transformer/object_transformer.py
@@ -0,0 +1,205 @@
+from typing import Dict, Optional
+from omegaconf import DictConfig
+
+import torch
+import torch.nn as nn
+from tracker.model.group_modules import GConv2d
+from tracker.utils.tensor_utils import aggregate
+from tracker.model.transformer.positional_encoding import PositionalEncoding
+from tracker.model.transformer.transformer_layers import *
+
+
+class QueryTransformerBlock(nn.Module):
+ def __init__(self, model_cfg: DictConfig):
+ super().__init__()
+
+ this_cfg = model_cfg.object_transformer
+ self.embed_dim = this_cfg.embed_dim
+ self.num_heads = this_cfg.num_heads
+ self.num_queries = this_cfg.num_queries
+ self.ff_dim = this_cfg.ff_dim
+
+ self.read_from_pixel = CrossAttention(self.embed_dim,
+ self.num_heads,
+ add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv)
+ self.self_attn = SelfAttention(self.embed_dim,
+ self.num_heads,
+ add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv)
+ self.ffn = FFN(self.embed_dim, self.ff_dim)
+ self.read_from_query = CrossAttention(self.embed_dim,
+ self.num_heads,
+ add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv,
+ norm=this_cfg.read_from_query.output_norm)
+ self.pixel_ffn = PixelFFN(self.embed_dim)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pixel: torch.Tensor,
+ query_pe: torch.Tensor,
+ pixel_pe: torch.Tensor,
+ attn_mask: torch.Tensor,
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
+ # x: (bs*num_objects)*num_queries*embed_dim
+ # pixel: bs*num_objects*C*H*W
+ # query_pe: (bs*num_objects)*num_queries*embed_dim
+ # pixel_pe: (bs*num_objects)*(H*W)*C
+ # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W)
+
+ # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C
+ pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
+ x, q_weights = self.read_from_pixel(x,
+ pixel_flat,
+ query_pe,
+ pixel_pe,
+ attn_mask=attn_mask,
+ need_weights=need_weights)
+ x = self.self_attn(x, query_pe)
+ x = self.ffn(x)
+
+ pixel_flat, p_weights = self.read_from_query(pixel_flat,
+ x,
+ pixel_pe,
+ query_pe,
+ need_weights=need_weights)
+ pixel = self.pixel_ffn(pixel, pixel_flat)
+
+ if need_weights:
+ bs, num_objects, _, h, w = pixel.shape
+ q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w)
+ p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads,
+ self.num_queries, h, w)
+
+ return x, pixel, q_weights, p_weights
+
+
+class QueryTransformer(nn.Module):
+ def __init__(self, model_cfg: DictConfig):
+ super().__init__()
+
+ this_cfg = model_cfg.object_transformer
+ self.value_dim = model_cfg.value_dim
+ self.embed_dim = this_cfg.embed_dim
+ self.num_heads = this_cfg.num_heads
+ self.num_queries = this_cfg.num_queries
+
+ # query initialization and embedding
+ self.query_init = nn.Embedding(self.num_queries, self.embed_dim)
+ self.query_emb = nn.Embedding(self.num_queries, self.embed_dim)
+
+ # projection from object summaries to query initialization and embedding
+ self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim)
+ self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim)
+
+ self.pixel_pe_scale = model_cfg.pixel_pe_scale
+ self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
+ self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
+ self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
+ self.spatial_pe = PositionalEncoding(self.embed_dim,
+ scale=self.pixel_pe_scale,
+ temperature=self.pixel_pe_temperature,
+ channel_last=False,
+ transpose_output=True)
+
+ # transformer blocks
+ self.num_blocks = this_cfg.num_blocks
+ self.blocks = nn.ModuleList(
+ QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks))
+ self.mask_pred = nn.ModuleList(
+ nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1))
+ for _ in range(self.num_blocks + 1))
+
+ self.act = nn.ReLU(inplace=True)
+
+ def forward(self,
+ pixel: torch.Tensor,
+ obj_summaries: torch.Tensor,
+ selector: Optional[torch.Tensor] = None,
+ need_weights: bool = False) -> (torch.Tensor, Dict[str, torch.Tensor]):
+ # pixel: B*num_objects*embed_dim*H*W
+ # obj_summaries: B*num_objects*T*num_queries*embed_dim
+ T = obj_summaries.shape[2]
+ bs, num_objects, _, H, W = pixel.shape
+
+ # normalize object values
+ # the last channel is the cumulative area of the object
+ obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries,
+ self.embed_dim + 1)
+ # sum over time
+ # during inference, T=1 as we already did streaming average in memory_manager
+ obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1)
+ obj_area = obj_summaries[:, :, :, -1:].sum(dim=1)
+ obj_values = obj_sums / (obj_area + 1e-4)
+ obj_init = self.summary_to_query_init(obj_values)
+ obj_emb = self.summary_to_query_emb(obj_values)
+
+ # positional embeddings for object queries
+ query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init
+ query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb
+
+ # positional embeddings for pixel features
+ pixel_init = self.pixel_init_proj(pixel)
+ pixel_emb = self.pixel_emb_proj(pixel)
+ pixel_pe = self.spatial_pe(pixel.flatten(0, 1))
+ pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
+ pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb
+
+ pixel = pixel_init
+
+ # run the transformer
+ aux_features = {'logits': []}
+
+ # first aux output
+ aux_logits = self.mask_pred[0](pixel).squeeze(2)
+ attn_mask = self._get_aux_mask(aux_logits, selector)
+ aux_features['logits'].append(aux_logits)
+ for i in range(self.num_blocks):
+ query, pixel, q_weights, p_weights = self.blocks[i](query,
+ pixel,
+ query_emb,
+ pixel_pe,
+ attn_mask,
+ need_weights=need_weights)
+
+ if self.training or i <= self.num_blocks - 1 or need_weights:
+ aux_logits = self.mask_pred[i + 1](pixel).squeeze(2)
+ attn_mask = self._get_aux_mask(aux_logits, selector)
+ aux_features['logits'].append(aux_logits)
+
+ aux_features['q_weights'] = q_weights # last layer only
+ aux_features['p_weights'] = p_weights # last layer only
+
+ if self.training:
+ # no need to save all heads
+ aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads,
+ self.num_queries, H, W)[:, :, 0]
+
+ return pixel, aux_features
+
+ def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
+ # logits: batch_size*num_objects*H*W
+ # selector: batch_size*num_objects*1*1
+ # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W)
+ # where True means the attention is blocked
+
+ if selector is None:
+ prob = logits.sigmoid()
+ else:
+ prob = logits.sigmoid() * selector
+ logits = aggregate(prob, dim=1)
+
+ is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0])
+ foreground_mask = is_foreground.bool().flatten(start_dim=2)
+ inv_foreground_mask = ~foreground_mask
+ inv_background_mask = foreground_mask
+
+ aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat(
+ 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
+ aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat(
+ 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
+
+ aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1)
+
+ aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False
+
+ return aux_mask
\ No newline at end of file
diff --git a/tracker/model/transformer/positional_encoding.py b/tracker/model/transformer/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c15bb73784d3e5fcb1a5d2f9713069e7a933f34
--- /dev/null
+++ b/tracker/model/transformer/positional_encoding.py
@@ -0,0 +1,108 @@
+# Reference:
+# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
+# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
+
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+
+def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
+ """
+ Gets a base embedding for one dimension with sin and cos intertwined
+ """
+ emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
+ return torch.flatten(emb, -2, -1)
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self,
+ dim: int,
+ scale: float = math.pi * 2,
+ temperature: float = 10000,
+ normalize: bool = True,
+ channel_last: bool = True,
+ transpose_output: bool = False):
+ super().__init__()
+ dim = int(np.ceil(dim / 4) * 2)
+ self.dim = dim
+ inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.normalize = normalize
+ self.scale = scale
+ self.eps = 1e-6
+ self.channel_last = channel_last
+ self.transpose_output = transpose_output
+
+ self.cached_penc = None # the cache is irrespective of the number of objects
+
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
+ """
+ :param tensor: A 4/5d tensor of size
+ channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
+ channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
+ :return: positional encoding tensor that has the same shape as the input if the input is 4d
+ if the input is 5d, the output is broadcastable along the k-dimension
+ """
+ if len(tensor.shape) != 4 and len(tensor.shape) != 5:
+ raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
+
+ if len(tensor.shape) == 5:
+ # take a sample from the k dimension
+ num_objects = tensor.shape[1]
+ tensor = tensor[:, 0]
+ else:
+ num_objects = None
+
+ if self.channel_last:
+ batch_size, h, w, c = tensor.shape
+ else:
+ batch_size, c, h, w = tensor.shape
+
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
+ if num_objects is None:
+ return self.cached_penc
+ else:
+ return self.cached_penc.unsqueeze(1)
+
+ self.cached_penc = None
+
+ pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
+ pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
+ if self.normalize:
+ pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
+ pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
+
+ sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
+ emb_y = get_emb(sin_inp_y).unsqueeze(1)
+ emb_x = get_emb(sin_inp_x)
+
+ emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
+ emb[:, :, :self.dim] = emb_x
+ emb[:, :, self.dim:] = emb_y
+
+ if not self.channel_last and self.transpose_output:
+ # cancelled out
+ pass
+ elif (not self.channel_last) or (self.transpose_output):
+ emb = emb.permute(2, 0, 1)
+
+ self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
+ if num_objects is None:
+ return self.cached_penc
+ else:
+ return self.cached_penc.unsqueeze(1)
+
+
+if __name__ == '__main__':
+ pe = PositionalEncoding(8).cuda()
+ input = torch.ones((1, 8, 8, 8)).cuda()
+ output = pe(input)
+ # print(output)
+ print(output[0, :, 0, 0])
+ print(output[0, :, 0, 5])
+ print(output[0, 0, :, 0])
+ print(output[0, 0, 0, :])
diff --git a/tracker/model/transformer/transformer_layers.py b/tracker/model/transformer/transformer_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..80cf522dad092e8282f43e7f0f0dc05cfd15aa9b
--- /dev/null
+++ b/tracker/model/transformer/transformer_layers.py
@@ -0,0 +1,161 @@
+# Modified from PyTorch nn.Transformer
+
+from typing import List, Callable
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+import torch.nn.functional as F
+from tracker.model.channel_attn import CAResBlock
+
+
+class SelfAttention(nn.Module):
+ def __init__(self,
+ dim: int,
+ nhead: int,
+ dropout: float = 0.0,
+ batch_first: bool = True,
+ add_pe_to_qkv: List[bool] = [True, True, False]):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
+ self.norm = nn.LayerNorm(dim)
+ self.dropout = nn.Dropout(dropout)
+ self.add_pe_to_qkv = add_pe_to_qkv
+
+ def forward(self,
+ x: torch.Tensor,
+ pe: torch.Tensor,
+ attn_mask: bool = None,
+ key_padding_mask: bool = None) -> torch.Tensor:
+ x = self.norm(x)
+ if any(self.add_pe_to_qkv):
+ x_with_pe = x + pe
+ q = x_with_pe if self.add_pe_to_qkv[0] else x
+ k = x_with_pe if self.add_pe_to_qkv[1] else x
+ v = x_with_pe if self.add_pe_to_qkv[2] else x
+ else:
+ q = k = v = x
+
+ r = x
+ x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
+ return r + self.dropout(x)
+
+
+# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
+class CrossAttention(nn.Module):
+ def __init__(self,
+ dim: int,
+ nhead: int,
+ dropout: float = 0.0,
+ batch_first: bool = True,
+ add_pe_to_qkv: List[bool] = [True, True, False],
+ residual: bool = True,
+ norm: bool = True):
+ super().__init__()
+ self.cross_attn = nn.MultiheadAttention(dim,
+ nhead,
+ dropout=dropout,
+ batch_first=batch_first)
+ if norm:
+ self.norm = nn.LayerNorm(dim)
+ else:
+ self.norm = nn.Identity()
+ self.dropout = nn.Dropout(dropout)
+ self.add_pe_to_qkv = add_pe_to_qkv
+ self.residual = residual
+
+ def forward(self,
+ x: torch.Tensor,
+ mem: torch.Tensor,
+ x_pe: torch.Tensor,
+ mem_pe: torch.Tensor,
+ attn_mask: bool = None,
+ *,
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
+ x = self.norm(x)
+ if self.add_pe_to_qkv[0]:
+ q = x + x_pe
+ else:
+ q = x
+
+ if any(self.add_pe_to_qkv[1:]):
+ mem_with_pe = mem + mem_pe
+ k = mem_with_pe if self.add_pe_to_qkv[1] else mem
+ v = mem_with_pe if self.add_pe_to_qkv[2] else mem
+ else:
+ k = v = mem
+ r = x
+ x, weights = self.cross_attn(q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ need_weights=need_weights,
+ average_attn_weights=False)
+
+ if self.residual:
+ return r + self.dropout(x), weights
+ else:
+ return self.dropout(x), weights
+
+
+class FFN(nn.Module):
+ def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
+ super().__init__()
+ self.linear1 = nn.Linear(dim_in, dim_ff)
+ self.linear2 = nn.Linear(dim_ff, dim_in)
+ self.norm = nn.LayerNorm(dim_in)
+
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ r = x
+ x = self.norm(x)
+ x = self.linear2(self.activation(self.linear1(x)))
+ x = r + x
+ return x
+
+
+class PixelFFN(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.dim = dim
+ self.conv = CAResBlock(dim, dim)
+
+ def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
+ # pixel: batch_size * num_objects * dim * H * W
+ # pixel_flat: (batch_size*num_objects) * (H*W) * dim
+ bs, num_objects, _, h, w = pixel.shape
+ pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
+ pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
+
+ x = self.conv(pixel_flat)
+ x = x.view(bs, num_objects, self.dim, h, w)
+ return x
+
+
+class OutputFFN(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
+ super().__init__()
+ self.linear1 = nn.Linear(dim_in, dim_out)
+ self.linear2 = nn.Linear(dim_out, dim_out)
+
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ else:
+ self.activation = activation
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear2(self.activation(self.linear1(x)))
+ return x
+
+
+def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return F.gelu
+
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
diff --git a/tracker/model/utils/__init__.py b/tracker/model/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/model/utils/memory_utils.py b/tracker/model/utils/memory_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b8726ce6bca51cbd76fff814db9cf03544397cd
--- /dev/null
+++ b/tracker/model/utils/memory_utils.py
@@ -0,0 +1,95 @@
+import math
+import torch
+from typing import Optional, Union, Tuple
+
+
+# @torch.jit.script
+def get_similarity(mk: torch.Tensor,
+ ms: torch.Tensor,
+ qk: torch.Tensor,
+ qe: torch.Tensor,
+ add_batch_dim: bool = False) -> torch.Tensor:
+ # used for training/inference and memory reading/memory potentiation
+ # mk: B x CK x [N] - Memory keys
+ # ms: B x 1 x [N] - Memory shrinkage
+ # qk: B x CK x [HW/P] - Query keys
+ # qe: B x CK x [HW/P] - Query selection
+ # Dimensions in [] are flattened
+ if add_batch_dim:
+ mk, ms = mk.unsqueeze(0), ms.unsqueeze(0)
+ qk, qe = qk.unsqueeze(0), qe.unsqueeze(0)
+
+ CK = mk.shape[1]
+ mk = mk.flatten(start_dim=2)
+ ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
+ qk = qk.flatten(start_dim=2)
+ qe = qe.flatten(start_dim=2) if qe is not None else None
+
+ if qe is not None:
+ # See XMem's appendix for derivation
+ mk = mk.transpose(1, 2)
+ a_sq = (mk.pow(2) @ qe)
+ two_ab = 2 * (mk @ (qk * qe))
+ b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
+ similarity = (-a_sq + two_ab - b_sq)
+ else:
+ # similar to STCN if we don't have the selection term
+ a_sq = mk.pow(2).sum(1).unsqueeze(2)
+ two_ab = 2 * (mk.transpose(1, 2) @ qk)
+ similarity = (-a_sq + two_ab)
+
+ if ms is not None:
+ similarity = similarity * ms / math.sqrt(CK) # B*N*HW
+ else:
+ similarity = similarity / math.sqrt(CK) # B*N*HW
+
+ return similarity
+
+
+def do_softmax(
+ similarity: torch.Tensor,
+ top_k: Optional[int] = None,
+ inplace: bool = False,
+ return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ # normalize similarity with top-k softmax
+ # similarity: B x N x [HW/P]
+ # use inplace with care
+ if top_k is not None:
+ values, indices = torch.topk(similarity, k=top_k, dim=1)
+
+ x_exp = values.exp_()
+ x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
+ if inplace:
+ similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
+ affinity = similarity
+ else:
+ affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
+ else:
+ maxes = torch.max(similarity, dim=1, keepdim=True)[0]
+ x_exp = torch.exp(similarity - maxes)
+ x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
+ affinity = x_exp / x_exp_sum
+ indices = None
+
+ if return_usage:
+ return affinity, affinity.sum(dim=2)
+
+ return affinity
+
+
+def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor,
+ qe: torch.Tensor) -> torch.Tensor:
+ # shorthand used in training with no top-k
+ similarity = get_similarity(mk, ms, qk, qe)
+ affinity = do_softmax(similarity)
+ return affinity
+
+
+def readout(affinity: torch.Tensor, mv: torch.Tensor) -> torch.Tensor:
+ B, CV, T, H, W = mv.shape
+
+ mo = mv.view(B, CV, T * H * W)
+ mem = torch.bmm(mo, affinity)
+ mem = mem.view(B, CV, H, W)
+
+ return mem
diff --git a/tracker/model/utils/parameter_groups.py b/tracker/model/utils/parameter_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..177866af48de5e6d8795bdf6734b0dccb5a1947b
--- /dev/null
+++ b/tracker/model/utils/parameter_groups.py
@@ -0,0 +1,72 @@
+import logging
+
+log = logging.getLogger()
+
+
+def get_parameter_groups(model, stage_cfg, print_log=False):
+ """
+ Assign different weight decays and learning rates to different parameters.
+ Returns a parameter group which can be passed to the optimizer.
+ """
+ weight_decay = stage_cfg.weight_decay
+ embed_weight_decay = stage_cfg.embed_weight_decay
+ backbone_lr_ratio = stage_cfg.backbone_lr_ratio
+ base_lr = stage_cfg.learning_rate
+
+ backbone_params = []
+ embed_params = []
+ other_params = []
+
+ embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
+ embedding_names = [e + '.weight' for e in embedding_names]
+
+ # inspired by detectron2
+ memo = set()
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if param in memo:
+ continue
+ memo.add(param)
+
+ if name.startswith('module'):
+ name = name[7:]
+
+ inserted = False
+ if name.startswith('pixel_encoder.'):
+ backbone_params.append(param)
+ inserted = True
+ if print_log:
+ log.info(f'{name} counted as a backbone parameter.')
+ else:
+ for e in embedding_names:
+ if name.endswith(e):
+ embed_params.append(param)
+ inserted = True
+ if print_log:
+ log.info(f'{name} counted as an embedding parameter.')
+ break
+
+ if not inserted:
+ other_params.append(param)
+
+ parameter_groups = [
+ {
+ 'params': backbone_params,
+ 'lr': base_lr * backbone_lr_ratio,
+ 'weight_decay': weight_decay
+ },
+ {
+ 'params': embed_params,
+ 'lr': base_lr,
+ 'weight_decay': embed_weight_decay
+ },
+ {
+ 'params': other_params,
+ 'lr': base_lr,
+ 'weight_decay': weight_decay
+ },
+ ]
+
+ return parameter_groups
\ No newline at end of file
diff --git a/tracker/model/utils/resnet.py b/tracker/model/utils/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a07d2fd12da0b951ed6c724f97aa2f203877e7e
--- /dev/null
+++ b/tracker/model/utils/resnet.py
@@ -0,0 +1,179 @@
+"""
+resnet.py - A modified ResNet structure
+We append extra channels to the first conv by some network surgery
+"""
+
+from collections import OrderedDict
+import math
+
+import torch
+import torch.nn as nn
+from torch.utils import model_zoo
+
+
+def load_weights_add_extra_dim(target, source_state, extra_dim=1):
+ new_dict = OrderedDict()
+
+ for k1, v1 in target.state_dict().items():
+ if not 'num_batches_tracked' in k1:
+ if k1 in source_state:
+ tar_v = source_state[k1]
+
+ if v1.shape != tar_v.shape:
+ # Init the new segmentation channel with zeros
+ # print(v1.shape, tar_v.shape)
+ c, _, w, h = v1.shape
+ pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
+ nn.init.orthogonal_(pads)
+ tar_v = torch.cat([tar_v, pads], 1)
+
+ new_dict[k1] = tar_v
+
+ target.load_state_dict(new_dict)
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ dilation=dilation,
+ padding=dilation,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = [block(self.inplanes, planes, stride, downsample)]
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, dilation=dilation))
+
+ return nn.Sequential(*layers)
+
+
+def resnet18(pretrained=True, extra_dim=0):
+ model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
+ if pretrained:
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
+ return model
+
+
+def resnet50(pretrained=True, extra_dim=0):
+ model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
+ if pretrained:
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
+ return model
diff --git a/tracker/utils/__init__.py b/tracker/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tracker/utils/image_saver.py b/tracker/utils/image_saver.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3edfa96e60fea0e5ec8fd087da85d2efaa6444c
--- /dev/null
+++ b/tracker/utils/image_saver.py
@@ -0,0 +1,230 @@
+import cv2
+import numpy as np
+
+import torch
+from collections import defaultdict
+
+
+def tensor_to_numpy(image):
+ image_np = (image.numpy() * 255).astype('uint8')
+ return image_np
+
+
+def tensor_to_np_float(image):
+ image_np = image.numpy().astype('float32')
+ return image_np
+
+
+def detach_to_cpu(x):
+ return x.detach().cpu()
+
+
+def transpose_np(x):
+ return np.transpose(x, [1, 2, 0])
+
+
+def tensor_to_gray_im(x):
+ x = detach_to_cpu(x)
+ x = tensor_to_numpy(x)
+ x = transpose_np(x)
+ return x
+
+
+def tensor_to_im(x):
+ x = detach_to_cpu(x).clamp(0, 1)
+ x = tensor_to_numpy(x)
+ x = transpose_np(x)
+ return x
+
+
+# Predefined key <-> caption dict
+key_captions = {
+ 'im': 'Image',
+ 'gt': 'GT',
+}
+"""
+Return an image array with captions
+keys in dictionary will be used as caption if not provided
+values should contain lists of cv2 images
+"""
+
+
+def get_image_array(images, grid_shape, captions={}):
+ h, w = grid_shape
+ cate_counts = len(images)
+ rows_counts = len(next(iter(images.values())))
+
+ font = cv2.FONT_HERSHEY_SIMPLEX
+
+ output_image = np.zeros([w * cate_counts, h * (rows_counts + 1), 3], dtype=np.uint8)
+ col_cnt = 0
+ for k, v in images.items():
+
+ # Default as key value itself
+ caption = captions.get(k, k)
+
+ # Handles new line character
+ dy = 40
+ for i, line in enumerate(caption.split('\n')):
+ cv2.putText(output_image, line, (10, col_cnt * w + 100 + i * dy), font, 0.8,
+ (255, 255, 255), 2, cv2.LINE_AA)
+
+ # Put images
+ for row_cnt, img in enumerate(v):
+ im_shape = img.shape
+ if len(im_shape) == 2:
+ img = img[..., np.newaxis]
+
+ img = (img * 255).astype('uint8')
+
+ output_image[(col_cnt + 0) * w:(col_cnt + 1) * w,
+ (row_cnt + 1) * h:(row_cnt + 2) * h, :] = img
+
+ col_cnt += 1
+
+ return output_image
+
+
+def base_transform(im, size):
+ im = tensor_to_np_float(im)
+ if len(im.shape) == 3:
+ im = im.transpose((1, 2, 0))
+ else:
+ im = im[:, :, None]
+
+ # Resize
+ if im.shape[1] != size:
+ im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST)
+
+ return im.clip(0, 1)
+
+
+def im_transform(im, size):
+ return base_transform(detach_to_cpu(im), size=size)
+
+
+def mask_transform(mask, size):
+ return base_transform(detach_to_cpu(mask), size=size)
+
+
+def logits_transform(mask, size):
+ return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size)
+
+
+def add_attention(mask, pos):
+ mask = mask[:, :, None].repeat(3, axis=2)
+ pos = (pos + 1) / 2
+ for i in range(pos.shape[0]):
+ y = int(pos[i][0] * mask.shape[0])
+ x = int(pos[i][1] * mask.shape[1])
+ y = max(min(y, mask.shape[0] - 1), 0)
+ x = max(min(x, mask.shape[1] - 1), 0)
+ # mask[y, x, :] = (255, 0, 0)
+ cv2.circle(mask, (x, y), 5, (1, 0, 0), -1)
+ return mask
+
+
+def vis(images, size, num_objects):
+ req_images = defaultdict(list)
+
+ b, t = images['rgb'].shape[:2]
+
+ # limit the number of images saved
+ b = min(2, b)
+
+ # find max num objects
+ max_num_objects = max(num_objects[:b])
+
+ GT_suffix = ''
+ for bi in range(b):
+ GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4]
+
+ for bi in range(b):
+ for ti in range(t):
+ req_images['RGB'].append(im_transform(images['rgb'][bi, ti], size))
+ aux = images[f'aux_{max(ti, 1)}'] # no aux_0, use aux_1 for shape
+ if 'sensory_logits' in aux:
+ sensory_aux = aux['sensory_logits'][bi].softmax(dim=0)
+ # batch_size * num_objects * num_levels * H * W
+ q_mask_aux = aux['q_logits'][bi].softmax(dim=0)
+ num_levels = q_mask_aux.shape[1]
+
+ for oi in range(max_num_objects):
+ if ti == 0 or oi >= num_objects[bi]:
+ req_images[f'Mask_{oi}'].append(
+ mask_transform(images['first_frame_gt'][bi][0, oi], size))
+ req_images[f'S-Aux_{oi}'].append(
+ mask_transform(images['first_frame_gt'][bi][0, oi], size))
+ for l in range(num_levels):
+ req_images[f'Q-Aux-L{l}_{oi}'].append(
+ mask_transform(images['first_frame_gt'][bi][0, oi], size))
+ else:
+ mask = mask_transform(images[f'masks_{ti}'][bi][oi], size)
+ req_images[f'Mask_{oi}'].append(mask)
+ if 'sensory_logits' in aux:
+ req_images[f'S-Aux_{oi}'].append(mask_transform(sensory_aux[oi + 1], size))
+
+ for l in range(num_levels):
+ mask = mask_transform(q_mask_aux[oi + 1, l], size)
+ req_images[f'Q-Aux-L{l}_{oi}'].append(mask)
+
+ req_images[f'GT_{oi}_{GT_suffix}'].append(
+ mask_transform(images['cls_gt'][bi, ti, 0] == (oi + 1), size))
+
+ return get_image_array(req_images, size, key_captions)
+
+
+def vis_debug(images, size, num_objects):
+ req_images = defaultdict(list)
+
+ b, t = images['rgb'].shape[:2]
+
+ # limit the number of images saved
+ b = min(2, b)
+
+ # find max num objects
+ max_num_objects = max(num_objects[:b])
+
+ GT_suffix = ''
+ for bi in range(b):
+ GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4]
+
+ for bi in range(b):
+ for ti in range(t):
+ req_images['RGB'].append(im_transform(images['rgb'][bi, ti], size))
+ aux = images[f'aux_{max(ti, 1)}'] # no aux_0, use aux_1 for shape
+ sensory_aux = aux['sensory_logits'][bi].softmax(dim=0)
+ # batch_size * num_objects * num_levels * H * W
+ q_mask_aux = aux['q_logits'][bi].softmax(dim=0)
+ attn_mask = aux['attn_mask'][bi]
+ num_levels = q_mask_aux.shape[1]
+ num_queries = attn_mask.shape[1]
+
+ for oi in range(max_num_objects):
+ if ti == 0 or oi >= num_objects[bi]:
+ req_images[f'Mask_{oi}'].append(
+ mask_transform(images['first_frame_gt'][bi][0, oi], size))
+ req_images[f'S-Aux_{oi}'].append(
+ mask_transform(images['first_frame_gt'][bi][0, oi], size))
+ for l in range(num_levels):
+ req_images[f'Q-Aux-L{l}_{oi}'].append(
+ mask_transform(images['first_frame_gt'][bi][0, oi], size))
+ for q in range(num_queries):
+ req_images[f'Attn-Mask-Q{q}_{oi}'].append(
+ mask_transform(images['first_frame_gt'][bi][0, oi], size))
+ else:
+ mask = mask_transform(images[f'masks_{ti}'][bi][oi], size)
+ req_images[f'Mask_{oi}'].append(mask)
+ req_images[f'S-Aux_{oi}'].append(mask_transform(sensory_aux[oi + 1], size))
+
+ for l in range(num_levels):
+ mask = mask_transform(q_mask_aux[oi + 1, l], size)
+ req_images[f'Q-Aux-L{l}_{oi}'].append(mask)
+ for q in range(num_queries):
+ mask = mask_transform(1 - attn_mask[oi, q].float(), size)
+ req_images[f'Attn-Mask-Q{q}_{oi}'].append(mask)
+
+ req_images[f'GT_{oi}_{GT_suffix}'].append(
+ mask_transform(images['cls_gt'][bi, ti, 0] == (oi + 1), size))
+
+ return get_image_array(req_images, size, key_captions)
\ No newline at end of file
diff --git a/tracker/utils/load_subset.py b/tracker/utils/load_subset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c16ed0391ae745a736290bb7b956c98539e087ca
--- /dev/null
+++ b/tracker/utils/load_subset.py
@@ -0,0 +1,13 @@
+import json
+
+
+def load_subset(path):
+ with open(path, mode='r') as f:
+ subset = set(f.read().splitlines())
+ return subset
+
+
+def load_empty_masks(path):
+ with open(path, mode='r') as f:
+ empty_masks = json.load(f)
+ return empty_masks
diff --git a/tracker/utils/log_integrator.py b/tracker/utils/log_integrator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2e12ebb393cfd6b734859fad3c243850ba60ea7
--- /dev/null
+++ b/tracker/utils/log_integrator.py
@@ -0,0 +1,84 @@
+"""
+Integrate numerical values for some iterations
+Typically used for loss computation / logging to tensorboard
+Call finalize and create a new Integrator when you want to display/log
+"""
+from typing import Dict, Callable, Tuple
+import torch
+from tracker.utils.logger import TensorboardLogger
+
+
+class Integrator:
+ def __init__(self, logger: TensorboardLogger, distributed: bool = True):
+ self.values = {}
+ self.counts = {}
+ self.hooks = [] # List is used here to maintain insertion order
+
+ self.logger = logger
+
+ self.distributed = distributed
+ self.local_rank = torch.distributed.get_rank()
+ self.world_size = torch.distributed.get_world_size()
+
+ def add_tensor(self, key: str, tensor: torch.Tensor):
+ if key not in self.values:
+ self.counts[key] = 1
+ if type(tensor) == float or type(tensor) == int:
+ self.values[key] = tensor
+ else:
+ self.values[key] = tensor.mean().item()
+ else:
+ self.counts[key] += 1
+ if type(tensor) == float or type(tensor) == int:
+ self.values[key] += tensor
+ else:
+ self.values[key] += tensor.mean().item()
+
+ def add_dict(self, tensor_dict: Dict[str, torch.Tensor]):
+ for k, v in tensor_dict.items():
+ self.add_tensor(k, v)
+
+ def add_hook(self, hook: Callable[[torch.Tensor], Tuple[str, torch.Tensor]]):
+ """
+ Adds a custom hook, i.e. compute new metrics using values in the dict
+ The hook takes the dict as argument, and returns a (k, v) tuple
+ e.g. for computing IoU
+ """
+ if type(hook) == list:
+ self.hooks.extend(hook)
+ else:
+ self.hooks.append(hook)
+
+ def reset_except_hooks(self):
+ self.values = {}
+ self.counts = {}
+
+ # Average and output the metrics
+ def finalize(self, exp_id: str, prefix: str, it: int) -> None:
+
+ for hook in self.hooks:
+ k, v = hook(self.values)
+ self.add_tensor(k, v)
+
+ outputs = {}
+ for k, v in self.values.items():
+
+ if k[:4] == 'hide':
+ continue
+
+ avg = v / self.counts[k]
+
+ if self.distributed:
+ # Inplace operation
+ avg = torch.tensor(avg).cuda()
+ torch.distributed.reduce(avg, dst=0)
+
+ if self.local_rank == 0:
+ avg = (avg / self.world_size).cpu().item()
+ outputs[k] = avg
+ else:
+ # Simple does it
+ outputs[k] = avg
+
+ if (not self.distributed) or (self.local_rank == 0):
+ self.logger.log_metrics(exp_id, prefix, outputs, it)
diff --git a/tracker/utils/logger.py b/tracker/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..399f76d685d2f0fa8a68b69222c1400815a0a2e3
--- /dev/null
+++ b/tracker/utils/logger.py
@@ -0,0 +1,107 @@
+"""
+Dumps things to tensorboard and console
+"""
+
+import os
+import logging
+import datetime
+from typing import Dict
+import numpy as np
+from PIL import Image
+
+from torch.utils.tensorboard import SummaryWriter
+from tracker.utils.time_estimator import TimeEstimator
+
+
+def tensor_to_numpy(image):
+ image_np = (image.numpy() * 255).astype('uint8')
+ return image_np
+
+
+def detach_to_cpu(x):
+ return x.detach().cpu()
+
+
+def fix_width_trunc(x):
+ return ('{:.9s}'.format('{:0.9f}'.format(x)))
+
+
+class TensorboardLogger:
+ def __init__(self, run_dir, py_logger: logging.Logger, *, enabled_tb):
+ self.run_dir = run_dir
+ self.py_log = py_logger
+ if enabled_tb:
+ self.tb_log = SummaryWriter(run_dir)
+ else:
+ self.tb_log = None
+
+ # Get current git info for logging
+ try:
+ import git
+ repo = git.Repo(".")
+ git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)
+ except (ImportError, RuntimeError):
+ print('Failed to fetch git info. Defaulting to None')
+ git_info = 'None'
+
+ self.log_string('git', git_info)
+
+ # used when logging metrics
+ self.time_estimator: TimeEstimator = None
+
+ def log_scalar(self, tag, x, it):
+ if self.tb_log is None:
+ return
+ self.tb_log.add_scalar(tag, x, it)
+
+ def log_metrics(self, exp_id, prefix, metrics: Dict, it):
+ msg = f'{exp_id}-{prefix} - it {it:6d}: '
+ metrics_msg = ''
+ for k, v in sorted(metrics.items()):
+ self.log_scalar(f'{prefix}/{k}', v, it)
+ metrics_msg += f'{k: >10}:{v:.7f},\t'
+
+ if self.time_estimator is not None:
+ self.time_estimator.update()
+ avg_time = self.time_estimator.get_and_reset_avg_time()
+ est = self.time_estimator.get_est_remaining(it)
+ est = datetime.timedelta(seconds=est)
+ if est.days > 0:
+ remaining_str = f'{est.days}d {est.seconds // 3600}h'
+ else:
+ remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m'
+ eta = datetime.datetime.now() + est
+ eta_str = eta.strftime('%Y-%m-%d %H:%M:%S')
+ time_msg = f'avg_time:{avg_time:.3f},remaining:{remaining_str},eta:{eta_str},\t'
+ msg = f'{msg} {time_msg}'
+
+ msg = f'{msg} {metrics_msg}'
+ self.py_log.info(msg)
+
+ def log_image(self, stage_name, tag, image, it):
+ image_dir = os.path.join(self.run_dir, f'{stage_name}_images')
+ os.makedirs(image_dir, exist_ok=True)
+
+ image = Image.fromarray(image)
+ image.save(os.path.join(image_dir, f'{tag}_{it}.png'))
+
+ def log_string(self, tag, x):
+ self.py_log.info(f'{tag} - {x}')
+ if self.tb_log is None:
+ return
+ self.tb_log.add_text(tag, x)
+
+ def debug(self, x):
+ self.py_log.debug(x)
+
+ def info(self, x):
+ self.py_log.info(x)
+
+ def warning(self, x):
+ self.py_log.warning(x)
+
+ def error(self, x):
+ self.py_log.error(x)
+
+ def critical(self, x):
+ self.py_log.critical(x)
diff --git a/tracker/utils/mask_mapper.py b/tracker/utils/mask_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..617af0c712d20f429a05274438b77a5afc88d2db
--- /dev/null
+++ b/tracker/utils/mask_mapper.py
@@ -0,0 +1,78 @@
+import numpy as np
+import torch
+
+def all_to_onehot(masks, labels):
+ if len(masks.shape) == 3:
+ Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
+ else:
+ Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
+
+ for ni, l in enumerate(labels):
+ Ms[ni] = (masks == l).astype(np.uint8)
+
+ return Ms
+
+class MaskMapper:
+ """
+ This class is used to convert a indexed-mask to a one-hot representation.
+ It also takes care of remapping non-continuous indices
+ It has two modes:
+ 1. Default. Only masks with new indices are supposed to go into the remapper.
+ This is also the case for YouTubeVOS.
+ i.e., regions with index 0 are not "background", but "don't care".
+
+ 2. Exhaustive. Regions with index 0 are considered "background".
+ Every single pixel is considered to be "labeled".
+ """
+ def __init__(self):
+ self.labels = []
+ self.remappings = {}
+
+ # if coherent, no mapping is required
+ self.coherent = True
+
+ def clear_labels(self):
+ self.labels = []
+ self.remappings = {}
+ # if coherent, no mapping is required
+ self.coherent = True
+
+ def convert_mask(self, mask, exhaustive=False):
+ # mask is in index representation, H*W numpy array
+ labels = np.unique(mask).astype(np.uint8)
+ labels = labels[labels!=0].tolist()
+
+ new_labels = list(set(labels) - set(self.labels))
+ if not exhaustive:
+ assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
+
+ # add new remappings
+ for i, l in enumerate(new_labels):
+ self.remappings[l] = i+len(self.labels)+1
+ if self.coherent and i+len(self.labels)+1 != l:
+ self.coherent = False
+
+ if exhaustive:
+ new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
+ else:
+ if self.coherent:
+ new_mapped_labels = new_labels
+ else:
+ new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)
+
+ self.labels.extend(new_labels)
+ # mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
+ mask = torch.from_numpy(mask).float()
+ # mask num_objects*H*W
+ return mask, new_mapped_labels
+
+
+ def remap_index_mask(self, mask):
+ # mask is in index representation, H*W numpy array
+ if self.coherent:
+ return mask
+
+ new_mask = np.zeros_like(mask)
+ for l, i in self.remappings.items():
+ new_mask[mask==i] = l
+ return new_mask
\ No newline at end of file
diff --git a/tracker/utils/palette.py b/tracker/utils/palette.py
new file mode 100644
index 0000000000000000000000000000000000000000..26a773c88bdd15fdb372fa9f552602a751625fc4
--- /dev/null
+++ b/tracker/utils/palette.py
@@ -0,0 +1,9 @@
+import numpy as np
+
+davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
+
+youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f'
+
+davis_palette_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3)
+
+youtube_palette_np = np.frombuffer(youtube_palette, dtype=np.uint8).reshape(-1, 3)
\ No newline at end of file
diff --git a/tracker/utils/pano_utils.py b/tracker/utils/pano_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d0e2fe9448ec977f0f177c2b5dc9aaccf38250d
--- /dev/null
+++ b/tracker/utils/pano_utils.py
@@ -0,0 +1,30 @@
+import numpy as np
+from threading import Lock
+
+
+class ID2RGBConverter:
+ def __init__(self):
+ self.all_id = []
+ self.obj_to_id = {}
+ self.lock = Lock()
+
+ def _id_to_rgb(self, id: int):
+ rgb = np.zeros((3, ), dtype=np.uint8)
+ for i in range(3):
+ rgb[i] = id % 256
+ id = id // 256
+ return rgb
+
+ def convert(self, obj: int):
+ with self.lock:
+ if obj in self.obj_to_id:
+ id = self.obj_to_id[obj]
+ else:
+ while True:
+ id = np.random.randint(255, 256**3)
+ if id not in self.all_id:
+ break
+ self.obj_to_id[obj] = id
+ self.all_id.append(id)
+
+ return id, self._id_to_rgb(id)
diff --git a/tracker/utils/point_features.py b/tracker/utils/point_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b794ef23c856bba022215c84581ba38e6d030b
--- /dev/null
+++ b/tracker/utils/point_features.py
@@ -0,0 +1,111 @@
+# This file is copied from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
+# such that users do not need to install detectron2 just for these two functions
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+from typing import List
+import torch
+from torch.nn import functional as F
+
+
+def cat(tensors: List[torch.Tensor], dim: int = 0):
+ """
+ Efficient version of torch.cat that avoids a copy if there is only a single element in a list
+ """
+ assert isinstance(tensors, (list, tuple))
+ if len(tensors) == 1:
+ return tensors[0]
+ return torch.cat(tensors, dim)
+
+
+def calculate_uncertainty(sem_seg_logits):
+ """
+ For each location of the prediction `sem_seg_logits` we estimate uncerainty as the
+ difference between top first and top second predicted logits.
+ Args:
+ mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and
+ C is the number of foreground classes. The values are logits.
+ Returns:
+ scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with
+ the most uncertain locations having the highest uncertainty score.
+ """
+ if sem_seg_logits.shape[1] == 2:
+ # binary segmentation
+ return -(torch.abs(sem_seg_logits[:, 1:2]))
+ top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
+ return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
+
+
+def point_sample(input, point_coords, **kwargs):
+ """
+ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
+ Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
+ [0, 1] x [0, 1] square.
+ Args:
+ input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
+ point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
+ [0, 1] x [0, 1] normalized point coordinates.
+ Returns:
+ output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
+ features for points in `point_coords`. The features are obtained via bilinear
+ interpolation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
+ """
+ add_dim = False
+ if point_coords.dim() == 3:
+ add_dim = True
+ point_coords = point_coords.unsqueeze(2)
+ output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
+ if add_dim:
+ output = output.squeeze(3)
+ return output
+
+
+def get_uncertain_point_coords_with_randomness(coarse_logits, uncertainty_func, num_points,
+ oversample_ratio, importance_sample_ratio):
+ """
+ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties
+ are calculated for each point using 'uncertainty_func' function that takes point's logit
+ prediction as input.
+ See PointRend paper for details.
+ Args:
+ coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
+ class-specific or class-agnostic prediction.
+ uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
+ contains logit predictions for P points and returns their uncertainties as a Tensor of
+ shape (N, 1, P).
+ num_points (int): The number of points P to sample.
+ oversample_ratio (int): Oversampling parameter.
+ importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
+ Returns:
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
+ sampled points.
+ """
+ assert oversample_ratio >= 1
+ assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
+ num_boxes = coarse_logits.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
+ point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
+ # It is crucial to calculate uncertainty based on the sampled prediction value for the points.
+ # Calculating uncertainties of the coarse predictions first and sampling them for points leads
+ # to incorrect results.
+ # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
+ # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
+ # However, if we calculate uncertainties for the coarse predictions first,
+ # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
+ point_uncertainties = uncertainty_func(point_logits)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points,
+ 2)
+ if num_random_points > 0:
+ point_coords = cat(
+ [
+ point_coords,
+ torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
+ ],
+ dim=1,
+ )
+ return point_coords
\ No newline at end of file
diff --git a/tracker/utils/range_transform.py b/tracker/utils/range_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1b0b3b2a01a061b9b2220a93cdf7f7a6357bfb
--- /dev/null
+++ b/tracker/utils/range_transform.py
@@ -0,0 +1,12 @@
+import torchvision.transforms as transforms
+
+im_mean = (124, 116, 104)
+
+im_normalization = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+
+inv_im_trans = transforms.Normalize(
+ mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
+ std=[1/0.229, 1/0.224, 1/0.225])
diff --git a/tracker/utils/tensor_utils.py b/tracker/utils/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8e33c8936134ec4efdcc412945469853edc6498
--- /dev/null
+++ b/tracker/utils/tensor_utils.py
@@ -0,0 +1,62 @@
+from typing import List, Iterable
+import torch
+import torch.nn.functional as F
+
+
+# STM
+def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
+ h, w = in_img.shape[-2:]
+
+ if h % d > 0:
+ new_h = h + d - h % d
+ else:
+ new_h = h
+ if w % d > 0:
+ new_w = w + d - w % d
+ else:
+ new_w = w
+ lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
+ lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
+ pad_array = (int(lw), int(uw), int(lh), int(uh))
+ out = F.pad(in_img, pad_array)
+ return out, pad_array
+
+
+def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
+ if len(img.shape) == 4:
+ if pad[2] + pad[3] > 0:
+ img = img[:, :, pad[2]:-pad[3], :]
+ if pad[0] + pad[1] > 0:
+ img = img[:, :, :, pad[0]:-pad[1]]
+ elif len(img.shape) == 3:
+ if pad[2] + pad[3] > 0:
+ img = img[:, pad[2]:-pad[3], :]
+ if pad[0] + pad[1] > 0:
+ img = img[:, :, pad[0]:-pad[1]]
+ elif len(img.shape) == 5:
+ if pad[2] + pad[3] > 0:
+ img = img[:, :, :, pad[2]:-pad[3], :]
+ if pad[0] + pad[1] > 0:
+ img = img[:, :, :, :, pad[0]:-pad[1]]
+ else:
+ raise NotImplementedError
+ return img
+
+
+# @torch.jit.script
+def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
+ with torch.cuda.amp.autocast(enabled=False):
+ prob = prob.float()
+ new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
+ dim).clamp(1e-7, 1 - 1e-7)
+ logits = torch.log((new_prob / (1 - new_prob)))
+
+ return logits
+
+
+# @torch.jit.script
+def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
+ # cls_gt: B*1*H*W
+ B, _, H, W = cls_gt.shape
+ one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
+ return one_hot
\ No newline at end of file
diff --git a/tracker/utils/time_estimator.py b/tracker/utils/time_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d429b0404b641472ff84701305c570cc26280b7
--- /dev/null
+++ b/tracker/utils/time_estimator.py
@@ -0,0 +1,43 @@
+import time
+
+
+class TimeEstimator:
+ def __init__(self, total_iter, step_size):
+ self.avg_time_window = [] # window-based average
+ self.exp_avg_time = None # exponential moving average
+ self.alpha = 0.7 # for exponential moving average
+
+ self.last_time = time.time() # would not be accurate for the first iteration but well
+ self.total_iter = total_iter
+ self.step_size = step_size
+
+ self.buffering_exp = True
+
+ # call this at a fixed interval
+ # does not have to be every step
+ def update(self):
+ curr_time = time.time()
+ time_per_iter = curr_time - self.last_time
+ self.last_time = curr_time
+
+ self.avg_time_window.append(time_per_iter)
+
+ if self.buffering_exp:
+ if self.exp_avg_time is not None:
+ # discard the first iteration call to not pollute the ema
+ self.buffering_exp = False
+ self.exp_avg_time = time_per_iter
+ else:
+ self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
+
+ def get_est_remaining(self, it):
+ if self.exp_avg_time is None:
+ return 0
+
+ remaining_iter = self.total_iter - it
+ return remaining_iter * self.exp_avg_time / self.step_size
+
+ def get_and_reset_avg_time(self):
+ avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size
+ self.avg_time_window = []
+ return avg
diff --git a/utils/download_util.py b/utils/download_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8fb1b00522309d0c0931f5396355011fb200e7
--- /dev/null
+++ b/utils/download_util.py
@@ -0,0 +1,109 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formated file siz.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+ Ref:
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = 'https://docs.google.com/uc?export=download'
+ params = {'id': file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params['confirm'] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+ print(response_file_size)
+ if 'Content-Range' in response_file_size.headers:
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+ return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+ Returns:
+ str: The path to the downloaded file.
+ """
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
\ No newline at end of file
diff --git a/utils/evaluate_video.py b/utils/evaluate_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..47382fa43863ec0b4ed8772b3c7d4b9950b39a65
--- /dev/null
+++ b/utils/evaluate_video.py
@@ -0,0 +1,133 @@
+import torch
+import numpy as np
+from skimage.metrics import structural_similarity as ssim
+from skimage.metrics import peak_signal_noise_ratio as psnr
+import lpips
+from tqdm import tqdm
+import cv2
+
+class VideoQualityEvaluator:
+ def __init__(self, device='cuda'):
+ """Initialize video quality evaluator with specified computation device
+
+ Args:
+ device (str): Computation device ('cuda' or 'cpu')
+ """
+ self.device = device
+ # Initialize LPIPS model (perceptual similarity metric)
+ self.lpips_model = lpips.LPIPS(net='alex').to(device)
+
+ def _preprocess_frame(self, frame):
+ """Convert frame to standardized format for evaluation
+
+ Args:
+ frame: Input frame (numpy array or torch tensor)
+
+ Returns:
+ Processed frame in HWC format with values in [0,1]
+ """
+ if isinstance(frame, torch.Tensor):
+ frame = frame.detach().cpu().numpy()
+
+ # Normalize to [0,1] if needed
+ if frame.max() > 1:
+ frame = frame / 255.0
+ # Convert CHW to HWC if needed
+ if len(frame.shape) == 3 and frame.shape[0] == 3:
+ frame = frame.transpose(1, 2, 0)
+ return frame
+
+ def calculate_psnr(self, vid1, vid2):
+ """Calculate average PSNR between two videos
+
+ Args:
+ vid1: First video (list/array of frames)
+ vid2: Second video (list/array of frames)
+
+ Returns:
+ Mean PSNR value across all frames
+ """
+ psnrs = []
+ for f1, f2 in zip(vid1, vid2):
+ f1 = self._preprocess_frame(f1)
+ f2 = self._preprocess_frame(f2)
+ # Calculate PSNR for this frame pair
+ psnrs.append(psnr(f1, f2, data_range=1.0))
+ return np.mean(psnrs)
+
+ def calculate_ssim(self, vid1, vid2):
+ """Calculate average SSIM between two videos
+
+ Args:
+ vid1: First video (list/array of frames)
+ vid2: Second video (list/array of frames)
+
+ Returns:
+ Mean SSIM value across all frames
+ """
+ ssims = []
+ for f1, f2 in zip(vid1, vid2):
+ f1 = self._preprocess_frame(f1)
+ f2 = self._preprocess_frame(f2)
+ # Calculate SSIM for this frame pair (multichannel for color images)
+ ssims.append(ssim(f1, f2, channel_axis=2, data_range=1.0))
+ return np.mean(ssims)
+
+ def calculate_lpips(self, vid1, vid2):
+ """Calculate average LPIPS (perceptual similarity) between two videos
+
+ Args:
+ vid1: First video (list/array of frames)
+ vid2: Second video (list/array of frames)
+
+ Returns:
+ Mean LPIPS value across all frames (lower is better)
+ """
+ lpips_values = []
+ for f1, f2 in zip(vid1, vid2):
+ # Convert to torch tensor if needed
+ if not isinstance(f1, torch.Tensor):
+ f1 = torch.from_numpy(f1).permute(2, 0, 1).unsqueeze(0).float() # HWC -> 1CHW
+ f2 = torch.from_numpy(f2).permute(2, 0, 1).unsqueeze(0).float()
+
+ # Normalize to [-1,1] if needed
+ if f1.max() > 1:
+ f1 = f1 / 127.5 - 1.0
+ f2 = f2 / 127.5 - 1.0
+
+ f1 = f1.to(self.device)
+ f2 = f2.to(self.device)
+
+ # Calculate LPIPS with no gradients
+ with torch.no_grad():
+ lpips_values.append(self.lpips_model(f1, f2).item())
+ return np.mean(lpips_values)
+
+ def evaluate_videos(self, generated_video, reference_video, metrics=['psnr','lpips','ssim']):
+ """Comprehensive video quality evaluation between generated and reference videos
+
+ Args:
+ generated_video: Model-generated video [T,H,W,C] or [T,C,H,W]
+ reference_video: Ground truth reference video [T,H,W,C] or [T,C,H,W]
+ metrics: List of metrics to compute ('psnr', 'ssim', 'lpips')
+
+ Returns:
+ Dictionary containing computed metric values
+ """
+ results = {}
+
+ # Verify video lengths match
+ assert len(generated_video) == len(reference_video), "Videos must have same number of frames"
+
+ # Calculate requested metrics
+ if 'psnr' in metrics:
+ results['psnr'] = self.calculate_psnr(generated_video, reference_video)
+
+ if 'ssim' in metrics:
+ results['ssim'] = self.calculate_ssim(generated_video, reference_video)
+
+ if 'lpips' in metrics:
+ results['lpips'] = self.calculate_lpips(generated_video, reference_video)
+
+ return results
+
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..43b849902245dd338a36f4f4ff09e33425365af6
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,131 @@
+import os
+import re
+import random
+import time
+import torch
+import torch.nn as nn
+import logging
+import numpy as np
+from os import path as osp
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+initialized_logger = {}
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+ Args:
+ logger_name (str): root logger name. Default: 'basicsr'.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger_name in initialized_logger:
+ return logger
+
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(logging.Formatter(format_str))
+ logger.addHandler(stream_handler)
+ logger.propagate = False
+
+ if log_file is not None:
+ logger.setLevel(log_level)
+ # add file handler
+ # file_handler = logging.FileHandler(log_file, 'w')
+ file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+ initialized_logger[logger_name] = True
+ return logger
+
+
+IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
+ torch.__version__)[0][:3])] >= [1, 12, 0]
+
+def gpu_is_available():
+ if IS_HIGH_VERSION:
+ if torch.backends.mps.is_available():
+ return True
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
+
+def get_device(gpu_id=None):
+ if gpu_id is None:
+ gpu_str = ''
+ elif isinstance(gpu_id, int):
+ gpu_str = f':{gpu_id}'
+ else:
+ raise TypeError('Input should be int value.')
+
+ if IS_HIGH_VERSION:
+ if torch.backends.mps.is_available():
+ return torch.device('mps'+gpu_str)
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative pathes.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
\ No newline at end of file