Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import omegaconf | |
| from omegaconf import DictConfig | |
| import wandb | |
| from checkpoing_saver import CheckpointSaver | |
| from denoisers import get_model | |
| from optimizers import get_optimizer | |
| from losses import get_loss | |
| from datasets import get_datasets | |
| from testing.metrics import Metrics | |
| from datasets.minimal import Minimal | |
| from tqdm import tqdm | |
| def init_wandb(cfg): | |
| wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host']) | |
| wandb.init(project=cfg['wandb']['project'], | |
| notes=cfg['wandb']['notes'], | |
| config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), | |
| resume=cfg['wandb']['resume'], | |
| name=cfg['wandb']['run_name']) | |
| if wandb.run.resumed: | |
| api = wandb.Api() | |
| runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}", | |
| order='train_pesq') | |
| run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0] | |
| artifacts = run.logged_artifacts() | |
| best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0] | |
| best_model.download() | |
| def train(cfg: DictConfig): | |
| device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu') | |
| init_wandb(cfg) | |
| checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name, | |
| decreasing=False) | |
| metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device) | |
| model = get_model(cfg['model']).to(device) | |
| optimizer = get_optimizer(model.parameters(), cfg['optimizer']) | |
| loss_fn = get_loss(cfg['loss'], device) | |
| train_dataset, valid_dataset = get_datasets(cfg) | |
| minimal_dataset = Minimal(cfg) | |
| dataloaders = { | |
| 'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True, | |
| num_workers=cfg['dataloader']['num_workers']), | |
| 'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False, | |
| num_workers=cfg['dataloader']['num_workers']), | |
| 'minimal': DataLoader(minimal_dataset) | |
| } | |
| wandb.watch(model, log_freq=cfg['wandb']['log_interval']) | |
| epoch = 0 | |
| while epoch < cfg['training']['num_epochs']: | |
| for phase in ['train', 'val']: | |
| if phase == 'train': | |
| model.train() | |
| else: | |
| model.eval() | |
| running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0 | |
| loop = tqdm(dataloaders[phase]) | |
| for i, (inputs, labels) in enumerate(loop): | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| with torch.set_grad_enabled(phase == 'train'): | |
| outputs = model(inputs) | |
| loss = loss_fn(outputs, labels) | |
| if phase == 'train': | |
| loss.backward() | |
| optimizer.step() | |
| running_metrics = metrics(denoised=outputs, clean=labels) | |
| running_loss += loss.item() * inputs.size(0) | |
| running_pesq += running_metrics['PESQ'] | |
| running_stoi += running_metrics['STOI'] | |
| loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]") | |
| loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0), | |
| pesq=running_pesq / (i + 1) / inputs.size(0), | |
| stoi=running_stoi / (i + 1) / inputs.size(0)) | |
| if phase == 'train' and i % cfg['wandb']['log_interval'] == 0: | |
| wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0), | |
| "train_pesq": running_pesq / (i + 1) / inputs.size(0), | |
| "train_stoi": running_stoi / (i + 1) / inputs.size(0)}) | |
| epoch_loss = running_loss / len(dataloaders[phase].dataset) | |
| eposh_pesq = running_pesq / len(dataloaders[phase].dataset) | |
| eposh_stoi = running_stoi / len(dataloaders[phase].dataset) | |
| wandb.log({f"{phase}_loss": epoch_loss, | |
| f"{phase}_pesq": eposh_pesq, | |
| f"{phase}_stoi": eposh_stoi}) | |
| if phase == 'val': | |
| for i, (wav, rate) in enumerate(dataloaders['minimal']): | |
| if cfg['dataloader']['normalize']: | |
| std = torch.std(wav) | |
| wav = wav / std | |
| prediction = model(wav.to(device)) | |
| prediction = prediction * std | |
| else: | |
| prediction = model(wav.to(device)) | |
| wandb.log({ | |
| f"{i}_example": wandb.Audio( | |
| prediction.detach().cpu().numpy()[0][0], | |
| sample_rate=rate)}) | |
| checkpoint_saver(model, epoch, metric_val=eposh_pesq, | |
| optimizer=optimizer, loss=epoch_loss) | |
| epoch += 1 | |
| if __name__ == "__main__": | |
| pass | |