| | """ |
| | """ |
| |
|
| | from typing import Any |
| | from typing import Callable |
| | from typing import ParamSpec |
| | import spaces |
| | import torch |
| | from spaces.zero.torch.aoti import ZeroGPUCompiledModel |
| | from spaces.zero.torch.aoti import ZeroGPUWeights |
| | from torch.utils._pytree import tree_map |
| |
|
| | P = ParamSpec('P') |
| |
|
| | TRANSFORMER_IMAGE_DIM = torch.export.Dim('image_seq_length', min=4096, max=16384) |
| |
|
| | TRANSFORMER_DYNAMIC_SHAPES = { |
| | 'double': { |
| | 'hidden_states': { |
| | 1: TRANSFORMER_IMAGE_DIM, |
| | }, |
| | 'image_rotary_emb': ( |
| | {0: TRANSFORMER_IMAGE_DIM + 512}, |
| | {0: TRANSFORMER_IMAGE_DIM + 512}, |
| | ), |
| | }, |
| | 'single': { |
| | 'hidden_states': { |
| | 1: TRANSFORMER_IMAGE_DIM + 512, |
| | }, |
| | 'image_rotary_emb': ( |
| | {0: TRANSFORMER_IMAGE_DIM + 512}, |
| | {0: TRANSFORMER_IMAGE_DIM + 512}, |
| | ), |
| | }, |
| | } |
| |
|
| | INDUCTOR_CONFIGS = { |
| | 'conv_1x1_as_mm': True, |
| | 'epilogue_fusion': False, |
| | 'coordinate_descent_tuning': True, |
| | 'coordinate_descent_check_all_directions': True, |
| | 'max_autotune': True, |
| | 'triton.cudagraphs': True, |
| | } |
| |
|
| | def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs): |
| |
|
| | blocks = { |
| | 'double': pipeline.transformer.transformer_blocks, |
| | 'single': pipeline.transformer.single_transformer_blocks, |
| | } |
| |
|
| | @spaces.GPU(duration=1200) |
| | def compile_block(blocks_kind: str): |
| | block = blocks[blocks_kind][0] |
| | with spaces.aoti_capture(block) as call: |
| | pipeline(*args, **kwargs) |
| |
|
| | dynamic_shapes = tree_map(lambda t: None, call.kwargs) |
| | dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES[blocks_kind] |
| |
|
| | with torch.no_grad(): |
| | exported = torch.export.export( |
| | mod=block, |
| | args=call.args, |
| | kwargs=call.kwargs, |
| | dynamic_shapes=dynamic_shapes, |
| | ) |
| |
|
| | return spaces.aoti_compile(exported, INDUCTOR_CONFIGS).archive_file |
| |
|
| | for blocks_kind in ('double', 'single'): |
| | archive_file = compile_block(blocks_kind) |
| | for block in blocks[blocks_kind]: |
| | weights = ZeroGPUWeights(block.state_dict()) |
| | block.forward = ZeroGPUCompiledModel(archive_file, weights) |
| |
|