EXP
Browse files
infer/lib/predictors/CREPE/CREPE.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import librosa
|
| 5 |
+
import scipy.stats
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
sys.path.append(os.getcwd())
|
| 10 |
+
|
| 11 |
+
CENTS_PER_BIN, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 360, 16000, 1024
|
| 12 |
+
|
| 13 |
+
class CREPE:
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
model_path,
|
| 17 |
+
model_size="full",
|
| 18 |
+
hop_length=512,
|
| 19 |
+
batch_size=None,
|
| 20 |
+
f0_min=50,
|
| 21 |
+
f0_max=1100,
|
| 22 |
+
device=None,
|
| 23 |
+
sample_rate=16000,
|
| 24 |
+
providers=None,
|
| 25 |
+
onnx=False,
|
| 26 |
+
return_periodicity=False
|
| 27 |
+
):
|
| 28 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
self.hop_length = hop_length
|
| 30 |
+
self.batch_size = batch_size
|
| 31 |
+
self.sample_rate = sample_rate
|
| 32 |
+
self.onnx = onnx
|
| 33 |
+
self.f0_min = f0_min
|
| 34 |
+
self.f0_max = f0_max
|
| 35 |
+
self.return_periodicity = return_periodicity
|
| 36 |
+
|
| 37 |
+
if self.onnx:
|
| 38 |
+
import onnxruntime as ort
|
| 39 |
+
|
| 40 |
+
sess_options = ort.SessionOptions()
|
| 41 |
+
sess_options.log_severity_level = 3
|
| 42 |
+
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
| 43 |
+
else:
|
| 44 |
+
from main.library.predictors.CREPE.model import CREPEE
|
| 45 |
+
|
| 46 |
+
model = CREPEE(model_size)
|
| 47 |
+
model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
|
| 48 |
+
model.eval()
|
| 49 |
+
self.model = model.to(device)
|
| 50 |
+
|
| 51 |
+
def bins_to_frequency(self, bins):
|
| 52 |
+
if str(bins.device).startswith(("ocl", "privateuseone")): bins = bins.to(torch.float32)
|
| 53 |
+
|
| 54 |
+
cents = CENTS_PER_BIN * bins + 1997.3794084376191
|
| 55 |
+
cents = (
|
| 56 |
+
cents + cents.new_tensor(
|
| 57 |
+
scipy.stats.triang.rvs(
|
| 58 |
+
c=0.5,
|
| 59 |
+
loc=-CENTS_PER_BIN,
|
| 60 |
+
scale=2 * CENTS_PER_BIN,
|
| 61 |
+
size=cents.size()
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
) / 1200
|
| 65 |
+
|
| 66 |
+
return 10 * 2 ** cents
|
| 67 |
+
|
| 68 |
+
def frequency_to_bins(self, frequency, quantize_fn=torch.floor):
|
| 69 |
+
return quantize_fn(((1200 * (frequency / 10).log2()) - 1997.3794084376191) / CENTS_PER_BIN).int()
|
| 70 |
+
|
| 71 |
+
def viterbi(self, logits):
|
| 72 |
+
if not hasattr(self, 'transition'):
|
| 73 |
+
xx, yy = np.meshgrid(range(360), range(360))
|
| 74 |
+
transition = np.maximum(12 - abs(xx - yy), 0)
|
| 75 |
+
self.transition = transition / transition.sum(axis=1, keepdims=True)
|
| 76 |
+
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
| 79 |
+
|
| 80 |
+
bins = torch.tensor(
|
| 81 |
+
np.array([
|
| 82 |
+
librosa.sequence.viterbi(sequence, self.transition).astype(np.int64)
|
| 83 |
+
for sequence in probs.cpu().numpy()
|
| 84 |
+
]),
|
| 85 |
+
device=probs.device
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return bins, self.bins_to_frequency(bins)
|
| 89 |
+
|
| 90 |
+
def preprocess(self, audio, pad=True):
|
| 91 |
+
hop_length = (self.sample_rate // 100) if self.hop_length is None else self.hop_length
|
| 92 |
+
|
| 93 |
+
if self.sample_rate != SAMPLE_RATE:
|
| 94 |
+
audio = torch.tensor(
|
| 95 |
+
librosa.resample(
|
| 96 |
+
audio.detach().cpu().numpy().squeeze(0),
|
| 97 |
+
orig_sr=self.sample_rate,
|
| 98 |
+
target_sr=SAMPLE_RATE,
|
| 99 |
+
res_type="soxr_vhq"
|
| 100 |
+
),
|
| 101 |
+
device=audio.device
|
| 102 |
+
).unsqueeze(0)
|
| 103 |
+
|
| 104 |
+
hop_length = int(hop_length * SAMPLE_RATE / self.sample_rate)
|
| 105 |
+
|
| 106 |
+
if pad:
|
| 107 |
+
total_frames = 1 + int(audio.size(1) // hop_length)
|
| 108 |
+
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
|
| 109 |
+
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
|
| 110 |
+
|
| 111 |
+
batch_size = total_frames if self.batch_size is None else self.batch_size
|
| 112 |
+
|
| 113 |
+
for i in range(0, total_frames, batch_size):
|
| 114 |
+
frames = torch.nn.functional.unfold(
|
| 115 |
+
audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)],
|
| 116 |
+
kernel_size=(1, WINDOW_SIZE),
|
| 117 |
+
stride=(1, hop_length)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if self.device.startswith(("ocl", "privateuseone")):
|
| 121 |
+
frames = frames.transpose(1, 2).contiguous().reshape(-1, WINDOW_SIZE).to(self.device)
|
| 122 |
+
else:
|
| 123 |
+
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(self.device)
|
| 124 |
+
|
| 125 |
+
frames -= frames.mean(dim=1, keepdim=True)
|
| 126 |
+
frames /= torch.tensor(1e-10, device=frames.device).max(frames.std(dim=1, keepdim=True))
|
| 127 |
+
|
| 128 |
+
yield frames
|
| 129 |
+
|
| 130 |
+
def periodicity(self, probabilities, bins):
|
| 131 |
+
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
|
| 132 |
+
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
|
| 133 |
+
|
| 134 |
+
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
|
| 135 |
+
|
| 136 |
+
def postprocess(self, probabilities):
|
| 137 |
+
probabilities = probabilities.detach()
|
| 138 |
+
probabilities[:, :self.frequency_to_bins(torch.tensor(self.f0_min))] = -float('inf')
|
| 139 |
+
probabilities[:, self.frequency_to_bins(torch.tensor(self.f0_max), torch.ceil):] = -float('inf')
|
| 140 |
+
|
| 141 |
+
bins, pitch = self.viterbi(probabilities)
|
| 142 |
+
|
| 143 |
+
if not self.return_periodicity: return pitch
|
| 144 |
+
return pitch, self.periodicity(probabilities, bins)
|
| 145 |
+
|
| 146 |
+
def compute_f0(self, audio, pad=True):
|
| 147 |
+
results = []
|
| 148 |
+
|
| 149 |
+
for frames in self.preprocess(audio, pad):
|
| 150 |
+
if self.onnx:
|
| 151 |
+
model = torch.tensor(
|
| 152 |
+
self.model.run(
|
| 153 |
+
[self.model.get_outputs()[0].name],
|
| 154 |
+
{
|
| 155 |
+
self.model.get_inputs()[0].name: frames.cpu().numpy()
|
| 156 |
+
}
|
| 157 |
+
)[0].transpose(1, 0)[None],
|
| 158 |
+
device=self.device
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
model = self.model(
|
| 163 |
+
frames,
|
| 164 |
+
embed=False
|
| 165 |
+
).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2)
|
| 166 |
+
|
| 167 |
+
result = self.postprocess(model)
|
| 168 |
+
results.append(
|
| 169 |
+
(result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if self.return_periodicity:
|
| 173 |
+
pitch, periodicity = zip(*results)
|
| 174 |
+
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
| 175 |
+
|
| 176 |
+
return torch.cat(results, 1)
|
infer/lib/predictors/CREPE/filter.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def mean(signals, win_length=9):
|
| 4 |
+
assert signals.dim() == 2
|
| 5 |
+
|
| 6 |
+
signals = signals.unsqueeze(1)
|
| 7 |
+
mask = ~torch.isnan(signals)
|
| 8 |
+
padding = win_length // 2
|
| 9 |
+
|
| 10 |
+
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
|
| 11 |
+
|
| 12 |
+
avg_pooled = torch.nn.functional.conv1d(
|
| 13 |
+
torch.where(
|
| 14 |
+
mask,
|
| 15 |
+
signals,
|
| 16 |
+
torch.zeros_like(signals)
|
| 17 |
+
),
|
| 18 |
+
ones_kernel,
|
| 19 |
+
stride=1,
|
| 20 |
+
padding=padding
|
| 21 |
+
) / torch.nn.functional.conv1d(
|
| 22 |
+
mask.float(),
|
| 23 |
+
ones_kernel,
|
| 24 |
+
stride=1,
|
| 25 |
+
padding=padding
|
| 26 |
+
).clamp(min=1)
|
| 27 |
+
|
| 28 |
+
avg_pooled[avg_pooled == 0] = float("nan")
|
| 29 |
+
|
| 30 |
+
return avg_pooled.squeeze(1)
|
| 31 |
+
|
| 32 |
+
def median(signals, win_length):
|
| 33 |
+
assert signals.dim() == 2
|
| 34 |
+
|
| 35 |
+
signals = signals.unsqueeze(1)
|
| 36 |
+
mask = ~torch.isnan(signals)
|
| 37 |
+
padding = win_length // 2
|
| 38 |
+
|
| 39 |
+
x = torch.nn.functional.pad(
|
| 40 |
+
torch.where(
|
| 41 |
+
mask,
|
| 42 |
+
signals,
|
| 43 |
+
torch.zeros_like(signals)
|
| 44 |
+
),
|
| 45 |
+
(padding, padding),
|
| 46 |
+
mode="reflect"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
mask = torch.nn.functional.pad(
|
| 50 |
+
mask.float(),
|
| 51 |
+
(padding, padding),
|
| 52 |
+
mode="constant",
|
| 53 |
+
value=0
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
x = x.unfold(2, win_length, 1)
|
| 57 |
+
mask = mask.unfold(2, win_length, 1)
|
| 58 |
+
|
| 59 |
+
x = x.contiguous().view(x.size()[:3] + (-1,))
|
| 60 |
+
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
|
| 61 |
+
|
| 62 |
+
x_sorted, _ = torch.where(mask.bool(), x.float(), float("inf")).to(x).sort(dim=-1)
|
| 63 |
+
|
| 64 |
+
median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
|
| 65 |
+
median_pooled[torch.isinf(median_pooled)] = float("nan")
|
| 66 |
+
|
| 67 |
+
return median_pooled.squeeze(1)
|
infer/lib/predictors/CREPE/model.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
PITCH_BINS = 360
|
| 5 |
+
|
| 6 |
+
class CREPEE(torch.nn.Module):
|
| 7 |
+
def __init__(self, model='full'):
|
| 8 |
+
super().__init__()
|
| 9 |
+
in_channels = {"full": [1, 1024, 128, 128, 128, 256], "large": [1, 768, 96, 96, 96, 192], "medium": [1, 512, 64, 64, 64, 128], "small": [1, 256, 32, 32, 32, 64], "tiny": [1, 128, 16, 16, 16, 32]}[model]
|
| 10 |
+
out_channels = {"full": [1024, 128, 128, 128, 256, 512], "large": [768, 96, 96, 96, 192, 384], "medium": [512, 64, 64, 64, 128, 256], "small": [256, 32, 32, 32, 64, 128], "tiny": [128, 16, 16, 16, 32, 64]}[model]
|
| 11 |
+
self.in_features = {"full": 2048, "large": 1536, "medium": 1024, "small": 512, "tiny": 256}[model]
|
| 12 |
+
|
| 13 |
+
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
|
| 14 |
+
strides = [(4, 1)] + 5 * [(1, 1)]
|
| 15 |
+
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
|
| 16 |
+
|
| 17 |
+
self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
|
| 18 |
+
self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
|
| 19 |
+
|
| 20 |
+
self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
|
| 21 |
+
self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
|
| 22 |
+
|
| 23 |
+
self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
|
| 24 |
+
self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
|
| 25 |
+
|
| 26 |
+
self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
|
| 27 |
+
self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
|
| 28 |
+
|
| 29 |
+
self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
|
| 30 |
+
self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
|
| 31 |
+
|
| 32 |
+
self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
|
| 33 |
+
self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
|
| 34 |
+
|
| 35 |
+
self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
|
| 36 |
+
|
| 37 |
+
def forward(self, x, embed=False):
|
| 38 |
+
x = self.embed(x)
|
| 39 |
+
if embed: return x
|
| 40 |
+
return self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)).sigmoid()
|
| 41 |
+
|
| 42 |
+
def embed(self, x):
|
| 43 |
+
x = x[:, None, :, None]
|
| 44 |
+
return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
|
| 45 |
+
|
| 46 |
+
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
|
| 47 |
+
return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
|