NeoPy commited on
Commit
0a0615c
·
verified ·
1 Parent(s): 8d16f65
infer/lib/predictors/FCPE/FCPE.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import onnxruntime as ort
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+ from torch.nn.utils.parametrizations import weight_norm
12
+
13
+ sys.path.append(os.getcwd())
14
+ os.environ["LRU_CACHE_CAPACITY"] = "3"
15
+
16
+ from main.library.predictors.FCPE.wav2mel import Wav2Mel
17
+ from main.library.predictors.FCPE.encoder import EncoderLayer, ConformerNaiveEncoder
18
+ from main.library.predictors.FCPE.utils import batch_interp_with_replacement_detach, decrypt_model, DotDict
19
+
20
+ @torch.no_grad()
21
+ def cent_to_f0(cent):
22
+ return 10 * 2 ** (cent / 1200)
23
+
24
+ @torch.no_grad()
25
+ def f0_to_cent(f0):
26
+ return 1200 * (f0 / 10).log2()
27
+
28
+ @torch.no_grad()
29
+ def latent2cents_decoder(cent_table, y, threshold = 0.05, mask = True):
30
+ if str(y.device).startswith("privateuseone"):
31
+ cent_table = cent_table.cpu()
32
+ y = y.cpu()
33
+
34
+ B, N, _ = y.size()
35
+ ci = cent_table[None, None, :].expand(B, N, -1)
36
+ rtn = (ci * y).sum(dim=-1, keepdim=True) / y.sum(dim=-1, keepdim=True)
37
+
38
+ if mask:
39
+ confident = y.max(dim=-1, keepdim=True)[0]
40
+ confident_mask = torch.ones_like(confident)
41
+ confident_mask[confident <= threshold] = float("-INF")
42
+ rtn = rtn * confident_mask
43
+
44
+ return rtn
45
+
46
+ @torch.no_grad()
47
+ def latent2cents_local_decoder(cent_table, out_dims, y, threshold = 0.05, mask = True):
48
+ if str(y.device).startswith("privateuseone"):
49
+ cent_table = cent_table.cpu()
50
+ y = y.cpu()
51
+
52
+ B, N, _ = y.size()
53
+ ci = cent_table[None, None, :].expand(B, N, -1)
54
+ confident, max_index = y.max(dim=-1, keepdim=True)
55
+
56
+ local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
57
+ local_argmax_index[local_argmax_index < 0] = 0
58
+ local_argmax_index[local_argmax_index >= out_dims] = out_dims - 1
59
+
60
+ y_l = y.gather(-1, local_argmax_index)
61
+ rtn = (ci.gather(-1, local_argmax_index) * y_l).sum(dim=-1, keepdim=True) / y_l.sum(dim=-1, keepdim=True)
62
+
63
+ if mask:
64
+ confident_mask = torch.ones_like(confident)
65
+ confident_mask[confident <= threshold] = float("-INF")
66
+ rtn = rtn * confident_mask
67
+
68
+ return rtn
69
+
70
+ def cents_decoder(cent_table, y, confidence, threshold = 0.05, mask=True):
71
+ if str(y.device).startswith("privateuseone"):
72
+ cent_table = cent_table.cpu()
73
+ y = y.cpu()
74
+
75
+ B, N, _ = y.size()
76
+ rtn = (cent_table[None, None, :].expand(B, N, -1) * y).sum(dim=-1, keepdim=True) / y.sum(dim=-1, keepdim=True)
77
+
78
+ if mask:
79
+ confident = y.max(dim=-1, keepdim=True)[0]
80
+ confident_mask = torch.ones_like(confident)
81
+ confident_mask[confident <= threshold] = float("-INF")
82
+ rtn = rtn * confident_mask
83
+
84
+ return (rtn, confident) if confidence else rtn
85
+
86
+ def cents_local_decoder(cent_table, y, n_out, confidence, threshold = 0.05, mask=True):
87
+ if str(y.device).startswith("privateuseone"):
88
+ cent_table = cent_table.cpu()
89
+ y = y.cpu()
90
+
91
+ B, N, _ = y.size()
92
+ confident, max_index = y.max(dim=-1, keepdim=True)
93
+ local_argmax_index = (torch.arange(0, 9).to(max_index.device) + (max_index - 4)).clamp(0, n_out - 1)
94
+ y_l = y.gather(-1, local_argmax_index)
95
+ rtn = (cent_table[None, None, :].expand(B, N, -1).gather(-1, local_argmax_index) * y_l).sum(dim=-1, keepdim=True) / y_l.sum(dim=-1, keepdim=True)
96
+
97
+ if mask:
98
+ confident_mask = torch.ones_like(confident)
99
+ confident_mask[confident <= threshold] = float("-INF")
100
+ rtn = rtn * confident_mask
101
+
102
+ return (rtn, confident) if confidence else rtn
103
+
104
+ class PCmer(nn.Module):
105
+ def __init__(
106
+ self,
107
+ num_layers,
108
+ num_heads,
109
+ dim_model,
110
+ dim_keys,
111
+ dim_values,
112
+ residual_dropout,
113
+ attention_dropout
114
+ ):
115
+ super().__init__()
116
+ self.num_layers = num_layers
117
+ self.num_heads = num_heads
118
+ self.dim_model = dim_model
119
+ self.dim_values = dim_values
120
+ self.dim_keys = dim_keys
121
+ self.residual_dropout = residual_dropout
122
+ self.attention_dropout = attention_dropout
123
+ self._layers = nn.ModuleList([EncoderLayer(self) for _ in range(num_layers)])
124
+
125
+ def forward(self, phone, mask=None):
126
+ for layer in self._layers:
127
+ phone = layer(phone, mask)
128
+
129
+ return phone
130
+
131
+ class CFNaiveMelPE(nn.Module):
132
+ def __init__(
133
+ self,
134
+ input_channels,
135
+ out_dims,
136
+ hidden_dims = 512,
137
+ n_layers = 6,
138
+ n_heads = 8,
139
+ f0_max = 1975.5,
140
+ f0_min = 32.70,
141
+ use_fa_norm = False,
142
+ conv_only = False,
143
+ conv_dropout = 0,
144
+ atten_dropout = 0,
145
+ use_harmonic_emb = False
146
+ ):
147
+ super().__init__()
148
+ self.input_channels = input_channels
149
+ self.out_dims = out_dims
150
+ self.hidden_dims = hidden_dims
151
+ self.n_layers = n_layers
152
+ self.n_heads = n_heads
153
+ self.f0_max = f0_max
154
+ self.f0_min = f0_min
155
+ self.use_fa_norm = use_fa_norm
156
+
157
+ self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
158
+ self.input_stack = nn.Sequential(
159
+ nn.Conv1d(
160
+ input_channels,
161
+ hidden_dims,
162
+ 3,
163
+ 1,
164
+ 1
165
+ ),
166
+ nn.GroupNorm(
167
+ 4,
168
+ hidden_dims
169
+ ),
170
+ nn.LeakyReLU(),
171
+ nn.Conv1d(
172
+ hidden_dims,
173
+ hidden_dims,
174
+ 3,
175
+ 1,
176
+ 1
177
+ )
178
+ )
179
+ self.net = ConformerNaiveEncoder(
180
+ num_layers=n_layers,
181
+ num_heads=n_heads,
182
+ dim_model=hidden_dims,
183
+ use_norm=use_fa_norm,
184
+ conv_only=conv_only,
185
+ conv_dropout=conv_dropout,
186
+ atten_dropout=atten_dropout
187
+ )
188
+ self.norm = nn.LayerNorm(hidden_dims)
189
+ self.output_proj = weight_norm(
190
+ nn.Linear(
191
+ hidden_dims,
192
+ out_dims
193
+ )
194
+ )
195
+
196
+ self.cent_table_b = torch.linspace(
197
+ f0_to_cent(torch.Tensor([f0_min]))[0],
198
+ f0_to_cent(torch.Tensor([f0_max]))[0],
199
+ out_dims
200
+ ).detach()
201
+ self.gaussian_blurred_cent_mask_b = (
202
+ 1200 * torch.Tensor([self.f0_max / 10.]).log2()
203
+ )[0].detach()
204
+
205
+ self.register_buffer("cent_table", self.cent_table_b)
206
+ self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
207
+
208
+ def forward(self, x, _h_emb=None):
209
+ x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
210
+
211
+ if self.harmonic_emb is not None:
212
+ if _h_emb is None:
213
+ x += self.harmonic_emb(torch.LongTensor([0]).to(x.device))
214
+ else:
215
+ x += self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
216
+
217
+ return self.output_proj(self.norm(self.net(x))).sigmoid()
218
+
219
+ @torch.no_grad()
220
+ def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
221
+ latent = self.forward(mel)
222
+
223
+ return cent_to_f0(
224
+ (
225
+ latent2cents_decoder(
226
+ self.cent_table,
227
+ latent,
228
+ threshold=threshold
229
+ )
230
+ ) if decoder == "argmax" else (
231
+ latent2cents_local_decoder(
232
+ self.cent_table,
233
+ self.out_dims,
234
+ latent,
235
+ threshold=threshold
236
+ )
237
+ )
238
+ )
239
+
240
+ class FCPE_LEGACY(nn.Module):
241
+ def __init__(
242
+ self,
243
+ input_channel=128,
244
+ out_dims=360,
245
+ n_layers=12,
246
+ n_chans=512,
247
+ f0_max=1975.5,
248
+ f0_min=32.70,
249
+ confidence=False,
250
+ threshold=0.05,
251
+ use_input_conv=True
252
+ ):
253
+ super().__init__()
254
+ self.n_out = out_dims
255
+ self.f0_max = f0_max
256
+ self.f0_min = f0_min
257
+ self.confidence = confidence
258
+ self.threshold = threshold
259
+ self.use_input_conv = use_input_conv
260
+
261
+ self.cent_table_b = torch.Tensor(
262
+ np.linspace(
263
+ f0_to_cent(torch.Tensor([f0_min]))[0],
264
+ f0_to_cent(torch.Tensor([f0_max]))[0],
265
+ out_dims
266
+ )
267
+ )
268
+ self.register_buffer("cent_table", self.cent_table_b)
269
+
270
+ self.stack = nn.Sequential(
271
+ nn.Conv1d(
272
+ input_channel,
273
+ n_chans,
274
+ 3,
275
+ 1,
276
+ 1
277
+ ),
278
+ nn.GroupNorm(
279
+ 4,
280
+ n_chans
281
+ ),
282
+ nn.LeakyReLU(),
283
+ nn.Conv1d(
284
+ n_chans,
285
+ n_chans,
286
+ 3,
287
+ 1,
288
+ 1
289
+ )
290
+ )
291
+ self.decoder = PCmer(
292
+ num_layers=n_layers,
293
+ num_heads=8,
294
+ dim_model=n_chans,
295
+ dim_keys=n_chans,
296
+ dim_values=n_chans,
297
+ residual_dropout=0.1,
298
+ attention_dropout=0.1
299
+ )
300
+ self.norm = nn.LayerNorm(n_chans)
301
+ self.dense_out = weight_norm(
302
+ nn.Linear(
303
+ n_chans,
304
+ self.n_out
305
+ )
306
+ )
307
+
308
+ def forward(self, mel, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None):
309
+ x = self.decoder(self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)
310
+ x = self.dense_out(self.norm(x)).sigmoid()
311
+
312
+ x = cent_to_f0(
313
+ (
314
+ cents_decoder(
315
+ self.cent_table,
316
+ x,
317
+ self.confidence,
318
+ threshold=self.threshold,
319
+ mask=True
320
+ )
321
+ ) if cdecoder == "argmax" else (
322
+ cents_local_decoder(
323
+ self.cent_table,
324
+ x,
325
+ self.n_out,
326
+ self.confidence,
327
+ threshold=self.threshold,
328
+ mask=True
329
+ )
330
+ )
331
+ )
332
+
333
+ x = (1 + x / 700).log() if not return_hz_f0 else x
334
+
335
+ if output_interp_target_length is not None:
336
+ x = F.interpolate(
337
+ torch.where(x == 0, float("nan"), x).transpose(1, 2),
338
+ size=int(output_interp_target_length),
339
+ mode="linear"
340
+ ).transpose(1, 2)
341
+
342
+ x = torch.where(x.isnan(), float(0.0), x)
343
+
344
+ return x
345
+
346
+ def gaussian_blurred_cent(self, cents):
347
+ B, N, _ = cents.size()
348
+
349
+ return (
350
+ -(self.cent_table[None, None, :].expand(B, N, -1) - cents).square() / 1250
351
+ ).exp() * (cents > 0.1) & (
352
+ cents < (1200.0 * np.log2(self.f0_max / 10.0))
353
+ ).float()
354
+
355
+ class InferCFNaiveMelPE(torch.nn.Module):
356
+ def __init__(
357
+ self,
358
+ args,
359
+ state_dict
360
+ ):
361
+ super().__init__()
362
+ self.model = CFNaiveMelPE(
363
+ input_channels=args.mel.num_mels,
364
+ out_dims=args.model.out_dims,
365
+ hidden_dims=args.model.hidden_dims,
366
+ n_layers=args.model.n_layers,
367
+ n_heads=args.model.n_heads,
368
+ f0_max=args.model.f0_max,
369
+ f0_min=args.model.f0_min,
370
+ use_fa_norm=args.model.use_fa_norm,
371
+ conv_only=args.model.conv_only,
372
+ conv_dropout=args.model.conv_dropout,
373
+ atten_dropout=args.model.atten_dropout,
374
+ use_harmonic_emb=False
375
+ )
376
+ self.model.load_state_dict(state_dict)
377
+ self.model.eval()
378
+ self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
379
+
380
+ def forward(self, mel, decoder_mode = "local_argmax", threshold = 0.006):
381
+ with torch.no_grad():
382
+ mels = rearrange(torch.stack([mel], -1), "B T C K -> (B K) T C")
383
+ f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=1)
384
+
385
+ return f0s
386
+
387
+ def infer(
388
+ self,
389
+ mel,
390
+ decoder_mode = "local_argmax",
391
+ threshold = 0.006,
392
+ f0_min = None,
393
+ f0_max = None,
394
+ interp_uv = False,
395
+ output_interp_target_length = None,
396
+ return_uv = False
397
+ ):
398
+ f0 = self.__call__(mel, decoder_mode, threshold)
399
+ f0_for_uv = f0
400
+
401
+ uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
402
+ f0 = f0 * (1 - uv)
403
+
404
+ if interp_uv:
405
+ f0 = batch_interp_with_replacement_detach(
406
+ uv.squeeze(-1).bool(),
407
+ f0.squeeze(-1)
408
+ ).unsqueeze(-1)
409
+
410
+ if f0_max is not None: f0[f0 > f0_max] = f0_max
411
+
412
+ if output_interp_target_length is not None:
413
+ f0 = F.interpolate(
414
+ torch.where(f0 == 0, float("nan"), f0).transpose(1, 2),
415
+ size=int(output_interp_target_length),
416
+ mode="linear"
417
+ ).transpose(1, 2)
418
+
419
+ f0 = torch.where(f0.isnan(), float(0.0), f0)
420
+
421
+ if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
422
+ else: return f0
423
+
424
+ class FCPEInfer_LEGACY:
425
+ def __init__(
426
+ self,
427
+ configs,
428
+ model_path,
429
+ device=None,
430
+ dtype=torch.float32,
431
+ providers=None,
432
+ onnx=False,
433
+ f0_min=50,
434
+ f0_max=1100
435
+ ):
436
+ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
437
+ self.device = device
438
+ self.dtype = dtype
439
+ self.onnx = onnx
440
+ self.f0_min = f0_min
441
+ self.f0_max = f0_max
442
+ self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
443
+
444
+ if self.onnx:
445
+ sess_options = ort.SessionOptions()
446
+ sess_options.log_severity_level = 3
447
+ self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers)
448
+ else:
449
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
450
+ self.args = DotDict(ckpt["config"])
451
+ model = FCPE_LEGACY(
452
+ input_channel=self.args.model.input_channel,
453
+ out_dims=self.args.model.out_dims,
454
+ n_layers=self.args.model.n_layers,
455
+ n_chans=self.args.model.n_chans,
456
+ f0_max=self.f0_max,
457
+ f0_min=self.f0_min,
458
+ confidence=self.args.model.confidence
459
+ )
460
+ model.to(self.device).to(self.dtype)
461
+ model.load_state_dict(ckpt["model"])
462
+ model.eval()
463
+ self.model = model
464
+
465
+ @torch.no_grad()
466
+ def __call__(self, audio, sr, threshold=0.05, p_len=None):
467
+ if not self.onnx: self.model.threshold = threshold
468
+ if not hasattr(self, "numpy_threshold") and self.onnx: self.numpy_threshold = np.array(threshold, dtype=np.float32)
469
+
470
+ mel = self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype)
471
+
472
+ if self.onnx:
473
+ return torch.as_tensor(
474
+ self.model.run(
475
+ [self.model.get_outputs()[0].name],
476
+ {
477
+ self.model.get_inputs()[0].name: mel.detach().cpu().numpy(),
478
+ self.model.get_inputs()[1].name: self.numpy_threshold
479
+ }
480
+ )[0],
481
+ dtype=self.dtype,
482
+ device=self.device
483
+ )
484
+ else:
485
+ return self.model(
486
+ mel=mel,
487
+ return_hz_f0=True,
488
+ output_interp_target_length=p_len
489
+ )
490
+
491
+ class FCPEInfer:
492
+ def __init__(
493
+ self,
494
+ configs,
495
+ model_path,
496
+ device=None,
497
+ dtype=torch.float32,
498
+ providers=None,
499
+ onnx=False,
500
+ f0_min=50,
501
+ f0_max=1100
502
+ ):
503
+ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
504
+ self.device = device
505
+ self.dtype = dtype
506
+ self.onnx = onnx
507
+ self.f0_min = f0_min
508
+ self.f0_max = f0_max
509
+ self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
510
+
511
+ if self.onnx:
512
+ sess_options = ort.SessionOptions()
513
+ sess_options.log_severity_level = 3
514
+ self.model = ort.InferenceSession(decrypt_model(configs, model_path), sess_options=sess_options, providers=providers)
515
+ else:
516
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
517
+ ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
518
+ self.args = DotDict(ckpt["config_dict"])
519
+ model = InferCFNaiveMelPE(self.args, ckpt["model"])
520
+ self.model = model.to(device).to(self.dtype).eval()
521
+
522
+ @torch.no_grad()
523
+ def __call__(self, audio, sr, threshold=0.05, p_len=None):
524
+ if not hasattr(self, "numpy_threshold") and self.onnx: self.numpy_threshold = np.array(threshold, dtype=np.float32)
525
+ mel = self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype)
526
+
527
+ if self.onnx:
528
+ return torch.as_tensor(
529
+ self.model.run(
530
+ [self.model.get_outputs()[0].name],
531
+ {
532
+ self.model.get_inputs()[0].name: mel.detach().cpu().numpy(),
533
+ self.model.get_inputs()[1].name: self.numpy_threshold
534
+ }
535
+ )[0],
536
+ dtype=self.dtype,
537
+ device=self.device
538
+ )
539
+ else:
540
+ return self.model.infer(
541
+ mel,
542
+ threshold=threshold,
543
+ f0_min=self.f0_min,
544
+ f0_max=self.f0_max,
545
+ output_interp_target_length=p_len
546
+ )
547
+
548
+ class FCPE:
549
+ def __init__(
550
+ self,
551
+ configs,
552
+ model_path,
553
+ hop_length=512,
554
+ f0_min=50,
555
+ f0_max=1100,
556
+ dtype=torch.float32,
557
+ device=None,
558
+ sample_rate=16000,
559
+ threshold=0.05,
560
+ providers=None,
561
+ onnx=False,
562
+ legacy=False
563
+ ):
564
+ self.model = FCPEInfer_LEGACY if legacy else FCPEInfer
565
+ self.fcpe = self.model(configs, model_path, device=device, dtype=dtype, providers=providers, onnx=onnx, f0_min=f0_min, f0_max=f0_max)
566
+ self.hop_length = hop_length
567
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
568
+ self.threshold = threshold
569
+ self.sample_rate = sample_rate
570
+ self.dtype = dtype
571
+ self.legacy = legacy
572
+
573
+ def compute_f0(self, wav, p_len=None):
574
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
575
+ p_len = (x.shape[0] // self.hop_length) if p_len is None else p_len
576
+
577
+ f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len)
578
+ f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
579
+
580
+ if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len)
581
+ return f0.cpu().numpy()
infer/lib/predictors/FCPE/attentions.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ import torch.nn.functional as F
5
+
6
+ from torch import nn, einsum
7
+ from functools import partial
8
+ from einops import rearrange, repeat, pack, unpack
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+ def default(value, d):
14
+ return value if exists(value) else d
15
+
16
+ def empty(tensor):
17
+ return tensor.numel() == 0
18
+
19
+ def pad_to_multiple(tensor, multiple, dim=-1, value=0):
20
+ seqlen = tensor.shape[dim]
21
+ m = seqlen / multiple
22
+
23
+ if m.is_integer(): return False, tensor
24
+ return True, F.pad(tensor, (*((0,) * (-1 - dim) * 2), 0, (math.ceil(m) * multiple - seqlen)), value = value)
25
+
26
+ def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
27
+ t = x.shape[1]
28
+ dims = (len(x.shape) - dim) * (0, 0)
29
+ padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
30
+
31
+ return torch.cat([padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)], dim = dim)
32
+
33
+ def rotate_half(x):
34
+ x1, x2 = rearrange(x, 'b ... (r d) -> b ... r d', r = 2).unbind(dim = -2)
35
+ return torch.cat((-x2, x1), dim = -1)
36
+
37
+ def apply_rotary_pos_emb(q, k, freqs, scale = 1):
38
+ q_len = q.shape[-2]
39
+ q_freqs = freqs[..., -q_len:, :]
40
+
41
+ inv_scale = scale ** -1
42
+ if scale.ndim == 2: scale = scale[-q_len:, :]
43
+
44
+ q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
45
+ k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
46
+
47
+ return q, k
48
+
49
+ def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
50
+ unstructured_block = torch.randn((cols, cols), device=device)
51
+
52
+ q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
53
+ q, r = map(lambda t: t.to(device), (q, r))
54
+
55
+ if qr_uniform_q:
56
+ d = r.diag(0)
57
+ q *= d.sign()
58
+
59
+ return q.t()
60
+
61
+ def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
62
+ nb_full_blocks = int(nb_rows / nb_columns)
63
+ block_list = []
64
+
65
+ for _ in range(nb_full_blocks):
66
+ block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device))
67
+
68
+ remaining_rows = nb_rows - nb_full_blocks * nb_columns
69
+
70
+ if remaining_rows > 0:
71
+ block_list.append(
72
+ orthogonal_matrix_chunk(
73
+ nb_columns,
74
+ qr_uniform_q=qr_uniform_q,
75
+ device=device
76
+ )[:remaining_rows]
77
+ )
78
+ if scaling == 0:
79
+ multiplier = torch.randn(
80
+ (nb_rows, nb_columns),
81
+ device=device
82
+ ).norm(dim=1)
83
+ elif scaling == 1:
84
+ multiplier = math.sqrt(
85
+ (float(nb_columns))
86
+ ) * torch.ones(
87
+ (nb_rows,),
88
+ device=device
89
+ )
90
+ else: raise ValueError(f"{scaling} != 0, 1")
91
+
92
+ return multiplier.diag() @ torch.cat(block_list)
93
+
94
+ def linear_attention(q, k, v):
95
+ return einsum(
96
+ "...ed,...nd->...ne",
97
+ k,
98
+ q
99
+ ) if v is None else einsum(
100
+ "...de,...nd,...n->...ne",
101
+ einsum(
102
+ "...nd,...ne->...de",
103
+ k,
104
+ v
105
+ ),
106
+ q,
107
+ 1.0 / (einsum(
108
+ "...nd,...d->...n",
109
+ q,
110
+ k.sum(dim=-2).type_as(q)
111
+ ) + 1e-8)
112
+ )
113
+
114
+ def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
115
+ b, h, *_ = data.shape
116
+
117
+ data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
118
+ ratio = projection_matrix.shape[0] ** -0.5
119
+
120
+ data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), repeat(projection_matrix, "j d -> b h j d", b=b, h=h).type_as(data))
121
+ diag_data = (((data**2).sum(dim=-1) / 2.0) * (data_normalizer**2)).unsqueeze(dim=-1)
122
+
123
+ return (ratio * ((data_dash - diag_data - data_dash.max(dim=-1, keepdim=True).values).exp() + eps) if is_query else ratio * ((data_dash - diag_data + eps).exp())).type_as(data)
124
+
125
+ class SinusoidalEmbeddings(nn.Module):
126
+ def __init__(
127
+ self,
128
+ dim,
129
+ scale_base = None,
130
+ use_xpos = False,
131
+ theta = 10000
132
+ ):
133
+ super().__init__()
134
+ inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
135
+ self.register_buffer('inv_freq', inv_freq)
136
+ self.use_xpos = use_xpos
137
+ self.scale_base = scale_base
138
+ assert not (use_xpos and not exists(scale_base))
139
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
140
+ self.register_buffer('scale', scale, persistent = False)
141
+
142
+ def forward(self, x):
143
+ seq_len, device = x.shape[-2], x.device
144
+ t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
145
+
146
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
147
+ freqs = torch.cat((freqs, freqs), dim = -1)
148
+
149
+ if not self.use_xpos: return freqs, torch.ones(1, device = device)
150
+
151
+ power = (t - (seq_len // 2)) / self.scale_base
152
+ scale = self.scale ** rearrange(power, 'n -> n 1')
153
+
154
+ return freqs, torch.cat((scale, scale), dim = -1)
155
+
156
+ class LocalAttention(nn.Module):
157
+ def __init__(
158
+ self,
159
+ window_size,
160
+ causal = False,
161
+ look_backward = 1,
162
+ look_forward = None,
163
+ dropout = 0.,
164
+ shared_qk = False,
165
+ rel_pos_emb_config = None,
166
+ dim = None,
167
+ autopad = False,
168
+ exact_windowsize = False,
169
+ scale = None,
170
+ use_rotary_pos_emb = True,
171
+ use_xpos = False,
172
+ xpos_scale_base = None
173
+ ):
174
+ super().__init__()
175
+ look_forward = default(look_forward, 0 if causal else 1)
176
+ assert not (causal and look_forward > 0)
177
+ self.scale = scale
178
+ self.window_size = window_size
179
+ self.autopad = autopad
180
+ self.exact_windowsize = exact_windowsize
181
+ self.causal = causal
182
+ self.look_backward = look_backward
183
+ self.look_forward = look_forward
184
+ self.dropout = nn.Dropout(dropout)
185
+ self.shared_qk = shared_qk
186
+ self.rel_pos = None
187
+ self.use_xpos = use_xpos
188
+ if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):
189
+ if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0]
190
+ self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = default(xpos_scale_base, window_size // 2))
191
+
192
+ def forward(self, q, k, v, mask = None, input_mask = None, attn_bias = None, window_size = None):
193
+ mask = default(mask, input_mask)
194
+ assert not (exists(window_size) and not self.use_xpos)
195
+
196
+ (
197
+ _,
198
+ autopad,
199
+ pad_value,
200
+ window_size,
201
+ causal,
202
+ look_backward,
203
+ look_forward,
204
+ shared_qk
205
+ ) = (
206
+ q.shape,
207
+ self.autopad,
208
+ -1,
209
+ default(
210
+ window_size,
211
+ self.window_size
212
+ ),
213
+ self.causal,
214
+ self.look_backward,
215
+ self.look_forward,
216
+ self.shared_qk
217
+ )
218
+
219
+ (q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
220
+
221
+ if autopad:
222
+ orig_seq_len = q.shape[1]
223
+ (_, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
224
+
225
+ b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
226
+ scale = default(self.scale, dim_head ** -0.5)
227
+
228
+ assert (n % window_size) == 0
229
+ windows = n // window_size
230
+
231
+ if shared_qk: k = F.normalize(k, dim = -1).type(k.dtype)
232
+
233
+ seq = torch.arange(n, device = device)
234
+ b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)
235
+ bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))
236
+
237
+ bq = bq * scale
238
+ look_around_kwargs = dict(backward = look_backward, forward = look_forward, pad_value = pad_value)
239
+
240
+ bk = look_around(bk, **look_around_kwargs)
241
+ bv = look_around(bv, **look_around_kwargs)
242
+
243
+ if exists(self.rel_pos):
244
+ pos_emb, xpos_scale = self.rel_pos(bk)
245
+ bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)
246
+
247
+ bq_t = b_t
248
+ bq_k = look_around(b_t, **look_around_kwargs)
249
+ bq_t = rearrange(bq_t, '... i -> ... i 1')
250
+ bq_k = rearrange(bq_k, '... j -> ... 1 j')
251
+
252
+ pad_mask = bq_k == pad_value
253
+ sim = einsum('b h i e, b h j e -> b h i j', bq, bk)
254
+
255
+ if exists(attn_bias):
256
+ heads = attn_bias.shape[0]
257
+ assert (b % heads) == 0
258
+
259
+ attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
260
+ sim = sim + attn_bias
261
+
262
+ mask_value = -torch.finfo(sim.dtype).max
263
+ if shared_qk:
264
+ self_mask = bq_t == bq_k
265
+ sim = sim.masked_fill(self_mask, -5e4)
266
+ del self_mask
267
+
268
+ if causal:
269
+ causal_mask = bq_t < bq_k
270
+ if self.exact_windowsize: causal_mask = causal_mask | (bq_t > (bq_k + (self.window_size * self.look_backward)))
271
+ sim = sim.masked_fill(causal_mask, mask_value)
272
+ del causal_mask
273
+
274
+ sim = sim.masked_fill(
275
+ ((bq_k - (self.window_size * self.look_forward)) > bq_t) | (bq_t > (bq_k + (self.window_size * self.look_backward))) | pad_mask,
276
+ mask_value
277
+ ) if not causal and self.exact_windowsize else sim.masked_fill(
278
+ pad_mask,
279
+ mask_value
280
+ )
281
+
282
+ if exists(mask):
283
+ batch = mask.shape[0]
284
+ assert (b % batch) == 0
285
+
286
+ h = b // mask.shape[0]
287
+ if autopad: _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)
288
+
289
+ mask = repeat(
290
+ rearrange(
291
+ look_around(
292
+ rearrange(
293
+ mask,
294
+ '... (w n) -> (...) w n',
295
+ w = windows,
296
+ n = window_size
297
+ ),
298
+ **{
299
+ **look_around_kwargs,
300
+ 'pad_value': False
301
+ }
302
+ ),
303
+ '... j -> ... 1 j'
304
+ ),
305
+ 'b ... -> (b h) ...',
306
+ h = h
307
+ )
308
+
309
+ sim = sim.masked_fill(~mask, mask_value)
310
+ del mask
311
+
312
+ out = rearrange(
313
+ einsum(
314
+ 'b h i j, b h j e -> b h i e',
315
+ self.dropout(sim.softmax(dim = -1)),
316
+ bv
317
+ ),
318
+ 'b w n d -> b (w n) d'
319
+ )
320
+
321
+ if autopad: out = out[:, :orig_seq_len, :]
322
+ out, *_ = unpack(out, packed_shape, '* n d')
323
+
324
+ return out
325
+
326
+ class FastAttention(nn.Module):
327
+ def __init__(
328
+ self,
329
+ dim_heads,
330
+ nb_features=None,
331
+ ortho_scaling=0,
332
+ causal=False,
333
+ generalized_attention=False,
334
+ kernel_fn=nn.ReLU(),
335
+ qr_uniform_q=False,
336
+ no_projection=False
337
+ ):
338
+ super().__init__()
339
+ nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
340
+ self.dim_heads = dim_heads
341
+ self.nb_features = nb_features
342
+ self.ortho_scaling = ortho_scaling
343
+ self.create_projection = partial(
344
+ gaussian_orthogonal_random_matrix,
345
+ nb_rows=self.nb_features,
346
+ nb_columns=dim_heads,
347
+ scaling=ortho_scaling,
348
+ qr_uniform_q=qr_uniform_q
349
+ )
350
+ projection_matrix = self.create_projection()
351
+ self.register_buffer("projection_matrix", projection_matrix)
352
+ self.generalized_attention = generalized_attention
353
+ self.kernel_fn = kernel_fn
354
+ self.no_projection = no_projection
355
+ self.causal = causal
356
+
357
+ @torch.no_grad()
358
+ def redraw_projection_matrix(self):
359
+ projections = self.create_projection()
360
+ self.projection_matrix.copy_(projections)
361
+ del projections
362
+
363
+ def forward(self, q, k, v):
364
+ if self.no_projection: q, k = q.softmax(dim=-1), (k.exp() if self.causal else k.softmax(dim=-2))
365
+ else:
366
+ create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=q.device)
367
+ q, k = create_kernel(q, is_query=True), create_kernel(k, is_query=False)
368
+
369
+ attn_fn = linear_attention if not self.causal else self.causal_linear_fn
370
+ return attn_fn(q, k, None) if v is None else attn_fn(q, k, v)
371
+
372
+ class SelfAttention(nn.Module):
373
+ def __init__(
374
+ self,
375
+ dim,
376
+ causal=False,
377
+ heads=8,
378
+ dim_head=64,
379
+ local_heads=0,
380
+ local_window_size=256,
381
+ nb_features=None,
382
+ feature_redraw_interval=1000,
383
+ generalized_attention=False,
384
+ kernel_fn=nn.ReLU(),
385
+ qr_uniform_q=False,
386
+ dropout=0.0,
387
+ no_projection=False
388
+ ):
389
+ super().__init__()
390
+ assert dim % heads == 0
391
+ dim_head = default(dim_head, dim // heads)
392
+ inner_dim = dim_head * heads
393
+ self.fast_attention = FastAttention(
394
+ dim_head,
395
+ nb_features,
396
+ causal=causal,
397
+ generalized_attention=generalized_attention,
398
+ kernel_fn=kernel_fn,
399
+ qr_uniform_q=qr_uniform_q,
400
+ no_projection=no_projection
401
+ )
402
+ self.heads = heads
403
+ self.global_heads = heads - local_heads
404
+ self.local_attn = (
405
+ LocalAttention(
406
+ window_size=local_window_size,
407
+ causal=causal,
408
+ autopad=True,
409
+ dropout=dropout,
410
+ look_forward=int(not causal),
411
+ rel_pos_emb_config=(dim_head, local_heads)
412
+ ) if local_heads > 0 else None
413
+ )
414
+ self.to_q = nn.Linear(dim, inner_dim)
415
+ self.to_k = nn.Linear(dim, inner_dim)
416
+ self.to_v = nn.Linear(dim, inner_dim)
417
+ self.to_out = nn.Linear(inner_dim, dim)
418
+ self.dropout = nn.Dropout(dropout)
419
+
420
+ @torch.no_grad()
421
+ def redraw_projection_matrix(self):
422
+ self.fast_attention.redraw_projection_matrix()
423
+
424
+ def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
425
+ _, _, _, h, gh = *x.shape, self.heads, self.global_heads
426
+ cross_attend = exists(context)
427
+ context = default(context, x)
428
+ context_mask = default(context_mask, mask) if not cross_attend else context_mask
429
+
430
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (self.to_q(x), self.to_k(context), self.to_v(context)))
431
+ (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
432
+
433
+ attn_outs = []
434
+
435
+ if not empty(q):
436
+ if exists(context_mask): v.masked_fill_(~context_mask[:, None, :, None], 0.0)
437
+ if cross_attend: pass
438
+ else: out = self.fast_attention(q, k, v)
439
+
440
+ attn_outs.append(out)
441
+
442
+ if not empty(lq):
443
+ assert (not cross_attend), "not cross_attend"
444
+
445
+ out = self.local_attn(lq, lk, lv, input_mask=mask)
446
+ attn_outs.append(out)
447
+
448
+ return self.dropout(self.to_out(rearrange(torch.cat(attn_outs, dim=1), "b h n d -> b n (h d)")))
infer/lib/predictors/FCPE/encoder.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ sys.path.append(os.getcwd())
8
+
9
+ from main.library.predictors.FCPE.attentions import SelfAttention
10
+ from main.library.predictors.FCPE.utils import calc_same_padding, Transpose, GLU, Swish
11
+
12
+ class ConformerConvModule_LEGACY(nn.Module):
13
+ def __init__(
14
+ self,
15
+ dim,
16
+ causal=False,
17
+ expansion_factor=2,
18
+ kernel_size=31,
19
+ dropout=0.0
20
+ ):
21
+ super().__init__()
22
+ inner_dim = dim * expansion_factor
23
+ self.net = nn.Sequential(
24
+ nn.LayerNorm(dim),
25
+ Transpose((1, 2)),
26
+ nn.Conv1d(dim, inner_dim * 2, 1),
27
+ GLU(dim=1),
28
+ DepthWiseConv1d_LEGACY(
29
+ inner_dim,
30
+ inner_dim,
31
+ kernel_size=kernel_size,
32
+ padding=(
33
+ calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
34
+ )
35
+ ),
36
+ Swish(),
37
+ nn.Conv1d(inner_dim, dim, 1),
38
+ Transpose((1, 2)),
39
+ nn.Dropout(dropout)
40
+ )
41
+
42
+ def forward(self, x):
43
+ return self.net(x)
44
+
45
+ class ConformerConvModule(nn.Module):
46
+ def __init__(
47
+ self,
48
+ dim,
49
+ expansion_factor=2,
50
+ kernel_size=31,
51
+ dropout=0
52
+ ):
53
+ super().__init__()
54
+ inner_dim = dim * expansion_factor
55
+ self.net = nn.Sequential(
56
+ nn.LayerNorm(dim),
57
+ Transpose((1, 2)),
58
+ nn.Conv1d(dim, inner_dim * 2, 1),
59
+ nn.GLU(dim=1),
60
+ DepthWiseConv1d(
61
+ inner_dim,
62
+ inner_dim,
63
+ kernel_size=kernel_size,
64
+ padding=calc_same_padding(kernel_size)[0],
65
+ groups=inner_dim
66
+ ),
67
+ nn.SiLU(),
68
+ nn.Conv1d(inner_dim, dim, 1),
69
+ Transpose((1, 2)),
70
+ nn.Dropout(dropout)
71
+ )
72
+
73
+ def forward(self, x):
74
+ return self.net(x)
75
+
76
+ class DepthWiseConv1d_LEGACY(nn.Module):
77
+ def __init__(
78
+ self,
79
+ chan_in,
80
+ chan_out,
81
+ kernel_size,
82
+ padding
83
+ ):
84
+ super().__init__()
85
+ self.padding = padding
86
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
87
+
88
+ def forward(self, x):
89
+ return self.conv(F.pad(x, self.padding))
90
+
91
+ class DepthWiseConv1d(nn.Module):
92
+ def __init__(
93
+ self,
94
+ chan_in,
95
+ chan_out,
96
+ kernel_size,
97
+ padding,
98
+ groups
99
+ ):
100
+ super().__init__()
101
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=kernel_size, padding=padding, groups=groups)
102
+
103
+ def forward(self, x):
104
+ return self.conv(x)
105
+
106
+ class EncoderLayer(nn.Module):
107
+ def __init__(
108
+ self,
109
+ parent
110
+ ):
111
+ super().__init__()
112
+ self.conformer = ConformerConvModule_LEGACY(parent.dim_model)
113
+ self.norm = nn.LayerNorm(parent.dim_model)
114
+ self.dropout = nn.Dropout(parent.residual_dropout)
115
+ self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
116
+
117
+ def forward(self, phone, mask=None):
118
+ phone = phone + (self.attn(self.norm(phone), mask=mask))
119
+ return phone + (self.conformer(phone))
120
+
121
+ class ConformerNaiveEncoder(nn.Module):
122
+ def __init__(
123
+ self,
124
+ num_layers,
125
+ num_heads,
126
+ dim_model,
127
+ use_norm = False,
128
+ conv_only = False,
129
+ conv_dropout = 0,
130
+ atten_dropout = 0
131
+ ):
132
+ super().__init__()
133
+ self.num_layers = num_layers
134
+ self.num_heads = num_heads
135
+ self.dim_model = dim_model
136
+ self.use_norm = use_norm
137
+ self.residual_dropout = 0.1
138
+ self.attention_dropout = 0.1
139
+ self.encoder_layers = nn.ModuleList([
140
+ CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout)
141
+ for _ in range(num_layers)
142
+ ])
143
+
144
+ def forward(self, x, mask=None):
145
+ for (_, layer) in enumerate(self.encoder_layers):
146
+ x = layer(x, mask)
147
+
148
+ return x
149
+
150
+ class CFNEncoderLayer(nn.Module):
151
+ def __init__(
152
+ self,
153
+ dim_model,
154
+ num_heads = 8,
155
+ use_norm = False,
156
+ conv_only = False,
157
+ conv_dropout = 0,
158
+ atten_dropout = 0
159
+ ):
160
+ super().__init__()
161
+ self.conformer = (
162
+ nn.Sequential(
163
+ ConformerConvModule(dim_model),
164
+ nn.Dropout(conv_dropout)
165
+ )
166
+ ) if conv_dropout > 0 else (
167
+ ConformerConvModule(dim_model)
168
+ )
169
+
170
+ self.norm = nn.LayerNorm(dim_model)
171
+ self.dropout = nn.Dropout(0.1)
172
+
173
+ self.attn = SelfAttention(
174
+ dim=dim_model,
175
+ heads=num_heads,
176
+ causal=False,
177
+ use_norm=use_norm,
178
+ dropout=atten_dropout
179
+ ) if not conv_only else None
180
+
181
+ def forward(self, x, mask=None):
182
+ if self.attn is not None: x = x + (self.attn(self.norm(x), mask=mask))
183
+ return x + (self.conformer(x))
infer/lib/predictors/FCPE/stft.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+
8
+ from librosa.filters import mel
9
+
10
+ sys.path.append(os.getcwd())
11
+
12
+ class STFT:
13
+ def __init__(
14
+ self,
15
+ sr=22050,
16
+ n_mels=80,
17
+ n_fft=1024,
18
+ win_size=1024,
19
+ hop_length=256,
20
+ fmin=20,
21
+ fmax=11025,
22
+ clip_val=1e-5
23
+ ):
24
+ self.target_sr = sr
25
+ self.n_mels = n_mels
26
+ self.n_fft = n_fft
27
+ self.win_size = win_size
28
+ self.hop_length = hop_length
29
+ self.fmin = fmin
30
+ self.fmax = fmax
31
+ self.clip_val = clip_val
32
+ self.mel_basis = {}
33
+ self.hann_window = {}
34
+
35
+ def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
36
+ n_fft = self.n_fft
37
+ win_size = self.win_size
38
+ hop_length = self.hop_length
39
+
40
+ fmax = self.fmax
41
+ factor = 2 ** (keyshift / 12)
42
+
43
+ win_size_new = int(np.round(win_size * factor))
44
+ hop_length_new = int(np.round(hop_length * speed))
45
+
46
+ mel_basis = self.mel_basis if not train else {}
47
+ hann_window = self.hann_window if not train else {}
48
+ mel_basis_key = str(fmax) + "_" + str(y.device)
49
+
50
+ if mel_basis_key not in mel_basis:
51
+ mel_basis[mel_basis_key] = torch.from_numpy(
52
+ mel(
53
+ sr=self.target_sr,
54
+ n_fft=n_fft,
55
+ n_mels=self.n_mels,
56
+ fmin=self.fmin,
57
+ fmax=fmax
58
+ )
59
+ ).float().to(y.device)
60
+
61
+ keyshift_key = str(keyshift) + "_" + str(y.device)
62
+ if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
63
+
64
+ pad_left = (win_size_new - hop_length_new) // 2
65
+ pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
66
+
67
+ pad = F.pad(y.unsqueeze(1), (pad_left, pad_right), mode="reflect" if pad_right < y.size(-1) else "constant").squeeze(1)
68
+ n_fft = int(np.round(n_fft * factor))
69
+
70
+ if str(y.device).startswith(("ocl", "privateuseone")):
71
+ if not hasattr(self, "stft"):
72
+ from main.library.backends.utils import STFT as _STFT
73
+
74
+ self.stft = _STFT(
75
+ filter_length=n_fft,
76
+ hop_length=hop_length_new,
77
+ win_length=win_size_new
78
+ ).to(y.device)
79
+
80
+ spec = self.stft.transform(pad, 1e-9)
81
+ else:
82
+ spec = torch.stft(
83
+ pad,
84
+ n_fft,
85
+ hop_length=hop_length_new,
86
+ win_length=win_size_new,
87
+ window=hann_window[keyshift_key],
88
+ center=center,
89
+ pad_mode="reflect",
90
+ normalized=False,
91
+ onesided=True,
92
+ return_complex=True
93
+ )
94
+
95
+ spec = (spec.real.pow(2) + spec.imag.pow(2) + 1e-9).sqrt()
96
+
97
+ if keyshift != 0:
98
+ size = n_fft // 2 + 1
99
+ resize = spec.size(1)
100
+ spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :]) * win_size / win_size_new
101
+
102
+ return ((mel_basis[mel_basis_key] @ spec).clamp(min=self.clip_val) * 1).log()
infer/lib/predictors/FCPE/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from torch import nn
5
+ from io import BytesIO
6
+ from Crypto.Cipher import AES
7
+ from Crypto.Util.Padding import unpad
8
+
9
+ def decrypt_model(configs, input_path):
10
+ with open(input_path, "rb") as f:
11
+ data = f.read()
12
+
13
+ with open(
14
+ os.path.join(configs["binary_path"], "decrypt.bin"),
15
+ "rb"
16
+ ) as f:
17
+ key = f.read()
18
+
19
+ return BytesIO(
20
+ unpad(
21
+ AES.new(
22
+ key,
23
+ AES.MODE_CBC,
24
+ data[:16]
25
+ ).decrypt(data[16:]),
26
+ AES.block_size
27
+ )
28
+ ).read()
29
+
30
+ def calc_same_padding(kernel_size):
31
+ pad = kernel_size // 2
32
+ return (pad, pad - (kernel_size + 1) % 2)
33
+
34
+ def torch_interp(x, xp, fp):
35
+ sort_idx = xp.argsort()
36
+ xp = xp[sort_idx]
37
+ fp = fp[sort_idx]
38
+
39
+ right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
40
+ left_idxs = (right_idxs - 1).clamp(min=0)
41
+ x_left = xp[left_idxs]
42
+ y_left = fp[left_idxs]
43
+
44
+ interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
45
+ interp_vals[x < xp[0]] = fp[0]
46
+ interp_vals[x > xp[-1]] = fp[-1]
47
+
48
+ return interp_vals
49
+
50
+ def batch_interp_with_replacement_detach(uv, f0):
51
+ result = f0.clone()
52
+
53
+ for i in range(uv.shape[0]):
54
+ interp_vals = torch_interp(
55
+ torch.where(uv[i])[-1],
56
+ torch.where(~uv[i])[-1],
57
+ f0[i][~uv[i]]
58
+ ).detach()
59
+
60
+ result[i][uv[i]] = interp_vals
61
+
62
+ return result
63
+
64
+ class DotDict(dict):
65
+ def __getattr__(*args):
66
+ val = dict.get(*args)
67
+ return DotDict(val) if type(val) is dict else val
68
+
69
+ __setattr__ = dict.__setitem__
70
+ __delattr__ = dict.__delitem__
71
+
72
+ class Swish(nn.Module):
73
+ def forward(self, x):
74
+ return x * x.sigmoid()
75
+
76
+ class Transpose(nn.Module):
77
+ def __init__(self, dims):
78
+ super().__init__()
79
+ assert len(dims) == 2, "dims == 2"
80
+ self.dims = dims
81
+
82
+ def forward(self, x):
83
+ return x.transpose(*self.dims)
84
+
85
+ class GLU(nn.Module):
86
+ def __init__(self, dim):
87
+ super().__init__()
88
+ self.dim = dim
89
+
90
+ def forward(self, x):
91
+ out, gate = x.chunk(2, dim=self.dim)
92
+ return out * gate.sigmoid()
infer/lib/predictors/FCPE/wav2mel.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ from torchaudio.transforms import Resample
6
+
7
+ sys.path.append(os.getcwd())
8
+
9
+ from main.library.predictors.FCPE.stft import STFT
10
+
11
+ class Wav2Mel:
12
+ def __init__(
13
+ self,
14
+ device=None,
15
+ dtype=torch.float32
16
+ ):
17
+ self.sample_rate = 16000
18
+ self.hop_size = 160
19
+
20
+ if device is None:
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ self.device = device
24
+ self.dtype = dtype
25
+ self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
26
+ self.resample_kernel = {}
27
+
28
+ def extract_nvstft(
29
+ self,
30
+ audio,
31
+ keyshift=0,
32
+ train=False
33
+ ):
34
+ return self.stft.get_mel(
35
+ audio,
36
+ keyshift=keyshift,
37
+ train=train
38
+ ).transpose(1, 2)
39
+
40
+ def extract_mel(
41
+ self,
42
+ audio,
43
+ sample_rate,
44
+ keyshift=0,
45
+ train=False
46
+ ):
47
+ audio = audio.to(self.dtype).to(self.device)
48
+
49
+ if sample_rate == self.sample_rate:
50
+ audio_res = audio
51
+ else:
52
+ key_str = str(sample_rate)
53
+ if key_str not in self.resample_kernel:
54
+ self.resample_kernel[key_str] = Resample(
55
+ sample_rate,
56
+ self.sample_rate,
57
+ lowpass_filter_width=128
58
+ )
59
+
60
+ self.resample_kernel[key_str] = (
61
+ self.resample_kernel[key_str].to(self.dtype).to(self.device)
62
+ )
63
+
64
+ audio_res = self.resample_kernel[key_str](audio)
65
+
66
+ mel = self.extract_nvstft(
67
+ audio_res,
68
+ keyshift=keyshift,
69
+ train=train
70
+ )
71
+
72
+ n_frames = int(audio.shape[1] // self.hop_size) + 1
73
+ mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
74
+
75
+ return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
76
+
77
+ def __call__(self, audio, sample_rate, keyshift=0, train=False):
78
+ return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)