Elea Zhong commited on
Commit
92d8df6
·
1 Parent(s): 1b9d6c7

add 2step pipe and app

Browse files
app.py CHANGED
@@ -12,6 +12,9 @@ import gradio as gr
12
  import spaces
13
 
14
  import subprocess
 
 
 
15
  GIT_TOKEN = os.environ.get("GIT_TOKEN")
16
  import subprocess
17
 
@@ -34,6 +37,7 @@ import subprocess
34
  from qwenimage.debug import ctimed
35
  from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
36
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
 
37
 
38
  # --- Model Loading ---
39
 
@@ -64,7 +68,26 @@ pipe.load_lora_weights(
64
  "checkpoints/distill_5k_lora.safetensors",
65
  adapter_name="fast_5k",
66
  )
67
- # pipe.unload_lora_weights()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  MAX_SEED = np.iinfo(np.int32).max
70
 
@@ -73,12 +96,12 @@ MAX_SEED = np.iinfo(np.int32).max
73
  def run_pipe(
74
  image,
75
  prompt,
 
76
  seed,
77
  randomize_seed,
78
  num_inference_steps,
79
  shift,
80
- prev_output = None,
81
- progress=gr.Progress(track_tqdm=True)
82
  ):
83
  with ctimed("pre pipe"):
84
 
@@ -90,35 +113,40 @@ def run_pipe(
90
 
91
  # Choose input image (prefer uploaded, else last output)
92
  pil_images = []
93
- if image is not None:
94
- if isinstance(image, Image.Image):
95
- pil_images.append(image.convert("RGB"))
96
- elif hasattr(image, "name"):
97
- pil_images.append(Image.open(image.name).convert("RGB"))
98
- elif prev_output:
99
- pil_images.append(prev_output.convert("RGB"))
100
-
101
- if len(pil_images) == 0:
102
  raise gr.Error("Please upload an image first.")
103
-
104
- print(f"{len(pil_images)=}")
 
 
105
 
106
  # finetuner.enable()
107
  pipe.scheduler.config["base_shift"] = shift
108
  pipe.scheduler.config["max_shift"] = shift
109
 
110
- result = pipe(
111
- image=pil_images,
112
- prompt=prompt,
113
- num_inference_steps=num_inference_steps,
114
- generator=generator,
115
- ).images[0]
116
-
117
- return result, seed
 
 
 
 
 
 
 
 
 
118
 
119
 
120
  # --- UI ---
121
 
 
 
122
 
123
  with gr.Blocks(theme=gr.themes.Citrus()) as demo:
124
 
@@ -127,32 +155,40 @@ with gr.Blocks(theme=gr.themes.Citrus()) as demo:
127
  with gr.Row():
128
  with gr.Column():
129
  image = gr.Image(label="Input Image", type="pil")
130
- prev_output = gr.Image(value=None, visible=False)
131
- is_reset = gr.Checkbox(value=False, visible=False)
132
  prompt = gr.Textbox(label="Prompt", placeholder="Prompt", lines=2)
133
 
 
134
 
135
  run_btn = gr.Button("Generate", variant="primary")
136
 
137
  with gr.Accordion("Advanced Settings", open=False):
 
138
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
139
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
140
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=2)
141
  shift = gr.Slider(label="Timestep Shift", minimum=0.0, maximum=4.0, step=0.1, value=2.0)
142
 
143
  with gr.Column():
144
- result = gr.Image(label="Output Image", interactive=False)
 
 
 
 
 
 
 
145
 
146
  inputs = [
147
  image,
148
  prompt,
 
149
  seed,
150
  randomize_seed,
151
  num_inference_steps,
152
  shift,
153
- prev_output,
154
  ]
155
- outputs = [result, seed]
156
 
157
 
158
  run_event = run_btn.click(
@@ -161,6 +197,17 @@ with gr.Blocks(theme=gr.themes.Citrus()) as demo:
161
  outputs=outputs
162
  )
163
 
164
- run_event.then(lambda img, *_: img, inputs=[result], outputs=[prev_output])
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  demo.launch()
 
12
  import spaces
13
 
14
  import subprocess
