Spaces:
Running
Running
add scorer for quick start
Browse files- predict.py +85 -0
- score.py +122 -0
predict.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import pathlib
|
| 3 |
+
import tqdm
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
import torchaudio
|
| 6 |
+
from score import Score
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def get_arg():
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument("--bs", required=False, default=None, type=int)
|
| 12 |
+
parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str)
|
| 13 |
+
parser.add_argument("--ckpt_path", required=False, default="epoch=3-step=7459.ckpt", type=pathlib.Path)
|
| 14 |
+
parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path)
|
| 15 |
+
parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path)
|
| 16 |
+
parser.add_argument("--out_path", required=True, type=pathlib.Path)
|
| 17 |
+
parser.add_argument("--num_workers", required=False, default=0, type=int)
|
| 18 |
+
return parser.parse_args()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Dataset(Dataset):
|
| 22 |
+
def __init__(self, dir_path: pathlib.Path):
|
| 23 |
+
self.wavlist = list(dir_path.glob("*.wav"))
|
| 24 |
+
_, self.sr = torchaudio.load(self.wavlist[0])
|
| 25 |
+
|
| 26 |
+
def __len__(self):
|
| 27 |
+
return len(self.wavlist)
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, idx):
|
| 30 |
+
fname = self.wavlist[idx]
|
| 31 |
+
wav, _ = torchaudio.load(fname)
|
| 32 |
+
sample = {
|
| 33 |
+
"wav": wav}
|
| 34 |
+
return sample
|
| 35 |
+
|
| 36 |
+
def collate_fn(self, batch):
|
| 37 |
+
max_len = max([x["wav"].shape[1] for x in batch])
|
| 38 |
+
out = []
|
| 39 |
+
# Performing repeat padding
|
| 40 |
+
for t in batch:
|
| 41 |
+
wav = t["wav"]
|
| 42 |
+
amount_to_pad = max_len - wav.shape[1]
|
| 43 |
+
padding_tensor = wav.repeat(1,1+amount_to_pad//wav.size(1))
|
| 44 |
+
out.append(torch.cat((wav,padding_tensor[:,:amount_to_pad]),dim=1))
|
| 45 |
+
return torch.stack(out, dim=0)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
args = get_arg()
|
| 50 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 51 |
+
if args.mode == "predict_file":
|
| 52 |
+
assert args.inp_path is not None, "inp_path is required when mode is predict_file."
|
| 53 |
+
assert args.inp_dir is None, "inp_dir should be None."
|
| 54 |
+
assert args.inp_path.exists()
|
| 55 |
+
assert args.inp_path.is_file()
|
| 56 |
+
wav, sr = torchaudio.load(args.inp_path)
|
| 57 |
+
scorer = Score(ckpt_path=args.ckpt_path, input_sample_rate=sr, device=device)
|
| 58 |
+
score = scorer.score(wav.to(device))
|
| 59 |
+
with open(args.out_path, "w") as fw:
|
| 60 |
+
fw.write(str(score[0]))
|
| 61 |
+
else:
|
| 62 |
+
assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
|
| 63 |
+
assert args.bs is not None, "bs is required when mode is predict_dir."
|
| 64 |
+
assert args.inp_path is None, "inp_path should be None."
|
| 65 |
+
assert args.inp_dir.exists()
|
| 66 |
+
assert args.inp_dir.is_dir()
|
| 67 |
+
dataset = Dataset(dir_path=args.inp_dir)
|
| 68 |
+
loader = DataLoader(
|
| 69 |
+
dataset,
|
| 70 |
+
batch_size=args.bs,
|
| 71 |
+
collate_fn=dataset.collate_fn,
|
| 72 |
+
shuffle=True,
|
| 73 |
+
num_workers=args.num_workers)
|
| 74 |
+
sr = dataset.sr
|
| 75 |
+
scorer = Score(ckpt_path=args.ckpt_path, input_sample_rate=sr, device=device)
|
| 76 |
+
with open(args.out_path, 'w'):
|
| 77 |
+
pass
|
| 78 |
+
for batch in tqdm.tqdm(loader):
|
| 79 |
+
scores = scorer.score(batch.to(device))
|
| 80 |
+
with open(args.out_path, 'a') as fw:
|
| 81 |
+
fw.write("\n".join([str(s) for s in scores]))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == '__main__':
|
| 85 |
+
main()
|
score.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import lightning_module
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
import unittest
|
| 7 |
+
|
| 8 |
+
class Score:
|
| 9 |
+
"""Predicting score for each audio clip."""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
ckpt_path: str = "epoch=3-step=7459.ckpt",
|
| 14 |
+
input_sample_rate: int = 16000,
|
| 15 |
+
device: str = "cpu"):
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
ckpt_path: path to pretrained checkpoint of UTMOS strong learner.
|
| 19 |
+
input_sample_rate: sampling rate of input audio tensor. The input audio tensor
|
| 20 |
+
is automatically downsampled to 16kHz.
|
| 21 |
+
"""
|
| 22 |
+
print(f"Using device: {device}")
|
| 23 |
+
self.device = device
|
| 24 |
+
self.model = lightning_module.BaselineLightningModule.load_from_checkpoint(
|
| 25 |
+
ckpt_path).eval().to(device)
|
| 26 |
+
self.in_sr = input_sample_rate
|
| 27 |
+
self.resampler = torchaudio.transforms.Resample(
|
| 28 |
+
orig_freq=input_sample_rate,
|
| 29 |
+
new_freq=16000,
|
| 30 |
+
resampling_method="sinc_interpolation",
|
| 31 |
+
lowpass_filter_width=6,
|
| 32 |
+
dtype=torch.float32,
|
| 33 |
+
).to(device)
|
| 34 |
+
|
| 35 |
+
def score(self, wavs: torch.tensor) -> torch.tensor:
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
|
| 39 |
+
the model processes the input as a single audio clip. The model
|
| 40 |
+
performs batch processing when len(wavs) == 3.
|
| 41 |
+
"""
|
| 42 |
+
if len(wavs.shape) == 1:
|
| 43 |
+
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
|
| 44 |
+
elif len(wavs.shape) == 2:
|
| 45 |
+
out_wavs = wavs.unsqueeze(0)
|
| 46 |
+
elif len(wavs.shape) == 3:
|
| 47 |
+
out_wavs = wavs
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError('Dimension of input tensor needs to be <= 3.')
|
| 50 |
+
if self.in_sr != 16000:
|
| 51 |
+
out_wavs = self.resampler(out_wavs)
|
| 52 |
+
bs = out_wavs.shape[0]
|
| 53 |
+
batch = {
|
| 54 |
+
'wav': out_wavs,
|
| 55 |
+
'domains': torch.zeros(bs, dtype=torch.int).to(self.device),
|
| 56 |
+
'judge_id': torch.ones(bs, dtype=torch.int).to(self.device)*288
|
| 57 |
+
}
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
output = self.model(batch)
|
| 60 |
+
|
| 61 |
+
return output.mean(dim=1).squeeze(1).cpu().detach().numpy()*2 + 3
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TestFunc(unittest.TestCase):
|
| 65 |
+
"""Test class."""
|
| 66 |
+
|
| 67 |
+
def test_1dim_0(self):
|
| 68 |
+
scorer = Score(input_sample_rate=16000)
|
| 69 |
+
seq_len = 10000
|
| 70 |
+
inp_audio = torch.ones(seq_len)
|
| 71 |
+
pred = scorer.score(inp_audio)
|
| 72 |
+
self.assertGreaterEqual(pred, 0.)
|
| 73 |
+
self.assertLessEqual(pred, 5.)
|
| 74 |
+
|
| 75 |
+
def test_1dim_1(self):
|
| 76 |
+
scorer = Score(input_sample_rate=24000)
|
| 77 |
+
seq_len = 10000
|
| 78 |
+
inp_audio = torch.ones(seq_len)
|
| 79 |
+
pred = scorer.score(inp_audio)
|
| 80 |
+
self.assertGreaterEqual(pred, 0.)
|
| 81 |
+
self.assertLessEqual(pred, 5.)
|
| 82 |
+
|
| 83 |
+
def test_2dim_0(self):
|
| 84 |
+
scorer = Score(input_sample_rate=16000)
|
| 85 |
+
seq_len = 10000
|
| 86 |
+
inp_audio = torch.ones(1, seq_len)
|
| 87 |
+
pred = scorer.score(inp_audio)
|
| 88 |
+
self.assertGreaterEqual(pred, 0.)
|
| 89 |
+
self.assertLessEqual(pred, 5.)
|
| 90 |
+
|
| 91 |
+
def test_2dim_1(self):
|
| 92 |
+
scorer = Score(input_sample_rate=24000)
|
| 93 |
+
seq_len = 10000
|
| 94 |
+
inp_audio = torch.ones(1, seq_len)
|
| 95 |
+
pred = scorer.score(inp_audio)
|
| 96 |
+
print(pred)
|
| 97 |
+
print(pred.shape)
|
| 98 |
+
self.assertGreaterEqual(pred, 0.)
|
| 99 |
+
self.assertLessEqual(pred, 5.)
|
| 100 |
+
|
| 101 |
+
def test_3dim_0(self):
|
| 102 |
+
scorer = Score(input_sample_rate=16000)
|
| 103 |
+
seq_len = 10000
|
| 104 |
+
batch = 8
|
| 105 |
+
inp_audio = torch.ones(batch, 1, seq_len)
|
| 106 |
+
pred = scorer.score(inp_audio)
|
| 107 |
+
for p in pred:
|
| 108 |
+
self.assertGreaterEqual(p, 0.)
|
| 109 |
+
self.assertLessEqual(p, 5.)
|
| 110 |
+
|
| 111 |
+
def test_3dim_1(self):
|
| 112 |
+
scorer = Score(input_sample_rate=24000)
|
| 113 |
+
seq_len = 10000
|
| 114 |
+
batch = 8
|
| 115 |
+
inp_audio = torch.ones(batch, 1, seq_len)
|
| 116 |
+
pred = scorer.score(inp_audio)
|
| 117 |
+
for p in pred:
|
| 118 |
+
self.assertGreaterEqual(p, 0.)
|
| 119 |
+
self.assertLessEqual(p, 5.)
|
| 120 |
+
|
| 121 |
+
if __name__ == '__main__':
|
| 122 |
+
unittest.main()
|