Pj12 commited on
Commit
7c31b05
·
verified ·
1 Parent(s): 35e2e4b

Upload 4 files

Browse files
extract_feature_print.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, traceback
2
+ from transformers import HubertModel
3
+ import librosa
4
+ from torch import nn
5
+ import torch
6
+
7
+ import json
8
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
9
+ os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
10
+
11
+ device=sys.argv[1]
12
+ n_part = int(sys.argv[2])
13
+ i_part = int(sys.argv[3])
14
+ if len(sys.argv) == 6:
15
+ exp_dir = sys.argv[4]
16
+ version = sys.argv[5]
17
+ else:
18
+ i_gpu = sys.argv[4]
19
+ exp_dir = sys.argv[5]
20
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
21
+ version = sys.argv[6]
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import soundfile as sf
25
+ import numpy as np
26
+ from fairseq import checkpoint_utils
27
+
28
+ #device = "cpu"
29
+ if torch.cuda.is_available():
30
+ device = "cuda"
31
+ elif torch.backends.mps.is_available():
32
+ device = "mps"
33
+
34
+ version_config_paths = [
35
+ os.path.join("", "32k.json"),
36
+ os.path.join("", "40k.json"),
37
+ os.path.join("", "48k.json"),
38
+ os.path.join("", "48k_v2.json"),
39
+ os.path.join("", "40k.json"),
40
+ os.path.join("", "32k_v2.json"),
41
+ ]
42
+
43
+ class Config:
44
+ def __init__(self):
45
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
46
+ self.is_half = self.device != "cpu"
47
+ self.gpu_name = (
48
+ torch.cuda.get_device_name(int(self.device.split(":")[-1]))
49
+ if self.device.startswith("cuda")
50
+ else None
51
+ )
52
+ self.json_config = self.load_config_json()
53
+ self.gpu_mem = None
54
+ self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
55
+
56
+ def load_config_json(self) -> dict:
57
+ configs = {}
58
+ for config_file in version_config_paths:
59
+ config_path = os.path.join("configs", config_file)
60
+ with open(config_path, "r") as f:
61
+ configs[config_file] = json.load(f)
62
+ return configs
63
+
64
+ def has_mps(self) -> bool:
65
+ # Check if Metal Performance Shaders are available - for macOS 12.3+.
66
+ return torch.backends.mps.is_available()
67
+
68
+ def has_xpu(self) -> bool:
69
+ # Check if XPU is available.
70
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
71
+
72
+ def set_precision(self, precision):
73
+ if precision not in ["fp32", "fp16"]:
74
+ raise ValueError("Invalid precision type. Must be 'fp32' or 'fp16'.")
75
+
76
+ fp16_run_value = precision == "fp16"
77
+ preprocess_target_version = "3.7" if precision == "fp16" else "3.0"
78
+ preprocess_path = os.path.join(
79
+ os.path.dirname(__file__),
80
+ os.pardir,
81
+ ""
82
+ "preprocess.py",
83
+ )
84
+
85
+ for config_path in version_config_paths:
86
+ full_config_path = os.path.join("configs", config_path)
87
+ try:
88
+ with open(full_config_path, "r") as f:
89
+ config = json.load(f)
90
+ config["train"]["fp16_run"] = fp16_run_value
91
+ with open(full_config_path, "w") as f:
92
+ json.dump(config, f, indent=4)
93
+ except FileNotFoundError:
94
+ print(f"File not found: {full_config_path}")
95
+
96
+ if os.path.exists(preprocess_path):
97
+ with open(preprocess_path, "r") as f:
98
+ preprocess_content = f.read()
99
+ preprocess_content = preprocess_content.replace(
100
+ "3.0" if precision == "fp16" else "3.7", preprocess_target_version
101
+ )
102
+ with open(preprocess_path, "w") as f:
103
+ f.write(preprocess_content)
104
+
105
+ return f"Overwritten preprocess and config.json to use {precision}."
106
+
107
+ def get_precision(self):
108
+ if not version_config_paths:
109
+ raise FileNotFoundError("No configuration paths provided.")
110
+
111
+ full_config_path = os.path.join("configs", version_config_paths[0])
112
+ try:
113
+ with open(full_config_path, "r") as f:
114
+ config = json.load(f)
115
+ fp16_run_value = config["train"].get("fp16_run", False)
116
+ precision = "fp16" if fp16_run_value else "fp32"
117
+ return precision
118
+ except FileNotFoundError:
119
+ print(f"File not found: {full_config_path}")
120
+ return None
121
+
122
+ def device_config(self) -> tuple:
123
+ if self.device.startswith("cuda"):
124
+ self.set_cuda_config()
125
+ elif self.has_mps():
126
+ self.device = "mps"
127
+ self.is_half = False
128
+ self.set_precision("fp32")
129
+ else:
130
+ self.device = "cpu"
131
+ self.is_half = False
132
+ self.set_precision("fp32")
133
+
134
+ # Configuration for 6GB GPU memory
135
+ x_pad, x_query, x_center, x_max = (
136
+ (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
137
+ )
138
+ if self.gpu_mem is not None and self.gpu_mem <= 4:
139
+ # Configuration for 5GB GPU memory
140
+ x_pad, x_query, x_center, x_max = (1, 5, 30, 32)
141
+
142
+ return x_pad, x_query, x_center, x_max
143
+
144
+ def set_cuda_config(self):
145
+ i_device = int(self.device.split(":")[-1])
146
+ self.gpu_name = torch.cuda.get_device_name(i_device)
147
+ low_end_gpus = ["16", "P40", "P10", "1060", "1070", "1080"]
148
+ if (
149
+ any(gpu in self.gpu_name for gpu in low_end_gpus)
150
+ and "V100" not in self.gpu_name.upper()
151
+ ):
152
+ self.is_half = False
153
+ self.set_precision("fp32")
154
+
155
+ self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (
156
+ 1024**3
157
+ )
158
+ config = Config()
159
+
160
+ def load_audio(file, sample_rate):
161
+ try:
162
+ file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
163
+ audio, sr = sf.read(file)
164
+ if len(audio.shape) > 1:
165
+ audio = librosa.to_mono(audio.T)
166
+ if sr != sample_rate:
167
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
168
+ except Exception as error:
169
+ raise RuntimeError(f"An error occurred loading the audio: {error}")
170
+
171
+ return audio.flatten()
172
+
173
+ #HuggingFacePlaceHolder = None
174
+ class HubertModelWithFinalProj(HubertModel):
175
+ def __init__(self, config):
176
+ super().__init__(config)
177
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
178
+ print(config.hidden_size, config.classifier_proj_size)
179
+
180
+ f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
181
+
182
+
183
+ def printt(strr):
184
+ print(strr)
185
+ f.write("%s\n" % strr)
186
+ f.flush()
187
+
188
+
189
+ printt(sys.argv)
190
+ model_path = sys.argv[7]
191
+ Custom_Embed = False
192
+ sample_embedding = sys.argv[8]
193
+ if os.path.split(model_path)[-1] == "Custom" and sample_embedding == "hubert_base":
194
+ model_path = "hubert_base.pt"
195
+ Custom_Embed = True
196
+ elif os.path.split(model_path)[-1] == "Custom" and sample_embedding == "contentvec_base":
197
+ model_path = "contentvec_base.pt"
198
+ Custom_Embed = True
199
+ elif os.path.split(model_path)[-1] == "Custom" and sample_embedding == "hubert_base_japanese":
200
+ model_path = "japanese_hubert_base.pt"
201
+ Custom_Embed = True
202
+ elif os.path.split(model_path)[-1] == "Custom" and sample_embedding == "hubert_large_ll60k":
203
+ model_path = "hubert_large_ll60k.pt"
204
+ Custom_Embed = True
205
+
206
+ printt(exp_dir)
207
+ wavPath = "%s/1_16k_wavs" % exp_dir
208
+ outPath = (
209
+ "%s/3_feature256" % exp_dir if version == "v1" else "%s/3_feature768" % exp_dir if version == "v2" and sample_embedding != "hubert_large_ll60k" else "%s/3_feature1024" % exp_dir
210
+ )
211
+ os.makedirs(outPath, exist_ok=True)
212
+
213
+
214
+ # wave must be 16k, hop_size=320
215
+ def readwave(wav_path, normalize=False):
216
+ wav, sr = sf.read(wav_path)
217
+ assert sr == 16000
218
+ if Custom_Embed == False:
219
+ feats = torch.from_numpy(wav).float()
220
+ else:
221
+ feats = torch.from_numpy(load_audio(wav_path, sr)).to(dtype).to(device)
222
+ if feats.dim() == 2: # double channels
223
+ feats = feats.mean(-1)
224
+ assert feats.dim() == 1, feats.dim()
225
+ if normalize:
226
+ with torch.no_grad():
227
+ feats = F.layer_norm(feats, feats.shape)
228
+ feats = feats.view(1, -1)
229
+ return feats
230
+
231
+
232
+ # HuBERT model
233
+ printt("load model(s) from {}".format(model_path))
234
+ # if hubert model is exist
235
+ if os.access(model_path, os.F_OK) == False:
236
+ printt(
237
+ "Error: Extracting is shut down because %s does not exist, you may download it from https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main"
238
+ % model_path
239
+ )
240
+ exit(0)
241
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([model_path])
242
+ if Custom_Embed == False:
243
+ model = models[0]
244
+ if device not in ["mps", "cpu"]:
245
+ model = model.half()
246
+ elif sample_embedding == "hubert_large_ll60k":
247
+ dtype = torch.float16 if config.is_half and "cuda" in device else torch.float32
248
+ model = HubertModelWithFinalProj.from_pretrained("Custom/").to(dtype).to(device)
249
+ else:
250
+ dtype = torch.float16 if config.is_half and "cuda" in device else torch.float32
251
+ model = HubertModelWithFinalProj.from_pretrained("Custom/").to(dtype).to(device)
252
+ model = model.to(device)
253
+ printt("move model to %s" % device)
254
+ model.eval()
255
+
256
+ todo = sorted(list(os.listdir(wavPath)))[i_part::n_part]
257
+ n = max(1, len(todo) // 10)
258
+ if len(todo) == 0:
259
+ printt("no-feature-todo")
260
+ else:
261
+ printt("all-feature-%s" % len(todo))
262
+ for idx, file in enumerate(todo):
263
+ try:
264
+ if file.endswith(".wav"):
265
+ wav_path = "%s/%s" % (wavPath, file)
266
+ out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))
267
+
268
+ if os.path.exists(out_path):
269
+ continue
270
+
271
+ feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
272
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
273
+ inputs = {
274
+ "source": feats.half().to(device)
275
+ if device not in ["mps", "cpu"]
276
+ else feats.to(device),
277
+ "padding_mask": padding_mask.to(device),
278
+ "output_layer": 9 if version == "v1" else 12 if sample_embedding != "hubert_large_ll60k" else 24, # layer 9
279
+ }
280
+ with torch.no_grad():
281
+ if Custom_Embed == False:
282
+ logits = model.extract_features(**inputs)
283
+ feats = (
284
+ model.final_proj(logits[0]) if version == "v1" else logits[0]
285
+ )
286
+ elif Custom_Embed == True:
287
+ feats = model(feats)["last_hidden_state"]
288
+ feats = (
289
+ model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats
290
+ )
291
+
292
+ feats = feats.squeeze(0).float().cpu().numpy()
293
+ if np.isnan(feats).sum() == 0:
294
+ np.save(out_path, feats, allow_pickle=False)
295
+ else:
296
+ printt("%s-contains nan" % file)
297
+ if idx % n == 0:
298
+ printt("now-%s,all-%s,%s,%s" % (idx, len(todo), file, feats.shape))
299
+ except:
300
+ printt(traceback.format_exc())
301
+ printt("all-feature-done")
models.py ADDED
@@ -0,0 +1,1410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from lib.infer_pack import modules
7
+ from lib.infer_pack import attentions
8
+ from lib.infer_pack import commons
9
+ from lib.infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from lib.infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from lib.infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+
63
+
64
+ class TextEncoder768(nn.Module):
65
+ def __init__(
66
+ self,
67
+ out_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size,
73
+ p_dropout,
74
+ f0=True,
75
+ ):
76
+ super().__init__()
77
+ self.out_channels = out_channels
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.emb_phone = nn.Linear(768, hidden_channels)
85
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
+ if f0 == True:
87
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
+ self.encoder = attentions.Encoder(
89
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
+ )
91
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
92
+
93
+ def forward(self, phone, pitch, lengths):
94
+ if pitch == None:
95
+ x = self.emb_phone(phone)
96
+ else:
97
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
+ x = self.lrelu(x)
100
+ x = torch.transpose(x, 1, -1) # [b, h, t]
101
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
+ x.dtype
103
+ )
104
+ x = self.encoder(x * x_mask, x_mask)
105
+ stats = self.proj(x) * x_mask
106
+
107
+ m, logs = torch.split(stats, self.out_channels, dim=1)
108
+ return m, logs, x_mask
109
+
110
+ class TextEncoder1024(nn.Module):
111
+ def __init__(
112
+ self,
113
+ out_channels,
114
+ hidden_channels,
115
+ filter_channels,
116
+ n_heads,
117
+ n_layers,
118
+ kernel_size,
119
+ p_dropout,
120
+ f0=True,
121
+ ):
122
+ super().__init__()
123
+ self.out_channels = out_channels
124
+ self.hidden_channels = hidden_channels
125
+ self.filter_channels = filter_channels
126
+ self.n_heads = n_heads
127
+ self.n_layers = n_layers
128
+ self.kernel_size = kernel_size
129
+ self.p_dropout = p_dropout
130
+ self.emb_phone = nn.Linear(1024, hidden_channels)
131
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
132
+ if f0 == True:
133
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
134
+ self.encoder = attentions.Encoder(
135
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
136
+ )
137
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
138
+
139
+ def forward(self, phone, pitch, lengths):
140
+ if pitch == None:
141
+ x = self.emb_phone(phone)
142
+ else:
143
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
144
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
145
+ x = self.lrelu(x)
146
+ x = torch.transpose(x, 1, -1) # [b, h, t]
147
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
148
+ x.dtype
149
+ )
150
+ x = self.encoder(x * x_mask, x_mask)
151
+ stats = self.proj(x) * x_mask
152
+
153
+ m, logs = torch.split(stats, self.out_channels, dim=1)
154
+ return m, logs, x_mask
155
+
156
+
157
+ class ResidualCouplingBlock(nn.Module):
158
+ def __init__(
159
+ self,
160
+ channels,
161
+ hidden_channels,
162
+ kernel_size,
163
+ dilation_rate,
164
+ n_layers,
165
+ n_flows=4,
166
+ gin_channels=0,
167
+ ):
168
+ super().__init__()
169
+ self.channels = channels
170
+ self.hidden_channels = hidden_channels
171
+ self.kernel_size = kernel_size
172
+ self.dilation_rate = dilation_rate
173
+ self.n_layers = n_layers
174
+ self.n_flows = n_flows
175
+ self.gin_channels = gin_channels
176
+
177
+ self.flows = nn.ModuleList()
178
+ for i in range(n_flows):
179
+ self.flows.append(
180
+ modules.ResidualCouplingLayer(
181
+ channels,
182
+ hidden_channels,
183
+ kernel_size,
184
+ dilation_rate,
185
+ n_layers,
186
+ gin_channels=gin_channels,
187
+ mean_only=True,
188
+ )
189
+ )
190
+ self.flows.append(modules.Flip())
191
+
192
+ def forward(self, x, x_mask, g=None, reverse=False):
193
+ if not reverse:
194
+ for flow in self.flows:
195
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
196
+ else:
197
+ for flow in reversed(self.flows):
198
+ x = flow(x, x_mask, g=g, reverse=reverse)
199
+ return x
200
+
201
+ def remove_weight_norm(self):
202
+ for i in range(self.n_flows):
203
+ self.flows[i * 2].remove_weight_norm()
204
+
205
+
206
+ class PosteriorEncoder(nn.Module):
207
+ def __init__(
208
+ self,
209
+ in_channels,
210
+ out_channels,
211
+ hidden_channels,
212
+ kernel_size,
213
+ dilation_rate,
214
+ n_layers,
215
+ gin_channels=0,
216
+ ):
217
+ super().__init__()
218
+ self.in_channels = in_channels
219
+ self.out_channels = out_channels
220
+ self.hidden_channels = hidden_channels
221
+ self.kernel_size = kernel_size
222
+ self.dilation_rate = dilation_rate
223
+ self.n_layers = n_layers
224
+ self.gin_channels = gin_channels
225
+
226
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
227
+ self.enc = modules.WN(
228
+ hidden_channels,
229
+ kernel_size,
230
+ dilation_rate,
231
+ n_layers,
232
+ gin_channels=gin_channels,
233
+ )
234
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
235
+
236
+ def forward(self, x, x_lengths, g=None):
237
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
238
+ x.dtype
239
+ )
240
+ x = self.pre(x) * x_mask
241
+ x = self.enc(x, x_mask, g=g)
242
+ stats = self.proj(x) * x_mask
243
+ m, logs = torch.split(stats, self.out_channels, dim=1)
244
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
245
+ return z, m, logs, x_mask
246
+
247
+ def remove_weight_norm(self):
248
+ self.enc.remove_weight_norm()
249
+
250
+
251
+ class Generator(torch.nn.Module):
252
+ def __init__(
253
+ self,
254
+ initial_channel,
255
+ resblock,
256
+ resblock_kernel_sizes,
257
+ resblock_dilation_sizes,
258
+ upsample_rates,
259
+ upsample_initial_channel,
260
+ upsample_kernel_sizes,
261
+ gin_channels=0,
262
+ ):
263
+ super(Generator, self).__init__()
264
+ self.num_kernels = len(resblock_kernel_sizes)
265
+ self.num_upsamples = len(upsample_rates)
266
+ self.conv_pre = Conv1d(
267
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
268
+ )
269
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
270
+
271
+ self.ups = nn.ModuleList()
272
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
273
+ self.ups.append(
274
+ weight_norm(
275
+ ConvTranspose1d(
276
+ upsample_initial_channel // (2**i),
277
+ upsample_initial_channel // (2 ** (i + 1)),
278
+ k,
279
+ u,
280
+ padding=(k - u) // 2,
281
+ )
282
+ )
283
+ )
284
+
285
+ self.resblocks = nn.ModuleList()
286
+ for i in range(len(self.ups)):
287
+ ch = upsample_initial_channel // (2 ** (i + 1))
288
+ for j, (k, d) in enumerate(
289
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
290
+ ):
291
+ self.resblocks.append(resblock(ch, k, d))
292
+
293
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
294
+ self.ups.apply(init_weights)
295
+
296
+ if gin_channels != 0:
297
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
298
+
299
+ def forward(self, x, g=None):
300
+ x = self.conv_pre(x)
301
+ if g is not None:
302
+ x = x + self.cond(g)
303
+
304
+ for i in range(self.num_upsamples):
305
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
306
+ x = self.ups[i](x)
307
+ xs = None
308
+ for j in range(self.num_kernels):
309
+ if xs is None:
310
+ xs = self.resblocks[i * self.num_kernels + j](x)
311
+ else:
312
+ xs += self.resblocks[i * self.num_kernels + j](x)
313
+ x = xs / self.num_kernels
314
+ x = F.leaky_relu(x)
315
+ x = self.conv_post(x)
316
+ x = torch.tanh(x)
317
+
318
+ return x
319
+
320
+ def remove_weight_norm(self):
321
+ for l in self.ups:
322
+ remove_weight_norm(l)
323
+ for l in self.resblocks:
324
+ l.remove_weight_norm()
325
+
326
+
327
+ class SineGen(torch.nn.Module):
328
+ """Definition of sine generator
329
+ SineGen(samp_rate, harmonic_num = 0,
330
+ sine_amp = 0.1, noise_std = 0.003,
331
+ voiced_threshold = 0,
332
+ flag_for_pulse=False)
333
+ samp_rate: sampling rate in Hz
334
+ harmonic_num: number of harmonic overtones (default 0)
335
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
336
+ noise_std: std of Gaussian noise (default 0.003)
337
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
338
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
339
+ Note: when flag_for_pulse is True, the first time step of a voiced
340
+ segment is always sin(np.pi) or cos(0)
341
+ """
342
+
343
+ def __init__(
344
+ self,
345
+ samp_rate,
346
+ harmonic_num=0,
347
+ sine_amp=0.1,
348
+ noise_std=0.003,
349
+ voiced_threshold=0,
350
+ flag_for_pulse=False,
351
+ ):
352
+ super(SineGen, self).__init__()
353
+ self.sine_amp = sine_amp
354
+ self.noise_std = noise_std
355
+ self.harmonic_num = harmonic_num
356
+ self.dim = self.harmonic_num + 1
357
+ self.sampling_rate = samp_rate
358
+ self.voiced_threshold = voiced_threshold
359
+
360
+ def _f02uv(self, f0):
361
+ # generate uv signal
362
+ uv = torch.ones_like(f0)
363
+ uv = uv * (f0 > self.voiced_threshold)
364
+ return uv
365
+
366
+ def forward(self, f0, upp):
367
+ """sine_tensor, uv = forward(f0)
368
+ input F0: tensor(batchsize=1, length, dim=1)
369
+ f0 for unvoiced steps should be 0
370
+ output sine_tensor: tensor(batchsize=1, length, dim)
371
+ output uv: tensor(batchsize=1, length, 1)
372
+ """
373
+ with torch.no_grad():
374
+ f0 = f0[:, None].transpose(1, 2)
375
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
376
+ # fundamental component
377
+ f0_buf[:, :, 0] = f0[:, :, 0]
378
+ for idx in np.arange(self.harmonic_num):
379
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
380
+ idx + 2
381
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
382
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
383
+ rand_ini = torch.rand(
384
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
385
+ )
386
+ rand_ini[:, 0] = 0
387
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
388
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
389
+ tmp_over_one *= upp
390
+ tmp_over_one = F.interpolate(
391
+ tmp_over_one.transpose(2, 1),
392
+ scale_factor=upp,
393
+ mode="linear",
394
+ align_corners=True,
395
+ ).transpose(2, 1)
396
+ rad_values = F.interpolate(
397
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
398
+ ).transpose(
399
+ 2, 1
400
+ ) #######
401
+ tmp_over_one %= 1
402
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
403
+ cumsum_shift = torch.zeros_like(rad_values)
404
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
405
+ sine_waves = torch.sin(
406
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
407
+ )
408
+ sine_waves = sine_waves * self.sine_amp
409
+ uv = self._f02uv(f0)
410
+ uv = F.interpolate(
411
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
412
+ ).transpose(2, 1)
413
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
414
+ noise = noise_amp * torch.randn_like(sine_waves)
415
+ sine_waves = sine_waves * uv + noise
416
+ return sine_waves, uv, noise
417
+
418
+
419
+ class SourceModuleHnNSF(torch.nn.Module):
420
+ """SourceModule for hn-nsf
421
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
422
+ add_noise_std=0.003, voiced_threshod=0)
423
+ sampling_rate: sampling_rate in Hz
424
+ harmonic_num: number of harmonic above F0 (default: 0)
425
+ sine_amp: amplitude of sine source signal (default: 0.1)
426
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
427
+ note that amplitude of noise in unvoiced is decided
428
+ by sine_amp
429
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
430
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
431
+ F0_sampled (batchsize, length, 1)
432
+ Sine_source (batchsize, length, 1)
433
+ noise_source (batchsize, length 1)
434
+ uv (batchsize, length, 1)
435
+ """
436
+
437
+ def __init__(
438
+ self,
439
+ sampling_rate,
440
+ harmonic_num=0,
441
+ sine_amp=0.1,
442
+ add_noise_std=0.003,
443
+ voiced_threshod=0,
444
+ is_half=True,
445
+ ):
446
+ super(SourceModuleHnNSF, self).__init__()
447
+
448
+ self.sine_amp = sine_amp
449
+ self.noise_std = add_noise_std
450
+ self.is_half = is_half
451
+ # to produce sine waveforms
452
+ self.l_sin_gen = SineGen(
453
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
454
+ )
455
+
456
+ # to merge source harmonics into a single excitation
457
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
458
+ self.l_tanh = torch.nn.Tanh()
459
+
460
+ def forward(self, x, upp=None):
461
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
462
+ if self.is_half:
463
+ sine_wavs = sine_wavs.half()
464
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
465
+ return sine_merge, None, None # noise, uv
466
+
467
+
468
+ class GeneratorNSF(torch.nn.Module):
469
+ def __init__(
470
+ self,
471
+ initial_channel,
472
+ resblock,
473
+ resblock_kernel_sizes,
474
+ resblock_dilation_sizes,
475
+ upsample_rates,
476
+ upsample_initial_channel,
477
+ upsample_kernel_sizes,
478
+ gin_channels,
479
+ sr,
480
+ is_half=False,
481
+ ):
482
+ super(GeneratorNSF, self).__init__()
483
+ self.num_kernels = len(resblock_kernel_sizes)
484
+ self.num_upsamples = len(upsample_rates)
485
+
486
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
487
+ self.m_source = SourceModuleHnNSF(
488
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
489
+ )
490
+ self.noise_convs = nn.ModuleList()
491
+ self.conv_pre = Conv1d(
492
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
493
+ )
494
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
495
+
496
+ self.ups = nn.ModuleList()
497
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
498
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
499
+ self.ups.append(
500
+ weight_norm(
501
+ ConvTranspose1d(
502
+ upsample_initial_channel // (2**i),
503
+ upsample_initial_channel // (2 ** (i + 1)),
504
+ k,
505
+ u,
506
+ padding=(k - u) // 2,
507
+ )
508
+ )
509
+ )
510
+ if i + 1 < len(upsample_rates):
511
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
512
+ self.noise_convs.append(
513
+ Conv1d(
514
+ 1,
515
+ c_cur,
516
+ kernel_size=stride_f0 * 2,
517
+ stride=stride_f0,
518
+ padding=stride_f0 // 2,
519
+ )
520
+ )
521
+ else:
522
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
523
+
524
+ self.resblocks = nn.ModuleList()
525
+ for i in range(len(self.ups)):
526
+ ch = upsample_initial_channel // (2 ** (i + 1))
527
+ for j, (k, d) in enumerate(
528
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
529
+ ):
530
+ self.resblocks.append(resblock(ch, k, d))
531
+
532
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
533
+ self.ups.apply(init_weights)
534
+
535
+ if gin_channels != 0:
536
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
537
+
538
+ self.upp = np.prod(upsample_rates)
539
+
540
+ def forward(self, x, f0, g=None):
541
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
542
+ har_source = har_source.transpose(1, 2)
543
+ x = self.conv_pre(x)
544
+ if g is not None:
545
+ x = x + self.cond(g)
546
+
547
+ for i in range(self.num_upsamples):
548
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
549
+ x = self.ups[i](x)
550
+ x_source = self.noise_convs[i](har_source)
551
+ x = x + x_source
552
+ xs = None
553
+ for j in range(self.num_kernels):
554
+ if xs is None:
555
+ xs = self.resblocks[i * self.num_kernels + j](x)
556
+ else:
557
+ xs += self.resblocks[i * self.num_kernels + j](x)
558
+ x = xs / self.num_kernels
559
+ x = F.leaky_relu(x)
560
+ x = self.conv_post(x)
561
+ x = torch.tanh(x)
562
+ return x
563
+
564
+ def remove_weight_norm(self):
565
+ for l in self.ups:
566
+ remove_weight_norm(l)
567
+ for l in self.resblocks:
568
+ l.remove_weight_norm()
569
+
570
+
571
+ sr2sr = {
572
+ "32k": 32000,
573
+ "40k": 40000,
574
+ "48k": 48000,
575
+ }
576
+
577
+
578
+ class SynthesizerTrnMs256NSFsid(nn.Module):
579
+ def __init__(
580
+ self,
581
+ spec_channels,
582
+ segment_size,
583
+ inter_channels,
584
+ hidden_channels,
585
+ filter_channels,
586
+ n_heads,
587
+ n_layers,
588
+ kernel_size,
589
+ p_dropout,
590
+ resblock,
591
+ resblock_kernel_sizes,
592
+ resblock_dilation_sizes,
593
+ upsample_rates,
594
+ upsample_initial_channel,
595
+ upsample_kernel_sizes,
596
+ spk_embed_dim,
597
+ gin_channels,
598
+ sr,
599
+ **kwargs
600
+ ):
601
+ super().__init__()
602
+ if type(sr) == type("strr"):
603
+ sr = sr2sr[sr]
604
+ self.spec_channels = spec_channels
605
+ self.inter_channels = inter_channels
606
+ self.hidden_channels = hidden_channels
607
+ self.filter_channels = filter_channels
608
+ self.n_heads = n_heads
609
+ self.n_layers = n_layers
610
+ self.kernel_size = kernel_size
611
+ self.p_dropout = p_dropout
612
+ self.resblock = resblock
613
+ self.resblock_kernel_sizes = resblock_kernel_sizes
614
+ self.resblock_dilation_sizes = resblock_dilation_sizes
615
+ self.upsample_rates = upsample_rates
616
+ self.upsample_initial_channel = upsample_initial_channel
617
+ self.upsample_kernel_sizes = upsample_kernel_sizes
618
+ self.segment_size = segment_size
619
+ self.gin_channels = gin_channels
620
+ # self.hop_length = hop_length#
621
+ self.spk_embed_dim = spk_embed_dim
622
+ self.enc_p = TextEncoder256(
623
+ inter_channels,
624
+ hidden_channels,
625
+ filter_channels,
626
+ n_heads,
627
+ n_layers,
628
+ kernel_size,
629
+ p_dropout,
630
+ )
631
+ self.dec = GeneratorNSF(
632
+ inter_channels,
633
+ resblock,
634
+ resblock_kernel_sizes,
635
+ resblock_dilation_sizes,
636
+ upsample_rates,
637
+ upsample_initial_channel,
638
+ upsample_kernel_sizes,
639
+ gin_channels=gin_channels,
640
+ sr=sr,
641
+ is_half=kwargs["is_half"],
642
+ )
643
+ self.enc_q = PosteriorEncoder(
644
+ spec_channels,
645
+ inter_channels,
646
+ hidden_channels,
647
+ 5,
648
+ 1,
649
+ 16,
650
+ gin_channels=gin_channels,
651
+ )
652
+ self.flow = ResidualCouplingBlock(
653
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
654
+ )
655
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
656
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
657
+
658
+ def remove_weight_norm(self):
659
+ self.dec.remove_weight_norm()
660
+ self.flow.remove_weight_norm()
661
+ self.enc_q.remove_weight_norm()
662
+
663
+ def forward(
664
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
665
+ ): # 这里ds是id,[bs,1]
666
+ # print(1,pitch.shape)#[bs,t]
667
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
668
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
669
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
670
+ z_p = self.flow(z, y_mask, g=g)
671
+ z_slice, ids_slice = commons.rand_slice_segments(
672
+ z, y_lengths, self.segment_size
673
+ )
674
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
675
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
676
+ # print(-2,pitchf.shape,z_slice.shape)
677
+ o = self.dec(z_slice, pitchf, g=g)
678
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
679
+
680
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
681
+ g = self.emb_g(sid).unsqueeze(-1)
682
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
683
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
684
+ if rate:
685
+ head = int(z_p.shape[2] * rate)
686
+ z_p = z_p[:, :, -head:]
687
+ x_mask = x_mask[:, :, -head:]
688
+ nsff0 = nsff0[:, -head:]
689
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
690
+ o = self.dec(z * x_mask, nsff0, g=g)
691
+ return o, x_mask, (z, z_p, m_p, logs_p)
692
+
693
+
694
+ class SynthesizerTrnMs768NSFsid(nn.Module):
695
+ def __init__(
696
+ self,
697
+ spec_channels,
698
+ segment_size,
699
+ inter_channels,
700
+ hidden_channels,
701
+ filter_channels,
702
+ n_heads,
703
+ n_layers,
704
+ kernel_size,
705
+ p_dropout,
706
+ resblock,
707
+ resblock_kernel_sizes,
708
+ resblock_dilation_sizes,
709
+ upsample_rates,
710
+ upsample_initial_channel,
711
+ upsample_kernel_sizes,
712
+ spk_embed_dim,
713
+ gin_channels,
714
+ sr,
715
+ **kwargs
716
+ ):
717
+ super().__init__()
718
+ if type(sr) == type("strr"):
719
+ sr = sr2sr[sr]
720
+ self.spec_channels = spec_channels
721
+ self.inter_channels = inter_channels
722
+ self.hidden_channels = hidden_channels
723
+ self.filter_channels = filter_channels
724
+ self.n_heads = n_heads
725
+ self.n_layers = n_layers
726
+ self.kernel_size = kernel_size
727
+ self.p_dropout = p_dropout
728
+ self.resblock = resblock
729
+ self.resblock_kernel_sizes = resblock_kernel_sizes
730
+ self.resblock_dilation_sizes = resblock_dilation_sizes
731
+ self.upsample_rates = upsample_rates
732
+ self.upsample_initial_channel = upsample_initial_channel
733
+ self.upsample_kernel_sizes = upsample_kernel_sizes
734
+ self.segment_size = segment_size
735
+ self.gin_channels = gin_channels
736
+ # self.hop_length = hop_length#
737
+ self.spk_embed_dim = spk_embed_dim
738
+ self.enc_p = TextEncoder768(
739
+ inter_channels,
740
+ hidden_channels,
741
+ filter_channels,
742
+ n_heads,
743
+ n_layers,
744
+ kernel_size,
745
+ p_dropout,
746
+ )
747
+ self.dec = GeneratorNSF(
748
+ inter_channels,
749
+ resblock,
750
+ resblock_kernel_sizes,
751
+ resblock_dilation_sizes,
752
+ upsample_rates,
753
+ upsample_initial_channel,
754
+ upsample_kernel_sizes,
755
+ gin_channels=gin_channels,
756
+ sr=sr,
757
+ is_half=kwargs["is_half"],
758
+ )
759
+ self.enc_q = PosteriorEncoder(
760
+ spec_channels,
761
+ inter_channels,
762
+ hidden_channels,
763
+ 5,
764
+ 1,
765
+ 16,
766
+ gin_channels=gin_channels,
767
+ )
768
+ self.flow = ResidualCouplingBlock(
769
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
770
+ )
771
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
772
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
773
+
774
+ def remove_weight_norm(self):
775
+ self.dec.remove_weight_norm()
776
+ self.flow.remove_weight_norm()
777
+ self.enc_q.remove_weight_norm()
778
+
779
+ def forward(
780
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
781
+ ): # 这里ds是id,[bs,1]
782
+ # print(1,pitch.shape)#[bs,t]
783
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
784
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
785
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
786
+ z_p = self.flow(z, y_mask, g=g)
787
+ z_slice, ids_slice = commons.rand_slice_segments(
788
+ z, y_lengths, self.segment_size
789
+ )
790
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
791
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
792
+ # print(-2,pitchf.shape,z_slice.shape)
793
+ o = self.dec(z_slice, pitchf, g=g)
794
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
795
+
796
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
797
+ g = self.emb_g(sid).unsqueeze(-1)
798
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
799
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
800
+ if rate:
801
+ head = int(z_p.shape[2] * rate)
802
+ z_p = z_p[:, :, -head:]
803
+ x_mask = x_mask[:, :, -head:]
804
+ nsff0 = nsff0[:, -head:]
805
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
806
+ o = self.dec(z * x_mask, nsff0, g=g)
807
+ return o, x_mask, (z, z_p, m_p, logs_p)
808
+
809
+ class SynthesizerTrnMs1024NSFsid(nn.Module):
810
+ def __init__(
811
+ self,
812
+ spec_channels,
813
+ segment_size,
814
+ inter_channels,
815
+ hidden_channels,
816
+ filter_channels,
817
+ n_heads,
818
+ n_layers,
819
+ kernel_size,
820
+ p_dropout,
821
+ resblock,
822
+ resblock_kernel_sizes,
823
+ resblock_dilation_sizes,
824
+ upsample_rates,
825
+ upsample_initial_channel,
826
+ upsample_kernel_sizes,
827
+ spk_embed_dim,
828
+ gin_channels,
829
+ sr,
830
+ **kwargs
831
+ ):
832
+ super().__init__()
833
+ if type(sr) == type("strr"):
834
+ sr = sr2sr[sr]
835
+ self.spec_channels = spec_channels
836
+ self.inter_channels = inter_channels
837
+ self.hidden_channels = hidden_channels
838
+ self.filter_channels = filter_channels
839
+ self.n_heads = n_heads
840
+ self.n_layers = n_layers
841
+ self.kernel_size = kernel_size
842
+ self.p_dropout = p_dropout
843
+ self.resblock = resblock
844
+ self.resblock_kernel_sizes = resblock_kernel_sizes
845
+ self.resblock_dilation_sizes = resblock_dilation_sizes
846
+ self.upsample_rates = upsample_rates
847
+ self.upsample_initial_channel = upsample_initial_channel
848
+ self.upsample_kernel_sizes = upsample_kernel_sizes
849
+ self.segment_size = segment_size
850
+ self.gin_channels = gin_channels
851
+ # self.hop_length = hop_length#
852
+ self.spk_embed_dim = spk_embed_dim
853
+ self.enc_p = TextEncoder1024(
854
+ inter_channels,
855
+ hidden_channels,
856
+ filter_channels,
857
+ n_heads,
858
+ n_layers,
859
+ kernel_size,
860
+ p_dropout,
861
+ )
862
+ self.dec = GeneratorNSF(
863
+ inter_channels,
864
+ resblock,
865
+ resblock_kernel_sizes,
866
+ resblock_dilation_sizes,
867
+ upsample_rates,
868
+ upsample_initial_channel,
869
+ upsample_kernel_sizes,
870
+ gin_channels=gin_channels,
871
+ sr=sr,
872
+ is_half=kwargs["is_half"],
873
+ )
874
+ self.enc_q = PosteriorEncoder(
875
+ spec_channels,
876
+ inter_channels,
877
+ hidden_channels,
878
+ 5,
879
+ 1,
880
+ 16,
881
+ gin_channels=gin_channels,
882
+ )
883
+ self.flow = ResidualCouplingBlock(
884
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
885
+ )
886
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
887
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
888
+
889
+ def remove_weight_norm(self):
890
+ self.dec.remove_weight_norm()
891
+ self.flow.remove_weight_norm()
892
+ self.enc_q.remove_weight_norm()
893
+
894
+ def forward(
895
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
896
+ ): # 这里ds是id,[bs,1]
897
+ # print(1,pitch.shape)#[bs,t]
898
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
899
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
900
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
901
+ z_p = self.flow(z, y_mask, g=g)
902
+ z_slice, ids_slice = commons.rand_slice_segments(
903
+ z, y_lengths, self.segment_size
904
+ )
905
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
906
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
907
+ # print(-2,pitchf.shape,z_slice.shape)
908
+ o = self.dec(z_slice, pitchf, g=g)
909
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
910
+
911
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
912
+ g = self.emb_g(sid).unsqueeze(-1)
913
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
914
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
915
+ if rate:
916
+ head = int(z_p.shape[2] * rate)
917
+ z_p = z_p[:, :, -head:]
918
+ x_mask = x_mask[:, :, -head:]
919
+ nsff0 = nsff0[:, -head:]
920
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
921
+ o = self.dec(z * x_mask, nsff0, g=g)
922
+ return o, x_mask, (z, z_p, m_p, logs_p)
923
+
924
+
925
+ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
926
+ def __init__(
927
+ self,
928
+ spec_channels,
929
+ segment_size,
930
+ inter_channels,
931
+ hidden_channels,
932
+ filter_channels,
933
+ n_heads,
934
+ n_layers,
935
+ kernel_size,
936
+ p_dropout,
937
+ resblock,
938
+ resblock_kernel_sizes,
939
+ resblock_dilation_sizes,
940
+ upsample_rates,
941
+ upsample_initial_channel,
942
+ upsample_kernel_sizes,
943
+ spk_embed_dim,
944
+ gin_channels,
945
+ sr=None,
946
+ **kwargs
947
+ ):
948
+ super().__init__()
949
+ self.spec_channels = spec_channels
950
+ self.inter_channels = inter_channels
951
+ self.hidden_channels = hidden_channels
952
+ self.filter_channels = filter_channels
953
+ self.n_heads = n_heads
954
+ self.n_layers = n_layers
955
+ self.kernel_size = kernel_size
956
+ self.p_dropout = p_dropout
957
+ self.resblock = resblock
958
+ self.resblock_kernel_sizes = resblock_kernel_sizes
959
+ self.resblock_dilation_sizes = resblock_dilation_sizes
960
+ self.upsample_rates = upsample_rates
961
+ self.upsample_initial_channel = upsample_initial_channel
962
+ self.upsample_kernel_sizes = upsample_kernel_sizes
963
+ self.segment_size = segment_size
964
+ self.gin_channels = gin_channels
965
+ # self.hop_length = hop_length#
966
+ self.spk_embed_dim = spk_embed_dim
967
+ self.enc_p = TextEncoder256(
968
+ inter_channels,
969
+ hidden_channels,
970
+ filter_channels,
971
+ n_heads,
972
+ n_layers,
973
+ kernel_size,
974
+ p_dropout,
975
+ f0=False,
976
+ )
977
+ self.dec = Generator(
978
+ inter_channels,
979
+ resblock,
980
+ resblock_kernel_sizes,
981
+ resblock_dilation_sizes,
982
+ upsample_rates,
983
+ upsample_initial_channel,
984
+ upsample_kernel_sizes,
985
+ gin_channels=gin_channels,
986
+ )
987
+ self.enc_q = PosteriorEncoder(
988
+ spec_channels,
989
+ inter_channels,
990
+ hidden_channels,
991
+ 5,
992
+ 1,
993
+ 16,
994
+ gin_channels=gin_channels,
995
+ )
996
+ self.flow = ResidualCouplingBlock(
997
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
998
+ )
999
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
1000
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
1001
+
1002
+ def remove_weight_norm(self):
1003
+ self.dec.remove_weight_norm()
1004
+ self.flow.remove_weight_norm()
1005
+ self.enc_q.remove_weight_norm()
1006
+
1007
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
1008
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
1009
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
1010
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
1011
+ z_p = self.flow(z, y_mask, g=g)
1012
+ z_slice, ids_slice = commons.rand_slice_segments(
1013
+ z, y_lengths, self.segment_size
1014
+ )
1015
+ o = self.dec(z_slice, g=g)
1016
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
1017
+
1018
+ def infer(self, phone, phone_lengths, sid, rate=None):
1019
+ g = self.emb_g(sid).unsqueeze(-1)
1020
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
1021
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
1022
+ if rate:
1023
+ head = int(z_p.shape[2] * rate)
1024
+ z_p = z_p[:, :, -head:]
1025
+ x_mask = x_mask[:, :, -head:]
1026
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
1027
+ o = self.dec(z * x_mask, g=g)
1028
+ return o, x_mask, (z, z_p, m_p, logs_p)
1029
+
1030
+
1031
+ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
1032
+ def __init__(
1033
+ self,
1034
+ spec_channels,
1035
+ segment_size,
1036
+ inter_channels,
1037
+ hidden_channels,
1038
+ filter_channels,
1039
+ n_heads,
1040
+ n_layers,
1041
+ kernel_size,
1042
+ p_dropout,
1043
+ resblock,
1044
+ resblock_kernel_sizes,
1045
+ resblock_dilation_sizes,
1046
+ upsample_rates,
1047
+ upsample_initial_channel,
1048
+ upsample_kernel_sizes,
1049
+ spk_embed_dim,
1050
+ gin_channels,
1051
+ sr=None,
1052
+ **kwargs
1053
+ ):
1054
+ super().__init__()
1055
+ self.spec_channels = spec_channels
1056
+ self.inter_channels = inter_channels
1057
+ self.hidden_channels = hidden_channels
1058
+ self.filter_channels = filter_channels
1059
+ self.n_heads = n_heads
1060
+ self.n_layers = n_layers
1061
+ self.kernel_size = kernel_size
1062
+ self.p_dropout = p_dropout
1063
+ self.resblock = resblock
1064
+ self.resblock_kernel_sizes = resblock_kernel_sizes
1065
+ self.resblock_dilation_sizes = resblock_dilation_sizes
1066
+ self.upsample_rates = upsample_rates
1067
+ self.upsample_initial_channel = upsample_initial_channel
1068
+ self.upsample_kernel_sizes = upsample_kernel_sizes
1069
+ self.segment_size = segment_size
1070
+ self.gin_channels = gin_channels
1071
+ # self.hop_length = hop_length#
1072
+ self.spk_embed_dim = spk_embed_dim
1073
+ self.enc_p = TextEncoder768(
1074
+ inter_channels,
1075
+ hidden_channels,
1076
+ filter_channels,
1077
+ n_heads,
1078
+ n_layers,
1079
+ kernel_size,
1080
+ p_dropout,
1081
+ f0=False,
1082
+ )
1083
+ self.dec = Generator(
1084
+ inter_channels,
1085
+ resblock,
1086
+ resblock_kernel_sizes,
1087
+ resblock_dilation_sizes,
1088
+ upsample_rates,
1089
+ upsample_initial_channel,
1090
+ upsample_kernel_sizes,
1091
+ gin_channels=gin_channels,
1092
+ )
1093
+ self.enc_q = PosteriorEncoder(
1094
+ spec_channels,
1095
+ inter_channels,
1096
+ hidden_channels,
1097
+ 5,
1098
+ 1,
1099
+ 16,
1100
+ gin_channels=gin_channels,
1101
+ )
1102
+ self.flow = ResidualCouplingBlock(
1103
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
1104
+ )
1105
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
1106
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
1107
+
1108
+ def remove_weight_norm(self):
1109
+ self.dec.remove_weight_norm()
1110
+ self.flow.remove_weight_norm()
1111
+ self.enc_q.remove_weight_norm()
1112
+
1113
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
1114
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
1115
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
1116
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
1117
+ z_p = self.flow(z, y_mask, g=g)
1118
+ z_slice, ids_slice = commons.rand_slice_segments(
1119
+ z, y_lengths, self.segment_size
1120
+ )
1121
+ o = self.dec(z_slice, g=g)
1122
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
1123
+
1124
+ def infer(self, phone, phone_lengths, sid, rate=None):
1125
+ g = self.emb_g(sid).unsqueeze(-1)
1126
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
1127
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
1128
+ if rate:
1129
+ head = int(z_p.shape[2] * rate)
1130
+ z_p = z_p[:, :, -head:]
1131
+ x_mask = x_mask[:, :, -head:]
1132
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
1133
+ o = self.dec(z * x_mask, g=g)
1134
+ return o, x_mask, (z, z_p, m_p, logs_p)
1135
+
1136
+ class SynthesizerTrnMs1024NSFsid_nono(nn.Module):
1137
+ def __init__(
1138
+ self,
1139
+ spec_channels,
1140
+ segment_size,
1141
+ inter_channels,
1142
+ hidden_channels,
1143
+ filter_channels,
1144
+ n_heads,
1145
+ n_layers,
1146
+ kernel_size,
1147
+ p_dropout,
1148
+ resblock,
1149
+ resblock_kernel_sizes,
1150
+ resblock_dilation_sizes,
1151
+ upsample_rates,
1152
+ upsample_initial_channel,
1153
+ upsample_kernel_sizes,
1154
+ spk_embed_dim,
1155
+ gin_channels,
1156
+ sr=None,
1157
+ **kwargs
1158
+ ):
1159
+ super().__init__()
1160
+ self.spec_channels = spec_channels
1161
+ self.inter_channels = inter_channels
1162
+ self.hidden_channels = hidden_channels
1163
+ self.filter_channels = filter_channels
1164
+ self.n_heads = n_heads
1165
+ self.n_layers = n_layers
1166
+ self.kernel_size = kernel_size
1167
+ self.p_dropout = p_dropout
1168
+ self.resblock = resblock
1169
+ self.resblock_kernel_sizes = resblock_kernel_sizes
1170
+ self.resblock_dilation_sizes = resblock_dilation_sizes
1171
+ self.upsample_rates = upsample_rates
1172
+ self.upsample_initial_channel = upsample_initial_channel
1173
+ self.upsample_kernel_sizes = upsample_kernel_sizes
1174
+ self.segment_size = segment_size
1175
+ self.gin_channels = gin_channels
1176
+ # self.hop_length = hop_length#
1177
+ self.spk_embed_dim = spk_embed_dim
1178
+ self.enc_p = TextEncoder1024(
1179
+ inter_channels,
1180
+ hidden_channels,
1181
+ filter_channels,
1182
+ n_heads,
1183
+ n_layers,
1184
+ kernel_size,
1185
+ p_dropout,
1186
+ f0=False,
1187
+ )
1188
+ self.dec = Generator(
1189
+ inter_channels,
1190
+ resblock,
1191
+ resblock_kernel_sizes,
1192
+ resblock_dilation_sizes,
1193
+ upsample_rates,
1194
+ upsample_initial_channel,
1195
+ upsample_kernel_sizes,
1196
+ gin_channels=gin_channels,
1197
+ )
1198
+ self.enc_q = PosteriorEncoder(
1199
+ spec_channels,
1200
+ inter_channels,
1201
+ hidden_channels,
1202
+ 5,
1203
+ 1,
1204
+ 16,
1205
+ gin_channels=gin_channels,
1206
+ )
1207
+ self.flow = ResidualCouplingBlock(
1208
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
1209
+ )
1210
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
1211
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
1212
+
1213
+ def remove_weight_norm(self):
1214
+ self.dec.remove_weight_norm()
1215
+ self.flow.remove_weight_norm()
1216
+ self.enc_q.remove_weight_norm()
1217
+
1218
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
1219
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
1220
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
1221
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
1222
+ z_p = self.flow(z, y_mask, g=g)
1223
+ z_slice, ids_slice = commons.rand_slice_segments(
1224
+ z, y_lengths, self.segment_size
1225
+ )
1226
+ o = self.dec(z_slice, g=g)
1227
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
1228
+
1229
+ def infer(self, phone, phone_lengths, sid, rate=None):
1230
+ g = self.emb_g(sid).unsqueeze(-1)
1231
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
1232
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
1233
+ if rate:
1234
+ head = int(z_p.shape[2] * rate)
1235
+ z_p = z_p[:, :, -head:]
1236
+ x_mask = x_mask[:, :, -head:]
1237
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
1238
+ o = self.dec(z * x_mask, g=g)
1239
+ return o, x_mask, (z, z_p, m_p, logs_p)
1240
+
1241
+
1242
+ class MultiPeriodDiscriminator(torch.nn.Module):
1243
+ def __init__(self, use_spectral_norm=False):
1244
+ super(MultiPeriodDiscriminator, self).__init__()
1245
+ periods = [2, 3, 5, 7, 11, 17]
1246
+ # periods = [3, 5, 7, 11, 17, 23, 37]
1247
+
1248
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
1249
+ discs = discs + [
1250
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
1251
+ ]
1252
+ self.discriminators = nn.ModuleList(discs)
1253
+
1254
+ def forward(self, y, y_hat):
1255
+ y_d_rs = [] #
1256
+ y_d_gs = []
1257
+ fmap_rs = []
1258
+ fmap_gs = []
1259
+ for i, d in enumerate(self.discriminators):
1260
+ y_d_r, fmap_r = d(y)
1261
+ y_d_g, fmap_g = d(y_hat)
1262
+ # for j in range(len(fmap_r)):
1263
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1264
+ y_d_rs.append(y_d_r)
1265
+ y_d_gs.append(y_d_g)
1266
+ fmap_rs.append(fmap_r)
1267
+ fmap_gs.append(fmap_g)
1268
+
1269
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1270
+
1271
+
1272
+ class MultiPeriodDiscriminatorV2(torch.nn.Module):
1273
+ def __init__(self, use_spectral_norm=False):
1274
+ super(MultiPeriodDiscriminatorV2, self).__init__()
1275
+ # periods = [2, 3, 5, 7, 11, 17]
1276
+ periods = [2, 3, 5, 7, 11, 17, 23, 37]
1277
+
1278
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
1279
+ discs = discs + [
1280
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
1281
+ ]
1282
+ self.discriminators = nn.ModuleList(discs)
1283
+
1284
+ def forward(self, y, y_hat):
1285
+ y_d_rs = [] #
1286
+ y_d_gs = []
1287
+ fmap_rs = []
1288
+ fmap_gs = []
1289
+ for i, d in enumerate(self.discriminators):
1290
+ y_d_r, fmap_r = d(y)
1291
+ y_d_g, fmap_g = d(y_hat)
1292
+ # for j in range(len(fmap_r)):
1293
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1294
+ y_d_rs.append(y_d_r)
1295
+ y_d_gs.append(y_d_g)
1296
+ fmap_rs.append(fmap_r)
1297
+ fmap_gs.append(fmap_g)
1298
+
1299
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1300
+
1301
+
1302
+
1303
+
1304
+ class DiscriminatorS(torch.nn.Module):
1305
+ def __init__(self, use_spectral_norm=False):
1306
+ super(DiscriminatorS, self).__init__()
1307
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1308
+ self.convs = nn.ModuleList(
1309
+ [
1310
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1311
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1312
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1313
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1314
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1315
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1316
+ ]
1317
+ )
1318
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1319
+
1320
+ def forward(self, x):
1321
+ fmap = []
1322
+
1323
+ for l in self.convs:
1324
+ x = l(x)
1325
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1326
+ fmap.append(x)
1327
+ x = self.conv_post(x)
1328
+ fmap.append(x)
1329
+ x = torch.flatten(x, 1, -1)
1330
+
1331
+ return x, fmap
1332
+
1333
+
1334
+ class DiscriminatorP(torch.nn.Module):
1335
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1336
+ super(DiscriminatorP, self).__init__()
1337
+ self.period = period
1338
+ self.use_spectral_norm = use_spectral_norm
1339
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1340
+ self.convs = nn.ModuleList(
1341
+ [
1342
+ norm_f(
1343
+ Conv2d(
1344
+ 1,
1345
+ 32,
1346
+ (kernel_size, 1),
1347
+ (stride, 1),
1348
+ padding=(get_padding(kernel_size, 1), 0),
1349
+ )
1350
+ ),
1351
+ norm_f(
1352
+ Conv2d(
1353
+ 32,
1354
+ 128,
1355
+ (kernel_size, 1),
1356
+ (stride, 1),
1357
+ padding=(get_padding(kernel_size, 1), 0),
1358
+ )
1359
+ ),
1360
+ norm_f(
1361
+ Conv2d(
1362
+ 128,
1363
+ 512,
1364
+ (kernel_size, 1),
1365
+ (stride, 1),
1366
+ padding=(get_padding(kernel_size, 1), 0),
1367
+ )
1368
+ ),
1369
+ norm_f(
1370
+ Conv2d(
1371
+ 512,
1372
+ 1024,
1373
+ (kernel_size, 1),
1374
+ (stride, 1),
1375
+ padding=(get_padding(kernel_size, 1), 0),
1376
+ )
1377
+ ),
1378
+ norm_f(
1379
+ Conv2d(
1380
+ 1024,
1381
+ 1024,
1382
+ (kernel_size, 1),
1383
+ 1,
1384
+ padding=(get_padding(kernel_size, 1), 0),
1385
+ )
1386
+ ),
1387
+ ]
1388
+ )
1389
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1390
+
1391
+ def forward(self, x):
1392
+ fmap = []
1393
+
1394
+ # 1d to 2d
1395
+ b, c, t = x.shape
1396
+ if t % self.period != 0: # pad first
1397
+ n_pad = self.period - (t % self.period)
1398
+ x = F.pad(x, (0, n_pad), "reflect")
1399
+ t = t + n_pad
1400
+ x = x.view(b, c, t // self.period, self.period)
1401
+
1402
+ for l in self.convs:
1403
+ x = l(x)
1404
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1405
+ fmap.append(x)
1406
+ x = self.conv_post(x)
1407
+ fmap.append(x)
1408
+ x = torch.flatten(x, 1, -1)
1409
+
1410
+ return x, fmap
mute.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de08d058b9abfdb1d51e06e7ec8941ab9d2c41f09483e84eb0cb1cdb7368b717
3
+ size 1056896
train_nsf_sim_cache_sid_load_pretrain.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import pickle as p
3
+ now_dir = os.getcwd()
4
+ sys.path.append(os.path.join(now_dir))
5
+ sys.path.append(os.path.join(now_dir, "train"))
6
+ import utils
7
+ Loss_Gen_Per_Epoch = []
8
+ Loss_Disc_Per_Epoch = []
9
+ elapsed_time_record = []
10
+ Lowest_lg = 0
11
+ Lowest_ld = 0
12
+ import datetime
13
+ hps = utils.get_hparams()
14
+ overtrain = hps.overtrain
15
+ experiment_name = hps.name
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
17
+ n_gpus = len(hps.gpus.split("-"))
18
+ from random import shuffle, randint
19
+ import traceback, json, argparse, itertools, math, torch, pdb
20
+
21
+ torch.backends.cudnn.deterministic = False
22
+ torch.backends.cudnn.benchmark = False
23
+ from torch import nn, optim
24
+ from torch.nn import functional as F
25
+ from torch.utils.data import DataLoader
26
+ from torch.utils.tensorboard import SummaryWriter
27
+ import torch.multiprocessing as mp
28
+ import torch.distributed as dist
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+ from torch.cuda.amp import autocast, GradScaler
31
+ from lib.infer_pack import commons
32
+ from time import sleep
33
+ from time import time as ttime
34
+ from data_utils import (
35
+ TextAudioLoaderMultiNSFsid,
36
+ TextAudioLoader,
37
+ TextAudioCollateMultiNSFsid,
38
+ TextAudioCollate,
39
+ DistributedBucketSampler,
40
+ )
41
+
42
+ import csv
43
+
44
+ if hps.version == "v1":
45
+ from lib.infer_pack.models import (
46
+ SynthesizerTrnMs256NSFsid as RVC_Model_f0,
47
+ SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
48
+ MultiPeriodDiscriminator,
49
+ )
50
+ elif hps.version == "v2" and hps.Large_HuBert == True:
51
+ from lib.infer_pack.models import (
52
+ SynthesizerTrnMs1024NSFsid as RVC_Model_f0,
53
+ SynthesizerTrnMs1024NSFsid_nono as RVC_Model_nof0,
54
+ MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
55
+ )
56
+ else:
57
+ from lib.infer_pack.models import (
58
+ SynthesizerTrnMs768NSFsid as RVC_Model_f0,
59
+ SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
60
+ MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
61
+ )
62
+ from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
63
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
64
+ from process_ckpt import savee
65
+
66
+ global global_step
67
+ global_step = 0
68
+
69
+ def Calculate_format_elapsed_time(elapsed_time):
70
+ h = int(elapsed_time/3600)
71
+ m,s,ms = int(elapsed_time/60 - h*60), int(elapsed_time%60), round((elapsed_time - int(elapsed_time))*10000)
72
+ return h,m,s,ms
73
+ def right_index(List,Value):
74
+ index = len(List)-1-List[::-1].index(Value)
75
+ return index
76
+ def formating_time(time):
77
+ time = time if time >= 10 else f"0{time}"
78
+ return time
79
+ class EpochRecorder:
80
+ def __init__(self):
81
+ self.last_time = ttime()
82
+
83
+ def record(self):
84
+ now_time = ttime()
85
+ elapsed_time = now_time - self.last_time
86
+ self.last_time = now_time
87
+ elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time))
88
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
89
+ return f"[{current_time}] | ({elapsed_time_str})"
90
+
91
+
92
+ def main():
93
+ n_gpus = torch.cuda.device_count()
94
+ if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True:
95
+ n_gpus = 1
96
+ os.environ["MASTER_ADDR"] = "localhost"
97
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
98
+ children = []
99
+ for i in range(n_gpus):
100
+ subproc = mp.Process(
101
+ target=run,
102
+ args=(
103
+ i,
104
+ n_gpus,
105
+ hps,
106
+ ),
107
+ )
108
+ children.append(subproc)
109
+ subproc.start()
110
+ for i in range(n_gpus):
111
+ children[i].join()
112
+
113
+
114
+
115
+ def run(rank, n_gpus, hps):
116
+ global global_step, loss_disc, loss_gen_all, Loss_Disc_Per_Epoch, Loss_Gen_Per_Epoch, elapsed_time_record, best_epoch, best_global_step, Min_for_Single_epoch, prev_best_epoch
117
+ if rank == 0:
118
+ logger = utils.get_logger(hps.model_dir)
119
+ logger.info(hps)
120
+ # utils.check_git_hash(hps.model_dir)
121
+ writer = SummaryWriter(log_dir=hps.model_dir)
122
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
123
+
124
+ dist.init_process_group(
125
+ backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
126
+ )
127
+ torch.manual_seed(hps.train.seed)
128
+ if torch.cuda.is_available():
129
+ torch.cuda.set_device(rank)
130
+
131
+ if hps.if_f0 == 1:
132
+ train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
133
+ else:
134
+ train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
135
+ train_sampler = DistributedBucketSampler(
136
+ train_dataset,
137
+ hps.train.batch_size * n_gpus,
138
+ # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
139
+ [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
140
+ num_replicas=n_gpus,
141
+ rank=rank,
142
+ shuffle=True,
143
+ )
144
+ # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
145
+ # num_workers=8 -> num_workers=4
146
+ if hps.if_f0 == 1:
147
+ collate_fn = TextAudioCollateMultiNSFsid()
148
+ else:
149
+ collate_fn = TextAudioCollate()
150
+ train_loader = DataLoader(
151
+ train_dataset,
152
+ num_workers=4,
153
+ shuffle=False,
154
+ pin_memory=True,
155
+ collate_fn=collate_fn,
156
+ batch_sampler=train_sampler,
157
+ persistent_workers=True,
158
+ prefetch_factor=8,
159
+ )
160
+ if hps.if_f0 == 1:
161
+ net_g = RVC_Model_f0(
162
+ hps.data.filter_length // 2 + 1,
163
+ hps.train.segment_size // hps.data.hop_length,
164
+ **hps.model,
165
+ is_half=hps.train.fp16_run,
166
+ sr=hps.sample_rate,
167
+ )
168
+ else:
169
+ net_g = RVC_Model_nof0(
170
+ hps.data.filter_length // 2 + 1,
171
+ hps.train.segment_size // hps.data.hop_length,
172
+ **hps.model,
173
+ is_half=hps.train.fp16_run,
174
+ )
175
+ if torch.cuda.is_available():
176
+ net_g = net_g.cuda(rank)
177
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
178
+ if torch.cuda.is_available():
179
+ net_d = net_d.cuda(rank)
180
+ optim_g = torch.optim.AdamW(
181
+ net_g.parameters(),
182
+ hps.train.learning_rate,
183
+ betas=hps.train.betas,
184
+ eps=hps.train.eps,
185
+ )
186
+ optim_d = torch.optim.AdamW(
187
+ net_d.parameters(),
188
+ hps.train.learning_rate,
189
+ betas=hps.train.betas,
190
+ eps=hps.train.eps,
191
+ )
192
+ # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
193
+ # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
194
+ if torch.cuda.is_available():
195
+ net_g = DDP(net_g, device_ids=[rank])
196
+ net_d = DDP(net_d, device_ids=[rank])
197
+ else:
198
+ net_g = DDP(net_g)
199
+ net_d = DDP(net_d)
200
+
201
+ try: # 如果能加载自动resume
202
+ _, _, _, epoch_str = utils.load_checkpoint(
203
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
204
+ ) # D多半加载没事
205
+ if rank == 0:
206
+ logger.info("loaded D")
207
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
208
+ _, _, _, epoch_str = utils.load_checkpoint(
209
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
210
+ )
211
+ global_step = (epoch_str - 1) * len(train_loader)
212
+ # epoch_str = 1
213
+ # global_step = 0
214
+ except: # 如果首次不能加载,加载pretrain
215
+ # traceback.print_exc()
216
+ epoch_str = 1
217
+ global_step = 0
218
+ if hps.pretrainG != "":
219
+ if rank == 0:
220
+ logger.info("loaded pretrained %s" % (hps.pretrainG))
221
+ print(
222
+ net_g.module.load_state_dict(
223
+ torch.load(hps.pretrainG, map_location="cpu")["model"]
224
+ )
225
+ ) ##测试不加载优化器
226
+ if hps.pretrainD != "":
227
+ if rank == 0:
228
+ logger.info("loaded pretrained %s" % (hps.pretrainD))
229
+ print(
230
+ net_d.module.load_state_dict(
231
+ torch.load(hps.pretrainD, map_location="cpu")["model"]
232
+ )
233
+ )
234
+
235
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
236
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
237
+ )
238
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
239
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
240
+ )
241
+
242
+ scaler = GradScaler(enabled=hps.train.fp16_run)
243
+ #
244
+ #if hps.total_epoch < 100:
245
+ #Min_for_Single_epoch = int(hps.total_epoch/2)
246
+ #else:
247
+ #Min_for_Single_epoch = 50
248
+ Min_for_Single_epoch = 1
249
+ #
250
+ if os.path.exists(f"Loss_Gen_Per_Epoch_{hps.name}.p") and os.path.exists(f"Loss_Disc_Per_Epoch_{hps.name}.p"):
251
+ with open(f'Loss_Gen_Per_Epoch_{hps.name}.p', 'rb') as Loss_Gen:
252
+ Loss_Gen_Per_Epoch = p.load(Loss_Gen)
253
+ for i in range(len(Loss_Gen_Per_Epoch)-epoch_str+1):
254
+ Loss_Gen_Per_Epoch.pop()
255
+ with open(f'Loss_Disc_Per_Epoch_{hps.name}.p', 'rb') as Loss_Disc:
256
+ Loss_Disc_Per_Epoch = p.load(Loss_Disc)
257
+ for i in range(len(Loss_Disc_Per_Epoch)-epoch_str+1):
258
+ Loss_Disc_Per_Epoch.pop()
259
+ if os.path.exists(f"prev_best_epoch_{hps.name}.p"):
260
+ with open(f'prev_best_epoch_{hps.name}.p', 'rb') as prev_best_epoch_f:
261
+ prev_best_epoch = p.load(prev_best_epoch_f)
262
+ #
263
+ cache = []
264
+ for epoch in range(epoch_str, hps.train.epochs+1):
265
+ start_time = ttime()
266
+ if rank == 0:
267
+ train_and_evaluate(
268
+ rank,
269
+ epoch,
270
+ hps,
271
+ [net_g, net_d],
272
+ [optim_g, optim_d],
273
+ [scheduler_g, scheduler_d],
274
+ scaler,
275
+ [train_loader, None],
276
+ logger,
277
+ [writer, writer_eval],
278
+ cache,
279
+ )
280
+
281
+ # Printing and Saving stuff
282
+ loss_gen_all = loss_gen_all.item()
283
+ loss_disc = loss_disc.item()
284
+ #
285
+ Loss_Gen_Per_Epoch.append(loss_gen_all)
286
+ Loss_Disc_Per_Epoch.append(loss_disc)
287
+ #print(hps.train.epochs, epoch_str)
288
+ #
289
+ with open(f'Loss_Gen_Per_Epoch_{hps.name}.p', 'wb') as Loss_Gen:
290
+ p.dump(Loss_Gen_Per_Epoch, Loss_Gen)
291
+ Loss_Gen.close()
292
+ with open(f'Loss_Disc_Per_Epoch_{hps.name}.p', 'wb') as Loss_Disc:
293
+ p.dump(Loss_Disc_Per_Epoch, Loss_Disc)
294
+ Loss_Disc.close()
295
+ #
296
+ Lowest_lg = f"{min(Loss_Gen_Per_Epoch):.5f}, epoch: {right_index(Loss_Gen_Per_Epoch,min(Loss_Gen_Per_Epoch))+1}"
297
+ Lowest_ld = f"{min(Loss_Disc_Per_Epoch):.5f}, epoch: {right_index(Loss_Disc_Per_Epoch,min(Loss_Disc_Per_Epoch))+1}"
298
+ print(f"{hps.name}_e{epoch}_s{global_step} | Loss gen total: {Loss_Gen_Per_Epoch[-1]:.5f} | Lowest loss G: {Lowest_lg}\n Loss disc: {Loss_Disc_Per_Epoch[-1]:.5f} | Lowest loss D: {Lowest_ld}")
299
+ print(f"Specific Value: loss gen={loss_gen:.3f}, loss fm={loss_fm:.3f},loss mel={loss_mel:.3f}, loss kl={loss_kl:.3f}")
300
+ #
301
+ if len(Loss_Gen_Per_Epoch) > Min_for_Single_epoch and epoch % hps.save_every_epoch != 0:
302
+ if min(Loss_Gen_Per_Epoch[Min_for_Single_epoch::1]) == Loss_Gen_Per_Epoch[-1]:
303
+ if hasattr(net_g, "module"):
304
+ ckpt = net_g.module.state_dict()
305
+ else:
306
+ ckpt = net_g.state_dict()
307
+ savee(ckpt, hps.sample_rate, hps.if_f0, hps.name + "_e%s_s%s" % (epoch, global_step), epoch, hps.version, hps, experiment_name)
308
+ os.rename(f"logs/{hps.name}/weights/{hps.name}_e{epoch}_s{global_step}.pth",f"logs/{hps.name}/weights/{hps.name}_e{epoch}_s{global_step}_Best_Epoch.pth")
309
+ print(f"Saved: {hps.name}_e{epoch}_s{global_step}_Best_Epoch.pth")
310
+ try:
311
+ os.remove(prev_best_epoch)
312
+ except:
313
+ print("Nothing to remove, if there's is you may need to check again")
314
+ pass
315
+ else:
316
+ print(f"{os.path.split(prev_best_epoch)[-1]} Removed")
317
+ best_epoch = epoch
318
+ best_global_step = global_step
319
+ prev_best_epoch = f"logs/{hps.name}/weights/{hps.name}_e{best_epoch}_s{best_global_step}_Best_Epoch.pth"
320
+ with open(f'prev_best_epoch_{hps.name}.p', 'wb') as prev_best_epoch_f:
321
+ p.dump(prev_best_epoch, prev_best_epoch_f)
322
+ #
323
+ elapsed_time = ttime() - start_time
324
+ elapsed_time_record.append(elapsed_time)
325
+ if epoch-1 == epoch_str:
326
+ elapsed_time_record.pop(0)
327
+ elapsed_time_avg = sum(elapsed_time_record)/len(elapsed_time_record)
328
+ time_left = elapsed_time_avg*(hps.total_epoch-epoch)
329
+ hour, minute, second, millisec = Calculate_format_elapsed_time(elapsed_time)
330
+ hour_left, minute_left, second_left, millisec_left = Calculate_format_elapsed_time(time_left)
331
+ print(f"Time Elapsed: {hour}h:{formating_time(minute)}m:{formating_time(second)}s:{millisec}ms || Time left: {hour_left}h:{formating_time(minute_left)}m:{formating_time(second_left)}s:{millisec_left}ms\n")
332
+ #
333
+ if ((len(Loss_Gen_Per_Epoch) - right_index(Loss_Gen_Per_Epoch,min(Loss_Gen_Per_Epoch)) + 1) > overtrain and overtrain != -1):
334
+ logger.info("Over Train threshold reached. Training is done.")
335
+ print("Over Train threshold reached. Training is done.")
336
+
337
+ if hasattr(net_g, "module"):
338
+ ckpt = net_g.module.state_dict()
339
+ else:
340
+ ckpt = net_g.state_dict()
341
+ logger.info(
342
+ "saving final ckpt:%s"
343
+ % (
344
+ savee(
345
+ ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps, experiment_name
346
+ )
347
+ )
348
+ )
349
+ sleep(1)
350
+ with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
351
+ csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
352
+ csv_writer.writerow(["False"])
353
+ os._exit(2333333)
354
+
355
+ else:
356
+ train_and_evaluate(
357
+ rank,
358
+ epoch,
359
+ hps,
360
+ [net_g, net_d],
361
+ [optim_g, optim_d],
362
+ [scheduler_g, scheduler_d],
363
+ scaler,
364
+ [train_loader, None],
365
+ None,
366
+ None,
367
+ cache,
368
+ )
369
+ scheduler_g.step()
370
+ scheduler_d.step()
371
+ #gathered_tensors_gen = [torch.zeros_like(loss_gen_all) for _ in range(n_gpus)]
372
+ #gathered_tensors_disc = [torch.zeros_like(loss_disc) for _ in range(n_gpus)]
373
+ #dist.all_gather(gathered_tensors_gen, loss_gen_all)
374
+ #dist.all_gather(gathered_tensors_disc, loss_disc)
375
+
376
+
377
+
378
+ #######
379
+
380
+ def train_and_evaluate(
381
+ rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
382
+ ):
383
+ global loss_gen_all, loss_disc, ckpt, loss_kl, loss_fm, loss_gen, loss_mel
384
+ net_g, net_d = nets
385
+ optim_g, optim_d = optims
386
+ train_loader, eval_loader = loaders
387
+ if writers is not None:
388
+ writer, writer_eval = writers
389
+
390
+ train_loader.batch_sampler.set_epoch(epoch)
391
+ global global_step
392
+
393
+ net_g.train()
394
+ net_d.train()
395
+
396
+ # Prepare data iterator
397
+ if hps.if_cache_data_in_gpu == True:
398
+ # Use Cache
399
+ data_iterator = cache
400
+ if cache == []:
401
+ # Make new cache
402
+ for batch_idx, info in enumerate(train_loader):
403
+ # Unpack
404
+ if hps.if_f0 == 1:
405
+ (
406
+ phone,
407
+ phone_lengths,
408
+ pitch,
409
+ pitchf,
410
+ spec,
411
+ spec_lengths,
412
+ wave,
413
+ wave_lengths,
414
+ sid,
415
+ ) = info
416
+ else:
417
+ (
418
+ phone,
419
+ phone_lengths,
420
+ spec,
421
+ spec_lengths,
422
+ wave,
423
+ wave_lengths,
424
+ sid,
425
+ ) = info
426
+ # Load on CUDA
427
+ if torch.cuda.is_available():
428
+ phone = phone.cuda(rank, non_blocking=True)
429
+ phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
430
+ if hps.if_f0 == 1:
431
+ pitch = pitch.cuda(rank, non_blocking=True)
432
+ pitchf = pitchf.cuda(rank, non_blocking=True)
433
+ sid = sid.cuda(rank, non_blocking=True)
434
+ spec = spec.cuda(rank, non_blocking=True)
435
+ spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
436
+ wave = wave.cuda(rank, non_blocking=True)
437
+ wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
438
+ # Cache on list
439
+ if hps.if_f0 == 1:
440
+ cache.append(
441
+ (
442
+ batch_idx,
443
+ (
444
+ phone,
445
+ phone_lengths,
446
+ pitch,
447
+ pitchf,
448
+ spec,
449
+ spec_lengths,
450
+ wave,
451
+ wave_lengths,
452
+ sid,
453
+ ),
454
+ )
455
+ )
456
+ else:
457
+ cache.append(
458
+ (
459
+ batch_idx,
460
+ (
461
+ phone,
462
+ phone_lengths,
463
+ spec,
464
+ spec_lengths,
465
+ wave,
466
+ wave_lengths,
467
+ sid,
468
+ ),
469
+ )
470
+ )
471
+ else:
472
+ # Load shuffled cache
473
+ shuffle(cache)
474
+ else:
475
+ # Loader
476
+ data_iterator = enumerate(train_loader)
477
+
478
+ # Run steps
479
+ epoch_recorder = EpochRecorder()
480
+
481
+ for batch_idx, info in data_iterator:
482
+ # Data
483
+ ## Unpack
484
+ if hps.if_f0 == 1:
485
+ (
486
+ phone,
487
+ phone_lengths,
488
+ pitch,
489
+ pitchf,
490
+ spec,
491
+ spec_lengths,
492
+ wave,
493
+ wave_lengths,
494
+ sid,
495
+ ) = info
496
+ else:
497
+ phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
498
+ ## Load on CUDA
499
+ if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
500
+ phone = phone.cuda(rank, non_blocking=True)
501
+ phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
502
+ if hps.if_f0 == 1:
503
+ pitch = pitch.cuda(rank, non_blocking=True)
504
+ pitchf = pitchf.cuda(rank, non_blocking=True)
505
+ sid = sid.cuda(rank, non_blocking=True)
506
+ spec = spec.cuda(rank, non_blocking=True)
507
+ spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
508
+ wave = wave.cuda(rank, non_blocking=True)
509
+ # wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
510
+
511
+ # Calculate
512
+ with autocast(enabled=hps.train.fp16_run):
513
+ if hps.if_f0 == 1:
514
+ (
515
+ y_hat,
516
+ ids_slice,
517
+ x_mask,
518
+ z_mask,
519
+ (z, z_p, m_p, logs_p, m_q, logs_q),
520
+ ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
521
+ else:
522
+ (
523
+ y_hat,
524
+ ids_slice,
525
+ x_mask,
526
+ z_mask,
527
+ (z, z_p, m_p, logs_p, m_q, logs_q),
528
+ ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
529
+ mel = spec_to_mel_torch(
530
+ spec,
531
+ hps.data.filter_length,
532
+ hps.data.n_mel_channels,
533
+ hps.data.sampling_rate,
534
+ hps.data.mel_fmin,
535
+ hps.data.mel_fmax,
536
+ )
537
+ y_mel = commons.slice_segments(
538
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length
539
+ )
540
+ with autocast(enabled=False):
541
+ y_hat_mel = mel_spectrogram_torch(
542
+ y_hat.float().squeeze(1),
543
+ hps.data.filter_length,
544
+ hps.data.n_mel_channels,
545
+ hps.data.sampling_rate,
546
+ hps.data.hop_length,
547
+ hps.data.win_length,
548
+ hps.data.mel_fmin,
549
+ hps.data.mel_fmax,
550
+ )
551
+ if hps.train.fp16_run == True:
552
+ y_hat_mel = y_hat_mel.half()
553
+ wave = commons.slice_segments(
554
+ wave, ids_slice * hps.data.hop_length, hps.train.segment_size
555
+ ) # slice
556
+
557
+ # Discriminator
558
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
559
+ with autocast(enabled=False):
560
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
561
+ y_d_hat_r, y_d_hat_g
562
+ )
563
+ optim_d.zero_grad()
564
+ scaler.scale(loss_disc).backward()
565
+ scaler.unscale_(optim_d)
566
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
567
+ scaler.step(optim_d)
568
+
569
+ with autocast(enabled=hps.train.fp16_run):
570
+ # Generator
571
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
572
+ with autocast(enabled=False):
573
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
574
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
575
+ loss_fm = feature_loss(fmap_r, fmap_g)
576
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
577
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
578
+ optim_g.zero_grad()
579
+ scaler.scale(loss_gen_all).backward()
580
+ scaler.unscale_(optim_g)
581
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
582
+ scaler.step(optim_g)
583
+ scaler.update()
584
+
585
+ if rank == 0:
586
+ if global_step % hps.train.log_interval == 0:
587
+ lr = optim_g.param_groups[0]["lr"]
588
+ logger.info( ""
589
+ #"Train Epoch: {} [{:.0f}%]".format(
590
+ #epoch, 100.0 * batch_idx / len(train_loader)
591
+ #)
592
+ )
593
+ # Amor For Tensorboard display
594
+ if loss_mel > 75:
595
+ loss_mel = 75
596
+ if loss_kl > 9:
597
+ loss_kl = 9
598
+
599
+ logger.info([global_step, lr])
600
+ logger.info(""
601
+ #f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
602
+ )
603
+ scalar_dict = {
604
+ "loss/g/total": loss_gen_all,
605
+ "loss/d/total": loss_disc,
606
+ "learning_rate": lr,
607
+ "grad_norm_d": grad_norm_d,
608
+ "grad_norm_g": grad_norm_g,
609
+ }
610
+ scalar_dict.update(
611
+ {
612
+ "loss/g/fm": loss_fm,
613
+ "loss/g/mel": loss_mel,
614
+ "loss/g/kl": loss_kl,
615
+ }
616
+ )
617
+
618
+ scalar_dict.update(
619
+ {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
620
+ )
621
+ scalar_dict.update(
622
+ {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
623
+ )
624
+ scalar_dict.update(
625
+ {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
626
+ )
627
+ image_dict = {
628
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
629
+ y_mel[0].data.cpu().numpy()
630
+ ),
631
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
632
+ y_hat_mel[0].data.cpu().numpy()
633
+ ),
634
+ "all/mel": utils.plot_spectrogram_to_numpy(
635
+ mel[0].data.cpu().numpy()
636
+ ),
637
+ }
638
+ utils.summarize(
639
+ writer=writer,
640
+ global_step=global_step,
641
+ images=image_dict,
642
+ scalars=scalar_dict,
643
+ )
644
+ global_step += 1
645
+ # /Run steps
646
+
647
+ if epoch % hps.save_every_epoch == 0 and rank == 0:
648
+ print(f"Saved: {hps.name}_e{epoch}_s{global_step}.pth")
649
+ if hps.if_latest == 0:
650
+ utils.save_checkpoint(
651
+ net_g,
652
+ optim_g,
653
+ hps.train.learning_rate,
654
+ epoch,
655
+ os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
656
+ )
657
+ utils.save_checkpoint(
658
+ net_d,
659
+ optim_d,
660
+ hps.train.learning_rate,
661
+ epoch,
662
+ os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
663
+ )
664
+ else:
665
+ utils.save_checkpoint(
666
+ net_g,
667
+ optim_g,
668
+ hps.train.learning_rate,
669
+ epoch,
670
+ os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
671
+ )
672
+ utils.save_checkpoint(
673
+ net_d,
674
+ optim_d,
675
+ hps.train.learning_rate,
676
+ epoch,
677
+ os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
678
+ )
679
+ if rank == 0 and hps.save_every_weights == "1":
680
+ if hasattr(net_g, "module"):
681
+ ckpt = net_g.module.state_dict()
682
+ else:
683
+ ckpt = net_g.state_dict()
684
+ logger.info(
685
+ "saving ckpt %s_e%s:%s"
686
+ % (
687
+ hps.name,
688
+ epoch,
689
+ savee(
690
+ ckpt,
691
+ hps.sample_rate,
692
+ hps.if_f0,
693
+ hps.name + "_e%s_s%s" % (epoch, global_step),
694
+ epoch,
695
+ hps.version,
696
+ hps,
697
+ experiment_name,
698
+ ),
699
+ )
700
+ )
701
+
702
+ try:
703
+ with open("csvdb/stop.csv") as CSVStop:
704
+ csv_reader = list(csv.reader(CSVStop))
705
+ stopbtn = (
706
+ csv_reader[0][0]
707
+ if csv_reader is not None
708
+ else (lambda: exec('raise ValueError("No data")'))()
709
+ )
710
+ stopbtn = (
711
+ lambda stopbtn: True
712
+ if stopbtn.lower() == "true"
713
+ else (False if stopbtn.lower() == "false" else stopbtn)
714
+ )(stopbtn)
715
+ except (ValueError, TypeError, IndexError):
716
+ stopbtn = False
717
+
718
+ if stopbtn:
719
+ logger.info("Stop Button was pressed. The program is closed.")
720
+ if hasattr(net_g, "module"):
721
+ ckpt = net_g.module.state_dict()
722
+ else:
723
+ ckpt = net_g.state_dict()
724
+ logger.info(
725
+ "saving final ckpt:%s"
726
+ % (
727
+ savee(
728
+ ckpt,
729
+ hps.sample_rate,
730
+ hps.if_f0,
731
+ hps.name,
732
+ epoch,
733
+ hps.version,
734
+ hps,
735
+ experiment_name,
736
+ )
737
+ )
738
+ )
739
+ sleep(1)
740
+ with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
741
+ csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
742
+ csv_writer.writerow(["False"])
743
+ os._exit(2333333)
744
+
745
+ if rank == 0:
746
+ logger.info('')#"====> Epoch: {} {}".format(epoch, epoch_recorder.record()))
747
+ if epoch > hps.total_epoch and rank == 0:
748
+ logger.info("Training is done. The program is closed.")
749
+
750
+ if hasattr(net_g, "module"):
751
+ ckpt = net_g.module.state_dict()
752
+ else:
753
+ ckpt = net_g.state_dict()
754
+ logger.info(
755
+ "saving final ckpt:%s"
756
+ % (
757
+ savee(
758
+ ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps, experiment_name
759
+ )
760
+ )
761
+ )
762
+ sleep(1)
763
+ with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
764
+ csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
765
+ csv_writer.writerow(["False"])
766
+ os._exit(2333333)
767
+
768
+
769
+ if __name__ == "__main__":
770
+ torch.multiprocessing.set_start_method("spawn")
771
+ main()