15
+
16
+ from qwenimage.models.attention_processors import QwenDoubleStreamAttnProcessorFA3
17
+ from qwenimage.optimization import optimize_pipeline_
18
  GIT_TOKEN = os.environ.get("GIT_TOKEN")
19
  import subprocess
20
 
 
37
  from qwenimage.debug import ctimed
38
  from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
39
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
40
+ from qwenimage.experiments.quantize_experiments import conf_fp8darow_nolast, quantize_transformer_fp8darow_nolast
41
 
42
  # --- Model Loading ---
43
 
 
68
  "checkpoints/distill_5k_lora.safetensors",
69
  adapter_name="fast_5k",
70
  )
71
+ pipe.set_adapters(["fast_5k"], adapter_weights=[1.0])
72
+ pipe.fuse_lora(adapter_names=["fast_5k"], lora_scale=1.0)
73
+ pipe.unload_lora_weights()
74
+
75
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
76
+ pipe.transformer.fuse_qkv_projections()
77
+ pipe.transformer.check_fused_qkv()
78
+
79
+ optimize_pipeline_(
80
+ pipe,
81
+ cache_compiled=True,
82
+ quantize=True,
83
+ suffix="_fp8darow_nolast_fa3_fast5k",
84
+ quantize_config=conf_fp8darow_nolast(),
85
+ pipe_kwargs={
86
+ "image": [Image.new("RGB", (1024, 1024))],
87
+ "prompt":"prompt",
88
+ "num_inference_steps":2,
89
+ }
90
+ )
91
 
92
  MAX_SEED = np.iinfo(np.int32).max
93
 
 
96
  def run_pipe(
97
  image,
98
  prompt,
99
+ num_runs,
100
  seed,
101
  randomize_seed,
102
  num_inference_steps,
103
  shift,
104
+ prompt_cached,
 
105
  ):
106
  with ctimed("pre pipe"):
107
 
 
113
 
114
  # Choose input image (prefer uploaded, else last output)
115
  pil_images = []
116
+ if image is None:
 
 
 
 
 
 
 
 
117
  raise gr.Error("Please upload an image first.")
118
+ if isinstance(image, Image.Image):
119
+ pil_images.append(image.convert("RGB"))
120
+ elif hasattr(image, "name"):
121
+ pil_images.append(Image.open(image.name).convert("RGB"))
122
 
123
  # finetuner.enable()
124
  pipe.scheduler.config["base_shift"] = shift
125
  pipe.scheduler.config["max_shift"] = shift
126
 
127
+ gallery_images = []
128
+
129
+ for i in range(num_runs):
130
+ result = pipe(
131
+ image=pil_images,
132
+ prompt=prompt,
133
+ num_inference_steps=num_inference_steps,
134
+ generator=generator,
135
+ vae_image_override=1024 * 1024, #512 * 512,
136
+ latent_size_override=1024 * 1024,
137
+ prompt_cached=prompt_cached,
138
+ return_dict=True,
139
+ ).images[0]
140
+ prompt_cached = True
141
+ gallery_images.append(result)
142
+
143
+ yield gallery_images, seed, prompt_cached
144
 
145
 
146
  # --- UI ---
147
 
148
+ def reset_prompt_cache():
149
+ return False
150
 
151
  with gr.Blocks(theme=gr.themes.Citrus()) as demo:
152
 
 
155
  with gr.Row():
156
  with gr.Column():
157
  image = gr.Image(label="Input Image", type="pil")
 
 
158
  prompt = gr.Textbox(label="Prompt", placeholder="Prompt", lines=2)
159
 
160
+ num_runs = gr.Slider(label="Run Consecutively", minimum=0, maximum=100, step=1, value=16)
161
 
162
  run_btn = gr.Button("Generate", variant="primary")
163
 
164
  with gr.Accordion("Advanced Settings", open=False):
165
+ prompt_cached = gr.Checkbox(label="Auto-Cached embeds", value=False)
166
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
167
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
168
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=2)
169
  shift = gr.Slider(label="Timestep Shift", minimum=0.0, maximum=4.0, step=0.1, value=2.0)
170
 
171
  with gr.Column():
172
+ result = gr.Gallery(
173
+ label="Output Image",
174
+ interactive=False,
175
+ # type="filepath",
176
+ columns=4,
177
+ height=800,
178
+ object_fit="scale-down",
179
+ )
180
 
