Pj12 commited on
Commit
90bbd28
·
verified ·
1 Parent(s): 22ac411

Delete train_nsf_sim_cache_sid_load_pretrain.py

Browse files
train_nsf_sim_cache_sid_load_pretrain.py DELETED
@@ -1,771 +0,0 @@
1
- import sys, os
2
- import pickle as p
3
- now_dir = os.getcwd()
4
- sys.path.append(os.path.join(now_dir))
5
- sys.path.append(os.path.join(now_dir, "train"))
6
- import utils
7
- Loss_Gen_Per_Epoch = []
8
- Loss_Disc_Per_Epoch = []
9
- elapsed_time_record = []
10
- Lowest_lg = 0
11
- Lowest_ld = 0
12
- import datetime
13
- hps = utils.get_hparams()
14
- overtrain = hps.overtrain
15
- experiment_name = hps.name
16
- os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
17
- n_gpus = len(hps.gpus.split("-"))
18
- from random import shuffle, randint
19
- import traceback, json, argparse, itertools, math, torch, pdb
20
-
21
- torch.backends.cudnn.deterministic = False
22
- torch.backends.cudnn.benchmark = False
23
- from torch import nn, optim
24
- from torch.nn import functional as F
25
- from torch.utils.data import DataLoader
26
- from torch.utils.tensorboard import SummaryWriter
27
- import torch.multiprocessing as mp
28
- import torch.distributed as dist
29
- from torch.nn.parallel import DistributedDataParallel as DDP
30
- from torch.cuda.amp import autocast, GradScaler
31
- from lib.infer_pack import commons
32
- from time import sleep
33
- from time import time as ttime
34
- from data_utils import (
35
- TextAudioLoaderMultiNSFsid,
36
- TextAudioLoader,
37
- TextAudioCollateMultiNSFsid,
38
- TextAudioCollate,
39
- DistributedBucketSampler,
40
- )
41
-
42
- import csv
43
-
44
- if hps.version == "v1":
45
- from lib.infer_pack.models import (
46
- SynthesizerTrnMs256NSFsid as RVC_Model_f0,
47
- SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
48
- MultiPeriodDiscriminator,
49
- )
50
- elif hps.version == "v2" and hps.Large_HuBert == True:
51
- from lib.infer_pack.models import (
52
- SynthesizerTrnMs1024NSFsid as RVC_Model_f0,
53
- SynthesizerTrnMs1024NSFsid_nono as RVC_Model_nof0,
54
- MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
55
- )
56
- else:
57
- from lib.infer_pack.models import (
58
- SynthesizerTrnMs768NSFsid as RVC_Model_f0,
59
- SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
60
- MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
61
- )
62
- from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
63
- from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
64
- from process_ckpt import savee
65
-
66
- global global_step
67
- global_step = 0
68
-
69
- def Calculate_format_elapsed_time(elapsed_time):
70
- h = int(elapsed_time/3600)
71
- m,s,ms = int(elapsed_time/60 - h*60), int(elapsed_time%60), round((elapsed_time - int(elapsed_time))*10000)
72
- return h,m,s,ms
73
- def right_index(List,Value):
74
- index = len(List)-1-List[::-1].index(Value)
75
- return index
76
- def formating_time(time):
77
- time = time if time >= 10 else f"0{time}"
78
- return time
79
- class EpochRecorder:
80
- def __init__(self):
81
- self.last_time = ttime()
82
-
83
- def record(self):
84
- now_time = ttime()
85
- elapsed_time = now_time - self.last_time
86
- self.last_time = now_time
87
- elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time))
88
- current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
89
- return f"[{current_time}] | ({elapsed_time_str})"
90
-
91
-
92
- def main():
93
- n_gpus = torch.cuda.device_count()
94
- if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True:
95
- n_gpus = 1
96
- os.environ["MASTER_ADDR"] = "localhost"
97
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
98
- children = []
99
- for i in range(n_gpus):
100
- subproc = mp.Process(
101
- target=run,
102
- args=(
103
- i,
104
- n_gpus,
105
- hps,
106
- ),
107
- )
108
- children.append(subproc)
109
- subproc.start()
110
- for i in range(n_gpus):
111
- children[i].join()
112
-
113
-
114
-
115
- def run(rank, n_gpus, hps):
116
- global global_step, loss_disc, loss_gen_all, Loss_Disc_Per_Epoch, Loss_Gen_Per_Epoch, elapsed_time_record, best_epoch, best_global_step, Min_for_Single_epoch, prev_best_epoch
117
- if rank == 0:
118
- logger = utils.get_logger(hps.model_dir)
119
- logger.info(hps)
120
- # utils.check_git_hash(hps.model_dir)
121
- writer = SummaryWriter(log_dir=hps.model_dir)
122
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
123
-
124
- dist.init_process_group(
125
- backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
126
- )
127
- torch.manual_seed(hps.train.seed)
128
- if torch.cuda.is_available():
129
- torch.cuda.set_device(rank)
130
-
131
- if hps.if_f0 == 1:
132
- train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
133
- else:
134
- train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
135
- train_sampler = DistributedBucketSampler(
136
- train_dataset,
137
- hps.train.batch_size * n_gpus,
138
- # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
139
- [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
140
- num_replicas=n_gpus,
141
- rank=rank,
142
- shuffle=True,
143
- )
144
- # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
145
- # num_workers=8 -> num_workers=4
146
- if hps.if_f0 == 1:
147
- collate_fn = TextAudioCollateMultiNSFsid()
148
- else:
149
- collate_fn = TextAudioCollate()
150
- train_loader = DataLoader(
151
- train_dataset,
152
- num_workers=4,
153
- shuffle=False,
154
- pin_memory=True,
155
- collate_fn=collate_fn,
156
- batch_sampler=train_sampler,
157
- persistent_workers=True,
158
- prefetch_factor=8,
159
- )
160
- if hps.if_f0 == 1:
161
- net_g = RVC_Model_f0(
162
- hps.data.filter_length // 2 + 1,
163
- hps.train.segment_size // hps.data.hop_length,
164
- **hps.model,
165
- is_half=hps.train.fp16_run,
166
- sr=hps.sample_rate,
167
- )
168
- else:
169
- net_g = RVC_Model_nof0(
170
- hps.data.filter_length // 2 + 1,
171
- hps.train.segment_size // hps.data.hop_length,
172
- **hps.model,
173
- is_half=hps.train.fp16_run,
174
- )
175
- if torch.cuda.is_available():
176
- net_g = net_g.cuda(rank)
177
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
178
- if torch.cuda.is_available():
179
- net_d = net_d.cuda(rank)
180
- optim_g = torch.optim.AdamW(
181
- net_g.parameters(),
182
- hps.train.learning_rate,
183
- betas=hps.train.betas,
184
- eps=hps.train.eps,
185
- )
186
- optim_d = torch.optim.AdamW(
187
- net_d.parameters(),
188
- hps.train.learning_rate,
189
- betas=hps.train.betas,
190
- eps=hps.train.eps,
191
- )
192
- # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
193
- # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
194
- if torch.cuda.is_available():
195
- net_g = DDP(net_g, device_ids=[rank])
196
- net_d = DDP(net_d, device_ids=[rank])
197
- else:
198
- net_g = DDP(net_g)
199
- net_d = DDP(net_d)
200
-
201
- try: # 如果能加载自动resume
202
- _, _, _, epoch_str = utils.load_checkpoint(
203
- utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
204
- ) # D多半加载没事
205
- if rank == 0:
206
- logger.info("loaded D")
207
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
208
- _, _, _, epoch_str = utils.load_checkpoint(
209
- utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
210
- )
211
- global_step = (epoch_str - 1) * len(train_loader)
212
- # epoch_str = 1
213
- # global_step = 0
214
- except: # 如果首次不能加载,加载pretrain
215
- # traceback.print_exc()
216
- epoch_str = 1
217
- global_step = 0
218
- if hps.pretrainG != "":
219
- if rank == 0:
220
- logger.info("loaded pretrained %s" % (hps.pretrainG))
221
- print(
222
- net_g.module.load_state_dict(
223
- torch.load(hps.pretrainG, map_location="cpu")["model"]
224
- )
225
- ) ##测试不加载优化器
226
- if hps.pretrainD != "":
227
- if rank == 0:
228
- logger.info("loaded pretrained %s" % (hps.pretrainD))
229
- print(
230
- net_d.module.load_state_dict(
231
- torch.load(hps.pretrainD, map_location="cpu")["model"]
232
- )
233
- )
234
-
235
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
236
- optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
237
- )
238
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
239
- optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
240
- )
241
-
242
- scaler = GradScaler(enabled=hps.train.fp16_run)
243
- #
244
- #if hps.total_epoch < 100:
245
- #Min_for_Single_epoch = int(hps.total_epoch/2)
246
- #else:
247
- #Min_for_Single_epoch = 50
248
- Min_for_Single_epoch = 1
249
- #
250
- if os.path.exists(f"Loss_Gen_Per_Epoch_{hps.name}.p") and os.path.exists(f"Loss_Disc_Per_Epoch_{hps.name}.p"):
251
- with open(f'Loss_Gen_Per_Epoch_{hps.name}.p', 'rb') as Loss_Gen:
252
- Loss_Gen_Per_Epoch = p.load(Loss_Gen)
253
- for i in range(len(Loss_Gen_Per_Epoch)-epoch_str+1):
254
- Loss_Gen_Per_Epoch.pop()
255
- with open(f'Loss_Disc_Per_Epoch_{hps.name}.p', 'rb') as Loss_Disc:
256
- Loss_Disc_Per_Epoch = p.load(Loss_Disc)
257
- for i in range(len(Loss_Disc_Per_Epoch)-epoch_str+1):
258
- Loss_Disc_Per_Epoch.pop()
259
- if os.path.exists(f"prev_best_epoch_{hps.name}.p"):
260
- with open(f'prev_best_epoch_{hps.name}.p', 'rb') as prev_best_epoch_f:
261
- prev_best_epoch = p.load(prev_best_epoch_f)
262
- #
263
- cache = []
264
- for epoch in range(epoch_str, hps.train.epochs+1):
265
- start_time = ttime()
266
- if rank == 0:
267
- train_and_evaluate(
268
- rank,
269
- epoch,
270
- hps,
271
- [net_g, net_d],
272
- [optim_g, optim_d],
273
- [scheduler_g, scheduler_d],
274
- scaler,
275
- [train_loader, None],
276
- logger,
277
- [writer, writer_eval],
278
- cache,
279
- )
280
-
281
- # Printing and Saving stuff
282
- loss_gen_all = loss_gen_all.item()
283
- loss_disc = loss_disc.item()
284
- #
285
- Loss_Gen_Per_Epoch.append(loss_gen_all)
286
- Loss_Disc_Per_Epoch.append(loss_disc)
287
- #print(hps.train.epochs, epoch_str)
288
- #
289
- with open(f'Loss_Gen_Per_Epoch_{hps.name}.p', 'wb') as Loss_Gen:
290
- p.dump(Loss_Gen_Per_Epoch, Loss_Gen)
291
- Loss_Gen.close()
292
- with open(f'Loss_Disc_Per_Epoch_{hps.name}.p', 'wb') as Loss_Disc:
293
- p.dump(Loss_Disc_Per_Epoch, Loss_Disc)
294
- Loss_Disc.close()
295
- #
296
- Lowest_lg = f"{min(Loss_Gen_Per_Epoch):.5f}, epoch: {right_index(Loss_Gen_Per_Epoch,min(Loss_Gen_Per_Epoch))+1}"
297
- Lowest_ld = f"{min(Loss_Disc_Per_Epoch):.5f}, epoch: {right_index(Loss_Disc_Per_Epoch,min(Loss_Disc_Per_Epoch))+1}"
298
- print(f"{hps.name}_e{epoch}_s{global_step} | Loss gen total: {Loss_Gen_Per_Epoch[-1]:.5f} | Lowest loss G: {Lowest_lg}\n Loss disc: {Loss_Disc_Per_Epoch[-1]:.5f} | Lowest loss D: {Lowest_ld}")
299
- print(f"Specific Value: loss gen={loss_gen:.3f}, loss fm={loss_fm:.3f},loss mel={loss_mel:.3f}, loss kl={loss_kl:.3f}")
300
- #
301
- if len(Loss_Gen_Per_Epoch) > Min_for_Single_epoch and epoch % hps.save_every_epoch != 0:
302
- if min(Loss_Gen_Per_Epoch[Min_for_Single_epoch::1]) == Loss_Gen_Per_Epoch[-1]:
303
- if hasattr(net_g, "module"):
304
- ckpt = net_g.module.state_dict()
305
- else:
306
- ckpt = net_g.state_dict()
307
- savee(ckpt, hps.sample_rate, hps.if_f0, hps.name + "_e%s_s%s" % (epoch, global_step), epoch, hps.version, hps, experiment_name)
308
- os.rename(f"logs/{hps.name}/weights/{hps.name}_e{epoch}_s{global_step}.pth",f"logs/{hps.name}/weights/{hps.name}_e{epoch}_s{global_step}_Best_Epoch.pth")
309
- print(f"Saved: {hps.name}_e{epoch}_s{global_step}_Best_Epoch.pth")
310
- try:
311
- os.remove(prev_best_epoch)
312
- except:
313
- print("Nothing to remove, if there's is you may need to check again")
314
- pass
315
- else:
316
- print(f"{os.path.split(prev_best_epoch)[-1]} Removed")
317
- best_epoch = epoch
318
- best_global_step = global_step
319
- prev_best_epoch = f"logs/{hps.name}/weights/{hps.name}_e{best_epoch}_s{best_global_step}_Best_Epoch.pth"
320
- with open(f'prev_best_epoch_{hps.name}.p', 'wb') as prev_best_epoch_f:
321
- p.dump(prev_best_epoch, prev_best_epoch_f)
322
- #
323
- elapsed_time = ttime() - start_time
324
- elapsed_time_record.append(elapsed_time)
325
- if epoch-1 == epoch_str:
326
- elapsed_time_record.pop(0)
327
- elapsed_time_avg = sum(elapsed_time_record)/len(elapsed_time_record)
328
- time_left = elapsed_time_avg*(hps.total_epoch-epoch)
329
- hour, minute, second, millisec = Calculate_format_elapsed_time(elapsed_time)
330
- hour_left, minute_left, second_left, millisec_left = Calculate_format_elapsed_time(time_left)
331
- print(f"Time Elapsed: {hour}h:{formating_time(minute)}m:{formating_time(second)}s:{millisec}ms || Time left: {hour_left}h:{formating_time(minute_left)}m:{formating_time(second_left)}s:{millisec_left}ms\n")
332
- #
333
- if ((len(Loss_Gen_Per_Epoch) - right_index(Loss_Gen_Per_Epoch,min(Loss_Gen_Per_Epoch)) + 1) > overtrain and overtrain != -1):
334
- logger.info("Over Train threshold reached. Training is done.")
335
- print("Over Train threshold reached. Training is done.")
336
-
337
- if hasattr(net_g, "module"):
338
- ckpt = net_g.module.state_dict()
339
- else:
340
- ckpt = net_g.state_dict()
341
- logger.info(
342
- "saving final ckpt:%s"
343
- % (
344
- savee(
345
- ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps, experiment_name
346
- )
347
- )
348
- )
349
- sleep(1)
350
- with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
351
- csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
352
- csv_writer.writerow(["False"])
353
- os._exit(2333333)
354
-
355
- else:
356
- train_and_evaluate(
357
- rank,
358
- epoch,
359
- hps,
360
- [net_g, net_d],
361
- [optim_g, optim_d],
362
- [scheduler_g, scheduler_d],
363
- scaler,
364
- [train_loader, None],
365
- None,
366
- None,
367
- cache,
368
- )
369
- scheduler_g.step()
370
- scheduler_d.step()
371
- #gathered_tensors_gen = [torch.zeros_like(loss_gen_all) for _ in range(n_gpus)]
372
- #gathered_tensors_disc = [torch.zeros_like(loss_disc) for _ in range(n_gpus)]
373
- #dist.all_gather(gathered_tensors_gen, loss_gen_all)
374
- #dist.all_gather(gathered_tensors_disc, loss_disc)
375
-
376
-
377
-
378
- #######
379
-
380
- def train_and_evaluate(
381
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
382
- ):
383
- global loss_gen_all, loss_disc, ckpt, loss_kl, loss_fm, loss_gen, loss_mel
384
- net_g, net_d = nets
385
- optim_g, optim_d = optims
386
- train_loader, eval_loader = loaders
387
- if writers is not None:
388
- writer, writer_eval = writers
389
-
390
- train_loader.batch_sampler.set_epoch(epoch)
391
- global global_step
392
-
393
- net_g.train()
394
- net_d.train()
395
-
396
- # Prepare data iterator
397
- if hps.if_cache_data_in_gpu == True:
398
- # Use Cache
399
- data_iterator = cache
400
- if cache == []:
401
- # Make new cache
402
- for batch_idx, info in enumerate(train_loader):
403
- # Unpack
404
- if hps.if_f0 == 1:
405
- (
406
- phone,
407
- phone_lengths,
408
- pitch,
409
- pitchf,
410
- spec,
411
- spec_lengths,
412
- wave,
413
- wave_lengths,
414
- sid,
415
- ) = info
416
- else:
417
- (
418
- phone,
419
- phone_lengths,
420
- spec,
421
- spec_lengths,
422
- wave,
423
- wave_lengths,
424
- sid,
425
- ) = info
426
- # Load on CUDA
427
- if torch.cuda.is_available():
428
- phone = phone.cuda(rank, non_blocking=True)
429
- phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
430
- if hps.if_f0 == 1:
431
- pitch = pitch.cuda(rank, non_blocking=True)
432
- pitchf = pitchf.cuda(rank, non_blocking=True)
433
- sid = sid.cuda(rank, non_blocking=True)
434
- spec = spec.cuda(rank, non_blocking=True)
435
- spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
436
- wave = wave.cuda(rank, non_blocking=True)
437
- wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
438
- # Cache on list
439
- if hps.if_f0 == 1:
440
- cache.append(
441
- (
442
- batch_idx,
443
- (
444
- phone,
445
- phone_lengths,
446
- pitch,
447
- pitchf,
448
- spec,
449
- spec_lengths,
450
- wave,
451
- wave_lengths,
452
- sid,
453
- ),
454
- )
455
- )
456
- else:
457
- cache.append(
458
- (
459
- batch_idx,
460
- (
461
- phone,
462
- phone_lengths,
463
- spec,
464
- spec_lengths,
465
- wave,
466
- wave_lengths,
467
- sid,
468
- ),
469
- )
470
- )
471
- else:
472
- # Load shuffled cache
473
- shuffle(cache)
474
- else:
475
- # Loader
476
- data_iterator = enumerate(train_loader)
477
-
478
- # Run steps
479
- epoch_recorder = EpochRecorder()
480
-
481
- for batch_idx, info in data_iterator:
482
- # Data
483
- ## Unpack
484
- if hps.if_f0 == 1:
485
- (
486
- phone,
487
- phone_lengths,
488
- pitch,
489
- pitchf,
490
- spec,
491
- spec_lengths,
492
- wave,
493
- wave_lengths,
494
- sid,
495
- ) = info
496
- else:
497
- phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
498
- ## Load on CUDA
499
- if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
500
- phone = phone.cuda(rank, non_blocking=True)
501
- phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
502
- if hps.if_f0 == 1:
503
- pitch = pitch.cuda(rank, non_blocking=True)
504
- pitchf = pitchf.cuda(rank, non_blocking=True)
505
- sid = sid.cuda(rank, non_blocking=True)
506
- spec = spec.cuda(rank, non_blocking=True)
507
- spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
508
- wave = wave.cuda(rank, non_blocking=True)
509
- # wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
510
-
511
- # Calculate
512
- with autocast(enabled=hps.train.fp16_run):
513
- if hps.if_f0 == 1:
514
- (
515
- y_hat,
516
- ids_slice,
517
- x_mask,
518
- z_mask,
519
- (z, z_p, m_p, logs_p, m_q, logs_q),
520
- ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
521
- else:
522
- (
523
- y_hat,
524
- ids_slice,
525
- x_mask,
526
- z_mask,
527
- (z, z_p, m_p, logs_p, m_q, logs_q),
528
- ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
529
- mel = spec_to_mel_torch(
530
- spec,
531
- hps.data.filter_length,
532
- hps.data.n_mel_channels,
533
- hps.data.sampling_rate,
534
- hps.data.mel_fmin,
535
- hps.data.mel_fmax,
536
- )
537
- y_mel = commons.slice_segments(
538
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
539
- )
540
- with autocast(enabled=False):
541
- y_hat_mel = mel_spectrogram_torch(
542
- y_hat.float().squeeze(1),
543
- hps.data.filter_length,
544
- hps.data.n_mel_channels,
545
- hps.data.sampling_rate,
546
- hps.data.hop_length,
547
- hps.data.win_length,
548
- hps.data.mel_fmin,
549
- hps.data.mel_fmax,
550
- )
551
- if hps.train.fp16_run == True:
552
- y_hat_mel = y_hat_mel.half()
553
- wave = commons.slice_segments(
554
- wave, ids_slice * hps.data.hop_length, hps.train.segment_size
555
- ) # slice
556
-
557
- # Discriminator
558
- y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
559
- with autocast(enabled=False):
560
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
561
- y_d_hat_r, y_d_hat_g
562
- )
563
- optim_d.zero_grad()
564
- scaler.scale(loss_disc).backward()
565
- scaler.unscale_(optim_d)
566
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
567
- scaler.step(optim_d)
568
-
569
- with autocast(enabled=hps.train.fp16_run):
570
- # Generator
571
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
572
- with autocast(enabled=False):
573
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
574
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
575
- loss_fm = feature_loss(fmap_r, fmap_g)
576
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
577
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
578
- optim_g.zero_grad()
579
- scaler.scale(loss_gen_all).backward()
580
- scaler.unscale_(optim_g)
581
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
582
- scaler.step(optim_g)
583
- scaler.update()
584
-
585
- if rank == 0:
586
- if global_step % hps.train.log_interval == 0:
587
- lr = optim_g.param_groups[0]["lr"]
588
- logger.info( ""
589
- #"Train Epoch: {} [{:.0f}%]".format(
590
- #epoch, 100.0 * batch_idx / len(train_loader)
591
- #)
592
- )
593
- # Amor For Tensorboard display
594
- if loss_mel > 75:
595
- loss_mel = 75
596
- if loss_kl > 9:
597
- loss_kl = 9
598
-
599
- logger.info([global_step, lr])
600
- logger.info(""
601
- #f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
602
- )
603
- scalar_dict = {
604
- "loss/g/total": loss_gen_all,
605
- "loss/d/total": loss_disc,
606
- "learning_rate": lr,
607
- "grad_norm_d": grad_norm_d,
608
- "grad_norm_g": grad_norm_g,
609
- }
610
- scalar_dict.update(
611
- {
612
- "loss/g/fm": loss_fm,
613
- "loss/g/mel": loss_mel,
614
- "loss/g/kl": loss_kl,
615
- }
616
- )
617
-
618
- scalar_dict.update(
619
- {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
620
- )
621
- scalar_dict.update(
622
- {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
623
- )
624
- scalar_dict.update(
625
- {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
626
- )
627
- image_dict = {
628
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
629
- y_mel[0].data.cpu().numpy()
630
- ),
631
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
632
- y_hat_mel[0].data.cpu().numpy()
633
- ),
634
- "all/mel": utils.plot_spectrogram_to_numpy(
635
- mel[0].data.cpu().numpy()
636
- ),
637
- }
638
- utils.summarize(
639
- writer=writer,
640
- global_step=global_step,
641
- images=image_dict,
642
- scalars=scalar_dict,
643
- )
644
- global_step += 1
645
- # /Run steps
646
-
647
- if epoch % hps.save_every_epoch == 0 and rank == 0:
648
- print(f"Saved: {hps.name}_e{epoch}_s{global_step}.pth")
649
- if hps.if_latest == 0:
650
- utils.save_checkpoint(
651
- net_g,
652
- optim_g,
653
- hps.train.learning_rate,
654
- epoch,
655
- os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
656
- )
657
- utils.save_checkpoint(
658
- net_d,
659
- optim_d,
660
- hps.train.learning_rate,
661
- epoch,
662
- os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
663
- )
664
- else:
665
- utils.save_checkpoint(
666
- net_g,
667
- optim_g,
668
- hps.train.learning_rate,
669
- epoch,
670
- os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
671
- )
672
- utils.save_checkpoint(
673
- net_d,
674
- optim_d,
675
- hps.train.learning_rate,
676
- epoch,
677
- os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
678
- )
679
- if rank == 0 and hps.save_every_weights == "1":
680
- if hasattr(net_g, "module"):
681
- ckpt = net_g.module.state_dict()
682
- else:
683
- ckpt = net_g.state_dict()
684
- logger.info(
685
- "saving ckpt %s_e%s:%s"
686
- % (
687
- hps.name,
688
- epoch,
689
- savee(
690
- ckpt,
691
- hps.sample_rate,
692
- hps.if_f0,
693
- hps.name + "_e%s_s%s" % (epoch, global_step),
694
- epoch,
695
- hps.version,
696
- hps,
697
- experiment_name,
698
- ),
699
- )
700
- )
701
-
702
- try:
703
- with open("csvdb/stop.csv") as CSVStop:
704
- csv_reader = list(csv.reader(CSVStop))
705
- stopbtn = (
706
- csv_reader[0][0]
707
- if csv_reader is not None
708
- else (lambda: exec('raise ValueError("No data")'))()
709
- )
710
- stopbtn = (
711
- lambda stopbtn: True
712
- if stopbtn.lower() == "true"
713
- else (False if stopbtn.lower() == "false" else stopbtn)
714
- )(stopbtn)
715
- except (ValueError, TypeError, IndexError):
716
- stopbtn = False
717
-
718
- if stopbtn:
719
- logger.info("Stop Button was pressed. The program is closed.")
720
- if hasattr(net_g, "module"):
721
- ckpt = net_g.module.state_dict()
722
- else:
723
- ckpt = net_g.state_dict()
724
- logger.info(
725
- "saving final ckpt:%s"
726
- % (
727
- savee(
728
- ckpt,
729
- hps.sample_rate,
730
- hps.if_f0,
731
- hps.name,
732
- epoch,
733
- hps.version,
734
- hps,
735
- experiment_name,
736
- )
737
- )
738
- )
739
- sleep(1)
740
- with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
741
- csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
742
- csv_writer.writerow(["False"])
743
- os._exit(2333333)
744
-
745
- if rank == 0:
746
- logger.info('')#"====> Epoch: {} {}".format(epoch, epoch_recorder.record()))
747
- if epoch > hps.total_epoch and rank == 0:
748
- logger.info("Training is done. The program is closed.")
749
-
750
- if hasattr(net_g, "module"):
751
- ckpt = net_g.module.state_dict()
752
- else:
753
- ckpt = net_g.state_dict()
754
- logger.info(
755
- "saving final ckpt:%s"
756
- % (
757
- savee(
758
- ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps, experiment_name
759
- )
760
- )
761
- )
762
- sleep(1)
763
- with open("csvdb/stop.csv", "w+", newline="") as STOPCSVwrite:
764
- csv_writer = csv.writer(STOPCSVwrite, delimiter=",")
765
- csv_writer.writerow(["False"])
766
- os._exit(2333333)
767
-
768
-
769
- if __name__ == "__main__":
770
- torch.multiprocessing.set_start_method("spawn")
771
- main()