Spaces:
Runtime error
Runtime error
| """ | |
| """ | |
| from typing import Any | |
| from typing import Callable | |
| from typing import ParamSpec | |
| import spaces | |
| import torch | |
| from spaces.zero.torch.aoti import ZeroGPUCompiledModel | |
| from spaces.zero.torch.aoti import ZeroGPUWeights | |
| from fa3 import FlashFusedFluxAttnProcessor3_0 | |
| P = ParamSpec('P') | |
| INDUCTOR_CONFIGS = { | |
| 'conv_1x1_as_mm': True, | |
| 'epilogue_fusion': False, | |
| 'coordinate_descent_tuning': True, | |
| 'coordinate_descent_check_all_directions': True, | |
| 'max_autotune': True, | |
| 'triton.cudagraphs': True, | |
| } | |
| def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs): | |
| blocks_A = pipeline.transformer.transformer_blocks | |
| blocks_B = pipeline.transformer.single_transformer_blocks | |
| def compile_transformer_block_AB(): | |
| with spaces.aoti_capture(blocks_A[0]) as call_A: | |
| pipeline(*args, **kwargs) | |
| with spaces.aoti_capture(blocks_B[0]) as call_B: | |
| pipeline(*args, **kwargs) | |
| exported_A = torch.export.export( | |
| mod=blocks_A[0], | |
| args=call.args, | |
| kwargs=call.kwargs, | |
| ) | |
| exported_B = torch.export.export( | |
| mod=blocks_B[0], | |
| args=call.args, | |
| kwargs=call.kwargs, | |
| ) | |
| return ( | |
| spaces.aoti_compile(exported_A, INDUCTOR_CONFIGS).archive_file, | |
| spaces.aoti_compile(exported_B, INDUCTOR_CONFIGS).archive_file, | |
| ) | |
| pipeline.transformer.fuse_qkv_projections() | |
| pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0()) | |
| archive_file_A, archive_file_B = compile_transformer_block_AB() | |
| for blocks, archive_file in zip((blocks_A, blocks_B), (archive_file_A, archive_file_B)): | |
| for block in blocks: | |
| weights = ZeroGPUWeights(block.state_dict()) | |
| block.forward = ZeroGPUCompiledModel(archive_file, weights) | |