Elea Zhong commited on
Commit
4cd7f21
·
1 Parent(s): 8a268b5

add fbcache, lpips comparison, 50 step

Browse files
qwenimage/experiments/experiments_qwen.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  from PIL import Image
13
  import pandas as pd
14
  from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
 
15
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, quantize_
16
  from torchao.quantization import Int8WeightOnlyConfig
17
  import spaces
@@ -19,7 +20,8 @@ import torch
19
  from torch.utils._pytree import tree_map
20
  from torchao.utils import get_model_size_in_bytes
21
 
22
- from qwenimage.debug import ftimed, print_first_param
 
23
  from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
24
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
25
  from qwenimage.models.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
@@ -48,6 +50,19 @@ class ExperimentRegistry:
48
  raise KeyError(f"{name} not in {list(cls.registry.keys())}")
49
  return cls.registry[name]
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @classmethod
52
  def keys(cls):
53
  return list(cls.registry.keys())
@@ -142,7 +157,6 @@ class QwenBaseExperiment(AbstractExperiment):
142
  def optimize(self):
143
  pass
144
 
145
- @ftimed
146
  def run_once(self, *args, **kwargs):
147
  return self.pipe(*args, **kwargs).images[0]
148
 
@@ -152,7 +166,8 @@ class QwenBaseExperiment(AbstractExperiment):
152
 
153
  for i in range(self.config.iterations):
154
  inputs = self.pipe_inputs[i]
155
- output = self.run_once(**inputs)
 
156
  output.save(output_save_dir / f"{i:03d}.jpg")
157
 
158
  def report(self):
@@ -182,6 +197,40 @@ class QwenBaseExperiment(AbstractExperiment):
182
  del self.pipe.transformer
183
  del self.pipe
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  @ExperimentRegistry.register(name="qwen_lightning_lora")
186
  class Qwen_Lightning_Lora(QwenBaseExperiment):
187
  @ftimed
@@ -243,28 +292,24 @@ class Qwen_Lightning_Lora(QwenBaseExperiment):
243
 
244
  @ExperimentRegistry.register(name="qwen_lightning_lora_3step")
245
  class Qwen_Lightning_Lora_3step(Qwen_Lightning_Lora):
246
- @ftimed
247
  def run_once(self, *args, **kwargs):
248
  kwargs["num_inference_steps"] = 3
249
  return self.pipe(*args, **kwargs).images[0]
250
 
251
  @ExperimentRegistry.register(name="qwen_base_3step")
252
  class Qwen_Base_3step(QwenBaseExperiment):
253
- @ftimed
254
  def run_once(self, *args, **kwargs):
255
  kwargs["num_inference_steps"] = 3
256
  return self.pipe(*args, **kwargs).images[0]
257
 
258
  @ExperimentRegistry.register(name="qwen_lightning_lora_2step")
259
  class Qwen_Lightning_Lora_2step(Qwen_Lightning_Lora):
260
- @ftimed
261
  def run_once(self, *args, **kwargs):
262
  kwargs["num_inference_steps"] = 2
263
  return self.pipe(*args, **kwargs).images[0]
264
 
265
  @ExperimentRegistry.register(name="qwen_base_2step")
266
  class Qwen_Base_2step(QwenBaseExperiment):
267
- @ftimed
268
  def run_once(self, *args, **kwargs):
269
  kwargs["num_inference_steps"] = 2
270
  return self.pipe(*args, **kwargs).images[0]
@@ -582,20 +627,129 @@ class Qwen_Lightning_FA3_AoT_int8_fuse(Qwen_Lightning_Lora):
582
 
583
  @ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_2step")
584
  class Qwen_Lightning_FA3_AoT_int8_fuse_2step(Qwen_Lightning_FA3_AoT_int8_fuse):
585
- @ftimed
586
  def run_once(self, *args, **kwargs):
587
  kwargs["num_inference_steps"] = 2
588
  return self.pipe(*args, **kwargs).images[0]