181
  inputs = [
182
  image,
183
  prompt,
184
+ num_runs,
185
  seed,
186
  randomize_seed,
187
  num_inference_steps,
188
  shift,
189
+ prompt_cached,
190
  ]
191
+ outputs = [result, seed, prompt_cached]
192
 
193
 
194
  run_event = run_btn.click(
 
197
  outputs=outputs
198
  )
199
 
200
+
201
+ image.upload(
202
+ fn=reset_prompt_cache,
203
+ inputs=[],
204
+ outputs=[prompt_cached],
205
+ )
206
+
207
+ prompt.input(
208
+ fn=reset_prompt_cache,
209
+ inputs=[],
210
+ outputs=[prompt_cached],
211
+ )
212
 
213
  demo.launch()
qwenimage/experiments/quantize_experiments.py CHANGED
@@ -224,11 +224,11 @@ class Qwen_FA3_AoT_fp8darow_nolast(QwenBaseExperiment):
224
  }
225
  )
226
 
227
- def quantize_transformer_fp8darow_nolast(model):
 
228
  module_fqn_to_config = ModuleFqnToConfig(
229
  OrderedDict([
230
  (ATTN_LAST_LAYER, None),
231
- # ("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
232
  ("_default",Float8DynamicActivationFloat8WeightConfig(),),
233
  ])
234
  )
@@ -237,6 +237,26 @@ def quantize_transformer_fp8darow_nolast(model):
237
  print_first_param(model)
238
  print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  @ExperimentRegistry.register(name="qwen_fa3_aot_fp8darow_nofirstlast")
242
  class Qwen_FA3_AoT_fp8darow_nofirstlast(QwenBaseExperiment):
 
224
  }
225
  )
226
 
227
+
228
+ def quantize_transformer_fp8da_nolast(model):
229
  module_fqn_to_config = ModuleFqnToConfig(
230
  OrderedDict([
231
  (ATTN_LAST_LAYER, None),
 
232
  ("_default",Float8DynamicActivationFloat8WeightConfig(),),
233
  ])
234
  )
 
237
  print_first_param(model)
238
  print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
239
 
