| import torch | |
| def torch_gc(): | |
| if torch.cuda.is_available(): | |
| # with torch.cuda.device(DEVICE): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| elif torch.backends.mps.is_available(): | |
| try: | |
| from torch.mps import empty_cache | |
| empty_cache() | |
| except Exception as e: | |
| print(e) | |
| print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") |