sayakpaul HF Staff commited on
Commit
e9b4bd3
·
verified ·
1 Parent(s): 1b761f5

Upload denoise.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. denoise.py +217 -0
denoise.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ TODO: need to implement temporal reasoning:
18
+ https://huggingface.co/spaces/nvidia/ChronoEdit/blob/main/chronoedit_diffusers/pipeline_chronoedit.py
19
+ """
20
+
21
+ from diffusers.modular_pipelines import (
22
+ ModularPipelineBlocks,
23
+ ComponentSpec,
24
+ BlockState,
25
+ PipelineState,
26
+ ModularPipeline,
27
+ InputParam,
28
+ LoopSequentialPipelineBlocks,
29
+ )
30
+ from diffusers.configuration_utils import FrozenDict
31
+ from diffusers.guiders import ClassifierFreeGuidance
32
+ from typing import List
33
+ from diffusers import AutoModel, UniPCMultistepScheduler
34
+ import torch
35
+ from diffusers.modular_pipelines.wan.denoise import WanLoopAfterDenoiser, WanDenoiseLoopWrapper
36
+
37
+
38
+ class ChronoEditLoopBeforeDenoiser(ModularPipelineBlocks):
39
+ model_name = "chronoedit"
40
+
41
+ @property
42
+ def inputs(self) -> List[InputParam]:
43
+ return [
44
+ InputParam(
45
+ "latents",
46
+ required=True,
47
+ type_hint=torch.Tensor,
48
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
49
+ ),
50
+ InputParam(
51
+ "condition",
52
+ required=True,
53
+ type_hint=torch.Tensor,
54
+ description="The conditioning latents to use for the denoising process. Can be generated in prepare_latent step.",
55
+ ),
56
+ ]
57
+
58
+ @torch.no_grad()
59
+ def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
60
+ latent_model_input = torch.cat([block_state.latents, block_state.condition], dim=1)
61
+ block_state.latent_model_input = latent_model_input.to(block_state.latents.dtype)
62
+ block_state.timestep = t.expand(block_state.latents.shape[0])
63
+ return components, block_state
64
+
65
+
66
+ class ChronoEditLoopDenoiser(ModularPipelineBlocks):
67
+ model_name = "chronoedit"
68
+
69
+ @property
70
+ def expected_components(self) -> List[ComponentSpec]:
71
+ return [
72
+ ComponentSpec(
73
+ "guider",
74
+ ClassifierFreeGuidance,
75
+ config=FrozenDict({"guidance_scale": 1.0}),
76
+ default_creation_method="from_config",
77
+ ),
78
+ ComponentSpec("transformer", AutoModel),
79
+ ]
80
+
81
+ @property
82
+ def inputs(self) -> List[InputParam]:
83
+ return [
84
+ InputParam("attention_kwargs"),
85
+ InputParam(
86
+ "latents",
87
+ required=True,
88
+ type_hint=torch.Tensor,
89
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
90
+ ),
91
+ InputParam(
92
+ "condition",
93
+ required=True,
94
+ type_hint=torch.Tensor,
95
+ description="The conditioning latents to use for the denoising process. Can be generated in prepare_latent step.",
96
+ ),
97
+ InputParam(
98
+ "image_embeds",
99
+ required=True,
100
+ type_hint=torch.Tensor,
101
+ description="The conditioning image embeddings to use for the denoising process. Can be generated in prepare_latent step.",
102
+ ),
103
+ InputParam(
104
+ "num_inference_steps",
105
+ required=True,
106
+ type_hint=int,
107
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
108
+ ),
109
+ InputParam(
110
+ kwargs_type="denoiser_input_fields",
111
+ description=(
112
+ "All conditional model inputs that need to be prepared with guider. "
113
+ "It should contain prompt_embeds/negative_prompt_embeds. "
114
+ "Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
115
+ ),
116
+ ),
117
+ ]
118
+
119
+ @torch.no_grad()
120
+ def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor) -> PipelineState:
121
+ # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
122
+ # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
123
+ guider_inputs = {
124
+ "prompt_embeds": (
125
+ getattr(block_state, "prompt_embeds", None),
126
+ getattr(block_state, "negative_prompt_embeds", None),
127
+ ),
128
+ }
129
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
130
+
131
+ guider_state = components.guider.prepare_inputs(guider_inputs)
132
+
133
+ # run the denoiser for each guidance batch
134
+ for guider_state_batch in guider_state:
135
+ components.guider.prepare_models(components.transformer)
136
+ cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
137
+ prompt_embeds = cond_kwargs.pop("prompt_embeds")
138
+
139
+ # Predict the noise residual
140
+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
141
+ guider_state_batch.noise_pred = components.transformer(
142
+ hidden_states=block_state.latent_model_input,
143
+ timestep=block_state.timestep,
144
+ encoder_hidden_states=prompt_embeds,
145
+ encoder_hidden_states_image=block_state.image_embeds,
146
+ attention_kwargs=block_state.attention_kwargs,
147
+ return_dict=False,
148
+ )[0]
149
+ components.guider.cleanup_models(components.transformer)
150
+
151
+ # Perform guidance
152
+ block_state.noise_pred = components.guider(guider_state)[0]
153
+
154
+ return components, block_state
155
+
156
+
157
+ class ChronoEditDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
158
+ model_name = "chronoedit"
159
+
160
+ @property
161
+ def loop_expected_components(self) -> List[ComponentSpec]:
162
+ return [
163
+ ComponentSpec(
164
+ "guider",
165
+ ClassifierFreeGuidance,
166
+ config=FrozenDict({"guidance_scale": 1.0}),
167
+ default_creation_method="from_config",
168
+ ),
169
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
170
+ ComponentSpec("transformer", AutoModel),
171
+ ]
172
+
173
+ @property
174
+ def loop_inputs(self) -> List[InputParam]:
175
+ return [
176
+ InputParam(
177
+ "timesteps",
178
+ required=True,
179
+ type_hint=torch.Tensor,
180
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
181
+ ),
182
+ InputParam(
183
+ "num_inference_steps",
184
+ required=True,
185
+ type_hint=int,
186
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
187
+ ),
188
+ ]
189
+
190
+ @torch.no_grad()
191
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
192
+ block_state = self.get_block_state(state)
193
+
194
+ block_state.num_warmup_steps = max(
195
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
196
+ )
197
+
198
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
199
+ for i, t in enumerate(block_state.timesteps):
200
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
201
+ if i == len(block_state.timesteps) - 1 or (
202
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
203
+ ):
204
+ progress_bar.update()
205
+
206
+ self.set_block_state(state, block_state)
207
+
208
+ return components, state
209
+
210
+
211
+ class ChronoEditLoopAfterDenoiser(WanLoopAfterDenoiser):
212
+ model_name = "chronoedit"
213
+
214
+
215
+ class ChronoEditDenoiseStep(ChronoEditDenoiseLoopWrapper):
216
+ block_classes = [ChronoEditLoopBeforeDenoiser, ChronoEditLoopDenoiser, ChronoEditLoopAfterDenoiser]
217
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]