sayakpaul HF Staff commited on
Commit
fc56851
·
verified ·
1 Parent(s): 2e63ca7

Upload inputs.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inputs.py +96 -0
inputs.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam, ModularPipeline, PipelineState
17
+ import numpy as np
18
+ import torch
19
+ import PIL
20
+ from typing import List
21
+ from diffusers.modular_pipelines.wan.before_denoise import WanInputStep
22
+
23
+
24
+ def calculate_dimensions(image, mod_value):
25
+ """
26
+ Calculate output dimensions based on resolution settings.
27
+
28
+ Args:
29
+ image: PIL Image
30
+ mod_value: Modulo value for dimension alignment
31
+
32
+ Returns:
33
+ Tuple of (width, height)
34
+ """
35
+
36
+ # Get max area from preset or override
37
+ target_area = 720 * 1280
38
+
39
+ # Calculate dimensions maintaining aspect ratio
40
+ aspect_ratio = image.height / image.width
41
+ calculated_height = round(np.sqrt(target_area * aspect_ratio)) // mod_value * mod_value
42
+ calculated_width = round(np.sqrt(target_area / aspect_ratio)) // mod_value * mod_value
43
+
44
+ return calculated_width, calculated_height
45
+
46
+
47
+ # Make the input step aware of `negative_prompt_embeds`.
48
+ # ChronoEdit uses a `guidance_scale` of 1.
49
+ class ChronoEditInputStep(WanInputStep):
50
+ model_name = "chronoedit"
51
+
52
+ @property
53
+ def inputs(self) -> List[InputParam]:
54
+ return [
55
+ InputParam("num_videos_per_prompt", default=1),
56
+ InputParam(
57
+ "prompt_embeds",
58
+ required=True,
59
+ type_hint=torch.Tensor,
60
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
61
+ ),
62
+ InputParam(
63
+ "negative_prompt_embeds",
64
+ type_hint=torch.Tensor,
65
+ description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
66
+ ),
67
+ ]
68
+
69
+
70
+ class ChronoEditImageInputStep(ModularPipelineBlocks):
71
+ model_name = "chronoedit"
72
+
73
+ @property
74
+ def inputs(self) -> List[InputParam]:
75
+ return [InputParam(name="image")]
76
+
77
+ @property
78
+ def intermediate_outputs(self) -> List[OutputParam]:
79
+ return [
80
+ OutputParam(name="image", type_hint=PIL.Image.Image),
81
+ OutputParam(name="height", type_hint=int, description="The height set w.r.t input image and specs"),
82
+ OutputParam(name="width", type_hint=int, description="The width set w.r.t input image and specs"),
83
+ ]
84
+
85
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
86
+ block_state = self.get_block_state(state)
87
+ image = block_state.image
88
+ mod_value = components.vae_scale_factor_spatial * components.transformer.config.patch_size[1]
89
+
90
+ width, height = calculate_dimensions(image, mod_value)
91
+ block_state.image = image.resize((width, height))
92
+ block_state.height = height
93
+ block_state.width = width
94
+
95
+ self.set_block_state(state, block_state)
96
+ return components, state