589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  @ExperimentRegistry.register(name="qwen_channels_last")
592
  class Qwen_Channels_Last(QwenBaseExperiment):
593
  """
594
- This experiment is fully useless: channels last format only works with NCHW tensors,
595
  i.e. 2D CNNs, transformer is 1D and vae is 3D, plus, for it to work the inputs need to
596
  be converted in-pipe as well. left for reference.
597
  """
598
  @ftimed
599
  def optimize(self):
600
- self.pipe.vae = self.pipe.vae.to(memory_format=torch.channels_last)
601
- self.pipe.transformer = self.pipe.transformer.to(memory_format=torch.channels_last)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from PIL import Image
13
  import pandas as pd
14
  from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
15
+ from torchao import autoquant
16
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, quantize_
17
  from torchao.quantization import Int8WeightOnlyConfig
18
  import spaces
 
20
  from torch.utils._pytree import tree_map
21
  from torchao.utils import get_model_size_in_bytes
22
 
23
+ from qwenimage.debug import ctimed, ftimed, print_first_param
24
+ from qwenimage.models.first_block_cache import apply_cache_on_pipe
25
  from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
26
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
27
  from qwenimage.models.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
 
50
  raise KeyError(f"{name} not in {list(cls.registry.keys())}")
51
  return cls.registry[name]
52
 
53
+ @classmethod
54
+ def filter(cls, startswith=None, endswith=None, contains=None, not_contain=None):
55
+ keys = list(cls.registry.keys())
56
+ if startswith is not None:
57
+ keys = [k for k in keys if k.startswith(startswith)]
58
+ if endswith is not None:
59
+ keys = [k for k in keys if k.endswith(endswith)]
60
+ if contains is not None:
61
+ keys = [k for k in keys if contains in k]
62
+ if not_contain is not None:
63
+ keys = [k for k in keys if not_contain not in k]
64
+ return keys
65
+
66
  @classmethod
67
  def keys(cls):
68
  return list(cls.registry.keys())
 
157
  def optimize(self):
158
  pass
159
 
 
160
  def run_once(self, *args, **kwargs):
161
  return self.pipe(*args, **kwargs).images[0]
162
 
 
166
 
167
  for i in range(self.config.iterations):
168
  inputs = self.pipe_inputs[i]
169
+ with ctimed("run_once"):
170
+ output = self.run_once(**inputs)
171
  output.save(output_save_dir / f"{i:03d}.jpg")
172
 
173
  def report(self):
 
197
  del self.pipe.transformer
198
  del self.pipe
199
 
200
+ @ExperimentRegistry.register(name="qwen_50step")
201
+ class Qwen_50Step(QwenBaseExperiment):
202
+ @ftimed
203
+ def load(self):
204
+ dtype = torch.bfloat16
205
+ device = "cuda" if torch.cuda.is_available() else "cpu"
206
+ print(f"experiment load cuda: {torch.cuda.is_available()=}")
207
+
208
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
209
+ "Qwen/Qwen-Image-Edit-2509",
210
+ transformer=QwenImageTransformer2DModel.from_pretrained( # use our own model
211
+ "Qwen/Qwen-Image-Edit-2509",
212
+ subfolder='transformer',
213
+ torch_dtype=dtype,
214
+ device_map=device
215
+ ),
216
+ torch_dtype=dtype,
217
+ ).to(device)
218
+
219
+ pipe.load_lora_weights(
220
+ "dx8152/Qwen-Edit-2509-Multiple-angles",
221
+ weight_name="镜头转换.safetensors", adapter_name="angles"
222
+ )
223
+
224
+ pipe.set_adapters(["angles"], adapter_weights=[1.])
225
+ pipe.fuse_lora(adapter_names=["angles"], lora_scale=1.25)
226
+ pipe.unload_lora_weights()
227
+ self.pipe = pipe
228
+
229
+ def run_once(self, *args, **kwargs):
230
+ kwargs["num_inference_steps"] = 50
231
+ return self.pipe(*args, **kwargs).images[0]
232
+
233
+
234
  @ExperimentRegistry.register(name="qwen_lightning_lora")
