cbensimon HF Staff commited on
Commit
38cc643
·
verified ·
1 Parent(s): 780a177

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +6 -4
optimization.py CHANGED
@@ -28,14 +28,16 @@ INDUCTOR_CONFIGS = {
28
 
29
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
30
 
 
 
31
  @spaces.GPU(duration=120)
32
  def compile_transformer_block():
33
-
34
- with spaces.aoti_capture(pipeline.transformer.transformer_blocks[0]) as call:
35
  pipeline(*args, **kwargs)
36
 
37
  exported = torch.export.export(
38
- mod=pipeline.transformer,
39
  args=call.args,
40
  kwargs=call.kwargs,
41
  )
@@ -46,5 +48,5 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
46
  pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
47
 
48
  archive_file = compile_transformer_block()
49
- for block in pipeline.transformer.transformer_blocks:
50
  block.forward = ZeroGPUCompiledModel(archive_file, ZeroGPUWeights(block.state_dict()))
 
28
 
29
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
30
 
31
+ blocks = pipeline.transformer.transformer_blocks
32
+
33
  @spaces.GPU(duration=120)
34
  def compile_transformer_block():
35
+
36
+ with spaces.aoti_capture(blocks[0]) as call:
37
  pipeline(*args, **kwargs)
38
 
39
  exported = torch.export.export(
40
+ mod=blocks[0],
41
  args=call.args,
42
  kwargs=call.kwargs,
43
  )
 
48
  pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
49
 
50
  archive_file = compile_transformer_block()
51
+ for block in blocks:
52
  block.forward = ZeroGPUCompiledModel(archive_file, ZeroGPUWeights(block.state_dict()))