""" """ from typing import Any from typing import Callable from typing import ParamSpec import spaces import torch from spaces.zero.torch.aoti import ZeroGPUCompiledModel from spaces.zero.torch.aoti import ZeroGPUWeights from fa3 import FlashFusedFluxAttnProcessor3_0 P = ParamSpec('P') INDUCTOR_CONFIGS = { 'conv_1x1_as_mm': True, 'epilogue_fusion': False, 'coordinate_descent_tuning': True, 'coordinate_descent_check_all_directions': True, 'max_autotune': True, 'triton.cudagraphs': True, } def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs): blocks_A = pipeline.transformer.transformer_blocks blocks_B = pipeline.transformer.single_transformer_blocks @spaces.GPU(duration=1500) def compile_transformer_block_AB(): with spaces.aoti_capture(blocks_A[0]) as call_A: pipeline(*args, **kwargs) with spaces.aoti_capture(blocks_B[0]) as call_B: pipeline(*args, **kwargs) exported_A = torch.export.export( mod=blocks_A[0], args=call.args, kwargs=call.kwargs, ) exported_B = torch.export.export( mod=blocks_B[0], args=call.args, kwargs=call.kwargs, ) return ( spaces.aoti_compile(exported_A, INDUCTOR_CONFIGS).archive_file, spaces.aoti_compile(exported_B, INDUCTOR_CONFIGS).archive_file, ) pipeline.transformer.fuse_qkv_projections() pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0()) archive_file_A, archive_file_B = compile_transformer_block_AB() for blocks, archive_file in zip((blocks_A, blocks_B), (archive_file_A, archive_file_B)): for block in blocks: weights = ZeroGPUWeights(block.state_dict()) block.forward = ZeroGPUCompiledModel(archive_file, weights)