Elea Zhong commited on
Commit
cb0b907
·
1 Parent(s): 1ae3ff6

fp8 experiment

Browse files
app.py CHANGED
@@ -17,8 +17,8 @@ from safetensors.torch import load_file
17
  from torchao.quantization import quantize_
18
  from torchao.quantization import Int8WeightOnlyConfig
19
 
20
- from qwenimage.debug import ftimed
21
- from qwenimage.experiments.experiments_qwen import Qwen_FA3_AoT_int8
22
  from qwenimage.optimization import optimize_pipeline_
23
  from qwenimage.prompt import build_camera_prompt
24
  from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
@@ -30,6 +30,7 @@ dtype = torch.bfloat16
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
  exp = Qwen_FA3_AoT_int8()
 
33
  exp.load()
34
  exp.optimize()
35
  pipe = exp.pipe
@@ -55,30 +56,31 @@ def infer_camera_edit(
55
  prev_output = None,
56
  progress=gr.Progress(track_tqdm=True)
57
  ):
58
- prompt = build_camera_prompt(rotate_deg, move_forward, vertical_tilt, wideangle)
59
- print(f"Generated Prompt: {prompt}")
60
-
61
- if randomize_seed:
62
- seed = random.randint(0, MAX_SEED)
63
- generator = torch.Generator(device=device).manual_seed(seed)
64
-
65
- # Choose input image (prefer uploaded, else last output)
66
- pil_images = []
67
- if image is not None:
68
- if isinstance(image, Image.Image):
69
- pil_images.append(image.convert("RGB"))
70
- elif hasattr(image, "name"):
71
- pil_images.append(Image.open(image.name).convert("RGB"))
72
- elif prev_output:
73
- pil_images.append(prev_output.convert("RGB"))
74
-
75
- if len(pil_images) == 0:
76
- raise gr.Error("Please upload an image first.")
77
-
78
- print(f"{len(pil_images)=}")
 
79
 
80
- if prompt == "no camera movement":
81
- return image, seed, prompt
82
  result = pipe(
83
  image=pil_images,
84
  prompt=prompt,
@@ -154,7 +156,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
154
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
155
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
156
  true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
157
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=4)
158
  height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
159
  width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
160
 
@@ -202,6 +204,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
202
 
203
 
204
  # Live updates
 
205
  def maybe_infer(is_reset, progress=gr.Progress(track_tqdm=True), *args):
206
  if is_reset:
207
  return gr.update(), gr.update(), gr.update(), gr.update()
 
17
  from torchao.quantization import quantize_
18
  from torchao.quantization import Int8WeightOnlyConfig
19
 
20
+ from qwenimage.debug import ctimed, ftimed
21
+ from qwenimage.experiments.experiments_qwen import Qwen_FA3_AoT_fp8, Qwen_FA3_AoT_int8, QwenBaseExperiment
22
  from qwenimage.optimization import optimize_pipeline_
23
  from qwenimage.prompt import build_camera_prompt
24
  from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
  exp = Qwen_FA3_AoT_int8()
33
+ # exp = Qwen_FA3_AoT_fp8()
34
  exp.load()
35
  exp.optimize()
36
  pipe = exp.pipe
 
56
  prev_output = None,
57
  progress=gr.Progress(track_tqdm=True)
58
  ):
59
+ with ctimed("pre pipe"):
60
+ prompt = build_camera_prompt(rotate_deg, move_forward, vertical_tilt, wideangle)
61
+ print(f"Generated Prompt: {prompt}")
62
+
63
+ if randomize_seed:
64
+ seed = random.randint(0, MAX_SEED)
65
+ generator = torch.Generator(device=device).manual_seed(seed)
66
+
67
+ # Choose input image (prefer uploaded, else last output)
68
+ pil_images = []
69
+ if image is not None:
70
+ if isinstance(image, Image.Image):
71
+ pil_images.append(image.convert("RGB"))
72
+ elif hasattr(image, "name"):
73
+ pil_images.append(Image.open(image.name).convert("RGB"))
74
+ elif prev_output:
75
+ pil_images.append(prev_output.convert("RGB"))
76
+
77
+ if len(pil_images) == 0:
78
+ raise gr.Error("Please upload an image first.")
79
+
80
+ print(f"{len(pil_images)=}")
81
 
82
+ if prompt == "no camera movement":
83
+ return image, seed, prompt
84
  result = pipe(
85
  image=pil_images,
86
  prompt=prompt,
 
156
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
157
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
158
  true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
159
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=3)
160
  height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
161
  width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
162
 
 
204
 
205
 
206
  # Live updates
207
+ @ftimed
208
  def maybe_infer(is_reset, progress=gr.Progress(track_tqdm=True), *args):
209
  if is_reset:
210
  return gr.update(), gr.update(), gr.update(), gr.update()
qwenimage/experiments/experiments_qwen.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  from PIL import Image
11
  import pandas as pd
12
  from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
