Pj12 commited on
Commit
d0b22cd
·
verified ·
1 Parent(s): 09f4a8b

Delete train_nsf_sim_cache_sid_load_pretrain.py

Browse files
train_nsf_sim_cache_sid_load_pretrain.py DELETED
@@ -1,765 +0,0 @@
1
- import sys, os
2
- import pickle as p
3
- now_dir = os.getcwd()
4
- sys.path.append(os.path.join(now_dir))
5
- sys.path.append(os.path.join(now_dir, "train"))
6
- import utils
7
- Loss_Gen_Per_Epoch = []
8
- Loss_Disc_Per_Epoch = []
9
- elapsed_time_record = []
10
- Lowest_lg = 0
11
- Lowest_ld = 0
12
- import datetime
13
- hps = utils.get_hparams()
14
- overtrain = hps.overtrain
15
- experiment_name = hps.name
16
- os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
17
- n_gpus = len(hps.gpus.split("-"))
18
- from random import shuffle, randint
19
- import traceback, json, argparse, itertools, math, torch, pdb
20
-
21
- torch.backends.cudnn.deterministic = False
22
- torch.backends.cudnn.benchmark = False
23
- from torch import nn, optim
24
- from torch.nn import functional as F
25
- from torch.utils.data import DataLoader
26
- from torch.utils.tensorboard import SummaryWriter
27
- import torch.multiprocessing as mp
28
- import torch.distributed as dist
29
- from torch.nn.parallel import DistributedDataParallel as DDP
30
- from torch.cuda.amp import autocast, GradScaler
31
- from lib.infer_pack import commons
32
- from time import sleep
33
- from time import time as ttime
34
- from data_utils import (
35
- TextAudioLoaderMultiNSFsid,
36
- TextAudioLoader,
37
- TextAudioCollateMultiNSFsid,
38
- TextAudioCollate,
39
- DistributedBucketSampler,
40
- )
41
-
42
- import csv
43
-
44
- if hps.version == "v1":
45
- from lib.infer_pack.models import (
46
- SynthesizerTrnMs256NSFsid as RVC_Model_f0,
47
- SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
48
- MultiPeriodDiscriminator,
49
- )
50
- else:
51
- from lib.infer_pack.models import (
52
- SynthesizerTrnMs768NSFsid as RVC_Model_f0,
53
- SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
54
- MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
55
- )
56
- from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
57
- from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
58
- from process_ckpt import savee
59
-
60
- global global_step
61
- global_step = 0
62
-
63
- def Calculate_format_elapsed_time(elapsed_time):
64
- h = int(elapsed_time/3600)
65
- m,s,ms = int(elapsed_time/60 - h*60), int(elapsed_time%60), round((elapsed_time - int(elapsed_time))*10000)
66
- return h,m,s,ms
67
- def right_index(List,Value):
68
- index = len(List)-1-List[::-1].index(Value)
69
- return index
70
- def formating_time(time):
71
- time = time if time >= 10 else f"0{time}"
72
- return time
73
- class EpochRecorder:
74
- def __init__(self):
75
- self.last_time = ttime()
76
-
77
- def record(self):
78
- now_time = ttime()
79
- elapsed_time = now_time - self.last_time
80
- self.last_time = now_time
81
- elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time))
82
- current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
83
- return f"[{current_time}] | ({elapsed_time_str})"
84
-
85
-
86
- def main():
87
- n_gpus = torch.cuda.device_count()
88
- if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True:
89
- n_gpus = 1
90
- os.environ["MASTER_ADDR"] = "localhost"
91
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
92
- children = []
93
- for i in range(n_gpus):
94
- subproc = mp.Process(
95
- target=run,
96
- args=(
97
- i,
98
- n_gpus,
99
- hps,
100
- ),
101
- )
102
- children.append(subproc)
103
- subproc.start()
104
- for i in range(n_gpus):
105
- children[i].join()
106
-
107
-
108
-
109
- def run(rank, n_gpus, hps):
110
- global global_step, loss_disc, loss_gen_all, Loss_Disc_Per_Epoch, Loss_Gen_Per_Epoch, elapsed_time_record, best_epoch, best_global_step, Min_for_Single_epoch, prev_best_epoch
111
- if rank == 0:
112
- logger = utils.get_logger(hps.model_dir)
113
- logger.info(hps)
114
- # utils.check_git_hash(hps.model_dir)
115
- writer = SummaryWriter(log_dir=hps.model_dir)
116
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
117
-
118
- dist.init_process_group(
119
- backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
120
- )
121
- torch.manual_seed(hps.train.seed)
122
- if torch.cuda.is_available():
123
- torch.cuda.set_device(rank)
124
-
125
- if hps.if_f0 == 1:
126
- train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
127
- else:
128
- train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
129
- train_sampler = DistributedBucketSampler(
130
- train_dataset,
131
- hps.train.batch_size * n_gpus,
132
- # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
133
- [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
134
- num_replicas=n_gpus,
135
- rank=rank,
136
- shuffle=True,
137
- )
138
- # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
139
- # num_workers=8 -> num_workers=4
140
- if hps.if_f0 == 1:
141
- collate_fn = TextAudioCollateMultiNSFsid()
142
- else:
143
- collate_fn = TextAudioCollate()
144
- train_loader = DataLoader(
145
- train_dataset,
146
- num_workers=4,
147
- shuffle=False,
148
- pin_memory=True,
149
- collate_fn=collate_fn,
150
- batch_sampler=train_sampler,
151
- persistent_workers=True,
152
- prefetch_factor=8,
153
- )
154
- if hps.if_f0 == 1:
155
- net_g = RVC_Model_f0(
156
- hps.data.filter_length // 2 + 1,
157
- hps.train.segment_size // hps.data.hop_length,
158
- **hps.model,
159
- is_half=hps.train.fp16_run,
160
- sr=hps.sample_rate,
161
- )
162
- else:
163
- net_g = RVC_Model_nof0(
164
- hps.data.filter_length // 2 + 1,
165
- hps.train.segment_size // hps.data.hop_length,
166
- **hps.model,
167
- is_half=hps.train.fp16_run,
168
- )
169
- if torch.cuda.is_available():
170
- net_g = net_g.cuda(rank)
171
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
172
- if torch.cuda.is_available():
173
- net_d = net_d.cuda(rank)
174
- optim_g = torch.optim.AdamW(
175
- net_g.parameters(),
176
- hps.train.learning_rate,
177
- betas=hps.train.betas,
178
- eps=hps.train.eps,
179
- )
180
- optim_d = torch.optim.AdamW(
181
- net_d.parameters(),
182
- hps.train.learning_rate,
183
- betas=hps.train.betas,
184
- eps=hps.train.eps,
185
- )
186
- # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
187
- # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
188
- if torch.cuda.is_available():
189
- net_g = DDP(net_g, device_ids=[rank])
190
- net_d = DDP(net_d, device_ids=[rank])
191
- else:
192
- net_g = DDP(net_g)
193
- net_d = DDP(net_d)
194
-
195
- try: # 如果能加载自动resume
196
- _, _, _, epoch_str = utils.load_checkpoint(
197
- utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
198
- ) # D多半加载没事
199
- if rank == 0:
200
- logger.info("loaded D")
201
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
202
- _, _, _, epoch_str = utils.load_checkpoint(
203
- utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
204
- )
205
- global_step = (epoch_str - 1) * len(train_loader)
206
- # epoch_str = 1
207
- # global_step = 0
208
- except: # 如果首次不能加载,加载pretrain
209
- # traceback.print_exc()
210
- epoch_str = 1
211
- global_step = 0
212
- if hps.pretrainG != "":
213
- if rank == 0:
214
- logger.info("loaded pretrained %s" % (hps.pretrainG))
215
- print(
216
- net_g.module.load_state_dict(
217
- torch.load(hps.pretrainG, map_location="cpu")["model"]
218
- )
219
- ) ##测试不加载优化器
220
- if hps.pretrainD != "":
221
- if rank == 0:
222
- logger.info("loaded pretrained %s" % (hps.pretrainD))
223
- print(
224
- net_d.module.load_state_dict(
225
- torch.load(hps.pretrainD, map_location="cpu")["model"]
226
- )
227
- )
228
-
229
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
230
- optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
231
- )
232
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
233
- optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
234
- )
235
-
236
- scaler = GradScaler(enabled=hps.train.fp16_run)
237
- #
238
- #if hps.total_epoch < 100:
239
- #Min_for_Single_epoch = int(hps.total_epoch/2)
240
- #else:
241
- #Min_for_Single_epoch = 50
242
- Min_for_Single_epoch = 1
243
- #
244
- if os.path.exists(f"Loss_Gen_Per_Epoch_{hps.name}.p") and os.path.exists(f"Loss_Disc_Per_Epoch_{hps.name}.p"):
245
- with open(f'Loss_Gen_Per_Epoch_{hps.name}.p', 'rb') as Loss_Gen:
246
- Loss_Gen_Per_Epoch = p.load(Loss_Gen)
247
- for i in range(len(Loss_Gen_Per_Epoch)-epoch_str+1):
248
- Loss_Gen_Per_Epoch.pop()
249
- with open(f'Loss_Disc_Per_Epoch_{hps.name}.p', 'rb') as Loss_Disc:
250
- Loss_Disc_Per_Epoch = p.load(Loss_Disc)
251
- for i in range(len(Loss_Disc_Per_Epoch)-epoch_str+1):
252
- Loss_Disc_Per_Epoch.pop()
253
- if os.path.exists(f"prev_best_epoch_{hps.name}.p"):
254
- with open(f'prev_best_epoch_{hps.name}.p', 'rb') as prev_best_epoch_f:
255
- prev_best_epoch = p.load(prev_best_epoch_f)
256
- #
257
- cache = []
258
- for epoch in range(epoch_str, hps.train.epochs+1):
259
- start_time = ttime()
260
- if rank == 0:
261
- train_and_evaluate(
262
- rank,
263
- epoch,
264
- hps,
265
- [net_g, net_d],
266
- [optim_g, optim_d],
267
- [scheduler_g, scheduler_d],
268
- scaler,
269
- [train_loader, None],
270
- logger,
271
- [writer, writer_eval],
272
- cache,
273
- )
274
-
275
- # Printing and Saving stuff
276
- loss_gen_all = loss_gen_all.item()
277
- loss_disc = loss_disc.item()
278
- #
279
- Loss_Gen_Per_Epoch.append(loss_gen_all)
280
- Loss_Disc_Per_Epoch.append(loss_disc)
281
- #print(hps.train.epochs, epoch_str)
282
- #
283
- with open(f'Loss_Gen_Per_Epoch_{hps.name}.p', 'wb') as Loss_Gen:
284
- p.dump(Loss_Gen_Per_Epoch, Loss_Gen)
285
- Loss_Gen.close()
286
- with open(f'Loss_Disc_Per_Epoch_{hps.name}.p', 'wb') as Loss_Disc:
287
- p.dump(Loss_Disc_Per_Epoch, Loss_Disc)
288
- Loss_Disc.close()
289
- #
290
- Lowest_lg = f"{min(Loss_Gen_Per_Epoch):.5f}, epoch: {right_index(Loss_Gen_Per_Epoch,min(Loss_Gen_Per_Epoch))+1}"
291
- Lowest_ld = f"{min(Loss_Disc_Per_Epoch):.5f}, epoch: {right_index(Loss_Disc_Per_Epoch,min(Loss_Disc_Per_Epoch))+1}"
292
- print(f"{hps.name}_e{epoch}_s{global_step} | Loss gen total: {Loss_Gen_Per_Epoch[-1]:.5f} | Lowest loss G: {Lowest_lg}\n Loss disc: {Loss_Disc_Per_Epoch[-1]:.5f} | Lowest loss D: {Lowest_ld}")
293
- print(f"Specific Value: loss gen={loss_gen:.3f}, loss fm={loss_fm:.3f},loss mel={loss_mel:.3f}, loss kl={loss_kl:.3f}")
294
- #
295
- if len(Loss_Gen_Per_Epoch) > Min_for_Single_epoch and epoch % hps.save_every_epoch != 0:
296
- if min(Loss_Gen_Per_Epoch[Min_for_Single_epoch::1]) == Loss_Gen_Per_Epoch[-1]:
297
- if hasattr(net_g, "module"):
298
- ckpt = net_g.module.state_dict()
299
- else:
300
- ckpt = net_g.state_dict()
301
- savee(ckpt, hps.sample_rate, hps.if_f0, hps.name + "_e%s_s%s" % (epoch, global_step), epoch, hps.version, hps, experiment_name)
302
- os.rename(f"logs/{hps.name}/weights/{hps.name}_e{epoch}_s{global_step}.pth",f"logs/{hps.name}/weights/{hps.name}_e{epoch}_s{global_step}_Best_Epoch.pth")
303
- print(f"Saved: {hps.name}_e{epoch}_s{global_step}_Best_Epoch.pth")
304
- try:
305
- os.remove(prev_best_epoch)
306
- except:
307
- print("Nothing to remove, if there's is you may need to check again")
308
- pass
309
- else:
310
- print(f"{os.path.split(prev_best_epoch)[-1]} Removed")
311
- best_epoch = epoch
312
- best_global_step = global_step
313
- prev_best_epoch = f"logs/{hps.name}/weights/{hps.name}_e{best_epoch}_s{best_global_step}_Best_Epoch.pth"
314
- with open(f'prev_best_epoch_{hps.name}.p', 'wb') as prev_best_epoch_f:
315
- p.dump(prev_best_epoch, prev_best_epoch_f)
316
- #
317
- elapsed_time = ttime() - start_time
318
- elapsed_time_record.append(elapsed_time)
319
- if epoch-1 == epoch_str:
320
- elapsed_time_record.pop(0)
321
- elapsed_time_avg = sum(elapsed_time_record)/len(elapsed_time_record)
322
- time_left = elapsed_time_avg*(hps.total_epoch-epoch)
323
- hour, minute, second, millisec = Calculate_format_elapsed_time(elapsed_time)
324
- hour_left, minute_left, second_left, millisec_left = Calculate_format_elapsed_time(time_left)
325
- print(f"Time Elapsed: {hour}h:{formating_time(minute)}m:{formating_time(second)}s:{millisec}ms || Time left: {hour_left}h:{formating_time(minute_left)}m:{formating_time(second_left)}s:{millisec_left}ms\n")
326
- #
327
- if ((len(Loss_Gen_Per_Epoch) - right_index(Loss_Gen_Per_Epoch,min(Loss_Gen_Per_Epoch)) + 1) > overtrain and overtrain != -1):
328
- logger.info("Over Train threshold reached. Training is done.")
329
- print("Over Train threshold reached. Training is done.")
330
-
331
- if hasattr(net_g, "module"):
332
- ckpt = net_g.module.state_dict()
333
- else:
334
- ckpt = net_g.state_dict()
335
- logger.info(
336
- "saving final ckpt:%s"
337
- % (
338
- savee(
339
- ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps, experiment_name
340
- )
341
- )
342
- )
343
- sleep(1)
344
- with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
345
- csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
346
- csv_writer.writerow(["False"])
347
- os._exit(2333333)
348
-
349
- else:
350
- train_and_evaluate(
351
- rank,
352
- epoch,
353
- hps,
354
- [net_g, net_d],
355
- [optim_g, optim_d],
356
- [scheduler_g, scheduler_d],
357
- scaler,
358
- [train_loader, None],
359
- None,
360
- None,
361
- cache,
362
- )
363
- scheduler_g.step()
364
- scheduler_d.step()
365
- #gathered_tensors_gen = [torch.zeros_like(loss_gen_all) for _ in range(n_gpus)]
366
- #gathered_tensors_disc = [torch.zeros_like(loss_disc) for _ in range(n_gpus)]
367
- #dist.all_gather(gathered_tensors_gen, loss_gen_all)
368
- #dist.all_gather(gathered_tensors_disc, loss_disc)
369
-
370
-
371
-
372
- #######
373
-
374
- def train_and_evaluate(
375
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
376
- ):
377
- global loss_gen_all, loss_disc, ckpt, loss_kl, loss_fm, loss_gen, loss_mel
378
- net_g, net_d = nets
379
- optim_g, optim_d = optims
380
- train_loader, eval_loader = loaders
381
- if writers is not None:
382
- writer, writer_eval = writers
383
-
384
- train_loader.batch_sampler.set_epoch(epoch)
385
- global global_step
386
-
387
- net_g.train()
388
- net_d.train()
389
-
390
- # Prepare data iterator
391
- if hps.if_cache_data_in_gpu == True:
392
- # Use Cache
393
- data_iterator = cache
394
- if cache == []:
395
- # Make new cache
396
- for batch_idx, info in enumerate(train_loader):
397
- # Unpack
398
- if hps.if_f0 == 1:
399
- (
400
- phone,
401
- phone_lengths,
402
- pitch,
403
- pitchf,
404
- spec,
405
- spec_lengths,
406
- wave,
407
- wave_lengths,
408
- sid,
409
- ) = info
410
- else:
411
- (
412
- phone,
413
- phone_lengths,
414
- spec,
415
- spec_lengths,
416
- wave,
417
- wave_lengths,
418
- sid,
419
- ) = info
420
- # Load on CUDA
421
- if torch.cuda.is_available():
422
- phone = phone.cuda(rank, non_blocking=True)
423
- phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
424
- if hps.if_f0 == 1:
425
- pitch = pitch.cuda(rank, non_blocking=True)
426
- pitchf = pitchf.cuda(rank, non_blocking=True)
427
- sid = sid.cuda(rank, non_blocking=True)
428
- spec = spec.cuda(rank, non_blocking=True)
429
- spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
430
- wave = wave.cuda(rank, non_blocking=True)
431
- wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
432
- # Cache on list
433
- if hps.if_f0 == 1:
434
- cache.append(
435
- (
436
- batch_idx,
437
- (
438
- phone,
439
- phone_lengths,
440
- pitch,
441
- pitchf,
442
- spec,
443
- spec_lengths,
444
- wave,
445
- wave_lengths,
446
- sid,
447
- ),
448
- )
449
- )
450
- else:
451
- cache.append(
452
- (
453
- batch_idx,
454
- (
455
- phone,
456
- phone_lengths,
457
- spec,
458
- spec_lengths,
459
- wave,
460
- wave_lengths,
461
- sid,
462
- ),
463
- )
464
- )
465
- else:
466
- # Load shuffled cache
467
- shuffle(cache)
468
- else:
469
- # Loader
470
- data_iterator = enumerate(train_loader)
471
-
472
- # Run steps
473
- epoch_recorder = EpochRecorder()
474
-
475
- for batch_idx, info in data_iterator:
476
- # Data
477
- ## Unpack
478
- if hps.if_f0 == 1:
479
- (
480
- phone,
481
- phone_lengths,
482
- pitch,
483
- pitchf,
484
- spec,
485
- spec_lengths,
486
- wave,
487
- wave_lengths,
488
- sid,
489
- ) = info
490
- else:
491
- phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
492
- ## Load on CUDA
493
- if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
494
- phone = phone.cuda(rank, non_blocking=True)
495
- phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
496
- if hps.if_f0 == 1:
497
- pitch = pitch.cuda(rank, non_blocking=True)
498
- pitchf = pitchf.cuda(rank, non_blocking=True)
499
- sid = sid.cuda(rank, non_blocking=True)
500
- spec = spec.cuda(rank, non_blocking=True)
501
- spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
502
- wave = wave.cuda(rank, non_blocking=True)
503
- # wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
504
-
505
- # Calculate
506
- with autocast(enabled=hps.train.fp16_run):
507
- if hps.if_f0 == 1:
508
- (
509
- y_hat,
510
- ids_slice,
511
- x_mask,
512
- z_mask,
513
- (z, z_p, m_p, logs_p, m_q, logs_q),
514
- ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
515
- else:
516
- (
517
- y_hat,
518
- ids_slice,
519
- x_mask,
520
- z_mask,
521
- (z, z_p, m_p, logs_p, m_q, logs_q),
522
- ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
523
- mel = spec_to_mel_torch(
524
- spec,
525
- hps.data.filter_length,
526
- hps.data.n_mel_channels,
527
- hps.data.sampling_rate,
528
- hps.data.mel_fmin,
529
- hps.data.mel_fmax,
530
- )
531
- y_mel = commons.slice_segments(
532
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
533
- )
534
- with autocast(enabled=False):
535
- y_hat_mel = mel_spectrogram_torch(
536
- y_hat.float().squeeze(1),
537
- hps.data.filter_length,
538
- hps.data.n_mel_channels,
539
- hps.data.sampling_rate,
540
- hps.data.hop_length,
541
- hps.data.win_length,
542
- hps.data.mel_fmin,
543
- hps.data.mel_fmax,
544
- )
545
- if hps.train.fp16_run == True:
546
- y_hat_mel = y_hat_mel.half()
547
- wave = commons.slice_segments(
548
- wave, ids_slice * hps.data.hop_length, hps.train.segment_size
549
- ) # slice
550
-
551
- # Discriminator
552
- y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
553
- with autocast(enabled=False):
554
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
555
- y_d_hat_r, y_d_hat_g
556
- )
557
- optim_d.zero_grad()
558
- scaler.scale(loss_disc).backward()
559
- scaler.unscale_(optim_d)
560
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
561
- scaler.step(optim_d)
562
-
563
- with autocast(enabled=hps.train.fp16_run):
564
- # Generator
565
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
566
- with autocast(enabled=False):
567
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
568
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
569
- loss_fm = feature_loss(fmap_r, fmap_g)
570
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
571
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
572
- optim_g.zero_grad()
573
- scaler.scale(loss_gen_all).backward()
574
- scaler.unscale_(optim_g)
575
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
576
- scaler.step(optim_g)
577
- scaler.update()
578
-
579
- if rank == 0:
580
- if global_step % hps.train.log_interval == 0:
581
- lr = optim_g.param_groups[0]["lr"]
582
- logger.info( ""
583
- #"Train Epoch: {} [{:.0f}%]".format(
584
- #epoch, 100.0 * batch_idx / len(train_loader)
585
- #)
586
- )
587
- # Amor For Tensorboard display
588
- if loss_mel > 75:
589
- loss_mel = 75
590
- if loss_kl > 9:
591
- loss_kl = 9
592
-
593
- logger.info([global_step, lr])
594
- logger.info(""
595
- #f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
596
- )
597
- scalar_dict = {
598
- "loss/g/total": loss_gen_all,
599
- "loss/d/total": loss_disc,
600
- "learning_rate": lr,
601
- "grad_norm_d": grad_norm_d,
602
- "grad_norm_g": grad_norm_g,
603
- }
604
- scalar_dict.update(
605
- {
606
- "loss/g/fm": loss_fm,
607
- "loss/g/mel": loss_mel,
608
- "loss/g/kl": loss_kl,
609
- }
610
- )
611
-
612
- scalar_dict.update(
613
- {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
614
- )
615
- scalar_dict.update(
616
- {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
617
- )
618
- scalar_dict.update(
619
- {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
620
- )
621
- image_dict = {
622
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
623
- y_mel[0].data.cpu().numpy()
624
- ),
625
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
626
- y_hat_mel[0].data.cpu().numpy()
627
- ),
628
- "all/mel": utils.plot_spectrogram_to_numpy(
629
- mel[0].data.cpu().numpy()
630
- ),
631
- }
632
- utils.summarize(
633
- writer=writer,
634
- global_step=global_step,
635
- images=image_dict,
636
- scalars=scalar_dict,
637
- )
638
- global_step += 1
639
- # /Run steps
640
-
641
- if epoch % hps.save_every_epoch == 0 and rank == 0:
642
- print(f"Saved: {hps.name}_e{epoch}_s{global_step}.pth")
643
- if hps.if_latest == 0:
644
- utils.save_checkpoint(
645
- net_g,
646
- optim_g,
647
- hps.train.learning_rate,
648
- epoch,
649
- os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
650
- )
651
- utils.save_checkpoint(
652
- net_d,
653
- optim_d,
654
- hps.train.learning_rate,
655
- epoch,
656
- os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
657
- )
658
- else:
659
- utils.save_checkpoint(
660
- net_g,
661
- optim_g,
662
- hps.train.learning_rate,
663
- epoch,
664
- os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
665
- )
666
- utils.save_checkpoint(
667
- net_d,
668
- optim_d,
669
- hps.train.learning_rate,
670
- epoch,
671
- os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
672
- )
673
- if rank == 0 and hps.save_every_weights == "1":
674
- if hasattr(net_g, "module"):
675
- ckpt = net_g.module.state_dict()
676
- else:
677
- ckpt = net_g.state_dict()
678
- logger.info(
679
- "saving ckpt %s_e%s:%s"
680
- % (
681
- hps.name,
682
- epoch,
683
- savee(
684
- ckpt,
685
- hps.sample_rate,
686
- hps.if_f0,
687
- hps.name + "_e%s_s%s" % (epoch, global_step),
688
- epoch,
689
- hps.version,
690
- hps,
691
- experiment_name,
692
- ),
693
- )
694
- )
695
-
696
- try:
697
- with open("csvdb/stop.csv") as CSVStop:
698
- csv_reader = list(csv.reader(CSVStop))
699
- stopbtn = (
700
- csv_reader[0][0]
701
- if csv_reader is not None
702
- else (lambda: exec('raise ValueError("No data")'))()
703
- )
704
- stopbtn = (
705
- lambda stopbtn: True
706
- if stopbtn.lower() == "true"
707
- else (False if stopbtn.lower() == "false" else stopbtn)
708
- )(stopbtn)
709
- except (ValueError, TypeError, IndexError):
710
- stopbtn = False
711
-
712
- if stopbtn:
713
- logger.info("Stop Button was pressed. The program is closed.")
714
- if hasattr(net_g, "module"):
715
- ckpt = net_g.module.state_dict()
716
- else:
717
- ckpt = net_g.state_dict()
718
- logger.info(
719
- "saving final ckpt:%s"
720
- % (
721
- savee(
722
- ckpt,
723
- hps.sample_rate,
724
- hps.if_f0,
725
- hps.name,
726
- epoch,
727
- hps.version,
728
- hps,
729
- experiment_name,
730
- )
731
- )
732
- )
733
- sleep(1)
734
- with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
735
- csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
736
- csv_writer.writerow(["False"])
737
- os._exit(2333333)
738
-
739
- if rank == 0:
740
- logger.info('')#"====> Epoch: {} {}".format(epoch, epoch_recorder.record()))
741
- if epoch > hps.total_epoch and rank == 0:
742
- logger.info("Training is done. The program is closed.")
743
-
744
- if hasattr(net_g, "module"):
745
- ckpt = net_g.module.state_dict()
746
- else:
747
- ckpt = net_g.state_dict()
748
- logger.info(
749
- "saving final ckpt:%s"
750
- % (
751
- savee(
752
- ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps, experiment_name
753
- )
754
- )
755
- )
756
- sleep(1)
757
- with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
758
- csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
759
- csv_writer.writerow(["False"])
760
- os._exit(2333333)
761
-
762
-
763
- if __name__ == "__main__":
764
- torch.multiprocessing.set_start_method("spawn")
765
- main()