Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
cb0b907
1
Parent(s):
1ae3ff6
fp8 experiment
Browse files- app.py +29 -26
- qwenimage/experiments/experiments_qwen.py +58 -3
- qwenimage/models/pipeline_qwenimage_edit_plus.py +1 -0
- qwenimage/optimization.py +1 -0
- scripts/plot_data.ipynb +0 -0
app.py
CHANGED
|
@@ -17,8 +17,8 @@ from safetensors.torch import load_file
|
|
| 17 |
from torchao.quantization import quantize_
|
| 18 |
from torchao.quantization import Int8WeightOnlyConfig
|
| 19 |
|
| 20 |
-
from qwenimage.debug import ftimed
|
| 21 |
-
from qwenimage.experiments.experiments_qwen import Qwen_FA3_AoT_int8
|
| 22 |
from qwenimage.optimization import optimize_pipeline_
|
| 23 |
from qwenimage.prompt import build_camera_prompt
|
| 24 |
from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
|
@@ -30,6 +30,7 @@ dtype = torch.bfloat16
|
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
|
| 32 |
exp = Qwen_FA3_AoT_int8()
|
|
|
|
| 33 |
exp.load()
|
| 34 |
exp.optimize()
|
| 35 |
pipe = exp.pipe
|
|
@@ -55,30 +56,31 @@ def infer_camera_edit(
|
|
| 55 |
prev_output = None,
|
| 56 |
progress=gr.Progress(track_tqdm=True)
|
| 57 |
):
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
if
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
result = pipe(
|
| 83 |
image=pil_images,
|
| 84 |
prompt=prompt,
|
|
@@ -154,7 +156,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
|
|
| 154 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
| 155 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 156 |
true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
|
| 157 |
-
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=
|
| 158 |
height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
|
| 159 |
width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
|
| 160 |
|
|
@@ -202,6 +204,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
|
|
| 202 |
|
| 203 |
|
| 204 |
# Live updates
|
|
|
|
| 205 |
def maybe_infer(is_reset, progress=gr.Progress(track_tqdm=True), *args):
|
| 206 |
if is_reset:
|
| 207 |
return gr.update(), gr.update(), gr.update(), gr.update()
|
|
|
|
| 17 |
from torchao.quantization import quantize_
|
| 18 |
from torchao.quantization import Int8WeightOnlyConfig
|
| 19 |
|
| 20 |
+
from qwenimage.debug import ctimed, ftimed
|
| 21 |
+
from qwenimage.experiments.experiments_qwen import Qwen_FA3_AoT_fp8, Qwen_FA3_AoT_int8, QwenBaseExperiment
|
| 22 |
from qwenimage.optimization import optimize_pipeline_
|
| 23 |
from qwenimage.prompt import build_camera_prompt
|
| 24 |
from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
|
|
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
|
| 32 |
exp = Qwen_FA3_AoT_int8()
|
| 33 |
+
# exp = Qwen_FA3_AoT_fp8()
|
| 34 |
exp.load()
|
| 35 |
exp.optimize()
|
| 36 |
pipe = exp.pipe
|
|
|
|
| 56 |
prev_output = None,
|
| 57 |
progress=gr.Progress(track_tqdm=True)
|
| 58 |
):
|
| 59 |
+
with ctimed("pre pipe"):
|
| 60 |
+
prompt = build_camera_prompt(rotate_deg, move_forward, vertical_tilt, wideangle)
|
| 61 |
+
print(f"Generated Prompt: {prompt}")
|
| 62 |
+
|
| 63 |
+
if randomize_seed:
|
| 64 |
+
seed = random.randint(0, MAX_SEED)
|
| 65 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 66 |
+
|
| 67 |
+
# Choose input image (prefer uploaded, else last output)
|
| 68 |
+
pil_images = []
|
| 69 |
+
if image is not None:
|
| 70 |
+
if isinstance(image, Image.Image):
|
| 71 |
+
pil_images.append(image.convert("RGB"))
|
| 72 |
+
elif hasattr(image, "name"):
|
| 73 |
+
pil_images.append(Image.open(image.name).convert("RGB"))
|
| 74 |
+
elif prev_output:
|
| 75 |
+
pil_images.append(prev_output.convert("RGB"))
|
| 76 |
+
|
| 77 |
+
if len(pil_images) == 0:
|
| 78 |
+
raise gr.Error("Please upload an image first.")
|
| 79 |
+
|
| 80 |
+
print(f"{len(pil_images)=}")
|
| 81 |
|
| 82 |
+
if prompt == "no camera movement":
|
| 83 |
+
return image, seed, prompt
|
| 84 |
result = pipe(
|
| 85 |
image=pil_images,
|
| 86 |
prompt=prompt,
|
|
|
|
| 156 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
| 157 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 158 |
true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
|
| 159 |
+
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=3)
|
| 160 |
height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
|
| 161 |
width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
|
| 162 |
|
|
|
|
| 204 |
|
| 205 |
|
| 206 |
# Live updates
|
| 207 |
+
@ftimed
|
| 208 |
def maybe_infer(is_reset, progress=gr.Progress(track_tqdm=True), *args):
|
| 209 |
if is_reset:
|
| 210 |
return gr.update(), gr.update(), gr.update(), gr.update()
|
qwenimage/experiments/experiments_qwen.py
CHANGED
|
@@ -10,7 +10,7 @@ import torch
|
|
| 10 |
from PIL import Image
|
| 11 |
import pandas as pd
|
| 12 |
from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
|
| 13 |
-
from torchao.quantization import Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, quantize_
|
| 14 |
from torchao.quantization import Int8WeightOnlyConfig
|
| 15 |
import spaces
|
| 16 |
import torch
|
|
@@ -238,7 +238,7 @@ class Qwen_FA3_AoT_int8(QwenBaseExperiment):
|
|
| 238 |
)
|
| 239 |
|
| 240 |
|
| 241 |
-
@ExperimentRegistry.register(name="qwen_fp8")
|
| 242 |
class Qwen_fp8(QwenBaseExperiment):
|
| 243 |
@ftimed
|
| 244 |
def optimize(self):
|
|
@@ -247,7 +247,7 @@ class Qwen_fp8(QwenBaseExperiment):
|
|
| 247 |
quantize_(self.pipe.transformer, Float8WeightOnlyConfig())
|
| 248 |
|
| 249 |
|
| 250 |
-
@ExperimentRegistry.register(name="qwen_int8")
|
| 251 |
class Qwen_int8(QwenBaseExperiment):
|
| 252 |
@ftimed
|
| 253 |
def optimize(self):
|
|
@@ -255,3 +255,58 @@ class Qwen_int8(QwenBaseExperiment):
|
|
| 255 |
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
|
| 256 |
quantize_(self.pipe.transformer, Int8WeightOnlyConfig())
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from PIL import Image
|
| 11 |
import pandas as pd
|
| 12 |
from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
|
| 13 |
+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, quantize_
|
| 14 |
from torchao.quantization import Int8WeightOnlyConfig
|
| 15 |
import spaces
|
| 16 |
import torch
|
|
|
|
| 238 |
)
|
| 239 |
|
| 240 |
|
| 241 |
+
# @ExperimentRegistry.register(name="qwen_fp8")
|
| 242 |
class Qwen_fp8(QwenBaseExperiment):
|
| 243 |
@ftimed
|
| 244 |
def optimize(self):
|
|
|
|
| 247 |
quantize_(self.pipe.transformer, Float8WeightOnlyConfig())
|
| 248 |
|
| 249 |
|
| 250 |
+
# @ExperimentRegistry.register(name="qwen_int8")
|
| 251 |
class Qwen_int8(QwenBaseExperiment):
|
| 252 |
@ftimed
|
| 253 |
def optimize(self):
|
|
|
|
| 255 |
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
|
| 256 |
quantize_(self.pipe.transformer, Int8WeightOnlyConfig())
|
| 257 |
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8")
|
| 262 |
+
class Qwen_FA3_AoT_fp8(QwenBaseExperiment):
|
| 263 |
+
@ftimed
|
| 264 |
+
@spaces.GPU()
|
| 265 |
+
def optimize(self):
|
| 266 |
+
self.pipe.transformer.__class__ = QwenImageTransformer2DModel
|
| 267 |
+
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
|
| 268 |
+
pipe_kwargs={
|
| 269 |
+
"image": [Image.new("RGB", (1024, 1024))],
|
| 270 |
+
"prompt":"prompt",
|
| 271 |
+
"num_inference_steps":4
|
| 272 |
+
}
|
| 273 |
+
suffix="_fa3"
|
| 274 |
+
|
| 275 |
+
cache_compiled=self.config.cache_compiled
|
| 276 |
+
|
| 277 |
+
transformer_pt2_cache_path = f"checkpoints/transformer_fp8{suffix}_archive.pt2"
|
| 278 |
+
transformer_weights_cache_path = f"checkpoints/transformer_fp8{suffix}_weights.pt"
|
| 279 |
+
|
| 280 |
+
print(f"original model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
|
| 281 |
+
quantize_(self.pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
|
| 282 |
+
print_first_param(self.pipe.transformer)
|
| 283 |
+
print(f"quantized model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
|
| 284 |
+
|
| 285 |
+
inductor_config = INDUCTOR_CONFIGS
|
| 286 |
+
|
| 287 |
+
if os.path.isfile(transformer_pt2_cache_path) and cache_compiled:
|
| 288 |
+
drain_module_parameters(self.pipe.transformer)
|
| 289 |
+
zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
|
| 290 |
+
compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
|
| 291 |
+
else:
|
| 292 |
+
with spaces.aoti_capture(self.pipe.transformer) as call:
|
| 293 |
+
self.pipe(**pipe_kwargs)
|
| 294 |
+
|
| 295 |
+
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
|
| 296 |
+
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
| 297 |
+
|
| 298 |
+
exported = torch.export.export(
|
| 299 |
+
mod=self.pipe.transformer,
|
| 300 |
+
args=call.args,
|
| 301 |
+
kwargs=call.kwargs,
|
| 302 |
+
dynamic_shapes=dynamic_shapes,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
compiled_transformer = spaces.aoti_compile(exported, inductor_config)
|
| 306 |
+
with open(transformer_pt2_cache_path, "wb") as f:
|
| 307 |
+
f.write(compiled_transformer.archive_file.getvalue())
|
| 308 |
+
torch.save(compiled_transformer.weights, transformer_weights_cache_path)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
aoti_apply(compiled_transformer, self.pipe.transformer)
|
| 312 |
+
|
qwenimage/models/pipeline_qwenimage_edit_plus.py
CHANGED
|
@@ -521,6 +521,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 521 |
|
| 522 |
@torch.no_grad()
|
| 523 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
|
|
| 524 |
def __call__(
|
| 525 |
self,
|
| 526 |
image: Optional[PipelineImageInput] = None,
|
|
|
|
| 521 |
|
| 522 |
@torch.no_grad()
|
| 523 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 524 |
+
@ftimed
|
| 525 |
def __call__(
|
| 526 |
self,
|
| 527 |
image: Optional[PipelineImageInput] = None,
|
qwenimage/optimization.py
CHANGED
|
@@ -117,3 +117,4 @@ def optimize_pipeline_(
|
|
| 117 |
|
| 118 |
|
| 119 |
aoti_apply(compiled_transformer, pipeline.transformer)
|
|
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
aoti_apply(compiled_transformer, pipeline.transformer)
|
| 120 |
+
|
scripts/plot_data.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|