240
+ def quantize_transformer_fp8darow_nolast(model):
241
+ module_fqn_to_config = ModuleFqnToConfig(
242
+ OrderedDict([
243
+ (ATTN_LAST_LAYER, None),
244
+ ("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
245
+ ])
246
+ )
247
+ print(f"original model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
248
+ quantize_(model, module_fqn_to_config)
249
+ print_first_param(model)
250
+ print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
251
+
252
+ def conf_fp8darow_nolast():
253
+ module_fqn_to_config = ModuleFqnToConfig(
254
+ OrderedDict([
255
+ (ATTN_LAST_LAYER, None),
256
+ ("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
257
+ ])
258
+ )
259
+ return module_fqn_to_config
260
 
261
  @ExperimentRegistry.register(name="qwen_fa3_aot_fp8darow_nofirstlast")
262
  class Qwen_FA3_AoT_fp8darow_nofirstlast(QwenBaseExperiment):
qwenimage/foundation.py CHANGED
@@ -15,7 +15,7 @@ from einops import rearrange
15
  from qwenimage.datamodels import QwenConfig, QwenInputs
16
  from qwenimage.debug import clear_cuda_memory, ctimed, ftimed, print_gpu_memory, texam
17
  from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
18
- from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast
19
  from qwenimage.loss import LossAccumulator
20
  from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
21
  from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
@@ -110,7 +110,7 @@ class QwenImageFoundation(WandModel):
110
  quantize_text_encoder_int4wo_linear(self.text_encoder)
111
 
112
  if self.config.quantize_transformer:
113
- quantize_transformer_fp8darow_nolast(self.transformer)
114
 
115
 
116
  def load(self, load_path):
 
15
  from qwenimage.datamodels import QwenConfig, QwenInputs
16
  from qwenimage.debug import clear_cuda_memory, ctimed, ftimed, print_gpu_memory, texam
17
  from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
18
+ from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8da_nolast
19
  from qwenimage.loss import LossAccumulator
20
  from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
21
  from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
 
110
  quantize_text_encoder_int4wo_linear(self.text_encoder)
111
 
112
  if self.config.quantize_transformer:
113
+ quantize_transformer_fp8da_nolast(self.transformer)
114
 
115
 
116
  def load(self, load_path):
qwenimage/models/autoencoder_kl_qwenimage.py CHANGED
@@ -33,6 +33,8 @@ from diffusers.models.modeling_outputs import AutoencoderKLOutput
33
  from diffusers.models.modeling_utils import ModelMixin
34
  from diffusers.models.autoencoders.vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
35
 
 
 
36
 
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
 
@@ -870,11 +872,14 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
870
  If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
871
  returned.
872
  """
 
873
  if self.use_slicing and z.shape[0] > 1:
874
  decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
875
  decoded = torch.cat(decoded_slices)
876
  else:
877
  decoded = self._decode(z).sample
 
 
878
 
879
  if not return_dict:
880
  return (decoded,)
 
33
  from diffusers.models.modeling_utils import ModelMixin
34
  from diffusers.models.autoencoders.vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
35
 
36
+ from qwenimage.debug import texam
37
+
38
 
39
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
 
 
872
  If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
873
  returned.
874
  """
875
+ texam(z, "z")
876
  if self.use_slicing and z.shape[0] > 1:
877
  decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
878
  decoded = torch.cat(decoded_slices)
879
  else:
880
  decoded = self._decode(z).sample
881
+
882
+ texam(decoded, "decoded")
883
 
884
  if not return_dict:
885
  return (decoded,)
qwenimage/models/pipeline_qwenimage_edit_plus.py CHANGED
@@ -15,6 +15,7 @@
15
  import inspect
16
  import math
17
  from typing import Any, Callable, Dict, List, Optional, Union
 
18
  import warnings
19
 
20
  from PIL import Image
@@ -24,6 +25,7 @@ import torch
24
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
25
  from transformers.models.qwen2 import Qwen2Tokenizer
26
  from transformers.models.qwen2_vl import Qwen2VLProcessor
 
27
 
28
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
29
  from diffusers.loaders import QwenImageLoraLoaderMixin
@@ -226,6 +228,10 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
226
  self.prompt_template_encode_start_idx = 64
227
  self.default_sample_size = 128
228
 
 
 
 
 
229
  # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
230
  def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
231
  bool_mask = mask.bool()
@@ -571,6 +577,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
571
  channels_last_format: bool = False,
572
  vae_image_override: int | None = None,
573
  latent_size_override: int | None = None,
 
574
  ):
575
  r"""
576
  Function invoked when calling the pipeline for generation.
@@ -708,23 +715,24 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
708
 
709
  device = self._execution_device
710
  # 3. Preprocess image
711
- if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
712
- if not isinstance(image, list):
713
- image = [image]
714
- condition_image_sizes = []
715
- condition_images = []
716
- vae_image_sizes = []
717
- vae_images = []
718
- for img in image:
719
- image_width, image_height = img.size
720
- condition_width, condition_height = calculate_dimensions(
721
- CONDITION_IMAGE_SIZE, image_width / image_height
722
- )
723
- vae_width, vae_height = calculate_dimensions(vae_image_size, image_width / image_height)
724
- condition_image_sizes.append((condition_width, condition_height))
725
- vae_image_sizes.append((vae_width, vae_height))
726
- condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
727
- vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
 
728
 
729
  has_neg_prompt = negative_prompt is not None or (
730
  negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
@@ -741,15 +749,19 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
741
 
742
  with ctimed("Encode Prompt"):
743
  do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
744
- prompt_embeds, prompt_embeds_mask = self.encode_prompt(
745
- image=condition_images,
746
- prompt=prompt,
747
- prompt_embeds=prompt_embeds,
748
- prompt_embeds_mask=prompt_embeds_mask,
749
- device=device,
750
- num_images_per_prompt=num_images_per_prompt,
751
- max_sequence_length=max_sequence_length,
752
- )
 
 
 
 
753
  if do_true_cfg:
754
  negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
755
  image=condition_images,
@@ -764,26 +776,37 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
764
  with ctimed("Prep gen"):
765
  # 4. Prepare latent variables
766
  num_channels_latents = self.transformer.config.in_channels // 4
767
- latents, image_latents = self.prepare_latents(
768
- vae_images,
769
- batch_size * num_images_per_prompt,
770
- num_channels_latents,
771
- height,
772
- width,
773
- prompt_embeds.dtype,
774
- device,
775
- generator,
776
- latents,
777
- )
778
- img_shapes = [
779
- [
780
- (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
781
- *[
782
- (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
783
- for vae_width, vae_height in vae_image_sizes
784
- ],
785
- ]
786
- ] * batch_size
 
 
 
 
 
 
 
 
 
 
 
787
 
788
  # 5. Prepare timesteps
789
  # print(f"{num_inference_steps=}")
@@ -857,18 +880,23 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
857
  for i in range(len(ts)-1):
858
  t = ts[i]
859
  with ctimed(f"loop {i}"):
860
- if self.interrupt:
861
- continue
 
 
862
 
863
- # self._current_timestep = t
864
 
865
- latent_model_input = latents
866
- if image_latents is not None:
867
- latent_model_input = torch.cat([latents, image_latents], dim=1)
 
868
 
869
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
870
- in_t = t.expand(latents.shape[0]).to(latents.dtype)
871
- with self.transformer.cache_context("cond"):
 
 
872
  noise_pred = self.transformer(
873
  hidden_states=latent_model_input,
874
  timestep=in_t,
@@ -882,7 +910,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
882
  noise_pred = noise_pred[:, : latents.size(1)]
883
 
884
  if do_true_cfg:
885
- warnings.warn("doing true CFG")
886
  with self.transformer.cache_context("uncond"):
887
  neg_noise_pred = self.transformer(
888
  hidden_states=latent_model_input,
@@ -907,29 +935,29 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
907
 
908
  latents = t_utils.inference_ode_step(noise_pred, latents, i, ts)
909
 
 
 
 
 
 
910
 
911
- if latents.dtype != latents_dtype:
912
- if torch.backends.mps.is_available():
913
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
914
- latents = latents.to(latents_dtype)
915
-
916
- if callback_on_step_end is not None:
917
- callback_kwargs = {}
918
- for k in callback_on_step_end_tensor_inputs:
919
- callback_kwargs[k] = locals()[k]
920
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
921
 
922
- latents = callback_outputs.pop("latents", latents)
923
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
924
 
925
- # call the callback, if provided
926
- # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
927
- progress_bar.update()
928
 
929
- if XLA_AVAILABLE:
930
- xm.mark_step()
931
 
932
- # with ctimed("Post (vae)"):
933
  self._current_timestep = None
934
  if output_type == "latent":
935
  image = latents
@@ -940,16 +968,51 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
940
  latents_mean = (
941
  torch.tensor(self.vae.config.latents_mean)
942
  .view(1, self.vae.config.z_dim, 1, 1, 1)
943
- .to(latents.device, latents.dtype)
944
  )
945
  latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
946
- latents.device, latents.dtype
947
  )
948
  latents = latents / latents_std + latents_mean
949
- with ctimed("vae.decode"):
950
  image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
951
- with ctimed("post process"):
952
- image = self.image_processor.postprocess(image, output_type=output_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953
 
954
 
955
  # Offload all models
 
15
  import inspect
16
  import math
17
  from typing import Any, Callable, Dict, List, Optional, Union
18
+ import uuid
19
  import warnings
20
 
21
  from PIL import Image
 
25
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
26
  from transformers.models.qwen2 import Qwen2Tokenizer
27
  from transformers.models.qwen2_vl import Qwen2VLProcessor
28
+ from torchvision.io import encode_jpeg, write_file, write_jpeg
29
 
30
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
31
  from diffusers.loaders import QwenImageLoraLoaderMixin
 
228
  self.prompt_template_encode_start_idx = 64
229
  self.default_sample_size = 128
230
 
231
+ self.prompt_embeds, self.prompt_embeds_mask = None, None
232
+ self.image_latents = None
233
+ self.latents_mean, self.latents_std = None, None
234
+
235
  # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
236
  def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
237
  bool_mask = mask.bool()
 
577
  channels_last_format: bool = False,
578
  vae_image_override: int | None = None,
579
  latent_size_override: int | None = None,
580
+ prompt_cached: bool = False,
581
  ):
582
  r"""
583
  Function invoked when calling the pipeline for generation.
 
715
 
716
  device = self._execution_device
717
  # 3. Preprocess image
718
+ if not prompt_cached:
719
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
720
+ if not isinstance(image, list):
721
+ image = [image]
722
+ condition_image_sizes = []
723
+ condition_images = []
724
+ vae_image_sizes = []
725
+ vae_images = []
726
+ for img in image:
727
+ image_width, image_height = img.size
728
+ condition_width, condition_height = calculate_dimensions(
729
+ CONDITION_IMAGE_SIZE, image_width / image_height
730
+ )
731
+ vae_width, vae_height = calculate_dimensions(vae_image_size, image_width / image_height)
732
+ condition_image_sizes.append((condition_width, condition_height))
733
+ vae_image_sizes.append((vae_width, vae_height))
734
+ condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
735
+ vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
736
 
737
  has_neg_prompt = negative_prompt is not None or (
738
  negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
 
749
 
750
  with ctimed("Encode Prompt"):
751
  do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
752
+ if prompt_cached:
753
+ prompt_embeds, prompt_embeds_mask = self.prompt_embeds, self.prompt_embeds_mask
754
+ else:
755
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
756
+ image=condition_images,
757
+ prompt=prompt,
758
+ prompt_embeds=prompt_embeds,
759
+ prompt_embeds_mask=prompt_embeds_mask,
760
+ device=device,
761
+ num_images_per_prompt=num_images_per_prompt,
762
+ max_sequence_length=max_sequence_length,
763
+ )
764
+ self.prompt_embeds, self.prompt_embeds_mask = prompt_embeds, prompt_embeds_mask
765
  if do_true_cfg:
766
  negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
767
  image=condition_images,
 
776
  with ctimed("Prep gen"):
777
  # 4. Prepare latent variables
778
  num_channels_latents = self.transformer.config.in_channels // 4
779
+ if prompt_cached:
780
+ image_latents = self.image_latents
781
+ _height = 2 * (int(height) // (self.vae_scale_factor * 2))
782
+ _width = 2 * (int(width) // (self.vae_scale_factor * 2))
783
+ shape = (batch_size * num_images_per_prompt, 1, num_channels_latents, _height, _width)
784
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=image_latents.dtype)
785
+ latents = self._pack_latents(latents, batch_size * num_images_per_prompt, num_channels_latents, _height, _width)
786
+ img_shapes = self.img_shapes
787
+ else:
788
+ latents, image_latents = self.prepare_latents(
789
+ vae_images,
790
+ batch_size * num_images_per_prompt,
791
+ num_channels_latents,
792
+ height,
793
+ width,
794
+ prompt_embeds.dtype,
795
+ device,
796
+ generator,
797
+ latents,
798
+ )
799
+ self.image_latents = image_latents
800
+ img_shapes = [
801
+ [
802
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
803
+ *[
804
+ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
805
+ for vae_width, vae_height in vae_image_sizes
806
+ ],
807
+ ]
808
+ ] * batch_size
809
+ self.img_shapes = img_shapes
810
 
811
  # 5. Prepare timesteps
812
  # print(f"{num_inference_steps=}")
 
880
  for i in range(len(ts)-1):
881
  t = ts[i]
882
  with ctimed(f"loop {i}"):
883
+
884
+ with ctimed("pre trans"):
885
+ if self.interrupt:
886
+ continue
887
 
888
+ # self._current_timestep = t
889
 
890
+ with ctimed("cat lats"):
891
+ latent_model_input = latents
892
+ if image_latents is not None:
893
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
894
 
895
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
896
+ with ctimed("broadcast lats"):
897
+ in_t = t.expand(latents.shape[0]).to(latents.dtype)
898
+
899
+ with ctimed("transformer proper"):
900
  noise_pred = self.transformer(
901
  hidden_states=latent_model_input,
902
  timestep=in_t,
 
910
  noise_pred = noise_pred[:, : latents.size(1)]
911
 
912
  if do_true_cfg:
913
+ raise NotImplementedError()
914
  with self.transformer.cache_context("uncond"):
915
  neg_noise_pred = self.transformer(
916
  hidden_states=latent_model_input,
 
935
 
936
  latents = t_utils.inference_ode_step(noise_pred, latents, i, ts)
937
 
938
+ with ctimed("dtype stuff"):
939
+ if latents.dtype != latents_dtype:
940
+ if torch.backends.mps.is_available():
941
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
942
+ latents = latents.to(latents_dtype)
943
 
944
+ with ctimed("callback and shenanagans"):
945
+ if callback_on_step_end is not None:
946
+ callback_kwargs = {}
947
+ for k in callback_on_step_end_tensor_inputs:
948
+ callback_kwargs[k] = locals()[k]
949
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
 
 
 
 
950
 
951
+ latents = callback_outputs.pop("latents", latents)
952
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
953
 
954
+ # call the callback, if provided
955
+ # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
956
+ progress_bar.update()
957
 
958
+ if XLA_AVAILABLE:
959
+ xm.mark_step()
960
 
 
961
  self._current_timestep = None
962
  if output_type == "latent":
963
  image = latents
 
968
  latents_mean = (
969
  torch.tensor(self.vae.config.latents_mean)
970
  .view(1, self.vae.config.z_dim, 1, 1, 1)
971
+ .to(device, self.vae.dtype)
972
  )
973
  latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
974
+ device, self.vae.dtype
975
  )
976
  latents = latents / latents_std + latents_mean
 
977
  image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
978
+
979
+ image = image.squeeze(0).add(1).mul(127.5).to(torch.uint8).cpu()
980
+
981
+ image_path = f"/tmp/{str(uuid.uuid4())[:8]}.jpg"
982
+ write_jpeg(image, image_path)
983
+ image = (image_path,)
984
+
985
+
986
+ # with ctimed("Post (vae)"):
987
+ # self._current_timestep = None
988
+ # if output_type == "latent":
989
+ # image = latents
990
+ # else:
991
+ # with ctimed("pre decode"):
992
+ # latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
993
+
994
+ # if prompt_cached:
995
+ # latents_mean, latents_std = self.latents_mean, self.latents_std
996
+ # else:
997
+ # latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=latents.dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
998
+ # latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=latents.dtype).view(1, self.vae.config.z_dim, 1, 1, 1)
999
+
1000
+ # self.latents_mean, self.latents_std = latents_mean, latents_std
1001
+
1002
+ # latents = latents / latents_std + latents_mean
1003
+ # with ctimed("todtype"):
1004
+ # latents = latents.to(self.vae.dtype)
1005
+ # with ctimed("vae.decode"):
1006
+ # image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] # [B,C,H,W]
1007
+
1008
+ # with ctimed("post process"):
1009
+ # with ctimed("convert"):
1010
+ # image = image.squeeze(0).add(1).mul(127.5).to(torch.uint8).cpu()
1011
+ # with ctimed("write"):
1012
+ # image_path = f"/tmp/{str(uuid.uuid4())[:8]}.jpg"
1013
+ # write_jpeg(image, image_path)
1014
+ # image = (image_path,)
1015
+
1016
 
1017
 
1018
  # Offload all models
qwenimage/models/transformer_qwenimage.py CHANGED
@@ -35,6 +35,7 @@ from diffusers.models.modeling_utils import ModelMixin
35
  from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
36
 
37
  from qwenimage.activation_record import ActivationReport
 
38
  from qwenimage.models.attention_processors import QwenDoubleStreamAttnProcessor2_0
39
 
40
 
@@ -511,61 +512,66 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
511
  else:
512
  lora_scale = 1.0
513
 
514
- if USE_PEFT_BACKEND:
515
- # weight the lora layers by setting `lora_scale` for each PEFT layer
516
- scale_lora_layers(self, lora_scale)
517
- else:
518
- if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
519
- logger.warning(
520
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
521
- )
522
-
523
- hidden_states = self.img_in(hidden_states)
524
-
525
- timestep = timestep.to(hidden_states.dtype)
526
- encoder_hidden_states = self.txt_norm(encoder_hidden_states)
527
- encoder_hidden_states = self.txt_in(encoder_hidden_states)
528
-
529
- if guidance is not None:
530
- guidance = guidance.to(hidden_states.dtype) * 1000
531
-
532
- temb = (
533
- self.time_text_embed(timestep, hidden_states)
534
- if guidance is None
535
- else self.time_text_embed(timestep, guidance, hidden_states)
536
- )
537
-
538
- for index_block, block in enumerate(self.transformer_blocks):
539
- if torch.is_grad_enabled() and self.gradient_checkpointing:
540
- warnings.warn("Gradient ckpt?")
541
- encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
542
- block,
543
- hidden_states,
544
- encoder_hidden_states,
545
- encoder_hidden_states_mask,
546
- temb,
547
- image_rotary_emb,
548
- )
549
-
550
  else:
551
- encoder_hidden_states, hidden_states = block(
552
- hidden_states=hidden_states,
553
- encoder_hidden_states=encoder_hidden_states,
554
- encoder_hidden_states_mask=encoder_hidden_states_mask,
555
- temb=temb,
556
- image_rotary_emb=image_rotary_emb,
557
- joint_attention_kwargs=attention_kwargs,
558
- )
559
- self.arec(f"encoder_hidden_states.{index_block}", encoder_hidden_states)
560
- self.arec(f"hidden_states.{index_block}", hidden_states)
561
-
562
- # Use only the image part (hidden_states) from the dual-stream blocks
563
- hidden_states = self.norm_out(hidden_states, temb)
564
- output = self.proj_out(hidden_states)
565
-
566
- if USE_PEFT_BACKEND:
567
- # remove `lora_scale` from each PEFT layer
568
- unscale_lora_layers(self, lora_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
  if not return_dict:
571
  return (output,)
 
35
  from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
36
 
37
  from qwenimage.activation_record import ActivationReport
38
+ from qwenimage.debug import ctimed
39
  from qwenimage.models.attention_processors import QwenDoubleStreamAttnProcessor2_0
40
 
41
 
 
512
  else:
513
  lora_scale = 1.0
514
 
515
+ with ctimed("scale lora"):
516
+ if USE_PEFT_BACKEND:
517
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
518
+ scale_lora_layers(self, lora_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  else:
520
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
521
+ logger.warning(
522
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
523
+ )
524
+
525
+ with ctimed("pre blocks"):
526
+ hidden_states = self.img_in(hidden_states)
527
+
528
+ timestep = timestep.to(hidden_states.dtype)
529
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
530
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
531
+
532
+ if guidance is not None:
533
+ guidance = guidance.to(hidden_states.dtype) * 1000
534
+
535
+ temb = (
536
+ self.time_text_embed(timestep, hidden_states)
537
+ if guidance is None
538
+ else self.time_text_embed(timestep, guidance, hidden_states)
539
+ )
540
+
541
+ with ctimed("blocks"):
542
+ for index_block, block in enumerate(self.transformer_blocks):
543
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
544
+ warnings.warn("Gradient ckpt?")
545
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
546
+ block,
547
+ hidden_states,
548
+ encoder_hidden_states,
549
+ encoder_hidden_states_mask,
550
+ temb,
551
+ image_rotary_emb,
552
+ )
553
+
554
+ else:
555
+ encoder_hidden_states, hidden_states = block(
556
+ hidden_states=hidden_states,
557
+ encoder_hidden_states=encoder_hidden_states,
558
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
559
+ temb=temb,
560
+ image_rotary_emb=image_rotary_emb,
561
+ joint_attention_kwargs=attention_kwargs,
562
+ )
563
+ self.arec(f"encoder_hidden_states.{index_block}", encoder_hidden_states)
564
+ self.arec(f"hidden_states.{index_block}", hidden_states)
565
+
566
+ with ctimed("post blocks"):
567
+ # Use only the image part (hidden_states) from the dual-stream blocks
568
+ hidden_states = self.norm_out(hidden_states, temb)
569
+ output = self.proj_out(hidden_states)
570
+
571
+ with ctimed("lora"):
572
+ if USE_PEFT_BACKEND:
573
+ # remove `lora_scale` from each PEFT layer
574
+ unscale_lora_layers(self, lora_scale)
575
 
576
  if not return_dict:
577
  return (output,)