cbensimon HF Staff commited on
Commit
c628e2d
·
verified ·
1 Parent(s): 38cc643

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +25 -10
optimization.py CHANGED
@@ -28,25 +28,40 @@ INDUCTOR_CONFIGS = {
28
 
29
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
30
 
31
- blocks = pipeline.transformer.transformer_blocks
 
32
 
33
- @spaces.GPU(duration=120)
34
- def compile_transformer_block():
35
 
36
- with spaces.aoti_capture(blocks[0]) as call:
37
  pipeline(*args, **kwargs)
38
 
39
- exported = torch.export.export(
40
- mod=blocks[0],
 
 
 
41
  args=call.args,
42
  kwargs=call.kwargs,
43
  )
44
 
45
- return spaces.aoti_compile(exported, INDUCTOR_CONFIGS).archive_file
 
 
 
 
 
 
 
 
 
46
 
47
  pipeline.transformer.fuse_qkv_projections()
48
  pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
49
 
50
- archive_file = compile_transformer_block()
51
- for block in blocks:
52
- block.forward = ZeroGPUCompiledModel(archive_file, ZeroGPUWeights(block.state_dict()))
 
 
 
28
 
29
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
30
 
31
+ blocks_A = pipeline.transformer.transformer_blocks
32
+ blocks_B = pipeline.transformer.single_transformer_blocks
33
 
34
+ @spaces.GPU(duration=1500)
35
+ def compile_transformer_block_AB():
36
 
37
+ with spaces.aoti_capture(blocks_A[0]) as call_A:
38
  pipeline(*args, **kwargs)
39
 
40
+ with spaces.aoti_capture(blocks_B[0]) as call_B:
41
+ pipeline(*args, **kwargs)
42
+
43
+ exported_A = torch.export.export(
44
+ mod=blocks_A[0],
45
  args=call.args,
46
  kwargs=call.kwargs,
47
  )
48
 
49
+ exported_B = torch.export.export(
50
+ mod=blocks_B[0],
51
+ args=call.args,
52
+ kwargs=call.kwargs,
53
+ )
54
+
55
+ return (
56
+ spaces.aoti_compile(exported_A, INDUCTOR_CONFIGS).archive_file,
57
+ spaces.aoti_compile(exported_B, INDUCTOR_CONFIGS).archive_file,
58
+ )
59
 
60
  pipeline.transformer.fuse_qkv_projections()
61
  pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
62
 
63
+ archive_file_A, archive_file_B = compile_transformer_block_AB()
64
+ for blocks, archive_file in zip((blocks_A, blocks_B), (archive_file_A, archive_file_B)):
65
+ for block in blocks:
66
+ weights = ZeroGPUWeights(block.state_dict())
67
+ block.forward = ZeroGPUCompiledModel(archive_file, weights)