| | import functools |
| | import gc |
| |
|
| | import torch |
| | try: |
| | |
| | |
| | import intel_extension_for_pytorch as ipex |
| | except Exception: |
| | pass |
| |
|
| |
|
| | try: |
| | HAS_CUDA = torch.cuda.is_available() |
| | except Exception: |
| | HAS_CUDA = False |
| |
|
| | try: |
| | HAS_MPS = torch.backends.mps.is_available() |
| | except Exception: |
| | HAS_MPS = False |
| |
|
| | try: |
| | HAS_XPU = torch.xpu.is_available() |
| | except Exception: |
| | HAS_XPU = False |
| |
|
| |
|
| | def clean_memory(): |
| | gc.collect() |
| | if HAS_CUDA: |
| | torch.cuda.empty_cache() |
| | if HAS_XPU: |
| | torch.xpu.empty_cache() |
| | if HAS_MPS: |
| | torch.mps.empty_cache() |
| |
|
| |
|
| | def clean_memory_on_device(device: torch.device): |
| | r""" |
| | Clean memory on the specified device, will be called from training scripts. |
| | """ |
| | gc.collect() |
| |
|
| | |
| | if device.type == "cuda": |
| | torch.cuda.empty_cache() |
| | if device.type == "xpu": |
| | torch.xpu.empty_cache() |
| | if device.type == "mps": |
| | torch.mps.empty_cache() |
| |
|
| |
|
| | @functools.lru_cache(maxsize=None) |
| | def get_preferred_device() -> torch.device: |
| | r""" |
| | Do not call this function from training scripts. Use accelerator.device instead. |
| | """ |
| | if HAS_CUDA: |
| | device = torch.device("cuda") |
| | elif HAS_XPU: |
| | device = torch.device("xpu") |
| | elif HAS_MPS: |
| | device = torch.device("mps") |
| | else: |
| | device = torch.device("cpu") |
| | print(f"get_preferred_device() -> {device}") |
| | return device |
| |
|
| |
|
| | def init_ipex(): |
| | """ |
| | Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. |
| | |
| | This function should run right after importing torch and before doing anything else. |
| | |
| | If xpu is not available, this function does nothing. |
| | """ |
| | try: |
| | if HAS_XPU: |
| | from library.ipex import ipex_init |
| |
|
| | is_initialized, error_message = ipex_init() |
| | if not is_initialized: |
| | print("failed to initialize ipex:", error_message) |
| | else: |
| | return |
| | except Exception as e: |
| | print("failed to initialize ipex:", e) |
| |
|