Elea Zhong commited on
Commit
79e79fb
·
1 Parent(s): f687863
Files changed (1) hide show
  1. qwenimage/optimization.py +15 -17
qwenimage/optimization.py CHANGED
@@ -68,6 +68,7 @@ def drain_module_parameters(module: torch.nn.Module):
68
 
69
 
70
  @ftimed
 
71
  def optimize_pipeline_(
72
  pipeline: Callable[P, Any],
73
  cache_compiled=True,
@@ -96,23 +97,20 @@ def optimize_pipeline_(
96
  zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
97
  compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
98
  else:
99
- @spaces.GPU(duration=1500)
100
- def compile_transformer():
101
- with spaces.aoti_capture(pipeline.transformer) as call:
102
- pipeline(**pipe_kwargs)
103
-
104
- dynamic_shapes = tree_map(lambda t: None, call.kwargs)
105
- dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
106
-
107
- exported = torch.export.export(
108
- mod=pipeline.transformer,
109
- args=call.args,
110
- kwargs=call.kwargs,
111
- dynamic_shapes=dynamic_shapes,
112
- )
113
-
114
- return spaces.aoti_compile(exported, inductor_config)
115
- compiled_transformer = compile_transformer()
116
  with open(transformer_pt2_cache_path, "wb") as f:
117
  f.write(compiled_transformer.archive_file.getvalue())
118
  torch.save(compiled_transformer.weights, transformer_weights_cache_path)
 
68
 
69
 
70
  @ftimed
71
+ @spaces.GPU(duration=1500)
72
  def optimize_pipeline_(
73
  pipeline: Callable[P, Any],
74
  cache_compiled=True,
 
97
  zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
98
  compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
99
  else:
100
+ with spaces.aoti_capture(pipeline.transformer) as call:
101
+ pipeline(**pipe_kwargs)
102
+
103
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
104
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
105
+
106
+ exported = torch.export.export(
107
+ mod=pipeline.transformer,
108
+ args=call.args,
109
+ kwargs=call.kwargs,
110
+ dynamic_shapes=dynamic_shapes,
111
+ )
112
+
113
+ compiled_transformer = spaces.aoti_compile(exported, inductor_config)
 
 
 
114
  with open(transformer_pt2_cache_path, "wb") as f:
115
  f.write(compiled_transformer.archive_file.getvalue())
116
  torch.save(compiled_transformer.weights, transformer_weights_cache_path)