cbensimon HF Staff commited on
Commit
780a177
·
verified ·
1 Parent(s): d6785d3

Regional AoT

Browse files
Files changed (1) hide show
  1. optimization.py +11 -5
optimization.py CHANGED
@@ -4,8 +4,11 @@
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
 
7
  import spaces
8
  import torch
 
 
9
 
10
  from fa3 import FlashFusedFluxAttnProcessor3_0
11
 
@@ -25,10 +28,10 @@ INDUCTOR_CONFIGS = {
25
 
26
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
27
 
28
- @spaces.GPU(duration=1500)
29
- def compile_transformer():
30
 
31
- with spaces.aoti_capture(pipeline.transformer) as call:
32
  pipeline(*args, **kwargs)
33
 
34
  exported = torch.export.export(
@@ -37,8 +40,11 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
37
  kwargs=call.kwargs,
38
  )
39
 
40
- return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
41
 
42
  pipeline.transformer.fuse_qkv_projections()
43
  pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
44
- spaces.aoti_apply(compile_transformer(), pipeline.transformer)
 
 
 
 
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
7
+
8
  import spaces
9
  import torch
10
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
11
+ from spaces.zero.torch.aoti import ZeroGPUWeights
12
 
13
  from fa3 import FlashFusedFluxAttnProcessor3_0
14
 
 
28
 
29
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
30
 
31
+ @spaces.GPU(duration=120)
32
+ def compile_transformer_block():
33
 
34
+ with spaces.aoti_capture(pipeline.transformer.transformer_blocks[0]) as call:
35
  pipeline(*args, **kwargs)
36
 
37
  exported = torch.export.export(
 
40
  kwargs=call.kwargs,
41
  )
42
 
43
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS).archive_file
44
 
45
  pipeline.transformer.fuse_qkv_projections()
46
  pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
47
+
48
+ archive_file = compile_transformer_block()
49
+ for block in pipeline.transformer.transformer_blocks:
50
+ block.forward = ZeroGPUCompiledModel(archive_file, ZeroGPUWeights(block.state_dict()))