Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
92d8df6
1
Parent(s):
1b9d6c7
add 2step pipe and app
Browse files
app.py
CHANGED
|
@@ -12,6 +12,9 @@ import gradio as gr
|
|
| 12 |
import spaces
|
| 13 |
|
| 14 |
import subprocess
|
|
|
|
|
|
|
|
|
|
| 15 |
GIT_TOKEN = os.environ.get("GIT_TOKEN")
|
| 16 |
import subprocess
|
| 17 |
|
|
@@ -34,6 +37,7 @@ import subprocess
|
|
| 34 |
from qwenimage.debug import ctimed
|
| 35 |
from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
| 36 |
from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
|
|
|
|
| 37 |
|
| 38 |
# --- Model Loading ---
|
| 39 |
|
|
@@ -64,7 +68,26 @@ pipe.load_lora_weights(
|
|
| 64 |
"checkpoints/distill_5k_lora.safetensors",
|
| 65 |
adapter_name="fast_5k",
|
| 66 |
)
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
MAX_SEED = np.iinfo(np.int32).max
|
| 70 |
|
|
@@ -73,12 +96,12 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
| 73 |
def run_pipe(
|
| 74 |
image,
|
| 75 |
prompt,
|
|
|
|
| 76 |
seed,
|
| 77 |
randomize_seed,
|
| 78 |
num_inference_steps,
|
| 79 |
shift,
|
| 80 |
-
|
| 81 |
-
progress=gr.Progress(track_tqdm=True)
|
| 82 |
):
|
| 83 |
with ctimed("pre pipe"):
|
| 84 |
|
|
@@ -90,35 +113,40 @@ def run_pipe(
|
|
| 90 |
|
| 91 |
# Choose input image (prefer uploaded, else last output)
|
| 92 |
pil_images = []
|
| 93 |
-
if image is
|
| 94 |
-
if isinstance(image, Image.Image):
|
| 95 |
-
pil_images.append(image.convert("RGB"))
|
| 96 |
-
elif hasattr(image, "name"):
|
| 97 |
-
pil_images.append(Image.open(image.name).convert("RGB"))
|
| 98 |
-
elif prev_output:
|
| 99 |
-
pil_images.append(prev_output.convert("RGB"))
|
| 100 |
-
|
| 101 |
-
if len(pil_images) == 0:
|
| 102 |
raise gr.Error("Please upload an image first.")
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# finetuner.enable()
|
| 107 |
pipe.scheduler.config["base_shift"] = shift
|
| 108 |
pipe.scheduler.config["max_shift"] = shift
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
# --- UI ---
|
| 121 |
|
|
|
|
|
|
|
| 122 |
|
| 123 |
with gr.Blocks(theme=gr.themes.Citrus()) as demo:
|
| 124 |
|
|
@@ -127,32 +155,40 @@ with gr.Blocks(theme=gr.themes.Citrus()) as demo:
|
|
| 127 |
with gr.Row():
|
| 128 |
with gr.Column():
|
| 129 |
image = gr.Image(label="Input Image", type="pil")
|
| 130 |
-
prev_output = gr.Image(value=None, visible=False)
|
| 131 |
-
is_reset = gr.Checkbox(value=False, visible=False)
|
| 132 |
prompt = gr.Textbox(label="Prompt", placeholder="Prompt", lines=2)
|
| 133 |
|
|
|
|
| 134 |
|
| 135 |
run_btn = gr.Button("Generate", variant="primary")
|
| 136 |
|
| 137 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
| 138 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
| 139 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 140 |
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=2)
|
| 141 |
shift = gr.Slider(label="Timestep Shift", minimum=0.0, maximum=4.0, step=0.1, value=2.0)
|
| 142 |
|
| 143 |
with gr.Column():
|
| 144 |
-
result = gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
inputs = [
|
| 147 |
image,
|
| 148 |
prompt,
|
|
|
|
| 149 |
seed,
|
| 150 |
randomize_seed,
|
| 151 |
num_inference_steps,
|
| 152 |
shift,
|
| 153 |
-
|
| 154 |
]
|
| 155 |
-
outputs = [result, seed]
|
| 156 |
|
| 157 |
|
| 158 |
run_event = run_btn.click(
|
|
@@ -161,6 +197,17 @@ with gr.Blocks(theme=gr.themes.Citrus()) as demo:
|
|
| 161 |
outputs=outputs
|
| 162 |
)
|
| 163 |
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
demo.launch()
|
|
|
|
| 12 |
import spaces
|
| 13 |
|
| 14 |
import subprocess
|
| 15 |
+
|
| 16 |
+
from qwenimage.models.attention_processors import QwenDoubleStreamAttnProcessorFA3
|
| 17 |
+
from qwenimage.optimization import optimize_pipeline_
|
| 18 |
GIT_TOKEN = os.environ.get("GIT_TOKEN")
|
| 19 |
import subprocess
|
| 20 |
|
|
|
|
| 37 |
from qwenimage.debug import ctimed
|
| 38 |
from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
| 39 |
from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
|
| 40 |
+
from qwenimage.experiments.quantize_experiments import conf_fp8darow_nolast, quantize_transformer_fp8darow_nolast
|
| 41 |
|
| 42 |
# --- Model Loading ---
|
| 43 |
|
|
|
|
| 68 |
"checkpoints/distill_5k_lora.safetensors",
|
| 69 |
adapter_name="fast_5k",
|
| 70 |
)
|
| 71 |
+
pipe.set_adapters(["fast_5k"], adapter_weights=[1.0])
|
| 72 |
+
pipe.fuse_lora(adapter_names=["fast_5k"], lora_scale=1.0)
|
| 73 |
+
pipe.unload_lora_weights()
|
| 74 |
+
|
| 75 |
+
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
|
| 76 |
+
pipe.transformer.fuse_qkv_projections()
|
| 77 |
+
pipe.transformer.check_fused_qkv()
|
| 78 |
+
|
| 79 |
+
optimize_pipeline_(
|
| 80 |
+
pipe,
|
| 81 |
+
cache_compiled=True,
|
| 82 |
+
quantize=True,
|
| 83 |
+
suffix="_fp8darow_nolast_fa3_fast5k",
|
| 84 |
+
quantize_config=conf_fp8darow_nolast(),
|
| 85 |
+
pipe_kwargs={
|
| 86 |
+
"image": [Image.new("RGB", (1024, 1024))],
|
| 87 |
+
"prompt":"prompt",
|
| 88 |
+
"num_inference_steps":2,
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
|
| 92 |
MAX_SEED = np.iinfo(np.int32).max
|
| 93 |
|
|
|
|
| 96 |
def run_pipe(
|
| 97 |
image,
|
| 98 |
prompt,
|
| 99 |
+
num_runs,
|
| 100 |
seed,
|
| 101 |
randomize_seed,
|
| 102 |
num_inference_steps,
|
| 103 |
shift,
|
| 104 |
+
prompt_cached,
|
|
|
|
| 105 |
):
|
| 106 |
with ctimed("pre pipe"):
|
| 107 |
|
|
|
|
| 113 |
|
| 114 |
# Choose input image (prefer uploaded, else last output)
|
| 115 |
pil_images = []
|
| 116 |
+
if image is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
raise gr.Error("Please upload an image first.")
|
| 118 |
+
if isinstance(image, Image.Image):
|
| 119 |
+
pil_images.append(image.convert("RGB"))
|
| 120 |
+
elif hasattr(image, "name"):
|
| 121 |
+
pil_images.append(Image.open(image.name).convert("RGB"))
|
| 122 |
|
| 123 |
# finetuner.enable()
|
| 124 |
pipe.scheduler.config["base_shift"] = shift
|
| 125 |
pipe.scheduler.config["max_shift"] = shift
|
| 126 |
|
| 127 |
+
gallery_images = []
|
| 128 |
+
|
| 129 |
+
for i in range(num_runs):
|
| 130 |
+
result = pipe(
|
| 131 |
+
image=pil_images,
|
| 132 |
+
prompt=prompt,
|
| 133 |
+
num_inference_steps=num_inference_steps,
|
| 134 |
+
generator=generator,
|
| 135 |
+
vae_image_override=1024 * 1024, #512 * 512,
|
| 136 |
+
latent_size_override=1024 * 1024,
|
| 137 |
+
prompt_cached=prompt_cached,
|
| 138 |
+
return_dict=True,
|
| 139 |
+
).images[0]
|
| 140 |
+
prompt_cached = True
|
| 141 |
+
gallery_images.append(result)
|
| 142 |
+
|
| 143 |
+
yield gallery_images, seed, prompt_cached
|
| 144 |
|
| 145 |
|
| 146 |
# --- UI ---
|
| 147 |
|
| 148 |
+
def reset_prompt_cache():
|
| 149 |
+
return False
|
| 150 |
|
| 151 |
with gr.Blocks(theme=gr.themes.Citrus()) as demo:
|
| 152 |
|
|
|
|
| 155 |
with gr.Row():
|
| 156 |
with gr.Column():
|
| 157 |
image = gr.Image(label="Input Image", type="pil")
|
|
|
|
|
|
|
| 158 |
prompt = gr.Textbox(label="Prompt", placeholder="Prompt", lines=2)
|
| 159 |
|
| 160 |
+
num_runs = gr.Slider(label="Run Consecutively", minimum=0, maximum=100, step=1, value=16)
|
| 161 |
|
| 162 |
run_btn = gr.Button("Generate", variant="primary")
|
| 163 |
|
| 164 |
with gr.Accordion("Advanced Settings", open=False):
|
| 165 |
+
prompt_cached = gr.Checkbox(label="Auto-Cached embeds", value=False)
|
| 166 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
| 167 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 168 |
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=2)
|
| 169 |
shift = gr.Slider(label="Timestep Shift", minimum=0.0, maximum=4.0, step=0.1, value=2.0)
|
| 170 |
|
| 171 |
with gr.Column():
|
| 172 |
+
result = gr.Gallery(
|
| 173 |
+
label="Output Image",
|
| 174 |
+
interactive=False,
|
| 175 |
+
# type="filepath",
|
| 176 |
+
columns=4,
|
| 177 |
+
height=800,
|
| 178 |
+
object_fit="scale-down",
|
| 179 |
+
)
|
| 180 |
|
| 181 |
inputs = [
|
| 182 |
image,
|
| 183 |
prompt,
|
| 184 |
+
num_runs,
|
| 185 |
seed,
|
| 186 |
randomize_seed,
|
| 187 |
num_inference_steps,
|
| 188 |
shift,
|
| 189 |
+
prompt_cached,
|
| 190 |
]
|
| 191 |
+
outputs = [result, seed, prompt_cached]
|
| 192 |
|
| 193 |
|
| 194 |
run_event = run_btn.click(
|
|
|
|
| 197 |
outputs=outputs
|
| 198 |
)
|
| 199 |
|
| 200 |
+
|
| 201 |
+
image.upload(
|
| 202 |
+
fn=reset_prompt_cache,
|
| 203 |
+
inputs=[],
|
| 204 |
+
outputs=[prompt_cached],
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
prompt.input(
|
| 208 |
+
fn=reset_prompt_cache,
|
| 209 |
+
inputs=[],
|
| 210 |
+
outputs=[prompt_cached],
|
| 211 |
+
)
|
| 212 |
|
| 213 |
demo.launch()
|
qwenimage/experiments/quantize_experiments.py
CHANGED
|
@@ -224,11 +224,11 @@ class Qwen_FA3_AoT_fp8darow_nolast(QwenBaseExperiment):
|
|
| 224 |
}
|
| 225 |
)
|
| 226 |
|
| 227 |
-
|
|
|
|
| 228 |
module_fqn_to_config = ModuleFqnToConfig(
|
| 229 |
OrderedDict([
|
| 230 |
(ATTN_LAST_LAYER, None),
|
| 231 |
-
# ("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
|
| 232 |
("_default",Float8DynamicActivationFloat8WeightConfig(),),
|
| 233 |
])
|
| 234 |
)
|
|
@@ -237,6 +237,26 @@ def quantize_transformer_fp8darow_nolast(model):
|
|
| 237 |
print_first_param(model)
|
| 238 |
print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8darow_nofirstlast")
|
| 242 |
class Qwen_FA3_AoT_fp8darow_nofirstlast(QwenBaseExperiment):
|
|
|
|
| 224 |
}
|
| 225 |
)
|
| 226 |
|
| 227 |
+
|
| 228 |
+
def quantize_transformer_fp8da_nolast(model):
|
| 229 |
module_fqn_to_config = ModuleFqnToConfig(
|
| 230 |
OrderedDict([
|
| 231 |
(ATTN_LAST_LAYER, None),
|
|
|
|
| 232 |
("_default",Float8DynamicActivationFloat8WeightConfig(),),
|
| 233 |
])
|
| 234 |
)
|
|
|
|
| 237 |
print_first_param(model)
|
| 238 |
print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
|
| 239 |
|
| 240 |
+
def quantize_transformer_fp8darow_nolast(model):
|
| 241 |
+
module_fqn_to_config = ModuleFqnToConfig(
|
| 242 |
+
OrderedDict([
|
| 243 |
+
(ATTN_LAST_LAYER, None),
|
| 244 |
+
("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
|
| 245 |
+
])
|
| 246 |
+
)
|
| 247 |
+
print(f"original model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
|
| 248 |
+
quantize_(model, module_fqn_to_config)
|
| 249 |
+
print_first_param(model)
|
| 250 |
+
print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
|
| 251 |
+
|
| 252 |
+
def conf_fp8darow_nolast():
|
| 253 |
+
module_fqn_to_config = ModuleFqnToConfig(
|
| 254 |
+
OrderedDict([
|
| 255 |
+
(ATTN_LAST_LAYER, None),
|
| 256 |
+
("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
|
| 257 |
+
])
|
| 258 |
+
)
|
| 259 |
+
return module_fqn_to_config
|
| 260 |
|
| 261 |
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8darow_nofirstlast")
|
| 262 |
class Qwen_FA3_AoT_fp8darow_nofirstlast(QwenBaseExperiment):
|
qwenimage/foundation.py
CHANGED
|
@@ -15,7 +15,7 @@ from einops import rearrange
|
|
| 15 |
from qwenimage.datamodels import QwenConfig, QwenInputs
|
| 16 |
from qwenimage.debug import clear_cuda_memory, ctimed, ftimed, print_gpu_memory, texam
|
| 17 |
from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
|
| 18 |
-
from qwenimage.experiments.quantize_experiments import
|
| 19 |
from qwenimage.loss import LossAccumulator
|
| 20 |
from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
|
| 21 |
from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
|
|
@@ -110,7 +110,7 @@ class QwenImageFoundation(WandModel):
|
|
| 110 |
quantize_text_encoder_int4wo_linear(self.text_encoder)
|
| 111 |
|
| 112 |
if self.config.quantize_transformer:
|
| 113 |
-
|
| 114 |
|
| 115 |
|
| 116 |
def load(self, load_path):
|
|
|
|
| 15 |
from qwenimage.datamodels import QwenConfig, QwenInputs
|
| 16 |
from qwenimage.debug import clear_cuda_memory, ctimed, ftimed, print_gpu_memory, texam
|
| 17 |
from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
|
| 18 |
+
from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8da_nolast
|
| 19 |
from qwenimage.loss import LossAccumulator
|
| 20 |
from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
|
| 21 |
from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
|
|
|
|
| 110 |
quantize_text_encoder_int4wo_linear(self.text_encoder)
|
| 111 |
|
| 112 |
if self.config.quantize_transformer:
|
| 113 |
+
quantize_transformer_fp8da_nolast(self.transformer)
|
| 114 |
|
| 115 |
|
| 116 |
def load(self, load_path):
|
qwenimage/models/autoencoder_kl_qwenimage.py
CHANGED
|
@@ -33,6 +33,8 @@ from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
|
| 33 |
from diffusers.models.modeling_utils import ModelMixin
|
| 34 |
from diffusers.models.autoencoders.vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
| 35 |
|
|
|
|
|
|
|
| 36 |
|
| 37 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
|
|
@@ -870,11 +872,14 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
|
| 870 |
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 871 |
returned.
|
| 872 |
"""
|
|
|
|
| 873 |
if self.use_slicing and z.shape[0] > 1:
|
| 874 |
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 875 |
decoded = torch.cat(decoded_slices)
|
| 876 |
else:
|
| 877 |
decoded = self._decode(z).sample
|
|
|
|
|
|
|
| 878 |
|
| 879 |
if not return_dict:
|
| 880 |
return (decoded,)
|
|
|
|
| 33 |
from diffusers.models.modeling_utils import ModelMixin
|
| 34 |
from diffusers.models.autoencoders.vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
| 35 |
|
| 36 |
+
from qwenimage.debug import texam
|
| 37 |
+
|
| 38 |
|
| 39 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 40 |
|
|
|
|
| 872 |
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 873 |
returned.
|
| 874 |
"""
|
| 875 |
+
texam(z, "z")
|
| 876 |
if self.use_slicing and z.shape[0] > 1:
|
| 877 |
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 878 |
decoded = torch.cat(decoded_slices)
|
| 879 |
else:
|
| 880 |
decoded = self._decode(z).sample
|
| 881 |
+
|
| 882 |
+
texam(decoded, "decoded")
|
| 883 |
|
| 884 |
if not return_dict:
|
| 885 |
return (decoded,)
|
qwenimage/models/pipeline_qwenimage_edit_plus.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
| 15 |
import inspect
|
| 16 |
import math
|
| 17 |
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
| 18 |
import warnings
|
| 19 |
|
| 20 |
from PIL import Image
|
|
@@ -24,6 +25,7 @@ import torch
|
|
| 24 |
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
| 25 |
from transformers.models.qwen2 import Qwen2Tokenizer
|
| 26 |
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
|
|
|
| 27 |
|
| 28 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 29 |
from diffusers.loaders import QwenImageLoraLoaderMixin
|
|
@@ -226,6 +228,10 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 226 |
self.prompt_template_encode_start_idx = 64
|
| 227 |
self.default_sample_size = 128
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
| 230 |
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
| 231 |
bool_mask = mask.bool()
|
|
@@ -571,6 +577,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 571 |
channels_last_format: bool = False,
|
| 572 |
vae_image_override: int | None = None,
|
| 573 |
latent_size_override: int | None = None,
|
|
|
|
| 574 |
):
|
| 575 |
r"""
|
| 576 |
Function invoked when calling the pipeline for generation.
|
|
@@ -708,23 +715,24 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 708 |
|
| 709 |
device = self._execution_device
|
| 710 |
# 3. Preprocess image
|
| 711 |
-
if
|
| 712 |
-
if not isinstance(image,
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
|
|
|
| 728 |
|
| 729 |
has_neg_prompt = negative_prompt is not None or (
|
| 730 |
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
|
@@ -741,15 +749,19 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 741 |
|
| 742 |
with ctimed("Encode Prompt"):
|
| 743 |
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
prompt_embeds=
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
if do_true_cfg:
|
| 754 |
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
| 755 |
image=condition_images,
|
|
@@ -764,26 +776,37 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 764 |
with ctimed("Prep gen"):
|
| 765 |
# 4. Prepare latent variables
|
| 766 |
num_channels_latents = self.transformer.config.in_channels // 4
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
latents,
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
# 5. Prepare timesteps
|
| 789 |
# print(f"{num_inference_steps=}")
|
|
@@ -857,18 +880,23 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 857 |
for i in range(len(ts)-1):
|
| 858 |
t = ts[i]
|
| 859 |
with ctimed(f"loop {i}"):
|
| 860 |
-
|
| 861 |
-
|
|
|
|
|
|
|
| 862 |
|
| 863 |
-
|
| 864 |
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
|
|
|
| 868 |
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
|
|
|
|
|
|
| 872 |
noise_pred = self.transformer(
|
| 873 |
hidden_states=latent_model_input,
|
| 874 |
timestep=in_t,
|
|
@@ -882,7 +910,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 882 |
noise_pred = noise_pred[:, : latents.size(1)]
|
| 883 |
|
| 884 |
if do_true_cfg:
|
| 885 |
-
|
| 886 |
with self.transformer.cache_context("uncond"):
|
| 887 |
neg_noise_pred = self.transformer(
|
| 888 |
hidden_states=latent_model_input,
|
|
@@ -907,29 +935,29 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 907 |
|
| 908 |
latents = t_utils.inference_ode_step(noise_pred, latents, i, ts)
|
| 909 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 910 |
|
| 911 |
-
|
| 912 |
-
if
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
callback_kwargs = {}
|
| 918 |
-
for k in callback_on_step_end_tensor_inputs:
|
| 919 |
-
callback_kwargs[k] = locals()[k]
|
| 920 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 921 |
|
| 922 |
-
|
| 923 |
-
|
| 924 |
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
|
| 929 |
-
|
| 930 |
-
|
| 931 |
|
| 932 |
-
# with ctimed("Post (vae)"):
|
| 933 |
self._current_timestep = None
|
| 934 |
if output_type == "latent":
|
| 935 |
image = latents
|
|
@@ -940,16 +968,51 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 940 |
latents_mean = (
|
| 941 |
torch.tensor(self.vae.config.latents_mean)
|
| 942 |
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 943 |
-
.to(
|
| 944 |
)
|
| 945 |
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 946 |
-
|
| 947 |
)
|
| 948 |
latents = latents / latents_std + latents_mean
|
| 949 |
-
with ctimed("vae.decode"):
|
| 950 |
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
| 951 |
-
|
| 952 |
-
image =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
|
| 954 |
|
| 955 |
# Offload all models
|
|
|
|
| 15 |
import inspect
|
| 16 |
import math
|
| 17 |
from typing import Any, Callable, Dict, List, Optional, Union
|
| 18 |
+
import uuid
|
| 19 |
import warnings
|
| 20 |
|
| 21 |
from PIL import Image
|
|
|
|
| 25 |
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
| 26 |
from transformers.models.qwen2 import Qwen2Tokenizer
|
| 27 |
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
| 28 |
+
from torchvision.io import encode_jpeg, write_file, write_jpeg
|
| 29 |
|
| 30 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 31 |
from diffusers.loaders import QwenImageLoraLoaderMixin
|
|
|
|
| 228 |
self.prompt_template_encode_start_idx = 64
|
| 229 |
self.default_sample_size = 128
|
| 230 |
|
| 231 |
+
self.prompt_embeds, self.prompt_embeds_mask = None, None
|
| 232 |
+
self.image_latents = None
|
| 233 |
+
self.latents_mean, self.latents_std = None, None
|
| 234 |
+
|
| 235 |
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
| 236 |
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
| 237 |
bool_mask = mask.bool()
|
|
|
|
| 577 |
channels_last_format: bool = False,
|
| 578 |
vae_image_override: int | None = None,
|
| 579 |
latent_size_override: int | None = None,
|
| 580 |
+
prompt_cached: bool = False,
|
| 581 |
):
|
| 582 |
r"""
|
| 583 |
Function invoked when calling the pipeline for generation.
|
|
|
|
| 715 |
|
| 716 |
device = self._execution_device
|
| 717 |
# 3. Preprocess image
|
| 718 |
+
if not prompt_cached:
|
| 719 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 720 |
+
if not isinstance(image, list):
|
| 721 |
+
image = [image]
|
| 722 |
+
condition_image_sizes = []
|
| 723 |
+
condition_images = []
|
| 724 |
+
vae_image_sizes = []
|
| 725 |
+
vae_images = []
|
| 726 |
+
for img in image:
|
| 727 |
+
image_width, image_height = img.size
|
| 728 |
+
condition_width, condition_height = calculate_dimensions(
|
| 729 |
+
CONDITION_IMAGE_SIZE, image_width / image_height
|
| 730 |
+
)
|
| 731 |
+
vae_width, vae_height = calculate_dimensions(vae_image_size, image_width / image_height)
|
| 732 |
+
condition_image_sizes.append((condition_width, condition_height))
|
| 733 |
+
vae_image_sizes.append((vae_width, vae_height))
|
| 734 |
+
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
|
| 735 |
+
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
|
| 736 |
|
| 737 |
has_neg_prompt = negative_prompt is not None or (
|
| 738 |
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
|
|
|
| 749 |
|
| 750 |
with ctimed("Encode Prompt"):
|
| 751 |
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 752 |
+
if prompt_cached:
|
| 753 |
+
prompt_embeds, prompt_embeds_mask = self.prompt_embeds, self.prompt_embeds_mask
|
| 754 |
+
else:
|
| 755 |
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
| 756 |
+
image=condition_images,
|
| 757 |
+
prompt=prompt,
|
| 758 |
+
prompt_embeds=prompt_embeds,
|
| 759 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 760 |
+
device=device,
|
| 761 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 762 |
+
max_sequence_length=max_sequence_length,
|
| 763 |
+
)
|
| 764 |
+
self.prompt_embeds, self.prompt_embeds_mask = prompt_embeds, prompt_embeds_mask
|
| 765 |
if do_true_cfg:
|
| 766 |
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
| 767 |
image=condition_images,
|
|
|
|
| 776 |
with ctimed("Prep gen"):
|
| 777 |
# 4. Prepare latent variables
|
| 778 |
num_channels_latents = self.transformer.config.in_channels // 4
|
| 779 |
+
if prompt_cached:
|
| 780 |
+
image_latents = self.image_latents
|
| 781 |
+
_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 782 |
+
_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 783 |
+
shape = (batch_size * num_images_per_prompt, 1, num_channels_latents, _height, _width)
|
| 784 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=image_latents.dtype)
|
| 785 |
+
latents = self._pack_latents(latents, batch_size * num_images_per_prompt, num_channels_latents, _height, _width)
|
| 786 |
+
img_shapes = self.img_shapes
|
| 787 |
+
else:
|
| 788 |
+
latents, image_latents = self.prepare_latents(
|
| 789 |
+
vae_images,
|
| 790 |
+
batch_size * num_images_per_prompt,
|
| 791 |
+
num_channels_latents,
|
| 792 |
+
height,
|
| 793 |
+
width,
|
| 794 |
+
prompt_embeds.dtype,
|
| 795 |
+
device,
|
| 796 |
+
generator,
|
| 797 |
+
latents,
|
| 798 |
+
)
|
| 799 |
+
self.image_latents = image_latents
|
| 800 |
+
img_shapes = [
|
| 801 |
+
[
|
| 802 |
+
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
| 803 |
+
*[
|
| 804 |
+
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
|
| 805 |
+
for vae_width, vae_height in vae_image_sizes
|
| 806 |
+
],
|
| 807 |
+
]
|
| 808 |
+
] * batch_size
|
| 809 |
+
self.img_shapes = img_shapes
|
| 810 |
|
| 811 |
# 5. Prepare timesteps
|
| 812 |
# print(f"{num_inference_steps=}")
|
|
|
|
| 880 |
for i in range(len(ts)-1):
|
| 881 |
t = ts[i]
|
| 882 |
with ctimed(f"loop {i}"):
|
| 883 |
+
|
| 884 |
+
with ctimed("pre trans"):
|
| 885 |
+
if self.interrupt:
|
| 886 |
+
continue
|
| 887 |
|
| 888 |
+
# self._current_timestep = t
|
| 889 |
|
| 890 |
+
with ctimed("cat lats"):
|
| 891 |
+
latent_model_input = latents
|
| 892 |
+
if image_latents is not None:
|
| 893 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 894 |
|
| 895 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 896 |
+
with ctimed("broadcast lats"):
|
| 897 |
+
in_t = t.expand(latents.shape[0]).to(latents.dtype)
|
| 898 |
+
|
| 899 |
+
with ctimed("transformer proper"):
|
| 900 |
noise_pred = self.transformer(
|
| 901 |
hidden_states=latent_model_input,
|
| 902 |
timestep=in_t,
|
|
|
|
| 910 |
noise_pred = noise_pred[:, : latents.size(1)]
|
| 911 |
|
| 912 |
if do_true_cfg:
|
| 913 |
+
raise NotImplementedError()
|
| 914 |
with self.transformer.cache_context("uncond"):
|
| 915 |
neg_noise_pred = self.transformer(
|
| 916 |
hidden_states=latent_model_input,
|
|
|
|
| 935 |
|
| 936 |
latents = t_utils.inference_ode_step(noise_pred, latents, i, ts)
|
| 937 |
|
| 938 |
+
with ctimed("dtype stuff"):
|
| 939 |
+
if latents.dtype != latents_dtype:
|
| 940 |
+
if torch.backends.mps.is_available():
|
| 941 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 942 |
+
latents = latents.to(latents_dtype)
|
| 943 |
|
| 944 |
+
with ctimed("callback and shenanagans"):
|
| 945 |
+
if callback_on_step_end is not None:
|
| 946 |
+
callback_kwargs = {}
|
| 947 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 948 |
+
callback_kwargs[k] = locals()[k]
|
| 949 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
|
| 951 |
+
latents = callback_outputs.pop("latents", latents)
|
| 952 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 953 |
|
| 954 |
+
# call the callback, if provided
|
| 955 |
+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 956 |
+
progress_bar.update()
|
| 957 |
|
| 958 |
+
if XLA_AVAILABLE:
|
| 959 |
+
xm.mark_step()
|
| 960 |
|
|
|
|
| 961 |
self._current_timestep = None
|
| 962 |
if output_type == "latent":
|
| 963 |
image = latents
|
|
|
|
| 968 |
latents_mean = (
|
| 969 |
torch.tensor(self.vae.config.latents_mean)
|
| 970 |
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 971 |
+
.to(device, self.vae.dtype)
|
| 972 |
)
|
| 973 |
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 974 |
+
device, self.vae.dtype
|
| 975 |
)
|
| 976 |
latents = latents / latents_std + latents_mean
|
|
|
|
| 977 |
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
| 978 |
+
|
| 979 |
+
image = image.squeeze(0).add(1).mul(127.5).to(torch.uint8).cpu()
|
| 980 |
+
|
| 981 |
+
image_path = f"/tmp/{str(uuid.uuid4())[:8]}.jpg"
|
| 982 |
+
write_jpeg(image, image_path)
|
| 983 |
+
image = (image_path,)
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
# with ctimed("Post (vae)"):
|
| 987 |
+
# self._current_timestep = None
|
| 988 |
+
# if output_type == "latent":
|
| 989 |
+
# image = latents
|
| 990 |
+
# else:
|
| 991 |
+
# with ctimed("pre decode"):
|
| 992 |
+
# latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 993 |
+
|
| 994 |
+
# if prompt_cached:
|
| 995 |
+
# latents_mean, latents_std = self.latents_mean, self.latents_std
|
| 996 |
+
# else:
|
| 997 |
+
# latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=latents.dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 998 |
+
# latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=latents.dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 999 |
+
|
| 1000 |
+
# self.latents_mean, self.latents_std = latents_mean, latents_std
|
| 1001 |
+
|
| 1002 |
+
# latents = latents / latents_std + latents_mean
|
| 1003 |
+
# with ctimed("todtype"):
|
| 1004 |
+
# latents = latents.to(self.vae.dtype)
|
| 1005 |
+
# with ctimed("vae.decode"):
|
| 1006 |
+
# image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] # [B,C,H,W]
|
| 1007 |
+
|
| 1008 |
+
# with ctimed("post process"):
|
| 1009 |
+
# with ctimed("convert"):
|
| 1010 |
+
# image = image.squeeze(0).add(1).mul(127.5).to(torch.uint8).cpu()
|
| 1011 |
+
# with ctimed("write"):
|
| 1012 |
+
# image_path = f"/tmp/{str(uuid.uuid4())[:8]}.jpg"
|
| 1013 |
+
# write_jpeg(image, image_path)
|
| 1014 |
+
# image = (image_path,)
|
| 1015 |
+
|
| 1016 |
|
| 1017 |
|
| 1018 |
# Offload all models
|
qwenimage/models/transformer_qwenimage.py
CHANGED
|
@@ -35,6 +35,7 @@ from diffusers.models.modeling_utils import ModelMixin
|
|
| 35 |
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 36 |
|
| 37 |
from qwenimage.activation_record import ActivationReport
|
|
|
|
| 38 |
from qwenimage.models.attention_processors import QwenDoubleStreamAttnProcessor2_0
|
| 39 |
|
| 40 |
|
|
@@ -511,61 +512,66 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
|
| 511 |
else:
|
| 512 |
lora_scale = 1.0
|
| 513 |
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 519 |
-
logger.warning(
|
| 520 |
-
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 521 |
-
)
|
| 522 |
-
|
| 523 |
-
hidden_states = self.img_in(hidden_states)
|
| 524 |
-
|
| 525 |
-
timestep = timestep.to(hidden_states.dtype)
|
| 526 |
-
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
| 527 |
-
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
| 528 |
-
|
| 529 |
-
if guidance is not None:
|
| 530 |
-
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 531 |
-
|
| 532 |
-
temb = (
|
| 533 |
-
self.time_text_embed(timestep, hidden_states)
|
| 534 |
-
if guidance is None
|
| 535 |
-
else self.time_text_embed(timestep, guidance, hidden_states)
|
| 536 |
-
)
|
| 537 |
-
|
| 538 |
-
for index_block, block in enumerate(self.transformer_blocks):
|
| 539 |
-
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 540 |
-
warnings.warn("Gradient ckpt?")
|
| 541 |
-
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 542 |
-
block,
|
| 543 |
-
hidden_states,
|
| 544 |
-
encoder_hidden_states,
|
| 545 |
-
encoder_hidden_states_mask,
|
| 546 |
-
temb,
|
| 547 |
-
image_rotary_emb,
|
| 548 |
-
)
|
| 549 |
-
|
| 550 |
else:
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
if not return_dict:
|
| 571 |
return (output,)
|
|
|
|
| 35 |
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 36 |
|
| 37 |
from qwenimage.activation_record import ActivationReport
|
| 38 |
+
from qwenimage.debug import ctimed
|
| 39 |
from qwenimage.models.attention_processors import QwenDoubleStreamAttnProcessor2_0
|
| 40 |
|
| 41 |
|
|
|
|
| 512 |
else:
|
| 513 |
lora_scale = 1.0
|
| 514 |
|
| 515 |
+
with ctimed("scale lora"):
|
| 516 |
+
if USE_PEFT_BACKEND:
|
| 517 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 518 |
+
scale_lora_layers(self, lora_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
else:
|
| 520 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 521 |
+
logger.warning(
|
| 522 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
with ctimed("pre blocks"):
|
| 526 |
+
hidden_states = self.img_in(hidden_states)
|
| 527 |
+
|
| 528 |
+
timestep = timestep.to(hidden_states.dtype)
|
| 529 |
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
| 530 |
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
| 531 |
+
|
| 532 |
+
if guidance is not None:
|
| 533 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 534 |
+
|
| 535 |
+
temb = (
|
| 536 |
+
self.time_text_embed(timestep, hidden_states)
|
| 537 |
+
if guidance is None
|
| 538 |
+
else self.time_text_embed(timestep, guidance, hidden_states)
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
with ctimed("blocks"):
|
| 542 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 543 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 544 |
+
warnings.warn("Gradient ckpt?")
|
| 545 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 546 |
+
block,
|
| 547 |
+
hidden_states,
|
| 548 |
+
encoder_hidden_states,
|
| 549 |
+
encoder_hidden_states_mask,
|
| 550 |
+
temb,
|
| 551 |
+
image_rotary_emb,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
else:
|
| 555 |
+
encoder_hidden_states, hidden_states = block(
|
| 556 |
+
hidden_states=hidden_states,
|
| 557 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 558 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 559 |
+
temb=temb,
|
| 560 |
+
image_rotary_emb=image_rotary_emb,
|
| 561 |
+
joint_attention_kwargs=attention_kwargs,
|
| 562 |
+
)
|
| 563 |
+
self.arec(f"encoder_hidden_states.{index_block}", encoder_hidden_states)
|
| 564 |
+
self.arec(f"hidden_states.{index_block}", hidden_states)
|
| 565 |
+
|
| 566 |
+
with ctimed("post blocks"):
|
| 567 |
+
# Use only the image part (hidden_states) from the dual-stream blocks
|
| 568 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 569 |
+
output = self.proj_out(hidden_states)
|
| 570 |
+
|
| 571 |
+
with ctimed("lora"):
|
| 572 |
+
if USE_PEFT_BACKEND:
|
| 573 |
+
# remove `lora_scale` from each PEFT layer
|
| 574 |
+
unscale_lora_layers(self, lora_scale)
|
| 575 |
|
| 576 |
if not return_dict:
|
| 577 |
return (output,)
|