Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import multiprocessing | |
| import subprocess | |
| import time | |
| import fsspec | |
| import torch | |
| import json | |
| from tqdm import tqdm | |
| from .train import unwrap_model | |
| def remote_sync_s3(local_dir, remote_dir): | |
| # skip epoch_latest which can change during sync. | |
| result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| if result.returncode != 0: | |
| logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") | |
| return False | |
| logging.info(f"Successfully synced with S3 bucket") | |
| return True | |
| def remote_sync_fsspec(local_dir, remote_dir): | |
| # FIXME currently this is slow and not recommended. Look into speeding up. | |
| a = fsspec.get_mapper(local_dir) | |
| b = fsspec.get_mapper(remote_dir) | |
| for k in a: | |
| # skip epoch_latest which can change during sync. | |
| if 'epoch_latest.pt' in k: | |
| continue | |
| logging.info(f'Attempting to sync {k}') | |
| if k in b and len(a[k]) == len(b[k]): | |
| logging.debug(f'Skipping remote sync for {k}.') | |
| continue | |
| try: | |
| logging.info(f'Successful sync for {k}.') | |
| b[k] = a[k] | |
| except Exception as e: | |
| logging.info(f'Error during remote sync for {k}: {e}') | |
| return False | |
| return True | |
| def remote_sync(local_dir, remote_dir, protocol): | |
| logging.info('Starting remote sync.') | |
| if protocol == 's3': | |
| return remote_sync_s3(local_dir, remote_dir) | |
| elif protocol == 'fsspec': | |
| return remote_sync_fsspec(local_dir, remote_dir) | |
| else: | |
| logging.error('Remote protocol not known') | |
| return False | |
| def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): | |
| while True: | |
| time.sleep(sync_every) | |
| remote_sync(local_dir, remote_dir, protocol) | |
| def start_sync_process(sync_every, local_dir, remote_dir, protocol): | |
| p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) | |
| return p | |
| # Note: we are not currently using this save function. | |
| def pt_save(pt_obj, file_path): | |
| of = fsspec.open(file_path, "wb") | |
| with of as f: | |
| torch.save(pt_obj, file_path) | |
| def pt_load(file_path, map_location=None): | |
| if file_path.startswith('s3'): | |
| logging.info('Loading remote checkpoint, which may take a bit.') | |
| of = fsspec.open(file_path, "rb") | |
| with of as f: | |
| out = torch.load(f, map_location=map_location) | |
| return out | |
| def check_exists(file_path): | |
| try: | |
| with fsspec.open(file_path): | |
| pass | |
| except FileNotFoundError: | |
| return False | |
| return True | |
| def save_ckpt(args, model, scaler, optimizer): | |
| assert args.save_path is not None | |
| ckpt_path = args.save_path | |
| model = unwrap_model(model) | |
| checkpoint_dict = { | |
| "iterations": args.iterations, | |
| "name": args.name, | |
| "state_dict": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| } | |
| if scaler is not None: | |
| checkpoint_dict["scaler"] = scaler.state_dict() | |
| torch.save( | |
| checkpoint_dict, | |
| ckpt_path, | |
| ) | |
| logging.info(f"saved {ckpt_path}") | |