Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
79e79fb
1
Parent(s):
f687863
nit space
Browse files- qwenimage/optimization.py +15 -17
qwenimage/optimization.py
CHANGED
|
@@ -68,6 +68,7 @@ def drain_module_parameters(module: torch.nn.Module):
|
|
| 68 |
|
| 69 |
|
| 70 |
@ftimed
|
|
|
|
| 71 |
def optimize_pipeline_(
|
| 72 |
pipeline: Callable[P, Any],
|
| 73 |
cache_compiled=True,
|
|
@@ -96,23 +97,20 @@ def optimize_pipeline_(
|
|
| 96 |
zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
|
| 97 |
compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
|
| 98 |
else:
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
return spaces.aoti_compile(exported, inductor_config)
|
| 115 |
-
compiled_transformer = compile_transformer()
|
| 116 |
with open(transformer_pt2_cache_path, "wb") as f:
|
| 117 |
f.write(compiled_transformer.archive_file.getvalue())
|
| 118 |
torch.save(compiled_transformer.weights, transformer_weights_cache_path)
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
@ftimed
|
| 71 |
+
@spaces.GPU(duration=1500)
|
| 72 |
def optimize_pipeline_(
|
| 73 |
pipeline: Callable[P, Any],
|
| 74 |
cache_compiled=True,
|
|
|
|
| 97 |
zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
|
| 98 |
compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
|
| 99 |
else:
|
| 100 |
+
with spaces.aoti_capture(pipeline.transformer) as call:
|
| 101 |
+
pipeline(**pipe_kwargs)
|
| 102 |
+
|
| 103 |
+
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
|
| 104 |
+
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
| 105 |
+
|
| 106 |
+
exported = torch.export.export(
|
| 107 |
+
mod=pipeline.transformer,
|
| 108 |
+
args=call.args,
|
| 109 |
+
kwargs=call.kwargs,
|
| 110 |
+
dynamic_shapes=dynamic_shapes,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
compiled_transformer = spaces.aoti_compile(exported, inductor_config)
|
|
|
|
|
|
|
|
|
|
| 114 |
with open(transformer_pt2_cache_path, "wb") as f:
|
| 115 |
f.write(compiled_transformer.archive_file.getvalue())
|
| 116 |
torch.save(compiled_transformer.weights, transformer_weights_cache_path)
|