235
  class Qwen_Lightning_Lora(QwenBaseExperiment):
236
  @ftimed
 
292
 
293
  @ExperimentRegistry.register(name="qwen_lightning_lora_3step")
294
  class Qwen_Lightning_Lora_3step(Qwen_Lightning_Lora):
 
295
  def run_once(self, *args, **kwargs):
296
  kwargs["num_inference_steps"] = 3
297
  return self.pipe(*args, **kwargs).images[0]
298
 
299
  @ExperimentRegistry.register(name="qwen_base_3step")
300
  class Qwen_Base_3step(QwenBaseExperiment):
 
301
  def run_once(self, *args, **kwargs):
302
  kwargs["num_inference_steps"] = 3
303
  return self.pipe(*args, **kwargs).images[0]
304
 
305
  @ExperimentRegistry.register(name="qwen_lightning_lora_2step")
306
  class Qwen_Lightning_Lora_2step(Qwen_Lightning_Lora):
 
307
  def run_once(self, *args, **kwargs):
308
  kwargs["num_inference_steps"] = 2
309
  return self.pipe(*args, **kwargs).images[0]
310
 
311
  @ExperimentRegistry.register(name="qwen_base_2step")
312
  class Qwen_Base_2step(QwenBaseExperiment):
 
313
  def run_once(self, *args, **kwargs):
314
  kwargs["num_inference_steps"] = 2
315
  return self.pipe(*args, **kwargs).images[0]
 
627
 
628
  @ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_2step")
629
  class Qwen_Lightning_FA3_AoT_int8_fuse_2step(Qwen_Lightning_FA3_AoT_int8_fuse):
 
630
  def run_once(self, *args, **kwargs):
631
  kwargs["num_inference_steps"] = 2
632
  return self.pipe(*args, **kwargs).images[0]
633
 
634
+ @ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_3step")
635
+ class Qwen_Lightning_FA3_AoT_int8_fuse_3step(Qwen_Lightning_FA3_AoT_int8_fuse):
636
+ def run_once(self, *args, **kwargs):
637
+ kwargs["num_inference_steps"] = 3
638
+ return self.pipe(*args, **kwargs).images[0]
639
+
640
+ @ExperimentRegistry.register(name="qwen_fa3_aot_int8_fuse_2step")
641
+ class Qwen_FA3_AoT_int8_fuse_2step(Qwen_FA3_AoT_int8_fuse):
642
+ def run_once(self, *args, **kwargs):
643
+ kwargs["num_inference_steps"] = 2
644
+ return self.pipe(*args, **kwargs).images[0]
645
+
646
+ @ExperimentRegistry.register(name="qwen_fa3_aot_int8_fuse_3step")
647
+ class Qwen_FA3_AoT_int8_fuse_3step(Qwen_FA3_AoT_int8_fuse):
648
+ def run_once(self, *args, **kwargs):
649
+ kwargs["num_inference_steps"] = 3
650
+ return self.pipe(*args, **kwargs).images[0]
651
 
652
  @ExperimentRegistry.register(name="qwen_channels_last")
653
  class Qwen_Channels_Last(QwenBaseExperiment):
654
  """
655
+ This experiment may be useless: channels last format only works with NCHW tensors,
656
  i.e. 2D CNNs, transformer is 1D and vae is 3D, plus, for it to work the inputs need to
657
  be converted in-pipe as well. left for reference.
658
  """
659
  @ftimed
660
  def optimize(self):
