NeoPy commited on
Commit
2dcbf9e
·
verified ·
1 Parent(s): e95fe14
infer/lib/predictors/CREPE/CREPE.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import librosa
5
+ import scipy.stats
6
+
7
+ import numpy as np
8
+
9
+ sys.path.append(os.getcwd())
10
+
11
+ CENTS_PER_BIN, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 360, 16000, 1024
12
+
13
+ class CREPE:
14
+ def __init__(
15
+ self,
16
+ model_path,
17
+ model_size="full",
18
+ hop_length=512,
19
+ batch_size=None,
20
+ f0_min=50,
21
+ f0_max=1100,
22
+ device=None,
23
+ sample_rate=16000,
24
+ providers=None,
25
+ onnx=False,
26
+ return_periodicity=False
27
+ ):
28
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
29
+ self.hop_length = hop_length
30
+ self.batch_size = batch_size
31
+ self.sample_rate = sample_rate
32
+ self.onnx = onnx
33
+ self.f0_min = f0_min
34
+ self.f0_max = f0_max
35
+ self.return_periodicity = return_periodicity
36
+
37
+ if self.onnx:
38
+ import onnxruntime as ort
39
+
40
+ sess_options = ort.SessionOptions()
41
+ sess_options.log_severity_level = 3
42
+ self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
43
+ else:
44
+ from main.library.predictors.CREPE.model import CREPEE
45
+
46
+ model = CREPEE(model_size)
47
+ model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
48
+ model.eval()
49
+ self.model = model.to(device)
50
+
51
+ def bins_to_frequency(self, bins):
52
+ if str(bins.device).startswith(("ocl", "privateuseone")): bins = bins.to(torch.float32)
53
+
54
+ cents = CENTS_PER_BIN * bins + 1997.3794084376191
55
+ cents = (
56
+ cents + cents.new_tensor(
57
+ scipy.stats.triang.rvs(
58
+ c=0.5,
59
+ loc=-CENTS_PER_BIN,
60
+ scale=2 * CENTS_PER_BIN,
61
+ size=cents.size()
62
+ )
63
+ )
64
+ ) / 1200
65
+
66
+ return 10 * 2 ** cents
67
+
68
+ def frequency_to_bins(self, frequency, quantize_fn=torch.floor):
69
+ return quantize_fn(((1200 * (frequency / 10).log2()) - 1997.3794084376191) / CENTS_PER_BIN).int()
70
+
71
+ def viterbi(self, logits):
72
+ if not hasattr(self, 'transition'):
73
+ xx, yy = np.meshgrid(range(360), range(360))
74
+ transition = np.maximum(12 - abs(xx - yy), 0)
75
+ self.transition = transition / transition.sum(axis=1, keepdims=True)
76
+
77
+ with torch.no_grad():
78
+ probs = torch.nn.functional.softmax(logits, dim=1)
79
+
80
+ bins = torch.tensor(
81
+ np.array([
82
+ librosa.sequence.viterbi(sequence, self.transition).astype(np.int64)
83
+ for sequence in probs.cpu().numpy()
84
+ ]),
85
+ device=probs.device
86
+ )
87
+
88
+ return bins, self.bins_to_frequency(bins)
89
+
90
+ def preprocess(self, audio, pad=True):
91
+ hop_length = (self.sample_rate // 100) if self.hop_length is None else self.hop_length
92
+
93
+ if self.sample_rate != SAMPLE_RATE:
94
+ audio = torch.tensor(
95
+ librosa.resample(
96
+ audio.detach().cpu().numpy().squeeze(0),
97
+ orig_sr=self.sample_rate,
98
+ target_sr=SAMPLE_RATE,
99
+ res_type="soxr_vhq"
100
+ ),
101
+ device=audio.device
102
+ ).unsqueeze(0)
103
+
104
+ hop_length = int(hop_length * SAMPLE_RATE / self.sample_rate)
105
+
106
+ if pad:
107
+ total_frames = 1 + int(audio.size(1) // hop_length)
108
+ audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
109
+ else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
110
+
111
+ batch_size = total_frames if self.batch_size is None else self.batch_size
112
+
113
+ for i in range(0, total_frames, batch_size):
114
+ frames = torch.nn.functional.unfold(
115
+ audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)],
116
+ kernel_size=(1, WINDOW_SIZE),
117
+ stride=(1, hop_length)
118
+ )
119
+
120
+ if self.device.startswith(("ocl", "privateuseone")):
121
+ frames = frames.transpose(1, 2).contiguous().reshape(-1, WINDOW_SIZE).to(self.device)
122
+ else:
123
+ frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(self.device)
124
+
125
+ frames -= frames.mean(dim=1, keepdim=True)
126
+ frames /= torch.tensor(1e-10, device=frames.device).max(frames.std(dim=1, keepdim=True))
127
+
128
+ yield frames
129
+
130
+ def periodicity(self, probabilities, bins):
131
+ probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
132
+ periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
133
+
134
+ return periodicity.reshape(probabilities.size(0), probabilities.size(2))
135
+
136
+ def postprocess(self, probabilities):
137
+ probabilities = probabilities.detach()
138
+ probabilities[:, :self.frequency_to_bins(torch.tensor(self.f0_min))] = -float('inf')
139
+ probabilities[:, self.frequency_to_bins(torch.tensor(self.f0_max), torch.ceil):] = -float('inf')
140
+
141
+ bins, pitch = self.viterbi(probabilities)
142
+
143
+ if not self.return_periodicity: return pitch
144
+ return pitch, self.periodicity(probabilities, bins)
145
+
146
+ def compute_f0(self, audio, pad=True):
147
+ results = []
148
+
149
+ for frames in self.preprocess(audio, pad):
150
+ if self.onnx:
151
+ model = torch.tensor(
152
+ self.model.run(
153
+ [self.model.get_outputs()[0].name],
154
+ {
155
+ self.model.get_inputs()[0].name: frames.cpu().numpy()
156
+ }
157
+ )[0].transpose(1, 0)[None],
158
+ device=self.device
159
+ )
160
+ else:
161
+ with torch.no_grad():
162
+ model = self.model(
163
+ frames,
164
+ embed=False
165
+ ).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2)
166
+
167
+ result = self.postprocess(model)
168
+ results.append(
169
+ (result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device)
170
+ )
171
+
172
+ if self.return_periodicity:
173
+ pitch, periodicity = zip(*results)
174
+ return torch.cat(pitch, 1), torch.cat(periodicity, 1)
175
+
176
+ return torch.cat(results, 1)
infer/lib/predictors/CREPE/filter.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def mean(signals, win_length=9):
4
+ assert signals.dim() == 2
5
+
6
+ signals = signals.unsqueeze(1)
7
+ mask = ~torch.isnan(signals)
8
+ padding = win_length // 2
9
+
10
+ ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
11
+
12
+ avg_pooled = torch.nn.functional.conv1d(
13
+ torch.where(
14
+ mask,
15
+ signals,
16
+ torch.zeros_like(signals)
17
+ ),
18
+ ones_kernel,
19
+ stride=1,
20
+ padding=padding
21
+ ) / torch.nn.functional.conv1d(
22
+ mask.float(),
23
+ ones_kernel,
24
+ stride=1,
25
+ padding=padding
26
+ ).clamp(min=1)
27
+
28
+ avg_pooled[avg_pooled == 0] = float("nan")
29
+
30
+ return avg_pooled.squeeze(1)
31
+
32
+ def median(signals, win_length):
33
+ assert signals.dim() == 2
34
+
35
+ signals = signals.unsqueeze(1)
36
+ mask = ~torch.isnan(signals)
37
+ padding = win_length // 2
38
+
39
+ x = torch.nn.functional.pad(
40
+ torch.where(
41
+ mask,
42
+ signals,
43
+ torch.zeros_like(signals)
44
+ ),
45
+ (padding, padding),
46
+ mode="reflect"
47
+ )
48
+
49
+ mask = torch.nn.functional.pad(
50
+ mask.float(),
51
+ (padding, padding),
52
+ mode="constant",
53
+ value=0
54
+ )
55
+
56
+ x = x.unfold(2, win_length, 1)
57
+ mask = mask.unfold(2, win_length, 1)
58
+
59
+ x = x.contiguous().view(x.size()[:3] + (-1,))
60
+ mask = mask.contiguous().view(mask.size()[:3] + (-1,))
61
+
62
+ x_sorted, _ = torch.where(mask.bool(), x.float(), float("inf")).to(x).sort(dim=-1)
63
+
64
+ median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
65
+ median_pooled[torch.isinf(median_pooled)] = float("nan")
66
+
67
+ return median_pooled.squeeze(1)
infer/lib/predictors/CREPE/model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import functools
3
+
4
+ PITCH_BINS = 360
5
+
6
+ class CREPEE(torch.nn.Module):
7
+ def __init__(self, model='full'):
8
+ super().__init__()
9
+ in_channels = {"full": [1, 1024, 128, 128, 128, 256], "large": [1, 768, 96, 96, 96, 192], "medium": [1, 512, 64, 64, 64, 128], "small": [1, 256, 32, 32, 32, 64], "tiny": [1, 128, 16, 16, 16, 32]}[model]
10
+ out_channels = {"full": [1024, 128, 128, 128, 256, 512], "large": [768, 96, 96, 96, 192, 384], "medium": [512, 64, 64, 64, 128, 256], "small": [256, 32, 32, 32, 64, 128], "tiny": [128, 16, 16, 16, 32, 64]}[model]
11
+ self.in_features = {"full": 2048, "large": 1536, "medium": 1024, "small": 512, "tiny": 256}[model]
12
+
13
+ kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
14
+ strides = [(4, 1)] + 5 * [(1, 1)]
15
+ batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
16
+
17
+ self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
18
+ self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
19
+
20
+ self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
21
+ self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
22
+
23
+ self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
24
+ self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
25
+
26
+ self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
27
+ self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
28
+
29
+ self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
30
+ self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
31
+
32
+ self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
33
+ self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
34
+
35
+ self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
36
+
37
+ def forward(self, x, embed=False):
38
+ x = self.embed(x)
39
+ if embed: return x
40
+ return self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)).sigmoid()
41
+
42
+ def embed(self, x):
43
+ x = x[:, None, :, None]
44
+ return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
45
+
46
+ def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
47
+ return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))