13
- from torchao.quantization import Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, quantize_
14
  from torchao.quantization import Int8WeightOnlyConfig
15
  import spaces
16
  import torch
@@ -238,7 +238,7 @@ class Qwen_FA3_AoT_int8(QwenBaseExperiment):
238
  )
239
 
240
 
241
- @ExperimentRegistry.register(name="qwen_fp8")
242
  class Qwen_fp8(QwenBaseExperiment):
243
  @ftimed
244
  def optimize(self):
@@ -247,7 +247,7 @@ class Qwen_fp8(QwenBaseExperiment):
247
  quantize_(self.pipe.transformer, Float8WeightOnlyConfig())
248
 
249
 
250
- @ExperimentRegistry.register(name="qwen_int8")
251
  class Qwen_int8(QwenBaseExperiment):
252
  @ftimed
253
  def optimize(self):
@@ -255,3 +255,58 @@ class Qwen_int8(QwenBaseExperiment):
255
  self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
256
  quantize_(self.pipe.transformer, Int8WeightOnlyConfig())
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from PIL import Image
11
  import pandas as pd
12
  from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
13
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, quantize_
14
  from torchao.quantization import Int8WeightOnlyConfig
15
  import spaces
16
  import torch
 
238
  )
239
 
240
 
241
+ # @ExperimentRegistry.register(name="qwen_fp8")
242
  class Qwen_fp8(QwenBaseExperiment):
243
  @ftimed
244
  def optimize(self):
 
247
  quantize_(self.pipe.transformer, Float8WeightOnlyConfig())
248
 
249
 
250
+ # @ExperimentRegistry.register(name="qwen_int8")
251
  class Qwen_int8(QwenBaseExperiment):
252
  @ftimed
253
  def optimize(self):
 
255
  self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
256
  quantize_(self.pipe.transformer, Int8WeightOnlyConfig())
257
 
258
+
259
+
260
+
261
+ @ExperimentRegistry.register(name="qwen_fa3_aot_fp8")
262
+ class Qwen_FA3_AoT_fp8(QwenBaseExperiment):
263
+ @ftimed
264
+ @spaces.GPU()
265
+ def optimize(self):
266
+ self.pipe.transformer.__class__ = QwenImageTransformer2DModel
267
+ self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
268
+ pipe_kwargs={
269
+ "image": [Image.new("RGB", (1024, 1024))],
270
+ "prompt":"prompt",
271
+ "num_inference_steps":4
272
+ }
273
+ suffix="_fa3"
274
+
275
+ cache_compiled=self.config.cache_compiled
276
+
277
+ transformer_pt2_cache_path = f"checkpoints/transformer_fp8{suffix}_archive.pt2"
278
+ transformer_weights_cache_path = f"checkpoints/transformer_fp8{suffix}_weights.pt"
279
+
280
+ print(f"original model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
281
+ quantize_(self.pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
282
+ print_first_param(self.pipe.transformer)
283
+ print(f"quantized model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
284
+
285
+ inductor_config = INDUCTOR_CONFIGS
286
+
287
+ if os.path.isfile(transformer_pt2_cache_path) and cache_compiled:
288
+ drain_module_parameters(self.pipe.transformer)
289
+ zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
290
+ compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
291
+ else:
292
+ with spaces.aoti_capture(self.pipe.transformer) as call:
293
+ self.pipe(**pipe_kwargs)
294
+
295
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
296
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
297
+
298
+ exported = torch.export.export(
299
+ mod=self.pipe.transformer,
300
+ args=call.args,
301
+ kwargs=call.kwargs,
302
+ dynamic_shapes=dynamic_shapes,
303
+ )
304
+
305
+ compiled_transformer = spaces.aoti_compile(exported, inductor_config)
306
+ with open(transformer_pt2_cache_path, "wb") as f:
307
+ f.write(compiled_transformer.archive_file.getvalue())
308
+ torch.save(compiled_transformer.weights, transformer_weights_cache_path)
309
+
310
+
311
+ aoti_apply(compiled_transformer, self.pipe.transformer)
312
+
qwenimage/models/pipeline_qwenimage_edit_plus.py CHANGED
@@ -521,6 +521,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
521
 
522
  @torch.no_grad()
523
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
524
  def __call__(
525
  self,
526
  image: Optional[PipelineImageInput] = None,
 
521
 
522
  @torch.no_grad()
523
  @replace_example_docstring(EXAMPLE_DOC_STRING)
524
+ @ftimed
525
  def __call__(
526
  self,
527
  image: Optional[PipelineImageInput] = None,
qwenimage/optimization.py CHANGED
@@ -117,3 +117,4 @@ def optimize_pipeline_(
117
 
118
 
119
  aoti_apply(compiled_transformer, pipeline.transformer)
 
 
117
 
118
 
119
  aoti_apply(compiled_transformer, pipeline.transformer)
120
+
scripts/plot_data.ipynb CHANGED
The diff for this file is too large to render. See raw diff