661
+ # self.pipe.vae = self.pipe.vae.to(memory_format=torch.channels_last_3d)
662
+ self.pipe.transformer = self.pipe.transformer.to(memory_format=torch.channels_last)
663
+
664
+ @ExperimentRegistry.register(name="qwen_fbcache_05")
665
+ class Qwen_FBCache_05(QwenBaseExperiment):
666
+ @ftimed
667
+ def optimize(self):
668
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.5,)
669
+
670
+
671
+ @ExperimentRegistry.register(name="qwen_fbcache_055")
672
+ class Qwen_FBCache_055(QwenBaseExperiment):
673
+ @ftimed
674
+ def optimize(self):
675
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.55,)
676
+
677
+ @ExperimentRegistry.register(name="qwen_fbcache_054")
678
+ class Qwen_FBCache_054(QwenBaseExperiment):
679
+ @ftimed
680
+ def optimize(self):
681
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.54,)
682
+
683
+ @ExperimentRegistry.register(name="qwen_fbcache_053")
684
+ class Qwen_FBCache_053(QwenBaseExperiment):
685
+ @ftimed
686
+ def optimize(self):
687
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.53,)
688
+
689
+ @ExperimentRegistry.register(name="qwen_fbcache_052")
690
+ class Qwen_FBCache_052(QwenBaseExperiment):
691
+ @ftimed
692
+ def optimize(self):
693
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.52,)
694
+
695
+ @ExperimentRegistry.register(name="qwen_fbcache_051")
696
+ class Qwen_FBCache_051(QwenBaseExperiment):
697
+ @ftimed
698
+ def optimize(self):
699
+ apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.51,)
700
+
701
+
702
+ # @ExperimentRegistry.register(name="qwen_lightning_fa3_aot_autoquant_fuse")
703
+ class Qwen_lightning_FA3_AoT_autoquant_fuse(Qwen_Lightning_Lora):
704
+ """
705
+ Seemingly not working with AoT export
706
+ """
707
+ @ftimed
708
+ def optimize(self):
709
+ self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
710
+ self.pipe.transformer.fuse_qkv_projections()
711
+
712
+ pipe_kwargs={
713
+ "image": [Image.new("RGB", (1024, 1024))],
714
+ "prompt":"prompt",
715
+ "num_inference_steps":4
716
+ }
717
+ suffix="_autoquant_fa3_fuse"
718
+
719
+ cache_compiled=self.config.cache_compiled
720
+
721
+ transformer_pt2_cache_path = f"checkpoints/transformer_{suffix}_archive.pt2"
722
+ transformer_weights_cache_path = f"checkpoints/transformer_{suffix}_weights.pt"
723
+
724
+ print(f"original model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
725
+ autoquant(self.pipe.transformer, error_on_unseen=False)
726
+ print_first_param(self.pipe.transformer)
727
+ print(f"quantized model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
728
+
729
+ inductor_config = INDUCTOR_CONFIGS
730
+
731
+ if os.path.isfile(transformer_pt2_cache_path) and cache_compiled:
732
+ drain_module_parameters(self.pipe.transformer)
733
+ zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
734
+ compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
735
+ else:
736
+ with spaces.aoti_capture(self.pipe.transformer) as call:
737
+ self.pipe(**pipe_kwargs)
738
+
739
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
740
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
741
+
742
+ exported = torch.export.export(
743
+ mod=self.pipe.transformer,
744
+ args=call.args,
745
+ kwargs=call.kwargs,
746
+ dynamic_shapes=dynamic_shapes,
747
+ )
748
+
749
+ compiled_transformer = spaces.aoti_compile(exported, inductor_config)
750
+ with open(transformer_pt2_cache_path, "wb") as f:
751
+ f.write(compiled_transformer.archive_file.getvalue())
752
+ torch.save(compiled_transformer.weights, transformer_weights_cache_path)
753
+
754
+
755
+ aoti_apply(compiled_transformer, self.pipe.transformer)
qwenimage/models/first_block_cache.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import unittest
3
+
4
+ import torch
5
+
6
+ from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
7
+ from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
8
+
9
+ from para_attn.first_block_cache import utils
10
+
11
+
12
+ def apply_cache_on_transformer(
13
+ transformer: QwenImageTransformer2DModel,
14
+ ):
15
+ if getattr(transformer, "_is_cached", False):
16
+ return transformer
17
+
18
+ cached_transformer_blocks = torch.nn.ModuleList(
19
+ [
20
+ utils.CachedTransformerBlocks(
21
+ transformer.transformer_blocks,
22
+ transformer=transformer,
23
+ return_hidden_states_first=False,
24
+ )
25
+ ]
26
+ )
27
+
28
+ original_forward = transformer.forward
29
+
30
+ @functools.wraps(transformer.__class__.forward)
31
+ def new_forward(
32
+ self,
33
+ *args,
34
+ **kwargs,
35
+ ):
36
+ with unittest.mock.patch.object(
37
+ self,
38
+ "transformer_blocks",
39
+ cached_transformer_blocks,
40
+ ):
41
+ return original_forward(
42
+ *args,
43
+ **kwargs,
44
+ )
45
+
46
+ transformer.forward = new_forward.__get__(transformer)
47
+
48
+ transformer._is_cached = True
49
+
50
+ return transformer
51
+
52
+
53
+ def apply_cache_on_pipe(
54
+ pipe: QwenImageEditPlusPipeline,
55
+ *,
56
+ shallow_patch: bool = False,
57
+ **kwargs,
58
+ ):
59
+ if not getattr(pipe, "_is_cached", False):
60
+ original_call = pipe.__class__.__call__
61
+
62
+ @functools.wraps(original_call)
63
+ def new_call(self, *args, **kwargs_):
64
+ with utils.cache_context(utils.create_cache_context(**kwargs)):
65
+ return original_call(self, *args, **kwargs_)
66
+
67
+ pipe.__class__.__call__ = new_call
68
+ pipe.__class__._is_cached = True
69
+
70
+ if not shallow_patch:
71
+ apply_cache_on_transformer(pipe.transformer)
72
+
73
+ return pipe
qwenimage/models/pipeline_qwenimage_edit_plus.py CHANGED
@@ -548,6 +548,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
548
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
549
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
550
  max_sequence_length: int = 512,
 
551
  ):
552
  r"""
553
  Function invoked when calling the pipeline for generation.
@@ -665,6 +666,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
665
  self._attention_kwargs = attention_kwargs
666
  self._current_timestep = None
667
  self._interrupt = False
 
668
 
669
  # 2. Define call parameters
670
  if prompt is not None and isinstance(prompt, str):
 
548
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
549
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
550
  max_sequence_length: int = 512,
551
+ channels_last_format: bool = False,
552
  ):
553
  r"""
554
  Function invoked when calling the pipeline for generation.
 
666
  self._attention_kwargs = attention_kwargs
667
  self._current_timestep = None
668
  self._interrupt = False
669
+ self.channels_last_format = channels_last_format
670
 
671
  # 2. Define call parameters
672
  if prompt is not None and isinstance(prompt, str):
qwenimage/models/transformer_qwenimage.py CHANGED
@@ -15,6 +15,7 @@
15
  import functools
16
  import math
17
  from typing import Any, Dict, List, Optional, Tuple, Union
 
18
 
19
  import torch
20
  import torch.nn as nn
@@ -615,6 +616,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
615
 
616
  for index_block, block in enumerate(self.transformer_blocks):
617
  if torch.is_grad_enabled() and self.gradient_checkpointing:
 
618
  encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
619
  block,
620
  hidden_states,
 
15
  import functools
16
  import math
17
  from typing import Any, Dict, List, Optional, Tuple, Union
18
+ import warnings
19
 
20
  import torch
21
  import torch.nn as nn
 
616
 
617
  for index_block, block in enumerate(self.transformer_blocks):
618
  if torch.is_grad_enabled() and self.gradient_checkpointing:
619
+ warnings.warn("Gradient ckpt?")
620
  encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
621
  block,
622
  hidden_states,
requirements.txt CHANGED
@@ -15,4 +15,5 @@ pydantic
15
  pandas
16
  modal
17
 
 
18
  lpips
 
15
  pandas
16
  modal
17
 
18
+ para-attn
19
  lpips
scripts/lpips_compare.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
scripts/plot_data.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
scripts/scratch.ipynb CHANGED
@@ -606,26 +606,7 @@
606
  "id": "e571d339",
607
  "metadata": {},
608
  "outputs": [],
609
- "source": [
610
- "mvae_params = sum(p.numel() for p in flux.blend_model.mvae.parameters())\n",
611
- "model_size_mb = sum(p.numel() * p.element_size() for p in flux.blend_model.mvae.parameters()) / (1024 ** 2)\n",
612
- "print(f\"mvae parameter count: {mvae_params:,}\")\n",
613
- "print(f\"mvae model size: {model_size_mb:.2f} MB\")\n",
614
- "\n",
615
- "transformer_params = sum(p.numel() for p in flux.transformer.parameters())\n",
616
- "model_size_mb = sum(p.numel() * p.element_size() for p in flux.transformer.parameters()) / (1024 ** 2)\n",
617
- "print(f\"flux.transformer parameter count: {transformer_params:,}\")\n",
618
- "print(f\"flux.transformer model size: {model_size_mb:.2f} MB\")\n",
619
- "\n",
620
- "vae_params = sum(p.numel() for p in flux.vae.parameters())\n",
621
- "model_size_mb = sum(p.numel() * p.element_size() for p in flux.vae.parameters()) / (1024 ** 2)\n",
622
- "print(f\"flux.vae parameter count: {vae_params:,}\")\n",
623
- "print(f\"flux.vae model size: {model_size_mb:.2f} MB\")\n",
624
- "\n",
625
- "print(f\"\\nParameter comparisons:\")\n",
626
- "print(f\"mvae vs transformer: {mvae_params / transformer_params * 100:.3f}%\")\n",
627
- "print(f\"mvae vs vae: {mvae_params / vae_params * 100:.3f}%\")\n"
628
- ]
629
  },
