Spaces:
Runtime error
Runtime error
| # Copyright 2024 MIT Han Lab | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import builtins | |
| import json | |
| import os | |
| import pickle | |
| import time | |
| import ipdb | |
| import torch | |
| import torch.distributed as dist | |
| from triton.runtime.autotuner import Autotuner | |
| class CustomAutotuner(Autotuner): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.best_config_cache_path = os.path.expanduser( | |
| os.path.join( | |
| "~", | |
| ".triton", | |
| "best_config_cache", | |
| torch.cuda.get_device_name(0).replace(" ", "_"), | |
| self.base_fn.__name__ + ".pkl", | |
| ) | |
| ) | |
| if os.path.exists(self.best_config_cache_path): | |
| with open(self.best_config_cache_path, "rb") as f: | |
| self.cache = pickle.load(f) | |
| def run(self, *args, **kwargs): | |
| self.nargs = dict(zip(self.arg_names, args)) | |
| used_cached_result = True | |
| if len(self.configs) > 1: | |
| all_args = {**self.nargs, **kwargs} | |
| _args = [] | |
| for name in self.arg_names: | |
| if name in all_args: | |
| _args.append(all_args[name]) | |
| key = [_args[i] for i in self.key_idx] | |
| for arg in _args: | |
| if hasattr(arg, "dtype"): | |
| key.append(str(arg.dtype)) | |
| key = tuple(key) | |
| if key not in self.cache: | |
| # prune configs | |
| used_cached_result = False | |
| pruned_configs = self.prune_configs(kwargs) | |
| bench_start = time.time() | |
| timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} | |
| bench_end = time.time() | |
| self.bench_time = bench_end - bench_start | |
| self.cache[key] = builtins.min(timings, key=timings.get) | |
| self.pre_hook(args, reset_only=True) | |
| self.configs_timings = timings | |
| if not dist.is_initialized() or dist.get_rank() == 0: | |
| best_config_cache_dir = os.path.dirname(self.best_config_cache_path) | |
| os.makedirs(best_config_cache_dir, exist_ok=True) | |
| with open(self.best_config_cache_path, "wb") as f: | |
| pickle.dump(self.cache, f) | |
| config = self.cache[key] | |
| else: | |
| config = self.configs[0] | |
| self.best_config = config | |
| if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: | |
| print( | |
| f"Triton autotuning for function {self.base_fn.__name__} finished after " | |
| f"{self.bench_time:.2f}s; best config selected: {self.best_config};" | |
| ) | |
| if config.pre_hook is not None: | |
| config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) | |
| ret = self.fn.run( | |
| *args, | |
| **kwargs, | |
| **config.all_kwargs(), | |
| ) | |
| self.nargs = None | |
| return ret | |
| def custom_autotune( | |
| configs, | |
| key, | |
| prune_configs_by=None, | |
| reset_to_zero=None, | |
| restore_value=None, | |
| pre_hook=None, | |
| post_hook=None, | |
| warmup=25, | |
| rep=100, | |
| use_cuda_graph=False, | |
| ): | |
| def decorator(fn): | |
| return CustomAutotuner( | |
| fn, | |
| fn.arg_names, | |
| configs, | |
| key, | |
| reset_to_zero, | |
| restore_value, | |
| pre_hook=pre_hook, | |
| post_hook=post_hook, | |
| prune_configs_by=prune_configs_by, | |
| warmup=warmup, | |
| rep=rep, | |
| use_cuda_graph=use_cuda_graph, | |
| ) | |
| return decorator | |