Spaces:
Runtime error
Runtime error
| from src.utils import * | |
| from src.flow_utils import warp_tensor | |
| import torch | |
| import torchvision | |
| import gc | |
| """ | |
| ========================================================================== | |
| * step(): one DDPM step with background smoothing | |
| * inference(): translate one batch with FRESCO and background smoothing | |
| ========================================================================== | |
| """ | |
| def step(pipe, model_output, timestep, sample, generator, repeat_noise=False, | |
| visualize_pipeline=False, flows=None, occs=None, saliency=None): | |
| """ | |
| DDPM step with background smoothing | |
| * background smoothing: warp the background region of the previous frame to the current frame | |
| """ | |
| scheduler = pipe.scheduler | |
| # 1. get previous step value (=t-1) | |
| prev_timestep = scheduler.previous_timestep(timestep) | |
| # 2. compute alphas, betas | |
| alpha_prod_t = scheduler.alphas_cumprod[timestep] | |
| alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.one | |
| beta_prod_t = 1 - alpha_prod_t | |
| beta_prod_t_prev = 1 - alpha_prod_t_prev | |
| current_alpha_t = alpha_prod_t / alpha_prod_t_prev | |
| current_beta_t = 1 - current_alpha_t | |
| # 3. compute predicted original sample from predicted noise also called | |
| # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
| pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
| """ | |
| [HACK] add background smoothing | |
| decode the feature | |
| warp the feature of f_{i-1} | |
| fuse the warped f_{i-1} with f_{i} in the non-salient region (i.e., background) | |
| encode the fused feature | |
| """ | |
| if saliency is not None and flows is not None and occs is not None: | |
| image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample | |
| image = warp_tensor(image, flows, occs, saliency, unet_chunk_size=1) | |
| pred_original_sample = pipe.vae.config.scaling_factor * pipe.vae.encode(image).latent_dist.sample() | |
| # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t | |
| # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
| pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t | |
| current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t | |
| # 5. Compute predicted previous sample µ_t | |
| # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
| pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample | |
| variance = beta_prod_t_prev / beta_prod_t * current_beta_t | |
| variance = torch.clamp(variance, min=1e-20) | |
| variance = (variance ** 0.5) * torch.randn(model_output.shape, generator=generator, | |
| device=model_output.device, dtype=model_output.dtype) | |
| """ | |
| [HACK] background smoothing | |
| applying the same noise could be good for static background | |
| """ | |
| if repeat_noise: | |
| variance = variance[0:1].repeat(model_output.shape[0],1,1,1) | |
| if visualize_pipeline: # for debug | |
| image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample | |
| viz = torchvision.utils.make_grid(torch.clamp(image, -1, 1), image.shape[0], 1) | |
| visualize(viz.cpu(), 90) | |
| pred_prev_sample = pred_prev_sample + variance | |
| return (pred_prev_sample, pred_original_sample) | |
| def inference(pipe, controlnet, frescoProc, | |
| imgs, prompt_embeds, edges, timesteps, | |
| cond_scale=[0.7]*20, num_inference_steps=20, num_warmup_steps=6, | |
| do_classifier_free_guidance=True, seed=0, guidance_scale=7.5, use_controlnet=True, | |
| record_latents=[], propagation_mode=False, visualize_pipeline=False, | |
| flows = None, occs = None, saliency=None, repeat_noise=False, | |
| num_intraattn_steps = 1, step_interattn_end = 350, bg_smoothing_steps = [16,17]): | |
| """ | |
| video-to-video translation inference pipeline with FRESCO | |
| * add controlnet and SDEdit | |
| * add FRESCO-guided attention | |
| * add FRESCO-guided optimization | |
| * add background smoothing | |
| * add support for inter-batch long video translation | |
| [input of the original pipe] | |
| pipe: base diffusion model | |
| imgs: a batch of the input frames | |
| prompt_embeds: prompts | |
| num_inference_steps: number of DDPM steps | |
| timesteps: generated by pipe.scheduler.set_timesteps(num_inference_steps) | |
| do_classifier_free_guidance: cfg, should be always true | |
| guidance_scale: cfg scale | |
| seed | |
| [input of SDEdit] | |
| num_warmup_steps: skip the first num_warmup_steps DDPM steps | |
| [input of controlnet] | |
| use_controlnet: bool, whether using controlnet | |
| controlnet: controlnet model | |
| edges: input for controlnet (edge/stroke/depth, etc.) | |
| cond_scale: controlnet scale | |
| [input of FRESCO] | |
| frescoProc: FRESCO attention controller | |
| flows: optical flows | |
| occs: occlusion mask | |
| num_intraattn_steps: apply num_interattn_steps steps of spatial-guided attention | |
| step_interattn_end: apply temporal-guided attention in [step_interattn_end, 1000] steps | |
| [input for background smoothing] | |
| saliency: saliency mask | |
| repeat_noise: bool, use the same noise for all frames | |
| bg_smoothing_steps: apply background smoothing in bg_smoothing_steps | |
| [input for long video translation] | |
| record_latents: recorded latents in the last batch | |
| propagation_mode: bool, whether this is not the first batch | |
| [output] | |
| latents: a batch of latents of the translated frames | |
| """ | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| device = pipe._execution_device | |
| noise_scheduler = pipe.scheduler | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| B, C, H, W = imgs.shape | |
| latents = pipe.prepare_latents( | |
| B, | |
| pipe.unet.config.in_channels, | |
| H, | |
| W, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents = None, | |
| ) | |
| if repeat_noise: | |
| latents = latents[0:1].repeat(B,1,1,1).detach() | |
| if num_warmup_steps < 0: | |
| latents_init = latents.detach() | |
| num_warmup_steps = 0 | |
| else: | |
| # SDEdit, use the noisy latent of imges as the input rather than a pure gausssian noise | |
| latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample() | |
| latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach() | |
| # SDEdit, run num_inference_steps-num_warmup_steps steps | |
| with pipe.progress_bar(total=num_inference_steps-num_warmup_steps) as progress_bar: | |
| latents = latents_init | |
| for i, t in enumerate(timesteps[num_warmup_steps:]): | |
| """ | |
| [HACK] control the steps to apply spatial/temporal-guided attention | |
| [HACK] record and restore latents from previous batch | |
| """ | |
| if i >= num_intraattn_steps: | |
| frescoProc.controller.disable_intraattn() | |
| if t < step_interattn_end: | |
| frescoProc.controller.disable_interattn() | |
| if propagation_mode: # restore latent from previous batch and record latent of the current batch | |
| latents[0:2] = record_latents[i].detach().clone() | |
| record_latents[i] = latents[[0,len(latents)-1]].detach().clone() | |
| else: # frist batch, record_latents[0][t] = [x_1,t, x_{N,t}] | |
| record_latents += [latents[[0,len(latents)-1]].detach().clone()] | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| if use_controlnet: | |
| control_model_input = latent_model_input | |
| controlnet_prompt_embeds = prompt_embeds | |
| down_block_res_samples, mid_block_res_sample = controlnet( | |
| control_model_input, | |
| t, | |
| encoder_hidden_states=controlnet_prompt_embeds, | |
| controlnet_cond=edges, | |
| conditioning_scale=cond_scale[i+num_warmup_steps], | |
| guess_mode=False, | |
| return_dict=False, | |
| ) | |
| else: | |
| down_block_res_samples, mid_block_res_sample = None, None | |
| # predict the noise residual | |
| noise_pred = pipe.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| cross_attention_kwargs=None, | |
| down_block_additional_residuals=down_block_res_samples, | |
| mid_block_additional_residual=mid_block_res_sample, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| """ | |
| [HACK] background smoothing | |
| Note: bg_smoothing_steps should be rescaled based on num_inference_steps | |
| current [16,17] is based on num_inference_steps=20 | |
| """ | |
| if i + num_warmup_steps in bg_smoothing_steps: | |
| latents = step(pipe, noise_pred, t, latents, generator, | |
| visualize_pipeline=visualize_pipeline, | |
| flows = flows, occs = occs, saliency=saliency)[0] | |
| else: | |
| latents = step(pipe, noise_pred, t, latents, generator, | |
| visualize_pipeline=visualize_pipeline)[0] | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % pipe.scheduler.order == 0): | |
| progress_bar.update() | |
| return latents |