630
  {
631
  "cell_type": "code",
 
606
  "id": "e571d339",
607
  "metadata": {},
608
  "outputs": [],
609
+ "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  },
611
  {
612
  "cell_type": "code",
scripts/visual_compare.ipynb CHANGED
@@ -231,119 +231,7 @@
231
  "id": "244dfe0f",
232
  "metadata": {},
233
  "outputs": [],
234
- "source": [
235
- "\n",
236
- "import lpips\n",
237
- "import torch\n",
238
- "from PIL import Image\n",
239
- "import torchvision.transforms as transforms\n",
240
- "\n",
241
- "# Initialize LPIPS model\n",
242
- "loss_fn = lpips.LPIPS(net='alex') # or 'vgg' or 'squeeze'\n",
243
- "if torch.cuda.is_available():\n",
244
- " loss_fn = loss_fn.cuda()\n",
245
- "\n",
246
- "# Transform to convert PIL images to tensors\n",
247
- "transform = transforms.Compose([\n",
248
- " transforms.ToTensor(),\n",
249
- " transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
250
- "])\n",
251
- "\n",
252
- "def calculate_lpips_scores(base_paths, compare_paths):\n",
253
- " \"\"\"Calculate LPIPS scores between two sets of images.\"\"\"\n",
254
- " scores = []\n",
255
- " \n",
256
- " # Get the minimum number of images available\n",
257
- " num_images = min(len(base_paths), len(compare_paths))\n",
258
- " \n",
259
- " for idx in range(num_images):\n",
260
- " # Load images\n",
261
- " img1 = Image.open(base_paths[idx]).convert('RGB')\n",
262
- " img2 = Image.open(compare_paths[idx]).convert('RGB')\n",
263
- " \n",
264
- " # Resize if dimensions don't match\n",
265
- " if img1.size != img2.size:\n",
266
- " img2 = img2.resize(img1.size, Image.LANCZOS)\n",
267
- " \n",
268
- " # Transform to tensors\n",
269
- " img1_tensor = transform(img1).unsqueeze(0)\n",
270
- " img2_tensor = transform(img2).unsqueeze(0)\n",
271
- " \n",
272
- " if torch.cuda.is_available():\n",
273
- " img1_tensor = img1_tensor.cuda()\n",
274
- " img2_tensor = img2_tensor.cuda()\n",
275
- " \n",
276
- " # Calculate LPIPS\n",
277
- " with torch.no_grad():\n",
278
- " score = loss_fn(img1_tensor, img2_tensor)\n",
279
- " \n",
280
- " scores.append(score.item())\n",
281
- " \n",
282
- " return scores\n",
283
- "\n",
284
- "# Define experiment sets\n",
285
- "experiment_sets = {\n",
286
- " 'qwen_base': {\n",
287
- " '4step': 'qwen_base',\n",
288
- " '3step': 'qwen_base_3step',\n",
289
- " '2step': 'qwen_base_2step'\n",
290
- " },\n",
291
- " 'qwen_lightning_lora': {\n",
292
- " '4step': 'qwen_lightning_lora',\n",
293
- " '3step': 'qwen_lightning_lora_3step',\n",
294
- " '2step': 'qwen_lightning_lora_2step'\n",
295
- " }\n",
296
- "}\n",
297
- "\n",
298
- "# Calculate LPIPS scores for each set\n",
299
- "results = {}\n",
300
- "\n",
301
- "for set_name, experiments in experiment_sets.items():\n",
302
- " print(f\"\\nProcessing {set_name}...\")\n",
303
- " \n",
304
- " # Get image paths\n",
305
- " base_4step_paths = experiment_outputs[experiments['4step']]\n",
306
- " step_3_paths = experiment_outputs[experiments['3step']]\n",
307
- " step_2_paths = experiment_outputs[experiments['2step']]\n",
308
- " \n",
309
- " # Calculate LPIPS scores\n",
310
- " print(f\"Calculating LPIPS: 4-step vs 3-step...\")\n",
311
- " scores_4vs3 = calculate_lpips_scores(base_4step_paths, step_3_paths)\n",
312
- " \n",
313
- " print(f\"Calculating LPIPS: 4-step vs 2-step...\")\n",
314
- " scores_4vs2 = calculate_lpips_scores(base_4step_paths, step_2_paths)\n",
315
- " \n",
316
- " # Create results dataframe\n",
317
- " results_df = pd.DataFrame({\n",
318
- " 'comparison': ['4step_vs_3step', '4step_vs_2step'],\n",
319
- " 'mean_lpips': [\n",
320
- " np.mean(scores_4vs3),\n",
321
- " np.mean(scores_4vs2)\n",
322
- " ],\n",
323
- " 'std_lpips': [\n",
324
- " np.std(scores_4vs3),\n",
325
- " np.std(scores_4vs2)\n",
326
- " ],\n",
327
- " 'num_samples': [\n",
328
- " len(scores_4vs3),\n",
329
- " len(scores_4vs2)\n",
330
- " ]\n",
331
- " })\n",
332
- " \n",
333
- " # Save to CSV\n",
334
- " csv_path = report_dir / f\"lpips_scores_{set_name}.csv\"\n",
335
- " results_df.to_csv(csv_path, index=False)\n",
336
- " \n",
337
- " print(f\"\\nResults for {set_name}:\")\n",
338
- " print(results_df)\n",
339
- " print(f\"\\nSaved to: {csv_path}\")\n",
340
- " \n",
341
- " results[set_name] = results_df\n",
342
- "\n",
343
- "print(\"\\n\" + \"=\"*60)\n",
344
- "print(\"LPIPS Analysis Complete!\")\n",
345
- "print(\"=\"*60)\n"
346
- ]
347
  }
348
  ],
349
  "metadata": {
 
231
  "id": "244dfe0f",
232
  "metadata": {},
233
  "outputs": [],
234
+ "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  }
236
  ],
237
  "metadata": {