from functools import partial import torch from toolkit.prompt_utils import PromptEmbeds from PIL import Image from diffusers import UniPCMultistepScheduler import torch from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.samplers.custom_flowmatch_sampler import ( CustomFlowMatchEulerDiscreteScheduler, ) from .wan22_pipeline import Wan22Pipeline from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from torchvision.transforms import functional as TF from toolkit.models.wan21.wan21 import Wan21, AggressiveWanUnloadPipeline from toolkit.models.wan21.wan_utils import add_first_frame_conditioning_v22 # for generation only? scheduler_configUniPC = { "_class_name": "UniPCMultistepScheduler", "_diffusers_version": "0.35.0.dev0", "beta_end": 0.02, "beta_schedule": "linear", "beta_start": 0.0001, "disable_corrector": [], "dynamic_thresholding_ratio": 0.995, "final_sigmas_type": "zero", "flow_shift": 5.0, "lower_order_final": True, "num_train_timesteps": 1000, "predict_x0": True, "prediction_type": "flow_prediction", "rescale_betas_zero_snr": False, "sample_max_value": 1.0, "solver_order": 2, "solver_p": None, "solver_type": "bh2", "steps_offset": 0, "thresholding": False, "time_shift_type": "exponential", "timestep_spacing": "linspace", "trained_betas": None, "use_beta_sigmas": False, "use_dynamic_shifting": False, "use_exponential_sigmas": False, "use_flow_sigmas": True, "use_karras_sigmas": False, } # for training. I think it is right scheduler_config = { "num_train_timesteps": 1000, "shift": 5.0, "use_dynamic_shifting": False, } # TODO: this is a temporary monkeypatch to fix the time text embedding to allow for batch sizes greater than 1. Remove this when the diffusers library is fixed. def time_text_monkeypatch( self, timestep: torch.Tensor, encoder_hidden_states, encoder_hidden_states_image = None, timestep_seq_len = None, ): timestep = self.timesteps_proj(timestep) if timestep_seq_len is not None: timestep = timestep.unflatten(0, (encoder_hidden_states.shape[0], timestep_seq_len)) time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: timestep = timestep.to(time_embedder_dtype) temb = self.time_embedder(timestep).type_as(encoder_hidden_states) timestep_proj = self.time_proj(self.act_fn(temb)) encoder_hidden_states = self.text_embedder(encoder_hidden_states) if encoder_hidden_states_image is not None: encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image class Wan225bModel(Wan21): arch = "wan22_5b" _wan_generation_scheduler_config = scheduler_configUniPC _wan_expand_timesteps = True def __init__( self, device, model_config: ModelConfig, dtype="bf16", custom_pipeline=None, noise_scheduler=None, **kwargs, ): super().__init__( device=device, model_config=model_config, dtype=dtype, custom_pipeline=custom_pipeline, noise_scheduler=noise_scheduler, **kwargs, ) self._wan_cache = None def load_model(self): super().load_model() # patch the condition embedder self.model.condition_embedder.forward = partial(time_text_monkeypatch, self.model.condition_embedder) def get_bucket_divisibility(self): # 16x compression and 2x2 patch size return 32 def get_generation_pipeline(self): scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config) pipeline = Wan22Pipeline( vae=self.vae, transformer=self.model, transformer_2=self.model, text_encoder=self.text_encoder, tokenizer=self.tokenizer, scheduler=scheduler, expand_timesteps=self._wan_expand_timesteps, device=self.device_torch, aggressive_offload=self.model_config.low_vram, ) pipeline = pipeline.to(self.device_torch) return pipeline # static method to get the scheduler @staticmethod def get_train_scheduler(): scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) return scheduler def get_base_model_version(self): return "wan_2.2_5b" def generate_single_image( self, pipeline: AggressiveWanUnloadPipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, generator: torch.Generator, extra: dict, ): # reactivate progress bar since this is slooooow pipeline.set_progress_bar_config(disable=False) num_frames = ( (gen_config.num_frames - 1) // 4 ) * 4 + 1 # make sure it is divisible by 4 + 1 gen_config.num_frames = num_frames height = gen_config.height width = gen_config.width noise_mask = None if gen_config.ctrl_img is not None: control_img = Image.open(gen_config.ctrl_img).convert("RGB") d = self.get_bucket_divisibility() # make sure they are divisible by d height = height // d * d width = width // d * d # resize the control image control_img = control_img.resize((width, height), Image.LANCZOS) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = pipeline.prepare_latents( 1, num_channels_latents, height, width, gen_config.num_frames, torch.float32, self.device_torch, generator, None, ).to(self.torch_dtype) first_frame_n1p1 = ( TF.to_tensor(control_img) .unsqueeze(0) .to(self.device_torch, dtype=self.torch_dtype) * 2.0 - 1.0 ) # normalize to [-1, 1] gen_config.latents, noise_mask = add_first_frame_conditioning_v22( latent_model_input=latents, first_frame=first_frame_n1p1, vae=self.vae ) output = pipeline( prompt_embeds=conditional_embeds.text_embeds.to( self.device_torch, dtype=self.torch_dtype ), negative_prompt_embeds=unconditional_embeds.text_embeds.to( self.device_torch, dtype=self.torch_dtype ), height=height, width=width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, num_frames=gen_config.num_frames, generator=generator, return_dict=False, output_type="pil", noise_mask=noise_mask, **extra, )[0] # shape = [1, frames, channels, height, width] batch_item = output[0] # list of pil images if gen_config.num_frames > 1: return batch_item # return the frames. else: # get just the first image img = batch_item[0] return img def get_noise_prediction( self, latent_model_input: torch.Tensor, timestep: torch.Tensor, # 0 to 1000 scale text_embeddings: PromptEmbeds, batch: DataLoaderBatchDTO, **kwargs, ): # videos come in (bs, num_frames, channels, height, width) # images come in (bs, channels, height, width) # for wan, only do i2v for video for now. Images do normal t2i conditioned_latent = latent_model_input noise_mask = None if batch.dataset_config.do_i2v: with torch.no_grad(): frames = batch.tensor if len(frames.shape) == 4: first_frames = frames elif len(frames.shape) == 5: first_frames = frames[:, 0] # Add conditioning using the standalone function conditioned_latent, noise_mask = add_first_frame_conditioning_v22( latent_model_input=latent_model_input.to( self.device_torch, self.torch_dtype ), first_frame=first_frames.to(self.device_torch, self.torch_dtype), vae=self.vae, ) else: raise ValueError(f"Unknown frame shape {frames.shape}") # make the noise mask if noise_mask is None: noise_mask = torch.ones( conditioned_latent.shape, dtype=conditioned_latent.dtype, device=conditioned_latent.device, ) # todo write this better t_chunks = torch.chunk(timestep, timestep.shape[0]) out_t_chunks = [] for t in t_chunks: # seq_len: num_latent_frames * latent_height//2 * latent_width//2 temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten() # batch_size, seq_len temp_ts = temp_ts.unsqueeze(0) out_t_chunks.append(temp_ts) timestep = torch.cat(out_t_chunks, dim=0) noise_pred = self.model( hidden_states=conditioned_latent, timestep=timestep, encoder_hidden_states=text_embeddings.text_embeds, return_dict=False, **kwargs, )[0] return noise_pred