Spaces:
Sleeping
Sleeping
update zipvoice demo
Browse files- app.py +91 -4
- infer_zipvoice.py +302 -0
- requirements.txt +44 -0
- utils.py +125 -0
- zipvoice/__init__.py +0 -0
- zipvoice/bin/compute_fbank.py +278 -0
- zipvoice/bin/generate_averaged_model.py +239 -0
- zipvoice/bin/infer_zipvoice.py +617 -0
- zipvoice/bin/infer_zipvoice_dialog.py +756 -0
- zipvoice/bin/infer_zipvoice_onnx.py +715 -0
- zipvoice/bin/onnx_export.py +404 -0
- zipvoice/bin/train_zipvoice.py +1110 -0
- zipvoice/bin/train_zipvoice_distill.py +1159 -0
- zipvoice/dataset/datamodule.py +319 -0
- zipvoice/dataset/dataset.py +105 -0
- zipvoice/eval/evaluate_sim.py +535 -0
- zipvoice/eval/evaluate_utmos.py +314 -0
- zipvoice/eval/evaluate_wer_hubert.py +192 -0
- zipvoice/eval/evaluate_wer_seedtts.py +200 -0
- zipvoice/models/modules/scaling.py +1563 -0
- zipvoice/models/modules/solver.py +281 -0
- zipvoice/models/modules/zipformer.py +1680 -0
- zipvoice/models/modules/zipformer_two_stream.py +264 -0
- zipvoice/models/zipvoice.py +534 -0
- zipvoice/models/zipvoice_dialog.py +358 -0
- zipvoice/models/zipvoice_distill.py +94 -0
- zipvoice/tokenizer/normalizer.py +170 -0
- zipvoice/tokenizer/tokenizer.py +618 -0
- zipvoice/utils/checkpoint.py +572 -0
- zipvoice/utils/common.py +604 -0
- zipvoice/utils/diagnostics.py +723 -0
- zipvoice/utils/feature.py +120 -0
- zipvoice/utils/hooks.py +111 -0
- zipvoice/utils/lr_scheduler.py +245 -0
- zipvoice/utils/optim.py +868 -0
- zipvoice/utils/scaling_converter.py +105 -0
app.py
CHANGED
|
@@ -1,7 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import os
|
| 3 |
+
from huggingface_hub import login
|
| 4 |
import gradio as gr
|
| 5 |
+
from cached_path import cached_path
|
| 6 |
+
import tempfile
|
| 7 |
+
from vinorm import TTSnorm
|
| 8 |
+
from infer_zipvoice import model, tokenizer, feature_extractor, device
|
| 9 |
+
from utils import preprocess_ref_audio_text, save_spectrogram
|
| 10 |
|
| 11 |
+
# Retrieve token from secrets
|
| 12 |
+
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 13 |
|
| 14 |
+
# Log in to Hugging Face
|
| 15 |
+
if hf_token:
|
| 16 |
+
login(token=hf_token)
|
| 17 |
+
|
| 18 |
+
def post_process(text):
|
| 19 |
+
text = " " + text + " "
|
| 20 |
+
text = text.replace(" . . ", " . ")
|
| 21 |
+
text = " " + text + " "
|
| 22 |
+
text = text.replace(" .. ", " . ")
|
| 23 |
+
text = " " + text + " "
|
| 24 |
+
text = text.replace(" , , ", " , ")
|
| 25 |
+
text = " " + text + " "
|
| 26 |
+
text = text.replace(" ,, ", " , ")
|
| 27 |
+
text = " " + text + " "
|
| 28 |
+
text = text.replace('"', "")
|
| 29 |
+
return " ".join(text.split())
|
| 30 |
+
|
| 31 |
+
@spaces.GPU
|
| 32 |
+
def infer_tts(ref_audio_orig: str, gen_text: str, speed: float = 1.0, request: gr.Request = None):
|
| 33 |
+
|
| 34 |
+
if not ref_audio_orig:
|
| 35 |
+
raise gr.Error("Please upload a sample audio file.")
|
| 36 |
+
if not gen_text.strip():
|
| 37 |
+
raise gr.Error("Please enter the text content to generate voice.")
|
| 38 |
+
if len(gen_text.split()) > 1000:
|
| 39 |
+
raise gr.Error("Please enter text content with less than 1000 words.")
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, "")
|
| 43 |
+
final_wave = generate_sentence(
|
| 44 |
+
ref_text.lower(),
|
| 45 |
+
ref_audio,
|
| 46 |
+
post_process(TTSnorm(gen_text)).lower(),
|
| 47 |
+
model=model,
|
| 48 |
+
vocoder=vocoder,
|
| 49 |
+
tokenizer=tokenizer,
|
| 50 |
+
feature_extractor=feature_extractor,
|
| 51 |
+
device=device,
|
| 52 |
+
speed=speed
|
| 53 |
+
)
|
| 54 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
| 55 |
+
spectrogram_path = tmp_spectrogram.name
|
| 56 |
+
save_spectrogram(spectrogram, spectrogram_path)
|
| 57 |
+
|
| 58 |
+
return (final_sample_rate, final_wave), spectrogram_path
|
| 59 |
+
except Exception as e:
|
| 60 |
+
raise gr.Error(f"Error generating voice: {e}")
|
| 61 |
+
|
| 62 |
+
# Gradio UI
|
| 63 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 64 |
+
gr.Markdown("""
|
| 65 |
+
# 🎤 ZipVoice: Vietnamese Text-to-Speech Synthesis.
|
| 66 |
+
# The model was trained with approximately 150 hours of data on a RTX 3090 GPU.
|
| 67 |
+
Enter text and upload a sample voice to generate natural speech.
|
| 68 |
+
""")
|
| 69 |
+
|
| 70 |
+
with gr.Row():
|
| 71 |
+
ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath")
|
| 72 |
+
gen_text = gr.Textbox(label="📝 Text", placeholder="Enter the text to generate voice...", lines=3)
|
| 73 |
+
|
| 74 |
+
speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed")
|
| 75 |
+
btn_synthesize = gr.Button("🔥 Generate Voice")
|
| 76 |
+
|
| 77 |
+
with gr.Row():
|
| 78 |
+
output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy")
|
| 79 |
+
output_spectrogram = gr.Image(label="📊 Spectrogram")
|
| 80 |
+
|
| 81 |
+
model_limitations = gr.Textbox(
|
| 82 |
+
value="""1. This model may not perform well with numerical characters, dates, special characters, etc. => A text normalization module is needed.
|
| 83 |
+
2. The rhythm of some generated audios may be inconsistent or choppy => It is recommended to select clearly pronounced sample audios with minimal pauses for better synthesis quality.
|
| 84 |
+
3. Default, reference audio text uses the pho-whisper-medium model, which may not always accurately recognize Vietnamese, resulting in poor voice synthesis quality.
|
| 85 |
+
4. Inference with overly long paragraphs may produce poor results.""",
|
| 86 |
+
label="❗ Model Limitations",
|
| 87 |
+
lines=4,
|
| 88 |
+
interactive=False
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
btn_synthesize.click(infer_tts, inputs=[ref_audio, gen_text, speed], outputs=[output_audio, output_spectrogram])
|
| 92 |
+
|
| 93 |
+
# Run Gradio with share=True to get a gradio.live link
|
| 94 |
+
demo.queue().launch()
|
infer_zipvoice.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script generates speech with our pre-trained ZipVoice or
|
| 20 |
+
ZipVoice-Distill models. If no local model is specified,
|
| 21 |
+
Required files will be automatically downloaded from HuggingFace.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
|
| 25 |
+
Note: If you having trouble connecting to HuggingFace,
|
| 26 |
+
try switching endpoint to mirror site:
|
| 27 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 28 |
+
|
| 29 |
+
(1) Inference of a single sentence:
|
| 30 |
+
|
| 31 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 32 |
+
--model-name "zipvoice" \
|
| 33 |
+
--prompt-wav prompt.wav \
|
| 34 |
+
--prompt-text "I am a prompt." \
|
| 35 |
+
--text "I am a sentence." \
|
| 36 |
+
--res-wav-path result.wav
|
| 37 |
+
|
| 38 |
+
(2) Inference of a list of sentences:
|
| 39 |
+
|
| 40 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 41 |
+
--model-name "zipvoice" \
|
| 42 |
+
--test-list test.tsv \
|
| 43 |
+
--res-dir results
|
| 44 |
+
|
| 45 |
+
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
| 46 |
+
which are the models before and after distillation, respectively.
|
| 47 |
+
|
| 48 |
+
Each line of `test.tsv` is in the format of
|
| 49 |
+
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import datetime as dt
|
| 54 |
+
import json
|
| 55 |
+
import os
|
| 56 |
+
from typing import Optional
|
| 57 |
+
|
| 58 |
+
import numpy as np
|
| 59 |
+
import safetensors.torch
|
| 60 |
+
import torch
|
| 61 |
+
import torchaudio
|
| 62 |
+
from huggingface_hub import hf_hub_download
|
| 63 |
+
from lhotse.utils import fix_random_seed
|
| 64 |
+
from vocos import Vocos
|
| 65 |
+
|
| 66 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 67 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 68 |
+
from zipvoice.tokenizer.tokenizer import (
|
| 69 |
+
EmiliaTokenizer,
|
| 70 |
+
EspeakTokenizer,
|
| 71 |
+
LibriTTSTokenizer,
|
| 72 |
+
SimpleTokenizer,
|
| 73 |
+
)
|
| 74 |
+
from zipvoice.utils.checkpoint import load_checkpoint
|
| 75 |
+
from zipvoice.utils.common import AttributeDict
|
| 76 |
+
from zipvoice.utils.feature import VocosFbank
|
| 77 |
+
|
| 78 |
+
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
|
| 79 |
+
PRETRAINED_MODEL = {
|
| 80 |
+
"zipvoice": "zipvoice/model.pt",
|
| 81 |
+
"zipvoice_distill": "zipvoice_distill/model.pt",
|
| 82 |
+
}
|
| 83 |
+
TOKEN_FILE = {
|
| 84 |
+
"zipvoice": "zipvoice/tokens.txt",
|
| 85 |
+
"zipvoice_distill": "zipvoice_distill/tokens.txt",
|
| 86 |
+
}
|
| 87 |
+
MODEL_CONFIG = {
|
| 88 |
+
"zipvoice": "zipvoice/zipvoice_base.json",
|
| 89 |
+
"zipvoice_distill": "zipvoice_distill/zipvoice_base.json",
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
torch.set_num_threads(1)
|
| 93 |
+
torch.set_num_interop_threads(1)
|
| 94 |
+
|
| 95 |
+
def get_vocoder(vocos_local_path: Optional[str] = None):
|
| 96 |
+
if vocos_local_path:
|
| 97 |
+
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
| 98 |
+
state_dict = torch.load(
|
| 99 |
+
f"{vocos_local_path}/pytorch_model.bin",
|
| 100 |
+
weights_only=True,
|
| 101 |
+
map_location="cpu",
|
| 102 |
+
)
|
| 103 |
+
vocoder.load_state_dict(state_dict)
|
| 104 |
+
else:
|
| 105 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 106 |
+
return vocoder
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def generate_sentence(
|
| 110 |
+
prompt_text: str,
|
| 111 |
+
prompt_wav: str,
|
| 112 |
+
text: str,
|
| 113 |
+
model: torch.nn.Module,
|
| 114 |
+
vocoder: torch.nn.Module,
|
| 115 |
+
tokenizer: EmiliaTokenizer,
|
| 116 |
+
feature_extractor: VocosFbank,
|
| 117 |
+
device: torch.device,
|
| 118 |
+
num_step: int = 16,
|
| 119 |
+
guidance_scale: float = 1.0,
|
| 120 |
+
speed: float = 1.0,
|
| 121 |
+
t_shift: float = 0.5,
|
| 122 |
+
target_rms: float = 0.1,
|
| 123 |
+
feat_scale: float = 0.1,
|
| 124 |
+
sampling_rate: int = 24000,
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
Generate waveform of a text based on a given prompt
|
| 128 |
+
waveform and its transcription.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
save_path (str): Path to save the generated wav.
|
| 132 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 133 |
+
prompt_wav (str): Path to the prompt wav file.
|
| 134 |
+
text (str): Text to be synthesized into a waveform.
|
| 135 |
+
model (torch.nn.Module): The model used for generation.
|
| 136 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 137 |
+
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
|
| 138 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 139 |
+
extract acoustic features.
|
| 140 |
+
device (torch.device): The device on which computations are performed.
|
| 141 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 142 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 143 |
+
Defaults to 1.0.
|
| 144 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 145 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 146 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 147 |
+
Defaults to 0.1.
|
| 148 |
+
feat_scale (float, optional): Scale for features.
|
| 149 |
+
Defaults to 0.1.
|
| 150 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 151 |
+
Defaults to 24000.
|
| 152 |
+
Returns:
|
| 153 |
+
metrics (dict): Dictionary containing time and real-time
|
| 154 |
+
factor metrics for processing.
|
| 155 |
+
"""
|
| 156 |
+
# Convert text to tokens
|
| 157 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 158 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 159 |
+
|
| 160 |
+
# Load and preprocess prompt wav
|
| 161 |
+
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
| 162 |
+
|
| 163 |
+
if prompt_sampling_rate != sampling_rate:
|
| 164 |
+
resampler = torchaudio.transforms.Resample(
|
| 165 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 166 |
+
)
|
| 167 |
+
prompt_wav = resampler(prompt_wav)
|
| 168 |
+
|
| 169 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 170 |
+
if prompt_rms < target_rms:
|
| 171 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 172 |
+
|
| 173 |
+
# Extract features from prompt wav
|
| 174 |
+
prompt_features = feature_extractor.extract(
|
| 175 |
+
prompt_wav, sampling_rate=sampling_rate
|
| 176 |
+
).to(device)
|
| 177 |
+
|
| 178 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 179 |
+
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
| 180 |
+
|
| 181 |
+
# Start timing
|
| 182 |
+
start_t = dt.datetime.now()
|
| 183 |
+
|
| 184 |
+
# Generate features
|
| 185 |
+
(
|
| 186 |
+
pred_features,
|
| 187 |
+
pred_features_lens,
|
| 188 |
+
pred_prompt_features,
|
| 189 |
+
pred_prompt_features_lens,
|
| 190 |
+
) = model.sample(
|
| 191 |
+
tokens=tokens,
|
| 192 |
+
prompt_tokens=prompt_tokens,
|
| 193 |
+
prompt_features=prompt_features,
|
| 194 |
+
prompt_features_lens=prompt_features_lens,
|
| 195 |
+
speed=speed,
|
| 196 |
+
t_shift=t_shift,
|
| 197 |
+
duration="predict",
|
| 198 |
+
num_step=num_step,
|
| 199 |
+
guidance_scale=guidance_scale,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Postprocess predicted features
|
| 203 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 204 |
+
|
| 205 |
+
# Start vocoder processing
|
| 206 |
+
start_vocoder_t = dt.datetime.now()
|
| 207 |
+
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
| 208 |
+
|
| 209 |
+
# Calculate processing times and real-time factors
|
| 210 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 211 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 212 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 213 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 214 |
+
rtf = t / wav_seconds
|
| 215 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 216 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 217 |
+
# metrics = {
|
| 218 |
+
# "t": t,
|
| 219 |
+
# "t_no_vocoder": t_no_vocoder,
|
| 220 |
+
# "t_vocoder": t_vocoder,
|
| 221 |
+
# "wav_seconds": wav_seconds,
|
| 222 |
+
# "rtf": rtf,
|
| 223 |
+
# "rtf_no_vocoder": rtf_no_vocoder,
|
| 224 |
+
# "rtf_vocoder": rtf_vocoder,
|
| 225 |
+
# }
|
| 226 |
+
|
| 227 |
+
# Adjust wav volume if necessary
|
| 228 |
+
if prompt_rms < target_rms:
|
| 229 |
+
wav = wav * prompt_rms / target_rms
|
| 230 |
+
# torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 231 |
+
# return metrics
|
| 232 |
+
return wav.cpu()
|
| 233 |
+
|
| 234 |
+
model_defaults = {
|
| 235 |
+
"zipvoice": {
|
| 236 |
+
"num_step": 16,
|
| 237 |
+
"guidance_scale": 1.0,
|
| 238 |
+
},
|
| 239 |
+
"zipvoice_distill": {
|
| 240 |
+
"num_step": 8,
|
| 241 |
+
"guidance_scale": 3.0,
|
| 242 |
+
},
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
device = torch.device("cuda", 0)
|
| 246 |
+
|
| 247 |
+
print("Loading model...")
|
| 248 |
+
model_config = "ckpt/model.json"
|
| 249 |
+
|
| 250 |
+
with open(model_config, "r") as f:
|
| 251 |
+
model_config = json.load(f)
|
| 252 |
+
|
| 253 |
+
token_file = "ckpt/tokens.txt"
|
| 254 |
+
|
| 255 |
+
tokenizer = EspeakTokenizer(token_file=token_file, lang="vi")
|
| 256 |
+
|
| 257 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 258 |
+
|
| 259 |
+
model_ckpt = "ckpt/model.pt"
|
| 260 |
+
|
| 261 |
+
model = ZipVoice(
|
| 262 |
+
**model_config["model"],
|
| 263 |
+
**tokenizer_config,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
load_checkpoint(filename=model_ckpt, model=model, strict=True)
|
| 267 |
+
|
| 268 |
+
model = model.to(device)
|
| 269 |
+
model.eval()
|
| 270 |
+
|
| 271 |
+
vocoder = get_vocoder(None)
|
| 272 |
+
vocoder = vocoder.to(device)
|
| 273 |
+
vocoder.eval()
|
| 274 |
+
|
| 275 |
+
if model_config["feature"]["type"] == "vocos":
|
| 276 |
+
feature_extractor = VocosFbank()
|
| 277 |
+
else:
|
| 278 |
+
raise NotImplementedError(
|
| 279 |
+
f"Unsupported feature type: {model_config['feature']['type']}"
|
| 280 |
+
)
|
| 281 |
+
sampling_rate = model_config["feature"]["sampling_rate"]
|
| 282 |
+
|
| 283 |
+
# generate_sentence(
|
| 284 |
+
# save_path=res_wav_path,
|
| 285 |
+
# prompt_text=prompt_text,
|
| 286 |
+
# prompt_wav=prompt_wav,
|
| 287 |
+
# text=text,
|
| 288 |
+
# model=model,
|
| 289 |
+
# vocoder=vocoder,
|
| 290 |
+
# tokenizer=tokenizer,
|
| 291 |
+
# feature_extractor=feature_extractor,
|
| 292 |
+
# device=device,
|
| 293 |
+
# num_step=16,
|
| 294 |
+
# guidance_scale=1.0,
|
| 295 |
+
# speed=speed,
|
| 296 |
+
# t_shift=0.5,
|
| 297 |
+
# target_rms=0.1,
|
| 298 |
+
# feat_scale=0.1,
|
| 299 |
+
# sampling_rate=sampling_rate,
|
| 300 |
+
# )
|
| 301 |
+
|
| 302 |
+
# print("Done")
|
requirements.txt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
|
| 2 |
+
|
| 3 |
+
torch<=2.6.0
|
| 4 |
+
torchaudio<=2.6.0
|
| 5 |
+
lhotse
|
| 6 |
+
tensorboard
|
| 7 |
+
vocos
|
| 8 |
+
|
| 9 |
+
# Normalization
|
| 10 |
+
cn2an
|
| 11 |
+
inflect
|
| 12 |
+
unidecode
|
| 13 |
+
|
| 14 |
+
# Tokenization
|
| 15 |
+
piper_phonemize
|
| 16 |
+
|
| 17 |
+
k2==1.24.4.dev20250208+cuda12.4.torch2.5.1 --find-links https://k2-fsa.github.io/k2/cuda-cn.html
|
| 18 |
+
|
| 19 |
+
transformers
|
| 20 |
+
bitsandbytes>0.37.0
|
| 21 |
+
vinorm
|
| 22 |
+
cached_path
|
| 23 |
+
huggingface_hub
|
| 24 |
+
gradio
|
| 25 |
+
accelerate>=0.33.0
|
| 26 |
+
click
|
| 27 |
+
datasets
|
| 28 |
+
ema_pytorch>=0.5.2
|
| 29 |
+
gradio>=3.45.2
|
| 30 |
+
hydra-core>=1.3.0
|
| 31 |
+
jieba
|
| 32 |
+
librosa
|
| 33 |
+
matplotlib
|
| 34 |
+
numpy<=1.26.4
|
| 35 |
+
pydub
|
| 36 |
+
pypinyin
|
| 37 |
+
safetensors
|
| 38 |
+
soundfile
|
| 39 |
+
tomli
|
| 40 |
+
torchdiffeq
|
| 41 |
+
tqdm>=4.65.0
|
| 42 |
+
transformers_stream_generator
|
| 43 |
+
wandb
|
| 44 |
+
x_transformers>=1.31.14
|
utils.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydub import AudioSegment, silence
|
| 2 |
+
import tempfile
|
| 3 |
+
import hashlib
|
| 4 |
+
import matplotlib.pylab as plt
|
| 5 |
+
import librosa
|
| 6 |
+
from transformers import pipeline
|
| 7 |
+
|
| 8 |
+
def initialize_asr_pipeline(device: str = device, dtype=None):
|
| 9 |
+
if dtype is None:
|
| 10 |
+
dtype = (
|
| 11 |
+
torch.float16
|
| 12 |
+
if "cuda" in device
|
| 13 |
+
and torch.cuda.get_device_properties(device).major >= 6
|
| 14 |
+
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
| 15 |
+
else torch.float32
|
| 16 |
+
)
|
| 17 |
+
global asr_pipe
|
| 18 |
+
asr_pipe = pipeline(
|
| 19 |
+
"automatic-speech-recognition",
|
| 20 |
+
model="vinai/PhoWhisper-medium",
|
| 21 |
+
torch_dtype=dtype,
|
| 22 |
+
device=device,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# transcribe
|
| 26 |
+
def transcribe(ref_audio, language=None):
|
| 27 |
+
global asr_pipe
|
| 28 |
+
if asr_pipe is None:
|
| 29 |
+
initialize_asr_pipeline(device=device)
|
| 30 |
+
return asr_pipe(
|
| 31 |
+
ref_audio,
|
| 32 |
+
chunk_length_s=30,
|
| 33 |
+
batch_size=128,
|
| 34 |
+
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
|
| 35 |
+
return_timestamps=False,
|
| 36 |
+
)["text"].strip()
|
| 37 |
+
|
| 38 |
+
def caculate_spec(audio):
|
| 39 |
+
# Compute spectrogram (Short-Time Fourier Transform)
|
| 40 |
+
stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512)
|
| 41 |
+
spectrogram = np.abs(stft)
|
| 42 |
+
# Convert to dB
|
| 43 |
+
spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max)
|
| 44 |
+
return spectrogram_db
|
| 45 |
+
|
| 46 |
+
def save_spectrogram(audio, path):
|
| 47 |
+
spectrogram = caculate_spec(audio)
|
| 48 |
+
plt.figure(figsize=(12, 4))
|
| 49 |
+
plt.imshow(spectrogram, origin="lower", aspect="auto")
|
| 50 |
+
plt.colorbar()
|
| 51 |
+
plt.savefig(path)
|
| 52 |
+
plt.close()
|
| 53 |
+
|
| 54 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
|
| 55 |
+
|
| 56 |
+
show_info("Converting audio...")
|
| 57 |
+
|
| 58 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 59 |
+
|
| 60 |
+
aseg = AudioSegment.from_file(ref_audio_orig)
|
| 61 |
+
|
| 62 |
+
if clip_short:
|
| 63 |
+
# 1. try to find long silence for clipping
|
| 64 |
+
non_silent_segs = silence.split_on_silence(
|
| 65 |
+
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
| 66 |
+
)
|
| 67 |
+
non_silent_wave = AudioSegment.silent(duration=0)
|
| 68 |
+
for non_silent_seg in non_silent_segs:
|
| 69 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
| 70 |
+
show_info("Audio is over 15s, clipping short. (1)")
|
| 71 |
+
break
|
| 72 |
+
non_silent_wave += non_silent_seg
|
| 73 |
+
|
| 74 |
+
# 2. try to find short silence for clipping if 1. failed
|
| 75 |
+
if len(non_silent_wave) > 15000:
|
| 76 |
+
non_silent_segs = silence.split_on_silence(
|
| 77 |
+
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
| 78 |
+
)
|
| 79 |
+
non_silent_wave = AudioSegment.silent(duration=0)
|
| 80 |
+
for non_silent_seg in non_silent_segs:
|
| 81 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
| 82 |
+
show_info("Audio is over 15s, clipping short. (2)")
|
| 83 |
+
break
|
| 84 |
+
non_silent_wave += non_silent_seg
|
| 85 |
+
|
| 86 |
+
aseg = non_silent_wave
|
| 87 |
+
|
| 88 |
+
# 3. if no proper silence found for clipping
|
| 89 |
+
if len(aseg) > 15000:
|
| 90 |
+
aseg = aseg[:15000]
|
| 91 |
+
show_info("Audio is over 15s, clipping short. (3)")
|
| 92 |
+
|
| 93 |
+
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
| 94 |
+
aseg.export(f.name, format="wav")
|
| 95 |
+
ref_audio = f.name
|
| 96 |
+
|
| 97 |
+
# Compute a hash of the reference audio file
|
| 98 |
+
with open(ref_audio, "rb") as audio_file:
|
| 99 |
+
audio_data = audio_file.read()
|
| 100 |
+
audio_hash = hashlib.md5(audio_data).hexdigest()
|
| 101 |
+
|
| 102 |
+
if not ref_text.strip():
|
| 103 |
+
global _ref_audio_cache
|
| 104 |
+
if audio_hash in _ref_audio_cache:
|
| 105 |
+
# Use cached asr transcription
|
| 106 |
+
show_info("Using cached reference text...")
|
| 107 |
+
ref_text = _ref_audio_cache[audio_hash]
|
| 108 |
+
else:
|
| 109 |
+
show_info("No reference text provided, transcribing reference audio...")
|
| 110 |
+
ref_text = transcribe(ref_audio)
|
| 111 |
+
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
|
| 112 |
+
_ref_audio_cache[audio_hash] = ref_text
|
| 113 |
+
else:
|
| 114 |
+
show_info("Using custom reference text...")
|
| 115 |
+
|
| 116 |
+
# Ensure ref_text ends with a proper sentence-ending punctuation
|
| 117 |
+
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
| 118 |
+
if ref_text.endswith("."):
|
| 119 |
+
ref_text += " "
|
| 120 |
+
else:
|
| 121 |
+
ref_text += ". "
|
| 122 |
+
|
| 123 |
+
print("\nref_text ", ref_text)
|
| 124 |
+
|
| 125 |
+
return ref_audio, ref_text
|
zipvoice/__init__.py
ADDED
|
File without changes
|
zipvoice/bin/compute_fbank.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang
|
| 3 |
+
# Han Zhu)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
"""
|
| 19 |
+
Usage:
|
| 20 |
+
python3 -m zipvoice.bin.compute_fbank \
|
| 21 |
+
--source-dir data/manifests \
|
| 22 |
+
--dest-dir data/fbank \
|
| 23 |
+
--dataset libritts \
|
| 24 |
+
--subset dev-other \
|
| 25 |
+
--sampling-rate 24000 \
|
| 26 |
+
--num-jobs 20
|
| 27 |
+
|
| 28 |
+
The input would be data/manifests/libritts-cuts_dev-other.jsonl.gz or
|
| 29 |
+
(libritts_supervisions_dev-other.jsonl.gz and librittsrecordings_dev-other.jsonl.gz)
|
| 30 |
+
|
| 31 |
+
The output would be data/fbank/libritts-cuts_dev-other.jsonl.gz
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import logging
|
| 37 |
+
from concurrent.futures import ProcessPoolExecutor as Pool
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
|
| 40 |
+
import lhotse
|
| 41 |
+
import torch
|
| 42 |
+
from lhotse import CutSet, LilcomChunkyWriter, load_manifest_lazy
|
| 43 |
+
|
| 44 |
+
from zipvoice.utils.feature import VocosFbank
|
| 45 |
+
|
| 46 |
+
# Torch's multithreaded behavior needs to be disabled or
|
| 47 |
+
# it wastes a lot of CPU and slow things down.
|
| 48 |
+
# Do this outside of main() in case it needs to take effect
|
| 49 |
+
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
| 50 |
+
torch.set_num_threads(1)
|
| 51 |
+
torch.set_num_interop_threads(1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def str2bool(v):
|
| 55 |
+
"""Used in argparse.ArgumentParser.add_argument to indicate
|
| 56 |
+
that a type is a bool type and user can enter
|
| 57 |
+
|
| 58 |
+
- yes, true, t, y, 1, to represent True
|
| 59 |
+
- no, false, f, n, 0, to represent False
|
| 60 |
+
|
| 61 |
+
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
| 62 |
+
"""
|
| 63 |
+
if isinstance(v, bool):
|
| 64 |
+
return v
|
| 65 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 66 |
+
return True
|
| 67 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 68 |
+
return False
|
| 69 |
+
else:
|
| 70 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_args():
|
| 74 |
+
parser = argparse.ArgumentParser()
|
| 75 |
+
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--sampling-rate",
|
| 78 |
+
type=int,
|
| 79 |
+
default=24000,
|
| 80 |
+
help="The target sampling rate, the audio will be resampled to it.",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--type",
|
| 85 |
+
type=str,
|
| 86 |
+
default="vocos",
|
| 87 |
+
help="fbank type",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--dataset",
|
| 92 |
+
type=str,
|
| 93 |
+
help="Dataset name.",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--subset",
|
| 98 |
+
type=str,
|
| 99 |
+
help="The subset of the dataset.",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--source-dir",
|
| 104 |
+
type=str,
|
| 105 |
+
default="data/manifests",
|
| 106 |
+
help="The source directory of manifest files.",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--dest-dir",
|
| 111 |
+
type=str,
|
| 112 |
+
default="data/fbank",
|
| 113 |
+
help="The destination directory of manifest files.",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--split-cuts",
|
| 118 |
+
type=str2bool,
|
| 119 |
+
default=False,
|
| 120 |
+
help="Whether to use splited cuts.",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--split-begin",
|
| 125 |
+
type=int,
|
| 126 |
+
help="Start idx of splited cuts.",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--split-end",
|
| 131 |
+
type=int,
|
| 132 |
+
help="End idx of splited cuts.",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--batch-duration",
|
| 137 |
+
type=int,
|
| 138 |
+
default=1000,
|
| 139 |
+
help="The batch duration when computing the features.",
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--num-jobs",
|
| 144 |
+
type=int,
|
| 145 |
+
default=20,
|
| 146 |
+
help="The number of extractor workers.",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return parser.parse_args()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def compute_fbank_split_single(params, idx):
|
| 153 |
+
lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia
|
| 154 |
+
src_dir = Path(params.source_dir)
|
| 155 |
+
output_dir = Path(params.dest_dir)
|
| 156 |
+
|
| 157 |
+
if not src_dir.exists():
|
| 158 |
+
logging.error(f"{src_dir} not exists")
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
if not output_dir.exists():
|
| 162 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 163 |
+
|
| 164 |
+
num_digits = 8
|
| 165 |
+
if params.type == "vocos":
|
| 166 |
+
extractor = VocosFbank()
|
| 167 |
+
else:
|
| 168 |
+
raise NotImplementedError(f"{params.type} is not supported")
|
| 169 |
+
|
| 170 |
+
prefix = params.dataset
|
| 171 |
+
subset = params.subset
|
| 172 |
+
suffix = "jsonl.gz"
|
| 173 |
+
|
| 174 |
+
idx = f"{idx}".zfill(num_digits)
|
| 175 |
+
cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}"
|
| 176 |
+
|
| 177 |
+
if (src_dir / cuts_filename).is_file():
|
| 178 |
+
logging.info(f"Loading manifests {src_dir / cuts_filename}")
|
| 179 |
+
cut_set = load_manifest_lazy(src_dir / cuts_filename)
|
| 180 |
+
else:
|
| 181 |
+
logging.warning(f"Raw {cuts_filename} not exists, skipping")
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
cut_set = cut_set.resample(params.sampling_rate)
|
| 185 |
+
|
| 186 |
+
if (output_dir / cuts_filename).is_file():
|
| 187 |
+
logging.info(f"{cuts_filename} already exists - skipping.")
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
logging.info(f"Processing {subset}.{idx} of {prefix}")
|
| 191 |
+
|
| 192 |
+
cut_set = cut_set.compute_and_store_features_batch(
|
| 193 |
+
extractor=extractor,
|
| 194 |
+
storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}",
|
| 195 |
+
num_workers=4,
|
| 196 |
+
batch_duration=params.batch_duration,
|
| 197 |
+
storage_type=LilcomChunkyWriter,
|
| 198 |
+
overwrite=True,
|
| 199 |
+
)
|
| 200 |
+
cut_set.to_file(output_dir / cuts_filename)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def compute_fbank_split(params):
|
| 204 |
+
if params.split_end < params.split_begin:
|
| 205 |
+
logging.warning(
|
| 206 |
+
f"Split begin should be smaller than split end, given "
|
| 207 |
+
f"{params.split_begin} -> {params.split_end}."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
with Pool(max_workers=params.num_jobs) as pool:
|
| 211 |
+
futures = [
|
| 212 |
+
pool.submit(compute_fbank_split_single, params, i)
|
| 213 |
+
for i in range(params.split_begin, params.split_end)
|
| 214 |
+
]
|
| 215 |
+
for f in futures:
|
| 216 |
+
f.result()
|
| 217 |
+
f.done()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def compute_fbank(params):
|
| 221 |
+
src_dir = Path(params.source_dir)
|
| 222 |
+
output_dir = Path(params.dest_dir)
|
| 223 |
+
num_jobs = params.num_jobs
|
| 224 |
+
if not output_dir.exists():
|
| 225 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 226 |
+
|
| 227 |
+
prefix = params.dataset
|
| 228 |
+
subset = params.subset
|
| 229 |
+
suffix = "jsonl.gz"
|
| 230 |
+
|
| 231 |
+
cut_set_name = f"{prefix}_cuts_{subset}.{suffix}"
|
| 232 |
+
|
| 233 |
+
if (src_dir / cut_set_name).is_file():
|
| 234 |
+
logging.info(f"Loading manifests {src_dir / cut_set_name}")
|
| 235 |
+
cut_set = load_manifest_lazy(src_dir / cut_set_name)
|
| 236 |
+
else:
|
| 237 |
+
recordings = load_manifest_lazy(
|
| 238 |
+
src_dir / f"{prefix}_recordings_{subset}.{suffix}"
|
| 239 |
+
)
|
| 240 |
+
supervisions = load_manifest_lazy(
|
| 241 |
+
src_dir / f"{prefix}_supervisions_{subset}.{suffix}"
|
| 242 |
+
)
|
| 243 |
+
cut_set = CutSet.from_manifests(
|
| 244 |
+
recordings=recordings,
|
| 245 |
+
supervisions=supervisions,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
cut_set = cut_set.resample(params.sampling_rate)
|
| 249 |
+
if params.type == "vocos":
|
| 250 |
+
extractor = VocosFbank()
|
| 251 |
+
else:
|
| 252 |
+
raise NotImplementedError(f"{params.type} is not supported")
|
| 253 |
+
|
| 254 |
+
cuts_filename = f"{prefix}_cuts_{subset}.{suffix}"
|
| 255 |
+
if (output_dir / cuts_filename).is_file():
|
| 256 |
+
logging.info(f"{prefix} {subset} already exists - skipping.")
|
| 257 |
+
return
|
| 258 |
+
logging.info(f"Processing {subset} of {prefix}")
|
| 259 |
+
|
| 260 |
+
cut_set = cut_set.compute_and_store_features(
|
| 261 |
+
extractor=extractor,
|
| 262 |
+
storage_path=f"{output_dir}/{prefix}_feats_{subset}",
|
| 263 |
+
num_jobs=num_jobs,
|
| 264 |
+
storage_type=LilcomChunkyWriter,
|
| 265 |
+
)
|
| 266 |
+
cut_set.to_file(output_dir / cuts_filename)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 271 |
+
|
| 272 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
| 273 |
+
args = get_args()
|
| 274 |
+
logging.info(vars(args))
|
| 275 |
+
if args.split_cuts:
|
| 276 |
+
compute_fbank_split(params=args)
|
| 277 |
+
else:
|
| 278 |
+
compute_fbank(params=args)
|
zipvoice/bin/generate_averaged_model.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2021-2022 Xiaomi Corporation
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
"""
|
| 19 |
+
Usage:
|
| 20 |
+
This script loads checkpoints and averages them.
|
| 21 |
+
|
| 22 |
+
python3 -m zipvoice.bin.generate_averaged_model \
|
| 23 |
+
--epoch 11 \
|
| 24 |
+
--avg 4 \
|
| 25 |
+
--model_name zipvoice \
|
| 26 |
+
--model-config conf/zipvoice_base.json \
|
| 27 |
+
--token-file data/tokens_emilia.txt \
|
| 28 |
+
--exp-dir exp/zipvoice
|
| 29 |
+
|
| 30 |
+
It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
|
| 31 |
+
You can later load it by `torch.load("epoch-11-avg-4.pt")`.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import argparse
|
| 35 |
+
import json
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 41 |
+
from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
|
| 42 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 43 |
+
from zipvoice.tokenizer.tokenizer import SimpleTokenizer
|
| 44 |
+
from zipvoice.utils.checkpoint import (
|
| 45 |
+
average_checkpoints_with_averaged_model,
|
| 46 |
+
find_checkpoints,
|
| 47 |
+
)
|
| 48 |
+
from zipvoice.utils.common import AttributeDict
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_parser():
|
| 52 |
+
parser = argparse.ArgumentParser(
|
| 53 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--epoch",
|
| 58 |
+
type=int,
|
| 59 |
+
default=11,
|
| 60 |
+
help="""It specifies the checkpoint to use for decoding.
|
| 61 |
+
Note: Epoch counts from 1.
|
| 62 |
+
You can specify --avg to use more checkpoints for model averaging.""",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--iter",
|
| 67 |
+
type=int,
|
| 68 |
+
default=0,
|
| 69 |
+
help="""If positive, --epoch is ignored and it
|
| 70 |
+
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
| 71 |
+
You can specify --avg to use more checkpoints for model averaging.
|
| 72 |
+
""",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--avg",
|
| 77 |
+
type=int,
|
| 78 |
+
default=4,
|
| 79 |
+
help="Number of checkpoints to average. Automatically select "
|
| 80 |
+
"consecutive checkpoints before the checkpoint specified by "
|
| 81 |
+
"'--epoch' or --iter",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--exp-dir",
|
| 86 |
+
type=str,
|
| 87 |
+
default="zipvoice/exp_zipvoice",
|
| 88 |
+
help="The experiment dir",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--model_name",
|
| 93 |
+
type=str,
|
| 94 |
+
default="zipvoice",
|
| 95 |
+
choices=[
|
| 96 |
+
"zipvoice",
|
| 97 |
+
"zipvoice_distill",
|
| 98 |
+
"zipvoice_dialog",
|
| 99 |
+
"zipvoice_dialog_stereo",
|
| 100 |
+
],
|
| 101 |
+
help="The model type to be averaged. ",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--model-config",
|
| 106 |
+
type=str,
|
| 107 |
+
default="conf/zipvoice_base.json",
|
| 108 |
+
help="The model configuration file.",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--token-file",
|
| 113 |
+
type=str,
|
| 114 |
+
default="data/tokens_emilia.txt",
|
| 115 |
+
help="The file that contains information that maps tokens to ids,"
|
| 116 |
+
"which is a text file with '{token}\t{token_id}' per line if type is"
|
| 117 |
+
"char or phone, otherwise it is a bpe_model file.",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return parser
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def main():
|
| 125 |
+
parser = get_parser()
|
| 126 |
+
args = parser.parse_args()
|
| 127 |
+
args.exp_dir = Path(args.exp_dir)
|
| 128 |
+
params = AttributeDict()
|
| 129 |
+
params.update(vars(args))
|
| 130 |
+
|
| 131 |
+
with open(params.model_config, "r") as f:
|
| 132 |
+
model_config = json.load(f)
|
| 133 |
+
|
| 134 |
+
tokenizer = SimpleTokenizer(token_file=params.token_file)
|
| 135 |
+
if params.model_name in ["zipvoice", "zipvoice_distill"]:
|
| 136 |
+
tokenizer_config = {
|
| 137 |
+
"vocab_size": tokenizer.vocab_size,
|
| 138 |
+
"pad_id": tokenizer.pad_id,
|
| 139 |
+
}
|
| 140 |
+
elif params.model_name in ["zipvoice_dialog", "zipvoice_dialog_stereo"]:
|
| 141 |
+
tokenizer_config = {
|
| 142 |
+
"vocab_size": tokenizer.vocab_size,
|
| 143 |
+
"pad_id": tokenizer.pad_id,
|
| 144 |
+
"spk_a_id": tokenizer.spk_a_id,
|
| 145 |
+
"spk_b_id": tokenizer.spk_a_id,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
| 149 |
+
|
| 150 |
+
print("Script started")
|
| 151 |
+
|
| 152 |
+
params.device = torch.device("cpu")
|
| 153 |
+
print(f"Device: {params.device}")
|
| 154 |
+
|
| 155 |
+
print("About to create model")
|
| 156 |
+
if params.model_name == "zipvoice":
|
| 157 |
+
model = ZipVoice(
|
| 158 |
+
**model_config["model"],
|
| 159 |
+
**tokenizer_config,
|
| 160 |
+
)
|
| 161 |
+
elif params.model_name == "zipvoice_distill":
|
| 162 |
+
model = ZipVoiceDistill(
|
| 163 |
+
**model_config["model"],
|
| 164 |
+
**tokenizer_config,
|
| 165 |
+
)
|
| 166 |
+
elif params.model_name == "zipvoice_dialog":
|
| 167 |
+
model = ZipVoiceDialog(
|
| 168 |
+
**model_config["model"],
|
| 169 |
+
**tokenizer_config,
|
| 170 |
+
)
|
| 171 |
+
elif params.model_name == "zipvoice_dialog_stereo":
|
| 172 |
+
model = ZipVoiceDialogStereo(
|
| 173 |
+
**model_config["model"],
|
| 174 |
+
**tokenizer_config,
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
raise ValueError(f"Unknown model name: {params.model_name}")
|
| 178 |
+
|
| 179 |
+
if params.iter > 0:
|
| 180 |
+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
| 181 |
+
: params.avg + 1
|
| 182 |
+
]
|
| 183 |
+
if len(filenames) == 0:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
| 186 |
+
)
|
| 187 |
+
elif len(filenames) < params.avg + 1:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
f"Not enough checkpoints ({len(filenames)}) found for"
|
| 190 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
| 191 |
+
)
|
| 192 |
+
filename_start = filenames[-1]
|
| 193 |
+
filename_end = filenames[0]
|
| 194 |
+
print(
|
| 195 |
+
"Calculating the averaged model over iteration checkpoints"
|
| 196 |
+
f" from {filename_start} (excluded) to {filename_end}"
|
| 197 |
+
)
|
| 198 |
+
model.to(params.device)
|
| 199 |
+
model.load_state_dict(
|
| 200 |
+
average_checkpoints_with_averaged_model(
|
| 201 |
+
filename_start=filename_start,
|
| 202 |
+
filename_end=filename_end,
|
| 203 |
+
device=params.device,
|
| 204 |
+
),
|
| 205 |
+
strict=True,
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
assert params.avg > 0, params.avg
|
| 209 |
+
start = params.epoch - params.avg
|
| 210 |
+
assert start >= 1, start
|
| 211 |
+
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
| 212 |
+
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
| 213 |
+
print(
|
| 214 |
+
f"Calculating the averaged model over epoch range from "
|
| 215 |
+
f"{start} (excluded) to {params.epoch}"
|
| 216 |
+
)
|
| 217 |
+
model.to(params.device)
|
| 218 |
+
model.load_state_dict(
|
| 219 |
+
average_checkpoints_with_averaged_model(
|
| 220 |
+
filename_start=filename_start,
|
| 221 |
+
filename_end=filename_end,
|
| 222 |
+
device=params.device,
|
| 223 |
+
),
|
| 224 |
+
strict=True,
|
| 225 |
+
)
|
| 226 |
+
if params.iter > 0:
|
| 227 |
+
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
| 228 |
+
else:
|
| 229 |
+
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
| 230 |
+
torch.save({"model": model.state_dict()}, filename)
|
| 231 |
+
|
| 232 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
| 233 |
+
print(f"Number of model parameters: {num_param}")
|
| 234 |
+
|
| 235 |
+
print("Done!")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
main()
|
zipvoice/bin/infer_zipvoice.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script generates speech with our pre-trained ZipVoice or
|
| 20 |
+
ZipVoice-Distill models. If no local model is specified,
|
| 21 |
+
Required files will be automatically downloaded from HuggingFace.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
|
| 25 |
+
Note: If you having trouble connecting to HuggingFace,
|
| 26 |
+
try switching endpoint to mirror site:
|
| 27 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 28 |
+
|
| 29 |
+
(1) Inference of a single sentence:
|
| 30 |
+
|
| 31 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 32 |
+
--model-name "zipvoice" \
|
| 33 |
+
--prompt-wav prompt.wav \
|
| 34 |
+
--prompt-text "I am a prompt." \
|
| 35 |
+
--text "I am a sentence." \
|
| 36 |
+
--res-wav-path result.wav
|
| 37 |
+
|
| 38 |
+
(2) Inference of a list of sentences:
|
| 39 |
+
|
| 40 |
+
python3 -m zipvoice.bin.infer_zipvoice \
|
| 41 |
+
--model-name "zipvoice" \
|
| 42 |
+
--test-list test.tsv \
|
| 43 |
+
--res-dir results
|
| 44 |
+
|
| 45 |
+
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
| 46 |
+
which are the models before and after distillation, respectively.
|
| 47 |
+
|
| 48 |
+
Each line of `test.tsv` is in the format of
|
| 49 |
+
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import datetime as dt
|
| 54 |
+
import json
|
| 55 |
+
import os
|
| 56 |
+
from typing import Optional
|
| 57 |
+
|
| 58 |
+
import numpy as np
|
| 59 |
+
import safetensors.torch
|
| 60 |
+
import torch
|
| 61 |
+
import torchaudio
|
| 62 |
+
from huggingface_hub import hf_hub_download
|
| 63 |
+
from lhotse.utils import fix_random_seed
|
| 64 |
+
from vocos import Vocos
|
| 65 |
+
|
| 66 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 67 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 68 |
+
from zipvoice.tokenizer.tokenizer import (
|
| 69 |
+
EmiliaTokenizer,
|
| 70 |
+
EspeakTokenizer,
|
| 71 |
+
LibriTTSTokenizer,
|
| 72 |
+
SimpleTokenizer,
|
| 73 |
+
)
|
| 74 |
+
from zipvoice.utils.checkpoint import load_checkpoint
|
| 75 |
+
from zipvoice.utils.common import AttributeDict
|
| 76 |
+
from zipvoice.utils.feature import VocosFbank
|
| 77 |
+
|
| 78 |
+
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
|
| 79 |
+
PRETRAINED_MODEL = {
|
| 80 |
+
"zipvoice": "zipvoice/model.pt",
|
| 81 |
+
"zipvoice_distill": "zipvoice_distill/model.pt",
|
| 82 |
+
}
|
| 83 |
+
TOKEN_FILE = {
|
| 84 |
+
"zipvoice": "zipvoice/tokens.txt",
|
| 85 |
+
"zipvoice_distill": "zipvoice_distill/tokens.txt",
|
| 86 |
+
}
|
| 87 |
+
MODEL_CONFIG = {
|
| 88 |
+
"zipvoice": "zipvoice/zipvoice_base.json",
|
| 89 |
+
"zipvoice_distill": "zipvoice_distill/zipvoice_base.json",
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_parser():
|
| 94 |
+
parser = argparse.ArgumentParser(
|
| 95 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--model-name",
|
| 100 |
+
type=str,
|
| 101 |
+
default="zipvoice",
|
| 102 |
+
choices=["zipvoice", "zipvoice_distill"],
|
| 103 |
+
help="The model used for inference",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--checkpoint",
|
| 108 |
+
type=str,
|
| 109 |
+
default=None,
|
| 110 |
+
help="The model checkpoint. "
|
| 111 |
+
"Will download pre-trained checkpoint from huggingface if not specified.",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--model-config",
|
| 116 |
+
type=str,
|
| 117 |
+
default=None,
|
| 118 |
+
help="The model configuration file. "
|
| 119 |
+
"Will download zipvoice_base.json from huggingface if not specified.",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--vocoder-path",
|
| 124 |
+
type=str,
|
| 125 |
+
default=None,
|
| 126 |
+
help="The vocoder checkpoint. "
|
| 127 |
+
"Will download pre-trained vocoder from huggingface if not specified.",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--token-file",
|
| 132 |
+
type=str,
|
| 133 |
+
default=None,
|
| 134 |
+
help="The file that contains information that maps tokens to ids,"
|
| 135 |
+
"which is a text file with '{token}\t{token_id}' per line. "
|
| 136 |
+
"Will download tokens_emilia.txt from huggingface if not specified.",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--tokenizer",
|
| 141 |
+
type=str,
|
| 142 |
+
default="emilia",
|
| 143 |
+
choices=["emilia", "libritts", "espeak", "simple"],
|
| 144 |
+
help="Tokenizer type.",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--lang",
|
| 149 |
+
type=str,
|
| 150 |
+
default="en-us",
|
| 151 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 152 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--test-list",
|
| 157 |
+
type=str,
|
| 158 |
+
default=None,
|
| 159 |
+
help="The list of prompt speech, prompt_transcription, "
|
| 160 |
+
"and text to synthesizein the format of "
|
| 161 |
+
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--prompt-wav",
|
| 166 |
+
type=str,
|
| 167 |
+
default=None,
|
| 168 |
+
help="The prompt wav to mimic",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--prompt-text",
|
| 173 |
+
type=str,
|
| 174 |
+
default=None,
|
| 175 |
+
help="The transcription of the prompt wav",
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--text",
|
| 180 |
+
type=str,
|
| 181 |
+
default=None,
|
| 182 |
+
help="The text to synthesize",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--res-dir",
|
| 187 |
+
type=str,
|
| 188 |
+
default="results",
|
| 189 |
+
help="""
|
| 190 |
+
Path name of the generated wavs dir,
|
| 191 |
+
used when test-list is not None
|
| 192 |
+
""",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--res-wav-path",
|
| 197 |
+
type=str,
|
| 198 |
+
default="result.wav",
|
| 199 |
+
help="""
|
| 200 |
+
Path name of the generated wav path,
|
| 201 |
+
used when test-list is None
|
| 202 |
+
""",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--guidance-scale",
|
| 207 |
+
type=float,
|
| 208 |
+
default=None,
|
| 209 |
+
help="The scale of classifier-free guidance during inference.",
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--num-step",
|
| 214 |
+
type=int,
|
| 215 |
+
default=None,
|
| 216 |
+
help="The number of sampling steps.",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--feat-scale",
|
| 221 |
+
type=float,
|
| 222 |
+
default=0.1,
|
| 223 |
+
help="The scale factor of fbank feature",
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--speed",
|
| 228 |
+
type=float,
|
| 229 |
+
default=1.0,
|
| 230 |
+
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--t-shift",
|
| 235 |
+
type=float,
|
| 236 |
+
default=0.5,
|
| 237 |
+
help="Shift t to smaller ones if t_shift < 1.0",
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--target-rms",
|
| 242 |
+
type=float,
|
| 243 |
+
default=0.1,
|
| 244 |
+
help="Target speech normalization rms value, set to 0 to disable normalization",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
"--seed",
|
| 249 |
+
type=int,
|
| 250 |
+
default=666,
|
| 251 |
+
help="Random seed",
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
return parser
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_vocoder(vocos_local_path: Optional[str] = None):
|
| 258 |
+
if vocos_local_path:
|
| 259 |
+
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
| 260 |
+
state_dict = torch.load(
|
| 261 |
+
f"{vocos_local_path}/pytorch_model.bin",
|
| 262 |
+
weights_only=True,
|
| 263 |
+
map_location="cpu",
|
| 264 |
+
)
|
| 265 |
+
vocoder.load_state_dict(state_dict)
|
| 266 |
+
else:
|
| 267 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 268 |
+
return vocoder
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def generate_sentence(
|
| 272 |
+
save_path: str,
|
| 273 |
+
prompt_text: str,
|
| 274 |
+
prompt_wav: str,
|
| 275 |
+
text: str,
|
| 276 |
+
model: torch.nn.Module,
|
| 277 |
+
vocoder: torch.nn.Module,
|
| 278 |
+
tokenizer: EmiliaTokenizer,
|
| 279 |
+
feature_extractor: VocosFbank,
|
| 280 |
+
device: torch.device,
|
| 281 |
+
num_step: int = 16,
|
| 282 |
+
guidance_scale: float = 1.0,
|
| 283 |
+
speed: float = 1.0,
|
| 284 |
+
t_shift: float = 0.5,
|
| 285 |
+
target_rms: float = 0.1,
|
| 286 |
+
feat_scale: float = 0.1,
|
| 287 |
+
sampling_rate: int = 24000,
|
| 288 |
+
):
|
| 289 |
+
"""
|
| 290 |
+
Generate waveform of a text based on a given prompt
|
| 291 |
+
waveform and its transcription.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
save_path (str): Path to save the generated wav.
|
| 295 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 296 |
+
prompt_wav (str): Path to the prompt wav file.
|
| 297 |
+
text (str): Text to be synthesized into a waveform.
|
| 298 |
+
model (torch.nn.Module): The model used for generation.
|
| 299 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 300 |
+
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
|
| 301 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 302 |
+
extract acoustic features.
|
| 303 |
+
device (torch.device): The device on which computations are performed.
|
| 304 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 305 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 306 |
+
Defaults to 1.0.
|
| 307 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 308 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 309 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 310 |
+
Defaults to 0.1.
|
| 311 |
+
feat_scale (float, optional): Scale for features.
|
| 312 |
+
Defaults to 0.1.
|
| 313 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 314 |
+
Defaults to 24000.
|
| 315 |
+
Returns:
|
| 316 |
+
metrics (dict): Dictionary containing time and real-time
|
| 317 |
+
factor metrics for processing.
|
| 318 |
+
"""
|
| 319 |
+
# Convert text to tokens
|
| 320 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 321 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 322 |
+
|
| 323 |
+
# Load and preprocess prompt wav
|
| 324 |
+
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
| 325 |
+
|
| 326 |
+
if prompt_sampling_rate != sampling_rate:
|
| 327 |
+
resampler = torchaudio.transforms.Resample(
|
| 328 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 329 |
+
)
|
| 330 |
+
prompt_wav = resampler(prompt_wav)
|
| 331 |
+
|
| 332 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 333 |
+
if prompt_rms < target_rms:
|
| 334 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 335 |
+
|
| 336 |
+
# Extract features from prompt wav
|
| 337 |
+
prompt_features = feature_extractor.extract(
|
| 338 |
+
prompt_wav, sampling_rate=sampling_rate
|
| 339 |
+
).to(device)
|
| 340 |
+
|
| 341 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 342 |
+
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
| 343 |
+
|
| 344 |
+
# Start timing
|
| 345 |
+
start_t = dt.datetime.now()
|
| 346 |
+
|
| 347 |
+
# Generate features
|
| 348 |
+
(
|
| 349 |
+
pred_features,
|
| 350 |
+
pred_features_lens,
|
| 351 |
+
pred_prompt_features,
|
| 352 |
+
pred_prompt_features_lens,
|
| 353 |
+
) = model.sample(
|
| 354 |
+
tokens=tokens,
|
| 355 |
+
prompt_tokens=prompt_tokens,
|
| 356 |
+
prompt_features=prompt_features,
|
| 357 |
+
prompt_features_lens=prompt_features_lens,
|
| 358 |
+
speed=speed,
|
| 359 |
+
t_shift=t_shift,
|
| 360 |
+
duration="predict",
|
| 361 |
+
num_step=num_step,
|
| 362 |
+
guidance_scale=guidance_scale,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Postprocess predicted features
|
| 366 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 367 |
+
|
| 368 |
+
# Start vocoder processing
|
| 369 |
+
start_vocoder_t = dt.datetime.now()
|
| 370 |
+
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
| 371 |
+
|
| 372 |
+
# Calculate processing times and real-time factors
|
| 373 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 374 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 375 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 376 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 377 |
+
rtf = t / wav_seconds
|
| 378 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 379 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 380 |
+
metrics = {
|
| 381 |
+
"t": t,
|
| 382 |
+
"t_no_vocoder": t_no_vocoder,
|
| 383 |
+
"t_vocoder": t_vocoder,
|
| 384 |
+
"wav_seconds": wav_seconds,
|
| 385 |
+
"rtf": rtf,
|
| 386 |
+
"rtf_no_vocoder": rtf_no_vocoder,
|
| 387 |
+
"rtf_vocoder": rtf_vocoder,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
# Adjust wav volume if necessary
|
| 391 |
+
if prompt_rms < target_rms:
|
| 392 |
+
wav = wav * prompt_rms / target_rms
|
| 393 |
+
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 394 |
+
|
| 395 |
+
return metrics
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def generate_list(
|
| 399 |
+
res_dir: str,
|
| 400 |
+
test_list: str,
|
| 401 |
+
model: torch.nn.Module,
|
| 402 |
+
vocoder: torch.nn.Module,
|
| 403 |
+
tokenizer: EmiliaTokenizer,
|
| 404 |
+
feature_extractor: VocosFbank,
|
| 405 |
+
device: torch.device,
|
| 406 |
+
num_step: int = 16,
|
| 407 |
+
guidance_scale: float = 1.0,
|
| 408 |
+
speed: float = 1.0,
|
| 409 |
+
t_shift: float = 0.5,
|
| 410 |
+
target_rms: float = 0.1,
|
| 411 |
+
feat_scale: float = 0.1,
|
| 412 |
+
sampling_rate: int = 24000,
|
| 413 |
+
):
|
| 414 |
+
total_t = []
|
| 415 |
+
total_t_no_vocoder = []
|
| 416 |
+
total_t_vocoder = []
|
| 417 |
+
total_wav_seconds = []
|
| 418 |
+
|
| 419 |
+
with open(test_list, "r") as fr:
|
| 420 |
+
lines = fr.readlines()
|
| 421 |
+
|
| 422 |
+
for i, line in enumerate(lines):
|
| 423 |
+
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
| 424 |
+
save_path = f"{res_dir}/{wav_name}.wav"
|
| 425 |
+
metrics = generate_sentence(
|
| 426 |
+
save_path=save_path,
|
| 427 |
+
prompt_text=prompt_text,
|
| 428 |
+
prompt_wav=prompt_wav,
|
| 429 |
+
text=text,
|
| 430 |
+
model=model,
|
| 431 |
+
vocoder=vocoder,
|
| 432 |
+
tokenizer=tokenizer,
|
| 433 |
+
feature_extractor=feature_extractor,
|
| 434 |
+
device=device,
|
| 435 |
+
num_step=num_step,
|
| 436 |
+
guidance_scale=guidance_scale,
|
| 437 |
+
speed=speed,
|
| 438 |
+
t_shift=t_shift,
|
| 439 |
+
target_rms=target_rms,
|
| 440 |
+
feat_scale=feat_scale,
|
| 441 |
+
sampling_rate=sampling_rate,
|
| 442 |
+
)
|
| 443 |
+
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
| 444 |
+
total_t.append(metrics["t"])
|
| 445 |
+
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
| 446 |
+
total_t_vocoder.append(metrics["t_vocoder"])
|
| 447 |
+
total_wav_seconds.append(metrics["wav_seconds"])
|
| 448 |
+
|
| 449 |
+
print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
|
| 450 |
+
print(
|
| 451 |
+
f"Average RTF w/o vocoder: "
|
| 452 |
+
f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 453 |
+
)
|
| 454 |
+
print(
|
| 455 |
+
f"Average RTF vocoder: "
|
| 456 |
+
f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
@torch.inference_mode()
|
| 461 |
+
def main():
|
| 462 |
+
parser = get_parser()
|
| 463 |
+
args = parser.parse_args()
|
| 464 |
+
|
| 465 |
+
params = AttributeDict()
|
| 466 |
+
params.update(vars(args))
|
| 467 |
+
fix_random_seed(params.seed)
|
| 468 |
+
|
| 469 |
+
model_defaults = {
|
| 470 |
+
"zipvoice": {
|
| 471 |
+
"num_step": 16,
|
| 472 |
+
"guidance_scale": 1.0,
|
| 473 |
+
},
|
| 474 |
+
"zipvoice_distill": {
|
| 475 |
+
"num_step": 8,
|
| 476 |
+
"guidance_scale": 3.0,
|
| 477 |
+
},
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
model_specific_defaults = model_defaults.get(params.model_name, {})
|
| 481 |
+
|
| 482 |
+
for param, value in model_specific_defaults.items():
|
| 483 |
+
if getattr(params, param) is None:
|
| 484 |
+
setattr(params, param, value)
|
| 485 |
+
print(f"Setting {param} to default value: {value}")
|
| 486 |
+
|
| 487 |
+
assert (params.test_list is not None) ^ (
|
| 488 |
+
(params.prompt_wav and params.prompt_text and params.text) is not None
|
| 489 |
+
), (
|
| 490 |
+
"For inference, please provide prompts and text with either '--test-list'"
|
| 491 |
+
" or '--prompt-wav, --prompt-text and --text'."
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
if torch.cuda.is_available():
|
| 495 |
+
params.device = torch.device("cuda", 0)
|
| 496 |
+
elif torch.backends.mps.is_available():
|
| 497 |
+
params.device = torch.device("mps")
|
| 498 |
+
else:
|
| 499 |
+
params.device = torch.device("cpu")
|
| 500 |
+
|
| 501 |
+
print("Loading model...")
|
| 502 |
+
if params.model_config is None:
|
| 503 |
+
model_config = hf_hub_download(
|
| 504 |
+
HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name]
|
| 505 |
+
)
|
| 506 |
+
else:
|
| 507 |
+
model_config = params.model_config
|
| 508 |
+
|
| 509 |
+
with open(model_config, "r") as f:
|
| 510 |
+
model_config = json.load(f)
|
| 511 |
+
|
| 512 |
+
if params.token_file is None:
|
| 513 |
+
token_file = hf_hub_download(
|
| 514 |
+
HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name]
|
| 515 |
+
)
|
| 516 |
+
else:
|
| 517 |
+
token_file = params.token_file
|
| 518 |
+
|
| 519 |
+
if params.tokenizer == "emilia":
|
| 520 |
+
tokenizer = EmiliaTokenizer(token_file=token_file)
|
| 521 |
+
elif params.tokenizer == "libritts":
|
| 522 |
+
tokenizer = LibriTTSTokenizer(token_file=token_file)
|
| 523 |
+
elif params.tokenizer == "espeak":
|
| 524 |
+
tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
|
| 525 |
+
else:
|
| 526 |
+
assert params.tokenizer == "simple"
|
| 527 |
+
tokenizer = SimpleTokenizer(token_file=token_file)
|
| 528 |
+
|
| 529 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 530 |
+
|
| 531 |
+
if params.checkpoint is None:
|
| 532 |
+
model_ckpt = hf_hub_download(
|
| 533 |
+
HUGGINGFACE_REPO,
|
| 534 |
+
filename=PRETRAINED_MODEL[params.model_name],
|
| 535 |
+
)
|
| 536 |
+
else:
|
| 537 |
+
model_ckpt = params.checkpoint
|
| 538 |
+
|
| 539 |
+
if params.model_name == "zipvoice":
|
| 540 |
+
model = ZipVoice(
|
| 541 |
+
**model_config["model"],
|
| 542 |
+
**tokenizer_config,
|
| 543 |
+
)
|
| 544 |
+
else:
|
| 545 |
+
assert params.model_name == "zipvoice_distill"
|
| 546 |
+
model = ZipVoiceDistill(
|
| 547 |
+
**model_config["model"],
|
| 548 |
+
**tokenizer_config,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if model_ckpt.endswith(".safetensors"):
|
| 552 |
+
safetensors.torch.load_model(model, model_ckpt)
|
| 553 |
+
elif model_ckpt.endswith(".pt"):
|
| 554 |
+
load_checkpoint(filename=model_ckpt, model=model, strict=True)
|
| 555 |
+
else:
|
| 556 |
+
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
|
| 557 |
+
|
| 558 |
+
model = model.to(params.device)
|
| 559 |
+
model.eval()
|
| 560 |
+
|
| 561 |
+
vocoder = get_vocoder(params.vocoder_path)
|
| 562 |
+
vocoder = vocoder.to(params.device)
|
| 563 |
+
vocoder.eval()
|
| 564 |
+
|
| 565 |
+
if model_config["feature"]["type"] == "vocos":
|
| 566 |
+
feature_extractor = VocosFbank()
|
| 567 |
+
else:
|
| 568 |
+
raise NotImplementedError(
|
| 569 |
+
f"Unsupported feature type: {model_config['feature']['type']}"
|
| 570 |
+
)
|
| 571 |
+
params.sampling_rate = model_config["feature"]["sampling_rate"]
|
| 572 |
+
|
| 573 |
+
print("Start generating...")
|
| 574 |
+
if params.test_list:
|
| 575 |
+
os.makedirs(params.res_dir, exist_ok=True)
|
| 576 |
+
generate_list(
|
| 577 |
+
res_dir=params.res_dir,
|
| 578 |
+
test_list=params.test_list,
|
| 579 |
+
model=model,
|
| 580 |
+
vocoder=vocoder,
|
| 581 |
+
tokenizer=tokenizer,
|
| 582 |
+
feature_extractor=feature_extractor,
|
| 583 |
+
device=params.device,
|
| 584 |
+
num_step=params.num_step,
|
| 585 |
+
guidance_scale=params.guidance_scale,
|
| 586 |
+
speed=params.speed,
|
| 587 |
+
t_shift=params.t_shift,
|
| 588 |
+
target_rms=params.target_rms,
|
| 589 |
+
feat_scale=params.feat_scale,
|
| 590 |
+
sampling_rate=params.sampling_rate,
|
| 591 |
+
)
|
| 592 |
+
else:
|
| 593 |
+
generate_sentence(
|
| 594 |
+
save_path=params.res_wav_path,
|
| 595 |
+
prompt_text=params.prompt_text,
|
| 596 |
+
prompt_wav=params.prompt_wav,
|
| 597 |
+
text=params.text,
|
| 598 |
+
model=model,
|
| 599 |
+
vocoder=vocoder,
|
| 600 |
+
tokenizer=tokenizer,
|
| 601 |
+
feature_extractor=feature_extractor,
|
| 602 |
+
device=params.device,
|
| 603 |
+
num_step=params.num_step,
|
| 604 |
+
guidance_scale=params.guidance_scale,
|
| 605 |
+
speed=params.speed,
|
| 606 |
+
t_shift=params.t_shift,
|
| 607 |
+
target_rms=params.target_rms,
|
| 608 |
+
feat_scale=params.feat_scale,
|
| 609 |
+
sampling_rate=params.sampling_rate,
|
| 610 |
+
)
|
| 611 |
+
print("Done")
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
if __name__ == "__main__":
|
| 615 |
+
torch.set_num_threads(1)
|
| 616 |
+
torch.set_num_interop_threads(1)
|
| 617 |
+
main()
|
zipvoice/bin/infer_zipvoice_dialog.py
ADDED
|
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script generates speech with our pre-trained ZipVoice-Dialog or
|
| 20 |
+
ZipVoice-Dialog-Stereo models. If no local model is specified,
|
| 21 |
+
Required files will be automatically downloaded from HuggingFace.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
|
| 25 |
+
Note: If you having trouble connecting to HuggingFace,
|
| 26 |
+
try switching endpoint to mirror site:
|
| 27 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 28 |
+
|
| 29 |
+
python3 -m zipvoice.bin.infer_zipvoice_dialog \
|
| 30 |
+
--model-name "zipvoice_dialog" \
|
| 31 |
+
--test-list test.tsv \
|
| 32 |
+
--res-dir results
|
| 33 |
+
|
| 34 |
+
`--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`,
|
| 35 |
+
which generate mono and stereo dialogues, respectively.
|
| 36 |
+
|
| 37 |
+
Each line of `test.tsv` is in the format of merged conversation:
|
| 38 |
+
'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'
|
| 39 |
+
or splited conversation:
|
| 40 |
+
'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}
|
| 41 |
+
\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
import argparse
|
| 45 |
+
import datetime as dt
|
| 46 |
+
import json
|
| 47 |
+
import os
|
| 48 |
+
from typing import List, Optional, Union
|
| 49 |
+
|
| 50 |
+
import numpy as np
|
| 51 |
+
import safetensors.torch
|
| 52 |
+
import torch
|
| 53 |
+
import torchaudio
|
| 54 |
+
from huggingface_hub import hf_hub_download
|
| 55 |
+
from lhotse.utils import fix_random_seed
|
| 56 |
+
from vocos import Vocos
|
| 57 |
+
|
| 58 |
+
from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
|
| 59 |
+
from zipvoice.tokenizer.tokenizer import DialogTokenizer
|
| 60 |
+
from zipvoice.utils.checkpoint import load_checkpoint
|
| 61 |
+
from zipvoice.utils.common import AttributeDict
|
| 62 |
+
from zipvoice.utils.feature import VocosFbank
|
| 63 |
+
|
| 64 |
+
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
|
| 65 |
+
PRETRAINED_MODEL = {
|
| 66 |
+
"zipvoice_dialog": "zipvoice_dialog/model.pt",
|
| 67 |
+
"zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.pt",
|
| 68 |
+
}
|
| 69 |
+
TOKEN_FILE = {
|
| 70 |
+
"zipvoice_dialog": "zipvoice_dialog/tokens.txt",
|
| 71 |
+
"zipvoice_dialog_stereo": "zipvoice_dialog_stereo/tokens.txt",
|
| 72 |
+
}
|
| 73 |
+
MODEL_CONFIG = {
|
| 74 |
+
"zipvoice_dialog": "zipvoice_dialog/model.json",
|
| 75 |
+
"zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.json",
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_parser():
|
| 80 |
+
parser = argparse.ArgumentParser(
|
| 81 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--model-name",
|
| 86 |
+
type=str,
|
| 87 |
+
default="zipvoice_dialog",
|
| 88 |
+
choices=["zipvoice_dialog", "zipvoice_dialog_stereo"],
|
| 89 |
+
help="The model used for inference",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--checkpoint",
|
| 94 |
+
type=str,
|
| 95 |
+
default=None,
|
| 96 |
+
help="The model checkpoint. "
|
| 97 |
+
"Will download pre-trained checkpoint from huggingface if not specified.",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--model-config",
|
| 102 |
+
type=str,
|
| 103 |
+
default=None,
|
| 104 |
+
help="The model configuration file. "
|
| 105 |
+
"Will download model.json from huggingface if not specified.",
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--vocoder-path",
|
| 110 |
+
type=str,
|
| 111 |
+
default=None,
|
| 112 |
+
help="The vocoder checkpoint. "
|
| 113 |
+
"Will download pre-trained vocoder from huggingface if not specified.",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--token-file",
|
| 118 |
+
type=str,
|
| 119 |
+
default=None,
|
| 120 |
+
help="The file that contains information that maps tokens to ids,"
|
| 121 |
+
"which is a text file with '{token}\t{token_id}' per line. "
|
| 122 |
+
"Will download tokens_emilia.txt from huggingface if not specified.",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--test-list",
|
| 127 |
+
type=str,
|
| 128 |
+
default=None,
|
| 129 |
+
help="The list of prompt speech, prompt_transcription, "
|
| 130 |
+
"and text to synthesizein the format of merged conversation: "
|
| 131 |
+
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' "
|
| 132 |
+
"or splited conversation: "
|
| 133 |
+
"'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}"
|
| 134 |
+
"\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--res-dir",
|
| 139 |
+
type=str,
|
| 140 |
+
default="results",
|
| 141 |
+
help="""
|
| 142 |
+
Path name of the generated wavs dir,
|
| 143 |
+
used when test-list is not None
|
| 144 |
+
""",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--guidance-scale",
|
| 149 |
+
type=float,
|
| 150 |
+
default=1.5,
|
| 151 |
+
help="The scale of classifier-free guidance during inference.",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--num-step",
|
| 156 |
+
type=int,
|
| 157 |
+
default=16,
|
| 158 |
+
help="The number of sampling steps.",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--feat-scale",
|
| 163 |
+
type=float,
|
| 164 |
+
default=0.1,
|
| 165 |
+
help="The scale factor of fbank feature",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--speed",
|
| 170 |
+
type=float,
|
| 171 |
+
default=1.0,
|
| 172 |
+
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--t-shift",
|
| 177 |
+
type=float,
|
| 178 |
+
default=0.5,
|
| 179 |
+
help="Shift t to smaller ones if t_shift < 1.0",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--target-rms",
|
| 184 |
+
type=float,
|
| 185 |
+
default=0.1,
|
| 186 |
+
help="Target speech normalization rms value, set to 0 to disable normalization",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--seed",
|
| 191 |
+
type=int,
|
| 192 |
+
default=666,
|
| 193 |
+
help="Random seed",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"--silence-wav",
|
| 198 |
+
type=str,
|
| 199 |
+
default="assets/silence.wav",
|
| 200 |
+
help="Path of the silence wav file, used in two-channel generation "
|
| 201 |
+
"with single-channel prompts",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return parser
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_vocoder(vocos_local_path: Optional[str] = None):
|
| 208 |
+
if vocos_local_path:
|
| 209 |
+
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
| 210 |
+
state_dict = torch.load(
|
| 211 |
+
f"{vocos_local_path}/pytorch_model.bin",
|
| 212 |
+
weights_only=True,
|
| 213 |
+
map_location="cpu",
|
| 214 |
+
)
|
| 215 |
+
vocoder.load_state_dict(state_dict)
|
| 216 |
+
else:
|
| 217 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 218 |
+
return vocoder
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def generate_sentence(
|
| 222 |
+
save_path: str,
|
| 223 |
+
prompt_text: str,
|
| 224 |
+
prompt_wav: Union[str, List[str]],
|
| 225 |
+
text: str,
|
| 226 |
+
model: torch.nn.Module,
|
| 227 |
+
vocoder: torch.nn.Module,
|
| 228 |
+
tokenizer: DialogTokenizer,
|
| 229 |
+
feature_extractor: VocosFbank,
|
| 230 |
+
device: torch.device,
|
| 231 |
+
num_step: int = 16,
|
| 232 |
+
guidance_scale: float = 1.0,
|
| 233 |
+
speed: float = 1.0,
|
| 234 |
+
t_shift: float = 0.5,
|
| 235 |
+
target_rms: float = 0.1,
|
| 236 |
+
feat_scale: float = 0.1,
|
| 237 |
+
sampling_rate: int = 24000,
|
| 238 |
+
):
|
| 239 |
+
"""
|
| 240 |
+
Generate waveform of a text based on a given prompt
|
| 241 |
+
waveform and its transcription.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
save_path (str): Path to save the generated wav.
|
| 245 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 246 |
+
prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
|
| 247 |
+
one or two wav files, which corresponding to a merged conversational
|
| 248 |
+
speech or two seperate speaker's speech.
|
| 249 |
+
text (str): Text to be synthesized into a waveform.
|
| 250 |
+
model (torch.nn.Module): The model used for generation.
|
| 251 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 252 |
+
tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
|
| 253 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 254 |
+
extract acoustic features.
|
| 255 |
+
device (torch.device): The device on which computations are performed.
|
| 256 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 257 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 258 |
+
Defaults to 1.0.
|
| 259 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 260 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 261 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 262 |
+
Defaults to 0.1.
|
| 263 |
+
feat_scale (float, optional): Scale for features.
|
| 264 |
+
Defaults to 0.1.
|
| 265 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 266 |
+
Defaults to 24000.
|
| 267 |
+
Returns:
|
| 268 |
+
metrics (dict): Dictionary containing time and real-time
|
| 269 |
+
factor metrics for processing.
|
| 270 |
+
"""
|
| 271 |
+
# Convert text to tokens
|
| 272 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 273 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 274 |
+
|
| 275 |
+
# Load and preprocess prompt wav
|
| 276 |
+
if isinstance(prompt_wav, str):
|
| 277 |
+
prompt_wav = [
|
| 278 |
+
prompt_wav,
|
| 279 |
+
]
|
| 280 |
+
else:
|
| 281 |
+
assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
|
| 282 |
+
|
| 283 |
+
loaded_prompt_wavs = prompt_wav
|
| 284 |
+
for i in range(len(prompt_wav)):
|
| 285 |
+
loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
|
| 286 |
+
if prompt_sampling_rate != sampling_rate:
|
| 287 |
+
resampler = torchaudio.transforms.Resample(
|
| 288 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 289 |
+
)
|
| 290 |
+
loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
|
| 291 |
+
|
| 292 |
+
if len(loaded_prompt_wavs) == 1:
|
| 293 |
+
prompt_wav = loaded_prompt_wavs[0]
|
| 294 |
+
else:
|
| 295 |
+
prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
|
| 296 |
+
|
| 297 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 298 |
+
if prompt_rms < target_rms:
|
| 299 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 300 |
+
|
| 301 |
+
# Extract features from prompt wav
|
| 302 |
+
prompt_features = feature_extractor.extract(
|
| 303 |
+
prompt_wav, sampling_rate=sampling_rate
|
| 304 |
+
).to(device)
|
| 305 |
+
|
| 306 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 307 |
+
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
| 308 |
+
|
| 309 |
+
# Start timing
|
| 310 |
+
start_t = dt.datetime.now()
|
| 311 |
+
|
| 312 |
+
# Generate features
|
| 313 |
+
(
|
| 314 |
+
pred_features,
|
| 315 |
+
pred_features_lens,
|
| 316 |
+
pred_prompt_features,
|
| 317 |
+
pred_prompt_features_lens,
|
| 318 |
+
) = model.sample(
|
| 319 |
+
tokens=tokens,
|
| 320 |
+
prompt_tokens=prompt_tokens,
|
| 321 |
+
prompt_features=prompt_features,
|
| 322 |
+
prompt_features_lens=prompt_features_lens,
|
| 323 |
+
speed=speed,
|
| 324 |
+
t_shift=t_shift,
|
| 325 |
+
duration="predict",
|
| 326 |
+
num_step=num_step,
|
| 327 |
+
guidance_scale=guidance_scale,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Postprocess predicted features
|
| 331 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 332 |
+
|
| 333 |
+
# Start vocoder processing
|
| 334 |
+
start_vocoder_t = dt.datetime.now()
|
| 335 |
+
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
| 336 |
+
|
| 337 |
+
# Calculate processing times and real-time factors
|
| 338 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 339 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 340 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 341 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 342 |
+
rtf = t / wav_seconds
|
| 343 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 344 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 345 |
+
metrics = {
|
| 346 |
+
"t": t,
|
| 347 |
+
"t_no_vocoder": t_no_vocoder,
|
| 348 |
+
"t_vocoder": t_vocoder,
|
| 349 |
+
"wav_seconds": wav_seconds,
|
| 350 |
+
"rtf": rtf,
|
| 351 |
+
"rtf_no_vocoder": rtf_no_vocoder,
|
| 352 |
+
"rtf_vocoder": rtf_vocoder,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
# Adjust wav volume if necessary
|
| 356 |
+
if prompt_rms < target_rms:
|
| 357 |
+
wav = wav * prompt_rms / target_rms
|
| 358 |
+
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 359 |
+
|
| 360 |
+
return metrics
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def generate_sentence_stereo(
|
| 364 |
+
save_path: str,
|
| 365 |
+
prompt_text: str,
|
| 366 |
+
prompt_wav: Union[str, List[str]],
|
| 367 |
+
text: str,
|
| 368 |
+
model: torch.nn.Module,
|
| 369 |
+
vocoder: torch.nn.Module,
|
| 370 |
+
tokenizer: DialogTokenizer,
|
| 371 |
+
feature_extractor: VocosFbank,
|
| 372 |
+
device: torch.device,
|
| 373 |
+
num_step: int = 16,
|
| 374 |
+
guidance_scale: float = 1.0,
|
| 375 |
+
speed: float = 1.0,
|
| 376 |
+
t_shift: float = 0.5,
|
| 377 |
+
target_rms: float = 0.1,
|
| 378 |
+
feat_scale: float = 0.1,
|
| 379 |
+
sampling_rate: int = 24000,
|
| 380 |
+
silence_wav: Optional[str] = None,
|
| 381 |
+
):
|
| 382 |
+
"""
|
| 383 |
+
Generate waveform of a text based on a given prompt
|
| 384 |
+
waveform and its transcription.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
save_path (str): Path to save the generated wav.
|
| 388 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 389 |
+
prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
|
| 390 |
+
one or two wav files, which corresponding to a merged conversational
|
| 391 |
+
speech or two seperate speaker's speech.
|
| 392 |
+
text (str): Text to be synthesized into a waveform.
|
| 393 |
+
model (torch.nn.Module): The model used for generation.
|
| 394 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 395 |
+
tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
|
| 396 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 397 |
+
extract acoustic features.
|
| 398 |
+
device (torch.device): The device on which computations are performed.
|
| 399 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 400 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 401 |
+
Defaults to 1.0.
|
| 402 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 403 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 404 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 405 |
+
Defaults to 0.1.
|
| 406 |
+
feat_scale (float, optional): Scale for features.
|
| 407 |
+
Defaults to 0.1.
|
| 408 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 409 |
+
Defaults to 24000.
|
| 410 |
+
silence_wav (str): Path of the silence wav file, used in two-channel
|
| 411 |
+
generation with single-channel prompts
|
| 412 |
+
Returns:
|
| 413 |
+
metrics (dict): Dictionary containing time and real-time
|
| 414 |
+
factor metrics for processing.
|
| 415 |
+
"""
|
| 416 |
+
# Convert text to tokens
|
| 417 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 418 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 419 |
+
|
| 420 |
+
# Load and preprocess prompt wav
|
| 421 |
+
if isinstance(prompt_wav, str):
|
| 422 |
+
prompt_wav = [
|
| 423 |
+
prompt_wav,
|
| 424 |
+
]
|
| 425 |
+
else:
|
| 426 |
+
assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
|
| 427 |
+
|
| 428 |
+
loaded_prompt_wavs = prompt_wav
|
| 429 |
+
for i in range(len(prompt_wav)):
|
| 430 |
+
loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
|
| 431 |
+
if prompt_sampling_rate != sampling_rate:
|
| 432 |
+
resampler = torchaudio.transforms.Resample(
|
| 433 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 434 |
+
)
|
| 435 |
+
loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
|
| 436 |
+
|
| 437 |
+
if len(loaded_prompt_wavs) == 1:
|
| 438 |
+
assert (
|
| 439 |
+
loaded_prompt_wavs[0].size(0) == 2
|
| 440 |
+
), "Merged prompt wav must be stereo for stereo dialogue generation"
|
| 441 |
+
prompt_wav = loaded_prompt_wavs[0]
|
| 442 |
+
|
| 443 |
+
else:
|
| 444 |
+
assert len(loaded_prompt_wavs) == 2
|
| 445 |
+
if loaded_prompt_wavs[0].size(0) == 2:
|
| 446 |
+
prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
|
| 447 |
+
else:
|
| 448 |
+
assert loaded_prompt_wavs[0].size(0) == 1
|
| 449 |
+
silence_wav, silence_sampling_rate = torchaudio.load(silence_wav)
|
| 450 |
+
assert silence_sampling_rate == sampling_rate
|
| 451 |
+
prompt_wav = silence_wav[
|
| 452 |
+
:, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1)
|
| 453 |
+
]
|
| 454 |
+
prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0]
|
| 455 |
+
prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1]
|
| 456 |
+
|
| 457 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 458 |
+
if prompt_rms < target_rms:
|
| 459 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 460 |
+
|
| 461 |
+
# Extract features from prompt wav
|
| 462 |
+
prompt_features = feature_extractor.extract(
|
| 463 |
+
prompt_wav, sampling_rate=sampling_rate
|
| 464 |
+
).to(device)
|
| 465 |
+
|
| 466 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 467 |
+
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
| 468 |
+
|
| 469 |
+
# Start timing
|
| 470 |
+
start_t = dt.datetime.now()
|
| 471 |
+
|
| 472 |
+
# Generate features
|
| 473 |
+
(
|
| 474 |
+
pred_features,
|
| 475 |
+
pred_features_lens,
|
| 476 |
+
pred_prompt_features,
|
| 477 |
+
pred_prompt_features_lens,
|
| 478 |
+
) = model.sample(
|
| 479 |
+
tokens=tokens,
|
| 480 |
+
prompt_tokens=prompt_tokens,
|
| 481 |
+
prompt_features=prompt_features,
|
| 482 |
+
prompt_features_lens=prompt_features_lens,
|
| 483 |
+
speed=speed,
|
| 484 |
+
t_shift=t_shift,
|
| 485 |
+
duration="predict",
|
| 486 |
+
num_step=num_step,
|
| 487 |
+
guidance_scale=guidance_scale,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# Postprocess predicted features
|
| 491 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 492 |
+
|
| 493 |
+
# Start vocoder processing
|
| 494 |
+
start_vocoder_t = dt.datetime.now()
|
| 495 |
+
feat_dim = pred_features.size(1) // 2
|
| 496 |
+
wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1)
|
| 497 |
+
wav_right = (
|
| 498 |
+
vocoder.decode(pred_features[:, feat_dim : feat_dim * 2])
|
| 499 |
+
.squeeze(1)
|
| 500 |
+
.clamp(-1, 1)
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
wav = torch.cat([wav_left, wav_right], dim=0)
|
| 504 |
+
|
| 505 |
+
# Calculate processing times and real-time factors
|
| 506 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 507 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 508 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 509 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 510 |
+
rtf = t / wav_seconds
|
| 511 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 512 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 513 |
+
metrics = {
|
| 514 |
+
"t": t,
|
| 515 |
+
"t_no_vocoder": t_no_vocoder,
|
| 516 |
+
"t_vocoder": t_vocoder,
|
| 517 |
+
"wav_seconds": wav_seconds,
|
| 518 |
+
"rtf": rtf,
|
| 519 |
+
"rtf_no_vocoder": rtf_no_vocoder,
|
| 520 |
+
"rtf_vocoder": rtf_vocoder,
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
# Adjust wav volume if necessary
|
| 524 |
+
if prompt_rms < target_rms:
|
| 525 |
+
wav = wav * prompt_rms / target_rms
|
| 526 |
+
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 527 |
+
|
| 528 |
+
return metrics
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def generate_list(
|
| 532 |
+
model_name: str,
|
| 533 |
+
res_dir: str,
|
| 534 |
+
test_list: str,
|
| 535 |
+
model: torch.nn.Module,
|
| 536 |
+
vocoder: torch.nn.Module,
|
| 537 |
+
tokenizer: DialogTokenizer,
|
| 538 |
+
feature_extractor: VocosFbank,
|
| 539 |
+
device: torch.device,
|
| 540 |
+
num_step: int = 16,
|
| 541 |
+
guidance_scale: float = 1.5,
|
| 542 |
+
speed: float = 1.0,
|
| 543 |
+
t_shift: float = 0.5,
|
| 544 |
+
target_rms: float = 0.1,
|
| 545 |
+
feat_scale: float = 0.1,
|
| 546 |
+
sampling_rate: int = 24000,
|
| 547 |
+
silence_wav: Optional[str] = None,
|
| 548 |
+
):
|
| 549 |
+
total_t = []
|
| 550 |
+
total_t_no_vocoder = []
|
| 551 |
+
total_t_vocoder = []
|
| 552 |
+
total_wav_seconds = []
|
| 553 |
+
|
| 554 |
+
with open(test_list, "r") as fr:
|
| 555 |
+
lines = fr.readlines()
|
| 556 |
+
|
| 557 |
+
for i, line in enumerate(lines):
|
| 558 |
+
items = line.strip().split("\t")
|
| 559 |
+
if len(items) == 6:
|
| 560 |
+
(
|
| 561 |
+
wav_name,
|
| 562 |
+
prompt_text_1,
|
| 563 |
+
prompt_text_2,
|
| 564 |
+
prompt_wav_1,
|
| 565 |
+
prompt_wav_2,
|
| 566 |
+
text,
|
| 567 |
+
) = items
|
| 568 |
+
prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}"
|
| 569 |
+
prompt_wav = [prompt_wav_1, prompt_wav_2]
|
| 570 |
+
elif len(items) == 4:
|
| 571 |
+
wav_name, prompt_text, prompt_wav, text = items
|
| 572 |
+
else:
|
| 573 |
+
raise ValueError(f"Invalid line: {line}")
|
| 574 |
+
assert text.startswith("[S1]")
|
| 575 |
+
|
| 576 |
+
save_path = f"{res_dir}/{wav_name}.wav"
|
| 577 |
+
|
| 578 |
+
if model_name == "zipvoice_dialog":
|
| 579 |
+
|
| 580 |
+
metrics = generate_sentence(
|
| 581 |
+
save_path=save_path,
|
| 582 |
+
prompt_text=prompt_text,
|
| 583 |
+
prompt_wav=prompt_wav,
|
| 584 |
+
text=text,
|
| 585 |
+
model=model,
|
| 586 |
+
vocoder=vocoder,
|
| 587 |
+
tokenizer=tokenizer,
|
| 588 |
+
feature_extractor=feature_extractor,
|
| 589 |
+
device=device,
|
| 590 |
+
num_step=num_step,
|
| 591 |
+
guidance_scale=guidance_scale,
|
| 592 |
+
speed=speed,
|
| 593 |
+
t_shift=t_shift,
|
| 594 |
+
target_rms=target_rms,
|
| 595 |
+
feat_scale=feat_scale,
|
| 596 |
+
sampling_rate=sampling_rate,
|
| 597 |
+
)
|
| 598 |
+
else:
|
| 599 |
+
assert model_name == "zipvoice_dialog_stereo"
|
| 600 |
+
metrics = generate_sentence_stereo(
|
| 601 |
+
save_path=save_path,
|
| 602 |
+
prompt_text=prompt_text,
|
| 603 |
+
prompt_wav=prompt_wav,
|
| 604 |
+
text=text,
|
| 605 |
+
model=model,
|
| 606 |
+
vocoder=vocoder,
|
| 607 |
+
tokenizer=tokenizer,
|
| 608 |
+
feature_extractor=feature_extractor,
|
| 609 |
+
device=device,
|
| 610 |
+
num_step=num_step,
|
| 611 |
+
guidance_scale=guidance_scale,
|
| 612 |
+
speed=speed,
|
| 613 |
+
t_shift=t_shift,
|
| 614 |
+
target_rms=target_rms,
|
| 615 |
+
feat_scale=feat_scale,
|
| 616 |
+
sampling_rate=sampling_rate,
|
| 617 |
+
silence_wav=silence_wav,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
| 621 |
+
total_t.append(metrics["t"])
|
| 622 |
+
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
| 623 |
+
total_t_vocoder.append(metrics["t_vocoder"])
|
| 624 |
+
total_wav_seconds.append(metrics["wav_seconds"])
|
| 625 |
+
|
| 626 |
+
print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
|
| 627 |
+
print(
|
| 628 |
+
f"Average RTF w/o vocoder: "
|
| 629 |
+
f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 630 |
+
)
|
| 631 |
+
print(
|
| 632 |
+
f"Average RTF vocoder: "
|
| 633 |
+
f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
@torch.inference_mode()
|
| 638 |
+
def main():
|
| 639 |
+
parser = get_parser()
|
| 640 |
+
args = parser.parse_args()
|
| 641 |
+
|
| 642 |
+
params = AttributeDict()
|
| 643 |
+
params.update(vars(args))
|
| 644 |
+
fix_random_seed(params.seed)
|
| 645 |
+
|
| 646 |
+
assert (
|
| 647 |
+
params.test_list is not None
|
| 648 |
+
), "For inference, please provide prompts and text with '--test-list'"
|
| 649 |
+
|
| 650 |
+
if torch.cuda.is_available():
|
| 651 |
+
params.device = torch.device("cuda", 0)
|
| 652 |
+
elif torch.backends.mps.is_available():
|
| 653 |
+
params.device = torch.device("mps")
|
| 654 |
+
else:
|
| 655 |
+
params.device = torch.device("cpu")
|
| 656 |
+
|
| 657 |
+
print("Loading model...")
|
| 658 |
+
if params.model_config is None:
|
| 659 |
+
model_config = hf_hub_download(
|
| 660 |
+
HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name]
|
| 661 |
+
)
|
| 662 |
+
else:
|
| 663 |
+
model_config = params.model_config
|
| 664 |
+
|
| 665 |
+
with open(model_config, "r") as f:
|
| 666 |
+
model_config = json.load(f)
|
| 667 |
+
|
| 668 |
+
if params.token_file is None:
|
| 669 |
+
token_file = hf_hub_download(
|
| 670 |
+
HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name]
|
| 671 |
+
)
|
| 672 |
+
else:
|
| 673 |
+
token_file = params.token_file
|
| 674 |
+
|
| 675 |
+
tokenizer = DialogTokenizer(token_file=token_file)
|
| 676 |
+
|
| 677 |
+
tokenizer_config = {
|
| 678 |
+
"vocab_size": tokenizer.vocab_size,
|
| 679 |
+
"pad_id": tokenizer.pad_id,
|
| 680 |
+
"spk_a_id": tokenizer.spk_a_id,
|
| 681 |
+
"spk_b_id": tokenizer.spk_b_id,
|
| 682 |
+
}
|
| 683 |
+
if params.checkpoint is None:
|
| 684 |
+
model_ckpt = hf_hub_download(
|
| 685 |
+
HUGGINGFACE_REPO,
|
| 686 |
+
filename=PRETRAINED_MODEL[params.model_name],
|
| 687 |
+
)
|
| 688 |
+
else:
|
| 689 |
+
model_ckpt = params.checkpoint
|
| 690 |
+
|
| 691 |
+
if params.model_name == "zipvoice_dialog":
|
| 692 |
+
model = ZipVoiceDialog(
|
| 693 |
+
**model_config["model"],
|
| 694 |
+
**tokenizer_config,
|
| 695 |
+
)
|
| 696 |
+
else:
|
| 697 |
+
assert params.model_name == "zipvoice_dialog_stereo"
|
| 698 |
+
model = ZipVoiceDialogStereo(
|
| 699 |
+
**model_config["model"],
|
| 700 |
+
**tokenizer_config,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
if model_ckpt.endswith(".safetensors"):
|
| 704 |
+
safetensors.torch.load_model(model, model_ckpt)
|
| 705 |
+
elif model_ckpt.endswith(".pt"):
|
| 706 |
+
load_checkpoint(filename=model_ckpt, model=model, strict=True)
|
| 707 |
+
else:
|
| 708 |
+
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
|
| 709 |
+
|
| 710 |
+
model = model.to(params.device)
|
| 711 |
+
model.eval()
|
| 712 |
+
|
| 713 |
+
vocoder = get_vocoder(params.vocoder_path)
|
| 714 |
+
vocoder = vocoder.to(params.device)
|
| 715 |
+
vocoder.eval()
|
| 716 |
+
|
| 717 |
+
if model_config["feature"]["type"] == "vocos":
|
| 718 |
+
if params.model_name == "zipvoice_dialog":
|
| 719 |
+
num_channels = 1
|
| 720 |
+
else:
|
| 721 |
+
assert params.model_name == "zipvoice_dialog_stereo"
|
| 722 |
+
num_channels = 2
|
| 723 |
+
feature_extractor = VocosFbank(num_channels=num_channels)
|
| 724 |
+
else:
|
| 725 |
+
raise NotImplementedError(
|
| 726 |
+
f"Unsupported feature type: {model_config['feature']['type']}"
|
| 727 |
+
)
|
| 728 |
+
params.sampling_rate = model_config["feature"]["sampling_rate"]
|
| 729 |
+
|
| 730 |
+
print("Start generating...")
|
| 731 |
+
os.makedirs(params.res_dir, exist_ok=True)
|
| 732 |
+
generate_list(
|
| 733 |
+
model_name=params.model_name,
|
| 734 |
+
res_dir=params.res_dir,
|
| 735 |
+
test_list=params.test_list,
|
| 736 |
+
model=model,
|
| 737 |
+
vocoder=vocoder,
|
| 738 |
+
tokenizer=tokenizer,
|
| 739 |
+
feature_extractor=feature_extractor,
|
| 740 |
+
device=params.device,
|
| 741 |
+
num_step=params.num_step,
|
| 742 |
+
guidance_scale=params.guidance_scale,
|
| 743 |
+
speed=params.speed,
|
| 744 |
+
t_shift=params.t_shift,
|
| 745 |
+
target_rms=params.target_rms,
|
| 746 |
+
feat_scale=params.feat_scale,
|
| 747 |
+
sampling_rate=params.sampling_rate,
|
| 748 |
+
silence_wav=params.silence_wav,
|
| 749 |
+
)
|
| 750 |
+
print("Done")
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
if __name__ == "__main__":
|
| 754 |
+
torch.set_num_threads(1)
|
| 755 |
+
torch.set_num_interop_threads(1)
|
| 756 |
+
main()
|
zipvoice/bin/infer_zipvoice_onnx.py
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
|
| 2 |
+
# Zengwei Yao)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
"""
|
| 18 |
+
This script generates speech with our pre-trained ZipVoice or ZipVoice-Distill
|
| 19 |
+
ONNX models. If no local model is specified,
|
| 20 |
+
Required files will be automatically downloaded from HuggingFace.
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
|
| 24 |
+
Note: If you having trouble connecting to HuggingFace,
|
| 25 |
+
try switching endpoint to mirror site:
|
| 26 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 27 |
+
|
| 28 |
+
(1) Inference of a single sentence:
|
| 29 |
+
|
| 30 |
+
python3 -m zipvoice.bin.infer_zipvoice_onnx \
|
| 31 |
+
--onnx-int8 False \
|
| 32 |
+
--model-name "zipvoice" \
|
| 33 |
+
--prompt-wav prompt.wav \
|
| 34 |
+
--prompt-text "I am a prompt." \
|
| 35 |
+
--text "I am a sentence." \
|
| 36 |
+
--res-wav-path result.wav
|
| 37 |
+
|
| 38 |
+
(2) Inference of a list of sentences:
|
| 39 |
+
python3 -m zipvoice.bin.infer_zipvoice_onnx \
|
| 40 |
+
--onnx-int8 False \
|
| 41 |
+
--model-name "zipvoice" \
|
| 42 |
+
--test-list test.tsv \
|
| 43 |
+
--res-dir results
|
| 44 |
+
|
| 45 |
+
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
| 46 |
+
which are the models before and after distillation, respectively.
|
| 47 |
+
|
| 48 |
+
Each line of `test.tsv` is in the format of
|
| 49 |
+
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
| 50 |
+
|
| 51 |
+
Set `--onnx-int8 True` to use int8 quantizated ONNX model.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
import argparse
|
| 55 |
+
import datetime as dt
|
| 56 |
+
import json
|
| 57 |
+
import os
|
| 58 |
+
from typing import List, Tuple
|
| 59 |
+
|
| 60 |
+
import numpy as np
|
| 61 |
+
import onnxruntime as ort
|
| 62 |
+
import torch
|
| 63 |
+
import torchaudio
|
| 64 |
+
from huggingface_hub import hf_hub_download
|
| 65 |
+
from lhotse.utils import fix_random_seed
|
| 66 |
+
from torch import Tensor, nn
|
| 67 |
+
|
| 68 |
+
from zipvoice.bin.infer_zipvoice import get_vocoder
|
| 69 |
+
from zipvoice.models.modules.solver import get_time_steps
|
| 70 |
+
from zipvoice.tokenizer.tokenizer import (
|
| 71 |
+
EmiliaTokenizer,
|
| 72 |
+
EspeakTokenizer,
|
| 73 |
+
LibriTTSTokenizer,
|
| 74 |
+
SimpleTokenizer,
|
| 75 |
+
)
|
| 76 |
+
from zipvoice.utils.common import AttributeDict, str2bool
|
| 77 |
+
from zipvoice.utils.feature import VocosFbank
|
| 78 |
+
|
| 79 |
+
HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
|
| 80 |
+
TOKEN_FILE = {
|
| 81 |
+
"zipvoice": "zipvoice/tokens.txt",
|
| 82 |
+
"zipvoice_distill": "zipvoice_distill/tokens.txt",
|
| 83 |
+
}
|
| 84 |
+
MODEL_CONFIG = {
|
| 85 |
+
"zipvoice": "zipvoice/model.json",
|
| 86 |
+
"zipvoice_distill": "zipvoice_distill/model.json",
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_parser():
|
| 91 |
+
parser = argparse.ArgumentParser(
|
| 92 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--onnx-int8",
|
| 97 |
+
type=str2bool,
|
| 98 |
+
default=False,
|
| 99 |
+
help="Whether to use the int8 model",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--model-name",
|
| 104 |
+
type=str,
|
| 105 |
+
default="zipvoice",
|
| 106 |
+
choices=["zipvoice", "zipvoice_distill"],
|
| 107 |
+
help="The model used for inference",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--onnx-model-dir",
|
| 112 |
+
type=str,
|
| 113 |
+
default=None,
|
| 114 |
+
help="The path to the local onnx model. "
|
| 115 |
+
"Will download pre-trained checkpoint from huggingface if not specified.",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--model-config",
|
| 120 |
+
type=str,
|
| 121 |
+
default=None,
|
| 122 |
+
help="The model configuration file. "
|
| 123 |
+
"Will download model.json from huggingface if not specified.",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--vocoder-path",
|
| 128 |
+
type=str,
|
| 129 |
+
default=None,
|
| 130 |
+
help="The vocoder checkpoint. "
|
| 131 |
+
"Will download pre-trained vocoder from huggingface if not specified.",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--token-file",
|
| 136 |
+
type=str,
|
| 137 |
+
default=None,
|
| 138 |
+
help="The file that contains information that maps tokens to ids,"
|
| 139 |
+
"which is a text file with '{token}\t{token_id}' per line. "
|
| 140 |
+
"Will download tokens_emilia.txt from huggingface if not specified.",
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--tokenizer",
|
| 145 |
+
type=str,
|
| 146 |
+
default="emilia",
|
| 147 |
+
choices=["emilia", "libritts", "espeak", "simple"],
|
| 148 |
+
help="Tokenizer type.",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--lang",
|
| 153 |
+
type=str,
|
| 154 |
+
default="en-us",
|
| 155 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 156 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"--test-list",
|
| 161 |
+
type=str,
|
| 162 |
+
default=None,
|
| 163 |
+
help="The list of prompt speech, prompt_transcription, "
|
| 164 |
+
"and text to synthesizein the format of "
|
| 165 |
+
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--prompt-wav",
|
| 170 |
+
type=str,
|
| 171 |
+
default=None,
|
| 172 |
+
help="The prompt wav to mimic",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--prompt-text",
|
| 177 |
+
type=str,
|
| 178 |
+
default=None,
|
| 179 |
+
help="The transcription of the prompt wav",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--text",
|
| 184 |
+
type=str,
|
| 185 |
+
default=None,
|
| 186 |
+
help="The text to synthesize",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--res-dir",
|
| 191 |
+
type=str,
|
| 192 |
+
default="results",
|
| 193 |
+
help="""
|
| 194 |
+
Path name of the generated wavs dir,
|
| 195 |
+
used when test-list is not None
|
| 196 |
+
""",
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--res-wav-path",
|
| 201 |
+
type=str,
|
| 202 |
+
default="result.wav",
|
| 203 |
+
help="""
|
| 204 |
+
Path name of the generated wav path,
|
| 205 |
+
used when test-list is None
|
| 206 |
+
""",
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--guidance-scale",
|
| 211 |
+
type=float,
|
| 212 |
+
default=None,
|
| 213 |
+
help="The scale of classifier-free guidance during inference.",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--num-step",
|
| 218 |
+
type=int,
|
| 219 |
+
default=None,
|
| 220 |
+
help="The number of sampling steps.",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--feat-scale",
|
| 225 |
+
type=float,
|
| 226 |
+
default=0.1,
|
| 227 |
+
help="The scale factor of fbank feature",
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--speed",
|
| 232 |
+
type=float,
|
| 233 |
+
default=1.0,
|
| 234 |
+
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--t-shift",
|
| 239 |
+
type=float,
|
| 240 |
+
default=0.5,
|
| 241 |
+
help="Shift t to smaller ones if t_shift < 1.0",
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--target-rms",
|
| 246 |
+
type=float,
|
| 247 |
+
default=0.1,
|
| 248 |
+
help="Target speech normalization rms value, set to 0 to disable normalization",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--seed",
|
| 253 |
+
type=int,
|
| 254 |
+
default=666,
|
| 255 |
+
help="Random seed",
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
return parser
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class OnnxModel:
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
text_encoder_path: str,
|
| 265 |
+
fm_decoder_path: str,
|
| 266 |
+
):
|
| 267 |
+
session_opts = ort.SessionOptions()
|
| 268 |
+
session_opts.inter_op_num_threads = 1
|
| 269 |
+
session_opts.intra_op_num_threads = 1
|
| 270 |
+
|
| 271 |
+
self.session_opts = session_opts
|
| 272 |
+
|
| 273 |
+
self.init_text_encoder(text_encoder_path)
|
| 274 |
+
self.init_fm_decoder(fm_decoder_path)
|
| 275 |
+
|
| 276 |
+
def init_text_encoder(self, model_path: str):
|
| 277 |
+
self.text_encoder = ort.InferenceSession(
|
| 278 |
+
model_path,
|
| 279 |
+
sess_options=self.session_opts,
|
| 280 |
+
providers=["CPUExecutionProvider"],
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def init_fm_decoder(self, model_path: str):
|
| 284 |
+
self.fm_decoder = ort.InferenceSession(
|
| 285 |
+
model_path,
|
| 286 |
+
sess_options=self.session_opts,
|
| 287 |
+
providers=["CPUExecutionProvider"],
|
| 288 |
+
)
|
| 289 |
+
meta = self.fm_decoder.get_modelmeta().custom_metadata_map
|
| 290 |
+
self.feat_dim = int(meta["feat_dim"])
|
| 291 |
+
|
| 292 |
+
def run_text_encoder(
|
| 293 |
+
self,
|
| 294 |
+
tokens: Tensor,
|
| 295 |
+
prompt_tokens: Tensor,
|
| 296 |
+
prompt_features_len: Tensor,
|
| 297 |
+
speed: Tensor,
|
| 298 |
+
) -> Tuple[Tensor, Tensor]:
|
| 299 |
+
out = self.text_encoder.run(
|
| 300 |
+
[
|
| 301 |
+
self.text_encoder.get_outputs()[0].name,
|
| 302 |
+
],
|
| 303 |
+
{
|
| 304 |
+
self.text_encoder.get_inputs()[0].name: tokens.numpy(),
|
| 305 |
+
self.text_encoder.get_inputs()[1].name: prompt_tokens.numpy(),
|
| 306 |
+
self.text_encoder.get_inputs()[2].name: prompt_features_len.numpy(),
|
| 307 |
+
self.text_encoder.get_inputs()[3].name: speed.numpy(),
|
| 308 |
+
},
|
| 309 |
+
)
|
| 310 |
+
return torch.from_numpy(out[0])
|
| 311 |
+
|
| 312 |
+
def run_fm_decoder(
|
| 313 |
+
self,
|
| 314 |
+
t: Tensor,
|
| 315 |
+
x: Tensor,
|
| 316 |
+
text_condition: Tensor,
|
| 317 |
+
speech_condition: torch.Tensor,
|
| 318 |
+
guidance_scale: Tensor,
|
| 319 |
+
) -> Tensor:
|
| 320 |
+
out = self.fm_decoder.run(
|
| 321 |
+
[
|
| 322 |
+
self.fm_decoder.get_outputs()[0].name,
|
| 323 |
+
],
|
| 324 |
+
{
|
| 325 |
+
self.fm_decoder.get_inputs()[0].name: t.numpy(),
|
| 326 |
+
self.fm_decoder.get_inputs()[1].name: x.numpy(),
|
| 327 |
+
self.fm_decoder.get_inputs()[2].name: text_condition.numpy(),
|
| 328 |
+
self.fm_decoder.get_inputs()[3].name: speech_condition.numpy(),
|
| 329 |
+
self.fm_decoder.get_inputs()[4].name: guidance_scale.numpy(),
|
| 330 |
+
},
|
| 331 |
+
)
|
| 332 |
+
return torch.from_numpy(out[0])
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def sample(
|
| 336 |
+
model: OnnxModel,
|
| 337 |
+
tokens: List[List[int]],
|
| 338 |
+
prompt_tokens: List[List[int]],
|
| 339 |
+
prompt_features: Tensor,
|
| 340 |
+
speed: float = 1.0,
|
| 341 |
+
t_shift: float = 0.5,
|
| 342 |
+
guidance_scale: float = 1.0,
|
| 343 |
+
num_step: int = 16,
|
| 344 |
+
) -> torch.Tensor:
|
| 345 |
+
"""
|
| 346 |
+
Generate acoustic features, given text tokens, prompts feature and prompt
|
| 347 |
+
transcription's text tokens.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
tokens: a list of list of text tokens.
|
| 351 |
+
prompt_tokens: a list of list of prompt tokens.
|
| 352 |
+
prompt_features: the prompt feature with the shape
|
| 353 |
+
(batch_size, seq_len, feat_dim).
|
| 354 |
+
speed : speed control.
|
| 355 |
+
t_shift: time shift.
|
| 356 |
+
guidance_scale: the guidance scale for classifier-free guidance.
|
| 357 |
+
num_step: the number of steps to use in the ODE solver.
|
| 358 |
+
"""
|
| 359 |
+
# Run text encoder
|
| 360 |
+
assert len(tokens) == len(prompt_tokens) == 1
|
| 361 |
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
| 362 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.int64)
|
| 363 |
+
prompt_features_len = torch.tensor(prompt_features.size(1), dtype=torch.int64)
|
| 364 |
+
speed = torch.tensor(speed, dtype=torch.float32)
|
| 365 |
+
|
| 366 |
+
text_condition = model.run_text_encoder(
|
| 367 |
+
tokens, prompt_tokens, prompt_features_len, speed
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
batch_size, num_frames, _ = text_condition.shape
|
| 371 |
+
assert batch_size == 1
|
| 372 |
+
feat_dim = model.feat_dim
|
| 373 |
+
|
| 374 |
+
# Run flow matching model
|
| 375 |
+
timesteps = get_time_steps(
|
| 376 |
+
t_start=0.0,
|
| 377 |
+
t_end=1.0,
|
| 378 |
+
num_step=num_step,
|
| 379 |
+
t_shift=t_shift,
|
| 380 |
+
)
|
| 381 |
+
x = torch.randn(batch_size, num_frames, feat_dim)
|
| 382 |
+
speech_condition = torch.nn.functional.pad(
|
| 383 |
+
prompt_features, (0, 0, 0, num_frames - prompt_features.shape[1])
|
| 384 |
+
) # (B, T, F)
|
| 385 |
+
guidance_scale = torch.tensor(guidance_scale, dtype=torch.float32)
|
| 386 |
+
|
| 387 |
+
for step in range(num_step):
|
| 388 |
+
v = model.run_fm_decoder(
|
| 389 |
+
t=timesteps[step],
|
| 390 |
+
x=x,
|
| 391 |
+
text_condition=text_condition,
|
| 392 |
+
speech_condition=speech_condition,
|
| 393 |
+
guidance_scale=guidance_scale,
|
| 394 |
+
)
|
| 395 |
+
x = x + v * (timesteps[step + 1] - timesteps[step])
|
| 396 |
+
|
| 397 |
+
x = x[:, prompt_features_len.item() :, :]
|
| 398 |
+
return x
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# Copied from zipvoice/infer/infer_zipvoice.py, but call an external sample function
|
| 402 |
+
def generate_sentence(
|
| 403 |
+
save_path: str,
|
| 404 |
+
prompt_text: str,
|
| 405 |
+
prompt_wav: str,
|
| 406 |
+
text: str,
|
| 407 |
+
model: OnnxModel,
|
| 408 |
+
vocoder: nn.Module,
|
| 409 |
+
tokenizer: EmiliaTokenizer,
|
| 410 |
+
feature_extractor: VocosFbank,
|
| 411 |
+
num_step: int = 16,
|
| 412 |
+
guidance_scale: float = 1.0,
|
| 413 |
+
speed: float = 1.0,
|
| 414 |
+
t_shift: float = 0.5,
|
| 415 |
+
target_rms: float = 0.1,
|
| 416 |
+
feat_scale: float = 0.1,
|
| 417 |
+
sampling_rate: int = 24000,
|
| 418 |
+
):
|
| 419 |
+
"""
|
| 420 |
+
Generate waveform of a text based on a given prompt
|
| 421 |
+
waveform and its transcription.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
save_path (str): Path to save the generated wav.
|
| 425 |
+
prompt_text (str): Transcription of the prompt wav.
|
| 426 |
+
prompt_wav (str): Path to the prompt wav file.
|
| 427 |
+
text (str): Text to be synthesized into a waveform.
|
| 428 |
+
model (torch.nn.Module): The model used for generation.
|
| 429 |
+
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
|
| 430 |
+
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
|
| 431 |
+
feature_extractor (VocosFbank): The feature extractor used to
|
| 432 |
+
extract acoustic features.
|
| 433 |
+
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
| 434 |
+
guidance_scale (float, optional): Scale for classifier-free guidance.
|
| 435 |
+
Defaults to 1.0.
|
| 436 |
+
speed (float, optional): Speed control. Defaults to 1.0.
|
| 437 |
+
t_shift (float, optional): Time shift. Defaults to 0.5.
|
| 438 |
+
target_rms (float, optional): Target RMS for waveform normalization.
|
| 439 |
+
Defaults to 0.1.
|
| 440 |
+
feat_scale (float, optional): Scale for features.
|
| 441 |
+
Defaults to 0.1.
|
| 442 |
+
sampling_rate (int, optional): Sampling rate for the waveform.
|
| 443 |
+
Defaults to 24000.
|
| 444 |
+
Returns:
|
| 445 |
+
metrics (dict): Dictionary containing time and real-time
|
| 446 |
+
factor metrics for processing.
|
| 447 |
+
"""
|
| 448 |
+
# Convert text to tokens
|
| 449 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 450 |
+
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
| 451 |
+
|
| 452 |
+
# Load and preprocess prompt wav
|
| 453 |
+
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
| 454 |
+
|
| 455 |
+
if prompt_sampling_rate != sampling_rate:
|
| 456 |
+
resampler = torchaudio.transforms.Resample(
|
| 457 |
+
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
| 458 |
+
)
|
| 459 |
+
prompt_wav = resampler(prompt_wav)
|
| 460 |
+
|
| 461 |
+
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
| 462 |
+
if prompt_rms < target_rms:
|
| 463 |
+
prompt_wav = prompt_wav * target_rms / prompt_rms
|
| 464 |
+
|
| 465 |
+
# Extract features from prompt wav
|
| 466 |
+
prompt_features = feature_extractor.extract(prompt_wav, sampling_rate=sampling_rate)
|
| 467 |
+
|
| 468 |
+
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
| 469 |
+
|
| 470 |
+
# Start timing
|
| 471 |
+
start_t = dt.datetime.now()
|
| 472 |
+
|
| 473 |
+
# Generate features
|
| 474 |
+
pred_features = sample(
|
| 475 |
+
model=model,
|
| 476 |
+
tokens=tokens,
|
| 477 |
+
prompt_tokens=prompt_tokens,
|
| 478 |
+
prompt_features=prompt_features,
|
| 479 |
+
speed=speed,
|
| 480 |
+
t_shift=t_shift,
|
| 481 |
+
guidance_scale=guidance_scale,
|
| 482 |
+
num_step=num_step,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Postprocess predicted features
|
| 486 |
+
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
| 487 |
+
|
| 488 |
+
# Start vocoder processing
|
| 489 |
+
start_vocoder_t = dt.datetime.now()
|
| 490 |
+
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
| 491 |
+
|
| 492 |
+
# Calculate processing times and real-time factors
|
| 493 |
+
t = (dt.datetime.now() - start_t).total_seconds()
|
| 494 |
+
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
| 495 |
+
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
| 496 |
+
wav_seconds = wav.shape[-1] / sampling_rate
|
| 497 |
+
rtf = t / wav_seconds
|
| 498 |
+
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
| 499 |
+
rtf_vocoder = t_vocoder / wav_seconds
|
| 500 |
+
metrics = {
|
| 501 |
+
"t": t,
|
| 502 |
+
"t_no_vocoder": t_no_vocoder,
|
| 503 |
+
"t_vocoder": t_vocoder,
|
| 504 |
+
"wav_seconds": wav_seconds,
|
| 505 |
+
"rtf": rtf,
|
| 506 |
+
"rtf_no_vocoder": rtf_no_vocoder,
|
| 507 |
+
"rtf_vocoder": rtf_vocoder,
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
# Adjust wav volume if necessary
|
| 511 |
+
if prompt_rms < target_rms:
|
| 512 |
+
wav = wav * prompt_rms / target_rms
|
| 513 |
+
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
| 514 |
+
|
| 515 |
+
return metrics
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def generate_list(
|
| 519 |
+
res_dir: str,
|
| 520 |
+
test_list: str,
|
| 521 |
+
model: OnnxModel,
|
| 522 |
+
vocoder: nn.Module,
|
| 523 |
+
tokenizer: EmiliaTokenizer,
|
| 524 |
+
feature_extractor: VocosFbank,
|
| 525 |
+
num_step: int = 16,
|
| 526 |
+
guidance_scale: float = 1.0,
|
| 527 |
+
speed: float = 1.0,
|
| 528 |
+
t_shift: float = 0.5,
|
| 529 |
+
target_rms: float = 0.1,
|
| 530 |
+
feat_scale: float = 0.1,
|
| 531 |
+
sampling_rate: int = 24000,
|
| 532 |
+
):
|
| 533 |
+
total_t = []
|
| 534 |
+
total_t_no_vocoder = []
|
| 535 |
+
total_t_vocoder = []
|
| 536 |
+
total_wav_seconds = []
|
| 537 |
+
|
| 538 |
+
with open(test_list, "r") as fr:
|
| 539 |
+
lines = fr.readlines()
|
| 540 |
+
|
| 541 |
+
for i, line in enumerate(lines):
|
| 542 |
+
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
| 543 |
+
save_path = f"{res_dir}/{wav_name}.wav"
|
| 544 |
+
metrics = generate_sentence(
|
| 545 |
+
save_path=save_path,
|
| 546 |
+
prompt_text=prompt_text,
|
| 547 |
+
prompt_wav=prompt_wav,
|
| 548 |
+
text=text,
|
| 549 |
+
model=model,
|
| 550 |
+
vocoder=vocoder,
|
| 551 |
+
tokenizer=tokenizer,
|
| 552 |
+
feature_extractor=feature_extractor,
|
| 553 |
+
num_step=num_step,
|
| 554 |
+
guidance_scale=guidance_scale,
|
| 555 |
+
speed=speed,
|
| 556 |
+
t_shift=t_shift,
|
| 557 |
+
target_rms=target_rms,
|
| 558 |
+
feat_scale=feat_scale,
|
| 559 |
+
sampling_rate=sampling_rate,
|
| 560 |
+
)
|
| 561 |
+
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
| 562 |
+
total_t.append(metrics["t"])
|
| 563 |
+
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
| 564 |
+
total_t_vocoder.append(metrics["t_vocoder"])
|
| 565 |
+
total_wav_seconds.append(metrics["wav_seconds"])
|
| 566 |
+
|
| 567 |
+
print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
|
| 568 |
+
print(
|
| 569 |
+
f"Average RTF w/o vocoder: "
|
| 570 |
+
f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 571 |
+
)
|
| 572 |
+
print(
|
| 573 |
+
f"Average RTF vocoder: "
|
| 574 |
+
f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
@torch.inference_mode()
|
| 579 |
+
def main():
|
| 580 |
+
parser = get_parser()
|
| 581 |
+
args = parser.parse_args()
|
| 582 |
+
|
| 583 |
+
params = AttributeDict()
|
| 584 |
+
params.update(vars(args))
|
| 585 |
+
fix_random_seed(params.seed)
|
| 586 |
+
|
| 587 |
+
model_defaults = {
|
| 588 |
+
"zipvoice": {
|
| 589 |
+
"num_step": 16,
|
| 590 |
+
"guidance_scale": 1.0,
|
| 591 |
+
},
|
| 592 |
+
"zipvoice_distill": {
|
| 593 |
+
"num_step": 8,
|
| 594 |
+
"guidance_scale": 3.0,
|
| 595 |
+
},
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
model_specific_defaults = model_defaults.get(params.model_name, {})
|
| 599 |
+
|
| 600 |
+
for param, value in model_specific_defaults.items():
|
| 601 |
+
if getattr(params, param) is None:
|
| 602 |
+
setattr(params, param, value)
|
| 603 |
+
print(f"Setting {param} to default value: {value}")
|
| 604 |
+
|
| 605 |
+
assert (params.test_list is not None) ^ (
|
| 606 |
+
(params.prompt_wav and params.prompt_text and params.text) is not None
|
| 607 |
+
), (
|
| 608 |
+
"For inference, please provide prompts and text with either '--test-list'"
|
| 609 |
+
" or '--prompt-wav, --prompt-text and --text'."
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
print("Loading model...")
|
| 613 |
+
if params.model_config is None:
|
| 614 |
+
model_config = hf_hub_download(
|
| 615 |
+
HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name]
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
model_config = params.model_config
|
| 619 |
+
|
| 620 |
+
with open(model_config, "r") as f:
|
| 621 |
+
model_config = json.load(f)
|
| 622 |
+
|
| 623 |
+
if params.token_file is None:
|
| 624 |
+
token_file = hf_hub_download(
|
| 625 |
+
HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name]
|
| 626 |
+
)
|
| 627 |
+
else:
|
| 628 |
+
token_file = params.token_file
|
| 629 |
+
|
| 630 |
+
if params.tokenizer == "emilia":
|
| 631 |
+
tokenizer = EmiliaTokenizer(token_file=token_file)
|
| 632 |
+
elif params.dataset == "libritts":
|
| 633 |
+
tokenizer = LibriTTSTokenizer(token_file=token_file)
|
| 634 |
+
elif params.tokenizer == "espeak":
|
| 635 |
+
tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
|
| 636 |
+
else:
|
| 637 |
+
assert params.tokenizer == "simple"
|
| 638 |
+
tokenizer = SimpleTokenizer(token_file=token_file)
|
| 639 |
+
|
| 640 |
+
if params.onnx_model_dir is not None:
|
| 641 |
+
dirname = params.onnx_model_dir
|
| 642 |
+
else:
|
| 643 |
+
if params.model_name == "zipvoice_distill":
|
| 644 |
+
dirname = "zipvoice_distill"
|
| 645 |
+
else:
|
| 646 |
+
dirname = "zipvoice"
|
| 647 |
+
|
| 648 |
+
if not params.onnx_int8:
|
| 649 |
+
text_encoder_path = f"{dirname}/text_encoder.onnx"
|
| 650 |
+
fm_decoder_path = f"{dirname}/fm_decoder.onnx"
|
| 651 |
+
else:
|
| 652 |
+
text_encoder_path = f"{dirname}/text_encoder_int8.onnx"
|
| 653 |
+
fm_decoder_path = f"{dirname}/fm_decoder_int8.onnx"
|
| 654 |
+
if params.onnx_model_dir is None:
|
| 655 |
+
text_encoder_path = hf_hub_download(
|
| 656 |
+
HUGGINGFACE_REPO, filename=text_encoder_path
|
| 657 |
+
)
|
| 658 |
+
fm_decoder_path = hf_hub_download(HUGGINGFACE_REPO, filename=fm_decoder_path)
|
| 659 |
+
|
| 660 |
+
model = OnnxModel(text_encoder_path, fm_decoder_path)
|
| 661 |
+
|
| 662 |
+
vocoder = get_vocoder(params.vocoder_path)
|
| 663 |
+
vocoder.eval()
|
| 664 |
+
|
| 665 |
+
if model_config["feature"]["type"] == "vocos":
|
| 666 |
+
feature_extractor = VocosFbank()
|
| 667 |
+
else:
|
| 668 |
+
raise NotImplementedError(
|
| 669 |
+
f"Unsupported feature type: {model_config['feature']['type']}"
|
| 670 |
+
)
|
| 671 |
+
params.sampling_rate = model_config["feature"]["sampling_rate"]
|
| 672 |
+
|
| 673 |
+
print("Start generating...")
|
| 674 |
+
if params.test_list:
|
| 675 |
+
os.makedirs(params.res_dir, exist_ok=True)
|
| 676 |
+
generate_list(
|
| 677 |
+
res_dir=params.res_dir,
|
| 678 |
+
test_list=params.test_list,
|
| 679 |
+
model=model,
|
| 680 |
+
vocoder=vocoder,
|
| 681 |
+
tokenizer=tokenizer,
|
| 682 |
+
feature_extractor=feature_extractor,
|
| 683 |
+
num_step=params.num_step,
|
| 684 |
+
guidance_scale=params.guidance_scale,
|
| 685 |
+
speed=params.speed,
|
| 686 |
+
t_shift=params.t_shift,
|
| 687 |
+
target_rms=params.target_rms,
|
| 688 |
+
feat_scale=params.feat_scale,
|
| 689 |
+
sampling_rate=params.sampling_rate,
|
| 690 |
+
)
|
| 691 |
+
else:
|
| 692 |
+
generate_sentence(
|
| 693 |
+
save_path=params.res_wav_path,
|
| 694 |
+
prompt_text=params.prompt_text,
|
| 695 |
+
prompt_wav=params.prompt_wav,
|
| 696 |
+
text=params.text,
|
| 697 |
+
model=model,
|
| 698 |
+
vocoder=vocoder,
|
| 699 |
+
tokenizer=tokenizer,
|
| 700 |
+
feature_extractor=feature_extractor,
|
| 701 |
+
num_step=params.num_step,
|
| 702 |
+
guidance_scale=params.guidance_scale,
|
| 703 |
+
speed=params.speed,
|
| 704 |
+
t_shift=params.t_shift,
|
| 705 |
+
target_rms=params.target_rms,
|
| 706 |
+
feat_scale=params.feat_scale,
|
| 707 |
+
sampling_rate=params.sampling_rate,
|
| 708 |
+
)
|
| 709 |
+
print("Done")
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
if __name__ == "__main__":
|
| 713 |
+
torch.set_num_threads(1)
|
| 714 |
+
torch.set_num_interop_threads(1)
|
| 715 |
+
main()
|
zipvoice/bin/onnx_export.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This script exports a pre-trained ZipVoice or ZipVoice-Distill model from PyTorch to
|
| 20 |
+
ONNX.
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
|
| 24 |
+
python3 -m zipvoice.bin.onnx_export \
|
| 25 |
+
--model-name zipvoice \
|
| 26 |
+
--token-file data/tokens_emilia.txt \
|
| 27 |
+
--checkpoint exp/zipvoice/epoch-11-avg-4.pt \
|
| 28 |
+
--model-config conf/zipvoice_base.json \
|
| 29 |
+
--onnx-model-dir exp/zipvoice_onnx
|
| 30 |
+
|
| 31 |
+
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
| 32 |
+
which are the models before and after distillation, respectively.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
import argparse
|
| 37 |
+
import json
|
| 38 |
+
import os
|
| 39 |
+
from typing import Dict
|
| 40 |
+
|
| 41 |
+
import onnx
|
| 42 |
+
import safetensors.torch
|
| 43 |
+
import torch
|
| 44 |
+
from onnxruntime.quantization import QuantType, quantize_dynamic
|
| 45 |
+
from torch import Tensor, nn
|
| 46 |
+
|
| 47 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 48 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 49 |
+
from zipvoice.tokenizer.tokenizer import SimpleTokenizer
|
| 50 |
+
from zipvoice.utils.checkpoint import load_checkpoint
|
| 51 |
+
from zipvoice.utils.common import AttributeDict
|
| 52 |
+
from zipvoice.utils.scaling_converter import convert_scaled_to_non_scaled
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_parser():
|
| 56 |
+
parser = argparse.ArgumentParser(
|
| 57 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--onnx-model-dir",
|
| 62 |
+
type=str,
|
| 63 |
+
default="exp",
|
| 64 |
+
help="Dir to the exported models",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--model-name",
|
| 69 |
+
type=str,
|
| 70 |
+
default="zipvoice",
|
| 71 |
+
choices=["zipvoice", "zipvoice_distill"],
|
| 72 |
+
help="The model used for inference",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--token-file",
|
| 77 |
+
type=str,
|
| 78 |
+
default="data/tokens_emilia.txt",
|
| 79 |
+
help="The file that contains information that maps tokens to ids,"
|
| 80 |
+
"which is a text file with '{token}\t{token_id}' per line.",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--checkpoint",
|
| 85 |
+
type=str,
|
| 86 |
+
default="exp_zipvoice/epoch-11-avg-4.pt",
|
| 87 |
+
help="The model checkpoint.",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--model-config",
|
| 92 |
+
type=str,
|
| 93 |
+
default="conf/zipvoice_base.json",
|
| 94 |
+
help="The model configuration file.",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return parser
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
| 101 |
+
"""Add meta data to an ONNX model. It is changed in-place.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
filename:
|
| 105 |
+
Filename of the ONNX model to be changed.
|
| 106 |
+
meta_data:
|
| 107 |
+
Key-value pairs.
|
| 108 |
+
"""
|
| 109 |
+
model = onnx.load(filename)
|
| 110 |
+
for key, value in meta_data.items():
|
| 111 |
+
meta = model.metadata_props.add()
|
| 112 |
+
meta.key = key
|
| 113 |
+
meta.value = value
|
| 114 |
+
|
| 115 |
+
onnx.save(model, filename)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class OnnxTextModel(nn.Module):
|
| 119 |
+
def __init__(self, model: nn.Module):
|
| 120 |
+
"""A wrapper for ZipVoice text encoder."""
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.embed = model.embed
|
| 123 |
+
self.text_encoder = model.text_encoder
|
| 124 |
+
self.pad_id = model.pad_id
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
tokens: Tensor,
|
| 129 |
+
prompt_tokens: Tensor,
|
| 130 |
+
prompt_features_len: Tensor,
|
| 131 |
+
speed: Tensor,
|
| 132 |
+
) -> Tensor:
|
| 133 |
+
cat_tokens = torch.cat([prompt_tokens, tokens], dim=1)
|
| 134 |
+
cat_tokens = nn.functional.pad(cat_tokens, (0, 1), value=self.pad_id)
|
| 135 |
+
tokens_len = cat_tokens.shape[1] - 1
|
| 136 |
+
padding_mask = (torch.arange(tokens_len + 1) == tokens_len).unsqueeze(0)
|
| 137 |
+
|
| 138 |
+
embed = self.embed(cat_tokens)
|
| 139 |
+
embed = self.text_encoder(x=embed, t=None, padding_mask=padding_mask)
|
| 140 |
+
|
| 141 |
+
features_len = torch.ceil(
|
| 142 |
+
(prompt_features_len / prompt_tokens.shape[1] * tokens_len / speed)
|
| 143 |
+
).to(dtype=torch.int64)
|
| 144 |
+
|
| 145 |
+
token_dur = torch.div(features_len, tokens_len, rounding_mode="floor").to(
|
| 146 |
+
dtype=torch.int64
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
text_condition = embed[:, :-1, :].unsqueeze(2).expand(-1, -1, token_dur, -1)
|
| 150 |
+
text_condition = text_condition.reshape(embed.shape[0], -1, embed.shape[2])
|
| 151 |
+
|
| 152 |
+
text_condition = torch.cat(
|
| 153 |
+
[
|
| 154 |
+
text_condition,
|
| 155 |
+
embed[:, -1:, :].expand(-1, features_len - text_condition.shape[1], -1),
|
| 156 |
+
],
|
| 157 |
+
dim=1,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
return text_condition
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class OnnxFlowMatchingModel(nn.Module):
|
| 164 |
+
def __init__(self, model: nn.Module):
|
| 165 |
+
"""A wrapper for ZipVoice flow-matching decoder."""
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.distill = model.distill
|
| 168 |
+
self.fm_decoder = model.fm_decoder
|
| 169 |
+
self.model_func = getattr(model, "forward_fm_decoder")
|
| 170 |
+
self.feat_dim = model.feat_dim
|
| 171 |
+
|
| 172 |
+
def forward(
|
| 173 |
+
self,
|
| 174 |
+
t: Tensor,
|
| 175 |
+
x: Tensor,
|
| 176 |
+
text_condition: Tensor,
|
| 177 |
+
speech_condition: torch.Tensor,
|
| 178 |
+
guidance_scale: Tensor,
|
| 179 |
+
) -> Tensor:
|
| 180 |
+
if self.distill:
|
| 181 |
+
return self.model_func(
|
| 182 |
+
t=t,
|
| 183 |
+
xt=x,
|
| 184 |
+
text_condition=text_condition,
|
| 185 |
+
speech_condition=speech_condition,
|
| 186 |
+
guidance_scale=guidance_scale,
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
x = x.repeat(2, 1, 1)
|
| 190 |
+
text_condition = torch.cat(
|
| 191 |
+
[torch.zeros_like(text_condition), text_condition], dim=0
|
| 192 |
+
)
|
| 193 |
+
speech_condition = torch.cat(
|
| 194 |
+
[
|
| 195 |
+
torch.where(
|
| 196 |
+
t > 0.5, torch.zeros_like(speech_condition), speech_condition
|
| 197 |
+
),
|
| 198 |
+
speech_condition,
|
| 199 |
+
],
|
| 200 |
+
dim=0,
|
| 201 |
+
)
|
| 202 |
+
guidance_scale = torch.where(t > 0.5, guidance_scale, guidance_scale * 2.0)
|
| 203 |
+
data_uncond, data_cond = self.model_func(
|
| 204 |
+
t=t,
|
| 205 |
+
xt=x,
|
| 206 |
+
text_condition=text_condition,
|
| 207 |
+
speech_condition=speech_condition,
|
| 208 |
+
).chunk(2, dim=0)
|
| 209 |
+
v = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
|
| 210 |
+
return v
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def export_text_encoder(
|
| 214 |
+
model: OnnxTextModel,
|
| 215 |
+
filename: str,
|
| 216 |
+
opset_version: int = 11,
|
| 217 |
+
) -> None:
|
| 218 |
+
"""Export the text encoder model to ONNX format.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
model:
|
| 222 |
+
The input model
|
| 223 |
+
filename:
|
| 224 |
+
The filename to save the exported ONNX model.
|
| 225 |
+
opset_version:
|
| 226 |
+
The opset version to use.
|
| 227 |
+
"""
|
| 228 |
+
tokens = torch.tensor([[2, 3, 4, 5]], dtype=torch.int64)
|
| 229 |
+
prompt_tokens = torch.tensor([[0, 1]], dtype=torch.int64)
|
| 230 |
+
prompt_features_len = torch.tensor(10, dtype=torch.int64)
|
| 231 |
+
speed = torch.tensor(1.0, dtype=torch.float32)
|
| 232 |
+
|
| 233 |
+
model = torch.jit.trace(model, (tokens, prompt_tokens, prompt_features_len, speed))
|
| 234 |
+
|
| 235 |
+
torch.onnx.export(
|
| 236 |
+
model,
|
| 237 |
+
(tokens, prompt_tokens, prompt_features_len, speed),
|
| 238 |
+
filename,
|
| 239 |
+
verbose=False,
|
| 240 |
+
opset_version=opset_version,
|
| 241 |
+
input_names=["tokens", "prompt_tokens", "prompt_features_len", "speed"],
|
| 242 |
+
output_names=["text_condition"],
|
| 243 |
+
dynamic_axes={
|
| 244 |
+
"tokens": {0: "N", 1: "T"},
|
| 245 |
+
"prompt_tokens": {0: "N", 1: "T"},
|
| 246 |
+
"text_condition": {0: "N", 1: "T"},
|
| 247 |
+
},
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
meta_data = {
|
| 251 |
+
"version": "1",
|
| 252 |
+
"model_author": "k2-fsa",
|
| 253 |
+
"comment": "ZipVoice text encoder",
|
| 254 |
+
}
|
| 255 |
+
print(f"meta_data: {meta_data}")
|
| 256 |
+
add_meta_data(filename=filename, meta_data=meta_data)
|
| 257 |
+
|
| 258 |
+
print(f"Exported to {filename}")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def export_fm_decoder(
|
| 262 |
+
model: OnnxFlowMatchingModel,
|
| 263 |
+
filename: str,
|
| 264 |
+
opset_version: int = 11,
|
| 265 |
+
) -> None:
|
| 266 |
+
"""Export the flow matching decoder model to ONNX format.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
model:
|
| 270 |
+
The input model
|
| 271 |
+
filename:
|
| 272 |
+
The filename to save the exported ONNX model.
|
| 273 |
+
opset_version:
|
| 274 |
+
The opset version to use.
|
| 275 |
+
"""
|
| 276 |
+
feat_dim = model.feat_dim
|
| 277 |
+
seq_len = 200
|
| 278 |
+
t = torch.tensor(0.5, dtype=torch.float32)
|
| 279 |
+
x = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
|
| 280 |
+
text_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
|
| 281 |
+
speech_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
|
| 282 |
+
guidance_scale = torch.tensor(1.0, dtype=torch.float32)
|
| 283 |
+
|
| 284 |
+
model = torch.jit.trace(
|
| 285 |
+
model, (t, x, text_condition, speech_condition, guidance_scale)
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
torch.onnx.export(
|
| 289 |
+
model,
|
| 290 |
+
(t, x, text_condition, speech_condition, guidance_scale),
|
| 291 |
+
filename,
|
| 292 |
+
verbose=False,
|
| 293 |
+
opset_version=opset_version,
|
| 294 |
+
input_names=["t", "x", "text_condition", "speech_condition", "guidance_scale"],
|
| 295 |
+
output_names=["v"],
|
| 296 |
+
dynamic_axes={
|
| 297 |
+
"x": {0: "N", 1: "T"},
|
| 298 |
+
"text_condition": {0: "N", 1: "T"},
|
| 299 |
+
"speech_condition": {0: "N", 1: "T"},
|
| 300 |
+
"v": {0: "N", 1: "T"},
|
| 301 |
+
},
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
meta_data = {
|
| 305 |
+
"version": "1",
|
| 306 |
+
"model_author": "k2-fsa",
|
| 307 |
+
"comment": "ZipVoice flow-matching decoder",
|
| 308 |
+
"feat_dim": str(feat_dim),
|
| 309 |
+
}
|
| 310 |
+
print(f"meta_data: {meta_data}")
|
| 311 |
+
add_meta_data(filename=filename, meta_data=meta_data)
|
| 312 |
+
|
| 313 |
+
print(f"Exported to {filename}")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
@torch.no_grad()
|
| 317 |
+
def main():
|
| 318 |
+
parser = get_parser()
|
| 319 |
+
args = parser.parse_args()
|
| 320 |
+
|
| 321 |
+
params = AttributeDict()
|
| 322 |
+
params.update(vars(args))
|
| 323 |
+
|
| 324 |
+
model_config = params.model_config
|
| 325 |
+
with open(model_config, "r") as f:
|
| 326 |
+
model_config = json.load(f)
|
| 327 |
+
for key, value in model_config["model"].items():
|
| 328 |
+
setattr(params, key, value)
|
| 329 |
+
for key, value in model_config["feature"].items():
|
| 330 |
+
setattr(params, key, value)
|
| 331 |
+
|
| 332 |
+
token_file = params.token_file
|
| 333 |
+
tokenizer = SimpleTokenizer(token_file)
|
| 334 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 335 |
+
|
| 336 |
+
if params.model_name == "zipvoice":
|
| 337 |
+
model = ZipVoice(
|
| 338 |
+
**model_config["model"],
|
| 339 |
+
**tokenizer_config,
|
| 340 |
+
)
|
| 341 |
+
else:
|
| 342 |
+
assert params.model_name == "zipvoice_distill"
|
| 343 |
+
model = ZipVoiceDistill(
|
| 344 |
+
**model_config["model"],
|
| 345 |
+
**tokenizer_config,
|
| 346 |
+
)
|
| 347 |
+
model_ckpt = params.checkpoint
|
| 348 |
+
|
| 349 |
+
if model_ckpt.endswith(".safetensors"):
|
| 350 |
+
safetensors.torch.load_model(model, model_ckpt)
|
| 351 |
+
elif model_ckpt.endswith(".pt"):
|
| 352 |
+
load_checkpoint(filename=model_ckpt, model=model, strict=True)
|
| 353 |
+
else:
|
| 354 |
+
raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
|
| 355 |
+
|
| 356 |
+
device = torch.device("cpu")
|
| 357 |
+
model = model.to(device)
|
| 358 |
+
model.eval()
|
| 359 |
+
|
| 360 |
+
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
| 361 |
+
|
| 362 |
+
print("Exporting model")
|
| 363 |
+
os.makedirs(params.onnx_model_dir, exist_ok=True)
|
| 364 |
+
opset_version = 11
|
| 365 |
+
|
| 366 |
+
text_encoder = OnnxTextModel(model=model)
|
| 367 |
+
text_encoder_file = f"{params.onnx_model_dir}/text_encoder.onnx"
|
| 368 |
+
export_text_encoder(
|
| 369 |
+
model=text_encoder,
|
| 370 |
+
filename=text_encoder_file,
|
| 371 |
+
opset_version=opset_version,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
fm_decoder = OnnxFlowMatchingModel(model=model)
|
| 375 |
+
fm_decoder_file = f"{params.onnx_model_dir}/fm_decoder.onnx"
|
| 376 |
+
export_fm_decoder(
|
| 377 |
+
model=fm_decoder,
|
| 378 |
+
filename=fm_decoder_file,
|
| 379 |
+
opset_version=opset_version,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
print("Generate int8 quantization models")
|
| 383 |
+
|
| 384 |
+
text_encoder_int8_file = f"{params.onnx_model_dir}/text_encoder_int8.onnx"
|
| 385 |
+
quantize_dynamic(
|
| 386 |
+
model_input=text_encoder_file,
|
| 387 |
+
model_output=text_encoder_int8_file,
|
| 388 |
+
op_types_to_quantize=["MatMul"],
|
| 389 |
+
weight_type=QuantType.QInt8,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
fm_decoder_int8_file = f"{params.onnx_model_dir}/fm_decoder_int8.onnx"
|
| 393 |
+
quantize_dynamic(
|
| 394 |
+
model_input=fm_decoder_file,
|
| 395 |
+
model_output=fm_decoder_int8_file,
|
| 396 |
+
op_types_to_quantize=["MatMul"],
|
| 397 |
+
weight_type=QuantType.QInt8,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
print("Done!")
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
if __name__ == "__main__":
|
| 404 |
+
main()
|
zipvoice/bin/train_zipvoice.py
ADDED
|
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang,
|
| 3 |
+
# Han Zhu)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
This script trains a ZipVoice model with the flow-matching loss.
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
|
| 24 |
+
python3 -m zipvoice.bin.train_zipvoice \
|
| 25 |
+
--world-size 8 \
|
| 26 |
+
--use-fp16 1 \
|
| 27 |
+
--num-epochs 11 \
|
| 28 |
+
--max-duration 500 \
|
| 29 |
+
--lr-hours 30000 \
|
| 30 |
+
--model-config conf/zipvoice_base.json \
|
| 31 |
+
--tokenizer emilia \
|
| 32 |
+
--token-file "data/tokens_emilia.txt" \
|
| 33 |
+
--dataset emilia \
|
| 34 |
+
--manifest-dir data/fbank \
|
| 35 |
+
--exp-dir exp/zipvoice
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import copy
|
| 40 |
+
import json
|
| 41 |
+
import logging
|
| 42 |
+
import os
|
| 43 |
+
from functools import partial
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
from shutil import copyfile
|
| 46 |
+
from typing import List, Optional, Tuple, Union
|
| 47 |
+
|
| 48 |
+
import torch
|
| 49 |
+
import torch.multiprocessing as mp
|
| 50 |
+
import torch.nn as nn
|
| 51 |
+
from lhotse.cut import Cut, CutSet
|
| 52 |
+
from lhotse.utils import fix_random_seed
|
| 53 |
+
from torch import Tensor
|
| 54 |
+
from torch.amp import GradScaler, autocast
|
| 55 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 56 |
+
from torch.optim import Optimizer
|
| 57 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 58 |
+
|
| 59 |
+
import zipvoice.utils.diagnostics as diagnostics
|
| 60 |
+
from zipvoice.dataset.datamodule import TtsDataModule
|
| 61 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 62 |
+
from zipvoice.tokenizer.tokenizer import (
|
| 63 |
+
EmiliaTokenizer,
|
| 64 |
+
EspeakTokenizer,
|
| 65 |
+
LibriTTSTokenizer,
|
| 66 |
+
SimpleTokenizer,
|
| 67 |
+
)
|
| 68 |
+
from zipvoice.utils.checkpoint import (
|
| 69 |
+
load_checkpoint,
|
| 70 |
+
remove_checkpoints,
|
| 71 |
+
resume_checkpoint,
|
| 72 |
+
save_checkpoint,
|
| 73 |
+
save_checkpoint_with_global_batch_idx,
|
| 74 |
+
update_averaged_model,
|
| 75 |
+
)
|
| 76 |
+
from zipvoice.utils.common import (
|
| 77 |
+
AttributeDict,
|
| 78 |
+
MetricsTracker,
|
| 79 |
+
cleanup_dist,
|
| 80 |
+
get_adjusted_batch_count,
|
| 81 |
+
get_env_info,
|
| 82 |
+
get_parameter_groups_with_lrs,
|
| 83 |
+
prepare_input,
|
| 84 |
+
set_batch_count,
|
| 85 |
+
setup_dist,
|
| 86 |
+
setup_logger,
|
| 87 |
+
str2bool,
|
| 88 |
+
)
|
| 89 |
+
from zipvoice.utils.hooks import register_inf_check_hooks
|
| 90 |
+
from zipvoice.utils.lr_scheduler import Eden, FixedLRScheduler, LRScheduler
|
| 91 |
+
from zipvoice.utils.optim import ScaledAdam
|
| 92 |
+
|
| 93 |
+
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_parser():
|
| 97 |
+
parser = argparse.ArgumentParser(
|
| 98 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--world-size",
|
| 103 |
+
type=int,
|
| 104 |
+
default=1,
|
| 105 |
+
help="Number of GPUs for DDP training.",
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--master-port",
|
| 110 |
+
type=int,
|
| 111 |
+
default=12356,
|
| 112 |
+
help="Master port to use for DDP training.",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--tensorboard",
|
| 117 |
+
type=str2bool,
|
| 118 |
+
default=True,
|
| 119 |
+
help="Should various information be logged in tensorboard.",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--num-epochs",
|
| 124 |
+
type=int,
|
| 125 |
+
default=11,
|
| 126 |
+
help="Number of epochs to train.",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--num-iters",
|
| 131 |
+
type=int,
|
| 132 |
+
default=0,
|
| 133 |
+
help="Number of iter to train, will ignore num_epochs if > 0.",
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--start-epoch",
|
| 138 |
+
type=int,
|
| 139 |
+
default=1,
|
| 140 |
+
help="""Resume training from this epoch. It should be positive.
|
| 141 |
+
If larger than 1, it will load checkpoint from
|
| 142 |
+
exp-dir/epoch-{start_epoch-1}.pt
|
| 143 |
+
""",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--checkpoint",
|
| 148 |
+
type=str,
|
| 149 |
+
default=None,
|
| 150 |
+
help="""Checkpoints of pre-trained models, will load it if not None
|
| 151 |
+
""",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--exp-dir",
|
| 156 |
+
type=str,
|
| 157 |
+
default="exp/zipvoice",
|
| 158 |
+
help="""The experiment dir.
|
| 159 |
+
It specifies the directory where all training related
|
| 160 |
+
files, e.g., checkpoints, log, etc, are saved
|
| 161 |
+
""",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--base-lr", type=float, default=0.02, help="The base learning rate."
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--lr-batches",
|
| 170 |
+
type=float,
|
| 171 |
+
default=7500,
|
| 172 |
+
help="""Number of steps that affects how rapidly the learning rate
|
| 173 |
+
decreases. We suggest not to change this.""",
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--lr-epochs",
|
| 178 |
+
type=float,
|
| 179 |
+
default=10,
|
| 180 |
+
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
| 181 |
+
""",
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--lr-hours",
|
| 186 |
+
type=float,
|
| 187 |
+
default=0,
|
| 188 |
+
help="""If positive, --epoch is ignored and it specifies the number of hours
|
| 189 |
+
that affects how rapidly the learning rate decreases.
|
| 190 |
+
""",
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--ref-duration",
|
| 195 |
+
type=float,
|
| 196 |
+
default=50,
|
| 197 |
+
help="""Reference batch duration for purposes of adjusting batch counts for"
|
| 198 |
+
setting various schedules inside the model".
|
| 199 |
+
""",
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--finetune",
|
| 204 |
+
type=str2bool,
|
| 205 |
+
default=False,
|
| 206 |
+
help="Whether to use the fine-tuning mode, will used a fixed learning rate "
|
| 207 |
+
"schedule and skip the large dropout phase.",
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--seed",
|
| 212 |
+
type=int,
|
| 213 |
+
default=42,
|
| 214 |
+
help="The seed for random generators intended for reproducibility",
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--print-diagnostics",
|
| 219 |
+
type=str2bool,
|
| 220 |
+
default=False,
|
| 221 |
+
help="Accumulate stats on activations, print them and exit.",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--scan-oom",
|
| 226 |
+
type=str2bool,
|
| 227 |
+
default=False,
|
| 228 |
+
help="Scan pessimistic batches to see whether they cause OOMs.",
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--inf-check",
|
| 233 |
+
type=str2bool,
|
| 234 |
+
default=False,
|
| 235 |
+
help="Add hooks to check for infinite module outputs and gradients.",
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--save-every-n",
|
| 240 |
+
type=int,
|
| 241 |
+
default=5000,
|
| 242 |
+
help="""Save checkpoint after processing this number of batches"
|
| 243 |
+
periodically. We save checkpoint to exp-dir/ whenever
|
| 244 |
+
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
| 245 |
+
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
| 246 |
+
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
| 247 |
+
end of each epoch where `xxx` is the epoch number counting from 1.
|
| 248 |
+
""",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--keep-last-k",
|
| 253 |
+
type=int,
|
| 254 |
+
default=30,
|
| 255 |
+
help="""Only keep this number of checkpoints on disk.
|
| 256 |
+
For instance, if it is 3, there are only 3 checkpoints
|
| 257 |
+
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
| 258 |
+
It does not affect checkpoints with name `epoch-xxx.pt`.
|
| 259 |
+
""",
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--average-period",
|
| 264 |
+
type=int,
|
| 265 |
+
default=200,
|
| 266 |
+
help="""Update the averaged model, namely `model_avg`, after processing
|
| 267 |
+
this number of batches. `model_avg` is a separate version of model,
|
| 268 |
+
in which each floating-point parameter is the average of all the
|
| 269 |
+
parameters from the start of training. Each time we take the average,
|
| 270 |
+
we do: `model_avg = model * (average_period / batch_idx_train) +
|
| 271 |
+
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
| 272 |
+
""",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--use-fp16",
|
| 277 |
+
type=str2bool,
|
| 278 |
+
default=True,
|
| 279 |
+
help="Whether to use half precision training.",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--feat-scale",
|
| 284 |
+
type=float,
|
| 285 |
+
default=0.1,
|
| 286 |
+
help="The scale factor of fbank feature",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--condition-drop-ratio",
|
| 291 |
+
type=float,
|
| 292 |
+
default=0.2,
|
| 293 |
+
help="The drop rate of text condition during training.",
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--dataset",
|
| 298 |
+
type=str,
|
| 299 |
+
default="emilia",
|
| 300 |
+
choices=["emilia", "libritts", "custom"],
|
| 301 |
+
help="The used training dataset",
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
"--train-manifest",
|
| 306 |
+
type=str,
|
| 307 |
+
help="Path of the training manifest",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--dev-manifest",
|
| 312 |
+
type=str,
|
| 313 |
+
help="Path of the validation manifest",
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
parser.add_argument(
|
| 317 |
+
"--min-len",
|
| 318 |
+
type=float,
|
| 319 |
+
default=1.0,
|
| 320 |
+
help="The minimum audio length used for training",
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--max-len",
|
| 325 |
+
type=float,
|
| 326 |
+
default=30.0,
|
| 327 |
+
help="The maximum audio length used for training",
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--model-config",
|
| 332 |
+
type=str,
|
| 333 |
+
default="conf/zipvoice_base.json",
|
| 334 |
+
help="The model configuration file.",
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--tokenizer",
|
| 339 |
+
type=str,
|
| 340 |
+
default="emilia",
|
| 341 |
+
choices=["emilia", "libritts", "espeak", "simple"],
|
| 342 |
+
help="Tokenizer type.",
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
parser.add_argument(
|
| 346 |
+
"--lang",
|
| 347 |
+
type=str,
|
| 348 |
+
default="en-us",
|
| 349 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 350 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--token-file",
|
| 355 |
+
type=str,
|
| 356 |
+
default="data/tokens_emilia.txt",
|
| 357 |
+
help="The file that contains information that maps tokens to ids,"
|
| 358 |
+
"which is a text file with '{token}\t{token_id}' per line.",
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
return parser
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def get_params() -> AttributeDict:
|
| 365 |
+
"""Return a dict containing training parameters.
|
| 366 |
+
|
| 367 |
+
All training related parameters that are not passed from the commandline
|
| 368 |
+
are saved in the variable `params`.
|
| 369 |
+
|
| 370 |
+
Commandline options are merged into `params` after they are parsed, so
|
| 371 |
+
you can also access them via `params`.
|
| 372 |
+
|
| 373 |
+
Explanation of options saved in `params`:
|
| 374 |
+
|
| 375 |
+
- best_train_loss: Best training loss so far. It is used to select
|
| 376 |
+
the model that has the lowest training loss. It is
|
| 377 |
+
updated during the training.
|
| 378 |
+
|
| 379 |
+
- best_valid_loss: Best validation loss so far. It is used to select
|
| 380 |
+
the model that has the lowest validation loss. It is
|
| 381 |
+
updated during the training.
|
| 382 |
+
|
| 383 |
+
- best_train_epoch: It is the epoch that has the best training loss.
|
| 384 |
+
|
| 385 |
+
- best_valid_epoch: It is the epoch that has the best validation loss.
|
| 386 |
+
|
| 387 |
+
- batch_idx_train: Used to writing statistics to tensorboard. It
|
| 388 |
+
contains number of batches trained so far across
|
| 389 |
+
epochs.
|
| 390 |
+
|
| 391 |
+
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
| 392 |
+
|
| 393 |
+
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
| 394 |
+
|
| 395 |
+
- env_info: A dict containing information about the environment.
|
| 396 |
+
|
| 397 |
+
"""
|
| 398 |
+
params = AttributeDict(
|
| 399 |
+
{
|
| 400 |
+
"best_train_loss": float("inf"),
|
| 401 |
+
"best_valid_loss": float("inf"),
|
| 402 |
+
"best_train_epoch": -1,
|
| 403 |
+
"best_valid_epoch": -1,
|
| 404 |
+
"batch_idx_train": 0,
|
| 405 |
+
"log_interval": 50,
|
| 406 |
+
"reset_interval": 200,
|
| 407 |
+
"env_info": get_env_info(),
|
| 408 |
+
}
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
return params
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def compute_fbank_loss(
|
| 415 |
+
params: AttributeDict,
|
| 416 |
+
model: Union[nn.Module, DDP],
|
| 417 |
+
features: Tensor,
|
| 418 |
+
features_lens: Tensor,
|
| 419 |
+
tokens: List[List[int]],
|
| 420 |
+
is_training: bool,
|
| 421 |
+
) -> Tuple[Tensor, MetricsTracker]:
|
| 422 |
+
"""
|
| 423 |
+
Compute loss given the model and its inputs.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
params:
|
| 427 |
+
Parameters for training. See :func:`get_params`.
|
| 428 |
+
model:
|
| 429 |
+
The model for training.
|
| 430 |
+
features:
|
| 431 |
+
The target acoustic feature.
|
| 432 |
+
features_lens:
|
| 433 |
+
The number of frames of each utterance.
|
| 434 |
+
tokens:
|
| 435 |
+
Input tokens that representing the transcripts.
|
| 436 |
+
is_training:
|
| 437 |
+
True for training. False for validation. When it is True, this
|
| 438 |
+
function enables autograd during computation; when it is False, it
|
| 439 |
+
disables autograd.
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 443 |
+
|
| 444 |
+
batch_size, num_frames, _ = features.shape
|
| 445 |
+
|
| 446 |
+
features = torch.nn.functional.pad(
|
| 447 |
+
features, (0, 0, 0, num_frames - features.size(1))
|
| 448 |
+
) # (B, T, F)
|
| 449 |
+
noise = torch.randn_like(features) # (B, T, F)
|
| 450 |
+
|
| 451 |
+
# Sampling t from uniform distribution
|
| 452 |
+
if is_training:
|
| 453 |
+
t = torch.rand(batch_size, 1, 1, device=device)
|
| 454 |
+
else:
|
| 455 |
+
t = (
|
| 456 |
+
(torch.arange(batch_size, device=device) / batch_size)
|
| 457 |
+
.unsqueeze(1)
|
| 458 |
+
.unsqueeze(2)
|
| 459 |
+
)
|
| 460 |
+
with torch.set_grad_enabled(is_training):
|
| 461 |
+
|
| 462 |
+
loss = model(
|
| 463 |
+
tokens=tokens,
|
| 464 |
+
features=features,
|
| 465 |
+
features_lens=features_lens,
|
| 466 |
+
noise=noise,
|
| 467 |
+
t=t,
|
| 468 |
+
condition_drop_ratio=params.condition_drop_ratio,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
assert loss.requires_grad == is_training
|
| 472 |
+
info = MetricsTracker()
|
| 473 |
+
num_frames = features_lens.sum().item()
|
| 474 |
+
info["frames"] = num_frames
|
| 475 |
+
info["loss"] = loss.detach().cpu().item() * num_frames
|
| 476 |
+
|
| 477 |
+
return loss, info
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def train_one_epoch(
|
| 481 |
+
params: AttributeDict,
|
| 482 |
+
model: Union[nn.Module, DDP],
|
| 483 |
+
optimizer: Optimizer,
|
| 484 |
+
scheduler: LRSchedulerType,
|
| 485 |
+
train_dl: torch.utils.data.DataLoader,
|
| 486 |
+
valid_dl: torch.utils.data.DataLoader,
|
| 487 |
+
scaler: GradScaler,
|
| 488 |
+
model_avg: Optional[nn.Module] = None,
|
| 489 |
+
tb_writer: Optional[SummaryWriter] = None,
|
| 490 |
+
world_size: int = 1,
|
| 491 |
+
rank: int = 0,
|
| 492 |
+
) -> None:
|
| 493 |
+
"""Train the model for one epoch.
|
| 494 |
+
|
| 495 |
+
The training loss from the mean of all frames is saved in
|
| 496 |
+
`params.train_loss`. It runs the validation process every
|
| 497 |
+
`params.valid_interval` batches.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
params:
|
| 501 |
+
It is returned by :func:`get_params`.
|
| 502 |
+
model:
|
| 503 |
+
The model for training.
|
| 504 |
+
optimizer:
|
| 505 |
+
The optimizer.
|
| 506 |
+
scheduler:
|
| 507 |
+
The learning rate scheduler, we call step() every epoch.
|
| 508 |
+
train_dl:
|
| 509 |
+
Dataloader for the training dataset.
|
| 510 |
+
valid_dl:
|
| 511 |
+
Dataloader for the validation dataset.
|
| 512 |
+
scaler:
|
| 513 |
+
The scaler used for mix precision training.
|
| 514 |
+
tb_writer:
|
| 515 |
+
Writer to write log messages to tensorboard.
|
| 516 |
+
world_size:
|
| 517 |
+
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
| 518 |
+
rank:
|
| 519 |
+
The rank of the node in DDP training. If no DDP is used, it should
|
| 520 |
+
be set to 0.
|
| 521 |
+
"""
|
| 522 |
+
model.train()
|
| 523 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 524 |
+
|
| 525 |
+
# used to track the stats over iterations in one epoch
|
| 526 |
+
tot_loss = MetricsTracker()
|
| 527 |
+
|
| 528 |
+
saved_bad_model = False
|
| 529 |
+
|
| 530 |
+
def save_bad_model(suffix: str = ""):
|
| 531 |
+
save_checkpoint(
|
| 532 |
+
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
| 533 |
+
model=model,
|
| 534 |
+
model_avg=model_avg,
|
| 535 |
+
params=params,
|
| 536 |
+
optimizer=optimizer,
|
| 537 |
+
scheduler=scheduler,
|
| 538 |
+
sampler=train_dl.sampler,
|
| 539 |
+
scaler=scaler,
|
| 540 |
+
rank=0,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
for batch_idx, batch in enumerate(train_dl):
|
| 544 |
+
|
| 545 |
+
if batch_idx % 10 == 0:
|
| 546 |
+
if params.finetune:
|
| 547 |
+
set_batch_count(model, get_adjusted_batch_count(params) + 100000)
|
| 548 |
+
else:
|
| 549 |
+
set_batch_count(model, get_adjusted_batch_count(params))
|
| 550 |
+
|
| 551 |
+
if (
|
| 552 |
+
params.batch_idx_train > 0
|
| 553 |
+
and params.batch_idx_train % params.valid_interval == 0
|
| 554 |
+
and not params.print_diagnostics
|
| 555 |
+
):
|
| 556 |
+
logging.info("Computing validation loss")
|
| 557 |
+
valid_info = compute_validation_loss(
|
| 558 |
+
params=params,
|
| 559 |
+
model=model,
|
| 560 |
+
valid_dl=valid_dl,
|
| 561 |
+
world_size=world_size,
|
| 562 |
+
)
|
| 563 |
+
model.train()
|
| 564 |
+
logging.info(
|
| 565 |
+
f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
|
| 566 |
+
f" validation: {valid_info}"
|
| 567 |
+
)
|
| 568 |
+
logging.info(
|
| 569 |
+
f"Maximum memory allocated so far is "
|
| 570 |
+
f"{torch.cuda.max_memory_allocated() // 1000000}MB"
|
| 571 |
+
)
|
| 572 |
+
if tb_writer is not None:
|
| 573 |
+
valid_info.write_summary(
|
| 574 |
+
tb_writer, "train/valid_", params.batch_idx_train
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
params.batch_idx_train += 1
|
| 578 |
+
|
| 579 |
+
batch_size = len(batch["text"])
|
| 580 |
+
|
| 581 |
+
tokens, features, features_lens = prepare_input(
|
| 582 |
+
params=params,
|
| 583 |
+
batch=batch,
|
| 584 |
+
device=device,
|
| 585 |
+
return_tokens=True,
|
| 586 |
+
return_feature=True,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
try:
|
| 590 |
+
with autocast("cuda", enabled=params.use_fp16):
|
| 591 |
+
loss, loss_info = compute_fbank_loss(
|
| 592 |
+
params=params,
|
| 593 |
+
model=model,
|
| 594 |
+
features=features,
|
| 595 |
+
features_lens=features_lens,
|
| 596 |
+
tokens=tokens,
|
| 597 |
+
is_training=True,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
| 601 |
+
|
| 602 |
+
scaler.scale(loss).backward()
|
| 603 |
+
|
| 604 |
+
scheduler.step_batch(params.batch_idx_train)
|
| 605 |
+
# Use the number of hours of speech to adjust the learning rate
|
| 606 |
+
if params.lr_hours > 0:
|
| 607 |
+
scheduler.step_epoch(
|
| 608 |
+
params.batch_idx_train
|
| 609 |
+
* params.max_duration
|
| 610 |
+
* params.world_size
|
| 611 |
+
/ 3600
|
| 612 |
+
)
|
| 613 |
+
scaler.step(optimizer)
|
| 614 |
+
scaler.update()
|
| 615 |
+
optimizer.zero_grad()
|
| 616 |
+
except Exception as e:
|
| 617 |
+
logging.info(f"Caught exception : {e}.")
|
| 618 |
+
save_bad_model()
|
| 619 |
+
raise
|
| 620 |
+
|
| 621 |
+
if params.print_diagnostics and batch_idx == 5:
|
| 622 |
+
return
|
| 623 |
+
|
| 624 |
+
if (
|
| 625 |
+
rank == 0
|
| 626 |
+
and params.batch_idx_train > 0
|
| 627 |
+
and params.batch_idx_train % params.average_period == 0
|
| 628 |
+
):
|
| 629 |
+
update_averaged_model(
|
| 630 |
+
params=params,
|
| 631 |
+
model_cur=model,
|
| 632 |
+
model_avg=model_avg,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
if (
|
| 636 |
+
params.batch_idx_train > 0
|
| 637 |
+
and params.batch_idx_train % params.save_every_n == 0
|
| 638 |
+
):
|
| 639 |
+
save_checkpoint_with_global_batch_idx(
|
| 640 |
+
out_dir=params.exp_dir,
|
| 641 |
+
global_batch_idx=params.batch_idx_train,
|
| 642 |
+
model=model,
|
| 643 |
+
model_avg=model_avg,
|
| 644 |
+
params=params,
|
| 645 |
+
optimizer=optimizer,
|
| 646 |
+
scheduler=scheduler,
|
| 647 |
+
sampler=train_dl.sampler,
|
| 648 |
+
scaler=scaler,
|
| 649 |
+
rank=rank,
|
| 650 |
+
)
|
| 651 |
+
remove_checkpoints(
|
| 652 |
+
out_dir=params.exp_dir,
|
| 653 |
+
topk=params.keep_last_k,
|
| 654 |
+
rank=rank,
|
| 655 |
+
)
|
| 656 |
+
if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
|
| 657 |
+
break
|
| 658 |
+
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
| 659 |
+
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
| 660 |
+
# of the grad scaler is configurable, but we can't configure it to have
|
| 661 |
+
# different behavior depending on the current grad scale.
|
| 662 |
+
cur_grad_scale = scaler._scale.item()
|
| 663 |
+
|
| 664 |
+
if cur_grad_scale < 1024.0 or (
|
| 665 |
+
cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
|
| 666 |
+
):
|
| 667 |
+
scaler.update(cur_grad_scale * 2.0)
|
| 668 |
+
if cur_grad_scale < 0.01:
|
| 669 |
+
if not saved_bad_model:
|
| 670 |
+
save_bad_model(suffix="-first-warning")
|
| 671 |
+
saved_bad_model = True
|
| 672 |
+
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
| 673 |
+
if cur_grad_scale < 1.0e-05:
|
| 674 |
+
save_bad_model()
|
| 675 |
+
raise RuntimeError(
|
| 676 |
+
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
if params.batch_idx_train % params.log_interval == 0:
|
| 680 |
+
cur_lr = max(scheduler.get_last_lr())
|
| 681 |
+
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
| 682 |
+
|
| 683 |
+
logging.info(
|
| 684 |
+
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
| 685 |
+
f"global_batch_idx: {params.batch_idx_train}, "
|
| 686 |
+
f"batch size: {batch_size}, "
|
| 687 |
+
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
| 688 |
+
f"cur_lr: {cur_lr:.2e}, "
|
| 689 |
+
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
if tb_writer is not None:
|
| 693 |
+
tb_writer.add_scalar(
|
| 694 |
+
"train/learning_rate", cur_lr, params.batch_idx_train
|
| 695 |
+
)
|
| 696 |
+
loss_info.write_summary(
|
| 697 |
+
tb_writer, "train/current_", params.batch_idx_train
|
| 698 |
+
)
|
| 699 |
+
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
| 700 |
+
if params.use_fp16:
|
| 701 |
+
tb_writer.add_scalar(
|
| 702 |
+
"train/grad_scale",
|
| 703 |
+
cur_grad_scale,
|
| 704 |
+
params.batch_idx_train,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
loss_value = tot_loss["loss"]
|
| 708 |
+
params.train_loss = loss_value
|
| 709 |
+
if params.train_loss < params.best_train_loss:
|
| 710 |
+
params.best_train_epoch = params.cur_epoch
|
| 711 |
+
params.best_train_loss = params.train_loss
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def compute_validation_loss(
|
| 715 |
+
params: AttributeDict,
|
| 716 |
+
model: Union[nn.Module, DDP],
|
| 717 |
+
valid_dl: torch.utils.data.DataLoader,
|
| 718 |
+
world_size: int = 1,
|
| 719 |
+
) -> MetricsTracker:
|
| 720 |
+
"""Run the validation process."""
|
| 721 |
+
|
| 722 |
+
model.eval()
|
| 723 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 724 |
+
|
| 725 |
+
# used to summary the stats over iterations
|
| 726 |
+
tot_loss = MetricsTracker()
|
| 727 |
+
|
| 728 |
+
for batch_idx, batch in enumerate(valid_dl):
|
| 729 |
+
tokens, features, features_lens = prepare_input(
|
| 730 |
+
params=params,
|
| 731 |
+
batch=batch,
|
| 732 |
+
device=device,
|
| 733 |
+
return_tokens=True,
|
| 734 |
+
return_feature=True,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
loss, loss_info = compute_fbank_loss(
|
| 738 |
+
params=params,
|
| 739 |
+
model=model,
|
| 740 |
+
features=features,
|
| 741 |
+
features_lens=features_lens,
|
| 742 |
+
tokens=tokens,
|
| 743 |
+
is_training=False,
|
| 744 |
+
)
|
| 745 |
+
assert loss.requires_grad is False
|
| 746 |
+
tot_loss = tot_loss + loss_info
|
| 747 |
+
|
| 748 |
+
if world_size > 1:
|
| 749 |
+
tot_loss.reduce(loss.device)
|
| 750 |
+
|
| 751 |
+
loss_value = tot_loss["loss"]
|
| 752 |
+
if loss_value < params.best_valid_loss:
|
| 753 |
+
params.best_valid_epoch = params.cur_epoch
|
| 754 |
+
params.best_valid_loss = loss_value
|
| 755 |
+
|
| 756 |
+
return tot_loss
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def display_and_save_batch(
|
| 760 |
+
batch: dict,
|
| 761 |
+
params: AttributeDict,
|
| 762 |
+
) -> None:
|
| 763 |
+
"""Display the batch statistics and save the batch into disk.
|
| 764 |
+
|
| 765 |
+
Args:
|
| 766 |
+
batch:
|
| 767 |
+
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
| 768 |
+
for the content in it.
|
| 769 |
+
params:
|
| 770 |
+
Parameters for training. See :func:`get_params`.
|
| 771 |
+
sp:
|
| 772 |
+
The BPE model.
|
| 773 |
+
"""
|
| 774 |
+
from lhotse.utils import uuid4
|
| 775 |
+
|
| 776 |
+
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
| 777 |
+
logging.info(f"Saving batch to {filename}")
|
| 778 |
+
torch.save(batch, filename)
|
| 779 |
+
|
| 780 |
+
features = batch["features"]
|
| 781 |
+
tokens = batch["tokens"]
|
| 782 |
+
|
| 783 |
+
logging.info(f"features shape: {features.shape}")
|
| 784 |
+
num_tokens = sum(len(i) for i in tokens)
|
| 785 |
+
logging.info(f"num tokens: {num_tokens}")
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def scan_pessimistic_batches_for_oom(
|
| 789 |
+
model: Union[nn.Module, DDP],
|
| 790 |
+
train_dl: torch.utils.data.DataLoader,
|
| 791 |
+
optimizer: torch.optim.Optimizer,
|
| 792 |
+
params: AttributeDict,
|
| 793 |
+
):
|
| 794 |
+
from lhotse.dataset import find_pessimistic_batches
|
| 795 |
+
|
| 796 |
+
logging.info(
|
| 797 |
+
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
| 798 |
+
)
|
| 799 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 800 |
+
|
| 801 |
+
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
| 802 |
+
for criterion, cuts in batches.items():
|
| 803 |
+
batch = train_dl.dataset[cuts]
|
| 804 |
+
tokens, features, features_lens = prepare_input(
|
| 805 |
+
params=params,
|
| 806 |
+
batch=batch,
|
| 807 |
+
device=device,
|
| 808 |
+
return_tokens=True,
|
| 809 |
+
return_feature=True,
|
| 810 |
+
)
|
| 811 |
+
try:
|
| 812 |
+
with autocast("cuda", enabled=params.use_fp16):
|
| 813 |
+
|
| 814 |
+
loss, loss_info = compute_fbank_loss(
|
| 815 |
+
params=params,
|
| 816 |
+
model=model,
|
| 817 |
+
features=features,
|
| 818 |
+
features_lens=features_lens,
|
| 819 |
+
tokens=tokens,
|
| 820 |
+
is_training=True,
|
| 821 |
+
)
|
| 822 |
+
loss.backward()
|
| 823 |
+
optimizer.zero_grad()
|
| 824 |
+
except Exception as e:
|
| 825 |
+
if "CUDA out of memory" in str(e):
|
| 826 |
+
logging.error(
|
| 827 |
+
"Your GPU ran out of memory with the current "
|
| 828 |
+
"max_duration setting. We recommend decreasing "
|
| 829 |
+
"max_duration and trying again.\n"
|
| 830 |
+
f"Failing criterion: {criterion} "
|
| 831 |
+
f"(={crit_values[criterion]}) ..."
|
| 832 |
+
)
|
| 833 |
+
display_and_save_batch(batch, params=params)
|
| 834 |
+
raise
|
| 835 |
+
logging.info(
|
| 836 |
+
f"Maximum memory allocated so far is "
|
| 837 |
+
f"{torch.cuda.max_memory_allocated() // 1000000}MB"
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
def tokenize_text(c: Cut, tokenizer):
|
| 842 |
+
text = c.supervisions[0].text
|
| 843 |
+
tokens = tokenizer.texts_to_token_ids([text])
|
| 844 |
+
c.supervisions[0].tokens = tokens[0]
|
| 845 |
+
return c
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def run(rank, world_size, args):
|
| 849 |
+
"""
|
| 850 |
+
Args:
|
| 851 |
+
rank:
|
| 852 |
+
It is a value between 0 and `world_size-1`, which is
|
| 853 |
+
passed automatically by `mp.spawn()` in :func:`main`.
|
| 854 |
+
The node with rank 0 is responsible for saving checkpoint.
|
| 855 |
+
world_size:
|
| 856 |
+
Number of GPUs for DDP training.
|
| 857 |
+
args:
|
| 858 |
+
The return value of get_parser().parse_args()
|
| 859 |
+
"""
|
| 860 |
+
params = get_params()
|
| 861 |
+
params.update(vars(args))
|
| 862 |
+
params.valid_interval = params.save_every_n
|
| 863 |
+
# Set epoch to a large number to ignore it.
|
| 864 |
+
if params.num_iters > 0:
|
| 865 |
+
params.num_epochs = 1000000
|
| 866 |
+
with open(params.model_config, "r") as f:
|
| 867 |
+
model_config = json.load(f)
|
| 868 |
+
params.update(model_config["model"])
|
| 869 |
+
params.update(model_config["feature"])
|
| 870 |
+
|
| 871 |
+
fix_random_seed(params.seed)
|
| 872 |
+
if world_size > 1:
|
| 873 |
+
setup_dist(rank, world_size, params.master_port)
|
| 874 |
+
|
| 875 |
+
os.makedirs(f"{params.exp_dir}", exist_ok=True)
|
| 876 |
+
copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
|
| 877 |
+
copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
|
| 878 |
+
setup_logger(f"{params.exp_dir}/log/log-train")
|
| 879 |
+
|
| 880 |
+
if args.tensorboard and rank == 0:
|
| 881 |
+
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
| 882 |
+
else:
|
| 883 |
+
tb_writer = None
|
| 884 |
+
|
| 885 |
+
if torch.cuda.is_available():
|
| 886 |
+
params.device = torch.device("cuda", rank)
|
| 887 |
+
else:
|
| 888 |
+
params.device = torch.device("cpu")
|
| 889 |
+
logging.info(f"Device: {params.device}")
|
| 890 |
+
|
| 891 |
+
if params.tokenizer == "emilia":
|
| 892 |
+
tokenizer = EmiliaTokenizer(token_file=params.token_file)
|
| 893 |
+
elif params.tokenizer == "libritts":
|
| 894 |
+
tokenizer = LibriTTSTokenizer(token_file=params.token_file)
|
| 895 |
+
elif params.tokenizer == "espeak":
|
| 896 |
+
tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
|
| 897 |
+
else:
|
| 898 |
+
assert params.tokenizer == "simple"
|
| 899 |
+
tokenizer = SimpleTokenizer(token_file=params.token_file)
|
| 900 |
+
|
| 901 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 902 |
+
params.update(tokenizer_config)
|
| 903 |
+
|
| 904 |
+
logging.info(params)
|
| 905 |
+
|
| 906 |
+
logging.info("About to create model")
|
| 907 |
+
|
| 908 |
+
model = ZipVoice(
|
| 909 |
+
**model_config["model"],
|
| 910 |
+
**tokenizer_config,
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
if params.checkpoint is not None:
|
| 914 |
+
logging.info(f"Loading pre-trained model from {params.checkpoint}")
|
| 915 |
+
_ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
|
| 916 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
| 917 |
+
logging.info(f"Number of parameters : {num_param}")
|
| 918 |
+
|
| 919 |
+
model_avg: Optional[nn.Module] = None
|
| 920 |
+
if rank == 0:
|
| 921 |
+
# model_avg is only used with rank 0
|
| 922 |
+
model_avg = copy.deepcopy(model).to(torch.float64)
|
| 923 |
+
|
| 924 |
+
assert params.start_epoch > 0, params.start_epoch
|
| 925 |
+
if params.start_epoch > 1:
|
| 926 |
+
checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
|
| 927 |
+
|
| 928 |
+
model = model.to(params.device)
|
| 929 |
+
if world_size > 1:
|
| 930 |
+
logging.info("Using DDP")
|
| 931 |
+
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
| 932 |
+
|
| 933 |
+
optimizer = ScaledAdam(
|
| 934 |
+
get_parameter_groups_with_lrs(
|
| 935 |
+
model,
|
| 936 |
+
lr=params.base_lr,
|
| 937 |
+
include_names=True,
|
| 938 |
+
),
|
| 939 |
+
lr=params.base_lr, # should have no effect
|
| 940 |
+
clipping_scale=2.0,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
assert params.lr_hours >= 0
|
| 944 |
+
|
| 945 |
+
if params.finetune:
|
| 946 |
+
scheduler = FixedLRScheduler(optimizer)
|
| 947 |
+
elif params.lr_hours > 0:
|
| 948 |
+
scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
|
| 949 |
+
else:
|
| 950 |
+
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
| 951 |
+
|
| 952 |
+
scaler = GradScaler("cuda", enabled=params.use_fp16)
|
| 953 |
+
|
| 954 |
+
if params.start_epoch > 1 and checkpoints is not None:
|
| 955 |
+
# load state_dict for optimizers
|
| 956 |
+
if "optimizer" in checkpoints:
|
| 957 |
+
logging.info("Loading optimizer state dict")
|
| 958 |
+
optimizer.load_state_dict(checkpoints["optimizer"])
|
| 959 |
+
|
| 960 |
+
# load state_dict for schedulers
|
| 961 |
+
if "scheduler" in checkpoints:
|
| 962 |
+
logging.info("Loading scheduler state dict")
|
| 963 |
+
scheduler.load_state_dict(checkpoints["scheduler"])
|
| 964 |
+
|
| 965 |
+
if "grad_scaler" in checkpoints:
|
| 966 |
+
logging.info("Loading grad scaler state dict")
|
| 967 |
+
scaler.load_state_dict(checkpoints["grad_scaler"])
|
| 968 |
+
|
| 969 |
+
if params.print_diagnostics:
|
| 970 |
+
opts = diagnostics.TensorDiagnosticOptions(
|
| 971 |
+
512
|
| 972 |
+
) # allow 4 megabytes per sub-module
|
| 973 |
+
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
| 974 |
+
|
| 975 |
+
if params.inf_check:
|
| 976 |
+
register_inf_check_hooks(model)
|
| 977 |
+
|
| 978 |
+
def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
|
| 979 |
+
if c.duration < min_len or c.duration > max_len:
|
| 980 |
+
return False
|
| 981 |
+
return True
|
| 982 |
+
|
| 983 |
+
_remove_short_and_long_utt = partial(
|
| 984 |
+
remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
datamodule = TtsDataModule(args)
|
| 988 |
+
if params.dataset == "emilia":
|
| 989 |
+
train_cuts = CutSet.mux(
|
| 990 |
+
datamodule.train_emilia_EN_cuts(),
|
| 991 |
+
datamodule.train_emilia_ZH_cuts(),
|
| 992 |
+
weights=[46000, 49000],
|
| 993 |
+
)
|
| 994 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 995 |
+
dev_cuts = CutSet.mux(
|
| 996 |
+
datamodule.dev_emilia_EN_cuts(),
|
| 997 |
+
datamodule.dev_emilia_ZH_cuts(),
|
| 998 |
+
weights=[0.5, 0.5],
|
| 999 |
+
)
|
| 1000 |
+
elif params.dataset == "libritts":
|
| 1001 |
+
train_cuts = datamodule.train_libritts_cuts()
|
| 1002 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 1003 |
+
dev_cuts = datamodule.dev_libritts_cuts()
|
| 1004 |
+
else:
|
| 1005 |
+
assert params.dataset == "custom"
|
| 1006 |
+
train_cuts = datamodule.train_custom_cuts(params.train_manifest)
|
| 1007 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 1008 |
+
dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
|
| 1009 |
+
# To avoid OOM issues due to too long dev cuts
|
| 1010 |
+
dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
|
| 1011 |
+
|
| 1012 |
+
_tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
|
| 1013 |
+
train_cuts = train_cuts.map(_tokenize_text)
|
| 1014 |
+
dev_cuts = dev_cuts.map(_tokenize_text)
|
| 1015 |
+
|
| 1016 |
+
train_dl = datamodule.train_dataloaders(train_cuts)
|
| 1017 |
+
|
| 1018 |
+
valid_dl = datamodule.dev_dataloaders(dev_cuts)
|
| 1019 |
+
|
| 1020 |
+
if params.scan_oom:
|
| 1021 |
+
scan_pessimistic_batches_for_oom(
|
| 1022 |
+
model=model,
|
| 1023 |
+
train_dl=train_dl,
|
| 1024 |
+
optimizer=optimizer,
|
| 1025 |
+
params=params,
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
logging.info("Training started")
|
| 1029 |
+
|
| 1030 |
+
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
| 1031 |
+
logging.info(f"Start epoch {epoch}")
|
| 1032 |
+
|
| 1033 |
+
if params.lr_hours == 0:
|
| 1034 |
+
scheduler.step_epoch(epoch - 1)
|
| 1035 |
+
fix_random_seed(params.seed + epoch - 1)
|
| 1036 |
+
train_dl.sampler.set_epoch(epoch - 1)
|
| 1037 |
+
|
| 1038 |
+
params.cur_epoch = epoch
|
| 1039 |
+
|
| 1040 |
+
if tb_writer is not None:
|
| 1041 |
+
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
| 1042 |
+
|
| 1043 |
+
train_one_epoch(
|
| 1044 |
+
params=params,
|
| 1045 |
+
model=model,
|
| 1046 |
+
model_avg=model_avg,
|
| 1047 |
+
optimizer=optimizer,
|
| 1048 |
+
scheduler=scheduler,
|
| 1049 |
+
train_dl=train_dl,
|
| 1050 |
+
valid_dl=valid_dl,
|
| 1051 |
+
scaler=scaler,
|
| 1052 |
+
tb_writer=tb_writer,
|
| 1053 |
+
world_size=world_size,
|
| 1054 |
+
rank=rank,
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
|
| 1058 |
+
break
|
| 1059 |
+
|
| 1060 |
+
if params.print_diagnostics:
|
| 1061 |
+
diagnostic.print_diagnostics()
|
| 1062 |
+
break
|
| 1063 |
+
|
| 1064 |
+
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
| 1065 |
+
save_checkpoint(
|
| 1066 |
+
filename=filename,
|
| 1067 |
+
params=params,
|
| 1068 |
+
model=model,
|
| 1069 |
+
model_avg=model_avg,
|
| 1070 |
+
optimizer=optimizer,
|
| 1071 |
+
scheduler=scheduler,
|
| 1072 |
+
sampler=train_dl.sampler,
|
| 1073 |
+
scaler=scaler,
|
| 1074 |
+
rank=rank,
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
if rank == 0:
|
| 1078 |
+
if params.best_train_epoch == params.cur_epoch:
|
| 1079 |
+
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
| 1080 |
+
copyfile(src=filename, dst=best_train_filename)
|
| 1081 |
+
|
| 1082 |
+
if params.best_valid_epoch == params.cur_epoch:
|
| 1083 |
+
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
| 1084 |
+
copyfile(src=filename, dst=best_valid_filename)
|
| 1085 |
+
|
| 1086 |
+
logging.info("Done!")
|
| 1087 |
+
|
| 1088 |
+
if world_size > 1:
|
| 1089 |
+
torch.distributed.barrier()
|
| 1090 |
+
cleanup_dist()
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
def main():
|
| 1094 |
+
parser = get_parser()
|
| 1095 |
+
TtsDataModule.add_arguments(parser)
|
| 1096 |
+
args = parser.parse_args()
|
| 1097 |
+
args.exp_dir = Path(args.exp_dir)
|
| 1098 |
+
|
| 1099 |
+
world_size = args.world_size
|
| 1100 |
+
assert world_size >= 1
|
| 1101 |
+
if world_size > 1:
|
| 1102 |
+
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
| 1103 |
+
else:
|
| 1104 |
+
run(rank=0, world_size=1, args=args)
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
if __name__ == "__main__":
|
| 1108 |
+
torch.set_num_threads(1)
|
| 1109 |
+
torch.set_num_interop_threads(1)
|
| 1110 |
+
main()
|
zipvoice/bin/train_zipvoice_distill.py
ADDED
|
@@ -0,0 +1,1159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
This script trains a ZipVoice-Distill model starting from a ZipVoice model.
|
| 21 |
+
It has two distillation stages.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
|
| 25 |
+
(1) The first distillation stage with a fixed ZipVoice model as the teacher.
|
| 26 |
+
|
| 27 |
+
python3 -m zipvoice.bin.train_zipvoice_distill \
|
| 28 |
+
--world-size 8 \
|
| 29 |
+
--use-fp16 1 \
|
| 30 |
+
--num-iters 60000 \
|
| 31 |
+
--max-duration 500 \
|
| 32 |
+
--base-lr 0.0005 \
|
| 33 |
+
--tokenizer emilia \
|
| 34 |
+
--token-file data/tokens_emilia.txt \
|
| 35 |
+
--dataset emilia \
|
| 36 |
+
--manifest-dir data/fbank \
|
| 37 |
+
--teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
|
| 38 |
+
--distill-stage first \
|
| 39 |
+
--exp-dir exp/zipvoice_distill_1stage
|
| 40 |
+
|
| 41 |
+
(2) The second distillation stage with a EMA model as the teacher.
|
| 42 |
+
python3 -m zipvoice.bin.train_zipvoice_distill \
|
| 43 |
+
--world-size 8 \
|
| 44 |
+
--use-fp16 1 \
|
| 45 |
+
--num-iters 2000 \
|
| 46 |
+
--save-every-n 1000 \
|
| 47 |
+
--max-duration 500 \
|
| 48 |
+
--base-lr 0.0001 \
|
| 49 |
+
--model-config conf/zipvoice_base.json \
|
| 50 |
+
--tokenizer emilia \
|
| 51 |
+
--token-file data/tokens_emilia.txt \
|
| 52 |
+
--dataset emilia \
|
| 53 |
+
--manifest-dir data/fbank \
|
| 54 |
+
--teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
|
| 55 |
+
--distill-stage second \
|
| 56 |
+
--exp-dir zipvoice/exp_zipvoice_distill
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
import argparse
|
| 60 |
+
import copy
|
| 61 |
+
import json
|
| 62 |
+
import logging
|
| 63 |
+
import os
|
| 64 |
+
import random
|
| 65 |
+
from functools import partial
|
| 66 |
+
from pathlib import Path
|
| 67 |
+
from shutil import copyfile
|
| 68 |
+
from typing import List, Optional, Tuple, Union
|
| 69 |
+
|
| 70 |
+
import torch
|
| 71 |
+
import torch.multiprocessing as mp
|
| 72 |
+
import torch.nn as nn
|
| 73 |
+
from lhotse.cut import Cut, CutSet
|
| 74 |
+
from lhotse.utils import fix_random_seed
|
| 75 |
+
from torch import Tensor
|
| 76 |
+
from torch.amp import GradScaler, autocast
|
| 77 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 78 |
+
from torch.optim import Optimizer
|
| 79 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 80 |
+
|
| 81 |
+
import zipvoice.utils.diagnostics as diagnostics
|
| 82 |
+
from zipvoice.bin.train_zipvoice import (
|
| 83 |
+
display_and_save_batch,
|
| 84 |
+
get_params,
|
| 85 |
+
tokenize_text,
|
| 86 |
+
)
|
| 87 |
+
from zipvoice.dataset.datamodule import TtsDataModule
|
| 88 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 89 |
+
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
|
| 90 |
+
from zipvoice.tokenizer.tokenizer import (
|
| 91 |
+
EmiliaTokenizer,
|
| 92 |
+
EspeakTokenizer,
|
| 93 |
+
LibriTTSTokenizer,
|
| 94 |
+
SimpleTokenizer,
|
| 95 |
+
)
|
| 96 |
+
from zipvoice.utils.checkpoint import (
|
| 97 |
+
load_checkpoint,
|
| 98 |
+
remove_checkpoints,
|
| 99 |
+
resume_checkpoint,
|
| 100 |
+
save_checkpoint,
|
| 101 |
+
save_checkpoint_with_global_batch_idx,
|
| 102 |
+
update_averaged_model,
|
| 103 |
+
)
|
| 104 |
+
from zipvoice.utils.common import (
|
| 105 |
+
AttributeDict,
|
| 106 |
+
MetricsTracker,
|
| 107 |
+
cleanup_dist,
|
| 108 |
+
condition_time_mask,
|
| 109 |
+
get_adjusted_batch_count,
|
| 110 |
+
get_parameter_groups_with_lrs,
|
| 111 |
+
make_pad_mask,
|
| 112 |
+
prepare_input,
|
| 113 |
+
set_batch_count,
|
| 114 |
+
setup_dist,
|
| 115 |
+
setup_logger,
|
| 116 |
+
str2bool,
|
| 117 |
+
)
|
| 118 |
+
from zipvoice.utils.hooks import register_inf_check_hooks
|
| 119 |
+
from zipvoice.utils.lr_scheduler import FixedLRScheduler, LRScheduler
|
| 120 |
+
from zipvoice.utils.optim import ScaledAdam
|
| 121 |
+
|
| 122 |
+
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_parser():
|
| 126 |
+
parser = argparse.ArgumentParser(
|
| 127 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--world-size",
|
| 132 |
+
type=int,
|
| 133 |
+
default=1,
|
| 134 |
+
help="Number of GPUs for DDP training.",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--master-port",
|
| 139 |
+
type=int,
|
| 140 |
+
default=12356,
|
| 141 |
+
help="Master port to use for DDP training.",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--tensorboard",
|
| 146 |
+
type=str2bool,
|
| 147 |
+
default=True,
|
| 148 |
+
help="Should various information be logged in tensorboard.",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--num-epochs",
|
| 153 |
+
type=int,
|
| 154 |
+
default=1,
|
| 155 |
+
help="Number of epochs to train.",
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--num-iters",
|
| 160 |
+
type=int,
|
| 161 |
+
default=0,
|
| 162 |
+
help="Number of iter to train, will ignore num_epochs if > 0.",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--start-epoch",
|
| 167 |
+
type=int,
|
| 168 |
+
default=1,
|
| 169 |
+
help="""Resume training from this epoch. It should be positive.
|
| 170 |
+
If larger than 1, it will load checkpoint from
|
| 171 |
+
exp-dir/epoch-{start_epoch-1}.pt
|
| 172 |
+
""",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--teacher-model",
|
| 177 |
+
type=str,
|
| 178 |
+
help="""Checkpoints of pre-trained teacher model""",
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--exp-dir",
|
| 183 |
+
type=str,
|
| 184 |
+
default="exp/zipvoice_distill",
|
| 185 |
+
help="""The experiment dir.
|
| 186 |
+
It specifies the directory where all training related
|
| 187 |
+
files, e.g., checkpoints, log, etc, are saved
|
| 188 |
+
""",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--base-lr", type=float, default=0.001, help="The base learning rate."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--ref-duration",
|
| 197 |
+
type=float,
|
| 198 |
+
default=50,
|
| 199 |
+
help="Reference batch duration for purposes of adjusting batch counts for "
|
| 200 |
+
"setting various schedules inside the model",
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--seed",
|
| 205 |
+
type=int,
|
| 206 |
+
default=42,
|
| 207 |
+
help="The seed for random generators intended for reproducibility",
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--print-diagnostics",
|
| 212 |
+
type=str2bool,
|
| 213 |
+
default=False,
|
| 214 |
+
help="Accumulate stats on activations, print them and exit.",
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--scan-oom",
|
| 219 |
+
type=str2bool,
|
| 220 |
+
default=False,
|
| 221 |
+
help="Scan pessimistic batches to see whether they cause OOMs.",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--inf-check",
|
| 226 |
+
type=str2bool,
|
| 227 |
+
default=False,
|
| 228 |
+
help="Add hooks to check for infinite module outputs and gradients.",
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--save-every-n",
|
| 233 |
+
type=int,
|
| 234 |
+
default=1000,
|
| 235 |
+
help="""Save checkpoint after processing this number of batches"
|
| 236 |
+
periodically. We save checkpoint to exp-dir/ whenever
|
| 237 |
+
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
| 238 |
+
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
| 239 |
+
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
| 240 |
+
end of each epoch where `xxx` is the epoch number counting from 1.
|
| 241 |
+
""",
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--keep-last-k",
|
| 246 |
+
type=int,
|
| 247 |
+
default=30,
|
| 248 |
+
help="""Only keep this number of checkpoints on disk.
|
| 249 |
+
For instance, if it is 3, there are only 3 checkpoints
|
| 250 |
+
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
| 251 |
+
It does not affect checkpoints with name `epoch-xxx.pt`.
|
| 252 |
+
""",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--average-period",
|
| 257 |
+
type=int,
|
| 258 |
+
default=200,
|
| 259 |
+
help="""Update the averaged model, namely `model_avg`, after processing
|
| 260 |
+
this number of batches. `model_avg` is a separate version of model,
|
| 261 |
+
in which each floating-point parameter is the average of all the
|
| 262 |
+
parameters from the start of training. Each time we take the average,
|
| 263 |
+
we do: `model_avg = model * (average_period / batch_idx_train) +
|
| 264 |
+
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
| 265 |
+
""",
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--use-fp16",
|
| 270 |
+
type=str2bool,
|
| 271 |
+
default=True,
|
| 272 |
+
help="Whether to use half precision training.",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--feat-scale",
|
| 277 |
+
type=float,
|
| 278 |
+
default=0.1,
|
| 279 |
+
help="The scale factor of fbank feature",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--ema-decay",
|
| 284 |
+
type=float,
|
| 285 |
+
default=0.9999,
|
| 286 |
+
help="The EMA decay factor of target model in distillation.",
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--distill-stage",
|
| 290 |
+
type=str,
|
| 291 |
+
choices=["first", "second"],
|
| 292 |
+
help="The stage of distillation.",
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--dataset",
|
| 297 |
+
type=str,
|
| 298 |
+
default="emilia",
|
| 299 |
+
choices=["emilia", "libritts", "custom"],
|
| 300 |
+
help="The used training dataset",
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
parser.add_argument(
|
| 304 |
+
"--train-manifest",
|
| 305 |
+
type=str,
|
| 306 |
+
help="Path of the training manifest",
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
parser.add_argument(
|
| 310 |
+
"--dev-manifest",
|
| 311 |
+
type=str,
|
| 312 |
+
help="Path of the validation manifest",
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
parser.add_argument(
|
| 316 |
+
"--min-len",
|
| 317 |
+
type=float,
|
| 318 |
+
default=1.0,
|
| 319 |
+
help="The minimum audio length used for training",
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--max-len",
|
| 324 |
+
type=float,
|
| 325 |
+
default=30.0,
|
| 326 |
+
help="The maximum audio length used for training",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--model-config",
|
| 331 |
+
type=str,
|
| 332 |
+
default="conf/zipvoice_base.json",
|
| 333 |
+
help="The model configuration file.",
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--tokenizer",
|
| 338 |
+
type=str,
|
| 339 |
+
default="emilia",
|
| 340 |
+
choices=["emilia", "libritts", "espeak", "simple"],
|
| 341 |
+
help="Tokenizer type.",
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
parser.add_argument(
|
| 345 |
+
"--lang",
|
| 346 |
+
type=str,
|
| 347 |
+
default="en-us",
|
| 348 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 349 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--lang",
|
| 354 |
+
type=str,
|
| 355 |
+
default="en-us",
|
| 356 |
+
help="Language identifier, used when tokenizer type is espeak. see"
|
| 357 |
+
"https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--token-file",
|
| 362 |
+
type=str,
|
| 363 |
+
default="data/tokens_emilia.txt",
|
| 364 |
+
help="The file that contains information that maps tokens to ids,"
|
| 365 |
+
"which is a text file with '{token}\t{token_id}' per line.",
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
return parser
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def ema(new_model, ema_model, decay):
|
| 372 |
+
if isinstance(new_model, DDP):
|
| 373 |
+
new_model = new_model.module
|
| 374 |
+
if isinstance(ema_model, DDP):
|
| 375 |
+
ema_model = ema_model.module
|
| 376 |
+
new_model_dict = new_model.state_dict()
|
| 377 |
+
ema_model_dict = ema_model.state_dict()
|
| 378 |
+
for key in new_model_dict.keys():
|
| 379 |
+
ema_model_dict[key].data.copy_(
|
| 380 |
+
ema_model_dict[key].data * decay + new_model_dict[key].data * (1 - decay)
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def compute_fbank_loss(
|
| 385 |
+
params: AttributeDict,
|
| 386 |
+
model: Union[nn.Module, DDP],
|
| 387 |
+
teacher_model: Union[nn.Module, DDP],
|
| 388 |
+
features: Tensor,
|
| 389 |
+
features_lens: Tensor,
|
| 390 |
+
tokens: List[List[int]],
|
| 391 |
+
is_training: bool,
|
| 392 |
+
) -> Tuple[Tensor, MetricsTracker]:
|
| 393 |
+
"""
|
| 394 |
+
Compute loss given the model and its inputs.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
params:
|
| 398 |
+
Parameters for training. See :func:`get_params`.
|
| 399 |
+
model:
|
| 400 |
+
The model for training.
|
| 401 |
+
teacher_model:
|
| 402 |
+
The teacher model for distillation.
|
| 403 |
+
features:
|
| 404 |
+
The target acoustic feature.
|
| 405 |
+
features_lens:
|
| 406 |
+
The number of frames of each utterance.
|
| 407 |
+
tokens:
|
| 408 |
+
Input tokens that representing the transcripts.
|
| 409 |
+
is_training:
|
| 410 |
+
True for training. False for validation. When it is True, this
|
| 411 |
+
function enables autograd during computation; when it is False, it
|
| 412 |
+
disables autograd.
|
| 413 |
+
"""
|
| 414 |
+
|
| 415 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 416 |
+
|
| 417 |
+
batch_size, num_frames, _ = features.shape
|
| 418 |
+
|
| 419 |
+
features = torch.nn.functional.pad(
|
| 420 |
+
features, (0, 0, 0, num_frames - features.size(1))
|
| 421 |
+
) # (B, T, F)
|
| 422 |
+
noise = torch.randn_like(features) # (B, T, F)
|
| 423 |
+
|
| 424 |
+
# Sampling t and guidance_scale from uniform distribution
|
| 425 |
+
|
| 426 |
+
t_value = random.random()
|
| 427 |
+
t = torch.ones(batch_size, 1, 1, device=device) * t_value
|
| 428 |
+
if params.distill_stage == "first":
|
| 429 |
+
guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2
|
| 430 |
+
else:
|
| 431 |
+
guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + 1
|
| 432 |
+
xt = features * t + noise * (1 - t)
|
| 433 |
+
t_delta_fix = random.uniform(0.0, min(0.3, 1 - t_value))
|
| 434 |
+
t_delta_ema = random.uniform(0.0, min(0.3, 1 - t_value - t_delta_fix))
|
| 435 |
+
t_dest = t_value + t_delta_fix + t_delta_ema
|
| 436 |
+
|
| 437 |
+
with torch.no_grad():
|
| 438 |
+
speech_condition_mask = condition_time_mask(
|
| 439 |
+
features_lens=features_lens,
|
| 440 |
+
mask_percent=(0.7, 1.0),
|
| 441 |
+
max_len=features.size(1),
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if params.distill_stage == "first":
|
| 445 |
+
teacher_x_t_mid, _ = teacher_model.sample_intermediate(
|
| 446 |
+
tokens=tokens,
|
| 447 |
+
features=features,
|
| 448 |
+
features_lens=features_lens,
|
| 449 |
+
noise=xt,
|
| 450 |
+
speech_condition_mask=speech_condition_mask,
|
| 451 |
+
t_start=t_value,
|
| 452 |
+
t_end=t_value + t_delta_fix,
|
| 453 |
+
num_step=1,
|
| 454 |
+
guidance_scale=guidance_scale,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
target_x1, _ = teacher_model.sample_intermediate(
|
| 458 |
+
tokens=tokens,
|
| 459 |
+
features=features,
|
| 460 |
+
features_lens=features_lens,
|
| 461 |
+
noise=teacher_x_t_mid,
|
| 462 |
+
speech_condition_mask=speech_condition_mask,
|
| 463 |
+
t_start=t_value + t_delta_fix,
|
| 464 |
+
t_end=t_dest,
|
| 465 |
+
num_step=1,
|
| 466 |
+
guidance_scale=guidance_scale,
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
teacher_x_t_mid, _ = teacher_model(
|
| 470 |
+
tokens=tokens,
|
| 471 |
+
features=features,
|
| 472 |
+
features_lens=features_lens,
|
| 473 |
+
noise=xt,
|
| 474 |
+
speech_condition_mask=speech_condition_mask,
|
| 475 |
+
t_start=t_value,
|
| 476 |
+
t_end=t_value + t_delta_fix,
|
| 477 |
+
num_step=1,
|
| 478 |
+
guidance_scale=guidance_scale,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
target_x1, _ = teacher_model(
|
| 482 |
+
tokens=tokens,
|
| 483 |
+
features=features,
|
| 484 |
+
features_lens=features_lens,
|
| 485 |
+
noise=teacher_x_t_mid,
|
| 486 |
+
speech_condition_mask=speech_condition_mask,
|
| 487 |
+
t_start=t_value + t_delta_fix,
|
| 488 |
+
t_end=t_dest,
|
| 489 |
+
num_step=1,
|
| 490 |
+
guidance_scale=guidance_scale,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
with torch.set_grad_enabled(is_training):
|
| 494 |
+
|
| 495 |
+
pred_x1, _ = model(
|
| 496 |
+
tokens=tokens,
|
| 497 |
+
features=features,
|
| 498 |
+
features_lens=features_lens,
|
| 499 |
+
noise=xt,
|
| 500 |
+
speech_condition_mask=speech_condition_mask,
|
| 501 |
+
t_start=t,
|
| 502 |
+
t_end=t_dest,
|
| 503 |
+
num_step=1,
|
| 504 |
+
guidance_scale=guidance_scale,
|
| 505 |
+
)
|
| 506 |
+
pred_v = (pred_x1 - xt) / (t_dest - t)
|
| 507 |
+
|
| 508 |
+
padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
|
| 509 |
+
loss_mask = speech_condition_mask & (~padding_mask)
|
| 510 |
+
|
| 511 |
+
target_v = (target_x1 - xt) / (t_dest - t)
|
| 512 |
+
loss = torch.mean((pred_v[loss_mask] - target_v[loss_mask]) ** 2)
|
| 513 |
+
|
| 514 |
+
ut = features - noise # (B, T, F)
|
| 515 |
+
|
| 516 |
+
ref_loss = torch.mean((pred_v[loss_mask] - ut[loss_mask]) ** 2)
|
| 517 |
+
|
| 518 |
+
assert loss.requires_grad == is_training
|
| 519 |
+
info = MetricsTracker()
|
| 520 |
+
num_frames = features_lens.sum().item()
|
| 521 |
+
info["frames"] = num_frames
|
| 522 |
+
info["loss"] = loss.detach().cpu().item() * num_frames
|
| 523 |
+
info["ref_loss"] = ref_loss.detach().cpu().item() * num_frames
|
| 524 |
+
return loss, info
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def train_one_epoch(
|
| 528 |
+
params: AttributeDict,
|
| 529 |
+
model: Union[nn.Module, DDP],
|
| 530 |
+
teacher_model: Union[nn.Module, DDP],
|
| 531 |
+
optimizer: Optimizer,
|
| 532 |
+
scheduler: LRSchedulerType,
|
| 533 |
+
train_dl: torch.utils.data.DataLoader,
|
| 534 |
+
valid_dl: torch.utils.data.DataLoader,
|
| 535 |
+
scaler: GradScaler,
|
| 536 |
+
model_avg: Optional[nn.Module] = None,
|
| 537 |
+
tb_writer: Optional[SummaryWriter] = None,
|
| 538 |
+
world_size: int = 1,
|
| 539 |
+
rank: int = 0,
|
| 540 |
+
) -> None:
|
| 541 |
+
"""Train the model for one epoch.
|
| 542 |
+
|
| 543 |
+
The training loss from the mean of all frames is saved in
|
| 544 |
+
`params.train_loss`. It runs the validation process every
|
| 545 |
+
`params.valid_interval` batches.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
params:
|
| 549 |
+
It is returned by :func:`get_params`.
|
| 550 |
+
model:
|
| 551 |
+
The model for training.
|
| 552 |
+
teacher_model:
|
| 553 |
+
The model for distillation.
|
| 554 |
+
Used to convert text to tokens.
|
| 555 |
+
optimizer:
|
| 556 |
+
The optimizer.
|
| 557 |
+
scheduler:
|
| 558 |
+
The learning rate scheduler, we call step() every epoch.
|
| 559 |
+
train_dl:
|
| 560 |
+
Dataloader for the training dataset.
|
| 561 |
+
valid_dl:
|
| 562 |
+
Dataloader for the validation dataset.
|
| 563 |
+
scaler:
|
| 564 |
+
The scaler used for mix precision training.
|
| 565 |
+
tb_writer:
|
| 566 |
+
Writer to write log messages to tensorboard.
|
| 567 |
+
world_size:
|
| 568 |
+
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
| 569 |
+
rank:
|
| 570 |
+
The rank of the node in DDP training. If no DDP is used, it should
|
| 571 |
+
be set to 0.
|
| 572 |
+
"""
|
| 573 |
+
model.train()
|
| 574 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 575 |
+
|
| 576 |
+
# used to track the stats over iterations in one epoch
|
| 577 |
+
tot_loss = MetricsTracker()
|
| 578 |
+
|
| 579 |
+
saved_bad_model = False
|
| 580 |
+
|
| 581 |
+
def save_bad_model(suffix: str = ""):
|
| 582 |
+
save_checkpoint(
|
| 583 |
+
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
| 584 |
+
model=model,
|
| 585 |
+
model_avg=model_avg,
|
| 586 |
+
model_ema=teacher_model,
|
| 587 |
+
params=params,
|
| 588 |
+
optimizer=optimizer,
|
| 589 |
+
scheduler=scheduler,
|
| 590 |
+
sampler=train_dl.sampler,
|
| 591 |
+
scaler=scaler,
|
| 592 |
+
rank=0,
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
for batch_idx, batch in enumerate(train_dl):
|
| 596 |
+
|
| 597 |
+
if batch_idx % 10 == 0:
|
| 598 |
+
set_batch_count(model, get_adjusted_batch_count(params) + 100000)
|
| 599 |
+
|
| 600 |
+
if (
|
| 601 |
+
params.batch_idx_train % params.valid_interval == 0
|
| 602 |
+
and not params.print_diagnostics
|
| 603 |
+
):
|
| 604 |
+
logging.info("Computing validation loss")
|
| 605 |
+
valid_info = compute_validation_loss(
|
| 606 |
+
params=params,
|
| 607 |
+
model=model,
|
| 608 |
+
teacher_model=teacher_model,
|
| 609 |
+
valid_dl=valid_dl,
|
| 610 |
+
world_size=world_size,
|
| 611 |
+
)
|
| 612 |
+
model.train()
|
| 613 |
+
logging.info(
|
| 614 |
+
f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
|
| 615 |
+
f" validation: {valid_info}"
|
| 616 |
+
)
|
| 617 |
+
logging.info(
|
| 618 |
+
f"Maximum memory allocated so far is "
|
| 619 |
+
f"{torch.cuda.max_memory_allocated() // 1000000}MB"
|
| 620 |
+
)
|
| 621 |
+
if tb_writer is not None:
|
| 622 |
+
valid_info.write_summary(
|
| 623 |
+
tb_writer, "train/valid_", params.batch_idx_train
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
params.batch_idx_train += 1
|
| 627 |
+
|
| 628 |
+
batch_size = len(batch["text"])
|
| 629 |
+
|
| 630 |
+
tokens, features, features_lens = prepare_input(
|
| 631 |
+
params=params,
|
| 632 |
+
batch=batch,
|
| 633 |
+
device=device,
|
| 634 |
+
return_tokens=True,
|
| 635 |
+
return_feature=True,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
try:
|
| 639 |
+
with autocast("cuda", enabled=params.use_fp16):
|
| 640 |
+
loss, loss_info = compute_fbank_loss(
|
| 641 |
+
params=params,
|
| 642 |
+
model=model,
|
| 643 |
+
teacher_model=teacher_model,
|
| 644 |
+
features=features,
|
| 645 |
+
features_lens=features_lens,
|
| 646 |
+
tokens=tokens,
|
| 647 |
+
is_training=True,
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
| 651 |
+
|
| 652 |
+
scaler.scale(loss).backward()
|
| 653 |
+
|
| 654 |
+
scheduler.step_batch(params.batch_idx_train)
|
| 655 |
+
scaler.step(optimizer)
|
| 656 |
+
scaler.update()
|
| 657 |
+
optimizer.zero_grad()
|
| 658 |
+
if params.distill_stage == "second":
|
| 659 |
+
ema(model, teacher_model, params.ema_decay)
|
| 660 |
+
except Exception as e:
|
| 661 |
+
logging.info(f"Caught exception : {e}.")
|
| 662 |
+
save_bad_model()
|
| 663 |
+
raise
|
| 664 |
+
|
| 665 |
+
if params.print_diagnostics and batch_idx == 5:
|
| 666 |
+
return
|
| 667 |
+
|
| 668 |
+
if (
|
| 669 |
+
rank == 0
|
| 670 |
+
and params.batch_idx_train > 0
|
| 671 |
+
and params.batch_idx_train % params.average_period == 0
|
| 672 |
+
):
|
| 673 |
+
update_averaged_model(
|
| 674 |
+
params=params,
|
| 675 |
+
model_cur=model,
|
| 676 |
+
model_avg=model_avg,
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
if (
|
| 680 |
+
params.batch_idx_train > 0
|
| 681 |
+
and params.batch_idx_train % params.save_every_n == 0
|
| 682 |
+
):
|
| 683 |
+
save_checkpoint_with_global_batch_idx(
|
| 684 |
+
out_dir=params.exp_dir,
|
| 685 |
+
global_batch_idx=params.batch_idx_train,
|
| 686 |
+
model=model,
|
| 687 |
+
model_avg=model_avg,
|
| 688 |
+
params=params,
|
| 689 |
+
optimizer=optimizer,
|
| 690 |
+
scheduler=scheduler,
|
| 691 |
+
sampler=train_dl.sampler,
|
| 692 |
+
scaler=scaler,
|
| 693 |
+
rank=rank,
|
| 694 |
+
)
|
| 695 |
+
remove_checkpoints(
|
| 696 |
+
out_dir=params.exp_dir,
|
| 697 |
+
topk=params.keep_last_k,
|
| 698 |
+
rank=rank,
|
| 699 |
+
)
|
| 700 |
+
if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
|
| 701 |
+
break
|
| 702 |
+
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
| 703 |
+
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
| 704 |
+
# of the grad scaler is configurable, but we can't configure it to have
|
| 705 |
+
# different behavior depending on the current grad scale.
|
| 706 |
+
cur_grad_scale = scaler._scale.item()
|
| 707 |
+
|
| 708 |
+
if cur_grad_scale < 1024.0 or (
|
| 709 |
+
cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
|
| 710 |
+
):
|
| 711 |
+
scaler.update(cur_grad_scale * 2.0)
|
| 712 |
+
if cur_grad_scale < 0.01:
|
| 713 |
+
if not saved_bad_model:
|
| 714 |
+
save_bad_model(suffix="-first-warning")
|
| 715 |
+
saved_bad_model = True
|
| 716 |
+
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
| 717 |
+
if cur_grad_scale < 1.0e-05:
|
| 718 |
+
save_bad_model()
|
| 719 |
+
raise RuntimeError(
|
| 720 |
+
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
if params.batch_idx_train % params.log_interval == 0:
|
| 724 |
+
cur_lr = max(scheduler.get_last_lr())
|
| 725 |
+
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
| 726 |
+
|
| 727 |
+
logging.info(
|
| 728 |
+
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
| 729 |
+
f"global_batch_idx: {params.batch_idx_train}, "
|
| 730 |
+
f"batch size: {batch_size}, "
|
| 731 |
+
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
| 732 |
+
f"cur_lr: {cur_lr:.2e}, "
|
| 733 |
+
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if tb_writer is not None:
|
| 737 |
+
tb_writer.add_scalar(
|
| 738 |
+
"train/learning_rate", cur_lr, params.batch_idx_train
|
| 739 |
+
)
|
| 740 |
+
loss_info.write_summary(
|
| 741 |
+
tb_writer, "train/current_", params.batch_idx_train
|
| 742 |
+
)
|
| 743 |
+
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
| 744 |
+
if params.use_fp16:
|
| 745 |
+
tb_writer.add_scalar(
|
| 746 |
+
"train/grad_scale",
|
| 747 |
+
cur_grad_scale,
|
| 748 |
+
params.batch_idx_train,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
loss_value = tot_loss["loss"]
|
| 752 |
+
params.train_loss = loss_value
|
| 753 |
+
if params.train_loss < params.best_train_loss:
|
| 754 |
+
params.best_train_epoch = params.cur_epoch
|
| 755 |
+
params.best_train_loss = params.train_loss
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def compute_validation_loss(
|
| 759 |
+
params: AttributeDict,
|
| 760 |
+
model: Union[nn.Module, DDP],
|
| 761 |
+
teacher_model: Optional[nn.Module],
|
| 762 |
+
valid_dl: torch.utils.data.DataLoader,
|
| 763 |
+
world_size: int = 1,
|
| 764 |
+
) -> MetricsTracker:
|
| 765 |
+
"""Run the validation process."""
|
| 766 |
+
|
| 767 |
+
model.eval()
|
| 768 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 769 |
+
|
| 770 |
+
# used to summary the stats over iterations
|
| 771 |
+
tot_loss = MetricsTracker()
|
| 772 |
+
|
| 773 |
+
for batch_idx, batch in enumerate(valid_dl):
|
| 774 |
+
tokens, features, features_lens = prepare_input(
|
| 775 |
+
params=params,
|
| 776 |
+
batch=batch,
|
| 777 |
+
device=device,
|
| 778 |
+
return_tokens=True,
|
| 779 |
+
return_feature=True,
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
loss, loss_info = compute_fbank_loss(
|
| 783 |
+
params=params,
|
| 784 |
+
model=model,
|
| 785 |
+
teacher_model=teacher_model,
|
| 786 |
+
features=features,
|
| 787 |
+
features_lens=features_lens,
|
| 788 |
+
tokens=tokens,
|
| 789 |
+
is_training=False,
|
| 790 |
+
)
|
| 791 |
+
assert loss.requires_grad is False
|
| 792 |
+
tot_loss = tot_loss + loss_info
|
| 793 |
+
|
| 794 |
+
if world_size > 1:
|
| 795 |
+
tot_loss.reduce(loss.device)
|
| 796 |
+
|
| 797 |
+
loss_value = tot_loss["loss"]
|
| 798 |
+
if loss_value < params.best_valid_loss:
|
| 799 |
+
params.best_valid_epoch = params.cur_epoch
|
| 800 |
+
params.best_valid_loss = loss_value
|
| 801 |
+
|
| 802 |
+
return tot_loss
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def scan_pessimistic_batches_for_oom(
|
| 806 |
+
model: Union[nn.Module, DDP],
|
| 807 |
+
teacher_model: nn.Module,
|
| 808 |
+
train_dl: torch.utils.data.DataLoader,
|
| 809 |
+
optimizer: torch.optim.Optimizer,
|
| 810 |
+
params: AttributeDict,
|
| 811 |
+
):
|
| 812 |
+
from lhotse.dataset import find_pessimistic_batches
|
| 813 |
+
|
| 814 |
+
logging.info(
|
| 815 |
+
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
| 816 |
+
)
|
| 817 |
+
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
| 818 |
+
|
| 819 |
+
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
| 820 |
+
for criterion, cuts in batches.items():
|
| 821 |
+
batch = train_dl.dataset[cuts]
|
| 822 |
+
tokens, features, features_lens = prepare_input(
|
| 823 |
+
params=params,
|
| 824 |
+
batch=batch,
|
| 825 |
+
device=device,
|
| 826 |
+
return_tokens=True,
|
| 827 |
+
return_feature=True,
|
| 828 |
+
)
|
| 829 |
+
try:
|
| 830 |
+
with autocast("cuda", enabled=params.use_fp16):
|
| 831 |
+
|
| 832 |
+
loss, loss_info = compute_fbank_loss(
|
| 833 |
+
params=params,
|
| 834 |
+
model=model,
|
| 835 |
+
teacher_model=teacher_model,
|
| 836 |
+
features=features,
|
| 837 |
+
features_lens=features_lens,
|
| 838 |
+
tokens=tokens,
|
| 839 |
+
is_training=True,
|
| 840 |
+
)
|
| 841 |
+
loss.backward()
|
| 842 |
+
optimizer.zero_grad()
|
| 843 |
+
except Exception as e:
|
| 844 |
+
if "CUDA out of memory" in str(e):
|
| 845 |
+
logging.error(
|
| 846 |
+
"Your GPU ran out of memory with the current "
|
| 847 |
+
"max_duration setting. We recommend decreasing "
|
| 848 |
+
"max_duration and trying again.\n"
|
| 849 |
+
f"Failing criterion: {criterion} "
|
| 850 |
+
f"(={crit_values[criterion]}) ..."
|
| 851 |
+
)
|
| 852 |
+
display_and_save_batch(batch, params=params)
|
| 853 |
+
raise
|
| 854 |
+
logging.info(
|
| 855 |
+
f"Maximum memory allocated so far is "
|
| 856 |
+
f"{torch.cuda.max_memory_allocated() // 1000000}MB"
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
def run(rank, world_size, args):
|
| 861 |
+
"""
|
| 862 |
+
Args:
|
| 863 |
+
rank:
|
| 864 |
+
It is a value between 0 and `world_size-1`, which is
|
| 865 |
+
passed automatically by `mp.spawn()` in :func:`main`.
|
| 866 |
+
The node with rank 0 is responsible for saving checkpoint.
|
| 867 |
+
world_size:
|
| 868 |
+
Number of GPUs for DDP training.
|
| 869 |
+
args:
|
| 870 |
+
The return value of get_parser().parse_args()
|
| 871 |
+
"""
|
| 872 |
+
params = get_params()
|
| 873 |
+
params.update(vars(args))
|
| 874 |
+
params.valid_interval = params.save_every_n
|
| 875 |
+
# Set epoch to a large number to ignore it.
|
| 876 |
+
if params.num_iters > 0:
|
| 877 |
+
params.num_epochs = 1000000
|
| 878 |
+
with open(params.model_config, "r") as f:
|
| 879 |
+
model_config = json.load(f)
|
| 880 |
+
params.update(model_config["model"])
|
| 881 |
+
params.update(model_config["feature"])
|
| 882 |
+
|
| 883 |
+
fix_random_seed(params.seed)
|
| 884 |
+
if world_size > 1:
|
| 885 |
+
setup_dist(rank, world_size, params.master_port)
|
| 886 |
+
|
| 887 |
+
os.makedirs(f"{params.exp_dir}", exist_ok=True)
|
| 888 |
+
copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
|
| 889 |
+
copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
|
| 890 |
+
setup_logger(f"{params.exp_dir}/log/log-train")
|
| 891 |
+
|
| 892 |
+
if args.tensorboard and rank == 0:
|
| 893 |
+
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
| 894 |
+
else:
|
| 895 |
+
tb_writer = None
|
| 896 |
+
|
| 897 |
+
if torch.cuda.is_available():
|
| 898 |
+
params.device = torch.device("cuda", rank)
|
| 899 |
+
else:
|
| 900 |
+
params.device = torch.device("cpu")
|
| 901 |
+
logging.info(f"Device: {params.device}")
|
| 902 |
+
|
| 903 |
+
if params.tokenizer == "emilia":
|
| 904 |
+
tokenizer = EmiliaTokenizer(token_file=params.token_file)
|
| 905 |
+
elif params.tokenizer == "libritts":
|
| 906 |
+
tokenizer = LibriTTSTokenizer(token_file=params.token_file)
|
| 907 |
+
elif params.tokenizer == "espeak":
|
| 908 |
+
tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
|
| 909 |
+
else:
|
| 910 |
+
assert params.tokenizer == "simple"
|
| 911 |
+
tokenizer = SimpleTokenizer(token_file=params.token_file)
|
| 912 |
+
|
| 913 |
+
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
|
| 914 |
+
params.update(tokenizer_config)
|
| 915 |
+
|
| 916 |
+
logging.info(params)
|
| 917 |
+
|
| 918 |
+
logging.info("About to create model")
|
| 919 |
+
|
| 920 |
+
assert params.teacher_model is not None
|
| 921 |
+
logging.info(f"Loading pre-trained model from {params.teacher_model}")
|
| 922 |
+
model = ZipVoiceDistill(
|
| 923 |
+
**model_config["model"],
|
| 924 |
+
**tokenizer_config,
|
| 925 |
+
)
|
| 926 |
+
_ = load_checkpoint(
|
| 927 |
+
filename=params.teacher_model,
|
| 928 |
+
model=model,
|
| 929 |
+
strict=(params.distill_stage == "second"),
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
if params.distill_stage == "first":
|
| 933 |
+
teacher_model = ZipVoice(
|
| 934 |
+
**model_config["model"],
|
| 935 |
+
**tokenizer_config,
|
| 936 |
+
)
|
| 937 |
+
_ = load_checkpoint(
|
| 938 |
+
filename=params.teacher_model, model=teacher_model, strict=True
|
| 939 |
+
)
|
| 940 |
+
else:
|
| 941 |
+
teacher_model = copy.deepcopy(model)
|
| 942 |
+
|
| 943 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
| 944 |
+
logging.info(f"Number of parameters : {num_param}")
|
| 945 |
+
|
| 946 |
+
model_avg: Optional[nn.Module] = None
|
| 947 |
+
if rank == 0:
|
| 948 |
+
# model_avg is only used with rank 0
|
| 949 |
+
model_avg = copy.deepcopy(model).to(torch.float64)
|
| 950 |
+
assert params.start_epoch > 0, params.start_epoch
|
| 951 |
+
if params.start_epoch > 1:
|
| 952 |
+
logging.info(f"Resuming from epoch {params.start_epoch}")
|
| 953 |
+
if params.distill_stage == "first":
|
| 954 |
+
checkpoints = resume_checkpoint(
|
| 955 |
+
params=params, model=model, model_avg=model_avg
|
| 956 |
+
)
|
| 957 |
+
else:
|
| 958 |
+
checkpoints = resume_checkpoint(
|
| 959 |
+
params=params,
|
| 960 |
+
model=model,
|
| 961 |
+
model_avg=model_avg,
|
| 962 |
+
model_ema=teacher_model,
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
model = model.to(params.device)
|
| 966 |
+
teacher_model.to(params.device)
|
| 967 |
+
teacher_model.eval()
|
| 968 |
+
|
| 969 |
+
if world_size > 1:
|
| 970 |
+
logging.info("Using DDP")
|
| 971 |
+
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
| 972 |
+
|
| 973 |
+
# only update the fm_decoder
|
| 974 |
+
num_trainable = 0
|
| 975 |
+
for name, p in model.named_parameters():
|
| 976 |
+
if "fm_decoder" in name:
|
| 977 |
+
p.requires_grad = True
|
| 978 |
+
num_trainable += p.numel()
|
| 979 |
+
else:
|
| 980 |
+
p.requires_grad = False
|
| 981 |
+
|
| 982 |
+
logging.info(
|
| 983 |
+
"A total of {} trainable parameters ({:.3f}% of the whole model)".format(
|
| 984 |
+
num_trainable, num_trainable / num_param * 100
|
| 985 |
+
)
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
optimizer = ScaledAdam(
|
| 989 |
+
get_parameter_groups_with_lrs(
|
| 990 |
+
model,
|
| 991 |
+
lr=params.base_lr,
|
| 992 |
+
include_names=True,
|
| 993 |
+
),
|
| 994 |
+
lr=params.base_lr, # should have no effect
|
| 995 |
+
clipping_scale=2.0,
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
scheduler = FixedLRScheduler(optimizer)
|
| 999 |
+
|
| 1000 |
+
scaler = GradScaler("cuda", enabled=params.use_fp16)
|
| 1001 |
+
|
| 1002 |
+
if params.start_epoch > 1 and checkpoints is not None:
|
| 1003 |
+
# load state_dict for optimizers
|
| 1004 |
+
if "optimizer" in checkpoints:
|
| 1005 |
+
logging.info("Loading optimizer state dict")
|
| 1006 |
+
optimizer.load_state_dict(checkpoints["optimizer"])
|
| 1007 |
+
|
| 1008 |
+
# load state_dict for schedulers
|
| 1009 |
+
if "scheduler" in checkpoints:
|
| 1010 |
+
logging.info("Loading scheduler state dict")
|
| 1011 |
+
scheduler.load_state_dict(checkpoints["scheduler"])
|
| 1012 |
+
|
| 1013 |
+
if "grad_scaler" in checkpoints:
|
| 1014 |
+
logging.info("Loading grad scaler state dict")
|
| 1015 |
+
scaler.load_state_dict(checkpoints["grad_scaler"])
|
| 1016 |
+
|
| 1017 |
+
if params.print_diagnostics:
|
| 1018 |
+
opts = diagnostics.TensorDiagnosticOptions(
|
| 1019 |
+
512
|
| 1020 |
+
) # allow 4 megabytes per sub-module
|
| 1021 |
+
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
| 1022 |
+
|
| 1023 |
+
if params.inf_check:
|
| 1024 |
+
register_inf_check_hooks(model)
|
| 1025 |
+
|
| 1026 |
+
def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
|
| 1027 |
+
if c.duration < min_len or c.duration > max_len:
|
| 1028 |
+
return False
|
| 1029 |
+
return True
|
| 1030 |
+
|
| 1031 |
+
_remove_short_and_long_utt = partial(
|
| 1032 |
+
remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
datamodule = TtsDataModule(args)
|
| 1036 |
+
if params.dataset == "emilia":
|
| 1037 |
+
train_cuts = CutSet.mux(
|
| 1038 |
+
datamodule.train_emilia_EN_cuts(),
|
| 1039 |
+
datamodule.train_emilia_ZH_cuts(),
|
| 1040 |
+
weights=[46000, 49000],
|
| 1041 |
+
)
|
| 1042 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 1043 |
+
dev_cuts = CutSet.mux(
|
| 1044 |
+
datamodule.dev_emilia_EN_cuts(),
|
| 1045 |
+
datamodule.dev_emilia_ZH_cuts(),
|
| 1046 |
+
weights=[0.5, 0.5],
|
| 1047 |
+
)
|
| 1048 |
+
elif params.dataset == "libritts":
|
| 1049 |
+
train_cuts = datamodule.train_libritts_cuts()
|
| 1050 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 1051 |
+
dev_cuts = datamodule.dev_libritts_cuts()
|
| 1052 |
+
else:
|
| 1053 |
+
assert params.dataset == "custom"
|
| 1054 |
+
train_cuts = datamodule.train_custom_cuts(params.train_manifest)
|
| 1055 |
+
train_cuts = train_cuts.filter(_remove_short_and_long_utt)
|
| 1056 |
+
dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
|
| 1057 |
+
# To avoid OOM issues due to too long dev cuts
|
| 1058 |
+
dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
|
| 1059 |
+
|
| 1060 |
+
_tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
|
| 1061 |
+
train_cuts = train_cuts.map(_tokenize_text)
|
| 1062 |
+
dev_cuts = dev_cuts.map(_tokenize_text)
|
| 1063 |
+
|
| 1064 |
+
train_dl = datamodule.train_dataloaders(train_cuts)
|
| 1065 |
+
|
| 1066 |
+
valid_dl = datamodule.dev_dataloaders(dev_cuts)
|
| 1067 |
+
|
| 1068 |
+
if params.scan_oom:
|
| 1069 |
+
scan_pessimistic_batches_for_oom(
|
| 1070 |
+
model=model,
|
| 1071 |
+
teacher_model=teacher_model,
|
| 1072 |
+
train_dl=train_dl,
|
| 1073 |
+
optimizer=optimizer,
|
| 1074 |
+
params=params,
|
| 1075 |
+
)
|
| 1076 |
+
logging.info("Training started")
|
| 1077 |
+
|
| 1078 |
+
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
| 1079 |
+
logging.info(f"Start epoch {epoch}")
|
| 1080 |
+
|
| 1081 |
+
scheduler.step_epoch(epoch - 1)
|
| 1082 |
+
fix_random_seed(params.seed + epoch - 1)
|
| 1083 |
+
train_dl.sampler.set_epoch(epoch - 1)
|
| 1084 |
+
|
| 1085 |
+
params.cur_epoch = epoch
|
| 1086 |
+
|
| 1087 |
+
if tb_writer is not None:
|
| 1088 |
+
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
| 1089 |
+
|
| 1090 |
+
train_one_epoch(
|
| 1091 |
+
params=params,
|
| 1092 |
+
model=model,
|
| 1093 |
+
model_avg=model_avg,
|
| 1094 |
+
teacher_model=teacher_model,
|
| 1095 |
+
optimizer=optimizer,
|
| 1096 |
+
scheduler=scheduler,
|
| 1097 |
+
train_dl=train_dl,
|
| 1098 |
+
valid_dl=valid_dl,
|
| 1099 |
+
scaler=scaler,
|
| 1100 |
+
tb_writer=tb_writer,
|
| 1101 |
+
world_size=world_size,
|
| 1102 |
+
rank=rank,
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
|
| 1106 |
+
break
|
| 1107 |
+
|
| 1108 |
+
if params.print_diagnostics:
|
| 1109 |
+
diagnostic.print_diagnostics()
|
| 1110 |
+
break
|
| 1111 |
+
|
| 1112 |
+
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
| 1113 |
+
save_checkpoint(
|
| 1114 |
+
filename=filename,
|
| 1115 |
+
params=params,
|
| 1116 |
+
model=model,
|
| 1117 |
+
model_avg=model_avg,
|
| 1118 |
+
model_ema=teacher_model,
|
| 1119 |
+
optimizer=optimizer,
|
| 1120 |
+
scheduler=scheduler,
|
| 1121 |
+
sampler=train_dl.sampler,
|
| 1122 |
+
scaler=scaler,
|
| 1123 |
+
rank=rank,
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
if rank == 0:
|
| 1127 |
+
if params.best_train_epoch == params.cur_epoch:
|
| 1128 |
+
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
| 1129 |
+
copyfile(src=filename, dst=best_train_filename)
|
| 1130 |
+
|
| 1131 |
+
if params.best_valid_epoch == params.cur_epoch:
|
| 1132 |
+
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
| 1133 |
+
copyfile(src=filename, dst=best_valid_filename)
|
| 1134 |
+
|
| 1135 |
+
logging.info("Done!")
|
| 1136 |
+
|
| 1137 |
+
if world_size > 1:
|
| 1138 |
+
torch.distributed.barrier()
|
| 1139 |
+
cleanup_dist()
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
def main():
|
| 1143 |
+
parser = get_parser()
|
| 1144 |
+
TtsDataModule.add_arguments(parser)
|
| 1145 |
+
args = parser.parse_args()
|
| 1146 |
+
args.exp_dir = Path(args.exp_dir)
|
| 1147 |
+
|
| 1148 |
+
world_size = args.world_size
|
| 1149 |
+
assert world_size >= 1
|
| 1150 |
+
if world_size > 1:
|
| 1151 |
+
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
| 1152 |
+
else:
|
| 1153 |
+
run(rank=0, world_size=1, args=args)
|
| 1154 |
+
|
| 1155 |
+
|
| 1156 |
+
if __name__ == "__main__":
|
| 1157 |
+
torch.set_num_threads(1)
|
| 1158 |
+
torch.set_num_interop_threads(1)
|
| 1159 |
+
main()
|
zipvoice/dataset/datamodule.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 Piotr Żelasko
|
| 2 |
+
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
| 3 |
+
# Zengwei Yao,
|
| 4 |
+
# Zengrui Jin,
|
| 5 |
+
# Han Zhu,
|
| 6 |
+
# Wei Kang)
|
| 7 |
+
#
|
| 8 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import logging
|
| 25 |
+
from functools import lru_cache
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Dict, Optional
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from lhotse import CutSet, load_manifest_lazy
|
| 31 |
+
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
|
| 32 |
+
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
|
| 33 |
+
from lhotse.utils import fix_random_seed
|
| 34 |
+
from torch.utils.data import DataLoader
|
| 35 |
+
|
| 36 |
+
from zipvoice.dataset.dataset import SpeechSynthesisDataset
|
| 37 |
+
from zipvoice.utils.common import str2bool
|
| 38 |
+
from zipvoice.utils.feature import VocosFbank
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class _SeedWorkers:
|
| 42 |
+
def __init__(self, seed: int):
|
| 43 |
+
self.seed = seed
|
| 44 |
+
|
| 45 |
+
def __call__(self, worker_id: int):
|
| 46 |
+
fix_random_seed(self.seed + worker_id)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
SAMPLING_RATE = 24000
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TtsDataModule:
|
| 53 |
+
"""
|
| 54 |
+
DataModule for tts experiments.
|
| 55 |
+
It assumes there is always one train and valid dataloader,
|
| 56 |
+
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
| 57 |
+
and test-other).
|
| 58 |
+
|
| 59 |
+
It contains all the common data pipeline modules used in ASR
|
| 60 |
+
experiments, e.g.:
|
| 61 |
+
- dynamic batch size,
|
| 62 |
+
- bucketing samplers,
|
| 63 |
+
- cut concatenation,
|
| 64 |
+
- on-the-fly feature extraction
|
| 65 |
+
|
| 66 |
+
This class should be derived for specific corpora used in ASR tasks.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, args: argparse.Namespace):
|
| 70 |
+
self.args = args
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def add_arguments(cls, parser: argparse.ArgumentParser):
|
| 74 |
+
group = parser.add_argument_group(
|
| 75 |
+
title="TTS data related options",
|
| 76 |
+
description="These options are used for the preparation of "
|
| 77 |
+
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
| 78 |
+
"effective batch sizes, sampling strategies, applied data "
|
| 79 |
+
"augmentations, etc.",
|
| 80 |
+
)
|
| 81 |
+
group.add_argument(
|
| 82 |
+
"--manifest-dir",
|
| 83 |
+
type=Path,
|
| 84 |
+
default=Path("data/fbank"),
|
| 85 |
+
help="Path to directory with train/valid/test cuts.",
|
| 86 |
+
)
|
| 87 |
+
group.add_argument(
|
| 88 |
+
"--max-duration",
|
| 89 |
+
type=int,
|
| 90 |
+
default=200.0,
|
| 91 |
+
help="Maximum pooled recordings duration (seconds) in a "
|
| 92 |
+
"single batch. You can reduce it if it causes CUDA OOM.",
|
| 93 |
+
)
|
| 94 |
+
group.add_argument(
|
| 95 |
+
"--bucketing-sampler",
|
| 96 |
+
type=str2bool,
|
| 97 |
+
default=True,
|
| 98 |
+
help="When enabled, the batches will come from buckets of "
|
| 99 |
+
"similar duration (saves padding frames).",
|
| 100 |
+
)
|
| 101 |
+
group.add_argument(
|
| 102 |
+
"--num-buckets",
|
| 103 |
+
type=int,
|
| 104 |
+
default=30,
|
| 105 |
+
help="The number of buckets for the DynamicBucketingSampler"
|
| 106 |
+
"(you might want to increase it for larger datasets).",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
group.add_argument(
|
| 110 |
+
"--on-the-fly-feats",
|
| 111 |
+
type=str2bool,
|
| 112 |
+
default=False,
|
| 113 |
+
help="When enabled, use on-the-fly cut mixing and feature "
|
| 114 |
+
"extraction. Will drop existing precomputed feature manifests "
|
| 115 |
+
"if available.",
|
| 116 |
+
)
|
| 117 |
+
group.add_argument(
|
| 118 |
+
"--shuffle",
|
| 119 |
+
type=str2bool,
|
| 120 |
+
default=True,
|
| 121 |
+
help="When enabled (=default), the examples will be "
|
| 122 |
+
"shuffled for each epoch.",
|
| 123 |
+
)
|
| 124 |
+
group.add_argument(
|
| 125 |
+
"--drop-last",
|
| 126 |
+
type=str2bool,
|
| 127 |
+
default=True,
|
| 128 |
+
help="Whether to drop last batch. Used by sampler.",
|
| 129 |
+
)
|
| 130 |
+
group.add_argument(
|
| 131 |
+
"--return-cuts",
|
| 132 |
+
type=str2bool,
|
| 133 |
+
default=False,
|
| 134 |
+
help="When enabled, each batch will have the "
|
| 135 |
+
"field: batch['cut'] with the cuts that "
|
| 136 |
+
"were used to construct it.",
|
| 137 |
+
)
|
| 138 |
+
group.add_argument(
|
| 139 |
+
"--num-workers",
|
| 140 |
+
type=int,
|
| 141 |
+
default=8,
|
| 142 |
+
help="The number of training dataloader workers that "
|
| 143 |
+
"collect the batches.",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
group.add_argument(
|
| 147 |
+
"--input-strategy",
|
| 148 |
+
type=str,
|
| 149 |
+
default="PrecomputedFeatures",
|
| 150 |
+
help="AudioSamples or PrecomputedFeatures",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def train_dataloaders(
|
| 154 |
+
self,
|
| 155 |
+
cuts_train: CutSet,
|
| 156 |
+
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
| 157 |
+
) -> DataLoader:
|
| 158 |
+
"""
|
| 159 |
+
Args:
|
| 160 |
+
cuts_train:
|
| 161 |
+
CutSet for training.
|
| 162 |
+
sampler_state_dict:
|
| 163 |
+
The state dict for the training sampler.
|
| 164 |
+
"""
|
| 165 |
+
logging.info("About to create train dataset")
|
| 166 |
+
|
| 167 |
+
train = SpeechSynthesisDataset(
|
| 168 |
+
return_text=True,
|
| 169 |
+
return_tokens=True,
|
| 170 |
+
return_spk_ids=True,
|
| 171 |
+
feature_input_strategy=OnTheFlyFeatures(VocosFbank())
|
| 172 |
+
if self.args.on_the_fly_feats
|
| 173 |
+
else PrecomputedFeatures(),
|
| 174 |
+
return_cuts=self.args.return_cuts,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if self.args.bucketing_sampler:
|
| 178 |
+
logging.info("Using DynamicBucketingSampler.")
|
| 179 |
+
train_sampler = DynamicBucketingSampler(
|
| 180 |
+
cuts_train,
|
| 181 |
+
max_duration=self.args.max_duration,
|
| 182 |
+
shuffle=self.args.shuffle,
|
| 183 |
+
num_buckets=self.args.num_buckets,
|
| 184 |
+
buffer_size=self.args.num_buckets * 2000,
|
| 185 |
+
shuffle_buffer_size=self.args.num_buckets * 5000,
|
| 186 |
+
drop_last=self.args.drop_last,
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
logging.info("Using SimpleCutSampler.")
|
| 190 |
+
train_sampler = SimpleCutSampler(
|
| 191 |
+
cuts_train,
|
| 192 |
+
max_duration=self.args.max_duration,
|
| 193 |
+
shuffle=self.args.shuffle,
|
| 194 |
+
)
|
| 195 |
+
logging.info("About to create train dataloader")
|
| 196 |
+
|
| 197 |
+
if sampler_state_dict is not None:
|
| 198 |
+
logging.info("Loading sampler state dict")
|
| 199 |
+
train_sampler.load_state_dict(sampler_state_dict)
|
| 200 |
+
|
| 201 |
+
# 'seed' is derived from the current random state, which will have
|
| 202 |
+
# previously been set in the main process.
|
| 203 |
+
seed = torch.randint(0, 100000, ()).item()
|
| 204 |
+
worker_init_fn = _SeedWorkers(seed)
|
| 205 |
+
|
| 206 |
+
train_dl = DataLoader(
|
| 207 |
+
train,
|
| 208 |
+
sampler=train_sampler,
|
| 209 |
+
batch_size=None,
|
| 210 |
+
num_workers=self.args.num_workers,
|
| 211 |
+
persistent_workers=False,
|
| 212 |
+
worker_init_fn=worker_init_fn,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return train_dl
|
| 216 |
+
|
| 217 |
+
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
| 218 |
+
logging.info("About to create dev dataset")
|
| 219 |
+
validate = SpeechSynthesisDataset(
|
| 220 |
+
return_text=True,
|
| 221 |
+
return_tokens=True,
|
| 222 |
+
return_spk_ids=True,
|
| 223 |
+
feature_input_strategy=OnTheFlyFeatures(VocosFbank())
|
| 224 |
+
if self.args.on_the_fly_feats
|
| 225 |
+
else PrecomputedFeatures(),
|
| 226 |
+
return_cuts=self.args.return_cuts,
|
| 227 |
+
)
|
| 228 |
+
dev_sampler = DynamicBucketingSampler(
|
| 229 |
+
cuts_valid,
|
| 230 |
+
max_duration=self.args.max_duration,
|
| 231 |
+
shuffle=False,
|
| 232 |
+
)
|
| 233 |
+
logging.info("About to create valid dataloader")
|
| 234 |
+
dev_dl = DataLoader(
|
| 235 |
+
validate,
|
| 236 |
+
sampler=dev_sampler,
|
| 237 |
+
batch_size=None,
|
| 238 |
+
num_workers=2,
|
| 239 |
+
persistent_workers=False,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
return dev_dl
|
| 243 |
+
|
| 244 |
+
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
| 245 |
+
logging.info("About to create test dataset")
|
| 246 |
+
test = SpeechSynthesisDataset(
|
| 247 |
+
return_text=True,
|
| 248 |
+
return_tokens=True,
|
| 249 |
+
return_spk_ids=True,
|
| 250 |
+
feature_input_strategy=OnTheFlyFeatures(VocosFbank())
|
| 251 |
+
if self.args.on_the_fly_feats
|
| 252 |
+
else PrecomputedFeatures(),
|
| 253 |
+
return_cuts=self.args.return_cuts,
|
| 254 |
+
return_audio=True,
|
| 255 |
+
)
|
| 256 |
+
test_sampler = DynamicBucketingSampler(
|
| 257 |
+
cuts,
|
| 258 |
+
max_duration=self.args.max_duration,
|
| 259 |
+
shuffle=False,
|
| 260 |
+
)
|
| 261 |
+
logging.info("About to create test dataloader")
|
| 262 |
+
test_dl = DataLoader(
|
| 263 |
+
test,
|
| 264 |
+
batch_size=None,
|
| 265 |
+
sampler=test_sampler,
|
| 266 |
+
num_workers=self.args.num_workers,
|
| 267 |
+
)
|
| 268 |
+
return test_dl
|
| 269 |
+
|
| 270 |
+
@lru_cache()
|
| 271 |
+
def train_custom_cuts(self, manifest_file) -> CutSet:
|
| 272 |
+
logging.info(f"About to get the custom training cuts {manifest_file}")
|
| 273 |
+
return load_manifest_lazy(manifest_file)
|
| 274 |
+
|
| 275 |
+
@lru_cache()
|
| 276 |
+
def dev_custom_cuts(self, manifest_file) -> CutSet:
|
| 277 |
+
logging.info(f"About to get the custom validation cuts {manifest_file}")
|
| 278 |
+
return load_manifest_lazy(manifest_file)
|
| 279 |
+
|
| 280 |
+
@lru_cache()
|
| 281 |
+
def train_emilia_EN_cuts(self) -> CutSet:
|
| 282 |
+
logging.info("About to get train the EN subset")
|
| 283 |
+
return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
|
| 284 |
+
|
| 285 |
+
@lru_cache()
|
| 286 |
+
def train_emilia_ZH_cuts(self) -> CutSet:
|
| 287 |
+
logging.info("About to get train the ZH subset")
|
| 288 |
+
return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
|
| 289 |
+
|
| 290 |
+
@lru_cache()
|
| 291 |
+
def dev_emilia_EN_cuts(self) -> CutSet:
|
| 292 |
+
logging.info("About to get dev the EN subset")
|
| 293 |
+
return load_manifest_lazy(
|
| 294 |
+
self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
@lru_cache()
|
| 298 |
+
def dev_emilia_ZH_cuts(self) -> CutSet:
|
| 299 |
+
logging.info("About to get dev the ZH subset")
|
| 300 |
+
return load_manifest_lazy(
|
| 301 |
+
self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
@lru_cache()
|
| 305 |
+
def train_libritts_cuts(self) -> CutSet:
|
| 306 |
+
logging.info(
|
| 307 |
+
"About to get the shuffled train-clean-100, \
|
| 308 |
+
train-clean-360 and train-other-500 cuts"
|
| 309 |
+
)
|
| 310 |
+
return load_manifest_lazy(
|
| 311 |
+
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
@lru_cache()
|
| 315 |
+
def dev_libritts_cuts(self) -> CutSet:
|
| 316 |
+
logging.info("About to get dev-clean cuts")
|
| 317 |
+
return load_manifest_lazy(
|
| 318 |
+
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
|
| 319 |
+
)
|
zipvoice/dataset/dataset.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Dict, List, Sequence, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from lhotse import CutSet, validate
|
| 5 |
+
from lhotse.dataset import PrecomputedFeatures
|
| 6 |
+
from lhotse.dataset.collation import collate_audio
|
| 7 |
+
from lhotse.dataset.input_strategies import BatchIO
|
| 8 |
+
from lhotse.utils import ifnone
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SpeechSynthesisDataset(torch.utils.data.Dataset):
|
| 12 |
+
"""
|
| 13 |
+
The PyTorch Dataset for the speech synthesis task.
|
| 14 |
+
Each item in this dataset is a dict of:
|
| 15 |
+
|
| 16 |
+
.. code-block::
|
| 17 |
+
|
| 18 |
+
{
|
| 19 |
+
'audio': (B x NumSamples) float tensor
|
| 20 |
+
'features': (B x NumFrames x NumFeatures) float tensor
|
| 21 |
+
'audio_lens': (B, ) int tensor
|
| 22 |
+
'features_lens': (B, ) int tensor
|
| 23 |
+
'text': List[str] of len B # when return_text=True
|
| 24 |
+
'tokens': List[List[str]] # when return_tokens=True
|
| 25 |
+
'speakers': List[str] of len B # when return_spk_ids=True
|
| 26 |
+
'cut': List of Cuts # when return_cuts=True
|
| 27 |
+
}
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
| 33 |
+
feature_input_strategy: BatchIO = PrecomputedFeatures(),
|
| 34 |
+
feature_transforms: Union[Sequence[Callable], Callable] = None,
|
| 35 |
+
return_text: bool = True,
|
| 36 |
+
return_tokens: bool = False,
|
| 37 |
+
return_spk_ids: bool = False,
|
| 38 |
+
return_cuts: bool = False,
|
| 39 |
+
return_audio: bool = False,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.cut_transforms = ifnone(cut_transforms, [])
|
| 44 |
+
self.feature_input_strategy = feature_input_strategy
|
| 45 |
+
|
| 46 |
+
self.return_text = return_text
|
| 47 |
+
self.return_tokens = return_tokens
|
| 48 |
+
self.return_spk_ids = return_spk_ids
|
| 49 |
+
self.return_cuts = return_cuts
|
| 50 |
+
self.return_audio = return_audio
|
| 51 |
+
|
| 52 |
+
if feature_transforms is None:
|
| 53 |
+
feature_transforms = []
|
| 54 |
+
elif not isinstance(feature_transforms, Sequence):
|
| 55 |
+
feature_transforms = [feature_transforms]
|
| 56 |
+
|
| 57 |
+
assert all(
|
| 58 |
+
isinstance(transform, Callable) for transform in feature_transforms
|
| 59 |
+
), "Feature transforms must be Callable"
|
| 60 |
+
self.feature_transforms = feature_transforms
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
|
| 63 |
+
validate_for_tts(cuts)
|
| 64 |
+
|
| 65 |
+
for transform in self.cut_transforms:
|
| 66 |
+
cuts = transform(cuts)
|
| 67 |
+
|
| 68 |
+
features, features_lens = self.feature_input_strategy(cuts)
|
| 69 |
+
|
| 70 |
+
for transform in self.feature_transforms:
|
| 71 |
+
features = transform(features)
|
| 72 |
+
|
| 73 |
+
batch = {
|
| 74 |
+
"features": features,
|
| 75 |
+
"features_lens": features_lens,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
if self.return_audio:
|
| 79 |
+
audio, audio_lens = collate_audio(cuts)
|
| 80 |
+
batch["audio"] = audio
|
| 81 |
+
batch["audio_lens"] = audio_lens
|
| 82 |
+
|
| 83 |
+
if self.return_text:
|
| 84 |
+
text = [cut.supervisions[0].text for cut in cuts]
|
| 85 |
+
batch["text"] = text
|
| 86 |
+
|
| 87 |
+
if self.return_tokens:
|
| 88 |
+
tokens = [cut.supervisions[0].tokens for cut in cuts]
|
| 89 |
+
batch["tokens"] = tokens
|
| 90 |
+
|
| 91 |
+
if self.return_spk_ids:
|
| 92 |
+
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
|
| 93 |
+
|
| 94 |
+
if self.return_cuts:
|
| 95 |
+
batch["cut"] = [cut for cut in cuts]
|
| 96 |
+
|
| 97 |
+
return batch
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def validate_for_tts(cuts: CutSet) -> None:
|
| 101 |
+
validate(cuts)
|
| 102 |
+
for cut in cuts:
|
| 103 |
+
assert (
|
| 104 |
+
len(cut.supervisions) == 1
|
| 105 |
+
), "Only the Cuts with single supervision are supported."
|
zipvoice/eval/evaluate_sim.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu
|
| 3 |
+
# Wei Kang)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
Calculate pairwise Speaker Similarity betweeen two speech directories.
|
| 22 |
+
SV model wavlm_large_finetune.pth is downloaded from
|
| 23 |
+
https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
|
| 24 |
+
SSL model wavlm_large.pt is downloaded from
|
| 25 |
+
https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
|
| 26 |
+
"""
|
| 27 |
+
import argparse
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
import librosa
|
| 32 |
+
import numpy as np
|
| 33 |
+
import soundfile as sf
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
logging.basicConfig(level=logging.INFO)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_parser():
|
| 43 |
+
parser = argparse.ArgumentParser()
|
| 44 |
+
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--eval-path", type=str, help="path of the evaluated speech directory"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--test-list",
|
| 50 |
+
type=str,
|
| 51 |
+
help="path of the file list that contains the corresponding "
|
| 52 |
+
"relationship between the prompt and evaluated speech. "
|
| 53 |
+
"The first column is the wav name and the third column is the prompt speech",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--sv-model-path",
|
| 57 |
+
type=str,
|
| 58 |
+
default="model/UniSpeech/wavlm_large_finetune.pth",
|
| 59 |
+
help="path of the wavlm-based ECAPA-TDNN model",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--ssl-model-path",
|
| 63 |
+
type=str,
|
| 64 |
+
default="model/s3prl/wavlm_large.pt",
|
| 65 |
+
help="path of the wavlm SSL model",
|
| 66 |
+
)
|
| 67 |
+
return parser
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class SpeakerSimilarity:
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
sv_model_path="model/UniSpeech/wavlm_large_finetune.pth",
|
| 74 |
+
ssl_model_path="model/s3prl/wavlm_large.pt",
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Initialize
|
| 78 |
+
"""
|
| 79 |
+
self.sample_rate = 16000
|
| 80 |
+
self.channels = 1
|
| 81 |
+
self.device = (
|
| 82 |
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 83 |
+
)
|
| 84 |
+
logging.info("[Speaker Similarity] Using device: {}".format(self.device))
|
| 85 |
+
self.model = ECAPA_TDNN_WAVLLM(
|
| 86 |
+
feat_dim=1024,
|
| 87 |
+
channels=512,
|
| 88 |
+
emb_dim=256,
|
| 89 |
+
sr=16000,
|
| 90 |
+
ssl_model_path=ssl_model_path,
|
| 91 |
+
)
|
| 92 |
+
state_dict = torch.load(
|
| 93 |
+
sv_model_path, map_location=lambda storage, loc: storage
|
| 94 |
+
)
|
| 95 |
+
self.model.load_state_dict(state_dict["model"], strict=False)
|
| 96 |
+
self.model.to(self.device)
|
| 97 |
+
self.model.eval()
|
| 98 |
+
|
| 99 |
+
def get_embeddings(self, wav_list, dtype="float32"):
|
| 100 |
+
"""
|
| 101 |
+
Get embeddings
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def _load_speech_task(fname, sample_rate):
|
| 105 |
+
|
| 106 |
+
wav_data, sr = sf.read(fname, dtype=dtype)
|
| 107 |
+
if sr != sample_rate:
|
| 108 |
+
wav_data = librosa.resample(
|
| 109 |
+
wav_data, orig_sr=sr, target_sr=self.sample_rate
|
| 110 |
+
)
|
| 111 |
+
wav_data = torch.from_numpy(wav_data)
|
| 112 |
+
|
| 113 |
+
return wav_data
|
| 114 |
+
|
| 115 |
+
embd_lst = []
|
| 116 |
+
for file_path in tqdm(wav_list):
|
| 117 |
+
speech = _load_speech_task(file_path, self.sample_rate)
|
| 118 |
+
speech = speech.to(self.device)
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
embd = self.model([speech])
|
| 121 |
+
embd_lst.append(embd)
|
| 122 |
+
|
| 123 |
+
return embd_lst
|
| 124 |
+
|
| 125 |
+
def score(
|
| 126 |
+
self,
|
| 127 |
+
eval_path,
|
| 128 |
+
test_list,
|
| 129 |
+
dtype="float32",
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Computes the Speaker Similarity (SIM-o) between two directories of speech files.
|
| 133 |
+
|
| 134 |
+
Parameters:
|
| 135 |
+
- eval_path (str): Path to the directory containing evaluation speech files.
|
| 136 |
+
- test_list (str): Path to the file containing the corresponding relationship
|
| 137 |
+
between prompt and evaluated speech.
|
| 138 |
+
- dtype (str, optional): Data type for loading speech. Default is "float32".
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
- float: The Speaker Similarity (SIM-o) score between the two directories
|
| 142 |
+
of speech files.
|
| 143 |
+
"""
|
| 144 |
+
prompt_wavs = []
|
| 145 |
+
eval_wavs = []
|
| 146 |
+
with open(test_list, "r") as fr:
|
| 147 |
+
lines = fr.readlines()
|
| 148 |
+
for line in lines:
|
| 149 |
+
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
| 150 |
+
prompt_wavs.append(prompt_wav)
|
| 151 |
+
eval_wavs.append(os.path.join(eval_path, wav_name + ".wav"))
|
| 152 |
+
embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype)
|
| 153 |
+
|
| 154 |
+
embds_eval = self.get_embeddings(eval_wavs, dtype=dtype)
|
| 155 |
+
|
| 156 |
+
# Check if embeddings are empty
|
| 157 |
+
if len(embds_prompt) == 0:
|
| 158 |
+
logging.info("[Speaker Similarity] real set dir is empty, exiting...")
|
| 159 |
+
return -1
|
| 160 |
+
if len(embds_eval) == 0:
|
| 161 |
+
logging.info("[Speaker Similarity] eval set dir is empty, exiting...")
|
| 162 |
+
return -1
|
| 163 |
+
|
| 164 |
+
scores = []
|
| 165 |
+
for real_embd, eval_embd in zip(embds_prompt, embds_eval):
|
| 166 |
+
scores.append(
|
| 167 |
+
torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1)
|
| 168 |
+
.detach()
|
| 169 |
+
.cpu()
|
| 170 |
+
.numpy()
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return np.mean(scores)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
| 177 |
+
|
| 178 |
+
""" Res2Conv1d + BatchNorm1d + ReLU
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Res2Conv1dReluBn(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
in_channels == out_channels == channels
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
channels,
|
| 190 |
+
kernel_size=1,
|
| 191 |
+
stride=1,
|
| 192 |
+
padding=0,
|
| 193 |
+
dilation=1,
|
| 194 |
+
bias=True,
|
| 195 |
+
scale=4,
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
| 199 |
+
self.scale = scale
|
| 200 |
+
self.width = channels // scale
|
| 201 |
+
self.nums = scale if scale == 1 else scale - 1
|
| 202 |
+
|
| 203 |
+
self.convs = []
|
| 204 |
+
self.bns = []
|
| 205 |
+
for i in range(self.nums):
|
| 206 |
+
self.convs.append(
|
| 207 |
+
nn.Conv1d(
|
| 208 |
+
self.width,
|
| 209 |
+
self.width,
|
| 210 |
+
kernel_size,
|
| 211 |
+
stride,
|
| 212 |
+
padding,
|
| 213 |
+
dilation,
|
| 214 |
+
bias=bias,
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
| 218 |
+
self.convs = nn.ModuleList(self.convs)
|
| 219 |
+
self.bns = nn.ModuleList(self.bns)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
out = []
|
| 223 |
+
spx = torch.split(x, self.width, 1)
|
| 224 |
+
for i in range(self.nums):
|
| 225 |
+
if i == 0:
|
| 226 |
+
sp = spx[i]
|
| 227 |
+
else:
|
| 228 |
+
sp = sp + spx[i]
|
| 229 |
+
# Order: conv -> relu -> bn
|
| 230 |
+
sp = self.convs[i](sp)
|
| 231 |
+
sp = self.bns[i](F.relu(sp))
|
| 232 |
+
out.append(sp)
|
| 233 |
+
if self.scale != 1:
|
| 234 |
+
out.append(spx[self.nums])
|
| 235 |
+
out = torch.cat(out, dim=1)
|
| 236 |
+
|
| 237 |
+
return out
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
""" Conv1d + BatchNorm1d + ReLU
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class Conv1dReluBn(nn.Module):
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
in_channels,
|
| 248 |
+
out_channels,
|
| 249 |
+
kernel_size=1,
|
| 250 |
+
stride=1,
|
| 251 |
+
padding=0,
|
| 252 |
+
dilation=1,
|
| 253 |
+
bias=True,
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.conv = nn.Conv1d(
|
| 257 |
+
in_channels,
|
| 258 |
+
out_channels,
|
| 259 |
+
kernel_size,
|
| 260 |
+
stride,
|
| 261 |
+
padding,
|
| 262 |
+
dilation,
|
| 263 |
+
bias=bias,
|
| 264 |
+
)
|
| 265 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
return self.bn(F.relu(self.conv(x)))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
""" The SE connection of 1D case.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class SE_Connect(nn.Module):
|
| 276 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
| 279 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
| 280 |
+
|
| 281 |
+
def forward(self, x):
|
| 282 |
+
out = x.mean(dim=2)
|
| 283 |
+
out = F.relu(self.linear1(out))
|
| 284 |
+
out = torch.sigmoid(self.linear2(out))
|
| 285 |
+
out = x * out.unsqueeze(2)
|
| 286 |
+
|
| 287 |
+
return out
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
| 295 |
+
# return nn.Sequential(
|
| 296 |
+
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
| 297 |
+
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
| 298 |
+
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
| 299 |
+
# SE_Connect(channels)
|
| 300 |
+
# )
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class SE_Res2Block(nn.Module):
|
| 304 |
+
def __init__(
|
| 305 |
+
self,
|
| 306 |
+
in_channels,
|
| 307 |
+
out_channels,
|
| 308 |
+
kernel_size,
|
| 309 |
+
stride,
|
| 310 |
+
padding,
|
| 311 |
+
dilation,
|
| 312 |
+
scale,
|
| 313 |
+
se_bottleneck_dim,
|
| 314 |
+
):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.Conv1dReluBn1 = Conv1dReluBn(
|
| 317 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 318 |
+
)
|
| 319 |
+
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
|
| 320 |
+
out_channels, kernel_size, stride, padding, dilation, scale=scale
|
| 321 |
+
)
|
| 322 |
+
self.Conv1dReluBn2 = Conv1dReluBn(
|
| 323 |
+
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 324 |
+
)
|
| 325 |
+
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
| 326 |
+
|
| 327 |
+
self.shortcut = None
|
| 328 |
+
if in_channels != out_channels:
|
| 329 |
+
self.shortcut = nn.Conv1d(
|
| 330 |
+
in_channels=in_channels,
|
| 331 |
+
out_channels=out_channels,
|
| 332 |
+
kernel_size=1,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def forward(self, x):
|
| 336 |
+
residual = x
|
| 337 |
+
if self.shortcut:
|
| 338 |
+
residual = self.shortcut(x)
|
| 339 |
+
|
| 340 |
+
x = self.Conv1dReluBn1(x)
|
| 341 |
+
x = self.Res2Conv1dReluBn(x)
|
| 342 |
+
x = self.Conv1dReluBn2(x)
|
| 343 |
+
x = self.SE_Connect(x)
|
| 344 |
+
|
| 345 |
+
return x + residual
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
""" Attentive weighted mean and standard deviation pooling.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class AttentiveStatsPool(nn.Module):
|
| 353 |
+
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.global_context_att = global_context_att
|
| 356 |
+
|
| 357 |
+
# Use Conv1d with stride == 1 rather than Linear,
|
| 358 |
+
# then we don't need to transpose inputs.
|
| 359 |
+
if global_context_att:
|
| 360 |
+
self.linear1 = nn.Conv1d(
|
| 361 |
+
in_dim * 3, attention_channels, kernel_size=1
|
| 362 |
+
) # equals W and b in the paper
|
| 363 |
+
else:
|
| 364 |
+
self.linear1 = nn.Conv1d(
|
| 365 |
+
in_dim, attention_channels, kernel_size=1
|
| 366 |
+
) # equals W and b in the paper
|
| 367 |
+
self.linear2 = nn.Conv1d(
|
| 368 |
+
attention_channels, in_dim, kernel_size=1
|
| 369 |
+
) # equals V and k in the paper
|
| 370 |
+
|
| 371 |
+
def forward(self, x):
|
| 372 |
+
|
| 373 |
+
if self.global_context_att:
|
| 374 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
| 375 |
+
context_std = torch.sqrt(
|
| 376 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-10
|
| 377 |
+
).expand_as(x)
|
| 378 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
| 379 |
+
else:
|
| 380 |
+
x_in = x
|
| 381 |
+
|
| 382 |
+
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
| 383 |
+
alpha = torch.tanh(self.linear1(x_in))
|
| 384 |
+
# alpha = F.relu(self.linear1(x_in))
|
| 385 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 386 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 387 |
+
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
| 388 |
+
std = torch.sqrt(residuals.clamp(min=1e-9))
|
| 389 |
+
return torch.cat([mean, std], dim=1)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class ECAPA_TDNN_WAVLLM(nn.Module):
|
| 393 |
+
def __init__(
|
| 394 |
+
self,
|
| 395 |
+
feat_dim=80,
|
| 396 |
+
channels=512,
|
| 397 |
+
emb_dim=192,
|
| 398 |
+
global_context_att=False,
|
| 399 |
+
sr=16000,
|
| 400 |
+
ssl_model_path=None,
|
| 401 |
+
):
|
| 402 |
+
super().__init__()
|
| 403 |
+
self.sr = sr
|
| 404 |
+
|
| 405 |
+
if ssl_model_path is None:
|
| 406 |
+
self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
|
| 407 |
+
else:
|
| 408 |
+
self.feature_extract = torch.hub.load(
|
| 409 |
+
os.path.dirname(ssl_model_path),
|
| 410 |
+
"wavlm_local",
|
| 411 |
+
source="local",
|
| 412 |
+
ckpt=ssl_model_path,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
| 416 |
+
self.feature_extract.model.encoder.layers[23].self_attn,
|
| 417 |
+
"fp32_attention",
|
| 418 |
+
):
|
| 419 |
+
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = (
|
| 420 |
+
False
|
| 421 |
+
)
|
| 422 |
+
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
| 423 |
+
self.feature_extract.model.encoder.layers[11].self_attn,
|
| 424 |
+
"fp32_attention",
|
| 425 |
+
):
|
| 426 |
+
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = (
|
| 427 |
+
False
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
self.feat_num = self.get_feat_num()
|
| 431 |
+
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
| 432 |
+
|
| 433 |
+
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
| 434 |
+
# self.channels = [channels] * 4 + [channels * 3]
|
| 435 |
+
self.channels = [channels] * 4 + [1536]
|
| 436 |
+
|
| 437 |
+
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
| 438 |
+
self.layer2 = SE_Res2Block(
|
| 439 |
+
self.channels[0],
|
| 440 |
+
self.channels[1],
|
| 441 |
+
kernel_size=3,
|
| 442 |
+
stride=1,
|
| 443 |
+
padding=2,
|
| 444 |
+
dilation=2,
|
| 445 |
+
scale=8,
|
| 446 |
+
se_bottleneck_dim=128,
|
| 447 |
+
)
|
| 448 |
+
self.layer3 = SE_Res2Block(
|
| 449 |
+
self.channels[1],
|
| 450 |
+
self.channels[2],
|
| 451 |
+
kernel_size=3,
|
| 452 |
+
stride=1,
|
| 453 |
+
padding=3,
|
| 454 |
+
dilation=3,
|
| 455 |
+
scale=8,
|
| 456 |
+
se_bottleneck_dim=128,
|
| 457 |
+
)
|
| 458 |
+
self.layer4 = SE_Res2Block(
|
| 459 |
+
self.channels[2],
|
| 460 |
+
self.channels[3],
|
| 461 |
+
kernel_size=3,
|
| 462 |
+
stride=1,
|
| 463 |
+
padding=4,
|
| 464 |
+
dilation=4,
|
| 465 |
+
scale=8,
|
| 466 |
+
se_bottleneck_dim=128,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
| 470 |
+
cat_channels = channels * 3
|
| 471 |
+
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
| 472 |
+
self.pooling = AttentiveStatsPool(
|
| 473 |
+
self.channels[-1],
|
| 474 |
+
attention_channels=128,
|
| 475 |
+
global_context_att=global_context_att,
|
| 476 |
+
)
|
| 477 |
+
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
| 478 |
+
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
| 479 |
+
|
| 480 |
+
def get_feat_num(self):
|
| 481 |
+
self.feature_extract.eval()
|
| 482 |
+
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
| 483 |
+
with torch.no_grad():
|
| 484 |
+
features = self.feature_extract(wav)
|
| 485 |
+
select_feature = features["hidden_states"]
|
| 486 |
+
if isinstance(select_feature, (list, tuple)):
|
| 487 |
+
return len(select_feature)
|
| 488 |
+
else:
|
| 489 |
+
return 1
|
| 490 |
+
|
| 491 |
+
def get_feat(self, x):
|
| 492 |
+
with torch.no_grad():
|
| 493 |
+
x = self.feature_extract([sample for sample in x])
|
| 494 |
+
|
| 495 |
+
x = x["hidden_states"]
|
| 496 |
+
if isinstance(x, (list, tuple)):
|
| 497 |
+
x = torch.stack(x, dim=0)
|
| 498 |
+
else:
|
| 499 |
+
x = x.unsqueeze(0)
|
| 500 |
+
norm_weights = (
|
| 501 |
+
F.softmax(self.feature_weight, dim=-1)
|
| 502 |
+
.unsqueeze(-1)
|
| 503 |
+
.unsqueeze(-1)
|
| 504 |
+
.unsqueeze(-1)
|
| 505 |
+
)
|
| 506 |
+
x = (norm_weights * x).sum(dim=0)
|
| 507 |
+
x = torch.transpose(x, 1, 2) + 1e-6
|
| 508 |
+
|
| 509 |
+
x = self.instance_norm(x)
|
| 510 |
+
return x
|
| 511 |
+
|
| 512 |
+
def forward(self, x):
|
| 513 |
+
x = self.get_feat(x)
|
| 514 |
+
|
| 515 |
+
out1 = self.layer1(x)
|
| 516 |
+
out2 = self.layer2(out1)
|
| 517 |
+
out3 = self.layer3(out2)
|
| 518 |
+
out4 = self.layer4(out3)
|
| 519 |
+
|
| 520 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
| 521 |
+
out = F.relu(self.conv(out))
|
| 522 |
+
out = self.bn(self.pooling(out))
|
| 523 |
+
out = self.linear(out)
|
| 524 |
+
|
| 525 |
+
return out
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
if __name__ == "__main__":
|
| 529 |
+
parser = get_parser()
|
| 530 |
+
args = parser.parse_args()
|
| 531 |
+
SIM = SpeakerSimilarity(
|
| 532 |
+
sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path
|
| 533 |
+
)
|
| 534 |
+
score = SIM.score(args.eval_path, args.test_list)
|
| 535 |
+
logging.info(f"SIM-o score: {score:.3f}")
|
zipvoice/eval/evaluate_utmos.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu
|
| 3 |
+
# Wei Kang)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
|
| 22 |
+
adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
|
| 23 |
+
|
| 24 |
+
# Download model checkpoints
|
| 25 |
+
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
|
| 26 |
+
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import logging
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
import fairseq
|
| 34 |
+
import librosa
|
| 35 |
+
import numpy as np
|
| 36 |
+
import pytorch_lightning as pl
|
| 37 |
+
import soundfile as sf
|
| 38 |
+
import torch
|
| 39 |
+
import torch.nn as nn
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
|
| 42 |
+
logging.basicConfig(level=logging.INFO)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_parser():
|
| 46 |
+
parser = argparse.ArgumentParser()
|
| 47 |
+
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--wav-path", type=str, help="path of the evaluated speech directory"
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--utmos-model-path",
|
| 53 |
+
type=str,
|
| 54 |
+
default="model/huggingface/utmos/utmos.pt",
|
| 55 |
+
help="path of the UTMOS model",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--ssl-model-path",
|
| 59 |
+
type=str,
|
| 60 |
+
default="model/huggingface/utmos/wav2vec_small.pt",
|
| 61 |
+
help="path of the wav2vec SSL model",
|
| 62 |
+
)
|
| 63 |
+
return parser
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class UTMOSScore:
|
| 67 |
+
"""Predicting score for each audio clip."""
|
| 68 |
+
|
| 69 |
+
def __init__(self, utmos_model_path, ssl_model_path):
|
| 70 |
+
self.sample_rate = 16000
|
| 71 |
+
self.device = (
|
| 72 |
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 73 |
+
)
|
| 74 |
+
self.model = (
|
| 75 |
+
BaselineLightningModule.load_from_checkpoint(
|
| 76 |
+
utmos_model_path, ssl_model_path=ssl_model_path
|
| 77 |
+
)
|
| 78 |
+
.eval()
|
| 79 |
+
.to(self.device)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def score(self, wavs: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Args:
|
| 85 |
+
wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
|
| 86 |
+
the model processes the input as a single audio clip. The model
|
| 87 |
+
performs batch processing when len(wavs) == 3.
|
| 88 |
+
"""
|
| 89 |
+
if len(wavs.shape) == 1:
|
| 90 |
+
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
|
| 91 |
+
elif len(wavs.shape) == 2:
|
| 92 |
+
out_wavs = wavs.unsqueeze(0)
|
| 93 |
+
elif len(wavs.shape) == 3:
|
| 94 |
+
out_wavs = wavs
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError("Dimension of input tensor needs to be <= 3.")
|
| 97 |
+
bs = out_wavs.shape[0]
|
| 98 |
+
batch = {
|
| 99 |
+
"wav": out_wavs,
|
| 100 |
+
"domains": torch.zeros(bs, dtype=torch.int).to(self.device),
|
| 101 |
+
"judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
|
| 102 |
+
}
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
output = self.model(batch)
|
| 105 |
+
|
| 106 |
+
return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
|
| 107 |
+
|
| 108 |
+
def score_dir(self, dir, dtype="float32"):
|
| 109 |
+
def _load_speech_task(fname, sample_rate):
|
| 110 |
+
|
| 111 |
+
wav_data, sr = sf.read(fname, dtype=dtype)
|
| 112 |
+
if sr != sample_rate:
|
| 113 |
+
wav_data = librosa.resample(
|
| 114 |
+
wav_data, orig_sr=sr, target_sr=self.sample_rate
|
| 115 |
+
)
|
| 116 |
+
wav_data = torch.from_numpy(wav_data)
|
| 117 |
+
|
| 118 |
+
return wav_data
|
| 119 |
+
|
| 120 |
+
score_lst = []
|
| 121 |
+
for fname in tqdm(os.listdir(dir)):
|
| 122 |
+
speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
|
| 123 |
+
speech = speech.to(self.device)
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
score = self.score(speech)
|
| 126 |
+
score_lst.append(score.item())
|
| 127 |
+
return np.mean(score_lst)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_ssl_model(ckpt_path="wav2vec_small.pt"):
|
| 131 |
+
SSL_OUT_DIM = 768
|
| 132 |
+
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
| 133 |
+
[ckpt_path]
|
| 134 |
+
)
|
| 135 |
+
ssl_model = model[0]
|
| 136 |
+
ssl_model.remove_pretraining_modules()
|
| 137 |
+
return SSL_model(ssl_model, SSL_OUT_DIM)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class BaselineLightningModule(pl.LightningModule):
|
| 141 |
+
def __init__(self, ssl_model_path):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.construct_model(ssl_model_path)
|
| 144 |
+
self.save_hyperparameters()
|
| 145 |
+
|
| 146 |
+
def construct_model(self, ssl_model_path):
|
| 147 |
+
self.feature_extractors = nn.ModuleList(
|
| 148 |
+
[
|
| 149 |
+
load_ssl_model(ckpt_path=ssl_model_path),
|
| 150 |
+
DomainEmbedding(3, 128),
|
| 151 |
+
]
|
| 152 |
+
)
|
| 153 |
+
output_dim = sum(
|
| 154 |
+
[
|
| 155 |
+
feature_extractor.get_output_dim()
|
| 156 |
+
for feature_extractor in self.feature_extractors
|
| 157 |
+
]
|
| 158 |
+
)
|
| 159 |
+
output_layers = [
|
| 160 |
+
LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
|
| 161 |
+
]
|
| 162 |
+
output_dim = output_layers[-1].get_output_dim()
|
| 163 |
+
output_layers.append(
|
| 164 |
+
Projection(
|
| 165 |
+
hidden_dim=2048,
|
| 166 |
+
activation=torch.nn.ReLU(),
|
| 167 |
+
range_clipping=False,
|
| 168 |
+
input_dim=output_dim,
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
self.output_layers = nn.ModuleList(output_layers)
|
| 173 |
+
|
| 174 |
+
def forward(self, inputs):
|
| 175 |
+
outputs = {}
|
| 176 |
+
for feature_extractor in self.feature_extractors:
|
| 177 |
+
outputs.update(feature_extractor(inputs))
|
| 178 |
+
x = outputs
|
| 179 |
+
for output_layer in self.output_layers:
|
| 180 |
+
x = output_layer(x, inputs)
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class SSL_model(nn.Module):
|
| 185 |
+
def __init__(self, ssl_model, ssl_out_dim) -> None:
|
| 186 |
+
super(SSL_model, self).__init__()
|
| 187 |
+
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
|
| 188 |
+
|
| 189 |
+
def forward(self, batch):
|
| 190 |
+
wav = batch["wav"]
|
| 191 |
+
wav = wav.squeeze(1) # [batches, wav_len]
|
| 192 |
+
res = self.ssl_model(wav, mask=False, features_only=True)
|
| 193 |
+
x = res["x"]
|
| 194 |
+
return {"ssl-feature": x}
|
| 195 |
+
|
| 196 |
+
def get_output_dim(self):
|
| 197 |
+
return self.ssl_out_dim
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class DomainEmbedding(nn.Module):
|
| 201 |
+
def __init__(self, n_domains, domain_dim) -> None:
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.embedding = nn.Embedding(n_domains, domain_dim)
|
| 204 |
+
self.output_dim = domain_dim
|
| 205 |
+
|
| 206 |
+
def forward(self, batch):
|
| 207 |
+
return {"domain-feature": self.embedding(batch["domains"])}
|
| 208 |
+
|
| 209 |
+
def get_output_dim(self):
|
| 210 |
+
return self.output_dim
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class LDConditioner(nn.Module):
|
| 214 |
+
"""
|
| 215 |
+
Conditions ssl output by listener embedding
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(self, input_dim, judge_dim, num_judges=None):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.input_dim = input_dim
|
| 221 |
+
self.judge_dim = judge_dim
|
| 222 |
+
self.num_judges = num_judges
|
| 223 |
+
assert num_judges is not None
|
| 224 |
+
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
|
| 225 |
+
# concat [self.output_layer, phoneme features]
|
| 226 |
+
|
| 227 |
+
self.decoder_rnn = nn.LSTM(
|
| 228 |
+
input_size=self.input_dim + self.judge_dim,
|
| 229 |
+
hidden_size=512,
|
| 230 |
+
num_layers=1,
|
| 231 |
+
batch_first=True,
|
| 232 |
+
bidirectional=True,
|
| 233 |
+
) # linear?
|
| 234 |
+
self.out_dim = self.decoder_rnn.hidden_size * 2
|
| 235 |
+
|
| 236 |
+
def get_output_dim(self):
|
| 237 |
+
return self.out_dim
|
| 238 |
+
|
| 239 |
+
def forward(self, x, batch):
|
| 240 |
+
judge_ids = batch["judge_id"]
|
| 241 |
+
if "phoneme-feature" in x.keys():
|
| 242 |
+
concatenated_feature = torch.cat(
|
| 243 |
+
(
|
| 244 |
+
x["ssl-feature"],
|
| 245 |
+
x["phoneme-feature"]
|
| 246 |
+
.unsqueeze(1)
|
| 247 |
+
.expand(-1, x["ssl-feature"].size(1), -1),
|
| 248 |
+
),
|
| 249 |
+
dim=2,
|
| 250 |
+
)
|
| 251 |
+
else:
|
| 252 |
+
concatenated_feature = x["ssl-feature"]
|
| 253 |
+
if "domain-feature" in x.keys():
|
| 254 |
+
concatenated_feature = torch.cat(
|
| 255 |
+
(
|
| 256 |
+
concatenated_feature,
|
| 257 |
+
x["domain-feature"]
|
| 258 |
+
.unsqueeze(1)
|
| 259 |
+
.expand(-1, concatenated_feature.size(1), -1),
|
| 260 |
+
),
|
| 261 |
+
dim=2,
|
| 262 |
+
)
|
| 263 |
+
if judge_ids is not None:
|
| 264 |
+
concatenated_feature = torch.cat(
|
| 265 |
+
(
|
| 266 |
+
concatenated_feature,
|
| 267 |
+
self.judge_embedding(judge_ids)
|
| 268 |
+
.unsqueeze(1)
|
| 269 |
+
.expand(-1, concatenated_feature.size(1), -1),
|
| 270 |
+
),
|
| 271 |
+
dim=2,
|
| 272 |
+
)
|
| 273 |
+
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
|
| 274 |
+
return decoder_output
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class Projection(nn.Module):
|
| 278 |
+
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
|
| 279 |
+
super(Projection, self).__init__()
|
| 280 |
+
self.range_clipping = range_clipping
|
| 281 |
+
output_dim = 1
|
| 282 |
+
if range_clipping:
|
| 283 |
+
self.proj = nn.Tanh()
|
| 284 |
+
|
| 285 |
+
self.net = nn.Sequential(
|
| 286 |
+
nn.Linear(input_dim, hidden_dim),
|
| 287 |
+
activation,
|
| 288 |
+
nn.Dropout(0.3),
|
| 289 |
+
nn.Linear(hidden_dim, output_dim),
|
| 290 |
+
)
|
| 291 |
+
self.output_dim = output_dim
|
| 292 |
+
|
| 293 |
+
def forward(self, x, batch):
|
| 294 |
+
output = self.net(x)
|
| 295 |
+
|
| 296 |
+
# range clipping
|
| 297 |
+
if self.range_clipping:
|
| 298 |
+
return self.proj(output) * 2.0 + 3
|
| 299 |
+
else:
|
| 300 |
+
return output
|
| 301 |
+
|
| 302 |
+
def get_output_dim(self):
|
| 303 |
+
return self.output_dim
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if __name__ == "__main__":
|
| 307 |
+
parser = get_parser()
|
| 308 |
+
args = parser.parse_args()
|
| 309 |
+
UTMOS = UTMOSScore(
|
| 310 |
+
utmos_model_path=args.utmos_model_path,
|
| 311 |
+
ssl_model_path=args.ssl_model_path,
|
| 312 |
+
)
|
| 313 |
+
score = UTMOS.score_dir(args.wav_path)
|
| 314 |
+
logging.info(f"UTMOS score: {score:.2f}")
|
zipvoice/eval/evaluate_wer_hubert.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
|
| 3 |
+
# Wei Kang)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
Calculate WER with Hubert models.
|
| 22 |
+
"""
|
| 23 |
+
import argparse
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import librosa
|
| 29 |
+
import numpy as np
|
| 30 |
+
import soundfile as sf
|
| 31 |
+
import torch
|
| 32 |
+
from jiwer import compute_measures
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
from transformers import pipeline
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_parser():
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
|
| 40 |
+
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--decode-path",
|
| 43 |
+
type=str,
|
| 44 |
+
default=None,
|
| 45 |
+
help="path of the output file of WER information",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--model-path",
|
| 49 |
+
type=str,
|
| 50 |
+
default=None,
|
| 51 |
+
help="path of the local hubert model, e.g., "
|
| 52 |
+
"model/huggingface/hubert-large-ls960-ft",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--test-list",
|
| 56 |
+
type=str,
|
| 57 |
+
default="test.tsv",
|
| 58 |
+
help="path of the transcript tsv file, where the first column "
|
| 59 |
+
"is the wav name and the last column is the transcript",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--batch-size", type=int, default=16, help="decoding batch size"
|
| 63 |
+
)
|
| 64 |
+
return parser
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def post_process(text: str):
|
| 68 |
+
text = text.replace("‘", "'")
|
| 69 |
+
text = text.replace("’", "'")
|
| 70 |
+
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
| 71 |
+
text = re.sub(r"\s+", " ", text)
|
| 72 |
+
text = text.strip()
|
| 73 |
+
return text
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def process_one(hypo, truth):
|
| 77 |
+
truth = post_process(truth)
|
| 78 |
+
hypo = post_process(hypo)
|
| 79 |
+
|
| 80 |
+
measures = compute_measures(truth, hypo)
|
| 81 |
+
word_num = len(truth.split(" "))
|
| 82 |
+
wer = measures["wer"]
|
| 83 |
+
subs = measures["substitutions"]
|
| 84 |
+
dele = measures["deletions"]
|
| 85 |
+
inse = measures["insertions"]
|
| 86 |
+
return (truth, hypo, wer, subs, dele, inse, word_num)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class SpeechEvalDataset(torch.utils.data.Dataset):
|
| 90 |
+
def __init__(self, wav_path: str, test_list: str):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.wav_name = []
|
| 93 |
+
self.wav_paths = []
|
| 94 |
+
self.transcripts = []
|
| 95 |
+
with Path(test_list).open("r", encoding="utf8") as f:
|
| 96 |
+
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
|
| 97 |
+
for item in meta:
|
| 98 |
+
self.wav_name.append(item[0])
|
| 99 |
+
self.wav_paths.append(Path(wav_path, item[0] + ".wav"))
|
| 100 |
+
self.transcripts.append(item[-1])
|
| 101 |
+
|
| 102 |
+
def __len__(self):
|
| 103 |
+
return len(self.wav_paths)
|
| 104 |
+
|
| 105 |
+
def __getitem__(self, index: int):
|
| 106 |
+
wav, sampling_rate = sf.read(self.wav_paths[index])
|
| 107 |
+
item = {
|
| 108 |
+
"array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000),
|
| 109 |
+
"sampling_rate": 16000,
|
| 110 |
+
"reference": self.transcripts[index],
|
| 111 |
+
"wav_name": self.wav_name[index],
|
| 112 |
+
}
|
| 113 |
+
return item
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def main(test_list, wav_path, model_path, decode_path, batch_size, device):
|
| 117 |
+
|
| 118 |
+
if model_path is not None:
|
| 119 |
+
pipe = pipeline(
|
| 120 |
+
"automatic-speech-recognition",
|
| 121 |
+
model=model_path,
|
| 122 |
+
device=device,
|
| 123 |
+
tokenizer=model_path,
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
pipe = pipeline(
|
| 127 |
+
"automatic-speech-recognition",
|
| 128 |
+
model="facebook/hubert-large-ls960-ft",
|
| 129 |
+
device=device,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
dataset = SpeechEvalDataset(wav_path, test_list)
|
| 133 |
+
|
| 134 |
+
bar = tqdm(
|
| 135 |
+
pipe(
|
| 136 |
+
dataset,
|
| 137 |
+
generate_kwargs={"language": "english", "task": "transcribe"},
|
| 138 |
+
batch_size=batch_size,
|
| 139 |
+
),
|
| 140 |
+
total=len(dataset),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
wers = []
|
| 144 |
+
inses = []
|
| 145 |
+
deles = []
|
| 146 |
+
subses = []
|
| 147 |
+
word_nums = 0
|
| 148 |
+
if decode_path:
|
| 149 |
+
decode_dir = os.path.dirname(decode_path)
|
| 150 |
+
if not os.path.exists(decode_dir):
|
| 151 |
+
os.makedirs(decode_dir)
|
| 152 |
+
fout = open(decode_path, "w")
|
| 153 |
+
for out in bar:
|
| 154 |
+
wav_name = out["wav_name"][0]
|
| 155 |
+
transcription = post_process(out["text"].strip())
|
| 156 |
+
text_ref = post_process(out["reference"][0].strip())
|
| 157 |
+
truth, hypo, wer, subs, dele, inse, word_num = process_one(
|
| 158 |
+
transcription, text_ref
|
| 159 |
+
)
|
| 160 |
+
if decode_path:
|
| 161 |
+
fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
|
| 162 |
+
wers.append(float(wer))
|
| 163 |
+
inses.append(float(inse))
|
| 164 |
+
deles.append(float(dele))
|
| 165 |
+
subses.append(float(subs))
|
| 166 |
+
word_nums += word_num
|
| 167 |
+
|
| 168 |
+
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
|
| 169 |
+
subs = round(np.mean(subses) * 100, 3)
|
| 170 |
+
dele = round(np.mean(deles) * 100, 3)
|
| 171 |
+
inse = round(np.mean(inses) * 100, 3)
|
| 172 |
+
print(f"WER: {wer}%\n")
|
| 173 |
+
if decode_path:
|
| 174 |
+
fout.write(f"WER: {wer}%\n")
|
| 175 |
+
fout.flush()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
parser = get_parser()
|
| 180 |
+
args = parser.parse_args()
|
| 181 |
+
if torch.cuda.is_available():
|
| 182 |
+
device = torch.device("cuda", 0)
|
| 183 |
+
else:
|
| 184 |
+
device = torch.device("cpu")
|
| 185 |
+
main(
|
| 186 |
+
args.test_list,
|
| 187 |
+
args.wav_path,
|
| 188 |
+
args.model_path,
|
| 189 |
+
args.decode_path,
|
| 190 |
+
args.batch_size,
|
| 191 |
+
device,
|
| 192 |
+
)
|
zipvoice/eval/evaluate_wer_seedtts.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu
|
| 3 |
+
# Wei Kang)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
Calculate WER with Whisper-large-v3 or Paraformer models,
|
| 22 |
+
following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import os
|
| 27 |
+
import string
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import scipy
|
| 31 |
+
import soundfile as sf
|
| 32 |
+
import torch
|
| 33 |
+
import zhconv
|
| 34 |
+
from funasr import AutoModel
|
| 35 |
+
from jiwer import compute_measures
|
| 36 |
+
from tqdm import tqdm
|
| 37 |
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
| 38 |
+
from zhon.hanzi import punctuation
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_parser():
|
| 42 |
+
parser = argparse.ArgumentParser()
|
| 43 |
+
|
| 44 |
+
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--decode-path",
|
| 47 |
+
type=str,
|
| 48 |
+
default=None,
|
| 49 |
+
help="path of the output file of WER information",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--model-path",
|
| 53 |
+
type=str,
|
| 54 |
+
default=None,
|
| 55 |
+
help="path of the local whisper and paraformer model, "
|
| 56 |
+
"e.g., whisper: model/huggingface/whisper-large-v3/, "
|
| 57 |
+
"paraformer: model/huggingface/paraformer-zh/",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--test-list",
|
| 61 |
+
type=str,
|
| 62 |
+
default="test.tsv",
|
| 63 |
+
help="path of the transcript tsv file, where the first column "
|
| 64 |
+
"is the wav name and the last column is the transcript",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument("--lang", type=str, help="decoded language, zh or en")
|
| 67 |
+
return parser
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_en_model(model_path):
|
| 71 |
+
if model_path is None:
|
| 72 |
+
model_path = "openai/whisper-large-v3"
|
| 73 |
+
processor = WhisperProcessor.from_pretrained(model_path)
|
| 74 |
+
model = WhisperForConditionalGeneration.from_pretrained(model_path)
|
| 75 |
+
return processor, model
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def load_zh_model(model_path):
|
| 79 |
+
if model_path is None:
|
| 80 |
+
model_path = "paraformer-zh"
|
| 81 |
+
model = AutoModel(model=model_path)
|
| 82 |
+
return model
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def process_one(hypo, truth, lang):
|
| 86 |
+
punctuation_all = punctuation + string.punctuation
|
| 87 |
+
for x in punctuation_all:
|
| 88 |
+
if x == "'":
|
| 89 |
+
continue
|
| 90 |
+
truth = truth.replace(x, "")
|
| 91 |
+
hypo = hypo.replace(x, "")
|
| 92 |
+
|
| 93 |
+
truth = truth.replace(" ", " ")
|
| 94 |
+
hypo = hypo.replace(" ", " ")
|
| 95 |
+
|
| 96 |
+
if lang == "zh":
|
| 97 |
+
truth = " ".join([x for x in truth])
|
| 98 |
+
hypo = " ".join([x for x in hypo])
|
| 99 |
+
elif lang == "en":
|
| 100 |
+
truth = truth.lower()
|
| 101 |
+
hypo = hypo.lower()
|
| 102 |
+
else:
|
| 103 |
+
raise NotImplementedError
|
| 104 |
+
|
| 105 |
+
measures = compute_measures(truth, hypo)
|
| 106 |
+
word_num = len(truth.split(" "))
|
| 107 |
+
wer = measures["wer"]
|
| 108 |
+
subs = measures["substitutions"]
|
| 109 |
+
dele = measures["deletions"]
|
| 110 |
+
inse = measures["insertions"]
|
| 111 |
+
return (truth, hypo, wer, subs, dele, inse, word_num)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def main(test_list, wav_path, model_path, decode_path, lang, device):
|
| 115 |
+
if lang == "en":
|
| 116 |
+
processor, model = load_en_model(model_path)
|
| 117 |
+
model.to(device)
|
| 118 |
+
elif lang == "zh":
|
| 119 |
+
model = load_zh_model(model_path)
|
| 120 |
+
params = []
|
| 121 |
+
for line in open(test_list).readlines():
|
| 122 |
+
line = line.strip()
|
| 123 |
+
items = line.split("\t")
|
| 124 |
+
wav_name, text_ref = items[0], items[-1]
|
| 125 |
+
file_path = os.path.join(wav_path, wav_name + ".wav")
|
| 126 |
+
assert os.path.exists(file_path), f"{file_path}"
|
| 127 |
+
|
| 128 |
+
params.append((file_path, text_ref))
|
| 129 |
+
wers = []
|
| 130 |
+
inses = []
|
| 131 |
+
deles = []
|
| 132 |
+
subses = []
|
| 133 |
+
word_nums = 0
|
| 134 |
+
if decode_path:
|
| 135 |
+
decode_dir = os.path.dirname(decode_path)
|
| 136 |
+
if not os.path.exists(decode_dir):
|
| 137 |
+
os.makedirs(decode_dir)
|
| 138 |
+
fout = open(decode_path, "w")
|
| 139 |
+
for wav_path, text_ref in tqdm(params):
|
| 140 |
+
if lang == "en":
|
| 141 |
+
wav, sr = sf.read(wav_path)
|
| 142 |
+
if sr != 16000:
|
| 143 |
+
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
|
| 144 |
+
input_features = processor(
|
| 145 |
+
wav, sampling_rate=16000, return_tensors="pt"
|
| 146 |
+
).input_features
|
| 147 |
+
input_features = input_features.to(device)
|
| 148 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(
|
| 149 |
+
language="english", task="transcribe"
|
| 150 |
+
)
|
| 151 |
+
predicted_ids = model.generate(
|
| 152 |
+
input_features, forced_decoder_ids=forced_decoder_ids
|
| 153 |
+
)
|
| 154 |
+
transcription = processor.batch_decode(
|
| 155 |
+
predicted_ids, skip_special_tokens=True
|
| 156 |
+
)[0]
|
| 157 |
+
elif lang == "zh":
|
| 158 |
+
res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
|
| 159 |
+
transcription = res[0]["text"]
|
| 160 |
+
transcription = zhconv.convert(transcription, "zh-cn")
|
| 161 |
+
|
| 162 |
+
truth, hypo, wer, subs, dele, inse, word_num = process_one(
|
| 163 |
+
transcription, text_ref, lang
|
| 164 |
+
)
|
| 165 |
+
if decode_path:
|
| 166 |
+
fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
|
| 167 |
+
wers.append(float(wer))
|
| 168 |
+
inses.append(float(inse))
|
| 169 |
+
deles.append(float(dele))
|
| 170 |
+
subses.append(float(subs))
|
| 171 |
+
word_nums += word_num
|
| 172 |
+
|
| 173 |
+
wer_avg = round(np.mean(wers) * 100, 3)
|
| 174 |
+
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
|
| 175 |
+
subs = round(np.mean(subses) * 100, 3)
|
| 176 |
+
dele = round(np.mean(deles) * 100, 3)
|
| 177 |
+
inse = round(np.mean(inses) * 100, 3)
|
| 178 |
+
print(f"Seed-TTS WER: {wer_avg}%\n")
|
| 179 |
+
print(f"WER: {wer}%\n")
|
| 180 |
+
if decode_path:
|
| 181 |
+
fout.write(f"SeedTTS WER: {wer_avg}%\n")
|
| 182 |
+
fout.write(f"WER: {wer}%\n")
|
| 183 |
+
fout.flush()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
parser = get_parser()
|
| 188 |
+
args = parser.parse_args()
|
| 189 |
+
if torch.cuda.is_available():
|
| 190 |
+
device = torch.device("cuda", 0)
|
| 191 |
+
else:
|
| 192 |
+
device = torch.device("cpu")
|
| 193 |
+
main(
|
| 194 |
+
args.test_list,
|
| 195 |
+
args.wav_path,
|
| 196 |
+
args.model_path,
|
| 197 |
+
args.decode_path,
|
| 198 |
+
args.lang,
|
| 199 |
+
device,
|
| 200 |
+
)
|
zipvoice/models/modules/scaling.py
ADDED
|
@@ -0,0 +1,1563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022-2025 Xiaomi Corp. (authors: Daniel Povey
|
| 2 |
+
# Wei Kang)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
import random
|
| 22 |
+
import sys
|
| 23 |
+
from typing import Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import k2
|
| 27 |
+
except Exception as e:
|
| 28 |
+
logging.warning(
|
| 29 |
+
f"Failed import k2 with error {e}. Swoosh functions will fallback to PyTorch"
|
| 30 |
+
f" implementation, leading to slower speed and higher memory consumption."
|
| 31 |
+
)
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
from torch import Tensor
|
| 35 |
+
|
| 36 |
+
custom_bwd = lambda func: torch.amp.custom_bwd(func, device_type="cuda")
|
| 37 |
+
custom_fwd = lambda func: torch.amp.custom_fwd(func, device_type="cuda")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
| 41 |
+
max_value = torch.max(x, y)
|
| 42 |
+
diff = torch.abs(x - y)
|
| 43 |
+
return max_value + torch.log1p(torch.exp(-diff))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
| 47 |
+
# 14 is not supported. Please feel free to request support or submit
|
| 48 |
+
# a pull request on PyTorch GitHub.
|
| 49 |
+
#
|
| 50 |
+
# The following function is to solve the above error when exporting
|
| 51 |
+
# models to ONNX via torch.jit.trace()
|
| 52 |
+
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
| 53 |
+
# Caution(fangjun): Put torch.jit.is_scripting() before
|
| 54 |
+
# torch.onnx.is_in_onnx_export();
|
| 55 |
+
# otherwise, it will cause errors for torch.jit.script().
|
| 56 |
+
#
|
| 57 |
+
# torch.logaddexp() works for both torch.jit.script() and
|
| 58 |
+
# torch.jit.trace() but it causes errors for ONNX export.
|
| 59 |
+
#
|
| 60 |
+
if torch.jit.is_scripting():
|
| 61 |
+
# Note: We cannot use torch.jit.is_tracing() here as it also
|
| 62 |
+
# matches torch.onnx.export().
|
| 63 |
+
return torch.logaddexp(x, y)
|
| 64 |
+
elif torch.onnx.is_in_onnx_export():
|
| 65 |
+
return logaddexp_onnx(x, y)
|
| 66 |
+
else:
|
| 67 |
+
# for torch.jit.trace()
|
| 68 |
+
return torch.logaddexp(x, y)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class PiecewiseLinear(object):
|
| 72 |
+
"""
|
| 73 |
+
Piecewise linear function, from float to float, specified as nonempty list of (x,y)
|
| 74 |
+
pairs with the x values in order. x values <[initial x] or >[final x] are map to
|
| 75 |
+
[initial y], [final y] respectively.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, *args):
|
| 79 |
+
assert len(args) >= 1, len(args)
|
| 80 |
+
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
|
| 81 |
+
self.pairs = list(args[0].pairs)
|
| 82 |
+
else:
|
| 83 |
+
self.pairs = [(float(x), float(y)) for x, y in args]
|
| 84 |
+
for x, y in self.pairs:
|
| 85 |
+
assert isinstance(x, (float, int)), type(x)
|
| 86 |
+
assert isinstance(y, (float, int)), type(y)
|
| 87 |
+
|
| 88 |
+
for i in range(len(self.pairs) - 1):
|
| 89 |
+
assert self.pairs[i + 1][0] > self.pairs[i][0], (
|
| 90 |
+
i,
|
| 91 |
+
self.pairs[i],
|
| 92 |
+
self.pairs[i + 1],
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def __str__(self):
|
| 96 |
+
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
|
| 97 |
+
return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
|
| 98 |
+
|
| 99 |
+
def __call__(self, x):
|
| 100 |
+
if x <= self.pairs[0][0]:
|
| 101 |
+
return self.pairs[0][1]
|
| 102 |
+
elif x >= self.pairs[-1][0]:
|
| 103 |
+
return self.pairs[-1][1]
|
| 104 |
+
else:
|
| 105 |
+
cur_x, cur_y = self.pairs[0]
|
| 106 |
+
for i in range(1, len(self.pairs)):
|
| 107 |
+
next_x, next_y = self.pairs[i]
|
| 108 |
+
if x >= cur_x and x <= next_x:
|
| 109 |
+
return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
|
| 110 |
+
cur_x, cur_y = next_x, next_y
|
| 111 |
+
assert False
|
| 112 |
+
|
| 113 |
+
def __mul__(self, alpha):
|
| 114 |
+
return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
|
| 115 |
+
|
| 116 |
+
def __add__(self, x):
|
| 117 |
+
if isinstance(x, (float, int)):
|
| 118 |
+
return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
|
| 119 |
+
s, x = self.get_common_basis(x)
|
| 120 |
+
return PiecewiseLinear(
|
| 121 |
+
*[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def max(self, x):
|
| 125 |
+
if isinstance(x, (float, int)):
|
| 126 |
+
x = PiecewiseLinear((0, x))
|
| 127 |
+
s, x = self.get_common_basis(x, include_crossings=True)
|
| 128 |
+
return PiecewiseLinear(
|
| 129 |
+
*[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def min(self, x):
|
| 133 |
+
if isinstance(x, float) or isinstance(x, int):
|
| 134 |
+
x = PiecewiseLinear((0, x))
|
| 135 |
+
s, x = self.get_common_basis(x, include_crossings=True)
|
| 136 |
+
return PiecewiseLinear(
|
| 137 |
+
*[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def __eq__(self, other):
|
| 141 |
+
return self.pairs == other.pairs
|
| 142 |
+
|
| 143 |
+
def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
|
| 144 |
+
"""
|
| 145 |
+
Returns (self_mod, p_mod) which are equivalent piecewise linear
|
| 146 |
+
functions to self and p, but with the same x values.
|
| 147 |
+
|
| 148 |
+
p: the other piecewise linear function
|
| 149 |
+
include_crossings: if true, include in the x values positions
|
| 150 |
+
where the functions indicate by this and p crosss.
|
| 151 |
+
"""
|
| 152 |
+
assert isinstance(p, PiecewiseLinear), type(p)
|
| 153 |
+
|
| 154 |
+
# get sorted x-values without repetition.
|
| 155 |
+
x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
|
| 156 |
+
y_vals1 = [self(x) for x in x_vals]
|
| 157 |
+
y_vals2 = [p(x) for x in x_vals]
|
| 158 |
+
|
| 159 |
+
if include_crossings:
|
| 160 |
+
extra_x_vals = []
|
| 161 |
+
for i in range(len(x_vals) - 1):
|
| 162 |
+
if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
|
| 163 |
+
# if the two lines in this subsegment potentially cross each other..
|
| 164 |
+
diff_cur = abs(y_vals1[i] - y_vals2[i])
|
| 165 |
+
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
|
| 166 |
+
# `pos`, between 0 and 1, gives the relative x position,
|
| 167 |
+
# with 0 being x_vals[i] and 1 being x_vals[i+1].
|
| 168 |
+
pos = diff_cur / (diff_cur + diff_next)
|
| 169 |
+
extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
|
| 170 |
+
extra_x_vals.append(extra_x_val)
|
| 171 |
+
if len(extra_x_vals) > 0:
|
| 172 |
+
x_vals = sorted(set(x_vals + extra_x_vals))
|
| 173 |
+
y_vals1 = [self(x) for x in x_vals]
|
| 174 |
+
y_vals2 = [p(x) for x in x_vals]
|
| 175 |
+
return (
|
| 176 |
+
PiecewiseLinear(*zip(x_vals, y_vals1)),
|
| 177 |
+
PiecewiseLinear(*zip(x_vals, y_vals2)),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class ScheduledFloat(torch.nn.Module):
|
| 182 |
+
"""
|
| 183 |
+
This object is a torch.nn.Module only because we want it to show up in
|
| 184 |
+
[top_level module].modules(); it does not have a working forward() function.
|
| 185 |
+
You are supposed to cast it to float, as in, float(parent_module.whatever), and use
|
| 186 |
+
it as something like a dropout prob.
|
| 187 |
+
|
| 188 |
+
It is a floating point value whose value changes depending on the batch count of the
|
| 189 |
+
training loop. It is a piecewise linear function where you specify the (x,y) pairs
|
| 190 |
+
in sorted order on x; x corresponds to the batch index. For batch-index values
|
| 191 |
+
before the first x or after the last x, we just use the first or last y value.
|
| 192 |
+
|
| 193 |
+
Example:
|
| 194 |
+
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
| 195 |
+
|
| 196 |
+
`default` is used when self.batch_count is not set or not in training mode or in
|
| 197 |
+
torch.jit scripting mode.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, *args, default: float = 0.0):
|
| 201 |
+
super().__init__()
|
| 202 |
+
# self.batch_count and self.name will be written to in the training loop.
|
| 203 |
+
self.batch_count = None
|
| 204 |
+
self.name = None
|
| 205 |
+
self.default = default
|
| 206 |
+
self.schedule = PiecewiseLinear(*args)
|
| 207 |
+
|
| 208 |
+
def extra_repr(self) -> str:
|
| 209 |
+
return (
|
| 210 |
+
f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def __float__(self):
|
| 214 |
+
batch_count = self.batch_count
|
| 215 |
+
if (
|
| 216 |
+
batch_count is None
|
| 217 |
+
or not self.training
|
| 218 |
+
or torch.jit.is_scripting()
|
| 219 |
+
or torch.jit.is_tracing()
|
| 220 |
+
):
|
| 221 |
+
return float(self.default)
|
| 222 |
+
else:
|
| 223 |
+
ans = self.schedule(self.batch_count)
|
| 224 |
+
if random.random() < 0.0002:
|
| 225 |
+
logging.debug(
|
| 226 |
+
f"ScheduledFloat: name={self.name}, "
|
| 227 |
+
f"batch_count={self.batch_count}, ans={ans}"
|
| 228 |
+
)
|
| 229 |
+
return ans
|
| 230 |
+
|
| 231 |
+
def __add__(self, x):
|
| 232 |
+
if isinstance(x, float) or isinstance(x, int):
|
| 233 |
+
return ScheduledFloat(self.schedule + x, default=self.default)
|
| 234 |
+
else:
|
| 235 |
+
return ScheduledFloat(
|
| 236 |
+
self.schedule + x.schedule, default=self.default + x.default
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def max(self, x):
|
| 240 |
+
if isinstance(x, float) or isinstance(x, int):
|
| 241 |
+
return ScheduledFloat(self.schedule.max(x), default=self.default)
|
| 242 |
+
else:
|
| 243 |
+
return ScheduledFloat(
|
| 244 |
+
self.schedule.max(x.schedule),
|
| 245 |
+
default=max(self.default, x.default),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
FloatLike = Union[float, ScheduledFloat]
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class CutoffEstimator:
|
| 253 |
+
"""
|
| 254 |
+
Estimates cutoffs of an arbitrary numerical quantity such that a specified
|
| 255 |
+
proportion of items will be above the cutoff on average.
|
| 256 |
+
|
| 257 |
+
p is the proportion of items that should be above the cutoff.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, p: float):
|
| 261 |
+
self.p = p
|
| 262 |
+
# total count of items
|
| 263 |
+
self.count = 0
|
| 264 |
+
# total count of items that were above the cutoff
|
| 265 |
+
self.count_above = 0
|
| 266 |
+
# initial cutoff value
|
| 267 |
+
self.cutoff = 0
|
| 268 |
+
|
| 269 |
+
def __call__(self, x: float) -> bool:
|
| 270 |
+
"""
|
| 271 |
+
Returns true if x is above the cutoff.
|
| 272 |
+
"""
|
| 273 |
+
ans = x > self.cutoff
|
| 274 |
+
self.count += 1
|
| 275 |
+
if ans:
|
| 276 |
+
self.count_above += 1
|
| 277 |
+
cur_p = self.count_above / self.count
|
| 278 |
+
delta_p = cur_p - self.p
|
| 279 |
+
if (delta_p > 0) == ans:
|
| 280 |
+
q = abs(delta_p)
|
| 281 |
+
self.cutoff = x * q + self.cutoff * (1 - q)
|
| 282 |
+
return ans
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class SoftmaxFunction(torch.autograd.Function):
|
| 286 |
+
"""
|
| 287 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
| 288 |
+
be more accurate for training than the default behavior.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
@staticmethod
|
| 292 |
+
def forward(ctx, x: Tensor, dim: int):
|
| 293 |
+
ans = x.softmax(dim=dim)
|
| 294 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
| 295 |
+
# (presumably) that op does not support float16, and autocast
|
| 296 |
+
# is enabled.
|
| 297 |
+
if torch.is_autocast_enabled():
|
| 298 |
+
ans = ans.to(torch.float16)
|
| 299 |
+
ctx.save_for_backward(ans)
|
| 300 |
+
ctx.x_dtype = x.dtype
|
| 301 |
+
ctx.dim = dim
|
| 302 |
+
return ans
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
def backward(ctx, ans_grad: Tensor):
|
| 306 |
+
(ans,) = ctx.saved_tensors
|
| 307 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 308 |
+
ans_grad = ans_grad.to(torch.float32)
|
| 309 |
+
ans = ans.to(torch.float32)
|
| 310 |
+
x_grad = ans_grad * ans
|
| 311 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
| 312 |
+
return x_grad, None
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def softmax(x: Tensor, dim: int):
|
| 316 |
+
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 317 |
+
return x.softmax(dim=dim)
|
| 318 |
+
|
| 319 |
+
return SoftmaxFunction.apply(x, dim)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class BiasNormFunction(torch.autograd.Function):
|
| 323 |
+
# This computes:
|
| 324 |
+
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
| 325 |
+
# return x * scales
|
| 326 |
+
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
| 327 |
+
# it can just store the returned value (chances are, this will also be needed for
|
| 328 |
+
# some other reason, related to the next operation, so we can save memory).
|
| 329 |
+
@staticmethod
|
| 330 |
+
def forward(
|
| 331 |
+
ctx,
|
| 332 |
+
x: Tensor,
|
| 333 |
+
bias: Tensor,
|
| 334 |
+
log_scale: Tensor,
|
| 335 |
+
channel_dim: int,
|
| 336 |
+
store_output_for_backprop: bool,
|
| 337 |
+
) -> Tensor:
|
| 338 |
+
assert bias.ndim == 1
|
| 339 |
+
if channel_dim < 0:
|
| 340 |
+
channel_dim = channel_dim + x.ndim
|
| 341 |
+
ctx.store_output_for_backprop = store_output_for_backprop
|
| 342 |
+
ctx.channel_dim = channel_dim
|
| 343 |
+
for _ in range(channel_dim + 1, x.ndim):
|
| 344 |
+
bias = bias.unsqueeze(-1)
|
| 345 |
+
scales = (
|
| 346 |
+
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
|
| 347 |
+
) * log_scale.exp()
|
| 348 |
+
ans = x * scales
|
| 349 |
+
ctx.save_for_backward(
|
| 350 |
+
ans.detach() if store_output_for_backprop else x,
|
| 351 |
+
scales.detach(),
|
| 352 |
+
bias.detach(),
|
| 353 |
+
log_scale.detach(),
|
| 354 |
+
)
|
| 355 |
+
return ans
|
| 356 |
+
|
| 357 |
+
@staticmethod
|
| 358 |
+
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
| 359 |
+
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
|
| 360 |
+
if ctx.store_output_for_backprop:
|
| 361 |
+
x = ans_or_x / scales
|
| 362 |
+
else:
|
| 363 |
+
x = ans_or_x
|
| 364 |
+
x = x.detach()
|
| 365 |
+
x.requires_grad = True
|
| 366 |
+
bias.requires_grad = True
|
| 367 |
+
log_scale.requires_grad = True
|
| 368 |
+
with torch.enable_grad():
|
| 369 |
+
# recompute scales from x, bias and log_scale.
|
| 370 |
+
scales = (
|
| 371 |
+
torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
|
| 372 |
+
) * log_scale.exp()
|
| 373 |
+
ans = x * scales
|
| 374 |
+
ans.backward(gradient=ans_grad)
|
| 375 |
+
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class BiasNorm(torch.nn.Module):
|
| 379 |
+
"""
|
| 380 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
| 381 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
| 382 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
| 383 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
| 384 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
| 385 |
+
on the other (useful) features. Presumably the weight and bias of the
|
| 386 |
+
LayerNorm are required to allow it to do this.
|
| 387 |
+
|
| 388 |
+
Instead, we give the BiasNorm a trainable bias that it can use when
|
| 389 |
+
computing the scale for normalization. We also give it a (scalar)
|
| 390 |
+
trainable scale on the output.
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
num_channels: the number of channels, e.g. 512.
|
| 395 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
| 396 |
+
interpreted as an offset from the input's ndim if negative.
|
| 397 |
+
This is NOT the num_channels; it should typically be one of
|
| 398 |
+
{-2, -1, 0, 1, 2, 3}.
|
| 399 |
+
log_scale: the initial log-scale that we multiply the output by; this
|
| 400 |
+
is learnable.
|
| 401 |
+
log_scale_min: FloatLike, minimum allowed value of log_scale
|
| 402 |
+
log_scale_max: FloatLike, maximum allowed value of log_scale
|
| 403 |
+
store_output_for_backprop: only possibly affects memory use; recommend
|
| 404 |
+
to set to True if you think the output of this module is more likely
|
| 405 |
+
than the input of this module to be required to be stored for the
|
| 406 |
+
backprop.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
def __init__(
|
| 410 |
+
self,
|
| 411 |
+
num_channels: int,
|
| 412 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
| 413 |
+
log_scale: float = 1.0,
|
| 414 |
+
log_scale_min: float = -1.5,
|
| 415 |
+
log_scale_max: float = 1.5,
|
| 416 |
+
store_output_for_backprop: bool = False,
|
| 417 |
+
) -> None:
|
| 418 |
+
super(BiasNorm, self).__init__()
|
| 419 |
+
self.num_channels = num_channels
|
| 420 |
+
self.channel_dim = channel_dim
|
| 421 |
+
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
| 422 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 423 |
+
|
| 424 |
+
self.log_scale_min = log_scale_min
|
| 425 |
+
self.log_scale_max = log_scale_max
|
| 426 |
+
|
| 427 |
+
self.store_output_for_backprop = store_output_for_backprop
|
| 428 |
+
|
| 429 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 430 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
| 431 |
+
|
| 432 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 433 |
+
channel_dim = self.channel_dim
|
| 434 |
+
if channel_dim < 0:
|
| 435 |
+
channel_dim += x.ndim
|
| 436 |
+
bias = self.bias
|
| 437 |
+
for _ in range(channel_dim + 1, x.ndim):
|
| 438 |
+
bias = bias.unsqueeze(-1)
|
| 439 |
+
scales = (
|
| 440 |
+
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
|
| 441 |
+
) * self.log_scale.exp()
|
| 442 |
+
return x * scales
|
| 443 |
+
|
| 444 |
+
log_scale = limit_param_value(
|
| 445 |
+
self.log_scale,
|
| 446 |
+
min=float(self.log_scale_min),
|
| 447 |
+
max=float(self.log_scale_max),
|
| 448 |
+
training=self.training,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
return BiasNormFunction.apply(
|
| 452 |
+
x,
|
| 453 |
+
self.bias,
|
| 454 |
+
log_scale,
|
| 455 |
+
self.channel_dim,
|
| 456 |
+
self.store_output_for_backprop,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
| 461 |
+
"""
|
| 462 |
+
Behaves like a constructor of a modified version of nn.Linear
|
| 463 |
+
that gives an easy way to set the default initial parameter scale.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 467 |
+
e.g. in_features, out_features, bias=False.
|
| 468 |
+
|
| 469 |
+
initial_scale: you can override this if you want to increase
|
| 470 |
+
or decrease the initial magnitude of the module's output
|
| 471 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 472 |
+
Another option, if you want to do something like this, is
|
| 473 |
+
to re-initialize the parameters.
|
| 474 |
+
"""
|
| 475 |
+
ans = nn.Linear(*args, **kwargs)
|
| 476 |
+
with torch.no_grad():
|
| 477 |
+
ans.weight[:] *= initial_scale
|
| 478 |
+
if ans.bias is not None:
|
| 479 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 480 |
+
return ans
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class BalancerFunction(torch.autograd.Function):
|
| 484 |
+
@staticmethod
|
| 485 |
+
def forward(
|
| 486 |
+
ctx,
|
| 487 |
+
x: Tensor,
|
| 488 |
+
min_mean: float,
|
| 489 |
+
max_mean: float,
|
| 490 |
+
min_rms: float,
|
| 491 |
+
max_rms: float,
|
| 492 |
+
grad_scale: float,
|
| 493 |
+
channel_dim: int,
|
| 494 |
+
) -> Tensor:
|
| 495 |
+
if channel_dim < 0:
|
| 496 |
+
channel_dim += x.ndim
|
| 497 |
+
ctx.channel_dim = channel_dim
|
| 498 |
+
ctx.save_for_backward(x)
|
| 499 |
+
ctx.config = (
|
| 500 |
+
min_mean,
|
| 501 |
+
max_mean,
|
| 502 |
+
min_rms,
|
| 503 |
+
max_rms,
|
| 504 |
+
grad_scale,
|
| 505 |
+
channel_dim,
|
| 506 |
+
)
|
| 507 |
+
return x
|
| 508 |
+
|
| 509 |
+
@staticmethod
|
| 510 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
| 511 |
+
(x,) = ctx.saved_tensors
|
| 512 |
+
(
|
| 513 |
+
min_mean,
|
| 514 |
+
max_mean,
|
| 515 |
+
min_rms,
|
| 516 |
+
max_rms,
|
| 517 |
+
grad_scale,
|
| 518 |
+
channel_dim,
|
| 519 |
+
) = ctx.config
|
| 520 |
+
|
| 521 |
+
try:
|
| 522 |
+
with torch.enable_grad():
|
| 523 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 524 |
+
x = x.to(torch.float32)
|
| 525 |
+
x = x.detach()
|
| 526 |
+
x.requires_grad = True
|
| 527 |
+
mean_dims = [i for i in range(x.ndim) if i != channel_dim]
|
| 528 |
+
uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
|
| 529 |
+
mean = x.mean(dim=mean_dims, keepdim=True)
|
| 530 |
+
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
| 531 |
+
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
| 532 |
+
|
| 533 |
+
m = mean / stddev
|
| 534 |
+
# part of loss that relates to mean / stddev
|
| 535 |
+
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
| 536 |
+
|
| 537 |
+
# put a much larger scale on the RMS-max-limit loss, so that if both
|
| 538 |
+
# it and the m_loss are violated we fix the RMS loss first.
|
| 539 |
+
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
| 540 |
+
r_loss = (rms_clamped / rms).log().abs()
|
| 541 |
+
|
| 542 |
+
loss = m_loss + r_loss
|
| 543 |
+
|
| 544 |
+
loss.backward(gradient=torch.ones_like(loss))
|
| 545 |
+
loss_grad = x.grad
|
| 546 |
+
loss_grad_rms = (
|
| 547 |
+
(loss_grad**2)
|
| 548 |
+
.mean(dim=mean_dims, keepdim=True)
|
| 549 |
+
.sqrt()
|
| 550 |
+
.clamp(min=1.0e-20)
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
| 554 |
+
|
| 555 |
+
x_grad_float = x_grad.to(torch.float32)
|
| 556 |
+
# scale each element of loss_grad by the absolute value of the
|
| 557 |
+
# corresponding element of x_grad, which we view as a noisy estimate
|
| 558 |
+
# of its magnitude for that (frame and dimension). later we can
|
| 559 |
+
# consider factored versions.
|
| 560 |
+
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
| 561 |
+
x_grad = x_grad_mod.to(x_grad.dtype)
|
| 562 |
+
except Exception as e:
|
| 563 |
+
logging.info(
|
| 564 |
+
f"Caught exception in Balancer backward: {e}, "
|
| 565 |
+
f"size={list(x_grad.shape)}, will continue."
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
return x_grad, None, None, None, None, None, None
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
class Balancer(torch.nn.Module):
|
| 572 |
+
"""
|
| 573 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
| 574 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
| 575 |
+
time. It does this by multiplying negative derivative values by up to
|
| 576 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
| 577 |
+
interpolated from 1 at the threshold to those extremal values when none
|
| 578 |
+
of the inputs are positive.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
num_channels: the number of channels
|
| 582 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
| 583 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
| 584 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
| 585 |
+
that (x > 0), below which we start to modify the derivatives.
|
| 586 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
| 587 |
+
that (x > 0), above which we start to modify the derivatives.
|
| 588 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
| 589 |
+
change in gradient once the constraints on min_abs and max_abs
|
| 590 |
+
are violated.
|
| 591 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
| 592 |
+
value per channel, which we allow, before we start to modify
|
| 593 |
+
the derivatives to prevent this.
|
| 594 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
| 595 |
+
value per channel, which we allow, before we start to modify
|
| 596 |
+
the derivatives to prevent this.
|
| 597 |
+
prob: determines the minimum probability with which we modify the
|
| 598 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
| 599 |
+
on each forward(). This is done randomly to prevent all layers
|
| 600 |
+
from doing it at the same time.
|
| 601 |
+
"""
|
| 602 |
+
|
| 603 |
+
def __init__(
|
| 604 |
+
self,
|
| 605 |
+
num_channels: int,
|
| 606 |
+
channel_dim: int,
|
| 607 |
+
min_positive: FloatLike = 0.05,
|
| 608 |
+
max_positive: FloatLike = 0.95,
|
| 609 |
+
min_abs: FloatLike = 0.2,
|
| 610 |
+
max_abs: FloatLike = 100.0,
|
| 611 |
+
grad_scale: FloatLike = 0.04,
|
| 612 |
+
prob: Optional[FloatLike] = None,
|
| 613 |
+
):
|
| 614 |
+
super().__init__()
|
| 615 |
+
|
| 616 |
+
if prob is None:
|
| 617 |
+
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
|
| 618 |
+
self.prob = prob
|
| 619 |
+
# 5% of the time we will return and do nothing because memory usage is
|
| 620 |
+
# too high.
|
| 621 |
+
self.mem_cutoff = CutoffEstimator(0.05)
|
| 622 |
+
|
| 623 |
+
# actually self.num_channels is no longer needed except for an assertion.
|
| 624 |
+
self.num_channels = num_channels
|
| 625 |
+
self.channel_dim = channel_dim
|
| 626 |
+
self.min_positive = min_positive
|
| 627 |
+
self.max_positive = max_positive
|
| 628 |
+
self.min_abs = min_abs
|
| 629 |
+
self.max_abs = max_abs
|
| 630 |
+
self.grad_scale = grad_scale
|
| 631 |
+
|
| 632 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 633 |
+
if (
|
| 634 |
+
torch.jit.is_scripting()
|
| 635 |
+
or not x.requires_grad
|
| 636 |
+
or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
|
| 637 |
+
):
|
| 638 |
+
return _no_op(x)
|
| 639 |
+
|
| 640 |
+
prob = float(self.prob)
|
| 641 |
+
if random.random() < prob:
|
| 642 |
+
# The following inner-functions convert from the way we historically
|
| 643 |
+
# specified these limitations, as limits on the absolute value and the
|
| 644 |
+
# proportion of positive values, to limits on the RMS value and
|
| 645 |
+
# the (mean / stddev).
|
| 646 |
+
def _abs_to_rms(x):
|
| 647 |
+
# for normally distributed data, if the expected absolute value is x,
|
| 648 |
+
# the expected rms value will be sqrt(pi/2) * x.
|
| 649 |
+
return 1.25331413732 * x
|
| 650 |
+
|
| 651 |
+
def _proportion_positive_to_mean(x):
|
| 652 |
+
def _atanh(x):
|
| 653 |
+
eps = 1.0e-10
|
| 654 |
+
# eps is to prevent crashes if x is exactly 0 or 1.
|
| 655 |
+
# we'll just end up returning a fairly large value.
|
| 656 |
+
return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
|
| 657 |
+
|
| 658 |
+
def _approx_inverse_erf(x):
|
| 659 |
+
# 1 / (sqrt(pi) * ln(2)),
|
| 660 |
+
# see https://math.stackexchange.com/questions/321569/
|
| 661 |
+
# approximating-the-error-function-erf-by-analytical-functions
|
| 662 |
+
# this approximation is extremely crude and gets progressively worse
|
| 663 |
+
# for x very close to -1 or +1, but we mostly care about the
|
| 664 |
+
# "middle" region
|
| 665 |
+
# e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
|
| 666 |
+
# and math.erf(0.0407316414078772) = 0.045935330944660666,
|
| 667 |
+
# which is pretty close to 0.05.
|
| 668 |
+
return 0.8139535143 * _atanh(x)
|
| 669 |
+
|
| 670 |
+
# first convert x from the range 0..1 to the range -1..1 which the error
|
| 671 |
+
# function returns
|
| 672 |
+
x = -1 + (2 * x)
|
| 673 |
+
return _approx_inverse_erf(x)
|
| 674 |
+
|
| 675 |
+
min_mean = _proportion_positive_to_mean(float(self.min_positive))
|
| 676 |
+
max_mean = _proportion_positive_to_mean(float(self.max_positive))
|
| 677 |
+
min_rms = _abs_to_rms(float(self.min_abs))
|
| 678 |
+
max_rms = _abs_to_rms(float(self.max_abs))
|
| 679 |
+
grad_scale = float(self.grad_scale)
|
| 680 |
+
|
| 681 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
| 682 |
+
|
| 683 |
+
return BalancerFunction.apply(
|
| 684 |
+
x,
|
| 685 |
+
min_mean,
|
| 686 |
+
max_mean,
|
| 687 |
+
min_rms,
|
| 688 |
+
max_rms,
|
| 689 |
+
grad_scale,
|
| 690 |
+
self.channel_dim,
|
| 691 |
+
)
|
| 692 |
+
else:
|
| 693 |
+
return _no_op(x)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def penalize_abs_values_gt(
|
| 697 |
+
x: Tensor, limit: float, penalty: float, name: str = None
|
| 698 |
+
) -> Tensor:
|
| 699 |
+
"""
|
| 700 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
| 701 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
| 702 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
| 703 |
+
|
| 704 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
| 705 |
+
in automatic mixed precision training. For this reasons we use this,
|
| 706 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
| 707 |
+
to disallow really implausible values of scores to be given to softmax.
|
| 708 |
+
|
| 709 |
+
The name is for randomly printed debug info.
|
| 710 |
+
"""
|
| 711 |
+
x_sign = x.sign()
|
| 712 |
+
over_limit = (x.abs() - limit) > 0
|
| 713 |
+
# The following is a memory efficient way to penalize the absolute values of
|
| 714 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
| 715 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
| 716 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
| 717 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
| 718 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
| 719 |
+
# limit).relu().
|
| 720 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
| 721 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
| 722 |
+
# sum() due to how with_loss() works.
|
| 723 |
+
x = with_loss(x, aux_loss, name)
|
| 724 |
+
# you must use x for something, or this will be ineffective.
|
| 725 |
+
return x
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
| 729 |
+
if x.ndim == 2:
|
| 730 |
+
return x.diag()
|
| 731 |
+
else:
|
| 732 |
+
(batch, dim, dim) = x.shape
|
| 733 |
+
x = x.reshape(batch, dim * dim)
|
| 734 |
+
x = x[:, :: dim + 1]
|
| 735 |
+
assert x.shape == (batch, dim)
|
| 736 |
+
return x
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
| 740 |
+
"""
|
| 741 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
| 742 |
+
of the centered feature covariance are the same within each group's covariance
|
| 743 |
+
matrix and also between groups.
|
| 744 |
+
Args:
|
| 745 |
+
x: a Tensor of shape (*, num_channels)
|
| 746 |
+
num_groups: the number of groups of channels, a number >=1 that divides
|
| 747 |
+
num_channels
|
| 748 |
+
Returns:
|
| 749 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
| 750 |
+
greater than 1.0 otherwise.
|
| 751 |
+
"""
|
| 752 |
+
assert x.dtype != torch.float16
|
| 753 |
+
x = x.reshape(-1, x.shape[-1])
|
| 754 |
+
(num_frames, num_channels) = x.shape
|
| 755 |
+
assert num_channels % num_groups == 0
|
| 756 |
+
channels_per_group = num_channels // num_groups
|
| 757 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
| 758 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
| 759 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
| 760 |
+
# My experience has been that when we "mess with the gradients" like this,
|
| 761 |
+
# it's better not do anything that tries to move the mean around, because
|
| 762 |
+
# that can easily cause instability.
|
| 763 |
+
x = x - x.mean(dim=1, keepdim=True)
|
| 764 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
| 765 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
| 766 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
| 767 |
+
# the following expression is what we'd get if we took the matrix product
|
| 768 |
+
# of each covariance and measured the mean of its trace, i.e.
|
| 769 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
| 770 |
+
x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
|
| 771 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
| 772 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
|
| 773 |
+
return metric
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
| 777 |
+
@staticmethod
|
| 778 |
+
def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
|
| 779 |
+
ctx.save_for_backward(x)
|
| 780 |
+
ctx.module = module
|
| 781 |
+
return x
|
| 782 |
+
|
| 783 |
+
@staticmethod
|
| 784 |
+
def backward(ctx, x_grad: Tensor):
|
| 785 |
+
(x_orig,) = ctx.saved_tensors
|
| 786 |
+
w = ctx.module
|
| 787 |
+
|
| 788 |
+
try:
|
| 789 |
+
with torch.enable_grad():
|
| 790 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 791 |
+
x_detached = x_orig.to(torch.float32).detach()
|
| 792 |
+
x_detached.requires_grad = True
|
| 793 |
+
|
| 794 |
+
metric = _whitening_metric(x_detached, w.num_groups)
|
| 795 |
+
|
| 796 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
| 797 |
+
logging.debug(
|
| 798 |
+
f"Whitening: name={w.name}, num_groups={w.num_groups},"
|
| 799 |
+
f"num_channels={x_orig.shape[-1]}, "
|
| 800 |
+
f"metric={metric.item():.2f}"
|
| 801 |
+
f" vs. limit={float(w.whitening_limit)}"
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
if metric < float(w.whitening_limit):
|
| 805 |
+
w.prob = w.min_prob
|
| 806 |
+
return x_grad, None
|
| 807 |
+
else:
|
| 808 |
+
w.prob = w.max_prob
|
| 809 |
+
metric.backward()
|
| 810 |
+
penalty_grad = x_detached.grad
|
| 811 |
+
scale = w.grad_scale * (
|
| 812 |
+
x_grad.to(torch.float32).norm()
|
| 813 |
+
/ (penalty_grad.norm() + 1.0e-20)
|
| 814 |
+
)
|
| 815 |
+
penalty_grad = penalty_grad * scale
|
| 816 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None
|
| 817 |
+
except Exception as e:
|
| 818 |
+
logging.info(
|
| 819 |
+
f"Caught exception in Whiten backward: {e}, "
|
| 820 |
+
f"size={list(x_grad.shape)}, will continue."
|
| 821 |
+
)
|
| 822 |
+
return x_grad, None
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
class Whiten(nn.Module):
|
| 826 |
+
def __init__(
|
| 827 |
+
self,
|
| 828 |
+
num_groups: int,
|
| 829 |
+
whitening_limit: FloatLike,
|
| 830 |
+
prob: Union[float, Tuple[float, float]],
|
| 831 |
+
grad_scale: FloatLike,
|
| 832 |
+
):
|
| 833 |
+
"""
|
| 834 |
+
Args:
|
| 835 |
+
num_groups: the number of groups to divide the channel dim into before
|
| 836 |
+
whitening. We will attempt to make the feature covariance
|
| 837 |
+
within each group, after mean subtraction, as "white" as possible,
|
| 838 |
+
while having the same trace across all groups.
|
| 839 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
| 840 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
| 841 |
+
white, with exactly the same trace across groups; larger values
|
| 842 |
+
give more freedom. E.g. 2.0.
|
| 843 |
+
prob: the probability with which we apply the gradient modification
|
| 844 |
+
(also affects the grad scale). May be supplied as a float,
|
| 845 |
+
or as a pair (min_prob, max_prob)
|
| 846 |
+
|
| 847 |
+
grad_scale: determines the scale on the gradient term from this object,
|
| 848 |
+
relative to the rest of the gradient on the attention weights.
|
| 849 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
| 850 |
+
"""
|
| 851 |
+
super(Whiten, self).__init__()
|
| 852 |
+
assert num_groups >= 1
|
| 853 |
+
assert float(whitening_limit) >= 1
|
| 854 |
+
assert grad_scale >= 0
|
| 855 |
+
self.num_groups = num_groups
|
| 856 |
+
self.whitening_limit = whitening_limit
|
| 857 |
+
self.grad_scale = grad_scale
|
| 858 |
+
|
| 859 |
+
if isinstance(prob, float):
|
| 860 |
+
prob = (prob, prob)
|
| 861 |
+
(self.min_prob, self.max_prob) = prob
|
| 862 |
+
assert 0 < self.min_prob <= self.max_prob <= 1
|
| 863 |
+
self.prob = self.max_prob
|
| 864 |
+
self.name = None # will be set in training loop
|
| 865 |
+
|
| 866 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 867 |
+
"""
|
| 868 |
+
In the forward pass, this function just returns the input unmodified.
|
| 869 |
+
In the backward pass, it will modify the gradients to ensure that the
|
| 870 |
+
distribution in each group has close to (lambda times I) as the covariance
|
| 871 |
+
after mean subtraction, with the same lambda across groups.
|
| 872 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
| 873 |
+
constraint.
|
| 874 |
+
|
| 875 |
+
Args:
|
| 876 |
+
x: the input of shape (*, num_channels)
|
| 877 |
+
|
| 878 |
+
Returns:
|
| 879 |
+
x, unmodified. You should make sure
|
| 880 |
+
you use the returned value, or the graph will be freed
|
| 881 |
+
and nothing will happen in backprop.
|
| 882 |
+
"""
|
| 883 |
+
grad_scale = float(self.grad_scale)
|
| 884 |
+
if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
|
| 885 |
+
return _no_op(x)
|
| 886 |
+
else:
|
| 887 |
+
return WhiteningPenaltyFunction.apply(x, self)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
class WithLoss(torch.autograd.Function):
|
| 891 |
+
@staticmethod
|
| 892 |
+
def forward(ctx, x: Tensor, y: Tensor, name: str):
|
| 893 |
+
ctx.y_shape = y.shape
|
| 894 |
+
if random.random() < 0.002 and name is not None:
|
| 895 |
+
loss_sum = y.sum().item()
|
| 896 |
+
logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
|
| 897 |
+
return x
|
| 898 |
+
|
| 899 |
+
@staticmethod
|
| 900 |
+
def backward(ctx, ans_grad: Tensor):
|
| 901 |
+
return (
|
| 902 |
+
ans_grad,
|
| 903 |
+
torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
|
| 904 |
+
None,
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def with_loss(x, y, name):
|
| 909 |
+
# returns x but adds y.sum() to the loss function.
|
| 910 |
+
return WithLoss.apply(x, y, name)
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
class LimitParamValue(torch.autograd.Function):
|
| 914 |
+
@staticmethod
|
| 915 |
+
def forward(ctx, x: Tensor, min: float, max: float):
|
| 916 |
+
ctx.save_for_backward(x)
|
| 917 |
+
assert max >= min
|
| 918 |
+
ctx.min = min
|
| 919 |
+
ctx.max = max
|
| 920 |
+
return x
|
| 921 |
+
|
| 922 |
+
@staticmethod
|
| 923 |
+
def backward(ctx, x_grad: Tensor):
|
| 924 |
+
(x,) = ctx.saved_tensors
|
| 925 |
+
# where x < ctx.min, ensure all grads are negative (this will tend to make
|
| 926 |
+
# x more positive).
|
| 927 |
+
x_grad = x_grad * torch.where(
|
| 928 |
+
torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
|
| 929 |
+
)
|
| 930 |
+
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
| 931 |
+
# x more negative).
|
| 932 |
+
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
| 933 |
+
return x_grad, None, None
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
def limit_param_value(
|
| 937 |
+
x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
|
| 938 |
+
):
|
| 939 |
+
# You apply this to (typically) an nn.Parameter during training to ensure that its
|
| 940 |
+
# (elements mostly) stays within a supplied range. This is done by modifying the
|
| 941 |
+
# gradients in backprop.
|
| 942 |
+
# It's not necessary to do this on every batch: do it only some of the time,
|
| 943 |
+
# to save a little time.
|
| 944 |
+
if training and random.random() < prob:
|
| 945 |
+
return LimitParamValue.apply(x, min, max)
|
| 946 |
+
else:
|
| 947 |
+
return x
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def _no_op(x: Tensor) -> Tensor:
|
| 951 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 952 |
+
return x
|
| 953 |
+
else:
|
| 954 |
+
# a no-op function that will have a node in the autograd graph,
|
| 955 |
+
# to avoid certain bugs relating to backward hooks
|
| 956 |
+
return x.chunk(1, dim=-1)[0]
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
# Identity more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
| 960 |
+
class Identity(torch.nn.Module):
|
| 961 |
+
def __init__(self):
|
| 962 |
+
super(Identity, self).__init__()
|
| 963 |
+
|
| 964 |
+
def forward(self, x):
|
| 965 |
+
return _no_op(x)
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
# Dropout2 is just like normal dropout, except it supports schedules
|
| 969 |
+
# on the dropout rates.
|
| 970 |
+
class Dropout2(nn.Module):
|
| 971 |
+
def __init__(self, p: FloatLike):
|
| 972 |
+
super().__init__()
|
| 973 |
+
self.p = p
|
| 974 |
+
|
| 975 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 976 |
+
return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
class MulForDropout3(torch.autograd.Function):
|
| 980 |
+
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
| 981 |
+
# grad and is zero-or-one.
|
| 982 |
+
@staticmethod
|
| 983 |
+
@custom_fwd
|
| 984 |
+
def forward(ctx, x, y, alpha):
|
| 985 |
+
assert not y.requires_grad
|
| 986 |
+
ans = x * y * alpha
|
| 987 |
+
ctx.save_for_backward(ans)
|
| 988 |
+
ctx.alpha = alpha
|
| 989 |
+
return ans
|
| 990 |
+
|
| 991 |
+
@staticmethod
|
| 992 |
+
@custom_bwd
|
| 993 |
+
def backward(ctx, ans_grad):
|
| 994 |
+
(ans,) = ctx.saved_tensors
|
| 995 |
+
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
| 996 |
+
return x_grad, None, None
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
# Dropout3 is just like normal dropout, except it supports schedules on the dropout
|
| 1000 |
+
# rates, and it lets you choose one dimension to share the dropout mask over
|
| 1001 |
+
class Dropout3(nn.Module):
|
| 1002 |
+
def __init__(self, p: FloatLike, shared_dim: int):
|
| 1003 |
+
super().__init__()
|
| 1004 |
+
self.p = p
|
| 1005 |
+
self.shared_dim = shared_dim
|
| 1006 |
+
|
| 1007 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1008 |
+
p = float(self.p)
|
| 1009 |
+
if not self.training or p == 0:
|
| 1010 |
+
return _no_op(x)
|
| 1011 |
+
scale = 1.0 / (1 - p)
|
| 1012 |
+
rand_shape = list(x.shape)
|
| 1013 |
+
rand_shape[self.shared_dim] = 1
|
| 1014 |
+
mask = torch.rand(*rand_shape, device=x.device) > p
|
| 1015 |
+
ans = MulForDropout3.apply(x, mask, scale)
|
| 1016 |
+
return ans
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
class SwooshLFunction(torch.autograd.Function):
|
| 1020 |
+
"""
|
| 1021 |
+
swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
| 1022 |
+
"""
|
| 1023 |
+
|
| 1024 |
+
@staticmethod
|
| 1025 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
| 1026 |
+
requires_grad = x.requires_grad
|
| 1027 |
+
if x.dtype == torch.float16:
|
| 1028 |
+
x = x.to(torch.float32)
|
| 1029 |
+
|
| 1030 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1031 |
+
|
| 1032 |
+
coeff = -0.08
|
| 1033 |
+
|
| 1034 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 1035 |
+
with torch.enable_grad():
|
| 1036 |
+
x = x.detach()
|
| 1037 |
+
x.requires_grad = True
|
| 1038 |
+
y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
|
| 1039 |
+
|
| 1040 |
+
if not requires_grad:
|
| 1041 |
+
return y
|
| 1042 |
+
|
| 1043 |
+
y.backward(gradient=torch.ones_like(y))
|
| 1044 |
+
|
| 1045 |
+
grad = x.grad
|
| 1046 |
+
floor = coeff
|
| 1047 |
+
ceil = 1.0 + coeff + 0.005
|
| 1048 |
+
|
| 1049 |
+
d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
| 1050 |
+
grad
|
| 1051 |
+
)
|
| 1052 |
+
if __name__ == "__main__":
|
| 1053 |
+
# for self-testing only.
|
| 1054 |
+
assert d_scaled.min() >= 0.0
|
| 1055 |
+
assert d_scaled.max() < 256.0
|
| 1056 |
+
|
| 1057 |
+
d_int = d_scaled.to(torch.uint8)
|
| 1058 |
+
ctx.save_for_backward(d_int)
|
| 1059 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
| 1060 |
+
y = y.to(torch.float16)
|
| 1061 |
+
return y
|
| 1062 |
+
|
| 1063 |
+
@staticmethod
|
| 1064 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
| 1065 |
+
(d,) = ctx.saved_tensors
|
| 1066 |
+
# the same constants as used in forward pass.
|
| 1067 |
+
coeff = -0.08
|
| 1068 |
+
floor = coeff
|
| 1069 |
+
ceil = 1.0 + coeff + 0.005
|
| 1070 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
| 1071 |
+
return y_grad * d
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
class SwooshL(torch.nn.Module):
|
| 1075 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1076 |
+
"""Return Swoosh-L activation."""
|
| 1077 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1078 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1079 |
+
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
| 1080 |
+
return SwooshLFunction.apply(x)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
class SwooshLOnnx(torch.nn.Module):
|
| 1084 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1085 |
+
"""Return Swoosh-L activation."""
|
| 1086 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1087 |
+
return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
class SwooshRFunction(torch.autograd.Function):
|
| 1091 |
+
"""
|
| 1092 |
+
swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
| 1093 |
+
|
| 1094 |
+
derivatives are between -0.08 and 0.92.
|
| 1095 |
+
"""
|
| 1096 |
+
|
| 1097 |
+
@staticmethod
|
| 1098 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
| 1099 |
+
requires_grad = x.requires_grad
|
| 1100 |
+
|
| 1101 |
+
if x.dtype == torch.float16:
|
| 1102 |
+
x = x.to(torch.float32)
|
| 1103 |
+
|
| 1104 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1105 |
+
|
| 1106 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 1107 |
+
with torch.enable_grad():
|
| 1108 |
+
x = x.detach()
|
| 1109 |
+
x.requires_grad = True
|
| 1110 |
+
y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
| 1111 |
+
|
| 1112 |
+
if not requires_grad:
|
| 1113 |
+
return y
|
| 1114 |
+
y.backward(gradient=torch.ones_like(y))
|
| 1115 |
+
|
| 1116 |
+
grad = x.grad
|
| 1117 |
+
floor = -0.08
|
| 1118 |
+
ceil = 0.925
|
| 1119 |
+
|
| 1120 |
+
d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
| 1121 |
+
grad
|
| 1122 |
+
)
|
| 1123 |
+
if __name__ == "__main__":
|
| 1124 |
+
# for self-testing only.
|
| 1125 |
+
assert d_scaled.min() >= 0.0
|
| 1126 |
+
assert d_scaled.max() < 256.0
|
| 1127 |
+
|
| 1128 |
+
d_int = d_scaled.to(torch.uint8)
|
| 1129 |
+
ctx.save_for_backward(d_int)
|
| 1130 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
| 1131 |
+
y = y.to(torch.float16)
|
| 1132 |
+
return y
|
| 1133 |
+
|
| 1134 |
+
@staticmethod
|
| 1135 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
| 1136 |
+
(d,) = ctx.saved_tensors
|
| 1137 |
+
# the same constants as used in forward pass.
|
| 1138 |
+
floor = -0.08
|
| 1139 |
+
ceil = 0.925
|
| 1140 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
| 1141 |
+
return y_grad * d
|
| 1142 |
+
|
| 1143 |
+
|
| 1144 |
+
class SwooshR(torch.nn.Module):
|
| 1145 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1146 |
+
"""Return Swoosh-R activation."""
|
| 1147 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1148 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1149 |
+
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
| 1150 |
+
return SwooshRFunction.apply(x)
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
class SwooshROnnx(torch.nn.Module):
|
| 1154 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1155 |
+
"""Return Swoosh-R activation."""
|
| 1156 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1157 |
+
return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
|
| 1158 |
+
|
| 1159 |
+
|
| 1160 |
+
# simple version of SwooshL that does not redefine the backprop, used in
|
| 1161 |
+
# ActivationDropoutAndLinearFunction.
|
| 1162 |
+
def SwooshLForward(x: Tensor):
|
| 1163 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 1164 |
+
x = x.to(torch.float32)
|
| 1165 |
+
x_offset = x - 4.0
|
| 1166 |
+
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
| 1167 |
+
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
| 1168 |
+
return log_sum - 0.08 * x - 0.035
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
# simple version of SwooshR that does not redefine the backprop, used in
|
| 1172 |
+
# ActivationDropoutAndLinearFunction.
|
| 1173 |
+
def SwooshRForward(x: Tensor):
|
| 1174 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 1175 |
+
x = x.to(torch.float32)
|
| 1176 |
+
x_offset = x - 1.0
|
| 1177 |
+
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
| 1178 |
+
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
| 1179 |
+
return log_sum - 0.08 * x - 0.313261687
|
| 1180 |
+
|
| 1181 |
+
|
| 1182 |
+
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
| 1183 |
+
@staticmethod
|
| 1184 |
+
@custom_fwd
|
| 1185 |
+
def forward(
|
| 1186 |
+
ctx,
|
| 1187 |
+
x: Tensor,
|
| 1188 |
+
weight: Tensor,
|
| 1189 |
+
bias: Optional[Tensor],
|
| 1190 |
+
activation: str,
|
| 1191 |
+
dropout_p: float,
|
| 1192 |
+
dropout_shared_dim: Optional[int],
|
| 1193 |
+
):
|
| 1194 |
+
if dropout_p != 0.0:
|
| 1195 |
+
dropout_shape = list(x.shape)
|
| 1196 |
+
if dropout_shared_dim is not None:
|
| 1197 |
+
dropout_shape[dropout_shared_dim] = 1
|
| 1198 |
+
# else it won't be very memory efficient.
|
| 1199 |
+
dropout_mask = (1.0 / (1.0 - dropout_p)) * (
|
| 1200 |
+
torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
|
| 1201 |
+
)
|
| 1202 |
+
else:
|
| 1203 |
+
dropout_mask = None
|
| 1204 |
+
|
| 1205 |
+
ctx.save_for_backward(x, weight, bias, dropout_mask)
|
| 1206 |
+
|
| 1207 |
+
ctx.activation = activation
|
| 1208 |
+
|
| 1209 |
+
forward_activation_dict = {
|
| 1210 |
+
"SwooshL": k2.swoosh_l_forward,
|
| 1211 |
+
"SwooshR": k2.swoosh_r_forward,
|
| 1212 |
+
}
|
| 1213 |
+
# it will raise a KeyError if this fails. This will be an error. We let it
|
| 1214 |
+
# propagate to the user.
|
| 1215 |
+
activation_func = forward_activation_dict[activation]
|
| 1216 |
+
x = activation_func(x)
|
| 1217 |
+
if dropout_mask is not None:
|
| 1218 |
+
x = x * dropout_mask
|
| 1219 |
+
x = torch.nn.functional.linear(x, weight, bias)
|
| 1220 |
+
return x
|
| 1221 |
+
|
| 1222 |
+
@staticmethod
|
| 1223 |
+
@custom_bwd
|
| 1224 |
+
def backward(ctx, ans_grad: Tensor):
|
| 1225 |
+
saved = ctx.saved_tensors
|
| 1226 |
+
(x, weight, bias, dropout_mask) = saved
|
| 1227 |
+
|
| 1228 |
+
forward_and_deriv_activation_dict = {
|
| 1229 |
+
"SwooshL": k2.swoosh_l_forward_and_deriv,
|
| 1230 |
+
"SwooshR": k2.swoosh_r_forward_and_deriv,
|
| 1231 |
+
}
|
| 1232 |
+
# the following lines a KeyError if the activation is unrecognized.
|
| 1233 |
+
# This will be an error. We let it propagate to the user.
|
| 1234 |
+
func = forward_and_deriv_activation_dict[ctx.activation]
|
| 1235 |
+
|
| 1236 |
+
y, func_deriv = func(x)
|
| 1237 |
+
if dropout_mask is not None:
|
| 1238 |
+
y = y * dropout_mask
|
| 1239 |
+
# now compute derivative of y w.r.t. weight and bias..
|
| 1240 |
+
# y: (..., in_channels), ans_grad: (..., out_channels),
|
| 1241 |
+
(out_channels, in_channels) = weight.shape
|
| 1242 |
+
|
| 1243 |
+
in_channels = y.shape[-1]
|
| 1244 |
+
g = ans_grad.reshape(-1, out_channels)
|
| 1245 |
+
weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
|
| 1246 |
+
y_deriv = torch.matmul(ans_grad, weight)
|
| 1247 |
+
bias_deriv = None if bias is None else g.sum(dim=0)
|
| 1248 |
+
x_deriv = y_deriv * func_deriv
|
| 1249 |
+
if dropout_mask is not None:
|
| 1250 |
+
# order versus func_deriv does not matter
|
| 1251 |
+
x_deriv = x_deriv * dropout_mask
|
| 1252 |
+
|
| 1253 |
+
return x_deriv, weight_deriv, bias_deriv, None, None, None
|
| 1254 |
+
|
| 1255 |
+
|
| 1256 |
+
class ActivationDropoutAndLinear(torch.nn.Module):
|
| 1257 |
+
"""
|
| 1258 |
+
This merges an activation function followed by dropout and then a nn.Linear module;
|
| 1259 |
+
it does so in a memory efficient way so that it only stores the input to the whole
|
| 1260 |
+
module. If activation == SwooshL and dropout_shared_dim != None, this will be
|
| 1261 |
+
equivalent to:
|
| 1262 |
+
nn.Sequential(SwooshL(),
|
| 1263 |
+
Dropout3(dropout_p, shared_dim=dropout_shared_dim),
|
| 1264 |
+
ScaledLinear(in_channels, out_channels, bias=bias,
|
| 1265 |
+
initial_scale=initial_scale))
|
| 1266 |
+
If dropout_shared_dim is None, the dropout would be equivalent to
|
| 1267 |
+
Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
|
| 1268 |
+
mask is smaller.
|
| 1269 |
+
|
| 1270 |
+
Args:
|
| 1271 |
+
in_channels: number of input channels, e.g. 256
|
| 1272 |
+
out_channels: number of output channels, e.g. 256
|
| 1273 |
+
bias: if true, have a bias
|
| 1274 |
+
activation: the activation function, for now just support SwooshL.
|
| 1275 |
+
dropout_p: the dropout probability or schedule (happens after nonlinearity).
|
| 1276 |
+
dropout_shared_dim: the dimension, if any, across which the dropout mask is
|
| 1277 |
+
shared (e.g. the time dimension). If None, this may be less memory
|
| 1278 |
+
efficient if there are modules before this one that cache the input
|
| 1279 |
+
for their backprop (e.g. Balancer or Whiten).
|
| 1280 |
+
"""
|
| 1281 |
+
|
| 1282 |
+
def __init__(
|
| 1283 |
+
self,
|
| 1284 |
+
in_channels: int,
|
| 1285 |
+
out_channels: int,
|
| 1286 |
+
bias: bool = True,
|
| 1287 |
+
activation: str = "SwooshL",
|
| 1288 |
+
dropout_p: FloatLike = 0.0,
|
| 1289 |
+
dropout_shared_dim: Optional[int] = -1,
|
| 1290 |
+
initial_scale: float = 1.0,
|
| 1291 |
+
):
|
| 1292 |
+
super().__init__()
|
| 1293 |
+
# create a temporary module of nn.Linear that we'll steal the
|
| 1294 |
+
# weights and bias from
|
| 1295 |
+
l = ScaledLinear(
|
| 1296 |
+
in_channels, out_channels, bias=bias, initial_scale=initial_scale
|
| 1297 |
+
)
|
| 1298 |
+
|
| 1299 |
+
self.weight = l.weight
|
| 1300 |
+
# register_parameter properly handles making it a parameter when l.bias
|
| 1301 |
+
# is None. I think there is some reason for doing it this way rather
|
| 1302 |
+
# than just setting it to None but I don't know what it is, maybe
|
| 1303 |
+
# something to do with exporting the module..
|
| 1304 |
+
self.register_parameter("bias", l.bias)
|
| 1305 |
+
|
| 1306 |
+
self.activation = activation
|
| 1307 |
+
self.dropout_p = dropout_p
|
| 1308 |
+
self.dropout_shared_dim = dropout_shared_dim
|
| 1309 |
+
|
| 1310 |
+
def forward(self, x: Tensor):
|
| 1311 |
+
if (
|
| 1312 |
+
torch.jit.is_scripting()
|
| 1313 |
+
or torch.jit.is_tracing()
|
| 1314 |
+
or "k2" not in sys.modules
|
| 1315 |
+
):
|
| 1316 |
+
if self.activation == "SwooshL":
|
| 1317 |
+
x = SwooshLForward(x)
|
| 1318 |
+
elif self.activation == "SwooshR":
|
| 1319 |
+
x = SwooshRForward(x)
|
| 1320 |
+
else:
|
| 1321 |
+
assert False, self.activation
|
| 1322 |
+
return torch.nn.functional.linear(x, self.weight, self.bias)
|
| 1323 |
+
|
| 1324 |
+
return ActivationDropoutAndLinearFunction.apply(
|
| 1325 |
+
x,
|
| 1326 |
+
self.weight,
|
| 1327 |
+
self.bias,
|
| 1328 |
+
self.activation,
|
| 1329 |
+
float(self.dropout_p),
|
| 1330 |
+
self.dropout_shared_dim,
|
| 1331 |
+
)
|
| 1332 |
+
|
| 1333 |
+
|
| 1334 |
+
def _test_whiten():
|
| 1335 |
+
for proportion in [0.1, 0.5, 10.0]:
|
| 1336 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
| 1337 |
+
x = torch.randn(100, 128)
|
| 1338 |
+
direction = torch.randn(128)
|
| 1339 |
+
coeffs = torch.randn(100, 1)
|
| 1340 |
+
x += proportion * direction * coeffs
|
| 1341 |
+
|
| 1342 |
+
x.requires_grad = True
|
| 1343 |
+
|
| 1344 |
+
m = Whiten(
|
| 1345 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
| 1346 |
+
) # grad_scale
|
| 1347 |
+
|
| 1348 |
+
for _ in range(4):
|
| 1349 |
+
y = m(x)
|
| 1350 |
+
|
| 1351 |
+
y_grad = torch.randn_like(x)
|
| 1352 |
+
y.backward(gradient=y_grad)
|
| 1353 |
+
|
| 1354 |
+
if proportion < 0.2:
|
| 1355 |
+
assert torch.allclose(x.grad, y_grad)
|
| 1356 |
+
elif proportion > 1.0:
|
| 1357 |
+
assert not torch.allclose(x.grad, y_grad)
|
| 1358 |
+
|
| 1359 |
+
|
| 1360 |
+
def _test_balancer_sign():
|
| 1361 |
+
probs = torch.arange(0, 1, 0.01)
|
| 1362 |
+
N = 1000
|
| 1363 |
+
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
| 1364 |
+
x = x.detach()
|
| 1365 |
+
x.requires_grad = True
|
| 1366 |
+
m = Balancer(
|
| 1367 |
+
probs.numel(),
|
| 1368 |
+
channel_dim=0,
|
| 1369 |
+
min_positive=0.05,
|
| 1370 |
+
max_positive=0.95,
|
| 1371 |
+
min_abs=0.0,
|
| 1372 |
+
prob=1.0,
|
| 1373 |
+
)
|
| 1374 |
+
|
| 1375 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
| 1376 |
+
|
| 1377 |
+
y = m(x)
|
| 1378 |
+
y.backward(gradient=y_grad)
|
| 1379 |
+
print("_test_balancer_sign: x = ", x)
|
| 1380 |
+
print("_test_balancer_sign: y grad = ", y_grad)
|
| 1381 |
+
print("_test_balancer_sign: x grad = ", x.grad)
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
def _test_balancer_magnitude():
|
| 1385 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
| 1386 |
+
N = 1000
|
| 1387 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
| 1388 |
+
x = x.detach()
|
| 1389 |
+
x.requires_grad = True
|
| 1390 |
+
m = Balancer(
|
| 1391 |
+
magnitudes.numel(),
|
| 1392 |
+
channel_dim=0,
|
| 1393 |
+
min_positive=0.0,
|
| 1394 |
+
max_positive=1.0,
|
| 1395 |
+
min_abs=0.2,
|
| 1396 |
+
max_abs=0.7,
|
| 1397 |
+
prob=1.0,
|
| 1398 |
+
)
|
| 1399 |
+
|
| 1400 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
| 1401 |
+
|
| 1402 |
+
y = m(x)
|
| 1403 |
+
y.backward(gradient=y_grad)
|
| 1404 |
+
print("_test_balancer_magnitude: x = ", x)
|
| 1405 |
+
print("_test_balancer_magnitude: y grad = ", y_grad)
|
| 1406 |
+
print("_test_balancer_magnitude: x grad = ", x.grad)
|
| 1407 |
+
|
| 1408 |
+
|
| 1409 |
+
def _test_swooshl_deriv():
|
| 1410 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
| 1411 |
+
x.requires_grad = True
|
| 1412 |
+
m = SwooshL()
|
| 1413 |
+
|
| 1414 |
+
tol = 1.0 / 255.0
|
| 1415 |
+
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
|
| 1416 |
+
|
| 1417 |
+
# for self-test.
|
| 1418 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
| 1419 |
+
x.requires_grad = True
|
| 1420 |
+
y = m(x)
|
| 1421 |
+
return y
|
| 1422 |
+
|
| 1423 |
+
|
| 1424 |
+
def _test_swooshr_deriv():
|
| 1425 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
| 1426 |
+
x.requires_grad = True
|
| 1427 |
+
m = SwooshR()
|
| 1428 |
+
|
| 1429 |
+
tol = 1.0 / 255.0
|
| 1430 |
+
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
|
| 1431 |
+
|
| 1432 |
+
# for self-test.
|
| 1433 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
| 1434 |
+
x.requires_grad = True
|
| 1435 |
+
y = m(x)
|
| 1436 |
+
return y
|
| 1437 |
+
|
| 1438 |
+
|
| 1439 |
+
def _test_softmax():
|
| 1440 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
| 1441 |
+
b = a.clone()
|
| 1442 |
+
a.requires_grad = True
|
| 1443 |
+
b.requires_grad = True
|
| 1444 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
| 1445 |
+
print("a grad = ", a.grad)
|
| 1446 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
| 1447 |
+
print("b grad = ", b.grad)
|
| 1448 |
+
assert torch.allclose(a.grad, b.grad)
|
| 1449 |
+
|
| 1450 |
+
|
| 1451 |
+
def _test_piecewise_linear():
|
| 1452 |
+
p = PiecewiseLinear((0, 10.0))
|
| 1453 |
+
for x in [-100, 0, 100]:
|
| 1454 |
+
assert p(x) == 10.0
|
| 1455 |
+
p = PiecewiseLinear((0, 10.0), (1, 0.0))
|
| 1456 |
+
for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
|
| 1457 |
+
print("x, y = ", x, y)
|
| 1458 |
+
assert p(x) == y, (x, p(x), y)
|
| 1459 |
+
|
| 1460 |
+
q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
|
| 1461 |
+
x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
|
| 1462 |
+
pq = p.max(q)
|
| 1463 |
+
for x in x_vals:
|
| 1464 |
+
y1 = max(p(x), q(x))
|
| 1465 |
+
y2 = pq(x)
|
| 1466 |
+
assert abs(y1 - y2) < 0.001
|
| 1467 |
+
pq = p.min(q)
|
| 1468 |
+
for x in x_vals:
|
| 1469 |
+
y1 = min(p(x), q(x))
|
| 1470 |
+
y2 = pq(x)
|
| 1471 |
+
assert abs(y1 - y2) < 0.001
|
| 1472 |
+
pq = p + q
|
| 1473 |
+
for x in x_vals:
|
| 1474 |
+
y1 = p(x) + q(x)
|
| 1475 |
+
y2 = pq(x)
|
| 1476 |
+
assert abs(y1 - y2) < 0.001
|
| 1477 |
+
|
| 1478 |
+
|
| 1479 |
+
def _test_activation_dropout_and_linear():
|
| 1480 |
+
in_channels = 20
|
| 1481 |
+
out_channels = 30
|
| 1482 |
+
|
| 1483 |
+
for bias in [True, False]:
|
| 1484 |
+
# actually we don't test for dropout_p != 0.0 because forward functions will
|
| 1485 |
+
# different answers. This is because we are using the k2 implementation of
|
| 1486 |
+
# swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
|
| 1487 |
+
# internally, messing up the random state.
|
| 1488 |
+
for dropout_p in [0.0]:
|
| 1489 |
+
for activation in ["SwooshL", "SwooshR"]:
|
| 1490 |
+
m1 = nn.Sequential(
|
| 1491 |
+
SwooshL() if activation == "SwooshL" else SwooshR(),
|
| 1492 |
+
Dropout3(p=dropout_p, shared_dim=-1),
|
| 1493 |
+
ScaledLinear(
|
| 1494 |
+
in_channels, out_channels, bias=bias, initial_scale=0.5
|
| 1495 |
+
),
|
| 1496 |
+
)
|
| 1497 |
+
m2 = ActivationDropoutAndLinear(
|
| 1498 |
+
in_channels,
|
| 1499 |
+
out_channels,
|
| 1500 |
+
bias=bias,
|
| 1501 |
+
initial_scale=0.5,
|
| 1502 |
+
activation=activation,
|
| 1503 |
+
dropout_p=dropout_p,
|
| 1504 |
+
)
|
| 1505 |
+
with torch.no_grad():
|
| 1506 |
+
m2.weight[:] = m1[2].weight
|
| 1507 |
+
if bias:
|
| 1508 |
+
m2.bias[:] = m1[2].bias
|
| 1509 |
+
# make sure forward gives same result.
|
| 1510 |
+
x1 = torch.randn(10, in_channels)
|
| 1511 |
+
x1.requires_grad = True
|
| 1512 |
+
|
| 1513 |
+
# TEMP.
|
| 1514 |
+
assert torch.allclose(
|
| 1515 |
+
SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
|
| 1516 |
+
)
|
| 1517 |
+
|
| 1518 |
+
x2 = x1.clone().detach()
|
| 1519 |
+
x2.requires_grad = True
|
| 1520 |
+
seed = 10
|
| 1521 |
+
torch.manual_seed(seed)
|
| 1522 |
+
y1 = m1(x1)
|
| 1523 |
+
y_grad = torch.randn_like(y1)
|
| 1524 |
+
y1.backward(gradient=y_grad)
|
| 1525 |
+
torch.manual_seed(seed)
|
| 1526 |
+
y2 = m2(x2)
|
| 1527 |
+
y2.backward(gradient=y_grad)
|
| 1528 |
+
|
| 1529 |
+
print(
|
| 1530 |
+
f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
|
| 1531 |
+
)
|
| 1532 |
+
print("y1 = ", y1)
|
| 1533 |
+
print("y2 = ", y2)
|
| 1534 |
+
assert torch.allclose(y1, y2, atol=0.02)
|
| 1535 |
+
assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
|
| 1536 |
+
if bias:
|
| 1537 |
+
assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
|
| 1538 |
+
print("x1.grad = ", x1.grad)
|
| 1539 |
+
print("x2.grad = ", x2.grad)
|
| 1540 |
+
|
| 1541 |
+
def isclose(a, b):
|
| 1542 |
+
# return true if cosine similarity is > 0.9.
|
| 1543 |
+
return (a * b).sum() > 0.9 * (
|
| 1544 |
+
(a**2).sum() * (b**2).sum()
|
| 1545 |
+
).sqrt()
|
| 1546 |
+
|
| 1547 |
+
# the SwooshL() implementation has a noisy gradient due to 1-byte
|
| 1548 |
+
# storage of it.
|
| 1549 |
+
assert isclose(x1.grad, x2.grad)
|
| 1550 |
+
|
| 1551 |
+
|
| 1552 |
+
if __name__ == "__main__":
|
| 1553 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 1554 |
+
torch.set_num_threads(1)
|
| 1555 |
+
torch.set_num_interop_threads(1)
|
| 1556 |
+
_test_piecewise_linear()
|
| 1557 |
+
_test_softmax()
|
| 1558 |
+
_test_whiten()
|
| 1559 |
+
_test_balancer_sign()
|
| 1560 |
+
_test_balancer_magnitude()
|
| 1561 |
+
_test_swooshr_deriv()
|
| 1562 |
+
_test_swooshl_deriv()
|
| 1563 |
+
_test_activation_dropout_and_linear()
|
zipvoice/models/modules/solver.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from typing import Optional, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DiffusionModel(torch.nn.Module):
|
| 24 |
+
"""A wrapper of diffusion models for inference.
|
| 25 |
+
Args:
|
| 26 |
+
model: The diffusion model.
|
| 27 |
+
func_name: The function name to call.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
model: torch.nn.Module,
|
| 33 |
+
func_name: str = "forward_fm_decoder",
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.model = model
|
| 37 |
+
self.func_name = func_name
|
| 38 |
+
self.model_func = getattr(self.model, func_name)
|
| 39 |
+
|
| 40 |
+
def forward(
|
| 41 |
+
self,
|
| 42 |
+
t: torch.Tensor,
|
| 43 |
+
x: torch.Tensor,
|
| 44 |
+
text_condition: torch.Tensor,
|
| 45 |
+
speech_condition: torch.Tensor,
|
| 46 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 47 |
+
guidance_scale: Union[float, torch.Tensor] = 0.0,
|
| 48 |
+
**kwargs
|
| 49 |
+
) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Forward function that Handles the classifier-free guidance.
|
| 52 |
+
Args:
|
| 53 |
+
t: The current timestep, a tensor of a tensor of a single float.
|
| 54 |
+
x: The initial value, with the shape (batch, seq_len, emb_dim).
|
| 55 |
+
text_condition: The text_condition of the diffision model, with
|
| 56 |
+
the shape (batch, seq_len, emb_dim).
|
| 57 |
+
speech_condition: The speech_condition of the diffision model, with the
|
| 58 |
+
shape (batch, seq_len, emb_dim).
|
| 59 |
+
padding_mask: The mask for padding; True means masked position, with the
|
| 60 |
+
shape (batch, seq_len).
|
| 61 |
+
guidance_scale: The scale of classifier-free guidance, a float or a tensor
|
| 62 |
+
of shape (batch, 1, 1).
|
| 63 |
+
Retrun:
|
| 64 |
+
The prediction with the shape (batch, seq_len, emb_dim).
|
| 65 |
+
"""
|
| 66 |
+
if not torch.is_tensor(guidance_scale):
|
| 67 |
+
guidance_scale = torch.tensor(
|
| 68 |
+
guidance_scale, dtype=t.dtype, device=t.device
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if (guidance_scale == 0.0).all():
|
| 72 |
+
return self.model_func(
|
| 73 |
+
t=t,
|
| 74 |
+
xt=x,
|
| 75 |
+
text_condition=text_condition,
|
| 76 |
+
speech_condition=speech_condition,
|
| 77 |
+
padding_mask=padding_mask,
|
| 78 |
+
**kwargs
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
assert t.dim() == 0
|
| 82 |
+
|
| 83 |
+
x = torch.cat([x] * 2, dim=0)
|
| 84 |
+
padding_mask = torch.cat([padding_mask] * 2, dim=0)
|
| 85 |
+
|
| 86 |
+
text_condition = torch.cat(
|
| 87 |
+
[torch.zeros_like(text_condition), text_condition], dim=0
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if t > 0.5:
|
| 91 |
+
speech_condition = torch.cat(
|
| 92 |
+
[torch.zeros_like(speech_condition), speech_condition], dim=0
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
guidance_scale = guidance_scale * 2
|
| 96 |
+
speech_condition = torch.cat(
|
| 97 |
+
[speech_condition, speech_condition], dim=0
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
data_uncond, data_cond = self.model_func(
|
| 101 |
+
t=t,
|
| 102 |
+
xt=x,
|
| 103 |
+
text_condition=text_condition,
|
| 104 |
+
speech_condition=speech_condition,
|
| 105 |
+
padding_mask=padding_mask,
|
| 106 |
+
**kwargs
|
| 107 |
+
).chunk(2, dim=0)
|
| 108 |
+
|
| 109 |
+
res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
|
| 110 |
+
return res
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class DistillDiffusionModel(DiffusionModel):
|
| 114 |
+
"""A wrapper of distilled diffusion models for inference.
|
| 115 |
+
Args:
|
| 116 |
+
model: The distilled diffusion model.
|
| 117 |
+
func_name: The function name to call.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
model: torch.nn.Module,
|
| 123 |
+
func_name: str = "forward_fm_decoder",
|
| 124 |
+
):
|
| 125 |
+
super().__init__(model=model, func_name=func_name)
|
| 126 |
+
|
| 127 |
+
def forward(
|
| 128 |
+
self,
|
| 129 |
+
t: torch.Tensor,
|
| 130 |
+
x: torch.Tensor,
|
| 131 |
+
text_condition: torch.Tensor,
|
| 132 |
+
speech_condition: torch.Tensor,
|
| 133 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 134 |
+
guidance_scale: Union[float, torch.Tensor] = 0.0,
|
| 135 |
+
**kwargs
|
| 136 |
+
) -> torch.Tensor:
|
| 137 |
+
"""
|
| 138 |
+
Forward function that Handles the classifier-free guidance.
|
| 139 |
+
Args:
|
| 140 |
+
t: The current timestep, a tensor of a single float.
|
| 141 |
+
x: The initial value, with the shape (batch, seq_len, emb_dim).
|
| 142 |
+
text_condition: The text_condition of the diffision model, with
|
| 143 |
+
the shape (batch, seq_len, emb_dim).
|
| 144 |
+
speech_condition: The speech_condition of the diffision model, with the
|
| 145 |
+
shape (batch, seq_len, emb_dim).
|
| 146 |
+
padding_mask: The mask for padding; True means masked position, with the
|
| 147 |
+
shape (batch, seq_len).
|
| 148 |
+
guidance_scale: The scale of classifier-free guidance, a float or a tensor
|
| 149 |
+
of shape (batch, 1, 1).
|
| 150 |
+
Retrun:
|
| 151 |
+
The prediction with the shape (batch, seq_len, emb_dim).
|
| 152 |
+
"""
|
| 153 |
+
if not torch.is_tensor(guidance_scale):
|
| 154 |
+
guidance_scale = torch.tensor(
|
| 155 |
+
guidance_scale, dtype=t.dtype, device=t.device
|
| 156 |
+
)
|
| 157 |
+
return self.model_func(
|
| 158 |
+
t=t,
|
| 159 |
+
xt=x,
|
| 160 |
+
text_condition=text_condition,
|
| 161 |
+
speech_condition=speech_condition,
|
| 162 |
+
padding_mask=padding_mask,
|
| 163 |
+
guidance_scale=guidance_scale,
|
| 164 |
+
**kwargs
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class EulerSolver:
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
model: torch.nn.Module,
|
| 172 |
+
func_name: str = "forward_fm_decoder",
|
| 173 |
+
):
|
| 174 |
+
"""Construct a Euler Solver
|
| 175 |
+
Args:
|
| 176 |
+
model: The diffusion model.
|
| 177 |
+
func_name: The function name to call.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
self.model = DiffusionModel(model, func_name=func_name)
|
| 181 |
+
|
| 182 |
+
def sample(
|
| 183 |
+
self,
|
| 184 |
+
x: torch.Tensor,
|
| 185 |
+
text_condition: torch.Tensor,
|
| 186 |
+
speech_condition: torch.Tensor,
|
| 187 |
+
padding_mask: torch.Tensor,
|
| 188 |
+
num_step: int = 10,
|
| 189 |
+
guidance_scale: Union[float, torch.Tensor] = 0.0,
|
| 190 |
+
t_start: float = 0.0,
|
| 191 |
+
t_end: float = 1.0,
|
| 192 |
+
t_shift: float = 1.0,
|
| 193 |
+
**kwargs
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""
|
| 196 |
+
Compute the sample at time `t_end` by Euler Solver.
|
| 197 |
+
Args:
|
| 198 |
+
x: The initial value at time `t_start`, with the shape (batch, seq_len,
|
| 199 |
+
emb_dim).
|
| 200 |
+
text_condition: The text condition of the diffision mode, with the
|
| 201 |
+
shape (batch, seq_len, emb_dim).
|
| 202 |
+
speech_condition: The speech condition of the diffision model, with the
|
| 203 |
+
shape (batch, seq_len, emb_dim).
|
| 204 |
+
padding_mask: The mask for padding; True means masked position, with the
|
| 205 |
+
shape (batch, seq_len).
|
| 206 |
+
num_step: The number of ODE steps.
|
| 207 |
+
guidance_scale: The scale for classifier-free guidance, which is
|
| 208 |
+
a float or a tensor with the shape (batch, 1, 1).
|
| 209 |
+
t_start: the start timestep in the range of [0, 1].
|
| 210 |
+
t_end: the end time_step in the range of [0, 1].
|
| 211 |
+
t_shift: shift the t toward smaller numbers so that the sampling
|
| 212 |
+
will emphasize low SNR region. Should be in the range of (0, 1].
|
| 213 |
+
The shifting will be more significant when the number is smaller.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
The approximated solution at time `t_end`.
|
| 217 |
+
"""
|
| 218 |
+
device = x.device
|
| 219 |
+
assert isinstance(t_start, float) and isinstance(t_end, float)
|
| 220 |
+
|
| 221 |
+
timesteps = get_time_steps(
|
| 222 |
+
t_start=t_start,
|
| 223 |
+
t_end=t_end,
|
| 224 |
+
num_step=num_step,
|
| 225 |
+
t_shift=t_shift,
|
| 226 |
+
device=device,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
for step in range(num_step):
|
| 230 |
+
v = self.model(
|
| 231 |
+
t=timesteps[step],
|
| 232 |
+
x=x,
|
| 233 |
+
text_condition=text_condition,
|
| 234 |
+
speech_condition=speech_condition,
|
| 235 |
+
padding_mask=padding_mask,
|
| 236 |
+
guidance_scale=guidance_scale,
|
| 237 |
+
**kwargs
|
| 238 |
+
)
|
| 239 |
+
x = x + v * (timesteps[step + 1] - timesteps[step])
|
| 240 |
+
return x
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class DistillEulerSolver(EulerSolver):
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
model: torch.nn.Module,
|
| 247 |
+
func_name: str = "forward_fm_decoder",
|
| 248 |
+
):
|
| 249 |
+
"""Construct a Euler Solver for distilled diffusion models.
|
| 250 |
+
Args:
|
| 251 |
+
model: The diffusion model.
|
| 252 |
+
"""
|
| 253 |
+
self.model = DistillDiffusionModel(model, func_name=func_name)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def get_time_steps(
|
| 257 |
+
t_start: float = 0.0,
|
| 258 |
+
t_end: float = 1.0,
|
| 259 |
+
num_step: int = 10,
|
| 260 |
+
t_shift: float = 1.0,
|
| 261 |
+
device: torch.device = torch.device("cpu"),
|
| 262 |
+
) -> torch.Tensor:
|
| 263 |
+
"""Compute the intermediate time steps for sampling.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
t_start: The starting time of the sampling (default is 0).
|
| 267 |
+
t_end: The starting time of the sampling (default is 1).
|
| 268 |
+
num_step: The number of sampling.
|
| 269 |
+
t_shift: shift the t toward smaller numbers so that the sampling
|
| 270 |
+
will emphasize low SNR region. Should be in the range of (0, 1].
|
| 271 |
+
The shifting will be more significant when the number is smaller.
|
| 272 |
+
device: A torch device.
|
| 273 |
+
Returns:
|
| 274 |
+
The time step with the shape (num_step + 1,).
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
|
| 278 |
+
|
| 279 |
+
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
|
| 280 |
+
|
| 281 |
+
return timesteps
|
zipvoice/models/modules/zipformer.py
ADDED
|
@@ -0,0 +1,1680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey,
|
| 3 |
+
# Zengwei Yao,
|
| 4 |
+
# Wei Kang
|
| 5 |
+
# Han Zhu)
|
| 6 |
+
#
|
| 7 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import copy
|
| 22 |
+
import logging
|
| 23 |
+
import math
|
| 24 |
+
import random
|
| 25 |
+
from typing import Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from torch import Tensor, nn
|
| 29 |
+
|
| 30 |
+
from zipvoice.models.modules.scaling import (
|
| 31 |
+
ActivationDropoutAndLinear,
|
| 32 |
+
Balancer,
|
| 33 |
+
BiasNorm,
|
| 34 |
+
Dropout2,
|
| 35 |
+
FloatLike,
|
| 36 |
+
Identity,
|
| 37 |
+
ScaledLinear,
|
| 38 |
+
ScheduledFloat,
|
| 39 |
+
SwooshR,
|
| 40 |
+
Whiten,
|
| 41 |
+
limit_param_value,
|
| 42 |
+
penalize_abs_values_gt,
|
| 43 |
+
softmax,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
| 48 |
+
"""Create sinusoidal timestep embeddings.
|
| 49 |
+
|
| 50 |
+
:param timesteps: shape of (N) or (N, T)
|
| 51 |
+
:param dim: the dimension of the output.
|
| 52 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 53 |
+
:return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
|
| 54 |
+
"""
|
| 55 |
+
half = dim // 2
|
| 56 |
+
freqs = torch.exp(
|
| 57 |
+
-math.log(max_period)
|
| 58 |
+
* torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
|
| 59 |
+
/ half
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if timesteps.dim() == 2:
|
| 63 |
+
timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
|
| 64 |
+
|
| 65 |
+
args = timesteps[..., None].float() * freqs[None]
|
| 66 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 67 |
+
if dim % 2:
|
| 68 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
|
| 69 |
+
return embedding
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TTSZipformer(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
Args:
|
| 75 |
+
|
| 76 |
+
Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
|
| 77 |
+
length as downsampling_factor if they are single ints or one-element tuples.
|
| 78 |
+
The length of downsampling_factor defines the number of stacks.
|
| 79 |
+
|
| 80 |
+
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
|
| 81 |
+
Note: this is in addition to the downsampling factor of 2 that is applied in
|
| 82 |
+
the frontend (self.encoder_embed).
|
| 83 |
+
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
|
| 84 |
+
one per encoder stack.
|
| 85 |
+
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
| 86 |
+
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
| 87 |
+
head: per stack, if a tuple..
|
| 88 |
+
pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
|
| 89 |
+
per attention head
|
| 90 |
+
value_head_dim (int or Tuple[int]): dimension of value in each attention head
|
| 91 |
+
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
|
| 92 |
+
Must be at least 4.
|
| 93 |
+
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
|
| 94 |
+
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
|
| 95 |
+
|
| 96 |
+
pos_dim (int): the dimension of each positional-encoding vector prior to
|
| 97 |
+
projection, e.g. 128.
|
| 98 |
+
|
| 99 |
+
dropout (float): dropout rate
|
| 100 |
+
warmup_batches (float): number of batches to warm up over; this controls
|
| 101 |
+
dropout of encoder layers.
|
| 102 |
+
use_time_embed: (bool): if True, take time embedding as an additional input.
|
| 103 |
+
time_embed_dim: (int): the dimension of the time embedding.
|
| 104 |
+
use_guidance_scale_embed (bool): if True, take guidance scale embedding as
|
| 105 |
+
an additional input.
|
| 106 |
+
guidance_scale_embed_dim: (int): the dimension of the guidance scale embedding.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
in_dim: int,
|
| 112 |
+
out_dim: int,
|
| 113 |
+
downsampling_factor: Union[int, Tuple[int]] = (2, 4),
|
| 114 |
+
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
| 115 |
+
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
| 116 |
+
encoder_dim: int = 384,
|
| 117 |
+
query_head_dim: int = 24,
|
| 118 |
+
pos_head_dim: int = 4,
|
| 119 |
+
value_head_dim: int = 12,
|
| 120 |
+
num_heads: int = 8,
|
| 121 |
+
feedforward_dim: int = 1536,
|
| 122 |
+
pos_dim: int = 192,
|
| 123 |
+
dropout: FloatLike = None, # see code below for default
|
| 124 |
+
warmup_batches: float = 4000.0,
|
| 125 |
+
use_time_embed: bool = True,
|
| 126 |
+
time_embed_dim: int = 192,
|
| 127 |
+
use_guidance_scale_embed: bool = False,
|
| 128 |
+
guidance_scale_embed_dim: int = 192,
|
| 129 |
+
use_conv: bool = True,
|
| 130 |
+
) -> None:
|
| 131 |
+
super(TTSZipformer, self).__init__()
|
| 132 |
+
|
| 133 |
+
if dropout is None:
|
| 134 |
+
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
|
| 135 |
+
if isinstance(downsampling_factor, int):
|
| 136 |
+
downsampling_factor = (downsampling_factor,)
|
| 137 |
+
|
| 138 |
+
def _to_tuple(x):
|
| 139 |
+
"""Converts a single int or a 1-tuple of an int to a tuple with the same
|
| 140 |
+
length as downsampling_factor"""
|
| 141 |
+
if isinstance(x, int):
|
| 142 |
+
x = (x,)
|
| 143 |
+
if len(x) == 1:
|
| 144 |
+
x = x * len(downsampling_factor)
|
| 145 |
+
else:
|
| 146 |
+
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
def _assert_downsampling_factor(factors):
|
| 150 |
+
"""assert downsampling_factor follows u-net style"""
|
| 151 |
+
assert factors[0] == 1 and factors[-1] == 1
|
| 152 |
+
|
| 153 |
+
for i in range(1, len(factors) // 2 + 1):
|
| 154 |
+
assert factors[i] == factors[i - 1] * 2
|
| 155 |
+
|
| 156 |
+
for i in range(len(factors) // 2 + 1, len(factors)):
|
| 157 |
+
assert factors[i] * 2 == factors[i - 1]
|
| 158 |
+
|
| 159 |
+
_assert_downsampling_factor(downsampling_factor)
|
| 160 |
+
self.downsampling_factor = downsampling_factor # tuple
|
| 161 |
+
num_encoder_layers = _to_tuple(num_encoder_layers)
|
| 162 |
+
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
| 163 |
+
self.encoder_dim = encoder_dim
|
| 164 |
+
self.num_encoder_layers = num_encoder_layers
|
| 165 |
+
self.query_head_dim = query_head_dim
|
| 166 |
+
self.value_head_dim = value_head_dim
|
| 167 |
+
self.num_heads = num_heads
|
| 168 |
+
|
| 169 |
+
self.use_time_embed = use_time_embed
|
| 170 |
+
self.use_guidance_scale_embed = use_guidance_scale_embed
|
| 171 |
+
|
| 172 |
+
self.time_embed_dim = time_embed_dim
|
| 173 |
+
if self.use_time_embed:
|
| 174 |
+
assert time_embed_dim != -1
|
| 175 |
+
else:
|
| 176 |
+
time_embed_dim = -1
|
| 177 |
+
self.guidance_scale_embed_dim = guidance_scale_embed_dim
|
| 178 |
+
|
| 179 |
+
self.in_proj = nn.Linear(in_dim, encoder_dim)
|
| 180 |
+
self.out_proj = nn.Linear(encoder_dim, out_dim)
|
| 181 |
+
|
| 182 |
+
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
| 183 |
+
encoders = []
|
| 184 |
+
|
| 185 |
+
num_encoders = len(downsampling_factor)
|
| 186 |
+
for i in range(num_encoders):
|
| 187 |
+
encoder_layer = Zipformer2EncoderLayer(
|
| 188 |
+
embed_dim=encoder_dim,
|
| 189 |
+
pos_dim=pos_dim,
|
| 190 |
+
num_heads=num_heads,
|
| 191 |
+
query_head_dim=query_head_dim,
|
| 192 |
+
pos_head_dim=pos_head_dim,
|
| 193 |
+
value_head_dim=value_head_dim,
|
| 194 |
+
feedforward_dim=feedforward_dim,
|
| 195 |
+
use_conv=use_conv,
|
| 196 |
+
cnn_module_kernel=cnn_module_kernel[i],
|
| 197 |
+
dropout=dropout,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# For the segment of the warmup period, we let the Conv2dSubsampling
|
| 201 |
+
# layer learn something. Then we start to warm up the other encoders.
|
| 202 |
+
encoder = Zipformer2Encoder(
|
| 203 |
+
encoder_layer,
|
| 204 |
+
num_encoder_layers[i],
|
| 205 |
+
embed_dim=encoder_dim,
|
| 206 |
+
time_embed_dim=time_embed_dim,
|
| 207 |
+
pos_dim=pos_dim,
|
| 208 |
+
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
| 209 |
+
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
| 210 |
+
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if downsampling_factor[i] != 1:
|
| 214 |
+
encoder = DownsampledZipformer2Encoder(
|
| 215 |
+
encoder,
|
| 216 |
+
dim=encoder_dim,
|
| 217 |
+
downsample=downsampling_factor[i],
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
encoders.append(encoder)
|
| 221 |
+
|
| 222 |
+
self.encoders = nn.ModuleList(encoders)
|
| 223 |
+
if self.use_time_embed:
|
| 224 |
+
self.time_embed = nn.Sequential(
|
| 225 |
+
nn.Linear(time_embed_dim, time_embed_dim * 2),
|
| 226 |
+
SwooshR(),
|
| 227 |
+
nn.Linear(time_embed_dim * 2, time_embed_dim),
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
self.time_embed = None
|
| 231 |
+
|
| 232 |
+
if self.use_guidance_scale_embed:
|
| 233 |
+
self.guidance_scale_embed = ScaledLinear(
|
| 234 |
+
guidance_scale_embed_dim,
|
| 235 |
+
time_embed_dim,
|
| 236 |
+
bias=False,
|
| 237 |
+
initial_scale=0.1,
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
self.guidance_scale_embed = None
|
| 241 |
+
|
| 242 |
+
def forward(
|
| 243 |
+
self,
|
| 244 |
+
x: Tensor,
|
| 245 |
+
t: Optional[Tensor] = None,
|
| 246 |
+
padding_mask: Optional[Tensor] = None,
|
| 247 |
+
guidance_scale: Optional[Tensor] = None,
|
| 248 |
+
) -> Tuple[Tensor, Tensor]:
|
| 249 |
+
"""
|
| 250 |
+
Args:
|
| 251 |
+
x:
|
| 252 |
+
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
| 253 |
+
t:
|
| 254 |
+
A t tensor of shape (batch_size,) or (batch_size, seq_len)
|
| 255 |
+
padding_mask:
|
| 256 |
+
The mask for padding, of shape (batch_size, seq_len); True means
|
| 257 |
+
masked position. May be None.
|
| 258 |
+
guidance_scale:
|
| 259 |
+
The guidance scale in classifier-free guidance of distillation model.
|
| 260 |
+
Returns:
|
| 261 |
+
Return the output embeddings. its shape is
|
| 262 |
+
(batch_size, output_seq_len, encoder_dim)
|
| 263 |
+
"""
|
| 264 |
+
x = x.permute(1, 0, 2)
|
| 265 |
+
x = self.in_proj(x)
|
| 266 |
+
|
| 267 |
+
if t is not None:
|
| 268 |
+
assert t.dim() == 1 or t.dim() == 2, t.shape
|
| 269 |
+
time_emb = timestep_embedding(t, self.time_embed_dim)
|
| 270 |
+
if guidance_scale is not None:
|
| 271 |
+
assert (
|
| 272 |
+
guidance_scale.dim() == 1 or guidance_scale.dim() == 2
|
| 273 |
+
), guidance_scale.shape
|
| 274 |
+
guidance_scale_emb = self.guidance_scale_embed(
|
| 275 |
+
timestep_embedding(guidance_scale, self.guidance_scale_embed_dim)
|
| 276 |
+
)
|
| 277 |
+
time_emb = time_emb + guidance_scale_emb
|
| 278 |
+
time_emb = self.time_embed(time_emb)
|
| 279 |
+
else:
|
| 280 |
+
time_emb = None
|
| 281 |
+
|
| 282 |
+
attn_mask = None
|
| 283 |
+
|
| 284 |
+
for i, module in enumerate(self.encoders):
|
| 285 |
+
x = module(
|
| 286 |
+
x,
|
| 287 |
+
time_emb=time_emb,
|
| 288 |
+
src_key_padding_mask=padding_mask,
|
| 289 |
+
attn_mask=attn_mask,
|
| 290 |
+
)
|
| 291 |
+
x = self.out_proj(x)
|
| 292 |
+
x = x.permute(1, 0, 2)
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
| 297 |
+
return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class Zipformer2EncoderLayer(nn.Module):
|
| 301 |
+
"""
|
| 302 |
+
Args:
|
| 303 |
+
embed_dim: the number of expected features in the input (required).
|
| 304 |
+
nhead: the number of heads in the multiheadattention models (required).
|
| 305 |
+
feedforward_dim: the dimension of the feedforward network model (required).
|
| 306 |
+
dropout: the dropout value (default=0.1).
|
| 307 |
+
cnn_module_kernel (int): Kernel size of convolution module (default=31).
|
| 308 |
+
|
| 309 |
+
Examples::
|
| 310 |
+
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
| 311 |
+
>>> src = torch.rand(10, 32, 512)
|
| 312 |
+
>>> pos_emb = torch.rand(32, 19, 512)
|
| 313 |
+
>>> out = encoder_layer(src, pos_emb)
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(
|
| 317 |
+
self,
|
| 318 |
+
embed_dim: int,
|
| 319 |
+
pos_dim: int,
|
| 320 |
+
num_heads: int,
|
| 321 |
+
query_head_dim: int,
|
| 322 |
+
pos_head_dim: int,
|
| 323 |
+
value_head_dim: int,
|
| 324 |
+
feedforward_dim: int,
|
| 325 |
+
dropout: FloatLike = 0.1,
|
| 326 |
+
cnn_module_kernel: int = 31,
|
| 327 |
+
use_conv: bool = True,
|
| 328 |
+
attention_skip_rate: FloatLike = ScheduledFloat(
|
| 329 |
+
(0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
|
| 330 |
+
),
|
| 331 |
+
conv_skip_rate: FloatLike = ScheduledFloat(
|
| 332 |
+
(0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
|
| 333 |
+
),
|
| 334 |
+
const_attention_rate: FloatLike = ScheduledFloat(
|
| 335 |
+
(0.0, 0.25), (4000.0, 0.025), default=0
|
| 336 |
+
),
|
| 337 |
+
ff2_skip_rate: FloatLike = ScheduledFloat(
|
| 338 |
+
(0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
|
| 339 |
+
),
|
| 340 |
+
ff3_skip_rate: FloatLike = ScheduledFloat(
|
| 341 |
+
(0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
|
| 342 |
+
),
|
| 343 |
+
bypass_skip_rate: FloatLike = ScheduledFloat(
|
| 344 |
+
(0.0, 0.5), (4000.0, 0.02), default=0
|
| 345 |
+
),
|
| 346 |
+
) -> None:
|
| 347 |
+
super(Zipformer2EncoderLayer, self).__init__()
|
| 348 |
+
self.embed_dim = embed_dim
|
| 349 |
+
|
| 350 |
+
# self.bypass implements layer skipping as well as bypass.
|
| 351 |
+
self.bypass = BypassModule(
|
| 352 |
+
embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
|
| 353 |
+
)
|
| 354 |
+
# bypass_mid is bypass used in the middle of the layer.
|
| 355 |
+
self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
|
| 356 |
+
|
| 357 |
+
# skip probability for dynamic modules (meaning: anything but feedforward).
|
| 358 |
+
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
|
| 359 |
+
# an additional skip probability that applies to ConvModule to stop it from
|
| 360 |
+
# contributing too much early on.
|
| 361 |
+
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
| 362 |
+
|
| 363 |
+
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
|
| 364 |
+
# compared to its residual.
|
| 365 |
+
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
|
| 366 |
+
self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
|
| 367 |
+
|
| 368 |
+
self.const_attention_rate = copy.deepcopy(const_attention_rate)
|
| 369 |
+
|
| 370 |
+
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
| 371 |
+
embed_dim,
|
| 372 |
+
pos_dim=pos_dim,
|
| 373 |
+
num_heads=num_heads,
|
| 374 |
+
query_head_dim=query_head_dim,
|
| 375 |
+
pos_head_dim=pos_head_dim,
|
| 376 |
+
dropout=0.0,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
| 380 |
+
|
| 381 |
+
self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
| 382 |
+
|
| 383 |
+
self.feed_forward1 = FeedforwardModule(
|
| 384 |
+
embed_dim, (feedforward_dim * 3) // 4, dropout
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
|
| 388 |
+
|
| 389 |
+
self.feed_forward3 = FeedforwardModule(
|
| 390 |
+
embed_dim, (feedforward_dim * 5) // 4, dropout
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
self.nonlin_attention = NonlinAttention(
|
| 394 |
+
embed_dim, hidden_channels=3 * embed_dim // 4
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
self.use_conv = use_conv
|
| 398 |
+
|
| 399 |
+
if self.use_conv:
|
| 400 |
+
self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel)
|
| 401 |
+
|
| 402 |
+
self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel)
|
| 403 |
+
|
| 404 |
+
self.norm = BiasNorm(embed_dim)
|
| 405 |
+
|
| 406 |
+
self.balancer1 = Balancer(
|
| 407 |
+
embed_dim,
|
| 408 |
+
channel_dim=-1,
|
| 409 |
+
min_positive=0.45,
|
| 410 |
+
max_positive=0.55,
|
| 411 |
+
min_abs=0.2,
|
| 412 |
+
max_abs=4.0,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# balancer for output of NonlinAttentionModule
|
| 416 |
+
self.balancer_na = Balancer(
|
| 417 |
+
embed_dim,
|
| 418 |
+
channel_dim=-1,
|
| 419 |
+
min_positive=0.3,
|
| 420 |
+
max_positive=0.7,
|
| 421 |
+
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
| 422 |
+
prob=0.05, # out of concern for memory usage
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# balancer for output of feedforward2, prevent it from staying too
|
| 426 |
+
# small. give this a very small probability, even at the start of
|
| 427 |
+
# training, it's to fix a rare problem and it's OK to fix it slowly.
|
| 428 |
+
self.balancer_ff2 = Balancer(
|
| 429 |
+
embed_dim,
|
| 430 |
+
channel_dim=-1,
|
| 431 |
+
min_positive=0.3,
|
| 432 |
+
max_positive=0.7,
|
| 433 |
+
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
|
| 434 |
+
max_abs=2.0,
|
| 435 |
+
prob=0.05,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
self.balancer_ff3 = Balancer(
|
| 439 |
+
embed_dim,
|
| 440 |
+
channel_dim=-1,
|
| 441 |
+
min_positive=0.3,
|
| 442 |
+
max_positive=0.7,
|
| 443 |
+
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
|
| 444 |
+
max_abs=4.0,
|
| 445 |
+
prob=0.05,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
self.whiten = Whiten(
|
| 449 |
+
num_groups=1,
|
| 450 |
+
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
| 451 |
+
prob=(0.025, 0.25),
|
| 452 |
+
grad_scale=0.01,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
self.balancer2 = Balancer(
|
| 456 |
+
embed_dim,
|
| 457 |
+
channel_dim=-1,
|
| 458 |
+
min_positive=0.45,
|
| 459 |
+
max_positive=0.55,
|
| 460 |
+
min_abs=0.1,
|
| 461 |
+
max_abs=4.0,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
def get_sequence_dropout_mask(
|
| 465 |
+
self, x: Tensor, dropout_rate: float
|
| 466 |
+
) -> Optional[Tensor]:
|
| 467 |
+
if (
|
| 468 |
+
dropout_rate == 0.0
|
| 469 |
+
or not self.training
|
| 470 |
+
or torch.jit.is_scripting()
|
| 471 |
+
or torch.jit.is_tracing()
|
| 472 |
+
):
|
| 473 |
+
return None
|
| 474 |
+
batch_size = x.shape[1]
|
| 475 |
+
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
| 476 |
+
return mask
|
| 477 |
+
|
| 478 |
+
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
|
| 479 |
+
"""
|
| 480 |
+
Apply sequence-level dropout to x.
|
| 481 |
+
x shape: (seq_len, batch_size, embed_dim)
|
| 482 |
+
"""
|
| 483 |
+
dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
|
| 484 |
+
if dropout_mask is None:
|
| 485 |
+
return x
|
| 486 |
+
else:
|
| 487 |
+
return x * dropout_mask
|
| 488 |
+
|
| 489 |
+
def forward(
|
| 490 |
+
self,
|
| 491 |
+
src: Tensor,
|
| 492 |
+
pos_emb: Tensor,
|
| 493 |
+
time_emb: Optional[Tensor] = None,
|
| 494 |
+
attn_mask: Optional[Tensor] = None,
|
| 495 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 496 |
+
) -> Tensor:
|
| 497 |
+
"""
|
| 498 |
+
Pass the input through the encoder layer.
|
| 499 |
+
Args:
|
| 500 |
+
src: the sequence to the encoder (required):
|
| 501 |
+
shape (seq_len, batch_size, embedding_dim).
|
| 502 |
+
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or
|
| 503 |
+
(batch_size, 2*seq_len-1, pos_emb_dim)
|
| 504 |
+
time_emb: the embedding representing the current timestep
|
| 505 |
+
shape (batch_size, embedding_dim) or (seq_len, batch_size, embedding_dim).
|
| 506 |
+
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
|
| 507 |
+
or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len)
|
| 508 |
+
or (tgt_seq_len, src_seq_len). True means masked position. May be None.
|
| 509 |
+
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
|
| 510 |
+
True means masked position. May be None.
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
A tensor which has the same shape as src
|
| 514 |
+
"""
|
| 515 |
+
src_orig = src
|
| 516 |
+
|
| 517 |
+
# dropout rate for non-feedforward submodules
|
| 518 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 519 |
+
attention_skip_rate = 0.0
|
| 520 |
+
else:
|
| 521 |
+
attention_skip_rate = (
|
| 522 |
+
float(self.attention_skip_rate) if self.training else 0.0
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
| 526 |
+
attn_weights = self.self_attn_weights(
|
| 527 |
+
src,
|
| 528 |
+
pos_emb=pos_emb,
|
| 529 |
+
attn_mask=attn_mask,
|
| 530 |
+
key_padding_mask=src_key_padding_mask,
|
| 531 |
+
)
|
| 532 |
+
if time_emb is not None:
|
| 533 |
+
|
| 534 |
+
src = src + time_emb
|
| 535 |
+
|
| 536 |
+
src = src + self.feed_forward1(src)
|
| 537 |
+
|
| 538 |
+
self_attn_dropout_mask = self.get_sequence_dropout_mask(
|
| 539 |
+
src, attention_skip_rate
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
selected_attn_weights = attn_weights[0:1]
|
| 543 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 544 |
+
pass
|
| 545 |
+
elif self.training and random.random() < float(self.const_attention_rate):
|
| 546 |
+
# Make attention weights constant. The intention is to
|
| 547 |
+
# encourage these modules to do something similar to an
|
| 548 |
+
# averaging-over-time operation.
|
| 549 |
+
# only need the mask, can just use the 1st one and expand later
|
| 550 |
+
selected_attn_weights = selected_attn_weights[0:1]
|
| 551 |
+
selected_attn_weights = (selected_attn_weights > 0.0).to(
|
| 552 |
+
selected_attn_weights.dtype
|
| 553 |
+
)
|
| 554 |
+
selected_attn_weights = selected_attn_weights * (
|
| 555 |
+
1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
|
| 559 |
+
|
| 560 |
+
src = src + (
|
| 561 |
+
na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
self_attn = self.self_attn1(src, attn_weights)
|
| 565 |
+
|
| 566 |
+
src = src + (
|
| 567 |
+
self_attn
|
| 568 |
+
if self_attn_dropout_mask is None
|
| 569 |
+
else self_attn * self_attn_dropout_mask
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
if self.use_conv:
|
| 573 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 574 |
+
conv_skip_rate = 0.0
|
| 575 |
+
else:
|
| 576 |
+
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
| 577 |
+
|
| 578 |
+
if time_emb is not None:
|
| 579 |
+
src = src + time_emb
|
| 580 |
+
|
| 581 |
+
src = src + self.sequence_dropout(
|
| 582 |
+
self.conv_module1(
|
| 583 |
+
src,
|
| 584 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 585 |
+
),
|
| 586 |
+
conv_skip_rate,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 590 |
+
ff2_skip_rate = 0.0
|
| 591 |
+
else:
|
| 592 |
+
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
| 593 |
+
src = src + self.sequence_dropout(
|
| 594 |
+
self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# bypass in the middle of the layer.
|
| 598 |
+
src = self.bypass_mid(src_orig, src)
|
| 599 |
+
|
| 600 |
+
self_attn = self.self_attn2(src, attn_weights)
|
| 601 |
+
|
| 602 |
+
src = src + (
|
| 603 |
+
self_attn
|
| 604 |
+
if self_attn_dropout_mask is None
|
| 605 |
+
else self_attn * self_attn_dropout_mask
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
if self.use_conv:
|
| 609 |
+
|
| 610 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 611 |
+
conv_skip_rate = 0.0
|
| 612 |
+
else:
|
| 613 |
+
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
| 614 |
+
|
| 615 |
+
if time_emb is not None:
|
| 616 |
+
src = src + time_emb
|
| 617 |
+
|
| 618 |
+
src = src + self.sequence_dropout(
|
| 619 |
+
self.conv_module2(
|
| 620 |
+
src,
|
| 621 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 622 |
+
),
|
| 623 |
+
conv_skip_rate,
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 627 |
+
ff3_skip_rate = 0.0
|
| 628 |
+
else:
|
| 629 |
+
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
| 630 |
+
src = src + self.sequence_dropout(
|
| 631 |
+
self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
src = self.balancer1(src)
|
| 635 |
+
src = self.norm(src)
|
| 636 |
+
|
| 637 |
+
src = self.bypass(src_orig, src)
|
| 638 |
+
|
| 639 |
+
src = self.balancer2(src)
|
| 640 |
+
src = self.whiten(src)
|
| 641 |
+
|
| 642 |
+
return src
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class Zipformer2Encoder(nn.Module):
|
| 646 |
+
r"""Zipformer2Encoder is a stack of N encoder layers
|
| 647 |
+
|
| 648 |
+
Args:
|
| 649 |
+
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
|
| 650 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
| 651 |
+
pos_dim: the dimension for the relative positional encoding
|
| 652 |
+
|
| 653 |
+
Examples::
|
| 654 |
+
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
| 655 |
+
>>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
|
| 656 |
+
>>> src = torch.rand(10, 32, 512)
|
| 657 |
+
>>> out = zipformer_encoder(src)
|
| 658 |
+
"""
|
| 659 |
+
|
| 660 |
+
def __init__(
|
| 661 |
+
self,
|
| 662 |
+
encoder_layer: nn.Module,
|
| 663 |
+
num_layers: int,
|
| 664 |
+
embed_dim: int,
|
| 665 |
+
time_embed_dim: int,
|
| 666 |
+
pos_dim: int,
|
| 667 |
+
warmup_begin: float,
|
| 668 |
+
warmup_end: float,
|
| 669 |
+
initial_layerdrop_rate: float = 0.5,
|
| 670 |
+
final_layerdrop_rate: float = 0.05,
|
| 671 |
+
) -> None:
|
| 672 |
+
super().__init__()
|
| 673 |
+
self.encoder_pos = CompactRelPositionalEncoding(
|
| 674 |
+
pos_dim, dropout_rate=0.15, length_factor=1.0
|
| 675 |
+
)
|
| 676 |
+
if time_embed_dim != -1:
|
| 677 |
+
self.time_emb = nn.Sequential(
|
| 678 |
+
SwooshR(),
|
| 679 |
+
nn.Linear(time_embed_dim, embed_dim),
|
| 680 |
+
)
|
| 681 |
+
else:
|
| 682 |
+
self.time_emb = None
|
| 683 |
+
|
| 684 |
+
self.layers = nn.ModuleList(
|
| 685 |
+
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
| 686 |
+
)
|
| 687 |
+
self.num_layers = num_layers
|
| 688 |
+
|
| 689 |
+
assert 0 <= warmup_begin <= warmup_end
|
| 690 |
+
|
| 691 |
+
delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
|
| 692 |
+
cur_begin = warmup_begin # interpreted as a training batch index
|
| 693 |
+
for i in range(num_layers):
|
| 694 |
+
cur_end = cur_begin + delta
|
| 695 |
+
self.layers[i].bypass.skip_rate = ScheduledFloat(
|
| 696 |
+
(cur_begin, initial_layerdrop_rate),
|
| 697 |
+
(cur_end, final_layerdrop_rate),
|
| 698 |
+
default=0.0,
|
| 699 |
+
)
|
| 700 |
+
cur_begin = cur_end
|
| 701 |
+
|
| 702 |
+
def forward(
|
| 703 |
+
self,
|
| 704 |
+
src: Tensor,
|
| 705 |
+
time_emb: Optional[Tensor] = None,
|
| 706 |
+
attn_mask: Optional[Tensor] = None,
|
| 707 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 708 |
+
) -> Tensor:
|
| 709 |
+
r"""Pass the input through the encoder layers in turn.
|
| 710 |
+
|
| 711 |
+
Args:
|
| 712 |
+
src: the sequence to the encoder (required):
|
| 713 |
+
shape (seq_len, batch_size, embedding_dim).
|
| 714 |
+
time_emb: the embedding representing the current timestep:
|
| 715 |
+
shape (batch_size, embedding_dim)
|
| 716 |
+
or (seq_len, batch_size, embedding_dim) .
|
| 717 |
+
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
|
| 718 |
+
or (seq_len, seq_len), interpreted as
|
| 719 |
+
(batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
| 720 |
+
True means masked position. May be None.
|
| 721 |
+
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
|
| 722 |
+
True means masked position. May be None.
|
| 723 |
+
|
| 724 |
+
Returns: a Tensor with the same shape as src.
|
| 725 |
+
"""
|
| 726 |
+
pos_emb = self.encoder_pos(src)
|
| 727 |
+
if self.time_emb is not None:
|
| 728 |
+
assert time_emb is not None
|
| 729 |
+
time_emb = self.time_emb(time_emb)
|
| 730 |
+
else:
|
| 731 |
+
assert time_emb is None
|
| 732 |
+
|
| 733 |
+
output = src
|
| 734 |
+
|
| 735 |
+
for i, mod in enumerate(self.layers):
|
| 736 |
+
output = mod(
|
| 737 |
+
output,
|
| 738 |
+
pos_emb,
|
| 739 |
+
time_emb=time_emb,
|
| 740 |
+
attn_mask=attn_mask,
|
| 741 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
return output
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
class BypassModule(nn.Module):
|
| 748 |
+
"""
|
| 749 |
+
An nn.Module that implements a learnable bypass scale, and also randomized
|
| 750 |
+
per-sequence layer-skipping. The bypass is limited during early stages of training
|
| 751 |
+
to be close to "straight-through", i.e. to not do the bypass operation much
|
| 752 |
+
initially, in order to force all the modules to learn something.
|
| 753 |
+
"""
|
| 754 |
+
|
| 755 |
+
def __init__(
|
| 756 |
+
self,
|
| 757 |
+
embed_dim: int,
|
| 758 |
+
skip_rate: FloatLike = 0.0,
|
| 759 |
+
straight_through_rate: FloatLike = 0.0,
|
| 760 |
+
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
| 761 |
+
scale_max: FloatLike = 1.0,
|
| 762 |
+
):
|
| 763 |
+
super().__init__()
|
| 764 |
+
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
| 765 |
+
self.skip_rate = copy.deepcopy(skip_rate)
|
| 766 |
+
self.straight_through_rate = copy.deepcopy(straight_through_rate)
|
| 767 |
+
self.scale_min = copy.deepcopy(scale_min)
|
| 768 |
+
self.scale_max = copy.deepcopy(scale_max)
|
| 769 |
+
|
| 770 |
+
def _get_bypass_scale(self, batch_size: int):
|
| 771 |
+
# returns bypass-scale of shape (num_channels,),
|
| 772 |
+
# or (batch_size, num_channels,). This is actually the
|
| 773 |
+
# scale on the non-residual term, so 0 corresponds to bypassing
|
| 774 |
+
# this module.
|
| 775 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
| 776 |
+
return self.bypass_scale
|
| 777 |
+
else:
|
| 778 |
+
ans = limit_param_value(
|
| 779 |
+
self.bypass_scale,
|
| 780 |
+
min=float(self.scale_min),
|
| 781 |
+
max=float(self.scale_max),
|
| 782 |
+
)
|
| 783 |
+
skip_rate = float(self.skip_rate)
|
| 784 |
+
if skip_rate != 0.0:
|
| 785 |
+
mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
|
| 786 |
+
ans = ans * mask
|
| 787 |
+
# now ans is of shape (batch_size, num_channels), and is zero for
|
| 788 |
+
# sequences on which we have randomly chosen to do layer-skipping.
|
| 789 |
+
straight_through_rate = float(self.straight_through_rate)
|
| 790 |
+
if straight_through_rate != 0.0:
|
| 791 |
+
mask = (
|
| 792 |
+
torch.rand((batch_size, 1), device=ans.device)
|
| 793 |
+
< straight_through_rate
|
| 794 |
+
)
|
| 795 |
+
ans = torch.maximum(ans, mask.to(ans.dtype))
|
| 796 |
+
return ans
|
| 797 |
+
|
| 798 |
+
def forward(self, src_orig: Tensor, src: Tensor):
|
| 799 |
+
"""
|
| 800 |
+
Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
|
| 801 |
+
Returns: something with the same shape as src and src_orig
|
| 802 |
+
"""
|
| 803 |
+
bypass_scale = self._get_bypass_scale(src.shape[1])
|
| 804 |
+
return src_orig + (src - src_orig) * bypass_scale
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
class DownsampledZipformer2Encoder(nn.Module):
|
| 808 |
+
r"""
|
| 809 |
+
DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame
|
| 810 |
+
rate, after convolutional downsampling, and then upsampled again at the output, and
|
| 811 |
+
combined with the origin input, so that the output has the same shape as the input.
|
| 812 |
+
"""
|
| 813 |
+
|
| 814 |
+
def __init__(self, encoder: nn.Module, dim: int, downsample: int):
|
| 815 |
+
super(DownsampledZipformer2Encoder, self).__init__()
|
| 816 |
+
self.downsample_factor = downsample
|
| 817 |
+
self.downsample = SimpleDownsample(downsample)
|
| 818 |
+
self.num_layers = encoder.num_layers
|
| 819 |
+
self.encoder = encoder
|
| 820 |
+
self.upsample = SimpleUpsample(downsample)
|
| 821 |
+
self.out_combiner = BypassModule(dim, straight_through_rate=0)
|
| 822 |
+
|
| 823 |
+
def forward(
|
| 824 |
+
self,
|
| 825 |
+
src: Tensor,
|
| 826 |
+
time_emb: Optional[Tensor] = None,
|
| 827 |
+
attn_mask: Optional[Tensor] = None,
|
| 828 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 829 |
+
) -> Tensor:
|
| 830 |
+
r"""Downsample, go through encoder, upsample.
|
| 831 |
+
|
| 832 |
+
Args:
|
| 833 |
+
src: the sequence to the encoder (required):
|
| 834 |
+
shape (seq_len, batch_size, embedding_dim).
|
| 835 |
+
time_emb: the embedding representing the current timestep:
|
| 836 |
+
shape (batch_size, embedding_dim)
|
| 837 |
+
or (seq_len, batch_size, embedding_dim) .
|
| 838 |
+
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
| 839 |
+
by at every layer: if a Tensor, likely of shape
|
| 840 |
+
(seq_len, batch_size, embedding_dim)
|
| 841 |
+
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
|
| 842 |
+
or (seq_len, seq_len), interpreted as
|
| 843 |
+
(batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
| 844 |
+
True means masked position. May be None.
|
| 845 |
+
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
|
| 846 |
+
True means masked position. May be None.
|
| 847 |
+
|
| 848 |
+
Returns: a Tensor with the same shape as src.
|
| 849 |
+
"""
|
| 850 |
+
src_orig = src
|
| 851 |
+
src = self.downsample(src)
|
| 852 |
+
ds = self.downsample_factor
|
| 853 |
+
if time_emb is not None and time_emb.dim() == 3:
|
| 854 |
+
time_emb = time_emb[::ds]
|
| 855 |
+
if attn_mask is not None:
|
| 856 |
+
attn_mask = attn_mask[::ds, ::ds]
|
| 857 |
+
if src_key_padding_mask is not None:
|
| 858 |
+
src_key_padding_mask = src_key_padding_mask[..., ::ds]
|
| 859 |
+
|
| 860 |
+
src = self.encoder(
|
| 861 |
+
src,
|
| 862 |
+
time_emb=time_emb,
|
| 863 |
+
attn_mask=attn_mask,
|
| 864 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 865 |
+
)
|
| 866 |
+
src = self.upsample(src)
|
| 867 |
+
# remove any extra frames that are not a multiple of downsample_factor
|
| 868 |
+
src = src[: src_orig.shape[0]]
|
| 869 |
+
|
| 870 |
+
return self.out_combiner(src_orig, src)
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
class SimpleDownsample(torch.nn.Module):
|
| 874 |
+
"""
|
| 875 |
+
Does downsampling with attention, by weighted sum.
|
| 876 |
+
"""
|
| 877 |
+
|
| 878 |
+
def __init__(self, downsample: int):
|
| 879 |
+
super(SimpleDownsample, self).__init__()
|
| 880 |
+
|
| 881 |
+
self.bias = nn.Parameter(torch.zeros(downsample))
|
| 882 |
+
|
| 883 |
+
self.name = None # will be set from training code
|
| 884 |
+
|
| 885 |
+
self.downsample = downsample
|
| 886 |
+
|
| 887 |
+
def forward(self, src: Tensor) -> Tensor:
|
| 888 |
+
"""
|
| 889 |
+
x: (seq_len, batch_size, in_channels)
|
| 890 |
+
Returns a tensor of shape
|
| 891 |
+
( (seq_len+downsample-1)//downsample, batch_size, channels)
|
| 892 |
+
"""
|
| 893 |
+
(seq_len, batch_size, in_channels) = src.shape
|
| 894 |
+
ds = self.downsample
|
| 895 |
+
d_seq_len = (seq_len + ds - 1) // ds
|
| 896 |
+
|
| 897 |
+
# Pad to an exact multiple of self.downsample
|
| 898 |
+
# right-pad src, repeating the last element.
|
| 899 |
+
pad = d_seq_len * ds - seq_len
|
| 900 |
+
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
|
| 901 |
+
src = torch.cat((src, src_extra), dim=0)
|
| 902 |
+
assert src.shape[0] == d_seq_len * ds
|
| 903 |
+
|
| 904 |
+
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
| 905 |
+
|
| 906 |
+
weights = self.bias.softmax(dim=0)
|
| 907 |
+
# weights: (downsample, 1, 1)
|
| 908 |
+
weights = weights.unsqueeze(-1).unsqueeze(-1)
|
| 909 |
+
|
| 910 |
+
# ans1 is the first `in_channels` channels of the output
|
| 911 |
+
ans = (src * weights).sum(dim=1)
|
| 912 |
+
|
| 913 |
+
return ans
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
class SimpleUpsample(torch.nn.Module):
|
| 917 |
+
"""
|
| 918 |
+
A very simple form of upsampling that just repeats the input.
|
| 919 |
+
"""
|
| 920 |
+
|
| 921 |
+
def __init__(self, upsample: int):
|
| 922 |
+
super(SimpleUpsample, self).__init__()
|
| 923 |
+
self.upsample = upsample
|
| 924 |
+
|
| 925 |
+
def forward(self, src: Tensor) -> Tensor:
|
| 926 |
+
"""
|
| 927 |
+
x: (seq_len, batch_size, num_channels)
|
| 928 |
+
Returns a tensor of shape
|
| 929 |
+
( (seq_len*upsample), batch_size, num_channels)
|
| 930 |
+
"""
|
| 931 |
+
upsample = self.upsample
|
| 932 |
+
(seq_len, batch_size, num_channels) = src.shape
|
| 933 |
+
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
| 934 |
+
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
| 935 |
+
return src
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
class CompactRelPositionalEncoding(torch.nn.Module):
|
| 939 |
+
"""
|
| 940 |
+
Relative positional encoding module. This version is "compact" meaning it is able
|
| 941 |
+
to encode the important information about the relative position in a relatively
|
| 942 |
+
small number of dimensions. The goal is to make it so that small differences between
|
| 943 |
+
large relative offsets (e.g. 1000 vs. 1001) make very little difference to the
|
| 944 |
+
embedding. Such differences were potentially important when encoding absolute
|
| 945 |
+
position, but not important when encoding relative position because there is now no
|
| 946 |
+
need to compare two large offsets with each other.
|
| 947 |
+
|
| 948 |
+
Our embedding works by projecting the interval [-infinity,infinity] to a finite
|
| 949 |
+
interval using the atan() function, before doing the Fourier transform of that fixed
|
| 950 |
+
interval. The atan() function would compress the "long tails" too small, making it
|
| 951 |
+
hard to distinguish between different magnitudes of large offsets, so we use a
|
| 952 |
+
logarithmic function to compress large offsets to a smaller range before applying
|
| 953 |
+
atan(). Scalings are chosen in such a way that the embedding can clearly distinguish
|
| 954 |
+
individual offsets as long as they are quite close to the origin, e.g. abs(offset)
|
| 955 |
+
<= about sqrt(embedding_dim)
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
Args:
|
| 959 |
+
embed_dim: Embedding dimension.
|
| 960 |
+
dropout_rate: Dropout rate.
|
| 961 |
+
max_len: Maximum input length: just a heuristic for initialization.
|
| 962 |
+
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
|
| 963 |
+
less weight to small differences of offset near the origin.
|
| 964 |
+
"""
|
| 965 |
+
|
| 966 |
+
def __init__(
|
| 967 |
+
self,
|
| 968 |
+
embed_dim: int,
|
| 969 |
+
dropout_rate: FloatLike,
|
| 970 |
+
max_len: int = 1000,
|
| 971 |
+
length_factor: float = 1.0,
|
| 972 |
+
) -> None:
|
| 973 |
+
"""Construct a CompactRelPositionalEncoding object."""
|
| 974 |
+
super(CompactRelPositionalEncoding, self).__init__()
|
| 975 |
+
self.embed_dim = embed_dim
|
| 976 |
+
assert embed_dim % 2 == 0, embed_dim
|
| 977 |
+
self.dropout = Dropout2(dropout_rate)
|
| 978 |
+
self.pe = None
|
| 979 |
+
assert length_factor >= 1.0, length_factor
|
| 980 |
+
self.length_factor = length_factor
|
| 981 |
+
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
| 982 |
+
|
| 983 |
+
def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
|
| 984 |
+
"""Reset the positional encodings."""
|
| 985 |
+
T = x.size(0) + left_context_len
|
| 986 |
+
|
| 987 |
+
if self.pe is not None:
|
| 988 |
+
# self.pe contains both positive and negative parts
|
| 989 |
+
# the length of self.pe is 2 * input_len - 1
|
| 990 |
+
if self.pe.size(0) >= T * 2 - 1:
|
| 991 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 992 |
+
return
|
| 993 |
+
|
| 994 |
+
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
| 995 |
+
x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
|
| 996 |
+
|
| 997 |
+
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
|
| 998 |
+
|
| 999 |
+
# `compression_length` this is arbitrary/heuristic, if it is larger we have more
|
| 1000 |
+
# resolution for small time offsets but less resolution for large time offsets.
|
| 1001 |
+
compression_length = self.embed_dim**0.5
|
| 1002 |
+
# x_compressed, like X, goes from -infinity to infinity as T goes from -infinity
|
| 1003 |
+
# to infinity; but it does so more slowly than T for large absolute values of T.
|
| 1004 |
+
# The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which is
|
| 1005 |
+
# important.
|
| 1006 |
+
x_compressed = (
|
| 1007 |
+
compression_length
|
| 1008 |
+
* x.sign()
|
| 1009 |
+
* ((x.abs() + compression_length).log() - math.log(compression_length))
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
# if self.length_factor == 1.0, then length_scale is chosen so that the
|
| 1013 |
+
# FFT can exactly separate points close to the origin (T == 0). So this
|
| 1014 |
+
# part of the formulation is not really heuristic.
|
| 1015 |
+
# But empirically, for ASR at least, length_factor > 1.0 seems to work better.
|
| 1016 |
+
length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
|
| 1017 |
+
|
| 1018 |
+
# note for machine implementations: if atan is not available, we can use:
|
| 1019 |
+
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
|
| 1020 |
+
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 ,
|
| 1021 |
+
# atan(x))
|
| 1022 |
+
x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
|
| 1023 |
+
|
| 1024 |
+
cosines = (x_atan * freqs).cos()
|
| 1025 |
+
sines = (x_atan * freqs).sin()
|
| 1026 |
+
|
| 1027 |
+
pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
|
| 1028 |
+
pe[:, 0::2] = cosines
|
| 1029 |
+
pe[:, 1::2] = sines
|
| 1030 |
+
pe[:, -1] = 1.0 # for bias.
|
| 1031 |
+
|
| 1032 |
+
self.pe = pe.to(dtype=x.dtype)
|
| 1033 |
+
|
| 1034 |
+
def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
|
| 1035 |
+
"""Create positional encoding.
|
| 1036 |
+
|
| 1037 |
+
Args:
|
| 1038 |
+
x (Tensor): Input tensor (time, batch, `*`).
|
| 1039 |
+
left_context_len: (int): Length of cached left context.
|
| 1040 |
+
|
| 1041 |
+
Returns:
|
| 1042 |
+
positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
|
| 1043 |
+
"""
|
| 1044 |
+
self.extend_pe(x, left_context_len)
|
| 1045 |
+
x_size_left = x.size(0) + left_context_len
|
| 1046 |
+
# length of positive side: x.size(0) + left_context_len
|
| 1047 |
+
# length of negative side: x.size(0)
|
| 1048 |
+
pos_emb = self.pe[
|
| 1049 |
+
self.pe.size(0) // 2
|
| 1050 |
+
- x_size_left
|
| 1051 |
+
+ 1 : self.pe.size(0) // 2 # noqa E203
|
| 1052 |
+
+ x.size(0),
|
| 1053 |
+
:,
|
| 1054 |
+
]
|
| 1055 |
+
pos_emb = pos_emb.unsqueeze(0)
|
| 1056 |
+
return self.dropout(pos_emb)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
class RelPositionMultiheadAttentionWeights(nn.Module):
|
| 1060 |
+
r"""Module that computes multi-head attention weights with relative position
|
| 1061 |
+
encoding. Various other modules consume the resulting attention weights:
|
| 1062 |
+
see, for example, the SimpleAttention module which allows you to compute
|
| 1063 |
+
conventional attention.
|
| 1064 |
+
|
| 1065 |
+
This is a quite heavily modified from: "Transformer-XL: Attentive Language
|
| 1066 |
+
Models Beyond a Fixed-Length Context",
|
| 1067 |
+
we have to write up the differences.
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
Args:
|
| 1071 |
+
embed_dim: number of channels at the input to this module, e.g. 256
|
| 1072 |
+
pos_dim: dimension of the positional encoding vectors, e.g. 128.
|
| 1073 |
+
num_heads: number of heads to compute weights for, e.g. 8
|
| 1074 |
+
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
| 1075 |
+
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
| 1076 |
+
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
| 1077 |
+
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
| 1078 |
+
any given call to forward(), in training time.
|
| 1079 |
+
"""
|
| 1080 |
+
|
| 1081 |
+
def __init__(
|
| 1082 |
+
self,
|
| 1083 |
+
embed_dim: int,
|
| 1084 |
+
pos_dim: int,
|
| 1085 |
+
num_heads: int,
|
| 1086 |
+
query_head_dim: int,
|
| 1087 |
+
pos_head_dim: int,
|
| 1088 |
+
dropout: float = 0.0,
|
| 1089 |
+
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
| 1090 |
+
) -> None:
|
| 1091 |
+
super().__init__()
|
| 1092 |
+
self.embed_dim = embed_dim
|
| 1093 |
+
self.num_heads = num_heads
|
| 1094 |
+
self.query_head_dim = query_head_dim
|
| 1095 |
+
self.pos_head_dim = pos_head_dim
|
| 1096 |
+
self.dropout = dropout
|
| 1097 |
+
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
| 1098 |
+
self.name = None # will be overwritten in training code; for diagnostics.
|
| 1099 |
+
|
| 1100 |
+
key_head_dim = query_head_dim
|
| 1101 |
+
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
| 1102 |
+
|
| 1103 |
+
# the initial_scale is supposed to take over the "scaling" factor of
|
| 1104 |
+
# head_dim ** -0.5 that has been used in previous forms of attention,
|
| 1105 |
+
# dividing it between the query and key. Note: this module is intended
|
| 1106 |
+
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
| 1107 |
+
# it would be necessary to apply the scaling factor in the forward function.
|
| 1108 |
+
self.in_proj = ScaledLinear(
|
| 1109 |
+
embed_dim,
|
| 1110 |
+
in_proj_dim,
|
| 1111 |
+
bias=True,
|
| 1112 |
+
initial_scale=query_head_dim**-0.25,
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
self.whiten_keys = Whiten(
|
| 1116 |
+
num_groups=num_heads,
|
| 1117 |
+
whitening_limit=_whitening_schedule(3.0),
|
| 1118 |
+
prob=(0.025, 0.25),
|
| 1119 |
+
grad_scale=0.025,
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
# add a balancer for the keys that runs with very small probability, and
|
| 1123 |
+
# tries to enforce that all dimensions have mean around zero. The
|
| 1124 |
+
# weights produced by this module are invariant to adding a constant to
|
| 1125 |
+
# the keys, so the derivative of the bias is mathematically zero; but
|
| 1126 |
+
# due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
|
| 1127 |
+
# bias because the small numerical roundoff tends to have a non-random
|
| 1128 |
+
# sign. This module is intended to prevent that. Use a very small
|
| 1129 |
+
# probability; that should be sufficient to fix the problem.
|
| 1130 |
+
self.balance_keys = Balancer(
|
| 1131 |
+
key_head_dim * num_heads,
|
| 1132 |
+
channel_dim=-1,
|
| 1133 |
+
min_positive=0.4,
|
| 1134 |
+
max_positive=0.6,
|
| 1135 |
+
min_abs=0.0,
|
| 1136 |
+
max_abs=100.0,
|
| 1137 |
+
prob=0.025,
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
# linear transformation for positional encoding.
|
| 1141 |
+
self.linear_pos = ScaledLinear(
|
| 1142 |
+
pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
|
| 1143 |
+
)
|
| 1144 |
+
|
| 1145 |
+
# the following are for diagnostics only, see --print-diagnostics option
|
| 1146 |
+
self.copy_pos_query = Identity()
|
| 1147 |
+
self.copy_query = Identity()
|
| 1148 |
+
|
| 1149 |
+
def forward(
|
| 1150 |
+
self,
|
| 1151 |
+
x: Tensor,
|
| 1152 |
+
pos_emb: Tensor,
|
| 1153 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 1154 |
+
attn_mask: Optional[Tensor] = None,
|
| 1155 |
+
) -> Tensor:
|
| 1156 |
+
r"""
|
| 1157 |
+
Args:
|
| 1158 |
+
x: input of shape (seq_len, batch_size, embed_dim)
|
| 1159 |
+
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
| 1160 |
+
key_padding_mask: a bool tensor of shape (batch_size, seq_len).
|
| 1161 |
+
Positions that are True in this mask will be ignored as sources in the
|
| 1162 |
+
attention weighting.
|
| 1163 |
+
attn_mask: mask of shape (seq_len, seq_len) or
|
| 1164 |
+
(batch_size, seq_len, seq_len), interpreted as
|
| 1165 |
+
([batch_size,] tgt_seq_len, src_seq_len)
|
| 1166 |
+
saying which positions are allowed to attend to which other positions.
|
| 1167 |
+
Returns:
|
| 1168 |
+
a tensor of attention weights, of
|
| 1169 |
+
shape (hum_heads, batch_size, seq_len, seq_len)
|
| 1170 |
+
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
|
| 1171 |
+
"""
|
| 1172 |
+
x = self.in_proj(x)
|
| 1173 |
+
query_head_dim = self.query_head_dim
|
| 1174 |
+
pos_head_dim = self.pos_head_dim
|
| 1175 |
+
num_heads = self.num_heads
|
| 1176 |
+
|
| 1177 |
+
seq_len, batch_size, _ = x.shape
|
| 1178 |
+
|
| 1179 |
+
query_dim = query_head_dim * num_heads
|
| 1180 |
+
|
| 1181 |
+
# self-attention
|
| 1182 |
+
q = x[..., 0:query_dim]
|
| 1183 |
+
k = x[..., query_dim : 2 * query_dim]
|
| 1184 |
+
# p is the position-encoding query
|
| 1185 |
+
p = x[..., 2 * query_dim :]
|
| 1186 |
+
assert p.shape[-1] == num_heads * pos_head_dim, (
|
| 1187 |
+
p.shape[-1],
|
| 1188 |
+
num_heads,
|
| 1189 |
+
pos_head_dim,
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
q = self.copy_query(q) # for diagnostics only, does nothing.
|
| 1193 |
+
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
| 1194 |
+
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
| 1195 |
+
|
| 1196 |
+
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
| 1197 |
+
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
|
| 1198 |
+
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
| 1199 |
+
|
| 1200 |
+
# time1 refers to target, time2 refers to source.
|
| 1201 |
+
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
| 1202 |
+
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
|
| 1203 |
+
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
| 1204 |
+
|
| 1205 |
+
attn_scores = torch.matmul(q, k)
|
| 1206 |
+
|
| 1207 |
+
use_pos_scores = False
|
| 1208 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1209 |
+
# We can't put random.random() in the same line
|
| 1210 |
+
use_pos_scores = True
|
| 1211 |
+
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
| 1212 |
+
use_pos_scores = True
|
| 1213 |
+
|
| 1214 |
+
if use_pos_scores:
|
| 1215 |
+
pos_emb = self.linear_pos(pos_emb)
|
| 1216 |
+
seq_len2 = 2 * seq_len - 1
|
| 1217 |
+
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
| 1218 |
+
2, 0, 3, 1
|
| 1219 |
+
)
|
| 1220 |
+
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
| 1221 |
+
|
| 1222 |
+
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head,
|
| 1223 |
+
# batch, time1, seq_len2) [where seq_len2 represents relative position.]
|
| 1224 |
+
pos_scores = torch.matmul(p, pos_emb)
|
| 1225 |
+
# the following .as_strided() expression converts the last axis of
|
| 1226 |
+
# pos_scores from relative to absolute position. I don't know whether I
|
| 1227 |
+
# might have got the time-offsets backwards or not, but let this code define
|
| 1228 |
+
# which way round it is supposed to be.
|
| 1229 |
+
if torch.jit.is_tracing():
|
| 1230 |
+
(num_heads, batch_size, time1, n) = pos_scores.shape
|
| 1231 |
+
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
| 1232 |
+
cols = torch.arange(seq_len)
|
| 1233 |
+
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
| 1234 |
+
indexes = rows + cols
|
| 1235 |
+
pos_scores = pos_scores.reshape(-1, n)
|
| 1236 |
+
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
| 1237 |
+
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
| 1238 |
+
else:
|
| 1239 |
+
pos_scores = pos_scores.as_strided(
|
| 1240 |
+
(num_heads, batch_size, seq_len, seq_len),
|
| 1241 |
+
(
|
| 1242 |
+
pos_scores.stride(0),
|
| 1243 |
+
pos_scores.stride(1),
|
| 1244 |
+
pos_scores.stride(2) - pos_scores.stride(3),
|
| 1245 |
+
pos_scores.stride(3),
|
| 1246 |
+
),
|
| 1247 |
+
storage_offset=pos_scores.stride(3) * (seq_len - 1),
|
| 1248 |
+
)
|
| 1249 |
+
|
| 1250 |
+
attn_scores = attn_scores + pos_scores
|
| 1251 |
+
|
| 1252 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1253 |
+
pass
|
| 1254 |
+
elif self.training and random.random() < 0.1:
|
| 1255 |
+
# This is a harder way of limiting the attention scores to not be
|
| 1256 |
+
# too large. It incurs a penalty if any of them has an absolute
|
| 1257 |
+
# value greater than 50.0. this should be outside the normal range
|
| 1258 |
+
# of the attention scores. We use this mechanism instead of, say,
|
| 1259 |
+
# something added to the loss function involving the entropy,
|
| 1260 |
+
# because once the entropy gets very small gradients through the
|
| 1261 |
+
# softmax can become very small, and we'd get zero derivatives. The
|
| 1262 |
+
# choices of 1.0e-04 as the scale on the penalty makes this
|
| 1263 |
+
# mechanism vulnerable to the absolute scale of the loss function,
|
| 1264 |
+
# but we view this as a failsafe to avoid "implausible" parameter
|
| 1265 |
+
# values rather than a regularization method that should be active
|
| 1266 |
+
# under normal circumstances.
|
| 1267 |
+
attn_scores = penalize_abs_values_gt(
|
| 1268 |
+
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
| 1272 |
+
|
| 1273 |
+
if attn_mask is not None:
|
| 1274 |
+
assert attn_mask.dtype == torch.bool
|
| 1275 |
+
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
| 1276 |
+
# all scores zero. It's important that this be large enough that exp(-1000)
|
| 1277 |
+
# is exactly zero, for reasons related to const_attention_rate, it
|
| 1278 |
+
# compares the final weights with zero.
|
| 1279 |
+
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
| 1280 |
+
|
| 1281 |
+
if key_padding_mask is not None:
|
| 1282 |
+
assert key_padding_mask.shape == (
|
| 1283 |
+
batch_size,
|
| 1284 |
+
seq_len,
|
| 1285 |
+
), key_padding_mask.shape
|
| 1286 |
+
attn_scores = attn_scores.masked_fill(
|
| 1287 |
+
key_padding_mask.unsqueeze(1),
|
| 1288 |
+
-1000,
|
| 1289 |
+
)
|
| 1290 |
+
|
| 1291 |
+
# We use our own version of softmax, defined in scaling.py, which should
|
| 1292 |
+
# save a little of the memory used in backprop by, if we are in
|
| 1293 |
+
# automatic mixed precision mode (amp / autocast), by only storing the
|
| 1294 |
+
# half-precision output for backprop purposes.
|
| 1295 |
+
attn_weights = softmax(attn_scores, dim=-1)
|
| 1296 |
+
|
| 1297 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1298 |
+
pass
|
| 1299 |
+
elif random.random() < 0.001 and not self.training:
|
| 1300 |
+
self._print_attn_entropy(attn_weights)
|
| 1301 |
+
|
| 1302 |
+
attn_weights = nn.functional.dropout(
|
| 1303 |
+
attn_weights, p=self.dropout, training=self.training
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
return attn_weights
|
| 1307 |
+
|
| 1308 |
+
def _print_attn_entropy(self, attn_weights: Tensor):
|
| 1309 |
+
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
| 1310 |
+
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
| 1311 |
+
|
| 1312 |
+
with torch.no_grad():
|
| 1313 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 1314 |
+
attn_weights = attn_weights.to(torch.float32)
|
| 1315 |
+
attn_weights_entropy = (
|
| 1316 |
+
-((attn_weights + 1.0e-20).log() * attn_weights)
|
| 1317 |
+
.sum(dim=-1)
|
| 1318 |
+
.mean(dim=(1, 2))
|
| 1319 |
+
)
|
| 1320 |
+
logging.debug(
|
| 1321 |
+
f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
|
| 1322 |
+
)
|
| 1323 |
+
|
| 1324 |
+
|
| 1325 |
+
class SelfAttention(nn.Module):
|
| 1326 |
+
"""
|
| 1327 |
+
The simplest possible attention module. This one works with already-computed
|
| 1328 |
+
attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
|
| 1329 |
+
|
| 1330 |
+
Args:
|
| 1331 |
+
embed_dim: the input and output embedding dimension
|
| 1332 |
+
num_heads: the number of attention heads
|
| 1333 |
+
value_head_dim: the value dimension per head
|
| 1334 |
+
"""
|
| 1335 |
+
|
| 1336 |
+
def __init__(
|
| 1337 |
+
self,
|
| 1338 |
+
embed_dim: int,
|
| 1339 |
+
num_heads: int,
|
| 1340 |
+
value_head_dim: int,
|
| 1341 |
+
) -> None:
|
| 1342 |
+
super().__init__()
|
| 1343 |
+
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
|
| 1344 |
+
|
| 1345 |
+
self.out_proj = ScaledLinear(
|
| 1346 |
+
num_heads * value_head_dim,
|
| 1347 |
+
embed_dim,
|
| 1348 |
+
bias=True,
|
| 1349 |
+
initial_scale=0.05,
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
self.whiten = Whiten(
|
| 1353 |
+
num_groups=1,
|
| 1354 |
+
whitening_limit=_whitening_schedule(7.5, ratio=3.0),
|
| 1355 |
+
prob=(0.025, 0.25),
|
| 1356 |
+
grad_scale=0.01,
|
| 1357 |
+
)
|
| 1358 |
+
|
| 1359 |
+
def forward(
|
| 1360 |
+
self,
|
| 1361 |
+
x: Tensor,
|
| 1362 |
+
attn_weights: Tensor,
|
| 1363 |
+
) -> Tensor:
|
| 1364 |
+
"""
|
| 1365 |
+
Args:
|
| 1366 |
+
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
| 1367 |
+
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
| 1368 |
+
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
| 1369 |
+
attn_weights.sum(dim=-1) == 1.
|
| 1370 |
+
Returns:
|
| 1371 |
+
a tensor with the same shape as x.
|
| 1372 |
+
"""
|
| 1373 |
+
(seq_len, batch_size, embed_dim) = x.shape
|
| 1374 |
+
num_heads = attn_weights.shape[0]
|
| 1375 |
+
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
| 1376 |
+
|
| 1377 |
+
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
| 1378 |
+
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
| 1379 |
+
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
| 1380 |
+
value_head_dim = x.shape[-1]
|
| 1381 |
+
|
| 1382 |
+
# todo: see whether there is benefit in overriding matmul
|
| 1383 |
+
x = torch.matmul(attn_weights, x)
|
| 1384 |
+
# v: (num_heads, batch_size, seq_len, value_head_dim)
|
| 1385 |
+
|
| 1386 |
+
x = (
|
| 1387 |
+
x.permute(2, 1, 0, 3)
|
| 1388 |
+
.contiguous()
|
| 1389 |
+
.view(seq_len, batch_size, num_heads * value_head_dim)
|
| 1390 |
+
)
|
| 1391 |
+
|
| 1392 |
+
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
| 1393 |
+
x = self.out_proj(x)
|
| 1394 |
+
x = self.whiten(x)
|
| 1395 |
+
|
| 1396 |
+
return x
|
| 1397 |
+
|
| 1398 |
+
|
| 1399 |
+
class FeedforwardModule(nn.Module):
|
| 1400 |
+
"""Feedforward module in TTSZipformer model."""
|
| 1401 |
+
|
| 1402 |
+
def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
|
| 1403 |
+
super(FeedforwardModule, self).__init__()
|
| 1404 |
+
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
| 1405 |
+
|
| 1406 |
+
self.hidden_balancer = Balancer(
|
| 1407 |
+
feedforward_dim,
|
| 1408 |
+
channel_dim=-1,
|
| 1409 |
+
min_positive=0.3,
|
| 1410 |
+
max_positive=1.0,
|
| 1411 |
+
min_abs=0.75,
|
| 1412 |
+
max_abs=5.0,
|
| 1413 |
+
)
|
| 1414 |
+
|
| 1415 |
+
# shared_dim=0 means we share the dropout mask along the time axis
|
| 1416 |
+
self.out_proj = ActivationDropoutAndLinear(
|
| 1417 |
+
feedforward_dim,
|
| 1418 |
+
embed_dim,
|
| 1419 |
+
activation="SwooshL",
|
| 1420 |
+
dropout_p=dropout,
|
| 1421 |
+
dropout_shared_dim=0,
|
| 1422 |
+
bias=True,
|
| 1423 |
+
initial_scale=0.1,
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
self.out_whiten = Whiten(
|
| 1427 |
+
num_groups=1,
|
| 1428 |
+
whitening_limit=_whitening_schedule(7.5),
|
| 1429 |
+
prob=(0.025, 0.25),
|
| 1430 |
+
grad_scale=0.01,
|
| 1431 |
+
)
|
| 1432 |
+
|
| 1433 |
+
def forward(self, x: Tensor):
|
| 1434 |
+
x = self.in_proj(x)
|
| 1435 |
+
x = self.hidden_balancer(x)
|
| 1436 |
+
# out_proj contains SwooshL activation, then dropout, then linear.
|
| 1437 |
+
x = self.out_proj(x)
|
| 1438 |
+
x = self.out_whiten(x)
|
| 1439 |
+
return x
|
| 1440 |
+
|
| 1441 |
+
|
| 1442 |
+
class NonlinAttention(nn.Module):
|
| 1443 |
+
"""This is like the ConvolutionModule, but refactored so that we use multiplication
|
| 1444 |
+
by attention weights (borrowed from the attention module) in place of actual
|
| 1445 |
+
convolution. We also took out the second nonlinearity, the one after the
|
| 1446 |
+
attention mechanism.
|
| 1447 |
+
|
| 1448 |
+
Args:
|
| 1449 |
+
channels (int): The number of channels of conv layers.
|
| 1450 |
+
"""
|
| 1451 |
+
|
| 1452 |
+
def __init__(
|
| 1453 |
+
self,
|
| 1454 |
+
channels: int,
|
| 1455 |
+
hidden_channels: int,
|
| 1456 |
+
) -> None:
|
| 1457 |
+
super().__init__()
|
| 1458 |
+
|
| 1459 |
+
self.hidden_channels = hidden_channels
|
| 1460 |
+
|
| 1461 |
+
self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
|
| 1462 |
+
|
| 1463 |
+
# balancer that goes before the sigmoid. Have quite a large min_abs value, at
|
| 1464 |
+
# 2.0, because we noticed that well-trained instances of this module have
|
| 1465 |
+
# abs-value before the sigmoid starting from about 3, and poorly-trained
|
| 1466 |
+
# instances of the module have smaller abs values before the sigmoid.
|
| 1467 |
+
self.balancer = Balancer(
|
| 1468 |
+
hidden_channels,
|
| 1469 |
+
channel_dim=-1,
|
| 1470 |
+
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
| 1471 |
+
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
| 1472 |
+
min_abs=0.5,
|
| 1473 |
+
max_abs=5.0,
|
| 1474 |
+
)
|
| 1475 |
+
self.tanh = nn.Tanh()
|
| 1476 |
+
|
| 1477 |
+
self.identity1 = Identity() # for diagnostics.
|
| 1478 |
+
self.identity2 = Identity() # for diagnostics.
|
| 1479 |
+
self.identity3 = Identity() # for diagnostics.
|
| 1480 |
+
|
| 1481 |
+
self.out_proj = ScaledLinear(
|
| 1482 |
+
hidden_channels, channels, bias=True, initial_scale=0.05
|
| 1483 |
+
)
|
| 1484 |
+
|
| 1485 |
+
self.whiten1 = Whiten(
|
| 1486 |
+
num_groups=1,
|
| 1487 |
+
whitening_limit=_whitening_schedule(5.0),
|
| 1488 |
+
prob=(0.025, 0.25),
|
| 1489 |
+
grad_scale=0.01,
|
| 1490 |
+
)
|
| 1491 |
+
|
| 1492 |
+
self.whiten2 = Whiten(
|
| 1493 |
+
num_groups=1,
|
| 1494 |
+
whitening_limit=_whitening_schedule(5.0, ratio=3.0),
|
| 1495 |
+
prob=(0.025, 0.25),
|
| 1496 |
+
grad_scale=0.01,
|
| 1497 |
+
)
|
| 1498 |
+
|
| 1499 |
+
def forward(
|
| 1500 |
+
self,
|
| 1501 |
+
x: Tensor,
|
| 1502 |
+
attn_weights: Tensor,
|
| 1503 |
+
) -> Tensor:
|
| 1504 |
+
""".
|
| 1505 |
+
Args:
|
| 1506 |
+
x: a Tensor of shape (seq_len, batch_size, num_channels)
|
| 1507 |
+
attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
| 1508 |
+
Returns:
|
| 1509 |
+
a Tensor with the same shape as x
|
| 1510 |
+
"""
|
| 1511 |
+
x = self.in_proj(x)
|
| 1512 |
+
|
| 1513 |
+
(seq_len, batch_size, _) = x.shape
|
| 1514 |
+
hidden_channels = self.hidden_channels
|
| 1515 |
+
|
| 1516 |
+
s, x, y = x.chunk(3, dim=2)
|
| 1517 |
+
|
| 1518 |
+
# s will go through tanh.
|
| 1519 |
+
|
| 1520 |
+
s = self.balancer(s)
|
| 1521 |
+
s = self.tanh(s)
|
| 1522 |
+
|
| 1523 |
+
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
|
| 1524 |
+
x = self.whiten1(x)
|
| 1525 |
+
x = x * s
|
| 1526 |
+
x = self.identity1(x) # diagnostics only, it's the identity.
|
| 1527 |
+
|
| 1528 |
+
(seq_len, batch_size, embed_dim) = x.shape
|
| 1529 |
+
num_heads = attn_weights.shape[0]
|
| 1530 |
+
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
| 1531 |
+
|
| 1532 |
+
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
| 1533 |
+
# now x: (num_heads, batch_size, seq_len, head_dim)
|
| 1534 |
+
x = torch.matmul(attn_weights, x)
|
| 1535 |
+
# now x: (num_heads, batch_size, seq_len, head_dim)
|
| 1536 |
+
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
| 1537 |
+
|
| 1538 |
+
y = self.identity2(y)
|
| 1539 |
+
x = x * y
|
| 1540 |
+
x = self.identity3(x)
|
| 1541 |
+
|
| 1542 |
+
x = self.out_proj(x)
|
| 1543 |
+
x = self.whiten2(x)
|
| 1544 |
+
return x
|
| 1545 |
+
|
| 1546 |
+
|
| 1547 |
+
class ConvolutionModule(nn.Module):
|
| 1548 |
+
"""ConvolutionModule in Zipformer2 model.
|
| 1549 |
+
|
| 1550 |
+
Args:
|
| 1551 |
+
channels (int): The number of channels of conv layers.
|
| 1552 |
+
kernel_size (int): Kernerl size of conv layers.
|
| 1553 |
+
bias (bool): Whether to use bias in conv layers (default=True).
|
| 1554 |
+
|
| 1555 |
+
"""
|
| 1556 |
+
|
| 1557 |
+
def __init__(
|
| 1558 |
+
self,
|
| 1559 |
+
channels: int,
|
| 1560 |
+
kernel_size: int,
|
| 1561 |
+
) -> None:
|
| 1562 |
+
"""Construct a ConvolutionModule object."""
|
| 1563 |
+
super(ConvolutionModule, self).__init__()
|
| 1564 |
+
# kernerl_size should be a odd number for 'SAME' padding
|
| 1565 |
+
assert (kernel_size - 1) % 2 == 0
|
| 1566 |
+
|
| 1567 |
+
bottleneck_dim = channels
|
| 1568 |
+
|
| 1569 |
+
self.in_proj = nn.Linear(
|
| 1570 |
+
channels,
|
| 1571 |
+
2 * bottleneck_dim,
|
| 1572 |
+
)
|
| 1573 |
+
# the gradients on in_proj are a little noisy, likely to do with the
|
| 1574 |
+
# sigmoid in glu.
|
| 1575 |
+
|
| 1576 |
+
# after in_proj we put x through a gated linear unit (nn.functional.glu). For
|
| 1577 |
+
# most layers the normal rms value of channels of x seems to be in the range 1
|
| 1578 |
+
# to 4, but sometimes, for some reason, for layer 0 the rms ends up being very
|
| 1579 |
+
# large, between 50 and 100 for different channels. This will cause very peaky
|
| 1580 |
+
# and sparse derivatives for the sigmoid gating function, which will tend to
|
| 1581 |
+
# make the loss function not learn effectively. (for most layers the average
|
| 1582 |
+
# absolute values are in the range 0.5..9.0, and the average p(x>0), i.e.
|
| 1583 |
+
# positive proportion, at the output of pointwise_conv1.output is around 0.35 to
|
| 1584 |
+
# 0.45 for different layers, which likely breaks down as 0.5 for the "linear"
|
| 1585 |
+
# half and 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that
|
| 1586 |
+
# if we constrain the rms values to a reasonable range via a constraint of
|
| 1587 |
+
# max_abs=10.0, it will be in a better position to start learning something,
|
| 1588 |
+
# i.e. to latch onto the correct range.
|
| 1589 |
+
self.balancer1 = Balancer(
|
| 1590 |
+
bottleneck_dim,
|
| 1591 |
+
channel_dim=-1,
|
| 1592 |
+
min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
|
| 1593 |
+
max_positive=1.0,
|
| 1594 |
+
min_abs=1.5,
|
| 1595 |
+
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
|
| 1596 |
+
)
|
| 1597 |
+
|
| 1598 |
+
self.activation1 = Identity() # for diagnostics
|
| 1599 |
+
|
| 1600 |
+
self.sigmoid = nn.Sigmoid()
|
| 1601 |
+
|
| 1602 |
+
self.activation2 = Identity() # for diagnostics
|
| 1603 |
+
|
| 1604 |
+
assert kernel_size % 2 == 1
|
| 1605 |
+
|
| 1606 |
+
self.depthwise_conv = nn.Conv1d(
|
| 1607 |
+
in_channels=bottleneck_dim,
|
| 1608 |
+
out_channels=bottleneck_dim,
|
| 1609 |
+
groups=bottleneck_dim,
|
| 1610 |
+
kernel_size=kernel_size,
|
| 1611 |
+
padding=kernel_size // 2,
|
| 1612 |
+
)
|
| 1613 |
+
|
| 1614 |
+
self.balancer2 = Balancer(
|
| 1615 |
+
bottleneck_dim,
|
| 1616 |
+
channel_dim=1,
|
| 1617 |
+
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
| 1618 |
+
max_positive=1.0,
|
| 1619 |
+
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
| 1620 |
+
max_abs=10.0,
|
| 1621 |
+
)
|
| 1622 |
+
|
| 1623 |
+
self.whiten = Whiten(
|
| 1624 |
+
num_groups=1,
|
| 1625 |
+
whitening_limit=_whitening_schedule(7.5),
|
| 1626 |
+
prob=(0.025, 0.25),
|
| 1627 |
+
grad_scale=0.01,
|
| 1628 |
+
)
|
| 1629 |
+
|
| 1630 |
+
self.out_proj = ActivationDropoutAndLinear(
|
| 1631 |
+
bottleneck_dim,
|
| 1632 |
+
channels,
|
| 1633 |
+
activation="SwooshR",
|
| 1634 |
+
dropout_p=0.0,
|
| 1635 |
+
initial_scale=0.05,
|
| 1636 |
+
)
|
| 1637 |
+
|
| 1638 |
+
def forward(
|
| 1639 |
+
self,
|
| 1640 |
+
x: Tensor,
|
| 1641 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 1642 |
+
) -> Tensor:
|
| 1643 |
+
"""Compute convolution module.
|
| 1644 |
+
|
| 1645 |
+
Args:
|
| 1646 |
+
x: Input tensor (#time, batch, channels).
|
| 1647 |
+
src_key_padding_mask: the mask for the src keys per batch (optional):
|
| 1648 |
+
(batch, #time), contains True in masked positions.
|
| 1649 |
+
|
| 1650 |
+
Returns:
|
| 1651 |
+
Tensor: Output tensor (#time, batch, channels).
|
| 1652 |
+
|
| 1653 |
+
"""
|
| 1654 |
+
|
| 1655 |
+
x = self.in_proj(x) # (time, batch, 2*channels)
|
| 1656 |
+
|
| 1657 |
+
x, s = x.chunk(2, dim=2)
|
| 1658 |
+
s = self.balancer1(s)
|
| 1659 |
+
s = self.sigmoid(s)
|
| 1660 |
+
x = self.activation1(x) # identity.
|
| 1661 |
+
x = x * s
|
| 1662 |
+
x = self.activation2(x) # identity
|
| 1663 |
+
|
| 1664 |
+
# (time, batch, channels)
|
| 1665 |
+
|
| 1666 |
+
# exchange the temporal dimension and the feature dimension
|
| 1667 |
+
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
| 1668 |
+
|
| 1669 |
+
if src_key_padding_mask is not None:
|
| 1670 |
+
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
| 1671 |
+
|
| 1672 |
+
x = self.depthwise_conv(x)
|
| 1673 |
+
|
| 1674 |
+
x = self.balancer2(x)
|
| 1675 |
+
x = x.permute(2, 0, 1) # (time, batch, channels)
|
| 1676 |
+
|
| 1677 |
+
x = self.whiten(x) # (time, batch, channels)
|
| 1678 |
+
x = self.out_proj(x) # (time, batch, channels)
|
| 1679 |
+
|
| 1680 |
+
return x
|
zipvoice/models/modules/zipformer_two_stream.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import Tensor, nn
|
| 23 |
+
|
| 24 |
+
from zipvoice.models.modules.scaling import FloatLike, ScheduledFloat, SwooshR
|
| 25 |
+
from zipvoice.models.modules.zipformer import (
|
| 26 |
+
DownsampledZipformer2Encoder,
|
| 27 |
+
TTSZipformer,
|
| 28 |
+
Zipformer2Encoder,
|
| 29 |
+
Zipformer2EncoderLayer,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
| 34 |
+
"""Create sinusoidal timestep embeddings.
|
| 35 |
+
|
| 36 |
+
:param timesteps: shape of (N) or (N, T)
|
| 37 |
+
:param dim: the dimension of the output.
|
| 38 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 39 |
+
:return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
|
| 40 |
+
"""
|
| 41 |
+
half = dim // 2
|
| 42 |
+
freqs = torch.exp(
|
| 43 |
+
-math.log(max_period)
|
| 44 |
+
* torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
|
| 45 |
+
/ half
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if timesteps.dim() == 2:
|
| 49 |
+
timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
|
| 50 |
+
|
| 51 |
+
args = timesteps[..., None].float() * freqs[None]
|
| 52 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 53 |
+
if dim % 2:
|
| 54 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
|
| 55 |
+
return embedding
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TTSZipformerTwoStream(TTSZipformer):
|
| 59 |
+
"""
|
| 60 |
+
Args:
|
| 61 |
+
|
| 62 |
+
Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
|
| 63 |
+
length as downsampling_factor if they are single ints or one-element tuples.
|
| 64 |
+
The length of downsampling_factor defines the number of stacks.
|
| 65 |
+
|
| 66 |
+
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
|
| 67 |
+
Note: this is in addition to the downsampling factor of 2 that is applied in
|
| 68 |
+
the frontend (self.encoder_embed).
|
| 69 |
+
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
|
| 70 |
+
one per encoder stack.
|
| 71 |
+
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
| 72 |
+
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
| 73 |
+
head: per stack, if a tuple..
|
| 74 |
+
pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
|
| 75 |
+
per attention head
|
| 76 |
+
value_head_dim (int or Tuple[int]): dimension of value in each attention head
|
| 77 |
+
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
|
| 78 |
+
Must be at least 4.
|
| 79 |
+
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
|
| 80 |
+
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
|
| 81 |
+
|
| 82 |
+
pos_dim (int): the dimension of each positional-encoding vector prior to
|
| 83 |
+
projection, e.g. 128.
|
| 84 |
+
|
| 85 |
+
dropout (float): dropout rate
|
| 86 |
+
warmup_batches (float): number of batches to warm up over; this controls
|
| 87 |
+
dropout of encoder layers.
|
| 88 |
+
use_time_embed: (bool): if True, do not take time embedding as additional input.
|
| 89 |
+
time_embed_dim: (int): the dimension of the time embedding.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
in_dim: Tuple[int],
|
| 95 |
+
out_dim: Tuple[int],
|
| 96 |
+
downsampling_factor: Tuple[int] = (2, 4),
|
| 97 |
+
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
| 98 |
+
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
| 99 |
+
encoder_dim: int = 384,
|
| 100 |
+
query_head_dim: int = 24,
|
| 101 |
+
pos_head_dim: int = 4,
|
| 102 |
+
value_head_dim: int = 12,
|
| 103 |
+
num_heads: int = 8,
|
| 104 |
+
feedforward_dim: int = 1536,
|
| 105 |
+
pos_dim: int = 192,
|
| 106 |
+
dropout: FloatLike = None, # see code below for default
|
| 107 |
+
warmup_batches: float = 4000.0,
|
| 108 |
+
use_time_embed: bool = True,
|
| 109 |
+
time_embed_dim: int = 192,
|
| 110 |
+
use_conv: bool = True,
|
| 111 |
+
) -> None:
|
| 112 |
+
nn.Module.__init__(self)
|
| 113 |
+
|
| 114 |
+
if dropout is None:
|
| 115 |
+
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
|
| 116 |
+
if isinstance(downsampling_factor, int):
|
| 117 |
+
downsampling_factor = (downsampling_factor,)
|
| 118 |
+
|
| 119 |
+
def _to_tuple(x):
|
| 120 |
+
"""Converts a single int or a 1-tuple of an int to a tuple with the same
|
| 121 |
+
length as downsampling_factor"""
|
| 122 |
+
if isinstance(x, int):
|
| 123 |
+
x = (x,)
|
| 124 |
+
if len(x) == 1:
|
| 125 |
+
x = x * len(downsampling_factor)
|
| 126 |
+
else:
|
| 127 |
+
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
def _assert_downsampling_factor(factors):
|
| 131 |
+
"""assert downsampling_factor follows u-net style"""
|
| 132 |
+
assert factors[0] == 1 and factors[-1] == 1
|
| 133 |
+
|
| 134 |
+
for i in range(1, len(factors) // 2 + 1):
|
| 135 |
+
assert factors[i] == factors[i - 1] * 2
|
| 136 |
+
|
| 137 |
+
for i in range(len(factors) // 2 + 1, len(factors)):
|
| 138 |
+
assert factors[i] * 2 == factors[i - 1]
|
| 139 |
+
|
| 140 |
+
_assert_downsampling_factor(downsampling_factor)
|
| 141 |
+
self.downsampling_factor = downsampling_factor # tuple
|
| 142 |
+
num_encoder_layers = _to_tuple(num_encoder_layers)
|
| 143 |
+
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
| 144 |
+
self.encoder_dim = encoder_dim
|
| 145 |
+
self.num_encoder_layers = num_encoder_layers
|
| 146 |
+
self.query_head_dim = query_head_dim
|
| 147 |
+
self.value_head_dim = value_head_dim
|
| 148 |
+
self.num_heads = num_heads
|
| 149 |
+
|
| 150 |
+
self.use_time_embed = use_time_embed
|
| 151 |
+
|
| 152 |
+
self.time_embed_dim = time_embed_dim
|
| 153 |
+
if self.use_time_embed:
|
| 154 |
+
assert time_embed_dim != -1
|
| 155 |
+
else:
|
| 156 |
+
time_embed_dim = -1
|
| 157 |
+
|
| 158 |
+
assert len(in_dim) == len(out_dim) == 2
|
| 159 |
+
|
| 160 |
+
self.in_dim = in_dim
|
| 161 |
+
self.in_proj = nn.ModuleList(
|
| 162 |
+
[nn.Linear(in_dim[0], encoder_dim), nn.Linear(in_dim[1], encoder_dim)]
|
| 163 |
+
)
|
| 164 |
+
self.out_dim = out_dim
|
| 165 |
+
self.out_proj = nn.ModuleList(
|
| 166 |
+
[nn.Linear(encoder_dim, out_dim[0]), nn.Linear(encoder_dim, out_dim[1])]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
| 170 |
+
encoders = []
|
| 171 |
+
|
| 172 |
+
num_encoders = len(downsampling_factor)
|
| 173 |
+
for i in range(num_encoders):
|
| 174 |
+
encoder_layer = Zipformer2EncoderLayer(
|
| 175 |
+
embed_dim=encoder_dim,
|
| 176 |
+
pos_dim=pos_dim,
|
| 177 |
+
num_heads=num_heads,
|
| 178 |
+
query_head_dim=query_head_dim,
|
| 179 |
+
pos_head_dim=pos_head_dim,
|
| 180 |
+
value_head_dim=value_head_dim,
|
| 181 |
+
feedforward_dim=feedforward_dim,
|
| 182 |
+
use_conv=use_conv,
|
| 183 |
+
cnn_module_kernel=cnn_module_kernel[i],
|
| 184 |
+
dropout=dropout,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# For the segment of the warmup period, we let the Conv2dSubsampling
|
| 188 |
+
# layer learn something. Then we start to warm up the other encoders.
|
| 189 |
+
encoder = Zipformer2Encoder(
|
| 190 |
+
encoder_layer,
|
| 191 |
+
num_encoder_layers[i],
|
| 192 |
+
embed_dim=encoder_dim,
|
| 193 |
+
time_embed_dim=time_embed_dim,
|
| 194 |
+
pos_dim=pos_dim,
|
| 195 |
+
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
| 196 |
+
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
| 197 |
+
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
if downsampling_factor[i] != 1:
|
| 201 |
+
encoder = DownsampledZipformer2Encoder(
|
| 202 |
+
encoder,
|
| 203 |
+
dim=encoder_dim,
|
| 204 |
+
downsample=downsampling_factor[i],
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
encoders.append(encoder)
|
| 208 |
+
|
| 209 |
+
self.encoders = nn.ModuleList(encoders)
|
| 210 |
+
if self.use_time_embed:
|
| 211 |
+
self.time_embed = nn.Sequential(
|
| 212 |
+
nn.Linear(time_embed_dim, time_embed_dim * 2),
|
| 213 |
+
SwooshR(),
|
| 214 |
+
nn.Linear(time_embed_dim * 2, time_embed_dim),
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
self.time_embed = None
|
| 218 |
+
|
| 219 |
+
def forward(
|
| 220 |
+
self,
|
| 221 |
+
x: Tensor,
|
| 222 |
+
t: Optional[Tensor] = None,
|
| 223 |
+
padding_mask: Optional[Tensor] = None,
|
| 224 |
+
) -> Tuple[Tensor, Tensor]:
|
| 225 |
+
"""
|
| 226 |
+
Args:
|
| 227 |
+
x:
|
| 228 |
+
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
| 229 |
+
t:
|
| 230 |
+
A t tensor of shape (batch_size,) or (batch_size, seq_len)
|
| 231 |
+
padding_mask:
|
| 232 |
+
The mask for padding, of shape (batch_size, seq_len); True means
|
| 233 |
+
masked position. May be None.
|
| 234 |
+
Returns:
|
| 235 |
+
Return the output embeddings. its shape is
|
| 236 |
+
(batch_size, output_seq_len, encoder_dim)
|
| 237 |
+
"""
|
| 238 |
+
assert x.size(2) in self.in_dim, f"{x.size(2)} in {self.in_dim}"
|
| 239 |
+
if x.size(2) == self.in_dim[0]:
|
| 240 |
+
index = 0
|
| 241 |
+
else:
|
| 242 |
+
index = 1
|
| 243 |
+
x = x.permute(1, 0, 2)
|
| 244 |
+
x = self.in_proj[index](x)
|
| 245 |
+
|
| 246 |
+
if t is not None:
|
| 247 |
+
assert t.dim() == 1 or t.dim() == 2, t.shape
|
| 248 |
+
time_emb = timestep_embedding(t, self.time_embed_dim)
|
| 249 |
+
time_emb = self.time_embed(time_emb)
|
| 250 |
+
else:
|
| 251 |
+
time_emb = None
|
| 252 |
+
|
| 253 |
+
attn_mask = None
|
| 254 |
+
|
| 255 |
+
for i, module in enumerate(self.encoders):
|
| 256 |
+
x = module(
|
| 257 |
+
x,
|
| 258 |
+
time_emb=time_emb,
|
| 259 |
+
src_key_padding_mask=padding_mask,
|
| 260 |
+
attn_mask=attn_mask,
|
| 261 |
+
)
|
| 262 |
+
x = self.out_proj[index](x)
|
| 263 |
+
x = x.permute(1, 0, 2)
|
| 264 |
+
return x
|
zipvoice/models/zipvoice.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
| 2 |
+
# Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from typing import List, Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 23 |
+
|
| 24 |
+
from zipvoice.models.modules.solver import EulerSolver
|
| 25 |
+
from zipvoice.models.modules.zipformer import TTSZipformer
|
| 26 |
+
from zipvoice.utils.common import (
|
| 27 |
+
condition_time_mask,
|
| 28 |
+
get_tokens_index,
|
| 29 |
+
make_pad_mask,
|
| 30 |
+
pad_labels,
|
| 31 |
+
prepare_avg_tokens_durations,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ZipVoice(nn.Module):
|
| 36 |
+
"""The ZipVoice model."""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1],
|
| 41 |
+
fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4],
|
| 42 |
+
fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31],
|
| 43 |
+
fm_decoder_feedforward_dim: int = 1536,
|
| 44 |
+
fm_decoder_num_heads: int = 4,
|
| 45 |
+
fm_decoder_dim: int = 512,
|
| 46 |
+
text_encoder_num_layers: int = 4,
|
| 47 |
+
text_encoder_feedforward_dim: int = 512,
|
| 48 |
+
text_encoder_cnn_module_kernel: int = 9,
|
| 49 |
+
text_encoder_num_heads: int = 4,
|
| 50 |
+
text_encoder_dim: int = 192,
|
| 51 |
+
time_embed_dim: int = 192,
|
| 52 |
+
text_embed_dim: int = 192,
|
| 53 |
+
query_head_dim: int = 32,
|
| 54 |
+
value_head_dim: int = 12,
|
| 55 |
+
pos_head_dim: int = 4,
|
| 56 |
+
pos_dim: int = 48,
|
| 57 |
+
feat_dim: int = 100,
|
| 58 |
+
vocab_size: int = 26,
|
| 59 |
+
pad_id: int = 0,
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
Initialize the model with specified configuration parameters.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
fm_decoder_downsampling_factor: List of downsampling factors for each layer
|
| 66 |
+
in the flow-matching decoder.
|
| 67 |
+
fm_decoder_num_layers: List of the number of layers for each block in the
|
| 68 |
+
flow-matching decoder.
|
| 69 |
+
fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the
|
| 70 |
+
flow-matching decoder.
|
| 71 |
+
fm_decoder_feedforward_dim: Dimension of the feedforward network in the
|
| 72 |
+
flow-matching decoder.
|
| 73 |
+
fm_decoder_num_heads: Number of attention heads in the flow-matching
|
| 74 |
+
decoder.
|
| 75 |
+
fm_decoder_dim: Hidden dimension of the flow-matching decoder.
|
| 76 |
+
text_encoder_num_layers: Number of layers in the text encoder.
|
| 77 |
+
text_encoder_feedforward_dim: Dimension of the feedforward network in the
|
| 78 |
+
text encoder.
|
| 79 |
+
text_encoder_cnn_module_kernel: Kernel size for the CNN module in the
|
| 80 |
+
text encoder.
|
| 81 |
+
text_encoder_num_heads: Number of attention heads in the text encoder.
|
| 82 |
+
text_encoder_dim: Hidden dimension of the text encoder.
|
| 83 |
+
time_embed_dim: Dimension of the time embedding.
|
| 84 |
+
text_embed_dim: Dimension of the text embedding.
|
| 85 |
+
query_head_dim: Dimension of the query attention head.
|
| 86 |
+
value_head_dim: Dimension of the value attention head.
|
| 87 |
+
pos_head_dim: Dimension of the position attention head.
|
| 88 |
+
pos_dim: Dimension of the positional encoding.
|
| 89 |
+
feat_dim: Dimension of the acoustic features.
|
| 90 |
+
vocab_size: Size of the vocabulary.
|
| 91 |
+
pad_id: ID used for padding tokens.
|
| 92 |
+
"""
|
| 93 |
+
super().__init__()
|
| 94 |
+
|
| 95 |
+
self.fm_decoder = TTSZipformer(
|
| 96 |
+
in_dim=feat_dim * 3,
|
| 97 |
+
out_dim=feat_dim,
|
| 98 |
+
downsampling_factor=fm_decoder_downsampling_factor,
|
| 99 |
+
num_encoder_layers=fm_decoder_num_layers,
|
| 100 |
+
cnn_module_kernel=fm_decoder_cnn_module_kernel,
|
| 101 |
+
encoder_dim=fm_decoder_dim,
|
| 102 |
+
feedforward_dim=fm_decoder_feedforward_dim,
|
| 103 |
+
num_heads=fm_decoder_num_heads,
|
| 104 |
+
query_head_dim=query_head_dim,
|
| 105 |
+
pos_head_dim=pos_head_dim,
|
| 106 |
+
value_head_dim=value_head_dim,
|
| 107 |
+
pos_dim=pos_dim,
|
| 108 |
+
use_time_embed=True,
|
| 109 |
+
time_embed_dim=time_embed_dim,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.text_encoder = TTSZipformer(
|
| 113 |
+
in_dim=text_embed_dim,
|
| 114 |
+
out_dim=feat_dim,
|
| 115 |
+
downsampling_factor=1,
|
| 116 |
+
num_encoder_layers=text_encoder_num_layers,
|
| 117 |
+
cnn_module_kernel=text_encoder_cnn_module_kernel,
|
| 118 |
+
encoder_dim=text_encoder_dim,
|
| 119 |
+
feedforward_dim=text_encoder_feedforward_dim,
|
| 120 |
+
num_heads=text_encoder_num_heads,
|
| 121 |
+
query_head_dim=query_head_dim,
|
| 122 |
+
pos_head_dim=pos_head_dim,
|
| 123 |
+
value_head_dim=value_head_dim,
|
| 124 |
+
pos_dim=pos_dim,
|
| 125 |
+
use_time_embed=False,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.feat_dim = feat_dim
|
| 129 |
+
self.text_embed_dim = text_embed_dim
|
| 130 |
+
self.pad_id = pad_id
|
| 131 |
+
|
| 132 |
+
self.embed = nn.Embedding(vocab_size, text_embed_dim)
|
| 133 |
+
self.solver = EulerSolver(self, func_name="forward_fm_decoder")
|
| 134 |
+
|
| 135 |
+
def forward_fm_decoder(
|
| 136 |
+
self,
|
| 137 |
+
t: torch.Tensor,
|
| 138 |
+
xt: torch.Tensor,
|
| 139 |
+
text_condition: torch.Tensor,
|
| 140 |
+
speech_condition: torch.Tensor,
|
| 141 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 142 |
+
guidance_scale: Optional[torch.Tensor] = None,
|
| 143 |
+
) -> torch.Tensor:
|
| 144 |
+
"""Compute velocity.
|
| 145 |
+
Args:
|
| 146 |
+
t: A tensor of shape (N, 1, 1) or a tensor of a float,
|
| 147 |
+
in the range of (0, 1).
|
| 148 |
+
xt: the input of the current timestep, including condition
|
| 149 |
+
embeddings and noisy acoustic features.
|
| 150 |
+
text_condition: the text condition embeddings, with the
|
| 151 |
+
shape (batch, seq_len, emb_dim).
|
| 152 |
+
speech_condition: the speech condition embeddings, with the
|
| 153 |
+
shape (batch, seq_len, emb_dim).
|
| 154 |
+
padding_mask: The mask for padding, True means masked
|
| 155 |
+
position, with the shape (N, T).
|
| 156 |
+
guidance_scale: The guidance scale in classifier-free guidance,
|
| 157 |
+
which is a tensor of shape (N, 1, 1) or a tensor of a float.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
predicted velocity, with the shape (batch, seq_len, emb_dim).
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
xt = torch.cat([xt, text_condition, speech_condition], dim=2)
|
| 164 |
+
|
| 165 |
+
assert t.dim() in (0, 3)
|
| 166 |
+
# Handle t with the shape (N, 1, 1):
|
| 167 |
+
# squeeze the last dimension if it's size is 1.
|
| 168 |
+
while t.dim() > 1 and t.size(-1) == 1:
|
| 169 |
+
t = t.squeeze(-1)
|
| 170 |
+
# Handle t with a single value: expand to the size of batch size.
|
| 171 |
+
if t.dim() == 0:
|
| 172 |
+
t = t.repeat(xt.shape[0])
|
| 173 |
+
|
| 174 |
+
if guidance_scale is not None:
|
| 175 |
+
while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
|
| 176 |
+
guidance_scale = guidance_scale.squeeze(-1)
|
| 177 |
+
if guidance_scale.dim() == 0:
|
| 178 |
+
guidance_scale = guidance_scale.repeat(xt.shape[0])
|
| 179 |
+
|
| 180 |
+
vt = self.fm_decoder(
|
| 181 |
+
x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
vt = self.fm_decoder(x=xt, t=t, padding_mask=padding_mask)
|
| 185 |
+
return vt
|
| 186 |
+
|
| 187 |
+
def forward_text_embed(
|
| 188 |
+
self,
|
| 189 |
+
tokens: List[List[int]],
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Get the text embeddings.
|
| 193 |
+
Args:
|
| 194 |
+
tokens: a list of list of token ids.
|
| 195 |
+
Returns:
|
| 196 |
+
embed: the text embeddings, shape (batch, seq_len, emb_dim).
|
| 197 |
+
tokens_lens: the length of each token sequence, shape (batch,).
|
| 198 |
+
"""
|
| 199 |
+
device = (
|
| 200 |
+
self.device if isinstance(self, DDP) else next(self.parameters()).device
|
| 201 |
+
)
|
| 202 |
+
tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
|
| 203 |
+
embed = self.embed(tokens_padded) # (B, S, C)
|
| 204 |
+
tokens_lens = torch.tensor(
|
| 205 |
+
[len(token) for token in tokens], dtype=torch.int64, device=device
|
| 206 |
+
)
|
| 207 |
+
tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
|
| 208 |
+
|
| 209 |
+
embed = self.text_encoder(
|
| 210 |
+
x=embed, t=None, padding_mask=tokens_padding_mask
|
| 211 |
+
) # (B, S, C)
|
| 212 |
+
return embed, tokens_lens
|
| 213 |
+
|
| 214 |
+
def forward_text_condition(
|
| 215 |
+
self,
|
| 216 |
+
embed: torch.Tensor,
|
| 217 |
+
tokens_lens: torch.Tensor,
|
| 218 |
+
features_lens: torch.Tensor,
|
| 219 |
+
):
|
| 220 |
+
"""
|
| 221 |
+
Get the text condition with the same length of the acoustic feature.
|
| 222 |
+
Args:
|
| 223 |
+
embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
|
| 224 |
+
tokens_lens: the length of each token sequence, shape (batch,).
|
| 225 |
+
features_lens: the length of each acoustic feature sequence,
|
| 226 |
+
shape (batch,).
|
| 227 |
+
Returns:
|
| 228 |
+
text_condition: the text condition, shape
|
| 229 |
+
(batch, feature_seq_len, emb_dim).
|
| 230 |
+
padding_mask: the padding mask of text condition, shape
|
| 231 |
+
(batch, feature_seq_len).
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
num_frames = int(features_lens.max())
|
| 235 |
+
|
| 236 |
+
padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
|
| 237 |
+
|
| 238 |
+
tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
|
| 239 |
+
|
| 240 |
+
tokens_index = get_tokens_index(tokens_durations, num_frames).to(
|
| 241 |
+
embed.device
|
| 242 |
+
) # (B, T)
|
| 243 |
+
|
| 244 |
+
text_condition = torch.gather(
|
| 245 |
+
embed,
|
| 246 |
+
dim=1,
|
| 247 |
+
index=tokens_index.unsqueeze(-1).expand(
|
| 248 |
+
embed.size(0), num_frames, embed.size(-1)
|
| 249 |
+
),
|
| 250 |
+
) # (B, T, F)
|
| 251 |
+
return text_condition, padding_mask
|
| 252 |
+
|
| 253 |
+
def forward_text_train(
|
| 254 |
+
self,
|
| 255 |
+
tokens: List[List[int]],
|
| 256 |
+
features_lens: torch.Tensor,
|
| 257 |
+
):
|
| 258 |
+
"""
|
| 259 |
+
Process text for training, given text tokens and real feature lengths.
|
| 260 |
+
"""
|
| 261 |
+
embed, tokens_lens = self.forward_text_embed(tokens)
|
| 262 |
+
text_condition, padding_mask = self.forward_text_condition(
|
| 263 |
+
embed, tokens_lens, features_lens
|
| 264 |
+
)
|
| 265 |
+
return (
|
| 266 |
+
text_condition,
|
| 267 |
+
padding_mask,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
def forward_text_inference_gt_duration(
|
| 271 |
+
self,
|
| 272 |
+
tokens: List[List[int]],
|
| 273 |
+
features_lens: torch.Tensor,
|
| 274 |
+
prompt_tokens: List[List[int]],
|
| 275 |
+
prompt_features_lens: torch.Tensor,
|
| 276 |
+
):
|
| 277 |
+
"""
|
| 278 |
+
Process text for inference, given text tokens, real feature lengths and prompts.
|
| 279 |
+
"""
|
| 280 |
+
tokens = [
|
| 281 |
+
prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
|
| 282 |
+
]
|
| 283 |
+
features_lens = prompt_features_lens + features_lens
|
| 284 |
+
embed, tokens_lens = self.forward_text_embed(tokens)
|
| 285 |
+
text_condition, padding_mask = self.forward_text_condition(
|
| 286 |
+
embed, tokens_lens, features_lens
|
| 287 |
+
)
|
| 288 |
+
return text_condition, padding_mask
|
| 289 |
+
|
| 290 |
+
def forward_text_inference_ratio_duration(
|
| 291 |
+
self,
|
| 292 |
+
tokens: List[List[int]],
|
| 293 |
+
prompt_tokens: List[List[int]],
|
| 294 |
+
prompt_features_lens: torch.Tensor,
|
| 295 |
+
speed: float,
|
| 296 |
+
):
|
| 297 |
+
"""
|
| 298 |
+
Process text for inference, given text tokens and prompts,
|
| 299 |
+
feature lengths are predicted with the ratio of token numbers.
|
| 300 |
+
"""
|
| 301 |
+
device = (
|
| 302 |
+
self.device if isinstance(self, DDP) else next(self.parameters()).device
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
cat_tokens = [
|
| 306 |
+
prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
prompt_tokens_lens = torch.tensor(
|
| 310 |
+
[len(token) for token in prompt_tokens],
|
| 311 |
+
dtype=torch.int64,
|
| 312 |
+
device=device,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
tokens_lens = torch.tensor(
|
| 316 |
+
[len(token) for token in tokens],
|
| 317 |
+
dtype=torch.int64,
|
| 318 |
+
device=device,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
|
| 322 |
+
|
| 323 |
+
features_lens = prompt_features_lens + torch.ceil(
|
| 324 |
+
(prompt_features_lens / prompt_tokens_lens * tokens_lens / speed)
|
| 325 |
+
).to(dtype=torch.int64)
|
| 326 |
+
|
| 327 |
+
text_condition, padding_mask = self.forward_text_condition(
|
| 328 |
+
cat_embed, cat_tokens_lens, features_lens
|
| 329 |
+
)
|
| 330 |
+
return text_condition, padding_mask
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
tokens: List[List[int]],
|
| 335 |
+
features: torch.Tensor,
|
| 336 |
+
features_lens: torch.Tensor,
|
| 337 |
+
noise: torch.Tensor,
|
| 338 |
+
t: torch.Tensor,
|
| 339 |
+
condition_drop_ratio: float = 0.0,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
"""Forward pass of the model for training.
|
| 342 |
+
Args:
|
| 343 |
+
tokens: a list of list of token ids.
|
| 344 |
+
features: the acoustic features, with the shape (batch, seq_len, feat_dim).
|
| 345 |
+
features_lens: the length of each acoustic feature sequence, shape (batch,).
|
| 346 |
+
noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
|
| 347 |
+
t: the time step, with the shape (batch, 1, 1).
|
| 348 |
+
condition_drop_ratio: the ratio of dropped text condition.
|
| 349 |
+
Returns:
|
| 350 |
+
fm_loss: the flow-matching loss.
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
(text_condition, padding_mask,) = self.forward_text_train(
|
| 354 |
+
tokens=tokens,
|
| 355 |
+
features_lens=features_lens,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
speech_condition_mask = condition_time_mask(
|
| 359 |
+
features_lens=features_lens,
|
| 360 |
+
mask_percent=(0.7, 1.0),
|
| 361 |
+
max_len=features.size(1),
|
| 362 |
+
)
|
| 363 |
+
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
|
| 364 |
+
|
| 365 |
+
if condition_drop_ratio > 0.0:
|
| 366 |
+
drop_mask = (
|
| 367 |
+
torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
|
| 368 |
+
> condition_drop_ratio
|
| 369 |
+
)
|
| 370 |
+
text_condition = text_condition * drop_mask
|
| 371 |
+
|
| 372 |
+
xt = features * t + noise * (1 - t)
|
| 373 |
+
ut = features - noise # (B, T, F)
|
| 374 |
+
|
| 375 |
+
vt = self.forward_fm_decoder(
|
| 376 |
+
t=t,
|
| 377 |
+
xt=xt,
|
| 378 |
+
text_condition=text_condition,
|
| 379 |
+
speech_condition=speech_condition,
|
| 380 |
+
padding_mask=padding_mask,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
loss_mask = speech_condition_mask & (~padding_mask)
|
| 384 |
+
fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
|
| 385 |
+
|
| 386 |
+
return fm_loss
|
| 387 |
+
|
| 388 |
+
def sample(
|
| 389 |
+
self,
|
| 390 |
+
tokens: List[List[int]],
|
| 391 |
+
prompt_tokens: List[List[int]],
|
| 392 |
+
prompt_features: torch.Tensor,
|
| 393 |
+
prompt_features_lens: torch.Tensor,
|
| 394 |
+
features_lens: Optional[torch.Tensor] = None,
|
| 395 |
+
speed: float = 1.0,
|
| 396 |
+
t_shift: float = 1.0,
|
| 397 |
+
duration: str = "predict",
|
| 398 |
+
num_step: int = 5,
|
| 399 |
+
guidance_scale: float = 0.5,
|
| 400 |
+
) -> torch.Tensor:
|
| 401 |
+
"""
|
| 402 |
+
Generate acoustic features, given text tokens, prompts feature
|
| 403 |
+
and prompt transcription's text tokens.
|
| 404 |
+
Args:
|
| 405 |
+
tokens: a list of list of text tokens.
|
| 406 |
+
prompt_tokens: a list of list of prompt tokens.
|
| 407 |
+
prompt_features: the prompt feature with the shape
|
| 408 |
+
(batch_size, seq_len, feat_dim).
|
| 409 |
+
prompt_features_lens: the length of each prompt feature,
|
| 410 |
+
with the shape (batch_size,).
|
| 411 |
+
features_lens: the length of the predicted eature, with the
|
| 412 |
+
shape (batch_size,). It is used only when duration is "real".
|
| 413 |
+
duration: "real" or "predict". If "real", the predicted
|
| 414 |
+
feature length is given by features_lens.
|
| 415 |
+
num_step: the number of steps to use in the ODE solver.
|
| 416 |
+
guidance_scale: the guidance scale for classifier-free guidance.
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
assert duration in ["real", "predict"]
|
| 420 |
+
|
| 421 |
+
if duration == "predict":
|
| 422 |
+
(
|
| 423 |
+
text_condition,
|
| 424 |
+
padding_mask,
|
| 425 |
+
) = self.forward_text_inference_ratio_duration(
|
| 426 |
+
tokens=tokens,
|
| 427 |
+
prompt_tokens=prompt_tokens,
|
| 428 |
+
prompt_features_lens=prompt_features_lens,
|
| 429 |
+
speed=speed,
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
assert features_lens is not None
|
| 433 |
+
text_condition, padding_mask = self.forward_text_inference_gt_duration(
|
| 434 |
+
tokens=tokens,
|
| 435 |
+
features_lens=features_lens,
|
| 436 |
+
prompt_tokens=prompt_tokens,
|
| 437 |
+
prompt_features_lens=prompt_features_lens,
|
| 438 |
+
)
|
| 439 |
+
batch_size, num_frames, _ = text_condition.shape
|
| 440 |
+
|
| 441 |
+
speech_condition = torch.nn.functional.pad(
|
| 442 |
+
prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
|
| 443 |
+
) # (B, T, F)
|
| 444 |
+
|
| 445 |
+
# False means speech condition positions.
|
| 446 |
+
speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
|
| 447 |
+
speech_condition = torch.where(
|
| 448 |
+
speech_condition_mask.unsqueeze(-1),
|
| 449 |
+
torch.zeros_like(speech_condition),
|
| 450 |
+
speech_condition,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
x0 = torch.randn(
|
| 454 |
+
batch_size,
|
| 455 |
+
num_frames,
|
| 456 |
+
prompt_features.size(-1),
|
| 457 |
+
device=text_condition.device,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
x1 = self.solver.sample(
|
| 461 |
+
x=x0,
|
| 462 |
+
text_condition=text_condition,
|
| 463 |
+
speech_condition=speech_condition,
|
| 464 |
+
padding_mask=padding_mask,
|
| 465 |
+
num_step=num_step,
|
| 466 |
+
guidance_scale=guidance_scale,
|
| 467 |
+
t_shift=t_shift,
|
| 468 |
+
)
|
| 469 |
+
x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
|
| 470 |
+
x1_prompt = torch.zeros(
|
| 471 |
+
x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
|
| 472 |
+
)
|
| 473 |
+
x1_wo_prompt = torch.zeros(
|
| 474 |
+
x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
|
| 475 |
+
)
|
| 476 |
+
for i in range(x1.size(0)):
|
| 477 |
+
x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
|
| 478 |
+
i,
|
| 479 |
+
prompt_features_lens[i] : prompt_features_lens[i]
|
| 480 |
+
+ x1_wo_prompt_lens[i],
|
| 481 |
+
]
|
| 482 |
+
x1_prompt[i, : prompt_features_lens[i], :] = x1[
|
| 483 |
+
i, : prompt_features_lens[i]
|
| 484 |
+
]
|
| 485 |
+
|
| 486 |
+
return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
|
| 487 |
+
|
| 488 |
+
def sample_intermediate(
|
| 489 |
+
self,
|
| 490 |
+
tokens: List[List[int]],
|
| 491 |
+
features: torch.Tensor,
|
| 492 |
+
features_lens: torch.Tensor,
|
| 493 |
+
noise: torch.Tensor,
|
| 494 |
+
speech_condition_mask: torch.Tensor,
|
| 495 |
+
t_start: float,
|
| 496 |
+
t_end: float,
|
| 497 |
+
num_step: int = 1,
|
| 498 |
+
guidance_scale: torch.Tensor = None,
|
| 499 |
+
) -> torch.Tensor:
|
| 500 |
+
"""
|
| 501 |
+
Generate acoustic features in intermediate timesteps.
|
| 502 |
+
Args:
|
| 503 |
+
tokens: List of list of token ids.
|
| 504 |
+
features: The acoustic features, with the shape (batch, seq_len, feat_dim).
|
| 505 |
+
features_lens: The length of each acoustic feature sequence,
|
| 506 |
+
with the shape (batch,).
|
| 507 |
+
noise: The initial noise, with the shape (batch, seq_len, feat_dim).
|
| 508 |
+
speech_condition_mask: The mask for speech condition, True means
|
| 509 |
+
non-condition positions, with the shape (batch, seq_len).
|
| 510 |
+
t_start: The start timestep.
|
| 511 |
+
t_end: The end timestep.
|
| 512 |
+
num_step: The number of steps for sampling.
|
| 513 |
+
guidance_scale: The scale for classifier-free guidance inference,
|
| 514 |
+
with the shape (batch, 1, 1).
|
| 515 |
+
"""
|
| 516 |
+
(text_condition, padding_mask,) = self.forward_text_train(
|
| 517 |
+
tokens=tokens,
|
| 518 |
+
features_lens=features_lens,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
|
| 522 |
+
|
| 523 |
+
x_t_end = self.solver.sample(
|
| 524 |
+
x=noise,
|
| 525 |
+
text_condition=text_condition,
|
| 526 |
+
speech_condition=speech_condition,
|
| 527 |
+
padding_mask=padding_mask,
|
| 528 |
+
num_step=num_step,
|
| 529 |
+
guidance_scale=guidance_scale,
|
| 530 |
+
t_start=t_start,
|
| 531 |
+
t_end=t_end,
|
| 532 |
+
)
|
| 533 |
+
x_t_end_lens = (~padding_mask).sum(-1)
|
| 534 |
+
return x_t_end, x_t_end_lens
|
zipvoice/models/zipvoice_dialog.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
| 2 |
+
#
|
| 3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 22 |
+
|
| 23 |
+
from zipvoice.models.modules.zipformer_two_stream import TTSZipformerTwoStream
|
| 24 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 25 |
+
from zipvoice.utils.common import condition_time_mask_suffix, make_pad_mask, pad_labels
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ZipVoiceDialog(ZipVoice):
|
| 29 |
+
"""The ZipVoice-Dialog model."""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1],
|
| 34 |
+
fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4],
|
| 35 |
+
fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31],
|
| 36 |
+
fm_decoder_feedforward_dim: int = 1536,
|
| 37 |
+
fm_decoder_num_heads: int = 4,
|
| 38 |
+
fm_decoder_dim: int = 512,
|
| 39 |
+
text_encoder_num_layers: int = 4,
|
| 40 |
+
text_encoder_feedforward_dim: int = 512,
|
| 41 |
+
text_encoder_cnn_module_kernel: int = 9,
|
| 42 |
+
text_encoder_num_heads: int = 4,
|
| 43 |
+
text_encoder_dim: int = 192,
|
| 44 |
+
time_embed_dim: int = 192,
|
| 45 |
+
text_embed_dim: int = 192,
|
| 46 |
+
query_head_dim: int = 32,
|
| 47 |
+
value_head_dim: int = 12,
|
| 48 |
+
pos_head_dim: int = 4,
|
| 49 |
+
pos_dim: int = 48,
|
| 50 |
+
feat_dim: int = 100,
|
| 51 |
+
vocab_size: int = 26,
|
| 52 |
+
pad_id: int = 0,
|
| 53 |
+
spk_a_id: int = 360,
|
| 54 |
+
spk_b_id: int = 361,
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Initialize the model with specified configuration parameters.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
fm_decoder_downsampling_factor: List of downsampling factors for each layer
|
| 61 |
+
in the flow-matching decoder.
|
| 62 |
+
fm_decoder_num_layers: List of the number of layers for each block in the
|
| 63 |
+
flow-matching decoder.
|
| 64 |
+
fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the
|
| 65 |
+
flow-matching decoder.
|
| 66 |
+
fm_decoder_feedforward_dim: Dimension of the feedforward network in the
|
| 67 |
+
flow-matching decoder.
|
| 68 |
+
fm_decoder_num_heads: Number of attention heads in the flow-matching
|
| 69 |
+
decoder.
|
| 70 |
+
fm_decoder_dim: Hidden dimension of the flow-matching decoder.
|
| 71 |
+
text_encoder_num_layers: Number of layers in the text encoder.
|
| 72 |
+
text_encoder_feedforward_dim: Dimension of the feedforward network in the
|
| 73 |
+
text encoder.
|
| 74 |
+
text_encoder_cnn_module_kernel: Kernel size for the CNN module in the
|
| 75 |
+
text encoder.
|
| 76 |
+
text_encoder_num_heads: Number of attention heads in the text encoder.
|
| 77 |
+
text_encoder_dim: Hidden dimension of the text encoder.
|
| 78 |
+
time_embed_dim: Dimension of the time embedding.
|
| 79 |
+
text_embed_dim: Dimension of the text embedding.
|
| 80 |
+
query_head_dim: Dimension of the query attention head.
|
| 81 |
+
value_head_dim: Dimension of the value attention head.
|
| 82 |
+
pos_head_dim: Dimension of the position attention head.
|
| 83 |
+
pos_dim: Dimension of the positional encoding.
|
| 84 |
+
feat_dim: Dimension of the acoustic features.
|
| 85 |
+
vocab_size: Size of the vocabulary.
|
| 86 |
+
pad_id: ID used for padding tokens.
|
| 87 |
+
spk_a_id: ID of speaker A / [S1].
|
| 88 |
+
spk_b_id: ID of speaker B / [S2].
|
| 89 |
+
"""
|
| 90 |
+
super().__init__(
|
| 91 |
+
fm_decoder_downsampling_factor=fm_decoder_downsampling_factor,
|
| 92 |
+
fm_decoder_num_layers=fm_decoder_num_layers,
|
| 93 |
+
fm_decoder_cnn_module_kernel=fm_decoder_cnn_module_kernel,
|
| 94 |
+
fm_decoder_feedforward_dim=fm_decoder_feedforward_dim,
|
| 95 |
+
fm_decoder_num_heads=fm_decoder_num_heads,
|
| 96 |
+
fm_decoder_dim=fm_decoder_dim,
|
| 97 |
+
text_encoder_num_layers=text_encoder_num_layers,
|
| 98 |
+
text_encoder_feedforward_dim=text_encoder_feedforward_dim,
|
| 99 |
+
text_encoder_cnn_module_kernel=text_encoder_cnn_module_kernel,
|
| 100 |
+
text_encoder_num_heads=text_encoder_num_heads,
|
| 101 |
+
text_encoder_dim=text_encoder_dim,
|
| 102 |
+
time_embed_dim=time_embed_dim,
|
| 103 |
+
text_embed_dim=text_embed_dim,
|
| 104 |
+
query_head_dim=query_head_dim,
|
| 105 |
+
value_head_dim=value_head_dim,
|
| 106 |
+
pos_head_dim=pos_head_dim,
|
| 107 |
+
pos_dim=pos_dim,
|
| 108 |
+
feat_dim=feat_dim,
|
| 109 |
+
vocab_size=vocab_size,
|
| 110 |
+
pad_id=pad_id,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.spk_a_id = spk_a_id
|
| 114 |
+
self.spk_b_id = spk_b_id
|
| 115 |
+
self.spk_embed = nn.Embedding(2, feat_dim)
|
| 116 |
+
torch.nn.init.normal_(self.spk_embed.weight, mean=0, std=0.1)
|
| 117 |
+
|
| 118 |
+
def extract_spk_indices(self, tensor):
|
| 119 |
+
turn_mask = ((tensor == self.spk_a_id) | (tensor == self.spk_b_id)).long()
|
| 120 |
+
turn_counts = turn_mask.cumsum(dim=1)
|
| 121 |
+
spk_mask = turn_counts % 2
|
| 122 |
+
spk_mask = torch.where(tensor == self.pad_id, -1, spk_mask)
|
| 123 |
+
spk_a_indices = torch.where(spk_mask == 0)
|
| 124 |
+
spk_b_indices = torch.where(spk_mask == 1)
|
| 125 |
+
return spk_a_indices, spk_b_indices
|
| 126 |
+
|
| 127 |
+
def forward_text_embed(
|
| 128 |
+
self,
|
| 129 |
+
tokens: List[List[int]],
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Get the text embeddings.
|
| 133 |
+
Args:
|
| 134 |
+
tokens: a list of list of token ids.
|
| 135 |
+
Returns:
|
| 136 |
+
embed: the text embeddings, shape (batch, seq_len, emb_dim).
|
| 137 |
+
tokens_lens: the length of each token sequence, shape (batch,).
|
| 138 |
+
"""
|
| 139 |
+
device = (
|
| 140 |
+
self.device if isinstance(self, DDP) else next(self.parameters()).device
|
| 141 |
+
)
|
| 142 |
+
tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
|
| 143 |
+
embed = self.embed(tokens_padded) # (B, S, C)
|
| 144 |
+
spk_a_indices, spk_b_indices = self.extract_spk_indices(tokens_padded)
|
| 145 |
+
tokens_lens = torch.tensor(
|
| 146 |
+
[len(token) for token in tokens], dtype=torch.int64, device=device
|
| 147 |
+
)
|
| 148 |
+
tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
|
| 149 |
+
|
| 150 |
+
embed = self.text_encoder(
|
| 151 |
+
x=embed, t=None, padding_mask=tokens_padding_mask
|
| 152 |
+
) # (B, S, C)
|
| 153 |
+
embed[spk_a_indices] += self.spk_embed(torch.tensor(0, device=device)).to(
|
| 154 |
+
embed.dtype
|
| 155 |
+
)
|
| 156 |
+
embed[spk_b_indices] += self.spk_embed(torch.tensor(1, device=device)).to(
|
| 157 |
+
embed.dtype
|
| 158 |
+
)
|
| 159 |
+
return embed, tokens_lens
|
| 160 |
+
|
| 161 |
+
def forward(
|
| 162 |
+
self,
|
| 163 |
+
tokens: List[List[int]],
|
| 164 |
+
features: torch.Tensor,
|
| 165 |
+
features_lens: torch.Tensor,
|
| 166 |
+
noise: torch.Tensor,
|
| 167 |
+
t: torch.Tensor,
|
| 168 |
+
condition_drop_ratio: float = 0.0,
|
| 169 |
+
) -> torch.Tensor:
|
| 170 |
+
"""Forward pass of the model for training.
|
| 171 |
+
Args:
|
| 172 |
+
tokens: a list of list of token ids.
|
| 173 |
+
features: the acoustic features, with the shape (batch, seq_len, feat_dim).
|
| 174 |
+
features_lens: the length of each acoustic feature sequence, shape (batch,).
|
| 175 |
+
noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
|
| 176 |
+
t: the time step, with the shape (batch, 1, 1).
|
| 177 |
+
condition_drop_ratio: the ratio of dropped text condition.
|
| 178 |
+
Returns:
|
| 179 |
+
fm_loss: the flow-matching loss.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
(text_condition, padding_mask,) = self.forward_text_train(
|
| 183 |
+
tokens=tokens,
|
| 184 |
+
features_lens=features_lens,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
speech_condition_mask = condition_time_mask_suffix(
|
| 188 |
+
features_lens=features_lens,
|
| 189 |
+
mask_percent=(0.5, 1.0),
|
| 190 |
+
max_len=features.size(1),
|
| 191 |
+
)
|
| 192 |
+
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
|
| 193 |
+
|
| 194 |
+
if condition_drop_ratio > 0.0:
|
| 195 |
+
drop_mask = (
|
| 196 |
+
torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
|
| 197 |
+
> condition_drop_ratio
|
| 198 |
+
)
|
| 199 |
+
text_condition = text_condition * drop_mask
|
| 200 |
+
|
| 201 |
+
xt = features * t + noise * (1 - t)
|
| 202 |
+
ut = features - noise # (B, T, F)
|
| 203 |
+
|
| 204 |
+
vt = self.forward_fm_decoder(
|
| 205 |
+
t=t,
|
| 206 |
+
xt=xt,
|
| 207 |
+
text_condition=text_condition,
|
| 208 |
+
speech_condition=speech_condition,
|
| 209 |
+
padding_mask=padding_mask,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
loss_mask = speech_condition_mask & (~padding_mask)
|
| 213 |
+
fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
|
| 214 |
+
|
| 215 |
+
return fm_loss
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class ZipVoiceDialogStereo(ZipVoiceDialog):
|
| 219 |
+
def __init__(self, *args, **kwargs):
|
| 220 |
+
super().__init__(*args, **kwargs)
|
| 221 |
+
|
| 222 |
+
required_params = {
|
| 223 |
+
"feat_dim",
|
| 224 |
+
"fm_decoder_downsampling_factor",
|
| 225 |
+
"fm_decoder_num_layers",
|
| 226 |
+
"fm_decoder_cnn_module_kernel",
|
| 227 |
+
"fm_decoder_dim",
|
| 228 |
+
"fm_decoder_feedforward_dim",
|
| 229 |
+
"fm_decoder_num_heads",
|
| 230 |
+
"query_head_dim",
|
| 231 |
+
"pos_head_dim",
|
| 232 |
+
"value_head_dim",
|
| 233 |
+
"pos_dim",
|
| 234 |
+
"time_embed_dim",
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
missing = [p for p in required_params if p not in kwargs]
|
| 238 |
+
if missing:
|
| 239 |
+
raise ValueError(f"Missing required parameters: {', '.join(missing)}")
|
| 240 |
+
|
| 241 |
+
self.fm_decoder = TTSZipformerTwoStream(
|
| 242 |
+
in_dim=(kwargs["feat_dim"] * 5, kwargs["feat_dim"] * 3),
|
| 243 |
+
out_dim=(kwargs["feat_dim"] * 2, kwargs["feat_dim"]),
|
| 244 |
+
downsampling_factor=kwargs["fm_decoder_downsampling_factor"],
|
| 245 |
+
num_encoder_layers=kwargs["fm_decoder_num_layers"],
|
| 246 |
+
cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"],
|
| 247 |
+
encoder_dim=kwargs["fm_decoder_dim"],
|
| 248 |
+
feedforward_dim=kwargs["fm_decoder_feedforward_dim"],
|
| 249 |
+
num_heads=kwargs["fm_decoder_num_heads"],
|
| 250 |
+
query_head_dim=kwargs["query_head_dim"],
|
| 251 |
+
pos_head_dim=kwargs["pos_head_dim"],
|
| 252 |
+
value_head_dim=kwargs["value_head_dim"],
|
| 253 |
+
pos_dim=kwargs["pos_dim"],
|
| 254 |
+
use_time_embed=True,
|
| 255 |
+
time_embed_dim=kwargs["time_embed_dim"],
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
def forward(
|
| 259 |
+
self,
|
| 260 |
+
tokens: List[List[int]],
|
| 261 |
+
features: torch.Tensor,
|
| 262 |
+
features_lens: torch.Tensor,
|
| 263 |
+
noise: torch.Tensor,
|
| 264 |
+
t: torch.Tensor,
|
| 265 |
+
condition_drop_ratio: float = 0.0,
|
| 266 |
+
se_weight: float = 1.0,
|
| 267 |
+
) -> torch.Tensor:
|
| 268 |
+
"""Forward pass of the model for training.
|
| 269 |
+
Args:
|
| 270 |
+
tokens: a list of list of token ids.
|
| 271 |
+
features: the acoustic features, with the shape (batch, seq_len, feat_dim).
|
| 272 |
+
features_lens: the length of each acoustic feature sequence, shape (batch,).
|
| 273 |
+
noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
|
| 274 |
+
t: the time step, with the shape (batch, 1, 1).
|
| 275 |
+
condition_drop_ratio: the ratio of dropped text condition.
|
| 276 |
+
se_weight: the weight of the speaker exclusive loss.
|
| 277 |
+
Returns:
|
| 278 |
+
fm_loss: the flow-matching loss.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
(text_condition, padding_mask,) = self.forward_text_train(
|
| 282 |
+
tokens=tokens,
|
| 283 |
+
features_lens=features_lens,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
speech_condition_mask = condition_time_mask_suffix(
|
| 287 |
+
features_lens=features_lens,
|
| 288 |
+
mask_percent=(0.5, 1.0),
|
| 289 |
+
max_len=features.size(1),
|
| 290 |
+
)
|
| 291 |
+
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
|
| 292 |
+
|
| 293 |
+
if condition_drop_ratio > 0.0:
|
| 294 |
+
drop_mask = (
|
| 295 |
+
torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
|
| 296 |
+
> condition_drop_ratio
|
| 297 |
+
)
|
| 298 |
+
text_condition = text_condition * drop_mask
|
| 299 |
+
|
| 300 |
+
xt = features * t + noise * (1 - t)
|
| 301 |
+
ut = features - noise # (B, T, F)
|
| 302 |
+
|
| 303 |
+
vt = self.forward_fm_decoder(
|
| 304 |
+
t=t,
|
| 305 |
+
xt=xt,
|
| 306 |
+
text_condition=text_condition,
|
| 307 |
+
speech_condition=speech_condition,
|
| 308 |
+
padding_mask=padding_mask,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
loss_mask = speech_condition_mask & (~padding_mask)
|
| 312 |
+
fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
|
| 313 |
+
|
| 314 |
+
if se_weight > 0:
|
| 315 |
+
target = xt + vt * (1 - t)
|
| 316 |
+
fbank_1 = target[:, :, : self.feat_dim]
|
| 317 |
+
fbank_2 = target[:, :, self.feat_dim :]
|
| 318 |
+
energy_loss = torch.mean(
|
| 319 |
+
self.energy_based_loss(fbank_1, fbank_2, features)[loss_mask]
|
| 320 |
+
)
|
| 321 |
+
loss = fm_loss + energy_loss * se_weight
|
| 322 |
+
else:
|
| 323 |
+
loss = fm_loss
|
| 324 |
+
|
| 325 |
+
return loss
|
| 326 |
+
|
| 327 |
+
def energy_based_loss(self, fbank1, fbank2, gt_fbank):
|
| 328 |
+
energy1 = self.energy(fbank1)
|
| 329 |
+
energy2 = self.energy(fbank2)
|
| 330 |
+
|
| 331 |
+
energy_thresholds = self.adaptive_threshold_from_gt(
|
| 332 |
+
torch.cat(
|
| 333 |
+
[
|
| 334 |
+
gt_fbank[:, :, : self.feat_dim],
|
| 335 |
+
gt_fbank[:, :, self.feat_dim :],
|
| 336 |
+
],
|
| 337 |
+
dim=1,
|
| 338 |
+
)
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
both_speaking = (
|
| 342 |
+
(energy1 > energy_thresholds) & (energy2 > energy_thresholds)
|
| 343 |
+
).float()
|
| 344 |
+
|
| 345 |
+
penalty = (
|
| 346 |
+
both_speaking
|
| 347 |
+
* (energy1 - energy_thresholds)
|
| 348 |
+
* (energy2 - energy_thresholds)
|
| 349 |
+
)
|
| 350 |
+
return penalty
|
| 351 |
+
|
| 352 |
+
def energy(self, fbank):
|
| 353 |
+
return torch.mean(fbank, dim=-1)
|
| 354 |
+
|
| 355 |
+
def adaptive_threshold_from_gt(self, gt_fbank, percentile=50):
|
| 356 |
+
frame_energies = self.energy(gt_fbank)
|
| 357 |
+
thresholds = torch.quantile(frame_energies, q=percentile / 100, dim=1)
|
| 358 |
+
return thresholds.unsqueeze(1)
|
zipvoice/models/zipvoice_distill.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
| 2 |
+
# Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from typing import List
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from zipvoice.models.modules.solver import DistillEulerSolver
|
| 23 |
+
from zipvoice.models.modules.zipformer import TTSZipformer
|
| 24 |
+
from zipvoice.models.zipvoice import ZipVoice
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ZipVoiceDistill(ZipVoice):
|
| 28 |
+
"""ZipVoice-Distill model."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, *args, **kwargs):
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
required_params = {
|
| 34 |
+
"feat_dim",
|
| 35 |
+
"fm_decoder_downsampling_factor",
|
| 36 |
+
"fm_decoder_num_layers",
|
| 37 |
+
"fm_decoder_cnn_module_kernel",
|
| 38 |
+
"fm_decoder_dim",
|
| 39 |
+
"fm_decoder_feedforward_dim",
|
| 40 |
+
"fm_decoder_num_heads",
|
| 41 |
+
"query_head_dim",
|
| 42 |
+
"pos_head_dim",
|
| 43 |
+
"value_head_dim",
|
| 44 |
+
"pos_dim",
|
| 45 |
+
"time_embed_dim",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
missing = [p for p in required_params if p not in kwargs]
|
| 49 |
+
if missing:
|
| 50 |
+
raise ValueError(f"Missing required parameters: {', '.join(missing)}")
|
| 51 |
+
|
| 52 |
+
self.fm_decoder = TTSZipformer(
|
| 53 |
+
in_dim=kwargs["feat_dim"] * 3,
|
| 54 |
+
out_dim=kwargs["feat_dim"],
|
| 55 |
+
downsampling_factor=kwargs["fm_decoder_downsampling_factor"],
|
| 56 |
+
num_encoder_layers=kwargs["fm_decoder_num_layers"],
|
| 57 |
+
cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"],
|
| 58 |
+
encoder_dim=kwargs["fm_decoder_dim"],
|
| 59 |
+
feedforward_dim=kwargs["fm_decoder_feedforward_dim"],
|
| 60 |
+
num_heads=kwargs["fm_decoder_num_heads"],
|
| 61 |
+
query_head_dim=kwargs["query_head_dim"],
|
| 62 |
+
pos_head_dim=kwargs["pos_head_dim"],
|
| 63 |
+
value_head_dim=kwargs["value_head_dim"],
|
| 64 |
+
pos_dim=kwargs["pos_dim"],
|
| 65 |
+
use_time_embed=True,
|
| 66 |
+
time_embed_dim=kwargs["time_embed_dim"],
|
| 67 |
+
use_guidance_scale_embed=True,
|
| 68 |
+
)
|
| 69 |
+
self.solver = DistillEulerSolver(self, func_name="forward_fm_decoder")
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self,
|
| 73 |
+
tokens: List[List[int]],
|
| 74 |
+
features: torch.Tensor,
|
| 75 |
+
features_lens: torch.Tensor,
|
| 76 |
+
noise: torch.Tensor,
|
| 77 |
+
speech_condition_mask: torch.Tensor,
|
| 78 |
+
t_start: float,
|
| 79 |
+
t_end: float,
|
| 80 |
+
num_step: int = 1,
|
| 81 |
+
guidance_scale: torch.Tensor = None,
|
| 82 |
+
) -> torch.Tensor:
|
| 83 |
+
|
| 84 |
+
return self.sample_intermediate(
|
| 85 |
+
tokens=tokens,
|
| 86 |
+
features=features,
|
| 87 |
+
features_lens=features_lens,
|
| 88 |
+
noise=noise,
|
| 89 |
+
speech_condition_mask=speech_condition_mask,
|
| 90 |
+
t_start=t_start,
|
| 91 |
+
t_end=t_end,
|
| 92 |
+
num_step=num_step,
|
| 93 |
+
guidance_scale=guidance_scale,
|
| 94 |
+
)
|
zipvoice/tokenizer/normalizer.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
|
| 4 |
+
import cn2an
|
| 5 |
+
import inflect
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TextNormalizer(ABC):
|
| 9 |
+
"""Abstract base class for text normalization, defining common interface."""
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def normalize(self, text: str) -> str:
|
| 13 |
+
"""Normalize text."""
|
| 14 |
+
raise NotImplementedError
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EnglishTextNormalizer(TextNormalizer):
|
| 18 |
+
"""
|
| 19 |
+
A class to handle preprocessing of English text including normalization. Following:
|
| 20 |
+
https://github.com/espnet/espnet_tts_frontend/blob/master/tacotron_cleaner/cleaners.py
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
| 25 |
+
self._abbreviations = [
|
| 26 |
+
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
| 27 |
+
for x in [
|
| 28 |
+
("mrs", "misess"),
|
| 29 |
+
("mr", "mister"),
|
| 30 |
+
("dr", "doctor"),
|
| 31 |
+
("st", "saint"),
|
| 32 |
+
("co", "company"),
|
| 33 |
+
("jr", "junior"),
|
| 34 |
+
("maj", "major"),
|
| 35 |
+
("gen", "general"),
|
| 36 |
+
("drs", "doctors"),
|
| 37 |
+
("rev", "reverend"),
|
| 38 |
+
("lt", "lieutenant"),
|
| 39 |
+
("hon", "honorable"),
|
| 40 |
+
("sgt", "sergeant"),
|
| 41 |
+
("capt", "captain"),
|
| 42 |
+
("esq", "esquire"),
|
| 43 |
+
("ltd", "limited"),
|
| 44 |
+
("col", "colonel"),
|
| 45 |
+
("ft", "fort"),
|
| 46 |
+
("etc", "et cetera"),
|
| 47 |
+
("btw", "by the way"),
|
| 48 |
+
]
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
self._inflect = inflect.engine()
|
| 52 |
+
self._comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
| 53 |
+
self._decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
| 54 |
+
self._percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
|
| 55 |
+
self._pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
| 56 |
+
self._dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
| 57 |
+
self._fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
|
| 58 |
+
self._ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
| 59 |
+
self._number_re = re.compile(r"[0-9]+")
|
| 60 |
+
self._whitespace_re = re.compile(r"\s+")
|
| 61 |
+
|
| 62 |
+
def normalize(self, text: str) -> str:
|
| 63 |
+
"""Custom pipeline for English text,
|
| 64 |
+
including number and abbreviation expansion."""
|
| 65 |
+
text = self.expand_abbreviations(text)
|
| 66 |
+
text = self.normalize_numbers(text)
|
| 67 |
+
|
| 68 |
+
return text
|
| 69 |
+
|
| 70 |
+
def fraction_to_words(self, numerator, denominator):
|
| 71 |
+
if numerator == 1 and denominator == 2:
|
| 72 |
+
return " one half "
|
| 73 |
+
if numerator == 1 and denominator == 4:
|
| 74 |
+
return " one quarter "
|
| 75 |
+
if denominator == 2:
|
| 76 |
+
return " " + self._inflect.number_to_words(numerator) + " halves "
|
| 77 |
+
if denominator == 4:
|
| 78 |
+
return " " + self._inflect.number_to_words(numerator) + " quarters "
|
| 79 |
+
return (
|
| 80 |
+
" "
|
| 81 |
+
+ self._inflect.number_to_words(numerator)
|
| 82 |
+
+ " "
|
| 83 |
+
+ self._inflect.ordinal(self._inflect.number_to_words(denominator))
|
| 84 |
+
+ " "
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def _remove_commas(self, m):
|
| 88 |
+
return m.group(1).replace(",", "")
|
| 89 |
+
|
| 90 |
+
def _expand_dollars(self, m):
|
| 91 |
+
match = m.group(1)
|
| 92 |
+
parts = match.split(".")
|
| 93 |
+
if len(parts) > 2:
|
| 94 |
+
return " " + match + " dollars " # Unexpected format
|
| 95 |
+
dollars = int(parts[0]) if parts[0] else 0
|
| 96 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
| 97 |
+
if dollars and cents:
|
| 98 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 99 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 100 |
+
return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
|
| 101 |
+
elif dollars:
|
| 102 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 103 |
+
return " %s %s " % (dollars, dollar_unit)
|
| 104 |
+
elif cents:
|
| 105 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 106 |
+
return " %s %s " % (cents, cent_unit)
|
| 107 |
+
else:
|
| 108 |
+
return " zero dollars "
|
| 109 |
+
|
| 110 |
+
def _expand_fraction(self, m):
|
| 111 |
+
numerator = int(m.group(1))
|
| 112 |
+
denominator = int(m.group(2))
|
| 113 |
+
return self.fraction_to_words(numerator, denominator)
|
| 114 |
+
|
| 115 |
+
def _expand_decimal_point(self, m):
|
| 116 |
+
return m.group(1).replace(".", " point ")
|
| 117 |
+
|
| 118 |
+
def _expand_percent(self, m):
|
| 119 |
+
return m.group(1).replace("%", " percent ")
|
| 120 |
+
|
| 121 |
+
def _expand_ordinal(self, m):
|
| 122 |
+
return " " + self._inflect.number_to_words(m.group(0)) + " "
|
| 123 |
+
|
| 124 |
+
def _expand_number(self, m):
|
| 125 |
+
num = int(m.group(0))
|
| 126 |
+
if num > 1000 and num < 3000:
|
| 127 |
+
if num == 2000:
|
| 128 |
+
return " two thousand "
|
| 129 |
+
elif num > 2000 and num < 2010:
|
| 130 |
+
return " two thousand " + self._inflect.number_to_words(num % 100) + " "
|
| 131 |
+
elif num % 100 == 0:
|
| 132 |
+
return " " + self._inflect.number_to_words(num // 100) + " hundred "
|
| 133 |
+
else:
|
| 134 |
+
return (
|
| 135 |
+
" "
|
| 136 |
+
+ self._inflect.number_to_words(
|
| 137 |
+
num, andword="", zero="oh", group=2
|
| 138 |
+
).replace(", ", " ")
|
| 139 |
+
+ " "
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
return " " + self._inflect.number_to_words(num, andword="") + " "
|
| 143 |
+
|
| 144 |
+
def normalize_numbers(self, text):
|
| 145 |
+
text = re.sub(self._comma_number_re, self._remove_commas, text)
|
| 146 |
+
text = re.sub(self._pounds_re, r"\1 pounds", text)
|
| 147 |
+
text = re.sub(self._dollars_re, self._expand_dollars, text)
|
| 148 |
+
text = re.sub(self._fraction_re, self._expand_fraction, text)
|
| 149 |
+
text = re.sub(self._decimal_number_re, self._expand_decimal_point, text)
|
| 150 |
+
text = re.sub(self._percent_number_re, self._expand_percent, text)
|
| 151 |
+
text = re.sub(self._ordinal_re, self._expand_ordinal, text)
|
| 152 |
+
text = re.sub(self._number_re, self._expand_number, text)
|
| 153 |
+
return text
|
| 154 |
+
|
| 155 |
+
def expand_abbreviations(self, text):
|
| 156 |
+
for regex, replacement in self._abbreviations:
|
| 157 |
+
text = re.sub(regex, replacement, text)
|
| 158 |
+
return text
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class ChineseTextNormalizer(TextNormalizer):
|
| 162 |
+
"""
|
| 163 |
+
A class to handle preprocessing of Chinese text including normalization.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def normalize(self, text: str) -> str:
|
| 167 |
+
"""Normalize text."""
|
| 168 |
+
# Convert numbers to Chinese
|
| 169 |
+
text = cn2an.transform(text, "an2cn")
|
| 170 |
+
return text
|
zipvoice/tokenizer/tokenizer.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
|
| 2 |
+
# Han Zhu,
|
| 3 |
+
# Wei Kang)
|
| 4 |
+
#
|
| 5 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import re
|
| 21 |
+
from abc import ABC, abstractmethod
|
| 22 |
+
from functools import reduce
|
| 23 |
+
from typing import Dict, List, Optional
|
| 24 |
+
|
| 25 |
+
import jieba
|
| 26 |
+
from pypinyin import Style, lazy_pinyin
|
| 27 |
+
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
| 28 |
+
|
| 29 |
+
from zipvoice.tokenizer.normalizer import ChineseTextNormalizer, EnglishTextNormalizer
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from piper_phonemize import phonemize_espeak
|
| 33 |
+
except Exception as ex:
|
| 34 |
+
raise RuntimeError(
|
| 35 |
+
f"{ex}\nPlease run\n"
|
| 36 |
+
"pip install piper_phonemize -f \
|
| 37 |
+
https://k2-fsa.github.io/icefall/piper_phonemize.html"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Tokenizer(ABC):
|
| 42 |
+
"""Abstract base class for tokenizers, defining common interface."""
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def texts_to_token_ids(self, texts: List[str]) -> List[List[int]]:
|
| 46 |
+
"""Convert list of texts to list of token id sequences."""
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def texts_to_tokens(self, texts: List[str]) -> List[List[str]]:
|
| 51 |
+
"""Convert list of texts to list of token sequences."""
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]:
|
| 56 |
+
"""Convert list of token sequences to list of token id sequences."""
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SimpleTokenizer(Tokenizer):
|
| 61 |
+
"""The simplpest tokenizer, treat every character as a token,
|
| 62 |
+
without text normalization.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, token_file: Optional[str] = None):
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
tokens: the file that contains information that maps tokens to ids,
|
| 69 |
+
which is a text file with '{token}\t{token_id}' per line.
|
| 70 |
+
"""
|
| 71 |
+
# Parse token file
|
| 72 |
+
self.has_tokens = False
|
| 73 |
+
if token_file is None:
|
| 74 |
+
logging.debug(
|
| 75 |
+
"Initialize Tokenizer without tokens file, \
|
| 76 |
+
will fail when map to ids."
|
| 77 |
+
)
|
| 78 |
+
return
|
| 79 |
+
self.token2id: Dict[str, int] = {}
|
| 80 |
+
with open(token_file, "r", encoding="utf-8") as f:
|
| 81 |
+
for line in f.readlines():
|
| 82 |
+
info = line.rstrip().split("\t")
|
| 83 |
+
token, id = info[0], int(info[1])
|
| 84 |
+
assert token not in self.token2id, token
|
| 85 |
+
self.token2id[token] = id
|
| 86 |
+
self.pad_id = self.token2id["_"] # padding
|
| 87 |
+
self.vocab_size = len(self.token2id)
|
| 88 |
+
self.has_tokens = True
|
| 89 |
+
|
| 90 |
+
def texts_to_token_ids(
|
| 91 |
+
self,
|
| 92 |
+
texts: List[str],
|
| 93 |
+
) -> List[List[int]]:
|
| 94 |
+
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
|
| 95 |
+
|
| 96 |
+
def texts_to_tokens(
|
| 97 |
+
self,
|
| 98 |
+
texts: List[str],
|
| 99 |
+
) -> List[List[str]]:
|
| 100 |
+
tokens_list = [list(texts[i]) for i in range(len(texts))]
|
| 101 |
+
return tokens_list
|
| 102 |
+
|
| 103 |
+
def tokens_to_token_ids(
|
| 104 |
+
self,
|
| 105 |
+
tokens_list: List[List[str]],
|
| 106 |
+
) -> List[List[int]]:
|
| 107 |
+
assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
|
| 108 |
+
|
| 109 |
+
token_ids_list = []
|
| 110 |
+
|
| 111 |
+
for tokens in tokens_list:
|
| 112 |
+
token_ids = []
|
| 113 |
+
for t in tokens:
|
| 114 |
+
if t not in self.token2id:
|
| 115 |
+
logging.debug(f"Skip OOV {t}")
|
| 116 |
+
continue
|
| 117 |
+
token_ids.append(self.token2id[t])
|
| 118 |
+
|
| 119 |
+
token_ids_list.append(token_ids)
|
| 120 |
+
|
| 121 |
+
return token_ids_list
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class EspeakTokenizer(Tokenizer):
|
| 125 |
+
"""A simple tokenizer with Espeak g2p function."""
|
| 126 |
+
|
| 127 |
+
def __init__(self, token_file: Optional[str] = None, lang: str = "en-us"):
|
| 128 |
+
"""
|
| 129 |
+
Args:
|
| 130 |
+
tokens: the file that contains information that maps tokens to ids,
|
| 131 |
+
which is a text file with '{token}\t{token_id}' per line.
|
| 132 |
+
lang: the language identifier, see
|
| 133 |
+
https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md
|
| 134 |
+
"""
|
| 135 |
+
# Parse token file
|
| 136 |
+
self.has_tokens = False
|
| 137 |
+
if token_file is None:
|
| 138 |
+
logging.debug(
|
| 139 |
+
"Initialize Tokenizer without tokens file, \
|
| 140 |
+
will fail when map to ids."
|
| 141 |
+
)
|
| 142 |
+
return
|
| 143 |
+
self.token2id: Dict[str, int] = {}
|
| 144 |
+
with open(token_file, "r", encoding="utf-8") as f:
|
| 145 |
+
for line in f.readlines():
|
| 146 |
+
info = line.rstrip().split("\t")
|
| 147 |
+
token, id = info[0], int(info[1])
|
| 148 |
+
assert token not in self.token2id, token
|
| 149 |
+
self.token2id[token] = id
|
| 150 |
+
self.pad_id = self.token2id["_"] # padding
|
| 151 |
+
self.vocab_size = len(self.token2id)
|
| 152 |
+
self.has_tokens = True
|
| 153 |
+
self.lang = lang
|
| 154 |
+
|
| 155 |
+
def g2p(self, text: str) -> List[str]:
|
| 156 |
+
try:
|
| 157 |
+
tokens = phonemize_espeak(text, self.lang)
|
| 158 |
+
tokens = reduce(lambda x, y: x + y, tokens)
|
| 159 |
+
return tokens
|
| 160 |
+
except Exception as ex:
|
| 161 |
+
logging.warning(f"Tokenization of {self.lang} texts failed: {ex}")
|
| 162 |
+
return []
|
| 163 |
+
|
| 164 |
+
def texts_to_token_ids(
|
| 165 |
+
self,
|
| 166 |
+
texts: List[str],
|
| 167 |
+
) -> List[List[int]]:
|
| 168 |
+
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
|
| 169 |
+
|
| 170 |
+
def texts_to_tokens(
|
| 171 |
+
self,
|
| 172 |
+
texts: List[str],
|
| 173 |
+
) -> List[List[str]]:
|
| 174 |
+
tokens_list = [self.g2p(texts[i]) for i in range(len(texts))]
|
| 175 |
+
return tokens_list
|
| 176 |
+
|
| 177 |
+
def tokens_to_token_ids(
|
| 178 |
+
self,
|
| 179 |
+
tokens_list: List[List[str]],
|
| 180 |
+
) -> List[List[int]]:
|
| 181 |
+
assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
|
| 182 |
+
|
| 183 |
+
token_ids_list = []
|
| 184 |
+
|
| 185 |
+
for tokens in tokens_list:
|
| 186 |
+
token_ids = []
|
| 187 |
+
for t in tokens:
|
| 188 |
+
if t not in self.token2id:
|
| 189 |
+
logging.debug(f"Skip OOV {t}")
|
| 190 |
+
continue
|
| 191 |
+
token_ids.append(self.token2id[t])
|
| 192 |
+
|
| 193 |
+
token_ids_list.append(token_ids)
|
| 194 |
+
|
| 195 |
+
return token_ids_list
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class EmiliaTokenizer(Tokenizer):
|
| 199 |
+
def __init__(self, token_file: Optional[str] = None, token_type="phone"):
|
| 200 |
+
"""
|
| 201 |
+
Args:
|
| 202 |
+
tokens: the file that contains information that maps tokens to ids,
|
| 203 |
+
which is a text file with '{token}\t{token_id}' per line.
|
| 204 |
+
"""
|
| 205 |
+
assert (
|
| 206 |
+
token_type == "phone"
|
| 207 |
+
), f"Only support phone tokenizer for Emilia, but get {token_type}."
|
| 208 |
+
|
| 209 |
+
self.english_normalizer = EnglishTextNormalizer()
|
| 210 |
+
self.chinese_normalizer = ChineseTextNormalizer()
|
| 211 |
+
|
| 212 |
+
self.has_tokens = False
|
| 213 |
+
if token_file is None:
|
| 214 |
+
logging.debug(
|
| 215 |
+
"Initialize Tokenizer without tokens file, \
|
| 216 |
+
will fail when map to ids."
|
| 217 |
+
)
|
| 218 |
+
return
|
| 219 |
+
self.token2id: Dict[str, int] = {}
|
| 220 |
+
with open(token_file, "r", encoding="utf-8") as f:
|
| 221 |
+
for line in f.readlines():
|
| 222 |
+
info = line.rstrip().split("\t")
|
| 223 |
+
token, id = info[0], int(info[1])
|
| 224 |
+
assert token not in self.token2id, token
|
| 225 |
+
self.token2id[token] = id
|
| 226 |
+
self.pad_id = self.token2id["_"] # padding
|
| 227 |
+
|
| 228 |
+
self.vocab_size = len(self.token2id)
|
| 229 |
+
self.has_tokens = True
|
| 230 |
+
|
| 231 |
+
def texts_to_token_ids(
|
| 232 |
+
self,
|
| 233 |
+
texts: List[str],
|
| 234 |
+
) -> List[List[int]]:
|
| 235 |
+
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
|
| 236 |
+
|
| 237 |
+
def preprocess_text(
|
| 238 |
+
self,
|
| 239 |
+
text: str,
|
| 240 |
+
) -> str:
|
| 241 |
+
return self.map_punctuations(text)
|
| 242 |
+
|
| 243 |
+
def texts_to_tokens(
|
| 244 |
+
self,
|
| 245 |
+
texts: List[str],
|
| 246 |
+
) -> List[List[str]]:
|
| 247 |
+
for i in range(len(texts)):
|
| 248 |
+
# Text normalization
|
| 249 |
+
texts[i] = self.preprocess_text(texts[i])
|
| 250 |
+
|
| 251 |
+
phoneme_list = []
|
| 252 |
+
for text in texts:
|
| 253 |
+
# now only en and ch
|
| 254 |
+
segments = self.get_segment(text)
|
| 255 |
+
all_phoneme = []
|
| 256 |
+
for index in range(len(segments)):
|
| 257 |
+
seg = segments[index]
|
| 258 |
+
if seg[1] == "zh":
|
| 259 |
+
phoneme = self.tokenize_ZH(seg[0])
|
| 260 |
+
elif seg[1] == "en":
|
| 261 |
+
phoneme = self.tokenize_EN(seg[0])
|
| 262 |
+
elif seg[1] == "pinyin":
|
| 263 |
+
phoneme = self.tokenize_pinyin(seg[0])
|
| 264 |
+
elif seg[1] == "tag":
|
| 265 |
+
phoneme = [seg[0]]
|
| 266 |
+
else:
|
| 267 |
+
logging.warning(
|
| 268 |
+
f"No English or Chinese characters found, \
|
| 269 |
+
skipping segment of unknown language: {seg}"
|
| 270 |
+
)
|
| 271 |
+
continue
|
| 272 |
+
all_phoneme += phoneme
|
| 273 |
+
phoneme_list.append(all_phoneme)
|
| 274 |
+
return phoneme_list
|
| 275 |
+
|
| 276 |
+
def tokens_to_token_ids(
|
| 277 |
+
self,
|
| 278 |
+
tokens_list: List[List[str]],
|
| 279 |
+
) -> List[List[int]]:
|
| 280 |
+
assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
|
| 281 |
+
token_ids_list = []
|
| 282 |
+
|
| 283 |
+
for tokens in tokens_list:
|
| 284 |
+
token_ids = []
|
| 285 |
+
for t in tokens:
|
| 286 |
+
if t not in self.token2id:
|
| 287 |
+
logging.debug(f"Skip OOV {t}")
|
| 288 |
+
continue
|
| 289 |
+
token_ids.append(self.token2id[t])
|
| 290 |
+
|
| 291 |
+
token_ids_list.append(token_ids)
|
| 292 |
+
|
| 293 |
+
return token_ids_list
|
| 294 |
+
|
| 295 |
+
def tokenize_ZH(self, text: str) -> List[str]:
|
| 296 |
+
try:
|
| 297 |
+
text = self.chinese_normalizer.normalize(text)
|
| 298 |
+
segs = list(jieba.cut(text))
|
| 299 |
+
full = lazy_pinyin(
|
| 300 |
+
segs,
|
| 301 |
+
style=Style.TONE3,
|
| 302 |
+
tone_sandhi=True,
|
| 303 |
+
neutral_tone_with_five=True,
|
| 304 |
+
)
|
| 305 |
+
phones = []
|
| 306 |
+
for x in full:
|
| 307 |
+
# valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
|
| 308 |
+
if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
|
| 309 |
+
phones.append(x)
|
| 310 |
+
continue
|
| 311 |
+
else:
|
| 312 |
+
phones.extend(self.seperate_pinyin(x))
|
| 313 |
+
return phones
|
| 314 |
+
except Exception as ex:
|
| 315 |
+
logging.warning(f"Tokenization of Chinese texts failed: {ex}")
|
| 316 |
+
return []
|
| 317 |
+
|
| 318 |
+
def tokenize_EN(self, text: str) -> List[str]:
|
| 319 |
+
try:
|
| 320 |
+
text = self.english_normalizer.normalize(text)
|
| 321 |
+
tokens = phonemize_espeak(text, "en-us")
|
| 322 |
+
tokens = reduce(lambda x, y: x + y, tokens)
|
| 323 |
+
return tokens
|
| 324 |
+
except Exception as ex:
|
| 325 |
+
logging.warning(f"Tokenization of English texts failed: {ex}")
|
| 326 |
+
return []
|
| 327 |
+
|
| 328 |
+
def tokenize_pinyin(self, text: str) -> List[str]:
|
| 329 |
+
try:
|
| 330 |
+
assert text.startswith("<") and text.endswith(">")
|
| 331 |
+
text = text.lstrip("<").rstrip(">")
|
| 332 |
+
# valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
|
| 333 |
+
if not (text[0:-1].isalpha() and text[-1] in ("1", "2", "3", "4", "5")):
|
| 334 |
+
logging.warning(
|
| 335 |
+
f"Strings enclosed with <> should be pinyin, \
|
| 336 |
+
but got: {text}. Skipped it. "
|
| 337 |
+
)
|
| 338 |
+
return []
|
| 339 |
+
else:
|
| 340 |
+
return self.seperate_pinyin(text)
|
| 341 |
+
except Exception as ex:
|
| 342 |
+
logging.warning(f"Tokenize pinyin failed: {ex}")
|
| 343 |
+
return []
|
| 344 |
+
|
| 345 |
+
def seperate_pinyin(self, text: str) -> List[str]:
|
| 346 |
+
"""
|
| 347 |
+
Separate pinyin into initial and final
|
| 348 |
+
"""
|
| 349 |
+
pinyins = []
|
| 350 |
+
initial = to_initials(text, strict=False)
|
| 351 |
+
# don't want to share tokens with espeak tokens,
|
| 352 |
+
# so use tone3 style
|
| 353 |
+
final = to_finals_tone3(
|
| 354 |
+
text,
|
| 355 |
+
strict=False,
|
| 356 |
+
neutral_tone_with_five=True,
|
| 357 |
+
)
|
| 358 |
+
if initial != "":
|
| 359 |
+
# don't want to share tokens with espeak tokens,
|
| 360 |
+
# so add a '0' after each initial
|
| 361 |
+
pinyins.append(initial + "0")
|
| 362 |
+
if final != "":
|
| 363 |
+
pinyins.append(final)
|
| 364 |
+
return pinyins
|
| 365 |
+
|
| 366 |
+
def map_punctuations(self, text):
|
| 367 |
+
text = text.replace(",", ",")
|
| 368 |
+
text = text.replace("。", ".")
|
| 369 |
+
text = text.replace("!", "!")
|
| 370 |
+
text = text.replace("?", "?")
|
| 371 |
+
text = text.replace(";", ";")
|
| 372 |
+
text = text.replace(":", ":")
|
| 373 |
+
text = text.replace("、", ",")
|
| 374 |
+
text = text.replace("‘", "'")
|
| 375 |
+
text = text.replace("“", '"')
|
| 376 |
+
text = text.replace("”", '"')
|
| 377 |
+
text = text.replace("’", "'")
|
| 378 |
+
text = text.replace("⋯", "…")
|
| 379 |
+
text = text.replace("···", "…")
|
| 380 |
+
text = text.replace("・・・", "…")
|
| 381 |
+
text = text.replace("...", "…")
|
| 382 |
+
return text
|
| 383 |
+
|
| 384 |
+
def get_segment(self, text: str) -> List[str]:
|
| 385 |
+
"""
|
| 386 |
+
Split a text into segments based on language types
|
| 387 |
+
(Chinese, English, Pinyin, tags, etc.)
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
text (str): Input text to be segmented
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
List[str]: Segmented text parts with their language types
|
| 394 |
+
|
| 395 |
+
Example:
|
| 396 |
+
Input: 我们是小米人,是吗? Yes I think so!霍...啦啦啦
|
| 397 |
+
Output: [('我们是小米人,是吗? ', 'zh'),
|
| 398 |
+
('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
|
| 399 |
+
"""
|
| 400 |
+
# Stores the final segmented parts and their language types
|
| 401 |
+
segments = []
|
| 402 |
+
# Stores the language type of each character in the input text
|
| 403 |
+
types = []
|
| 404 |
+
temp_seg = ""
|
| 405 |
+
temp_lang = ""
|
| 406 |
+
|
| 407 |
+
# Each part is a character, or a special string enclosed in <> and []
|
| 408 |
+
# <> denotes pinyin string, [] denotes other special strings.
|
| 409 |
+
_part_pattern = re.compile(r"[<[].*?[>\]]|.")
|
| 410 |
+
text = _part_pattern.findall(text)
|
| 411 |
+
|
| 412 |
+
for i, part in enumerate(text):
|
| 413 |
+
if self.is_chinese(part) or self.is_pinyin(part):
|
| 414 |
+
types.append("zh")
|
| 415 |
+
elif self.is_alphabet(part):
|
| 416 |
+
types.append("en")
|
| 417 |
+
else:
|
| 418 |
+
types.append("other")
|
| 419 |
+
|
| 420 |
+
assert len(types) == len(text)
|
| 421 |
+
|
| 422 |
+
for i in range(len(types)):
|
| 423 |
+
# find the first char of the seg
|
| 424 |
+
if i == 0:
|
| 425 |
+
temp_seg += text[i]
|
| 426 |
+
temp_lang = types[i]
|
| 427 |
+
else:
|
| 428 |
+
if temp_lang == "other":
|
| 429 |
+
temp_seg += text[i]
|
| 430 |
+
temp_lang = types[i]
|
| 431 |
+
else:
|
| 432 |
+
if types[i] in [temp_lang, "other"]:
|
| 433 |
+
temp_seg += text[i]
|
| 434 |
+
else:
|
| 435 |
+
segments.append((temp_seg, temp_lang))
|
| 436 |
+
temp_seg = text[i]
|
| 437 |
+
temp_lang = types[i]
|
| 438 |
+
|
| 439 |
+
segments.append((temp_seg, temp_lang))
|
| 440 |
+
|
| 441 |
+
# Handle "pinyin" and "tag" types
|
| 442 |
+
segments = self.split_segments(segments)
|
| 443 |
+
return segments
|
| 444 |
+
|
| 445 |
+
def split_segments(self, segments):
|
| 446 |
+
"""
|
| 447 |
+
split segments into smaller parts if special strings enclosed by [] or <>
|
| 448 |
+
are found, where <> denotes pinyin strings, [] denotes other special strings.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
segments (list): A list of tuples where each tuple contains:
|
| 452 |
+
- temp_seg (str): The text segment to be split.
|
| 453 |
+
- temp_lang (str): The language code associated with the segment.
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
list: A list of smaller segments.
|
| 457 |
+
"""
|
| 458 |
+
result = []
|
| 459 |
+
for temp_seg, temp_lang in segments:
|
| 460 |
+
parts = re.split(r"([<[].*?[>\]])", temp_seg)
|
| 461 |
+
for part in parts:
|
| 462 |
+
if not part:
|
| 463 |
+
continue
|
| 464 |
+
if self.is_pinyin(part):
|
| 465 |
+
result.append((part, "pinyin"))
|
| 466 |
+
elif self.is_tag(part):
|
| 467 |
+
result.append((part, "tag"))
|
| 468 |
+
else:
|
| 469 |
+
result.append((part, temp_lang))
|
| 470 |
+
return result
|
| 471 |
+
|
| 472 |
+
def is_chinese(self, char: str) -> bool:
|
| 473 |
+
if char >= "\u4e00" and char <= "\u9fa5":
|
| 474 |
+
return True
|
| 475 |
+
else:
|
| 476 |
+
return False
|
| 477 |
+
|
| 478 |
+
def is_alphabet(self, char: str) -> bool:
|
| 479 |
+
if (char >= "\u0041" and char <= "\u005a") or (
|
| 480 |
+
char >= "\u0061" and char <= "\u007a"
|
| 481 |
+
):
|
| 482 |
+
return True
|
| 483 |
+
else:
|
| 484 |
+
return False
|
| 485 |
+
|
| 486 |
+
def is_pinyin(self, part: str) -> bool:
|
| 487 |
+
if part.startswith("<") and part.endswith(">"):
|
| 488 |
+
return True
|
| 489 |
+
else:
|
| 490 |
+
return False
|
| 491 |
+
|
| 492 |
+
def is_tag(self, part: str) -> bool:
|
| 493 |
+
if part.startswith("[") and part.endswith("]"):
|
| 494 |
+
return True
|
| 495 |
+
else:
|
| 496 |
+
return False
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class DialogTokenizer(EmiliaTokenizer):
|
| 500 |
+
def __init__(self, token_file: Optional[str] = None, token_type="phone"):
|
| 501 |
+
super().__init__(token_file=token_file, token_type=token_type)
|
| 502 |
+
self.spk_a_id = self.token2id["[S1]"]
|
| 503 |
+
self.spk_b_id = self.token2id["[S2]"]
|
| 504 |
+
|
| 505 |
+
def preprocess_text(
|
| 506 |
+
self,
|
| 507 |
+
text: str,
|
| 508 |
+
) -> str:
|
| 509 |
+
text = re.sub(r"\s*(\[S[12]\])\s*", r"\1", text)
|
| 510 |
+
text = self.map_punctuations(text)
|
| 511 |
+
return text
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class LibriTTSTokenizer(Tokenizer):
|
| 515 |
+
def __init__(self, token_file: Optional[str] = None, token_type="char"):
|
| 516 |
+
"""
|
| 517 |
+
Args:
|
| 518 |
+
type: the type of tokenizer, e.g., bpe, char, phone.
|
| 519 |
+
tokens: the file that contains information that maps tokens to ids,
|
| 520 |
+
which is a text file with '{token}\t{token_id}' per line if type is
|
| 521 |
+
char or phone, otherwise it is a bpe_model file.
|
| 522 |
+
"""
|
| 523 |
+
self.type = token_type
|
| 524 |
+
assert token_type in ["bpe", "char", "phone"]
|
| 525 |
+
try:
|
| 526 |
+
import tacotron_cleaner.cleaners
|
| 527 |
+
except Exception as ex:
|
| 528 |
+
raise RuntimeError(f"{ex}\nPlease run\n" "pip install espnet_tts_frontend")
|
| 529 |
+
|
| 530 |
+
self.normalize = tacotron_cleaner.cleaners.custom_english_cleaners
|
| 531 |
+
|
| 532 |
+
self.has_tokens = False
|
| 533 |
+
if token_file is None:
|
| 534 |
+
logging.debug(
|
| 535 |
+
"Initialize Tokenizer without tokens file, \
|
| 536 |
+
will fail when map to ids."
|
| 537 |
+
)
|
| 538 |
+
return
|
| 539 |
+
if token_type == "bpe":
|
| 540 |
+
import sentencepiece as spm
|
| 541 |
+
|
| 542 |
+
self.sp = spm.SentencePieceProcessor()
|
| 543 |
+
self.sp.load(token_file)
|
| 544 |
+
self.pad_id = self.sp.piece_to_id("<pad>")
|
| 545 |
+
self.vocab_size = self.sp.get_piece_size()
|
| 546 |
+
else:
|
| 547 |
+
self.token2id: Dict[str, int] = {}
|
| 548 |
+
with open(token_file, "r", encoding="utf-8") as f:
|
| 549 |
+
for line in f.readlines():
|
| 550 |
+
info = line.rstrip().split("\t")
|
| 551 |
+
token, id = info[0], int(info[1])
|
| 552 |
+
assert token not in self.token2id, token
|
| 553 |
+
self.token2id[token] = id
|
| 554 |
+
self.pad_id = self.token2id["_"] # padding
|
| 555 |
+
self.vocab_size = len(self.token2id)
|
| 556 |
+
self.has_tokens = True
|
| 557 |
+
|
| 558 |
+
def texts_to_token_ids(
|
| 559 |
+
self,
|
| 560 |
+
texts: List[str],
|
| 561 |
+
) -> List[List[int]]:
|
| 562 |
+
if self.type == "bpe":
|
| 563 |
+
for i in range(len(texts)):
|
| 564 |
+
texts[i] = self.normalize(texts[i])
|
| 565 |
+
return self.sp.encode(texts)
|
| 566 |
+
else:
|
| 567 |
+
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
|
| 568 |
+
|
| 569 |
+
def texts_to_tokens(
|
| 570 |
+
self,
|
| 571 |
+
texts: List[str],
|
| 572 |
+
) -> List[List[str]]:
|
| 573 |
+
for i in range(len(texts)):
|
| 574 |
+
texts[i] = self.normalize(texts[i])
|
| 575 |
+
|
| 576 |
+
if self.type == "char":
|
| 577 |
+
tokens_list = [list(texts[i]) for i in range(len(texts))]
|
| 578 |
+
elif self.type == "phone":
|
| 579 |
+
tokens_list = [
|
| 580 |
+
phonemize_espeak(texts[i].lower(), "en-us") for i in range(len(texts))
|
| 581 |
+
]
|
| 582 |
+
elif self.type == "bpe":
|
| 583 |
+
tokens_list = self.sp.encode(texts, out_type=str)
|
| 584 |
+
|
| 585 |
+
return tokens_list
|
| 586 |
+
|
| 587 |
+
def tokens_to_token_ids(
|
| 588 |
+
self,
|
| 589 |
+
tokens_list: List[List[str]],
|
| 590 |
+
) -> List[List[int]]:
|
| 591 |
+
assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
|
| 592 |
+
|
| 593 |
+
assert self.type != "bpe", "BPE tokenizer does not support this function."
|
| 594 |
+
|
| 595 |
+
token_ids_list = []
|
| 596 |
+
|
| 597 |
+
for tokens in tokens_list:
|
| 598 |
+
token_ids = []
|
| 599 |
+
for t in tokens:
|
| 600 |
+
if t not in self.token2id:
|
| 601 |
+
logging.debug(f"Skip OOV {t}")
|
| 602 |
+
continue
|
| 603 |
+
token_ids.append(self.token2id[t])
|
| 604 |
+
|
| 605 |
+
token_ids_list.append(token_ids)
|
| 606 |
+
|
| 607 |
+
return token_ids_list
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
if __name__ == "__main__":
|
| 611 |
+
text = (
|
| 612 |
+
"我们是5年小米人,是吗? Yes I think so! "
|
| 613 |
+
"mr king, 5 years, from 2019 to 2024."
|
| 614 |
+
"霍...啦啦啦超过90%的人<le5>...?!9204"
|
| 615 |
+
)
|
| 616 |
+
tokenizer = EmiliaTokenizer()
|
| 617 |
+
tokens = tokenizer.texts_to_tokens([text])
|
| 618 |
+
print(f"tokens: {'|'.join(tokens[0])}")
|
zipvoice/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021-2025 Xiaomi Corporation (authors: Fangjun Kuang,
|
| 2 |
+
# Zengwei Yao)
|
| 3 |
+
#
|
| 4 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import glob
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, List, Optional, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
from lhotse.dataset.sampling.base import CutSampler
|
| 28 |
+
from torch.cuda.amp import GradScaler
|
| 29 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 30 |
+
from torch.optim import Optimizer
|
| 31 |
+
|
| 32 |
+
from zipvoice.utils.common import AttributeDict
|
| 33 |
+
|
| 34 |
+
# use duck typing for LRScheduler since we have different possibilities, see
|
| 35 |
+
# our class LRScheduler.
|
| 36 |
+
LRSchedulerType = object
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def save_checkpoint(
|
| 40 |
+
filename: Path,
|
| 41 |
+
model: Union[nn.Module, DDP],
|
| 42 |
+
model_avg: Optional[nn.Module] = None,
|
| 43 |
+
model_ema: Optional[nn.Module] = None,
|
| 44 |
+
params: Optional[Dict[str, Any]] = None,
|
| 45 |
+
optimizer: Optional[Optimizer] = None,
|
| 46 |
+
scheduler: Optional[LRSchedulerType] = None,
|
| 47 |
+
scaler: Optional[GradScaler] = None,
|
| 48 |
+
sampler: Optional[CutSampler] = None,
|
| 49 |
+
rank: int = 0,
|
| 50 |
+
) -> None:
|
| 51 |
+
"""Save training information to a file.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
filename:
|
| 55 |
+
The checkpoint filename.
|
| 56 |
+
model:
|
| 57 |
+
The model to be saved. We only save its `state_dict()`.
|
| 58 |
+
model_avg:
|
| 59 |
+
The stored model averaged from the start of training.
|
| 60 |
+
model_ema:
|
| 61 |
+
The EMA version of model.
|
| 62 |
+
params:
|
| 63 |
+
User defined parameters, e.g., epoch, loss.
|
| 64 |
+
optimizer:
|
| 65 |
+
The optimizer to be saved. We only save its `state_dict()`.
|
| 66 |
+
scheduler:
|
| 67 |
+
The scheduler to be saved. We only save its `state_dict()`.
|
| 68 |
+
scalar:
|
| 69 |
+
The GradScaler to be saved. We only save its `state_dict()`.
|
| 70 |
+
sampler:
|
| 71 |
+
The sampler used in the labeled training dataset. We only
|
| 72 |
+
save its `state_dict()`.
|
| 73 |
+
rank:
|
| 74 |
+
Used in DDP. We save checkpoint only for the node whose
|
| 75 |
+
rank is 0.
|
| 76 |
+
Returns:
|
| 77 |
+
Return None.
|
| 78 |
+
"""
|
| 79 |
+
if rank != 0:
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
logging.info(f"Saving checkpoint to {filename}")
|
| 83 |
+
|
| 84 |
+
if isinstance(model, DDP):
|
| 85 |
+
model = model.module
|
| 86 |
+
|
| 87 |
+
checkpoint = {
|
| 88 |
+
"model": model.state_dict(),
|
| 89 |
+
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
| 90 |
+
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
| 91 |
+
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
| 92 |
+
"sampler": sampler.state_dict() if sampler is not None else None,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
if model_avg is not None:
|
| 96 |
+
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
| 97 |
+
if model_ema is not None:
|
| 98 |
+
checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
|
| 99 |
+
|
| 100 |
+
if params:
|
| 101 |
+
for k, v in params.items():
|
| 102 |
+
assert k not in checkpoint
|
| 103 |
+
checkpoint[k] = v
|
| 104 |
+
|
| 105 |
+
torch.save(checkpoint, filename)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def load_checkpoint(
|
| 109 |
+
filename: Path,
|
| 110 |
+
model: Optional[nn.Module] = None,
|
| 111 |
+
model_avg: Optional[nn.Module] = None,
|
| 112 |
+
model_ema: Optional[nn.Module] = None,
|
| 113 |
+
strict: bool = False,
|
| 114 |
+
) -> Dict[str, Any]:
|
| 115 |
+
logging.info(f"Loading checkpoint from {filename}")
|
| 116 |
+
checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
|
| 117 |
+
|
| 118 |
+
if model is not None:
|
| 119 |
+
|
| 120 |
+
if next(iter(checkpoint["model"])).startswith("module."):
|
| 121 |
+
logging.info("Loading checkpoint saved by DDP")
|
| 122 |
+
|
| 123 |
+
dst_state_dict = model.state_dict()
|
| 124 |
+
src_state_dict = checkpoint["model"]
|
| 125 |
+
for key in dst_state_dict.keys():
|
| 126 |
+
src_key = "{}.{}".format("module", key)
|
| 127 |
+
dst_state_dict[key] = src_state_dict.pop(src_key)
|
| 128 |
+
assert len(src_state_dict) == 0
|
| 129 |
+
model.load_state_dict(dst_state_dict, strict=strict)
|
| 130 |
+
else:
|
| 131 |
+
logging.info("Loading checkpoint")
|
| 132 |
+
model.load_state_dict(checkpoint["model"], strict=strict)
|
| 133 |
+
|
| 134 |
+
checkpoint.pop("model")
|
| 135 |
+
|
| 136 |
+
if model_avg is not None and "model_avg" in checkpoint:
|
| 137 |
+
logging.info("Loading averaged model")
|
| 138 |
+
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
|
| 139 |
+
checkpoint.pop("model_avg")
|
| 140 |
+
|
| 141 |
+
if model_ema is not None and "model_ema" in checkpoint:
|
| 142 |
+
logging.info("Loading ema model")
|
| 143 |
+
model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
|
| 144 |
+
checkpoint.pop("model_ema")
|
| 145 |
+
|
| 146 |
+
return checkpoint
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_checkpoint_extend_vocab_size(
|
| 150 |
+
filename: Path, extend_size: int, model: nn.Module, strict: bool = True
|
| 151 |
+
) -> Dict[str, Any]:
|
| 152 |
+
logging.info(f"Loading checkpoint from {filename}")
|
| 153 |
+
checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
|
| 154 |
+
|
| 155 |
+
if model is not None:
|
| 156 |
+
if next(iter(checkpoint["model"])).startswith("module."):
|
| 157 |
+
logging.info("Loading checkpoint saved by DDP")
|
| 158 |
+
dst_state_dict = model.state_dict()
|
| 159 |
+
src_state_dict = checkpoint["model"]
|
| 160 |
+
for key in dst_state_dict.keys():
|
| 161 |
+
src_key = "{}.{}".format("module", key)
|
| 162 |
+
dst_state_dict[key] = src_state_dict.pop(src_key)
|
| 163 |
+
assert len(src_state_dict) == 0
|
| 164 |
+
else:
|
| 165 |
+
logging.info("Loading checkpoint")
|
| 166 |
+
dst_state_dict = checkpoint["model"]
|
| 167 |
+
dst_state_dict["spk_embed.weight"] = model.state_dict()["spk_embed.weight"]
|
| 168 |
+
embed_weight = model.state_dict()["embed.weight"]
|
| 169 |
+
embed_weight[:-extend_size, :] = dst_state_dict["embed.weight"]
|
| 170 |
+
dst_state_dict["embed.weight"] = embed_weight
|
| 171 |
+
|
| 172 |
+
model.load_state_dict(dst_state_dict, strict=strict)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def load_checkpoint_copy_proj_three_channel_alter(
|
| 176 |
+
filename: Path,
|
| 177 |
+
in_proj_key: str,
|
| 178 |
+
out_proj_key: str,
|
| 179 |
+
dim: int,
|
| 180 |
+
model: nn.Module,
|
| 181 |
+
) -> Dict[str, Any]:
|
| 182 |
+
logging.info(f"Loading checkpoint from {filename}")
|
| 183 |
+
checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
|
| 184 |
+
|
| 185 |
+
if model is not None:
|
| 186 |
+
if next(iter(checkpoint["model"])).startswith("module."):
|
| 187 |
+
logging.info("Loading checkpoint saved by DDP")
|
| 188 |
+
|
| 189 |
+
dst_state_dict = dict()
|
| 190 |
+
src_state_dict = checkpoint["model"]
|
| 191 |
+
for key in src_state_dict.keys():
|
| 192 |
+
dst_state_dict[key.lstrip("module.")] = src_state_dict.pop(key)
|
| 193 |
+
assert len(src_state_dict) == 0
|
| 194 |
+
else:
|
| 195 |
+
logging.info("Loading checkpoint")
|
| 196 |
+
dst_state_dict = checkpoint["model"]
|
| 197 |
+
keys = list(dst_state_dict.keys())
|
| 198 |
+
for key in keys:
|
| 199 |
+
if in_proj_key in key:
|
| 200 |
+
if "weight" in key:
|
| 201 |
+
weight = dst_state_dict.pop(key)
|
| 202 |
+
dst_state_dict[key.replace("weight", "0.weight")] = torch.cat(
|
| 203 |
+
[
|
| 204 |
+
weight[:, :dim] / 2,
|
| 205 |
+
weight[:, :dim] / 2,
|
| 206 |
+
weight[:, dim : dim * 2],
|
| 207 |
+
weight[:, dim * 2 :] / 2,
|
| 208 |
+
weight[:, dim * 2 :] / 2,
|
| 209 |
+
],
|
| 210 |
+
dim=-1,
|
| 211 |
+
)
|
| 212 |
+
dst_state_dict[key.replace("weight", "1.weight")] = weight
|
| 213 |
+
if "bias" in key:
|
| 214 |
+
bias = dst_state_dict.pop(key)
|
| 215 |
+
dst_state_dict[key.replace("bias", "0.bias")] = bias
|
| 216 |
+
dst_state_dict[key.replace("bias", "1.bias")] = bias
|
| 217 |
+
if out_proj_key in key:
|
| 218 |
+
if "weight" in key:
|
| 219 |
+
weight = dst_state_dict.pop(key)
|
| 220 |
+
dst_state_dict[key.replace("weight", "0.weight")] = torch.cat(
|
| 221 |
+
[weight, weight], dim=0
|
| 222 |
+
)
|
| 223 |
+
dst_state_dict[key.replace("weight", "1.weight")] = weight
|
| 224 |
+
elif "bias" in key:
|
| 225 |
+
bias = dst_state_dict.pop(key)
|
| 226 |
+
dst_state_dict[key.replace("bias", "0.bias")] = torch.cat(
|
| 227 |
+
[bias, bias], dim=0
|
| 228 |
+
)
|
| 229 |
+
dst_state_dict[key.replace("bias", "1.bias")] = bias
|
| 230 |
+
|
| 231 |
+
model.load_state_dict(dst_state_dict, strict=True)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
|
| 235 |
+
"""Find all available checkpoints in a directory.
|
| 236 |
+
|
| 237 |
+
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
| 238 |
+
where xxx is a numerical value.
|
| 239 |
+
|
| 240 |
+
Assume you have the following checkpoints in the folder `foo`:
|
| 241 |
+
|
| 242 |
+
- checkpoint-1.pt
|
| 243 |
+
- checkpoint-20.pt
|
| 244 |
+
- checkpoint-300.pt
|
| 245 |
+
- checkpoint-4000.pt
|
| 246 |
+
|
| 247 |
+
Case 1 (Return all checkpoints)::
|
| 248 |
+
|
| 249 |
+
find_checkpoints(out_dir='foo')
|
| 250 |
+
|
| 251 |
+
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
|
| 252 |
+
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
|
| 253 |
+
|
| 254 |
+
find_checkpoints(out_dir='foo', iteration=20)
|
| 255 |
+
|
| 256 |
+
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
|
| 257 |
+
checkpoint-20.pt, checkpoint-1.pt)::
|
| 258 |
+
|
| 259 |
+
find_checkpoints(out_dir='foo', iteration=-20)
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
out_dir:
|
| 263 |
+
The directory where to search for checkpoints.
|
| 264 |
+
iteration:
|
| 265 |
+
If it is 0, return all available checkpoints.
|
| 266 |
+
If it is positive, return the checkpoints whose iteration number is
|
| 267 |
+
greater than or equal to `iteration`.
|
| 268 |
+
If it is negative, return the checkpoints whose iteration number is
|
| 269 |
+
less than or equal to `-iteration`.
|
| 270 |
+
Returns:
|
| 271 |
+
Return a list of checkpoint filenames, sorted in descending
|
| 272 |
+
order by the numerical value in the filename.
|
| 273 |
+
"""
|
| 274 |
+
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
| 275 |
+
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
| 276 |
+
iter_checkpoints = []
|
| 277 |
+
for c in checkpoints:
|
| 278 |
+
result = pattern.search(c)
|
| 279 |
+
if not result:
|
| 280 |
+
logging.warn(f"Invalid checkpoint filename {c}")
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
iter_checkpoints.append((int(result.group(1)), c))
|
| 284 |
+
|
| 285 |
+
# iter_checkpoints is a list of tuples. Each tuple contains
|
| 286 |
+
# two elements: (iteration_number, checkpoint-iteration_number.pt)
|
| 287 |
+
|
| 288 |
+
iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
|
| 289 |
+
if iteration >= 0:
|
| 290 |
+
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
|
| 291 |
+
else:
|
| 292 |
+
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
|
| 293 |
+
|
| 294 |
+
return ans
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def average_checkpoints_with_averaged_model(
|
| 298 |
+
filename_start: str,
|
| 299 |
+
filename_end: str,
|
| 300 |
+
device: torch.device = torch.device("cpu"),
|
| 301 |
+
) -> Dict[str, torch.Tensor]:
|
| 302 |
+
"""Average model parameters over the range with given
|
| 303 |
+
start model (excluded) and end model.
|
| 304 |
+
|
| 305 |
+
Let start = batch_idx_train of model-start;
|
| 306 |
+
end = batch_idx_train of model-end;
|
| 307 |
+
interval = end - start.
|
| 308 |
+
Then the average model over range from start (excluded) to end is
|
| 309 |
+
(1) avg = (model_end * end - model_start * start) / interval.
|
| 310 |
+
It can be written as
|
| 311 |
+
(2) avg = model_end * weight_end + model_start * weight_start,
|
| 312 |
+
where weight_end = end / interval,
|
| 313 |
+
weight_start = -start / interval = 1 - weight_end.
|
| 314 |
+
Since the terms `weight_end` and `weight_start` would be large
|
| 315 |
+
if the model has been trained for lots of batches, which would cause
|
| 316 |
+
overflow when multiplying the model parameters.
|
| 317 |
+
To avoid this, we rewrite (2) as:
|
| 318 |
+
(3) avg = (model_end + model_start * (weight_start / weight_end))
|
| 319 |
+
* weight_end
|
| 320 |
+
|
| 321 |
+
The model index could be epoch number or iteration number.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
filename_start:
|
| 325 |
+
Checkpoint filename of the start model. We assume it
|
| 326 |
+
is saved by :func:`save_checkpoint`.
|
| 327 |
+
filename_end:
|
| 328 |
+
Checkpoint filename of the end model. We assume it
|
| 329 |
+
is saved by :func:`save_checkpoint`.
|
| 330 |
+
device:
|
| 331 |
+
Move checkpoints to this device before averaging.
|
| 332 |
+
"""
|
| 333 |
+
state_dict_start = torch.load(
|
| 334 |
+
filename_start, map_location=device, weights_only=False
|
| 335 |
+
)
|
| 336 |
+
state_dict_end = torch.load(filename_end, map_location=device, weights_only=False)
|
| 337 |
+
|
| 338 |
+
average_period = state_dict_start["average_period"]
|
| 339 |
+
|
| 340 |
+
batch_idx_train_start = state_dict_start["batch_idx_train"]
|
| 341 |
+
batch_idx_train_start = (batch_idx_train_start // average_period) * average_period
|
| 342 |
+
batch_idx_train_end = state_dict_end["batch_idx_train"]
|
| 343 |
+
batch_idx_train_end = (batch_idx_train_end // average_period) * average_period
|
| 344 |
+
interval = batch_idx_train_end - batch_idx_train_start
|
| 345 |
+
assert interval > 0, interval
|
| 346 |
+
weight_end = batch_idx_train_end / interval
|
| 347 |
+
weight_start = 1 - weight_end
|
| 348 |
+
|
| 349 |
+
model_end = state_dict_end["model_avg"]
|
| 350 |
+
model_start = state_dict_start["model_avg"]
|
| 351 |
+
avg = model_end
|
| 352 |
+
|
| 353 |
+
# scale the weight to avoid overflow
|
| 354 |
+
average_state_dict(
|
| 355 |
+
state_dict_1=avg,
|
| 356 |
+
state_dict_2=model_start,
|
| 357 |
+
weight_1=1.0,
|
| 358 |
+
weight_2=weight_start / weight_end,
|
| 359 |
+
scaling_factor=weight_end,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
return avg
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def remove_checkpoints(
|
| 366 |
+
out_dir: Path,
|
| 367 |
+
topk: int,
|
| 368 |
+
rank: int = 0,
|
| 369 |
+
):
|
| 370 |
+
"""Remove checkpoints from the given directory.
|
| 371 |
+
|
| 372 |
+
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
| 373 |
+
where xxx is a number, representing the number of processed batches
|
| 374 |
+
when saving that checkpoint. We sort checkpoints by filename and keep
|
| 375 |
+
only the `topk` checkpoints with the highest `xxx`.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
out_dir:
|
| 379 |
+
The directory containing checkpoints to be removed.
|
| 380 |
+
topk:
|
| 381 |
+
Number of checkpoints to keep.
|
| 382 |
+
rank:
|
| 383 |
+
If using DDP for training, it is the rank of the current node.
|
| 384 |
+
Use 0 if no DDP is used for training.
|
| 385 |
+
"""
|
| 386 |
+
assert topk >= 1, topk
|
| 387 |
+
if rank != 0:
|
| 388 |
+
return
|
| 389 |
+
checkpoints = find_checkpoints(out_dir)
|
| 390 |
+
|
| 391 |
+
if len(checkpoints) == 0:
|
| 392 |
+
logging.warn(f"No checkpoints found in {out_dir}")
|
| 393 |
+
return
|
| 394 |
+
|
| 395 |
+
if len(checkpoints) <= topk:
|
| 396 |
+
return
|
| 397 |
+
|
| 398 |
+
to_remove = checkpoints[topk:]
|
| 399 |
+
for c in to_remove:
|
| 400 |
+
os.remove(c)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def resume_checkpoint(
|
| 404 |
+
params: AttributeDict,
|
| 405 |
+
model: nn.Module,
|
| 406 |
+
model_avg: nn.Module,
|
| 407 |
+
model_ema: Optional[nn.Module] = None,
|
| 408 |
+
) -> Optional[Dict[str, Any]]:
|
| 409 |
+
"""Load checkpoint from file.
|
| 410 |
+
|
| 411 |
+
If params.start_epoch is larger than 1, it will load the checkpoint from
|
| 412 |
+
`params.start_epoch - 1`.
|
| 413 |
+
|
| 414 |
+
Apart from loading state dict for `model` and `optimizer` it also updates
|
| 415 |
+
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
| 416 |
+
and `best_valid_loss` in `params`.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
params:
|
| 420 |
+
The return value of :func:`get_params`.
|
| 421 |
+
model:
|
| 422 |
+
The training model.
|
| 423 |
+
Returns:
|
| 424 |
+
Return a dict containing previously saved training info.
|
| 425 |
+
"""
|
| 426 |
+
filename = params.exp_dir / f"epoch-{params.start_epoch - 1}.pt"
|
| 427 |
+
|
| 428 |
+
assert filename.is_file(), f"{filename} does not exist!"
|
| 429 |
+
|
| 430 |
+
saved_params = load_checkpoint(
|
| 431 |
+
filename,
|
| 432 |
+
model=model,
|
| 433 |
+
model_avg=model_avg,
|
| 434 |
+
model_ema=model_ema,
|
| 435 |
+
strict=True,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if params.start_epoch > 1:
|
| 439 |
+
keys = [
|
| 440 |
+
"best_train_epoch",
|
| 441 |
+
"best_valid_epoch",
|
| 442 |
+
"batch_idx_train",
|
| 443 |
+
"best_train_loss",
|
| 444 |
+
"best_valid_loss",
|
| 445 |
+
]
|
| 446 |
+
for k in keys:
|
| 447 |
+
params[k] = saved_params[k]
|
| 448 |
+
|
| 449 |
+
return saved_params
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def average_state_dict(
|
| 453 |
+
state_dict_1: Dict[str, torch.Tensor],
|
| 454 |
+
state_dict_2: Dict[str, torch.Tensor],
|
| 455 |
+
weight_1: float,
|
| 456 |
+
weight_2: float,
|
| 457 |
+
scaling_factor: float = 1.0,
|
| 458 |
+
) -> Dict[str, torch.Tensor]:
|
| 459 |
+
"""Average two state_dict with given weights:
|
| 460 |
+
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
|
| 461 |
+
* scaling_factor
|
| 462 |
+
It is an in-place operation on state_dict_1 itself.
|
| 463 |
+
"""
|
| 464 |
+
# Identify shared parameters. Two parameters are said to be shared
|
| 465 |
+
# if they have the same data_ptr
|
| 466 |
+
uniqued: Dict[int, str] = dict()
|
| 467 |
+
for k, v in state_dict_1.items():
|
| 468 |
+
v_data_ptr = v.data_ptr()
|
| 469 |
+
if v_data_ptr in uniqued:
|
| 470 |
+
continue
|
| 471 |
+
uniqued[v_data_ptr] = k
|
| 472 |
+
|
| 473 |
+
uniqued_names = list(uniqued.values())
|
| 474 |
+
for k in uniqued_names:
|
| 475 |
+
v = state_dict_1[k]
|
| 476 |
+
if torch.is_floating_point(v):
|
| 477 |
+
v *= weight_1
|
| 478 |
+
v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
|
| 479 |
+
v *= scaling_factor
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def update_averaged_model(
|
| 483 |
+
params: Dict[str, torch.Tensor],
|
| 484 |
+
model_cur: Union[nn.Module, DDP],
|
| 485 |
+
model_avg: nn.Module,
|
| 486 |
+
) -> None:
|
| 487 |
+
"""Update the averaged model:
|
| 488 |
+
model_avg = model_cur * (average_period / batch_idx_train)
|
| 489 |
+
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
params:
|
| 493 |
+
User defined parameters, e.g., epoch, loss.
|
| 494 |
+
model_cur:
|
| 495 |
+
The current model.
|
| 496 |
+
model_avg:
|
| 497 |
+
The averaged model to be updated.
|
| 498 |
+
"""
|
| 499 |
+
weight_cur = params.average_period / params.batch_idx_train
|
| 500 |
+
weight_avg = 1 - weight_cur
|
| 501 |
+
|
| 502 |
+
if isinstance(model_cur, DDP):
|
| 503 |
+
model_cur = model_cur.module
|
| 504 |
+
|
| 505 |
+
cur = model_cur.state_dict()
|
| 506 |
+
avg = model_avg.state_dict()
|
| 507 |
+
|
| 508 |
+
average_state_dict(
|
| 509 |
+
state_dict_1=avg,
|
| 510 |
+
state_dict_2=cur,
|
| 511 |
+
weight_1=weight_avg,
|
| 512 |
+
weight_2=weight_cur,
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def save_checkpoint_with_global_batch_idx(
|
| 517 |
+
out_dir: Path,
|
| 518 |
+
global_batch_idx: int,
|
| 519 |
+
model: Union[nn.Module, DDP],
|
| 520 |
+
model_avg: Optional[nn.Module] = None,
|
| 521 |
+
params: Optional[Dict[str, Any]] = None,
|
| 522 |
+
optimizer: Optional[Optimizer] = None,
|
| 523 |
+
scheduler: Optional[LRSchedulerType] = None,
|
| 524 |
+
scaler: Optional[GradScaler] = None,
|
| 525 |
+
sampler: Optional[CutSampler] = None,
|
| 526 |
+
rank: int = 0,
|
| 527 |
+
):
|
| 528 |
+
"""Save training info after processing given number of batches.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
out_dir:
|
| 532 |
+
The directory to save the checkpoint.
|
| 533 |
+
global_batch_idx:
|
| 534 |
+
The number of batches processed so far from the very start of the
|
| 535 |
+
training. The saved checkpoint will have the following filename:
|
| 536 |
+
|
| 537 |
+
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
| 538 |
+
model:
|
| 539 |
+
The neural network model whose `state_dict` will be saved in the
|
| 540 |
+
checkpoint.
|
| 541 |
+
model_avg:
|
| 542 |
+
The stored model averaged from the start of training.
|
| 543 |
+
params:
|
| 544 |
+
A dict of training configurations to be saved.
|
| 545 |
+
optimizer:
|
| 546 |
+
The optimizer used in the training. Its `state_dict` will be saved.
|
| 547 |
+
scheduler:
|
| 548 |
+
The learning rate scheduler used in the training. Its `state_dict` will
|
| 549 |
+
be saved.
|
| 550 |
+
scaler:
|
| 551 |
+
The scaler used for mix precision training. Its `state_dict` will
|
| 552 |
+
be saved.
|
| 553 |
+
sampler:
|
| 554 |
+
The sampler used in the training dataset.
|
| 555 |
+
rank:
|
| 556 |
+
The rank ID used in DDP training of the current node. Set it to 0
|
| 557 |
+
if DDP is not used.
|
| 558 |
+
"""
|
| 559 |
+
out_dir = Path(out_dir)
|
| 560 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 561 |
+
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
| 562 |
+
save_checkpoint(
|
| 563 |
+
filename=filename,
|
| 564 |
+
model=model,
|
| 565 |
+
model_avg=model_avg,
|
| 566 |
+
params=params,
|
| 567 |
+
optimizer=optimizer,
|
| 568 |
+
scheduler=scheduler,
|
| 569 |
+
scaler=scaler,
|
| 570 |
+
sampler=sampler,
|
| 571 |
+
rank=rank,
|
| 572 |
+
)
|
zipvoice/utils/common.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import collections
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import socket
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import distributed as dist
|
| 16 |
+
from torch import nn
|
| 17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 18 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 19 |
+
|
| 20 |
+
Pathlike = Union[str, Path]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AttributeDict(dict):
|
| 24 |
+
def __getattr__(self, key):
|
| 25 |
+
if key in self:
|
| 26 |
+
return self[key]
|
| 27 |
+
raise AttributeError(f"No such attribute '{key}'")
|
| 28 |
+
|
| 29 |
+
def __setattr__(self, key, value):
|
| 30 |
+
self[key] = value
|
| 31 |
+
|
| 32 |
+
def __delattr__(self, key):
|
| 33 |
+
if key in self:
|
| 34 |
+
del self[key]
|
| 35 |
+
return
|
| 36 |
+
raise AttributeError(f"No such attribute '{key}'")
|
| 37 |
+
|
| 38 |
+
def __str__(self, indent: int = 2):
|
| 39 |
+
tmp = {}
|
| 40 |
+
for k, v in self.items():
|
| 41 |
+
# PosixPath is ont JSON serializable
|
| 42 |
+
if isinstance(v, (Path, torch.device, torch.dtype)):
|
| 43 |
+
v = str(v)
|
| 44 |
+
tmp[k] = v
|
| 45 |
+
return json.dumps(tmp, indent=indent, sort_keys=True)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MetricsTracker(collections.defaultdict):
|
| 49 |
+
def __init__(self):
|
| 50 |
+
# Passing the type 'int' to the base-class constructor
|
| 51 |
+
# makes undefined items default to int() which is zero.
|
| 52 |
+
# This class will play a role as metrics tracker.
|
| 53 |
+
# It can record many metrics, including but not limited to loss.
|
| 54 |
+
super(MetricsTracker, self).__init__(int)
|
| 55 |
+
|
| 56 |
+
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
| 57 |
+
ans = MetricsTracker()
|
| 58 |
+
for k, v in self.items():
|
| 59 |
+
ans[k] = v
|
| 60 |
+
for k, v in other.items():
|
| 61 |
+
ans[k] = ans[k] + v
|
| 62 |
+
return ans
|
| 63 |
+
|
| 64 |
+
def __mul__(self, alpha: float) -> "MetricsTracker":
|
| 65 |
+
ans = MetricsTracker()
|
| 66 |
+
for k, v in self.items():
|
| 67 |
+
ans[k] = v * alpha
|
| 68 |
+
return ans
|
| 69 |
+
|
| 70 |
+
def __str__(self) -> str:
|
| 71 |
+
ans_frames = ""
|
| 72 |
+
ans_utterances = ""
|
| 73 |
+
for k, v in self.norm_items():
|
| 74 |
+
norm_value = "%.4g" % v
|
| 75 |
+
if "utt_" not in k:
|
| 76 |
+
ans_frames += str(k) + "=" + str(norm_value) + ", "
|
| 77 |
+
else:
|
| 78 |
+
ans_utterances += str(k) + "=" + str(norm_value)
|
| 79 |
+
if k == "utt_duration":
|
| 80 |
+
ans_utterances += " frames, "
|
| 81 |
+
elif k == "utt_pad_proportion":
|
| 82 |
+
ans_utterances += ", "
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Unexpected key: {k}")
|
| 85 |
+
frames = "%.2f" % self["frames"]
|
| 86 |
+
ans_frames += "over " + str(frames) + " frames. "
|
| 87 |
+
if ans_utterances != "":
|
| 88 |
+
utterances = "%.2f" % self["utterances"]
|
| 89 |
+
ans_utterances += "over " + str(utterances) + " utterances."
|
| 90 |
+
|
| 91 |
+
return ans_frames + ans_utterances
|
| 92 |
+
|
| 93 |
+
def norm_items(self) -> List[Tuple[str, float]]:
|
| 94 |
+
"""
|
| 95 |
+
Returns a list of pairs, like:
|
| 96 |
+
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
| 97 |
+
"""
|
| 98 |
+
num_frames = self["frames"] if "frames" in self else 1
|
| 99 |
+
num_utterances = self["utterances"] if "utterances" in self else 1
|
| 100 |
+
ans = []
|
| 101 |
+
for k, v in self.items():
|
| 102 |
+
if k == "frames" or k == "utterances":
|
| 103 |
+
continue
|
| 104 |
+
norm_value = (
|
| 105 |
+
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
|
| 106 |
+
)
|
| 107 |
+
ans.append((k, norm_value))
|
| 108 |
+
return ans
|
| 109 |
+
|
| 110 |
+
def reduce(self, device):
|
| 111 |
+
"""
|
| 112 |
+
Reduce using torch.distributed, which I believe ensures that
|
| 113 |
+
all processes get the total.
|
| 114 |
+
"""
|
| 115 |
+
keys = sorted(self.keys())
|
| 116 |
+
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
| 117 |
+
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
| 118 |
+
for k, v in zip(keys, s.cpu().tolist()):
|
| 119 |
+
self[k] = v
|
| 120 |
+
|
| 121 |
+
def write_summary(
|
| 122 |
+
self,
|
| 123 |
+
tb_writer: SummaryWriter,
|
| 124 |
+
prefix: str,
|
| 125 |
+
batch_idx: int,
|
| 126 |
+
) -> None:
|
| 127 |
+
"""Add logging information to a TensorBoard writer.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
tb_writer: a TensorBoard writer
|
| 131 |
+
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
| 132 |
+
or "train/current_"
|
| 133 |
+
batch_idx: The current batch index, used as the x-axis of the plot.
|
| 134 |
+
"""
|
| 135 |
+
for k, v in self.norm_items():
|
| 136 |
+
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def setup_dist(
|
| 140 |
+
rank=None,
|
| 141 |
+
world_size=None,
|
| 142 |
+
master_port=None,
|
| 143 |
+
use_ddp_launch=False,
|
| 144 |
+
master_addr=None,
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
rank and world_size are used only if use_ddp_launch is False.
|
| 148 |
+
"""
|
| 149 |
+
if "MASTER_ADDR" not in os.environ:
|
| 150 |
+
os.environ["MASTER_ADDR"] = (
|
| 151 |
+
"localhost" if master_addr is None else str(master_addr)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if "MASTER_PORT" not in os.environ:
|
| 155 |
+
os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
|
| 156 |
+
|
| 157 |
+
if use_ddp_launch is False:
|
| 158 |
+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
| 159 |
+
torch.cuda.set_device(rank)
|
| 160 |
+
else:
|
| 161 |
+
dist.init_process_group("nccl")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def cleanup_dist():
|
| 165 |
+
dist.destroy_process_group()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def prepare_input(
|
| 169 |
+
params: AttributeDict,
|
| 170 |
+
batch: dict,
|
| 171 |
+
device: torch.device,
|
| 172 |
+
return_tokens: bool = True,
|
| 173 |
+
return_feature: bool = True,
|
| 174 |
+
return_audio: bool = False,
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Parse the features and targets of the current batch.
|
| 178 |
+
Args:
|
| 179 |
+
params:
|
| 180 |
+
It is returned by :func:`get_params`.
|
| 181 |
+
batch:
|
| 182 |
+
It is the return value from iterating
|
| 183 |
+
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
| 184 |
+
for the format of the `batch`.
|
| 185 |
+
device:
|
| 186 |
+
The device of Tensor.
|
| 187 |
+
"""
|
| 188 |
+
return_list = []
|
| 189 |
+
|
| 190 |
+
if return_tokens:
|
| 191 |
+
return_list += [batch["tokens"]]
|
| 192 |
+
|
| 193 |
+
if return_feature:
|
| 194 |
+
features = batch["features"].to(device)
|
| 195 |
+
features_lens = batch["features_lens"].to(device)
|
| 196 |
+
return_list += [features * params.feat_scale, features_lens]
|
| 197 |
+
|
| 198 |
+
if return_audio:
|
| 199 |
+
return_list += [batch["audio"], batch["audio_lens"]]
|
| 200 |
+
|
| 201 |
+
return return_list
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def prepare_avg_tokens_durations(features_lens, tokens_lens):
|
| 205 |
+
tokens_durations = []
|
| 206 |
+
for i in range(len(features_lens)):
|
| 207 |
+
utt_duration = features_lens[i]
|
| 208 |
+
avg_token_duration = utt_duration // tokens_lens[i]
|
| 209 |
+
tokens_durations.append([avg_token_duration] * tokens_lens[i])
|
| 210 |
+
return tokens_durations
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
|
| 214 |
+
"""
|
| 215 |
+
Pad the transcripts to the same length with zeros.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
y: the transcripts, which is a list of a list
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Return a Tensor of padded transcripts.
|
| 222 |
+
"""
|
| 223 |
+
y = [token_ids + [pad_id] for token_ids in y]
|
| 224 |
+
length = max([len(token_ids) for token_ids in y])
|
| 225 |
+
y = [token_ids + [pad_id] * (length - len(token_ids)) for token_ids in y]
|
| 226 |
+
return torch.tensor(y, dtype=torch.int64, device=device)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
|
| 230 |
+
"""
|
| 231 |
+
Gets position in the transcript for each frame, i.e. the position
|
| 232 |
+
in the symbol-sequence to look up.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
durations:
|
| 236 |
+
Duration of each token in transcripts.
|
| 237 |
+
num_frames:
|
| 238 |
+
The maximum frame length of the current batch.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Return a Tensor of shape (batch_size, num_frames)
|
| 242 |
+
"""
|
| 243 |
+
durations = [x + [num_frames - sum(x)] for x in durations]
|
| 244 |
+
batch_size = len(durations)
|
| 245 |
+
ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
|
| 246 |
+
for b in range(batch_size):
|
| 247 |
+
this_dur = durations[b]
|
| 248 |
+
cur_frame = 0
|
| 249 |
+
for i, d in enumerate(this_dur):
|
| 250 |
+
ans[b, cur_frame : cur_frame + d] = i
|
| 251 |
+
cur_frame += d
|
| 252 |
+
assert cur_frame == num_frames, (cur_frame, num_frames)
|
| 253 |
+
return ans
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def to_int_tuple(s: Union[str, int]):
|
| 257 |
+
if isinstance(s, int):
|
| 258 |
+
return (s,)
|
| 259 |
+
return tuple(map(int, s.split(",")))
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
| 263 |
+
# returns the number of batches we would have used so far if we had used the
|
| 264 |
+
# reference duration. This is for purposes of set_batch_count().
|
| 265 |
+
return (
|
| 266 |
+
params.batch_idx_train
|
| 267 |
+
* (params.max_duration * params.world_size)
|
| 268 |
+
/ params.ref_duration
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
| 273 |
+
if isinstance(model, DDP):
|
| 274 |
+
# get underlying nn.Module
|
| 275 |
+
model = model.module
|
| 276 |
+
for name, module in model.named_modules():
|
| 277 |
+
if hasattr(module, "batch_count"):
|
| 278 |
+
module.batch_count = batch_count
|
| 279 |
+
if hasattr(module, "name"):
|
| 280 |
+
module.name = name
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def condition_time_mask(
|
| 284 |
+
features_lens: torch.Tensor,
|
| 285 |
+
mask_percent: Tuple[float, float],
|
| 286 |
+
max_len: int = 0,
|
| 287 |
+
) -> torch.Tensor:
|
| 288 |
+
"""
|
| 289 |
+
Apply Time masking.
|
| 290 |
+
Args:
|
| 291 |
+
features_lens:
|
| 292 |
+
input tensor of shape ``(B)``
|
| 293 |
+
mask_size:
|
| 294 |
+
the width size for masking.
|
| 295 |
+
max_len:
|
| 296 |
+
the maximum length of the mask.
|
| 297 |
+
Returns:
|
| 298 |
+
Return a 2-D bool tensor (B, T), where masked positions
|
| 299 |
+
are filled with `True` and non-masked positions are
|
| 300 |
+
filled with `False`.
|
| 301 |
+
"""
|
| 302 |
+
mask_size = (
|
| 303 |
+
torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
|
| 304 |
+
* features_lens
|
| 305 |
+
).to(torch.int64)
|
| 306 |
+
mask_starts = (
|
| 307 |
+
torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
|
| 308 |
+
).to(torch.int64)
|
| 309 |
+
mask_ends = mask_starts + mask_size
|
| 310 |
+
max_len = max(max_len, features_lens.max())
|
| 311 |
+
seq_range = torch.arange(0, max_len, device=features_lens.device)
|
| 312 |
+
mask = (seq_range[None, :] >= mask_starts[:, None]) & (
|
| 313 |
+
seq_range[None, :] < mask_ends[:, None]
|
| 314 |
+
)
|
| 315 |
+
return mask
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def condition_time_mask_suffix(
|
| 319 |
+
features_lens: torch.Tensor,
|
| 320 |
+
mask_percent: Tuple[float, float],
|
| 321 |
+
max_len: int = 0,
|
| 322 |
+
) -> torch.Tensor:
|
| 323 |
+
"""
|
| 324 |
+
Apply Time masking, mask from the end time index.
|
| 325 |
+
Args:
|
| 326 |
+
features_lens:
|
| 327 |
+
input tensor of shape ``(B)``
|
| 328 |
+
mask_size:
|
| 329 |
+
the width size for masking.
|
| 330 |
+
max_len:
|
| 331 |
+
the maximum length of the mask.
|
| 332 |
+
Returns:
|
| 333 |
+
Return a 2-D bool tensor (B, T), where masked positions
|
| 334 |
+
are filled with `True` and non-masked positions are
|
| 335 |
+
filled with `False`.
|
| 336 |
+
"""
|
| 337 |
+
mask_size = (
|
| 338 |
+
torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
|
| 339 |
+
* features_lens
|
| 340 |
+
).to(torch.int64)
|
| 341 |
+
mask_starts = (
|
| 342 |
+
torch.ones_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
|
| 343 |
+
).to(torch.int64)
|
| 344 |
+
mask_ends = mask_starts + mask_size
|
| 345 |
+
max_len = max(max_len, features_lens.max())
|
| 346 |
+
seq_range = torch.arange(0, max_len, device=features_lens.device)
|
| 347 |
+
mask = (seq_range[None, :] >= mask_starts[:, None]) & (
|
| 348 |
+
seq_range[None, :] < mask_ends[:, None]
|
| 349 |
+
)
|
| 350 |
+
return mask
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 354 |
+
"""
|
| 355 |
+
Args:
|
| 356 |
+
lengths:
|
| 357 |
+
A 1-D tensor containing sentence lengths.
|
| 358 |
+
max_len:
|
| 359 |
+
The length of masks.
|
| 360 |
+
Returns:
|
| 361 |
+
Return a 2-D bool tensor, where masked positions
|
| 362 |
+
are filled with `True` and non-masked positions are
|
| 363 |
+
filled with `False`.
|
| 364 |
+
|
| 365 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
| 366 |
+
>>> make_pad_mask(lengths)
|
| 367 |
+
tensor([[False, True, True, True, True],
|
| 368 |
+
[False, False, False, True, True],
|
| 369 |
+
[False, False, True, True, True],
|
| 370 |
+
[False, False, False, False, False]])
|
| 371 |
+
"""
|
| 372 |
+
assert lengths.ndim == 1, lengths.ndim
|
| 373 |
+
max_len = max(max_len, lengths.max())
|
| 374 |
+
n = lengths.size(0)
|
| 375 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
| 376 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
| 377 |
+
|
| 378 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def str2bool(v):
|
| 382 |
+
"""Used in argparse.ArgumentParser.add_argument to indicate
|
| 383 |
+
that a type is a bool type and user can enter
|
| 384 |
+
|
| 385 |
+
- yes, true, t, y, 1, to represent True
|
| 386 |
+
- no, false, f, n, 0, to represent False
|
| 387 |
+
|
| 388 |
+
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
| 389 |
+
"""
|
| 390 |
+
if isinstance(v, bool):
|
| 391 |
+
return v
|
| 392 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 393 |
+
return True
|
| 394 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 395 |
+
return False
|
| 396 |
+
else:
|
| 397 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def setup_logger(
|
| 401 |
+
log_filename: Pathlike,
|
| 402 |
+
log_level: str = "info",
|
| 403 |
+
use_console: bool = True,
|
| 404 |
+
) -> None:
|
| 405 |
+
"""Setup log level.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
log_filename:
|
| 409 |
+
The filename to save the log.
|
| 410 |
+
log_level:
|
| 411 |
+
The log level to use, e.g., "debug", "info", "warning", "error",
|
| 412 |
+
"critical"
|
| 413 |
+
use_console:
|
| 414 |
+
True to also print logs to console.
|
| 415 |
+
"""
|
| 416 |
+
now = datetime.now()
|
| 417 |
+
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
| 418 |
+
if dist.is_available() and dist.is_initialized():
|
| 419 |
+
world_size = dist.get_world_size()
|
| 420 |
+
rank = dist.get_rank()
|
| 421 |
+
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
| 422 |
+
log_filename = f"{log_filename}-{date_time}-{rank}"
|
| 423 |
+
else:
|
| 424 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 425 |
+
log_filename = f"{log_filename}-{date_time}"
|
| 426 |
+
|
| 427 |
+
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
| 428 |
+
|
| 429 |
+
level = logging.ERROR
|
| 430 |
+
if log_level == "debug":
|
| 431 |
+
level = logging.DEBUG
|
| 432 |
+
elif log_level == "info":
|
| 433 |
+
level = logging.INFO
|
| 434 |
+
elif log_level == "warning":
|
| 435 |
+
level = logging.WARNING
|
| 436 |
+
elif log_level == "critical":
|
| 437 |
+
level = logging.CRITICAL
|
| 438 |
+
|
| 439 |
+
logging.basicConfig(
|
| 440 |
+
filename=log_filename,
|
| 441 |
+
format=formatter,
|
| 442 |
+
level=level,
|
| 443 |
+
filemode="w",
|
| 444 |
+
force=True,
|
| 445 |
+
)
|
| 446 |
+
if use_console:
|
| 447 |
+
console = logging.StreamHandler()
|
| 448 |
+
console.setLevel(level)
|
| 449 |
+
console.setFormatter(logging.Formatter(formatter))
|
| 450 |
+
logging.getLogger("").addHandler(console)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def get_git_sha1():
|
| 454 |
+
try:
|
| 455 |
+
git_commit = (
|
| 456 |
+
subprocess.run(
|
| 457 |
+
["git", "rev-parse", "--short", "HEAD"],
|
| 458 |
+
check=True,
|
| 459 |
+
stdout=subprocess.PIPE,
|
| 460 |
+
)
|
| 461 |
+
.stdout.decode()
|
| 462 |
+
.rstrip("\n")
|
| 463 |
+
.strip()
|
| 464 |
+
)
|
| 465 |
+
dirty_commit = (
|
| 466 |
+
len(
|
| 467 |
+
subprocess.run(
|
| 468 |
+
["git", "diff", "--shortstat"],
|
| 469 |
+
check=True,
|
| 470 |
+
stdout=subprocess.PIPE,
|
| 471 |
+
)
|
| 472 |
+
.stdout.decode()
|
| 473 |
+
.rstrip("\n")
|
| 474 |
+
.strip()
|
| 475 |
+
)
|
| 476 |
+
> 0
|
| 477 |
+
)
|
| 478 |
+
git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
| 479 |
+
except: # noqa
|
| 480 |
+
return None
|
| 481 |
+
|
| 482 |
+
return git_commit
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def get_git_date():
|
| 486 |
+
try:
|
| 487 |
+
git_date = (
|
| 488 |
+
subprocess.run(
|
| 489 |
+
["git", "log", "-1", "--format=%ad", "--date=local"],
|
| 490 |
+
check=True,
|
| 491 |
+
stdout=subprocess.PIPE,
|
| 492 |
+
)
|
| 493 |
+
.stdout.decode()
|
| 494 |
+
.rstrip("\n")
|
| 495 |
+
.strip()
|
| 496 |
+
)
|
| 497 |
+
except: # noqa
|
| 498 |
+
return None
|
| 499 |
+
|
| 500 |
+
return git_date
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def get_git_branch_name():
|
| 504 |
+
try:
|
| 505 |
+
git_date = (
|
| 506 |
+
subprocess.run(
|
| 507 |
+
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
| 508 |
+
check=True,
|
| 509 |
+
stdout=subprocess.PIPE,
|
| 510 |
+
)
|
| 511 |
+
.stdout.decode()
|
| 512 |
+
.rstrip("\n")
|
| 513 |
+
.strip()
|
| 514 |
+
)
|
| 515 |
+
except: # noqa
|
| 516 |
+
return None
|
| 517 |
+
|
| 518 |
+
return git_date
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def get_env_info() -> Dict[str, Any]:
|
| 522 |
+
"""Get the environment information."""
|
| 523 |
+
return {
|
| 524 |
+
"torch-version": str(torch.__version__),
|
| 525 |
+
"torch-cuda-available": torch.cuda.is_available(),
|
| 526 |
+
"torch-cuda-version": torch.version.cuda,
|
| 527 |
+
"python-version": sys.version[:4],
|
| 528 |
+
"zipvoice-git-branch": get_git_branch_name(),
|
| 529 |
+
"zipvoice-git-sha1": get_git_sha1(),
|
| 530 |
+
"zipvoice-git-date": get_git_date(),
|
| 531 |
+
"zipvoice-path": str(Path(__file__).resolve().parent.parent),
|
| 532 |
+
"hostname": socket.gethostname(),
|
| 533 |
+
"IP address": socket.gethostbyname(socket.gethostname()),
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def get_parameter_groups_with_lrs(
|
| 538 |
+
model: nn.Module,
|
| 539 |
+
lr: float,
|
| 540 |
+
include_names: bool = False,
|
| 541 |
+
freeze_modules: List[str] = [],
|
| 542 |
+
) -> List[dict]:
|
| 543 |
+
"""
|
| 544 |
+
This is for use with the ScaledAdam optimizers (more recent versions that accept
|
| 545 |
+
lists of named-parameters; we can, if needed, create a version without the names).
|
| 546 |
+
|
| 547 |
+
It provides a way to specify learning-rate scales inside the module, so that if
|
| 548 |
+
any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
|
| 549 |
+
scale the LR of any parameters inside that module or its submodules. Note: you
|
| 550 |
+
can set module parameters outside the __init__ function, e.g.:
|
| 551 |
+
>>> a = nn.Linear(10, 10)
|
| 552 |
+
>>> a.lr_scale = 0.5
|
| 553 |
+
|
| 554 |
+
Returns: a list of dicts, of the following form:
|
| 555 |
+
if include_names == False:
|
| 556 |
+
[ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 },
|
| 557 |
+
{ 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 },
|
| 558 |
+
... ]
|
| 559 |
+
if include_names == true:
|
| 560 |
+
[ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 },
|
| 561 |
+
{ 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 },
|
| 562 |
+
... ]
|
| 563 |
+
|
| 564 |
+
"""
|
| 565 |
+
# flat_lr_scale just contains the lr_scale explicitly specified
|
| 566 |
+
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
| 567 |
+
# to be multiplied for all prefix of the name of any given parameter.
|
| 568 |
+
flat_lr_scale = defaultdict(lambda: 1.0)
|
| 569 |
+
names = []
|
| 570 |
+
for name, m in model.named_modules():
|
| 571 |
+
names.append(name)
|
| 572 |
+
if hasattr(m, "lr_scale"):
|
| 573 |
+
flat_lr_scale[name] = m.lr_scale
|
| 574 |
+
|
| 575 |
+
# lr_to_parames is a dict from learning rate (floating point) to: if
|
| 576 |
+
# include_names == true, a list of (name, parameter) for that learning rate;
|
| 577 |
+
# otherwise a list of parameters for that learning rate.
|
| 578 |
+
lr_to_params = defaultdict(list)
|
| 579 |
+
|
| 580 |
+
for name, parameter in model.named_parameters():
|
| 581 |
+
split_name = name.split(".")
|
| 582 |
+
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
| 583 |
+
prefix = split_name[0]
|
| 584 |
+
if prefix == "module": # DDP
|
| 585 |
+
module_name = split_name[1]
|
| 586 |
+
if module_name in freeze_modules:
|
| 587 |
+
logging.info(f"Remove {name} from parameters")
|
| 588 |
+
continue
|
| 589 |
+
else:
|
| 590 |
+
if prefix in freeze_modules:
|
| 591 |
+
logging.info(f"Remove {name} from parameters")
|
| 592 |
+
continue
|
| 593 |
+
cur_lr = lr * flat_lr_scale[prefix]
|
| 594 |
+
if prefix != "":
|
| 595 |
+
cur_lr *= flat_lr_scale[""]
|
| 596 |
+
for part in split_name[1:]:
|
| 597 |
+
prefix = ".".join([prefix, part])
|
| 598 |
+
cur_lr *= flat_lr_scale[prefix]
|
| 599 |
+
lr_to_params[cur_lr].append((name, parameter) if include_names else parameter)
|
| 600 |
+
|
| 601 |
+
if include_names:
|
| 602 |
+
return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()]
|
| 603 |
+
else:
|
| 604 |
+
return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()]
|
zipvoice/utils/diagnostics.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
|
| 2 |
+
# Zengwei Yao
|
| 3 |
+
# Mingshuang Luo,
|
| 4 |
+
# Zengrui Jin,)
|
| 5 |
+
#
|
| 6 |
+
# See ../LICENSE for clarification regarding multiple authors
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
import random
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import Optional, Tuple
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch import Tensor, nn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TensorDiagnosticOptions(object):
|
| 30 |
+
"""Options object for tensor diagnostics:
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
max_eig_dim:
|
| 34 |
+
The maximum dimension for which we print out eigenvalues
|
| 35 |
+
(limited for speed reasons).
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, max_eig_dim: int = 512):
|
| 39 |
+
self.max_eig_dim = max_eig_dim
|
| 40 |
+
|
| 41 |
+
def dim_is_summarized(self, size: int):
|
| 42 |
+
return size > 10 and size != 31
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_tensor_stats(
|
| 46 |
+
x: Tensor,
|
| 47 |
+
dim: int,
|
| 48 |
+
stats_type: str,
|
| 49 |
+
) -> Tuple[Tensor, int]:
|
| 50 |
+
"""
|
| 51 |
+
Returns the specified transformation of the Tensor (either x or x.abs()
|
| 52 |
+
or (x > 0), summed over all but the index `dim`.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x:
|
| 56 |
+
Tensor, tensor to be analyzed
|
| 57 |
+
dim:
|
| 58 |
+
Dimension with 0 <= dim < x.ndim
|
| 59 |
+
stats_type:
|
| 60 |
+
The stats_type includes several types:
|
| 61 |
+
"abs" -> take abs() before summing
|
| 62 |
+
"positive" -> take (x > 0) before summing
|
| 63 |
+
"rms" -> square before summing, we'll take sqrt later
|
| 64 |
+
"value" -> just sum x itself
|
| 65 |
+
"max", "min" -> take the maximum or minimum [over all other dims but dim]
|
| 66 |
+
instead of summing
|
| 67 |
+
"rms-sort" -> this is a bit different than the others, it's based on computing
|
| 68 |
+
the rms over the specified dim and returning percentiles of the result
|
| 69 |
+
(11 of them).
|
| 70 |
+
Returns:
|
| 71 |
+
stats: a Tensor of shape (x.shape[dim],).
|
| 72 |
+
count: an integer saying how many items were counted in each element
|
| 73 |
+
of stats.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
if stats_type == "rms-sort":
|
| 77 |
+
rms = (x**2).mean(dim=dim).sqrt()
|
| 78 |
+
rms = rms.flatten()
|
| 79 |
+
rms = rms.sort()[0]
|
| 80 |
+
rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
|
| 81 |
+
count = 1.0
|
| 82 |
+
return rms, count
|
| 83 |
+
|
| 84 |
+
count = x.numel() // x.shape[dim]
|
| 85 |
+
|
| 86 |
+
if stats_type == "eigs":
|
| 87 |
+
x = x.transpose(dim, -1)
|
| 88 |
+
x = x.reshape(-1, x.shape[-1])
|
| 89 |
+
# shape of returned tensor: (s, s),
|
| 90 |
+
# where s is size of dimension `dim` of original x.
|
| 91 |
+
return torch.matmul(x.transpose(0, 1), x), count
|
| 92 |
+
elif stats_type == "abs":
|
| 93 |
+
x = x.abs()
|
| 94 |
+
elif stats_type == "rms":
|
| 95 |
+
x = x**2
|
| 96 |
+
elif stats_type == "positive":
|
| 97 |
+
x = (x > 0).to(dtype=torch.float)
|
| 98 |
+
else:
|
| 99 |
+
assert stats_type in ["value", "max", "min"]
|
| 100 |
+
|
| 101 |
+
sum_dims = [d for d in range(x.ndim) if d != dim]
|
| 102 |
+
if len(sum_dims) > 0:
|
| 103 |
+
if stats_type == "max":
|
| 104 |
+
for dim in reversed(sum_dims):
|
| 105 |
+
x = torch.max(x, dim=dim)[0]
|
| 106 |
+
elif stats_type == "min":
|
| 107 |
+
for dim in reversed(sum_dims):
|
| 108 |
+
x = torch.min(x, dim=dim)[0]
|
| 109 |
+
else:
|
| 110 |
+
x = torch.sum(x, dim=sum_dims)
|
| 111 |
+
x = x.flatten().clone()
|
| 112 |
+
return x, count
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
+
class TensorAndCount:
|
| 117 |
+
tensor: Tensor
|
| 118 |
+
count: int
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TensorDiagnostic(object):
|
| 122 |
+
"""This class is not directly used by the user, it is responsible for
|
| 123 |
+
collecting diagnostics for a module or parameter tensor of a torch.nn.Module.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
opts:
|
| 127 |
+
Options object.
|
| 128 |
+
name:
|
| 129 |
+
The name associated with this diagnostics object, will probably be
|
| 130 |
+
{module_name}.X where X is "output" or "grad", or {parameter_name}.
|
| 131 |
+
Y where Y is param_value or param_grad.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
| 135 |
+
self.opts = opts
|
| 136 |
+
self.name = name
|
| 137 |
+
self.class_name = None # will assign in accumulate()
|
| 138 |
+
|
| 139 |
+
self.stats = None # we'll later assign a list to self.stats.
|
| 140 |
+
# It's a list of dicts, indexed by dim (i.e. by the
|
| 141 |
+
# axis of the tensor). The dicts, in turn, are
|
| 142 |
+
# indexed by `stats-type` which are strings in
|
| 143 |
+
# ["abs", "max", "min", "positive", "value", "rms"].
|
| 144 |
+
|
| 145 |
+
# scalar_stats contains some analysis of the activations and gradients,
|
| 146 |
+
self.scalar_stats = None
|
| 147 |
+
|
| 148 |
+
# the keys into self.stats[dim] are strings, whose values can be
|
| 149 |
+
# "abs", "max", "min" ,"value", "positive", "rms", "value".
|
| 150 |
+
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
|
| 151 |
+
# containing a tensor and its associated count (which is the sum of the other
|
| 152 |
+
# dims that we aggregated over, e.g. the number of frames and/or batch elements
|
| 153 |
+
# and/or channels.
|
| 154 |
+
# ... we actually accumulate the Tensors / counts any time we have the same-dim
|
| 155 |
+
# tensor, only adding a new element to the list if there was a different dim.
|
| 156 |
+
# if the string in the key is "eigs", if we detect a length mismatch we put None
|
| 157 |
+
# as the value.
|
| 158 |
+
|
| 159 |
+
def accumulate(self, x, class_name: Optional[str] = None):
|
| 160 |
+
"""
|
| 161 |
+
Accumulate tensors.
|
| 162 |
+
"""
|
| 163 |
+
if class_name is not None:
|
| 164 |
+
self.class_name = class_name
|
| 165 |
+
if isinstance(x, Tuple):
|
| 166 |
+
x = x[0]
|
| 167 |
+
if not isinstance(x, Tensor):
|
| 168 |
+
return
|
| 169 |
+
if x.numel() == 0: # for empty tensor
|
| 170 |
+
return
|
| 171 |
+
x = x.detach().clone()
|
| 172 |
+
if x.ndim == 0:
|
| 173 |
+
x = x.unsqueeze(0)
|
| 174 |
+
ndim = x.ndim
|
| 175 |
+
if self.stats is None:
|
| 176 |
+
self.stats = [dict() for _ in range(ndim)]
|
| 177 |
+
|
| 178 |
+
for dim in range(ndim):
|
| 179 |
+
this_dim_stats = self.stats[dim]
|
| 180 |
+
if ndim > 1:
|
| 181 |
+
# rms-sort is different from the others, it's based on summing over just
|
| 182 |
+
# this dim, then sorting and returning the percentiles.
|
| 183 |
+
stats_types = [
|
| 184 |
+
"abs",
|
| 185 |
+
"max",
|
| 186 |
+
"min",
|
| 187 |
+
"positive",
|
| 188 |
+
"value",
|
| 189 |
+
"rms",
|
| 190 |
+
"rms-sort",
|
| 191 |
+
]
|
| 192 |
+
if x.shape[dim] <= self.opts.max_eig_dim:
|
| 193 |
+
stats_types.append("eigs")
|
| 194 |
+
else:
|
| 195 |
+
stats_types = ["value", "abs", "max", "min"]
|
| 196 |
+
|
| 197 |
+
for stats_type in stats_types:
|
| 198 |
+
stats, count = get_tensor_stats(x, dim, stats_type)
|
| 199 |
+
if stats_type not in this_dim_stats:
|
| 200 |
+
this_dim_stats[stats_type] = [] # list of TensorAndCount
|
| 201 |
+
|
| 202 |
+
done = False
|
| 203 |
+
if this_dim_stats[stats_type] is None:
|
| 204 |
+
# we can reach here if we detected for stats_type "eigs" that
|
| 205 |
+
# where was more than one different size for this dim. Then we
|
| 206 |
+
# disable accumulating this stats type, as it uses too much memory.
|
| 207 |
+
continue
|
| 208 |
+
for s in this_dim_stats[stats_type]:
|
| 209 |
+
if s.tensor.shape == stats.shape:
|
| 210 |
+
if stats_type == "max":
|
| 211 |
+
s.tensor = torch.maximum(s.tensor, stats)
|
| 212 |
+
|
| 213 |
+
elif stats_type == "min":
|
| 214 |
+
s.tensor = torch.minimum(s.tensor, stats)
|
| 215 |
+
else:
|
| 216 |
+
assert stats_type != "max"
|
| 217 |
+
s.tensor += stats
|
| 218 |
+
s.count += count
|
| 219 |
+
done = True
|
| 220 |
+
break
|
| 221 |
+
if not done:
|
| 222 |
+
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
|
| 223 |
+
# >1 size encountered on this dim, e.g. it's a batch or time
|
| 224 |
+
# dimension, don't accumulat "eigs" stats type, it uses too much
|
| 225 |
+
# memory
|
| 226 |
+
this_dim_stats[stats_type] = None
|
| 227 |
+
else:
|
| 228 |
+
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
|
| 229 |
+
|
| 230 |
+
def print_diagnostics(self):
|
| 231 |
+
"""Print diagnostics for each dimension of the tensor."""
|
| 232 |
+
if self.stats is None:
|
| 233 |
+
print(f"Warning: the stats of {self.name} is None.")
|
| 234 |
+
return
|
| 235 |
+
for dim, this_dim_stats in enumerate(self.stats):
|
| 236 |
+
if "rms" in this_dim_stats and "value" in this_dim_stats:
|
| 237 |
+
# produce "stddev" stats, which is centered RMS.
|
| 238 |
+
rms_stats_list = this_dim_stats["rms"]
|
| 239 |
+
value_stats_list = this_dim_stats["value"]
|
| 240 |
+
if len(rms_stats_list) == len(value_stats_list):
|
| 241 |
+
stddev_stats_list = []
|
| 242 |
+
for r, v in zip(rms_stats_list, value_stats_list):
|
| 243 |
+
stddev_stats_list.append(
|
| 244 |
+
# r.count and v.count should be the same, but we don't check
|
| 245 |
+
# this.
|
| 246 |
+
TensorAndCount(
|
| 247 |
+
r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
|
| 248 |
+
r.count,
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
this_dim_stats["stddev"] = stddev_stats_list
|
| 252 |
+
|
| 253 |
+
for stats_type, stats_list in this_dim_stats.items():
|
| 254 |
+
# stats_type could be "rms", "value", "abs", "eigs", "positive", "min"
|
| 255 |
+
# or "max". "stats_list" could be a list of TensorAndCount (one list per
|
| 256 |
+
# distinct tensor shape of the stats), or None
|
| 257 |
+
if stats_list is None:
|
| 258 |
+
assert stats_type == "eigs"
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
def get_count(count):
|
| 262 |
+
return 1 if stats_type in ["max", "min"] else count
|
| 263 |
+
|
| 264 |
+
if len(stats_list) == 1:
|
| 265 |
+
stats = stats_list[0].tensor / get_count(stats_list[0].count)
|
| 266 |
+
else:
|
| 267 |
+
# a dimension that has variable size in different nnet
|
| 268 |
+
# forwards, e.g. a time dimension in an ASR model.
|
| 269 |
+
stats = torch.cat(
|
| 270 |
+
[x.tensor / get_count(x.count) for x in stats_list], dim=0
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if stats_type == "eigs":
|
| 274 |
+
try:
|
| 275 |
+
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
|
| 276 |
+
eigs, _ = torch.linalg.eigh(stats)
|
| 277 |
+
else:
|
| 278 |
+
eigs, _ = torch.symeig(stats)
|
| 279 |
+
stats = eigs.abs().sqrt()
|
| 280 |
+
except: # noqa
|
| 281 |
+
print("Error getting eigenvalues, trying another method.")
|
| 282 |
+
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
|
| 283 |
+
eigs, _ = torch.linalg.eig(stats)
|
| 284 |
+
eigs = eigs.abs()
|
| 285 |
+
else:
|
| 286 |
+
eigs, _ = torch.eig(stats)
|
| 287 |
+
eigs = eigs.norm(dim=1)
|
| 288 |
+
stats = eigs.sqrt()
|
| 289 |
+
# sqrt so it reflects data magnitude, like stddev- not variance
|
| 290 |
+
|
| 291 |
+
if stats_type in ["rms", "stddev"]:
|
| 292 |
+
# we stored the square; after aggregation we need to take sqrt.
|
| 293 |
+
stats = stats.sqrt()
|
| 294 |
+
|
| 295 |
+
# if `summarize` we print percentiles of the stats; else,
|
| 296 |
+
# we print out individual elements.
|
| 297 |
+
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
|
| 298 |
+
stats.numel()
|
| 299 |
+
)
|
| 300 |
+
if summarize: # usually `summarize` will be true
|
| 301 |
+
# print out percentiles.
|
| 302 |
+
stats = stats.sort()[0]
|
| 303 |
+
num_percentiles = 10
|
| 304 |
+
size = stats.numel()
|
| 305 |
+
percentiles = []
|
| 306 |
+
for i in range(num_percentiles + 1):
|
| 307 |
+
index = (i * (size - 1)) // num_percentiles
|
| 308 |
+
percentiles.append(stats[index].item())
|
| 309 |
+
percentiles = ["%.2g" % x for x in percentiles]
|
| 310 |
+
percentiles = " ".join(percentiles)
|
| 311 |
+
ans = f"percentiles: [{percentiles}]"
|
| 312 |
+
else:
|
| 313 |
+
ans = stats.tolist()
|
| 314 |
+
ans = ["%.2g" % x for x in ans]
|
| 315 |
+
ans = "[" + " ".join(ans) + "]"
|
| 316 |
+
if stats_type in ["value", "rms", "stddev", "eigs"]:
|
| 317 |
+
# This norm is useful because it is strictly less than the largest
|
| 318 |
+
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
| 319 |
+
# speaking in an approximate way, how much of that largest
|
| 320 |
+
# eigenvalue can be attributed to the mean of the distribution.
|
| 321 |
+
norm = (stats**2).sum().sqrt().item()
|
| 322 |
+
ans += f", norm={norm:.2g}"
|
| 323 |
+
mean = stats.mean().item()
|
| 324 |
+
rms = (stats**2).mean().sqrt().item()
|
| 325 |
+
ans += f", mean={mean:.3g}, rms={rms:.3g}"
|
| 326 |
+
|
| 327 |
+
# OK, "ans" contains the actual stats, e.g.
|
| 328 |
+
# ans = "percentiles: \
|
| 329 |
+
# [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], \
|
| 330 |
+
# mean=0.5, rms=0.5"
|
| 331 |
+
|
| 332 |
+
sizes = [x.tensor.shape[0] for x in stats_list]
|
| 333 |
+
size_str = (
|
| 334 |
+
f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
|
| 335 |
+
)
|
| 336 |
+
maybe_class_name = (
|
| 337 |
+
f" type={self.class_name}," if self.class_name is not None else ""
|
| 338 |
+
)
|
| 339 |
+
print(
|
| 340 |
+
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, "
|
| 341 |
+
f"{stats_type} {ans}"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class ScalarDiagnostic(object):
|
| 346 |
+
"""This class is not directly used by the user, it is responsible for
|
| 347 |
+
collecting diagnostics for a single module (subclass of torch.nn.Module) that
|
| 348 |
+
represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
| 352 |
+
self.opts = opts
|
| 353 |
+
self.name = name
|
| 354 |
+
self.class_name = None # will assign in accumulate()
|
| 355 |
+
self.is_forward_pass = True
|
| 356 |
+
|
| 357 |
+
self.tick_scale = None
|
| 358 |
+
|
| 359 |
+
self.saved_inputs = []
|
| 360 |
+
self.is_ok = True
|
| 361 |
+
|
| 362 |
+
self.counts = None
|
| 363 |
+
self.sum_grad = None
|
| 364 |
+
self.sum_gradsq = None
|
| 365 |
+
self.sum_abs_grad = None
|
| 366 |
+
|
| 367 |
+
def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
|
| 368 |
+
"""
|
| 369 |
+
Called in forward pass.
|
| 370 |
+
"""
|
| 371 |
+
if not self.is_forward_pass:
|
| 372 |
+
# in case we did a forward pass without a backward pass, for some reason.
|
| 373 |
+
self.saved_inputs = []
|
| 374 |
+
self.is_forward_pass = True
|
| 375 |
+
|
| 376 |
+
if class_name is not None:
|
| 377 |
+
self.class_name = class_name
|
| 378 |
+
if not self.is_ok:
|
| 379 |
+
return
|
| 380 |
+
|
| 381 |
+
limit = 10
|
| 382 |
+
if len(self.saved_inputs) > limit:
|
| 383 |
+
print(
|
| 384 |
+
f"ERROR: forward pass called for this module over {limit} times "
|
| 385 |
+
f"with no backward pass. Will not accumulate scalar stats."
|
| 386 |
+
)
|
| 387 |
+
self.is_ok = False
|
| 388 |
+
return
|
| 389 |
+
self.saved_inputs.append(x)
|
| 390 |
+
|
| 391 |
+
def accumulate_output_grad(self, grad: Tensor):
|
| 392 |
+
if not self.is_ok:
|
| 393 |
+
return
|
| 394 |
+
if self.is_forward_pass:
|
| 395 |
+
self.is_forward_pass = False
|
| 396 |
+
|
| 397 |
+
last_shape = (
|
| 398 |
+
"n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
|
| 399 |
+
)
|
| 400 |
+
if len(self.saved_inputs) == 0 or grad.shape != last_shape:
|
| 401 |
+
print(
|
| 402 |
+
f"ERROR: shape mismatch or no forward activation present when backward "
|
| 403 |
+
f"pass called: grad shape ={tuple(grad.shape)}"
|
| 404 |
+
f", num-saved-inputs={len(self.saved_inputs)}"
|
| 405 |
+
f", shape-of-last-saved-input={last_shape}"
|
| 406 |
+
)
|
| 407 |
+
self.is_ok = False
|
| 408 |
+
return
|
| 409 |
+
|
| 410 |
+
x = self.saved_inputs.pop()
|
| 411 |
+
self.process_input_and_grad(x, grad)
|
| 412 |
+
|
| 413 |
+
def process_input_and_grad(self, x: Tensor, grad: Tensor):
|
| 414 |
+
assert x.shape == grad.shape
|
| 415 |
+
x = x.flatten()
|
| 416 |
+
grad = grad.flatten()
|
| 417 |
+
|
| 418 |
+
num_ticks_per_side = 256
|
| 419 |
+
|
| 420 |
+
if self.tick_scale is None:
|
| 421 |
+
x_abs_sorted = x.abs().sort()[0]
|
| 422 |
+
# take the 98th percentile as the largest value we count separately.
|
| 423 |
+
index = int(x.numel() * 0.98)
|
| 424 |
+
self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
|
| 425 |
+
|
| 426 |
+
# integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
|
| 427 |
+
self.counts = torch.zeros(
|
| 428 |
+
2 * num_ticks_per_side, dtype=torch.long, device=x.device
|
| 429 |
+
)
|
| 430 |
+
self.sum_grad = torch.zeros(
|
| 431 |
+
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
| 432 |
+
)
|
| 433 |
+
# sum_gradsq is for getting error bars.
|
| 434 |
+
self.sum_gradsq = torch.zeros(
|
| 435 |
+
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
| 436 |
+
)
|
| 437 |
+
self.sum_abs_grad = torch.zeros(
|
| 438 |
+
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# this will round down.
|
| 442 |
+
x = (x / self.tick_scale).to(torch.long)
|
| 443 |
+
x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
|
| 444 |
+
x = x + num_ticks_per_side
|
| 445 |
+
|
| 446 |
+
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
|
| 447 |
+
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
|
| 448 |
+
self.sum_gradsq.index_add_(
|
| 449 |
+
dim=0, index=x, source=(grad * grad).to(torch.double)
|
| 450 |
+
)
|
| 451 |
+
self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
|
| 452 |
+
|
| 453 |
+
def print_diagnostics(self):
|
| 454 |
+
"""Print diagnostics."""
|
| 455 |
+
if self.is_ok is False or self.counts is None:
|
| 456 |
+
print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
|
| 457 |
+
return
|
| 458 |
+
|
| 459 |
+
counts = self.counts.to("cpu")
|
| 460 |
+
sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32)
|
| 461 |
+
sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32)
|
| 462 |
+
sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32)
|
| 463 |
+
|
| 464 |
+
counts_cumsum = counts.cumsum(dim=0)
|
| 465 |
+
counts_tot = counts_cumsum[-1]
|
| 466 |
+
|
| 467 |
+
# subdivide the distribution up into `num_bins` intervals for analysis, for
|
| 468 |
+
# greater statistical significance. each bin corresponds to multiple of the
|
| 469 |
+
# original 'tick' intervals.
|
| 470 |
+
num_bins = 20
|
| 471 |
+
|
| 472 |
+
# integer division
|
| 473 |
+
counts_per_bin = (counts_tot // num_bins) + 1
|
| 474 |
+
bin_indexes = counts_cumsum // counts_per_bin
|
| 475 |
+
bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long)
|
| 476 |
+
|
| 477 |
+
bin_counts = torch.zeros(num_bins, dtype=torch.long)
|
| 478 |
+
bin_counts.index_add_(dim=0, index=bin_indexes, source=counts)
|
| 479 |
+
bin_grad = torch.zeros(num_bins)
|
| 480 |
+
bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad)
|
| 481 |
+
bin_gradsq = torch.zeros(num_bins)
|
| 482 |
+
bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
|
| 483 |
+
bin_abs_grad = torch.zeros(num_bins)
|
| 484 |
+
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
|
| 485 |
+
|
| 486 |
+
bin_boundary_counts = (
|
| 487 |
+
torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
|
| 488 |
+
)
|
| 489 |
+
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
|
| 490 |
+
# boundaries are the "x" values between the bins, e.g. corresponding to the
|
| 491 |
+
# locations of percentiles of the distribution.
|
| 492 |
+
num_ticks_per_side = counts.numel() // 2
|
| 493 |
+
bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
|
| 494 |
+
|
| 495 |
+
bin_grad = bin_grad / (bin_counts + 1)
|
| 496 |
+
bin_conf_interval = bin_gradsq.sqrt() / (
|
| 497 |
+
bin_counts + 1
|
| 498 |
+
) # consider this a standard deviation.
|
| 499 |
+
# bin_grad / bin_abs_grad will give us a sense for how important in a practical
|
| 500 |
+
# sense, the gradients are.
|
| 501 |
+
bin_abs_grad = bin_abs_grad / (bin_counts + 1)
|
| 502 |
+
|
| 503 |
+
bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20)
|
| 504 |
+
bin_conf = bin_grad / (bin_conf_interval + 1.0e-20)
|
| 505 |
+
|
| 506 |
+
def tensor_to_str(x: Tensor):
|
| 507 |
+
x = ["%.2g" % f for f in x]
|
| 508 |
+
x = "[" + " ".join(x) + "]"
|
| 509 |
+
return x
|
| 510 |
+
|
| 511 |
+
maybe_class_name = (
|
| 512 |
+
f" type={self.class_name}," if self.class_name is not None else ""
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
print(
|
| 516 |
+
f"module={self.name},{maybe_class_name} "
|
| 517 |
+
f"bin-boundaries={tensor_to_str(bin_boundaries)}, "
|
| 518 |
+
f"rel_grad={tensor_to_str(bin_rel_grad)}, "
|
| 519 |
+
f"grad_conf={tensor_to_str(bin_conf)}"
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class ModelDiagnostic(object):
|
| 524 |
+
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
opts:
|
| 528 |
+
Options object.
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
def __init__(self, opts: Optional[TensorDiagnosticOptions] = None):
|
| 532 |
+
# In this dictionary, the keys are tensors names and the values
|
| 533 |
+
# are corresponding TensorDiagnostic objects.
|
| 534 |
+
if opts is None:
|
| 535 |
+
self.opts = TensorDiagnosticOptions()
|
| 536 |
+
else:
|
| 537 |
+
self.opts = opts
|
| 538 |
+
self.diagnostics = dict()
|
| 539 |
+
|
| 540 |
+
def __getitem__(self, name: str):
|
| 541 |
+
T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic
|
| 542 |
+
if name not in self.diagnostics:
|
| 543 |
+
self.diagnostics[name] = T(self.opts, name)
|
| 544 |
+
return self.diagnostics[name]
|
| 545 |
+
|
| 546 |
+
def print_diagnostics(self):
|
| 547 |
+
"""Print diagnostics for each tensor."""
|
| 548 |
+
for k in sorted(self.diagnostics.keys()):
|
| 549 |
+
self.diagnostics[k].print_diagnostics()
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def get_class_name(module: nn.Module):
|
| 553 |
+
ans = type(module).__name__
|
| 554 |
+
# we put the below in try blocks in case anyone is using a different version of
|
| 555 |
+
# these modules that might have different member names.
|
| 556 |
+
if ans == "Balancer" or ans == "ActivationBalancer":
|
| 557 |
+
try:
|
| 558 |
+
ans += f"[{float(module.min_positive)},{float(module.max_positive)},"
|
| 559 |
+
f"{float(module.min_abs)},{float(module.max_abs)}]"
|
| 560 |
+
except:
|
| 561 |
+
pass
|
| 562 |
+
elif ans == "AbsValuePenalizer":
|
| 563 |
+
try:
|
| 564 |
+
ans += f"[{module.limit}]"
|
| 565 |
+
except:
|
| 566 |
+
pass
|
| 567 |
+
return ans
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def attach_diagnostics(
|
| 571 |
+
model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
|
| 572 |
+
) -> ModelDiagnostic:
|
| 573 |
+
"""Attach a ModelDiagnostic object to the model by
|
| 574 |
+
1) registering forward hook and backward hook on each module, to accumulate
|
| 575 |
+
its output tensors and gradient tensors, respectively;
|
| 576 |
+
2) registering backward hook on each module parameter, to accumulate its
|
| 577 |
+
values and gradients.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
model:
|
| 581 |
+
the model to be analyzed.
|
| 582 |
+
opts:
|
| 583 |
+
Options object.
|
| 584 |
+
|
| 585 |
+
Returns:
|
| 586 |
+
The ModelDiagnostic object attached to the model.
|
| 587 |
+
"""
|
| 588 |
+
|
| 589 |
+
ans = ModelDiagnostic(opts)
|
| 590 |
+
for name, module in model.named_modules():
|
| 591 |
+
if name == "":
|
| 592 |
+
name = "<top-level>"
|
| 593 |
+
|
| 594 |
+
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
| 595 |
+
# capture the variables, ensures that we use the current values.
|
| 596 |
+
# (this matters for `name`, since the variable gets overwritten).
|
| 597 |
+
# These closures don't really capture by value, only by
|
| 598 |
+
# "the final value the variable got in the function" :-(
|
| 599 |
+
def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
| 600 |
+
if isinstance(_output, tuple) and len(_output) == 1:
|
| 601 |
+
_output = _output[0]
|
| 602 |
+
|
| 603 |
+
if isinstance(_output, Tensor) and _output.dtype in (
|
| 604 |
+
torch.float32,
|
| 605 |
+
torch.float16,
|
| 606 |
+
torch.float64,
|
| 607 |
+
):
|
| 608 |
+
_model_diagnostic[f"{_name}.output"].accumulate(
|
| 609 |
+
_output, class_name=get_class_name(_module)
|
| 610 |
+
)
|
| 611 |
+
elif isinstance(_output, tuple):
|
| 612 |
+
for i, o in enumerate(_output):
|
| 613 |
+
if isinstance(o, Tensor) and o.dtype in (
|
| 614 |
+
torch.float32,
|
| 615 |
+
torch.float16,
|
| 616 |
+
torch.float64,
|
| 617 |
+
):
|
| 618 |
+
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
|
| 619 |
+
o, class_name=get_class_name(_module)
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
| 623 |
+
if isinstance(_output, tuple) and len(_output) == 1:
|
| 624 |
+
_output = _output[0]
|
| 625 |
+
if isinstance(_output, Tensor) and _output.dtype in (
|
| 626 |
+
torch.float32,
|
| 627 |
+
torch.float16,
|
| 628 |
+
torch.float64,
|
| 629 |
+
):
|
| 630 |
+
_model_diagnostic[f"{_name}.grad"].accumulate(
|
| 631 |
+
_output, class_name=get_class_name(_module)
|
| 632 |
+
)
|
| 633 |
+
elif isinstance(_output, tuple):
|
| 634 |
+
for i, o in enumerate(_output):
|
| 635 |
+
if isinstance(o, Tensor) and o.dtype in (
|
| 636 |
+
torch.float32,
|
| 637 |
+
torch.float16,
|
| 638 |
+
torch.float64,
|
| 639 |
+
):
|
| 640 |
+
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
|
| 641 |
+
o, class_name=get_class_name(_module)
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
module.register_forward_hook(forward_hook)
|
| 645 |
+
module.register_backward_hook(backward_hook)
|
| 646 |
+
|
| 647 |
+
if type(module).__name__ in [
|
| 648 |
+
"Sigmoid",
|
| 649 |
+
"Tanh",
|
| 650 |
+
"ReLU",
|
| 651 |
+
"TanSwish",
|
| 652 |
+
"Swish",
|
| 653 |
+
"DoubleSwish",
|
| 654 |
+
"Swoosh",
|
| 655 |
+
]:
|
| 656 |
+
# For these specific module types, accumulate some additional diagnostics
|
| 657 |
+
# that can help us improve the activation function. These require a lot of
|
| 658 |
+
# memory, to save the forward activations, so limit this to some select
|
| 659 |
+
# classes. Note: this will not work correctly for all model types.
|
| 660 |
+
def scalar_forward_hook(
|
| 661 |
+
_module, _input, _output, _model_diagnostic=ans, _name=name
|
| 662 |
+
):
|
| 663 |
+
if isinstance(_input, tuple):
|
| 664 |
+
(_input,) = _input
|
| 665 |
+
assert isinstance(_input, Tensor)
|
| 666 |
+
_model_diagnostic[f"{_name}.scalar"].accumulate_input(
|
| 667 |
+
_input, class_name=get_class_name(_module)
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
def scalar_backward_hook(
|
| 671 |
+
_module, _input, _output, _model_diagnostic=ans, _name=name
|
| 672 |
+
):
|
| 673 |
+
if isinstance(_output, tuple):
|
| 674 |
+
(_output,) = _output
|
| 675 |
+
assert isinstance(_output, Tensor)
|
| 676 |
+
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
|
| 677 |
+
|
| 678 |
+
module.register_forward_hook(scalar_forward_hook)
|
| 679 |
+
module.register_backward_hook(scalar_backward_hook)
|
| 680 |
+
|
| 681 |
+
for name, parameter in model.named_parameters():
|
| 682 |
+
|
| 683 |
+
def param_backward_hook(
|
| 684 |
+
grad, _parameter=parameter, _model_diagnostic=ans, _name=name
|
| 685 |
+
):
|
| 686 |
+
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
|
| 687 |
+
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
|
| 688 |
+
|
| 689 |
+
try:
|
| 690 |
+
parameter.register_hook(param_backward_hook)
|
| 691 |
+
except:
|
| 692 |
+
logging.warning(
|
| 693 |
+
f"Warning: could not register backward hook for parameter {name}, "
|
| 694 |
+
f"it might not be differentiable."
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
return ans
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def _test_tensor_diagnostic():
|
| 701 |
+
opts = TensorDiagnosticOptions(512)
|
| 702 |
+
|
| 703 |
+
diagnostic = TensorDiagnostic(opts, "foo")
|
| 704 |
+
|
| 705 |
+
for _ in range(10):
|
| 706 |
+
diagnostic.accumulate(torch.randn(50, 100) * 10.0)
|
| 707 |
+
|
| 708 |
+
diagnostic.print_diagnostics()
|
| 709 |
+
|
| 710 |
+
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80))
|
| 711 |
+
|
| 712 |
+
diagnostic = attach_diagnostics(model, opts)
|
| 713 |
+
for _ in range(10):
|
| 714 |
+
T = random.randint(200, 300)
|
| 715 |
+
x = torch.randn(T, 100)
|
| 716 |
+
y = model(x)
|
| 717 |
+
y.sum().backward()
|
| 718 |
+
|
| 719 |
+
diagnostic.print_diagnostics()
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if __name__ == "__main__":
|
| 723 |
+
_test_tensor_diagnostic()
|
zipvoice/utils/feature.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torchaudio
|
| 24 |
+
from lhotse.features.base import FeatureExtractor, register_extractor
|
| 25 |
+
from lhotse.utils import Seconds, compute_num_frames
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class VocosFbankConfig:
|
| 30 |
+
sampling_rate: int = 24000
|
| 31 |
+
n_mels: int = 100
|
| 32 |
+
n_fft: int = 1024
|
| 33 |
+
hop_length: int = 256
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@register_extractor
|
| 37 |
+
class VocosFbank(FeatureExtractor):
|
| 38 |
+
|
| 39 |
+
name = "VocosFbank"
|
| 40 |
+
config_type = VocosFbankConfig
|
| 41 |
+
|
| 42 |
+
def __init__(self, num_channels: int = 1):
|
| 43 |
+
config = VocosFbankConfig
|
| 44 |
+
super().__init__(config=config)
|
| 45 |
+
assert num_channels in (1, 2)
|
| 46 |
+
self.num_channels = num_channels
|
| 47 |
+
self.fbank = torchaudio.transforms.MelSpectrogram(
|
| 48 |
+
sample_rate=self.config.sampling_rate,
|
| 49 |
+
n_fft=self.config.n_fft,
|
| 50 |
+
hop_length=self.config.hop_length,
|
| 51 |
+
n_mels=self.config.n_mels,
|
| 52 |
+
center=True,
|
| 53 |
+
power=1,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def _feature_fn(self, sample):
|
| 57 |
+
mel = self.fbank(sample)
|
| 58 |
+
logmel = mel.clamp(min=1e-7).log()
|
| 59 |
+
|
| 60 |
+
return logmel
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def device(self) -> Union[str, torch.device]:
|
| 64 |
+
return self.config.device
|
| 65 |
+
|
| 66 |
+
def feature_dim(self, sampling_rate: int) -> int:
|
| 67 |
+
return self.config.n_mels
|
| 68 |
+
|
| 69 |
+
def extract(
|
| 70 |
+
self,
|
| 71 |
+
samples: Union[np.ndarray, torch.Tensor],
|
| 72 |
+
sampling_rate: int,
|
| 73 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
| 74 |
+
# Check for sampling rate compatibility.
|
| 75 |
+
expected_sr = self.config.sampling_rate
|
| 76 |
+
assert sampling_rate == expected_sr, (
|
| 77 |
+
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
| 78 |
+
f"got {sampling_rate}"
|
| 79 |
+
)
|
| 80 |
+
is_numpy = False
|
| 81 |
+
if not isinstance(samples, torch.Tensor):
|
| 82 |
+
samples = torch.from_numpy(samples)
|
| 83 |
+
is_numpy = True
|
| 84 |
+
|
| 85 |
+
if len(samples.shape) == 1:
|
| 86 |
+
samples = samples.unsqueeze(0)
|
| 87 |
+
else:
|
| 88 |
+
assert samples.ndim == 2, samples.shape
|
| 89 |
+
|
| 90 |
+
if self.num_channels == 1:
|
| 91 |
+
if samples.shape[0] == 2:
|
| 92 |
+
samples = samples.mean(dim=0, keepdims=True)
|
| 93 |
+
else:
|
| 94 |
+
assert samples.shape[0] == 2, samples.shape
|
| 95 |
+
|
| 96 |
+
mel = self._feature_fn(samples)
|
| 97 |
+
# (1, n_mels, time) or (2, n_mels, time)
|
| 98 |
+
mel = mel.reshape(-1, mel.shape[-1]).t()
|
| 99 |
+
# (time, n_mels) or (time, 2 * n_mels)
|
| 100 |
+
|
| 101 |
+
num_frames = compute_num_frames(
|
| 102 |
+
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if mel.shape[0] > num_frames:
|
| 106 |
+
mel = mel[:num_frames]
|
| 107 |
+
elif mel.shape[0] < num_frames:
|
| 108 |
+
mel = mel.unsqueeze(0)
|
| 109 |
+
mel = torch.nn.functional.pad(
|
| 110 |
+
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
| 111 |
+
).squeeze(0)
|
| 112 |
+
|
| 113 |
+
if is_numpy:
|
| 114 |
+
return mel.cpu().numpy()
|
| 115 |
+
else:
|
| 116 |
+
return mel
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def frame_shift(self) -> Seconds:
|
| 120 |
+
return self.config.hop_length / self.config.sampling_rate
|
zipvoice/utils/hooks.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
|
| 2 |
+
# Daniel Povey,
|
| 3 |
+
# Zengrui Jin,)
|
| 4 |
+
#
|
| 5 |
+
# See ../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import random
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import Tensor, nn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def register_inf_check_hooks(model: nn.Module) -> None:
|
| 27 |
+
"""Registering forward hook on each module, to check
|
| 28 |
+
whether its output tensors is not finite.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
model:
|
| 32 |
+
the model to be analyzed.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
for name, module in model.named_modules():
|
| 36 |
+
if name == "":
|
| 37 |
+
name = "<top-level>"
|
| 38 |
+
|
| 39 |
+
# default param _name is a way to capture the current value of the variable
|
| 40 |
+
# "name".
|
| 41 |
+
def forward_hook(_module, _input, _output, _name=name):
|
| 42 |
+
if isinstance(_output, Tensor):
|
| 43 |
+
try:
|
| 44 |
+
if not torch.isfinite(_output.to(torch.float32).sum()):
|
| 45 |
+
logging.warning(f"The sum of {_name}.output is not finite")
|
| 46 |
+
except RuntimeError: # e.g. CUDA out of memory
|
| 47 |
+
pass
|
| 48 |
+
elif isinstance(_output, tuple):
|
| 49 |
+
for i, o in enumerate(_output):
|
| 50 |
+
if isinstance(o, tuple):
|
| 51 |
+
o = o[0]
|
| 52 |
+
if not isinstance(o, Tensor):
|
| 53 |
+
continue
|
| 54 |
+
try:
|
| 55 |
+
if not torch.isfinite(o.to(torch.float32).sum()):
|
| 56 |
+
logging.warning(
|
| 57 |
+
f"The sum of {_name}.output[{i}] is not finite"
|
| 58 |
+
)
|
| 59 |
+
except RuntimeError: # e.g. CUDA out of memory
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
# default param _name is a way to capture the current value of the variable
|
| 63 |
+
# "name".
|
| 64 |
+
def backward_hook(_module, _input, _output, _name=name):
|
| 65 |
+
if isinstance(_output, Tensor):
|
| 66 |
+
try:
|
| 67 |
+
if not torch.isfinite(_output.to(torch.float32).sum()):
|
| 68 |
+
logging.warning(f"The sum of {_name}.grad is not finite")
|
| 69 |
+
except RuntimeError: # e.g. CUDA out of memory
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
elif isinstance(_output, tuple):
|
| 73 |
+
for i, o in enumerate(_output):
|
| 74 |
+
if isinstance(o, tuple):
|
| 75 |
+
o = o[0]
|
| 76 |
+
if not isinstance(o, Tensor):
|
| 77 |
+
continue
|
| 78 |
+
if not torch.isfinite(o.to(torch.float32).sum()):
|
| 79 |
+
logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
|
| 80 |
+
|
| 81 |
+
module.register_forward_hook(forward_hook)
|
| 82 |
+
module.register_backward_hook(backward_hook)
|
| 83 |
+
|
| 84 |
+
for name, parameter in model.named_parameters():
|
| 85 |
+
|
| 86 |
+
def param_backward_hook(grad, _name=name):
|
| 87 |
+
if not torch.isfinite(grad.to(torch.float32).sum()):
|
| 88 |
+
logging.warning(f"The sum of {_name}.param_grad is not finite")
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
parameter.register_hook(param_backward_hook)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logging.warning(
|
| 94 |
+
f"Warning: could not register backward hook for parameter {name}"
|
| 95 |
+
f" with error {e}, it might not be differentiable."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _test_inf_check_hooks():
|
| 100 |
+
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
|
| 101 |
+
|
| 102 |
+
register_inf_check_hooks(model)
|
| 103 |
+
for _ in range(10):
|
| 104 |
+
T = random.randint(200, 300)
|
| 105 |
+
x = torch.randn(T, 100) + float("inf") * (T % 2)
|
| 106 |
+
y = model(x)
|
| 107 |
+
y.sum().backward()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
_test_inf_check_hooks()
|
zipvoice/utils/lr_scheduler.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
| 2 |
+
#
|
| 3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
from typing import List, Optional, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch.optim import Optimizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LRScheduler(object):
|
| 25 |
+
"""
|
| 26 |
+
Base-class for learning rate schedulers where the learning-rate depends on both the
|
| 27 |
+
batch and the epoch.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
| 31 |
+
# Attach optimizer
|
| 32 |
+
if not isinstance(optimizer, Optimizer):
|
| 33 |
+
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
|
| 34 |
+
self.optimizer = optimizer
|
| 35 |
+
self.verbose = verbose
|
| 36 |
+
|
| 37 |
+
for group in optimizer.param_groups:
|
| 38 |
+
group.setdefault("base_lr", group["lr"])
|
| 39 |
+
|
| 40 |
+
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
|
| 41 |
+
|
| 42 |
+
self.epoch = 0
|
| 43 |
+
self.batch = 0
|
| 44 |
+
|
| 45 |
+
def state_dict(self):
|
| 46 |
+
"""Returns the state of the scheduler as a :class:`dict`.
|
| 47 |
+
|
| 48 |
+
It contains an entry for every variable in self.__dict__ which
|
| 49 |
+
is not the optimizer.
|
| 50 |
+
"""
|
| 51 |
+
return {
|
| 52 |
+
# the user might try to override the base_lr, so don't include this in the
|
| 53 |
+
# state. previously they were included.
|
| 54 |
+
# "base_lrs": self.base_lrs,
|
| 55 |
+
"epoch": self.epoch,
|
| 56 |
+
"batch": self.batch,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
def load_state_dict(self, state_dict):
|
| 60 |
+
"""Loads the schedulers state.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
state_dict (dict): scheduler state. Should be an object returned
|
| 64 |
+
from a call to :meth:`state_dict`.
|
| 65 |
+
"""
|
| 66 |
+
# the things with base_lrs are a work-around for a previous problem
|
| 67 |
+
# where base_lrs were written with the state dict.
|
| 68 |
+
base_lrs = self.base_lrs
|
| 69 |
+
self.__dict__.update(state_dict)
|
| 70 |
+
self.base_lrs = base_lrs
|
| 71 |
+
|
| 72 |
+
def get_last_lr(self) -> List[float]:
|
| 73 |
+
"""Return last computed learning rate by current scheduler.
|
| 74 |
+
Will be a list of float."""
|
| 75 |
+
return self._last_lr
|
| 76 |
+
|
| 77 |
+
def get_lr(self):
|
| 78 |
+
# Compute list of learning rates from self.epoch and self.batch and
|
| 79 |
+
# self.base_lrs; this must be overloaded by the user.
|
| 80 |
+
# e.g. return [some_formula(self.batch, self.epoch, base_lr)
|
| 81 |
+
# for base_lr in self.base_lrs ]
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
|
| 84 |
+
def step_batch(self, batch: Optional[int] = None) -> None:
|
| 85 |
+
# Step the batch index, or just set it. If `batch` is specified, it
|
| 86 |
+
# must be the batch index from the start of training, i.e. summed over
|
| 87 |
+
# all epochs.
|
| 88 |
+
# You can call this in any order; if you don't provide 'batch', it should
|
| 89 |
+
# of course be called once per batch.
|
| 90 |
+
if batch is not None:
|
| 91 |
+
self.batch = batch
|
| 92 |
+
else:
|
| 93 |
+
self.batch = self.batch + 1
|
| 94 |
+
self._set_lrs()
|
| 95 |
+
|
| 96 |
+
def step_epoch(self, epoch: Optional[int] = None):
|
| 97 |
+
# Step the epoch index, or just set it. If you provide the 'epoch' arg, you
|
| 98 |
+
# should call this at the start of the epoch; if you don't provide the 'epoch'
|
| 99 |
+
# arg, you should call it at the end of the epoch.
|
| 100 |
+
if epoch is not None:
|
| 101 |
+
self.epoch = epoch
|
| 102 |
+
else:
|
| 103 |
+
self.epoch = self.epoch + 1
|
| 104 |
+
self._set_lrs()
|
| 105 |
+
|
| 106 |
+
def _set_lrs(self):
|
| 107 |
+
values = self.get_lr()
|
| 108 |
+
assert len(values) == len(self.optimizer.param_groups)
|
| 109 |
+
|
| 110 |
+
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
| 111 |
+
param_group, lr = data
|
| 112 |
+
param_group["lr"] = lr
|
| 113 |
+
self.print_lr(self.verbose, i, lr)
|
| 114 |
+
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
| 115 |
+
|
| 116 |
+
def print_lr(self, is_verbose, group, lr):
|
| 117 |
+
"""Display the current learning rate."""
|
| 118 |
+
if is_verbose:
|
| 119 |
+
logging.warning(
|
| 120 |
+
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
| 121 |
+
f" of group {group} to {lr:.4e}."
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Eden(LRScheduler):
|
| 126 |
+
"""
|
| 127 |
+
Eden scheduler.
|
| 128 |
+
The basic formula (before warmup) is:
|
| 129 |
+
lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
| 130 |
+
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
|
| 131 |
+
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
| 132 |
+
and then stays constant at 1.
|
| 133 |
+
|
| 134 |
+
If you don't have the concept of epochs, or one epoch takes a very long time,
|
| 135 |
+
you can replace the notion of 'epoch' with some measure of the amount of data
|
| 136 |
+
processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
|
| 137 |
+
some measure representing "quite a lot of data": say, one fifth or one third
|
| 138 |
+
of an entire training run, but it doesn't matter much. You could also use
|
| 139 |
+
Eden2 which has only the notion of batches.
|
| 140 |
+
|
| 141 |
+
We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
optimizer: the optimizer to change the learning rates on
|
| 145 |
+
lr_batches: the number of batches after which we start significantly
|
| 146 |
+
decreasing the learning rate, suggest 5000.
|
| 147 |
+
lr_epochs: the number of epochs after which we start significantly
|
| 148 |
+
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
| 149 |
+
20 to 40 epochs, but may need smaller number if dataset is huge
|
| 150 |
+
and you will do few epochs.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
optimizer: Optimizer,
|
| 156 |
+
lr_batches: Union[int, float],
|
| 157 |
+
lr_epochs: Union[int, float],
|
| 158 |
+
warmup_batches: Union[int, float] = 500.0,
|
| 159 |
+
warmup_start: float = 0.5,
|
| 160 |
+
verbose: bool = False,
|
| 161 |
+
):
|
| 162 |
+
super(Eden, self).__init__(optimizer, verbose)
|
| 163 |
+
self.lr_batches = lr_batches
|
| 164 |
+
self.lr_epochs = lr_epochs
|
| 165 |
+
self.warmup_batches = warmup_batches
|
| 166 |
+
|
| 167 |
+
assert 0.0 <= warmup_start <= 1.0, warmup_start
|
| 168 |
+
self.warmup_start = warmup_start
|
| 169 |
+
|
| 170 |
+
def get_lr(self):
|
| 171 |
+
factor = (
|
| 172 |
+
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
|
| 173 |
+
) ** -0.25 * (
|
| 174 |
+
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
|
| 175 |
+
)
|
| 176 |
+
warmup_factor = (
|
| 177 |
+
1.0
|
| 178 |
+
if self.batch >= self.warmup_batches
|
| 179 |
+
else self.warmup_start
|
| 180 |
+
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
| 181 |
+
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return [x * factor * warmup_factor for x in self.base_lrs]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class FixedLRScheduler(LRScheduler):
|
| 188 |
+
"""
|
| 189 |
+
Fixed learning rate scheduler.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
optimizer: the optimizer to change the learning rates on
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
optimizer: Optimizer,
|
| 198 |
+
verbose: bool = False,
|
| 199 |
+
):
|
| 200 |
+
super(FixedLRScheduler, self).__init__(optimizer, verbose)
|
| 201 |
+
|
| 202 |
+
def get_lr(self):
|
| 203 |
+
|
| 204 |
+
return [x for x in self.base_lrs]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _test_eden():
|
| 208 |
+
m = torch.nn.Linear(100, 100)
|
| 209 |
+
from zipvoice.utils.optim import ScaledAdam
|
| 210 |
+
|
| 211 |
+
optim = ScaledAdam(m.parameters(), lr=0.03)
|
| 212 |
+
|
| 213 |
+
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
| 214 |
+
|
| 215 |
+
for epoch in range(10):
|
| 216 |
+
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
| 217 |
+
|
| 218 |
+
for step in range(20):
|
| 219 |
+
x = torch.randn(200, 100).detach()
|
| 220 |
+
x.requires_grad = True
|
| 221 |
+
y = m(x)
|
| 222 |
+
dy = torch.randn(200, 100).detach()
|
| 223 |
+
f = (y * dy).sum()
|
| 224 |
+
f.backward()
|
| 225 |
+
|
| 226 |
+
optim.step()
|
| 227 |
+
scheduler.step_batch()
|
| 228 |
+
optim.zero_grad()
|
| 229 |
+
|
| 230 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
| 231 |
+
logging.info(f"state dict = {scheduler.state_dict()}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
torch.set_num_threads(1)
|
| 236 |
+
torch.set_num_interop_threads(1)
|
| 237 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 238 |
+
import subprocess
|
| 239 |
+
|
| 240 |
+
s = subprocess.check_output(
|
| 241 |
+
"git status -uno .; git log -1; git diff HEAD .", shell=True
|
| 242 |
+
)
|
| 243 |
+
logging.info(s)
|
| 244 |
+
|
| 245 |
+
_test_eden()
|
zipvoice/utils/optim.py
ADDED
|
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
| 2 |
+
#
|
| 3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import contextlib
|
| 18 |
+
import logging
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from typing import Dict, List, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from lhotse.utils import fix_random_seed
|
| 24 |
+
from torch import Tensor
|
| 25 |
+
from torch.optim import Optimizer
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BatchedOptimizer(Optimizer):
|
| 29 |
+
"""
|
| 30 |
+
This class adds to class Optimizer the capability to optimize parameters in batches:
|
| 31 |
+
it will stack the parameters and their grads for you so the optimizer can work
|
| 32 |
+
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
| 33 |
+
as it reduces the number of kernels launched in the optimizer.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
params:
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, params, defaults):
|
| 40 |
+
super(BatchedOptimizer, self).__init__(params, defaults)
|
| 41 |
+
|
| 42 |
+
@contextlib.contextmanager
|
| 43 |
+
def batched_params(self, param_group, group_params_names):
|
| 44 |
+
"""
|
| 45 |
+
This function returns (technically, yields) a list of
|
| 46 |
+
of tuples (p, state), where
|
| 47 |
+
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
| 48 |
+
that share the same shape, and its gradient is also stacked;
|
| 49 |
+
`state` is the state corresponding to this batch of parameters
|
| 50 |
+
(it will be physically located in the "state" for one of the real
|
| 51 |
+
parameters, the last one that has any particular shape and dtype).
|
| 52 |
+
|
| 53 |
+
This function is decorated as a context manager so that it can
|
| 54 |
+
write parameters back to their "real" locations.
|
| 55 |
+
|
| 56 |
+
The idea is, instead of doing:
|
| 57 |
+
<code>
|
| 58 |
+
for p in group["params"]:
|
| 59 |
+
state = self.state[p]
|
| 60 |
+
...
|
| 61 |
+
</code>
|
| 62 |
+
you can do:
|
| 63 |
+
<code>
|
| 64 |
+
with self.batched_params(group["params"]) as batches:
|
| 65 |
+
for p, state, p_names in batches:
|
| 66 |
+
...
|
| 67 |
+
</code>
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
group: a parameter group, which is a list of parameters; should be
|
| 71 |
+
one of self.param_groups.
|
| 72 |
+
group_params_names: name for each parameter in group,
|
| 73 |
+
which is List[str].
|
| 74 |
+
"""
|
| 75 |
+
batches = defaultdict(
|
| 76 |
+
list
|
| 77 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
| 78 |
+
batches_names = defaultdict(
|
| 79 |
+
list
|
| 80 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
| 81 |
+
|
| 82 |
+
assert len(param_group) == len(group_params_names)
|
| 83 |
+
for p, named_p in zip(param_group, group_params_names):
|
| 84 |
+
key = (str(p.dtype), *p.shape)
|
| 85 |
+
batches[key].append(p)
|
| 86 |
+
batches_names[key].append(named_p)
|
| 87 |
+
|
| 88 |
+
batches_names_keys = list(batches_names.keys())
|
| 89 |
+
sorted_idx = sorted(
|
| 90 |
+
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
| 91 |
+
)
|
| 92 |
+
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
| 93 |
+
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
| 94 |
+
|
| 95 |
+
stacked_params_dict = dict()
|
| 96 |
+
|
| 97 |
+
# turn batches into a list, in deterministic order.
|
| 98 |
+
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
| 99 |
+
# one for each batch in `batches`.
|
| 100 |
+
tuples = []
|
| 101 |
+
|
| 102 |
+
for batch, batch_names in zip(batches, batches_names):
|
| 103 |
+
p = batch[0]
|
| 104 |
+
# we arbitrarily store the state in the
|
| 105 |
+
# state corresponding to the 1st parameter in the
|
| 106 |
+
# group. class Optimizer will take care of saving/loading state.
|
| 107 |
+
state = self.state[p]
|
| 108 |
+
p_stacked = torch.stack(batch)
|
| 109 |
+
grad = torch.stack(
|
| 110 |
+
[torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
|
| 111 |
+
)
|
| 112 |
+
p_stacked.grad = grad
|
| 113 |
+
stacked_params_dict[key] = p_stacked
|
| 114 |
+
tuples.append((p_stacked, state, batch_names))
|
| 115 |
+
|
| 116 |
+
yield tuples # <-- calling code will do the actual optimization here!
|
| 117 |
+
|
| 118 |
+
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
| 119 |
+
for i, p in enumerate(batch): # batch is list of Parameter
|
| 120 |
+
p.copy_(stacked_params[i])
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def basic_step(group, p, state, grad):
|
| 124 |
+
# computes basic Adam update using beta2 (dividing by gradient stddev) only. no
|
| 125 |
+
# momentum yet.
|
| 126 |
+
lr = group["lr"]
|
| 127 |
+
if p.numel() == p.shape[0]:
|
| 128 |
+
lr = lr * group["scalar_lr_scale"]
|
| 129 |
+
beta2 = group["betas"][1]
|
| 130 |
+
eps = group["eps"]
|
| 131 |
+
# p shape: (batch_size,) or (batch_size, 1, [1,..])
|
| 132 |
+
try:
|
| 133 |
+
exp_avg_sq = state[
|
| 134 |
+
"exp_avg_sq"
|
| 135 |
+
] # shape: (batch_size,) or (batch_size, 1, [1,..])
|
| 136 |
+
except KeyError:
|
| 137 |
+
exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
|
| 138 |
+
state["exp_avg_sq"] = exp_avg_sq
|
| 139 |
+
|
| 140 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 141 |
+
|
| 142 |
+
# bias_correction2 is like in Adam.
|
| 143 |
+
# slower update at the start will help stability anyway.
|
| 144 |
+
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
| 145 |
+
if bias_correction2 < 0.99:
|
| 146 |
+
# note: not in-place.
|
| 147 |
+
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
| 148 |
+
denom = exp_avg_sq.sqrt().add_(eps)
|
| 149 |
+
|
| 150 |
+
return -lr * grad / denom
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def scaling_step(group, p, state, grad):
|
| 154 |
+
delta = basic_step(group, p, state, grad)
|
| 155 |
+
if p.numel() == p.shape[0]:
|
| 156 |
+
return delta
|
| 157 |
+
# there is no scaling for scalar parameters.
|
| 158 |
+
# (p.shape[0] is the batch of parameters.)
|
| 159 |
+
|
| 160 |
+
step = state["step"]
|
| 161 |
+
size_update_period = group["size_update_period"]
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
param_rms = state["param_rms"]
|
| 165 |
+
scale_grads = state["scale_grads"]
|
| 166 |
+
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
| 167 |
+
except KeyError:
|
| 168 |
+
# we know p.ndim > 1 because we'd have returned above if not, so don't worry
|
| 169 |
+
# about the speial case of dim=[] that pytorch treats inconsistently.
|
| 170 |
+
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
| 171 |
+
param_rms = param_rms.to(torch.float)
|
| 172 |
+
scale_exp_avg_sq = torch.zeros_like(param_rms)
|
| 173 |
+
scale_grads = torch.zeros(
|
| 174 |
+
size_update_period,
|
| 175 |
+
*param_rms.shape,
|
| 176 |
+
dtype=torch.float,
|
| 177 |
+
device=p.device,
|
| 178 |
+
)
|
| 179 |
+
state["param_rms"] = param_rms
|
| 180 |
+
state["scale_grads"] = scale_grads
|
| 181 |
+
state["scale_exp_avg_sq"] = scale_exp_avg_sq
|
| 182 |
+
|
| 183 |
+
# on every step, update the gradient w.r.t. the scale of the parameter, we
|
| 184 |
+
# store these as a batch and periodically update the size (for speed only, to
|
| 185 |
+
# avoid too many operations).
|
| 186 |
+
scale_grads[step % size_update_period] = (p * grad).sum(
|
| 187 |
+
dim=list(range(1, p.ndim)), keepdim=True
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# periodically recompute the value of param_rms.
|
| 191 |
+
if step % size_update_period == size_update_period - 1:
|
| 192 |
+
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
| 193 |
+
|
| 194 |
+
param_min_rms = group["param_min_rms"]
|
| 195 |
+
|
| 196 |
+
# scale the step size by param_rms. This is the most important "scaling" part of
|
| 197 |
+
# ScaledAdam
|
| 198 |
+
delta *= param_rms.clamp(min=param_min_rms)
|
| 199 |
+
|
| 200 |
+
if step % size_update_period == size_update_period - 1 and step > 0:
|
| 201 |
+
# This block updates the size of parameter by adding a step ("delta") value in
|
| 202 |
+
# the direction of either shrinking or growing it.
|
| 203 |
+
beta2 = group["betas"][1]
|
| 204 |
+
size_lr = group["lr"] * group["scalar_lr_scale"]
|
| 205 |
+
param_max_rms = group["param_max_rms"]
|
| 206 |
+
eps = group["eps"]
|
| 207 |
+
# correct beta2 for the size update period: we will have
|
| 208 |
+
# faster decay at this level.
|
| 209 |
+
beta2_corr = beta2**size_update_period
|
| 210 |
+
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
| 211 |
+
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
| 212 |
+
alpha=1 - beta2_corr,
|
| 213 |
+
) # shape is (batch_size, 1, 1, ...)
|
| 214 |
+
|
| 215 |
+
# The 1st time we reach here is when size_step == 1.
|
| 216 |
+
size_step = (step + 1) // size_update_period
|
| 217 |
+
bias_correction2 = 1 - beta2_corr**size_step
|
| 218 |
+
|
| 219 |
+
denom = scale_exp_avg_sq.sqrt() + eps
|
| 220 |
+
|
| 221 |
+
scale_step = (
|
| 222 |
+
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
is_too_small = param_rms < param_min_rms
|
| 226 |
+
|
| 227 |
+
# when the param gets too small, just don't shrink it any further.
|
| 228 |
+
scale_step.masked_fill_(is_too_small, 0.0)
|
| 229 |
+
|
| 230 |
+
# The following may help prevent instability: don't allow the scale step to be
|
| 231 |
+
# too large in either direction.
|
| 232 |
+
scale_step.clamp_(min=-0.1, max=0.1)
|
| 233 |
+
|
| 234 |
+
# and ensure the parameter rms after update never exceeds param_max_rms.
|
| 235 |
+
# We have to look at the trained model for parameters at or around the
|
| 236 |
+
# param_max_rms, because sometimes they can indicate a problem with the
|
| 237 |
+
# topology or settings.
|
| 238 |
+
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
| 239 |
+
|
| 240 |
+
delta.add_(p * scale_step)
|
| 241 |
+
|
| 242 |
+
return delta
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def momentum_step(group, p, state, grad):
|
| 246 |
+
delta = scaling_step(group, p, state, grad)
|
| 247 |
+
beta1 = group["betas"][0]
|
| 248 |
+
try:
|
| 249 |
+
stored_delta = state["delta"]
|
| 250 |
+
except KeyError:
|
| 251 |
+
stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
|
| 252 |
+
state["delta"] = stored_delta
|
| 253 |
+
stored_delta.mul_(beta1)
|
| 254 |
+
stored_delta.add_(delta, alpha=(1 - beta1))
|
| 255 |
+
# we don't bother doing the "bias correction" part of Adam for beta1 because this is
|
| 256 |
+
# just an edge effect that affects the first 10 or so batches; and the effect of not
|
| 257 |
+
# doing it is just to do a slower update for the first few batches, which will help
|
| 258 |
+
# stability.
|
| 259 |
+
return stored_delta
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class ScaledAdam(BatchedOptimizer):
|
| 263 |
+
"""
|
| 264 |
+
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
| 265 |
+
proportional to the norm of that parameter; and also learn the scale of the
|
| 266 |
+
parameter, in log space, subject to upper and lower limits (as if we had factored
|
| 267 |
+
each parameter as param = underlying_param * log_scale.exp())
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
params: The parameters or param_groups to optimize (like other Optimizer
|
| 272 |
+
subclasses) Unlike common optimizers, which accept
|
| 273 |
+
model.parameters() or groups of parameters(), this optimizer
|
| 274 |
+
could accept model.named_parameters() or groups of
|
| 275 |
+
named_parameters(). See comments of function
|
| 276 |
+
_get_names_of_parameters for its 4 possible cases.
|
| 277 |
+
lr: The learning rate. We will typically use a learning rate schedule
|
| 278 |
+
that starts at 0.03 and decreases over time, i.e. much higher
|
| 279 |
+
than other common optimizers.
|
| 280 |
+
clipping_scale: (e.g. 2.0)
|
| 281 |
+
A scale for gradient-clipping: if specified, the normalized gradients
|
| 282 |
+
over the whole model will be clipped to have 2-norm equal to
|
| 283 |
+
`clipping_scale` times the median 2-norm over the most recent period
|
| 284 |
+
of `clipping_update_period` minibatches. By "normalized gradients",
|
| 285 |
+
we mean after multiplying by the rms parameter value for this tensor
|
| 286 |
+
[for non-scalars]; this is appropriate because our update is scaled
|
| 287 |
+
by this quantity.
|
| 288 |
+
betas: beta1,beta2 are momentum constants for regular momentum, and moving
|
| 289 |
+
sum-sq grad. Must satisfy 0 < beta <= beta2 < 1.
|
| 290 |
+
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
| 291 |
+
scale of each parameter tensor and scalar parameters of the mode..
|
| 292 |
+
If each parameter were decomposed as p * p_scale.exp(),
|
| 293 |
+
where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the
|
| 294 |
+
scaling factor on the learning rate of p_scale.
|
| 295 |
+
eps: A general-purpose epsilon to prevent division by zero
|
| 296 |
+
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
| 297 |
+
learning the scale on the parameters (we'll constrain the rms of
|
| 298 |
+
each non-scalar parameter tensor to be >= this value)
|
| 299 |
+
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
| 300 |
+
learning the scale on the parameters (we'll constrain the rms of
|
| 301 |
+
each non-scalar parameter tensor to be <= this value)
|
| 302 |
+
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
| 303 |
+
model has any parameters with numel() == 1).
|
| 304 |
+
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
| 305 |
+
of the parameter tensor. This is provided to save a little time
|
| 306 |
+
in the update.
|
| 307 |
+
clipping_update_period: if clipping_scale is specified, this is the period
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
params,
|
| 313 |
+
lr=3e-02,
|
| 314 |
+
clipping_scale=None,
|
| 315 |
+
betas=(0.9, 0.98),
|
| 316 |
+
scalar_lr_scale=0.1,
|
| 317 |
+
eps=1.0e-08,
|
| 318 |
+
param_min_rms=1.0e-05,
|
| 319 |
+
param_max_rms=3.0,
|
| 320 |
+
scalar_max=10.0,
|
| 321 |
+
size_update_period=4,
|
| 322 |
+
clipping_update_period=100,
|
| 323 |
+
):
|
| 324 |
+
|
| 325 |
+
defaults = dict(
|
| 326 |
+
lr=lr,
|
| 327 |
+
clipping_scale=clipping_scale,
|
| 328 |
+
betas=betas,
|
| 329 |
+
scalar_lr_scale=scalar_lr_scale,
|
| 330 |
+
eps=eps,
|
| 331 |
+
param_min_rms=param_min_rms,
|
| 332 |
+
param_max_rms=param_max_rms,
|
| 333 |
+
scalar_max=scalar_max,
|
| 334 |
+
size_update_period=size_update_period,
|
| 335 |
+
clipping_update_period=clipping_update_period,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# If params only contains parameters or group of parameters,
|
| 339 |
+
# i.e when parameter names are not given,
|
| 340 |
+
# this flag will be set to False in funciton _get_names_of_parameters.
|
| 341 |
+
self.show_dominant_parameters = True
|
| 342 |
+
param_groups, parameters_names = self._get_names_of_parameters(params)
|
| 343 |
+
super(ScaledAdam, self).__init__(param_groups, defaults)
|
| 344 |
+
assert len(self.param_groups) == len(parameters_names)
|
| 345 |
+
self.parameters_names = parameters_names
|
| 346 |
+
|
| 347 |
+
def _get_names_of_parameters(
|
| 348 |
+
self, params_or_named_params
|
| 349 |
+
) -> Tuple[List[Dict], List[List[str]]]:
|
| 350 |
+
"""
|
| 351 |
+
Args:
|
| 352 |
+
params_or_named_params: according to the way ScaledAdam is initialized
|
| 353 |
+
in train.py, this argument could be one of following 4 cases,
|
| 354 |
+
case 1, a generator of parameter, e.g.:
|
| 355 |
+
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
|
| 356 |
+
clipping_scale=3.0)
|
| 357 |
+
|
| 358 |
+
case 2, a list of parameter groups with different config, e.g.:
|
| 359 |
+
model_param_groups = [
|
| 360 |
+
{'params': model.encoder.parameters(), 'lr': 0.05},
|
| 361 |
+
{'params': model.decoder.parameters(), 'lr': 0.01},
|
| 362 |
+
{'params': model.joiner.parameters(), 'lr': 0.03},
|
| 363 |
+
]
|
| 364 |
+
optimizer = ScaledAdam(model_param_groups, lr=params.base_lr,
|
| 365 |
+
clipping_scale=3.0)
|
| 366 |
+
|
| 367 |
+
case 3, a generator of named_parameter, e.g.:
|
| 368 |
+
optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr,
|
| 369 |
+
clipping_scale=3.0)
|
| 370 |
+
|
| 371 |
+
case 4, a list of named_parameter groups with different config, e.g.:
|
| 372 |
+
model_named_param_groups = [
|
| 373 |
+
{'named_params': model.encoder.named_parameters(), 'lr': 0.05},
|
| 374 |
+
{'named_params': model.decoder.named_parameters(), 'lr': 0.01},
|
| 375 |
+
{'named_params': model.joiner.named_parameters(), 'lr': 0.03},
|
| 376 |
+
]
|
| 377 |
+
optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr,
|
| 378 |
+
clipping_scale=3.0)
|
| 379 |
+
|
| 380 |
+
For case 1 and case 2, input params is used to initialize the underlying
|
| 381 |
+
torch.optimizer.
|
| 382 |
+
For case 3 and case 4, firstly, names and params are extracted from input
|
| 383 |
+
named_params, then, these extracted params are used to initialize the
|
| 384 |
+
underlying torch.optimizer, and these extracted names are mainly used by
|
| 385 |
+
function `_show_gradient_dominating_parameter`
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
Returns a tuple containing 2 elements:
|
| 389 |
+
- `param_groups` with type List[Dict], each Dict element is a parameter
|
| 390 |
+
group. An example of `param_groups` could be:
|
| 391 |
+
[
|
| 392 |
+
{'params': `one iterable of Parameter`, 'lr': 0.05},
|
| 393 |
+
{'params': `another iterable of Parameter`, 'lr': 0.08},
|
| 394 |
+
{'params': `a third iterable of Parameter`, 'lr': 0.1},
|
| 395 |
+
]
|
| 396 |
+
- `param_gruops_names` with type List[List[str]],
|
| 397 |
+
each `List[str]` is for a group['params'] in param_groups,
|
| 398 |
+
and each `str` is the name of a parameter.
|
| 399 |
+
A dummy name "foo" is related to each parameter,
|
| 400 |
+
if input are params without names, i.e. case 1 or case 2.
|
| 401 |
+
"""
|
| 402 |
+
# variable naming convention in this function:
|
| 403 |
+
# p is short for param.
|
| 404 |
+
# np is short for named_param.
|
| 405 |
+
# p_or_np is short for param_or_named_param.
|
| 406 |
+
# cur is short for current.
|
| 407 |
+
# group is a dict,
|
| 408 |
+
# e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
|
| 409 |
+
# groups is a List[group]
|
| 410 |
+
|
| 411 |
+
iterable_or_groups = list(params_or_named_params)
|
| 412 |
+
if len(iterable_or_groups) == 0:
|
| 413 |
+
raise ValueError("optimizer got an empty parameter list")
|
| 414 |
+
|
| 415 |
+
# The first value of returned tuple. A list of dicts containing at
|
| 416 |
+
# least 'params' as a key.
|
| 417 |
+
param_groups = []
|
| 418 |
+
|
| 419 |
+
# The second value of returned tuple,
|
| 420 |
+
# a List[List[str]], each sub-List is for a group.
|
| 421 |
+
param_groups_names = []
|
| 422 |
+
|
| 423 |
+
if not isinstance(iterable_or_groups[0], dict):
|
| 424 |
+
# case 1 or case 3,
|
| 425 |
+
# the input is an iterable of parameter or named parameter.
|
| 426 |
+
param_iterable_cur_group = []
|
| 427 |
+
param_names_cur_group = []
|
| 428 |
+
for p_or_np in iterable_or_groups:
|
| 429 |
+
if isinstance(p_or_np, tuple):
|
| 430 |
+
# case 3
|
| 431 |
+
name, param = p_or_np
|
| 432 |
+
else:
|
| 433 |
+
# case 1
|
| 434 |
+
assert isinstance(p_or_np, torch.Tensor)
|
| 435 |
+
param = p_or_np
|
| 436 |
+
# Assign a dummy name as a placeholder
|
| 437 |
+
name = "foo"
|
| 438 |
+
self.show_dominant_parameters = False
|
| 439 |
+
param_iterable_cur_group.append(param)
|
| 440 |
+
param_names_cur_group.append(name)
|
| 441 |
+
param_groups.append({"params": param_iterable_cur_group})
|
| 442 |
+
param_groups_names.append(param_names_cur_group)
|
| 443 |
+
else:
|
| 444 |
+
# case 2 or case 4
|
| 445 |
+
# the input is groups of parameter or named parameter.
|
| 446 |
+
for cur_group in iterable_or_groups:
|
| 447 |
+
if "named_params" in cur_group:
|
| 448 |
+
name_list = [x[0] for x in cur_group["named_params"]]
|
| 449 |
+
p_list = [x[1] for x in cur_group["named_params"]]
|
| 450 |
+
del cur_group["named_params"]
|
| 451 |
+
cur_group["params"] = p_list
|
| 452 |
+
else:
|
| 453 |
+
assert "params" in cur_group
|
| 454 |
+
name_list = ["foo" for _ in cur_group["params"]]
|
| 455 |
+
param_groups.append(cur_group)
|
| 456 |
+
param_groups_names.append(name_list)
|
| 457 |
+
|
| 458 |
+
return param_groups, param_groups_names
|
| 459 |
+
|
| 460 |
+
def __setstate__(self, state):
|
| 461 |
+
super(ScaledAdam, self).__setstate__(state)
|
| 462 |
+
|
| 463 |
+
@torch.no_grad()
|
| 464 |
+
def step(self, closure=None):
|
| 465 |
+
"""Performs a single optimization step.
|
| 466 |
+
|
| 467 |
+
Arguments:
|
| 468 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 469 |
+
and returns the loss.
|
| 470 |
+
"""
|
| 471 |
+
loss = None
|
| 472 |
+
if closure is not None:
|
| 473 |
+
with torch.enable_grad():
|
| 474 |
+
loss = closure()
|
| 475 |
+
|
| 476 |
+
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
| 477 |
+
|
| 478 |
+
with self.batched_params(group["params"], group_params_names) as batches:
|
| 479 |
+
|
| 480 |
+
# batches is list of pairs (stacked_param, state). stacked_param is
|
| 481 |
+
# like a regular parameter, and will have a .grad, but the 1st dim
|
| 482 |
+
# corresponds to a stacking dim, it is not a real dim.
|
| 483 |
+
|
| 484 |
+
if (
|
| 485 |
+
len(batches[0][1]) == 0
|
| 486 |
+
): # if len(first state) == 0: not yet initialized
|
| 487 |
+
clipping_scale = 1
|
| 488 |
+
else:
|
| 489 |
+
clipping_scale = self._get_clipping_scale(group, batches)
|
| 490 |
+
|
| 491 |
+
for p, state, _ in batches:
|
| 492 |
+
# Perform optimization step.
|
| 493 |
+
# grad is not going to be None, we handled that when creating the
|
| 494 |
+
# batches.
|
| 495 |
+
grad = p.grad
|
| 496 |
+
if grad.is_sparse:
|
| 497 |
+
raise RuntimeError(
|
| 498 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
try:
|
| 502 |
+
cur_step = state["step"]
|
| 503 |
+
except KeyError:
|
| 504 |
+
state["step"] = 0
|
| 505 |
+
cur_step = 0
|
| 506 |
+
|
| 507 |
+
grad = (
|
| 508 |
+
p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
|
| 509 |
+
)
|
| 510 |
+
p += momentum_step(group, p.detach(), state, grad)
|
| 511 |
+
|
| 512 |
+
if p.numel() == p.shape[0]: # scalar parameter
|
| 513 |
+
scalar_max = group["scalar_max"]
|
| 514 |
+
p.clamp_(min=-scalar_max, max=scalar_max)
|
| 515 |
+
|
| 516 |
+
state["step"] = cur_step + 1
|
| 517 |
+
|
| 518 |
+
return loss
|
| 519 |
+
|
| 520 |
+
def _get_clipping_scale(
|
| 521 |
+
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
| 522 |
+
) -> float:
|
| 523 |
+
"""
|
| 524 |
+
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will
|
| 525 |
+
scale the gradients by this amount before applying the rest of the update.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
group: the parameter group, an item in self.param_groups
|
| 529 |
+
tuples: a list of tuples of (param, state, param_names)
|
| 530 |
+
where param is a batched set of parameters,
|
| 531 |
+
with a .grad (1st dim is batch dim)
|
| 532 |
+
and state is the state-dict where optimization parameters are kept.
|
| 533 |
+
param_names is a List[str] while each str is name for a parameter
|
| 534 |
+
in batched set of parameters "param".
|
| 535 |
+
"""
|
| 536 |
+
assert len(tuples) >= 1
|
| 537 |
+
clipping_scale = group["clipping_scale"]
|
| 538 |
+
(first_p, first_state, _) = tuples[0]
|
| 539 |
+
step = first_state["step"]
|
| 540 |
+
if clipping_scale is None or step == 0:
|
| 541 |
+
# no clipping. return early on step == 0 because the other
|
| 542 |
+
# parameters' state won't have been initialized yet.
|
| 543 |
+
return 1.0
|
| 544 |
+
clipping_update_period = group["clipping_update_period"]
|
| 545 |
+
scalar_lr_scale = group["scalar_lr_scale"]
|
| 546 |
+
|
| 547 |
+
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
| 548 |
+
for p, state, param_names in tuples:
|
| 549 |
+
grad = p.grad
|
| 550 |
+
if grad.is_sparse:
|
| 551 |
+
raise RuntimeError(
|
| 552 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
| 553 |
+
)
|
| 554 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
| 555 |
+
tot_sumsq += (grad**2).sum() * (
|
| 556 |
+
scalar_lr_scale**2
|
| 557 |
+
) # sum() to change shape [1] to []
|
| 558 |
+
else:
|
| 559 |
+
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
| 560 |
+
|
| 561 |
+
tot_norm = tot_sumsq.sqrt()
|
| 562 |
+
if "model_norms" not in first_state:
|
| 563 |
+
first_state["model_norms"] = torch.zeros(
|
| 564 |
+
clipping_update_period, device=p.device
|
| 565 |
+
)
|
| 566 |
+
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
| 567 |
+
|
| 568 |
+
irregular_estimate_steps = [
|
| 569 |
+
i for i in [10, 20, 40] if i < clipping_update_period
|
| 570 |
+
]
|
| 571 |
+
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
|
| 572 |
+
# Print some stats.
|
| 573 |
+
# We don't reach here if step == 0 because we would have returned
|
| 574 |
+
# above.
|
| 575 |
+
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
| 576 |
+
if step in irregular_estimate_steps:
|
| 577 |
+
sorted_norms = sorted_norms[-step:]
|
| 578 |
+
num_norms = sorted_norms.numel()
|
| 579 |
+
quartiles = []
|
| 580 |
+
for n in range(0, 5):
|
| 581 |
+
index = min(num_norms - 1, (num_norms // 4) * n)
|
| 582 |
+
quartiles.append(sorted_norms[index].item())
|
| 583 |
+
|
| 584 |
+
median = quartiles[2]
|
| 585 |
+
if median - median != 0:
|
| 586 |
+
raise RuntimeError("Too many grads were not finite")
|
| 587 |
+
threshold = clipping_scale * median
|
| 588 |
+
if step in irregular_estimate_steps:
|
| 589 |
+
# use larger thresholds on first few steps of estimating threshold,
|
| 590 |
+
# as norm may be changing rapidly.
|
| 591 |
+
threshold = threshold * 2.0
|
| 592 |
+
first_state["model_norm_threshold"] = threshold
|
| 593 |
+
percent_clipped = (
|
| 594 |
+
first_state["num_clipped"] * 100.0 / num_norms
|
| 595 |
+
if "num_clipped" in first_state
|
| 596 |
+
else 0.0
|
| 597 |
+
)
|
| 598 |
+
first_state["num_clipped"] = 0
|
| 599 |
+
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
| 600 |
+
logging.warning(
|
| 601 |
+
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
| 602 |
+
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
try:
|
| 606 |
+
model_norm_threshold = first_state["model_norm_threshold"]
|
| 607 |
+
except KeyError:
|
| 608 |
+
return 1.0 # threshold has not yet been set.
|
| 609 |
+
|
| 610 |
+
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
| 611 |
+
if ans != ans: # e.g. ans is nan
|
| 612 |
+
ans = 0.0
|
| 613 |
+
if ans < 1.0:
|
| 614 |
+
first_state["num_clipped"] += 1
|
| 615 |
+
if ans < 0.5:
|
| 616 |
+
logging.warning(
|
| 617 |
+
f"Scaling gradients by {ans}, "
|
| 618 |
+
f"model_norm_threshold={model_norm_threshold}"
|
| 619 |
+
)
|
| 620 |
+
if self.show_dominant_parameters:
|
| 621 |
+
assert p.shape[0] == len(param_names)
|
| 622 |
+
self._show_gradient_dominating_parameter(
|
| 623 |
+
tuples, tot_sumsq, group["scalar_lr_scale"]
|
| 624 |
+
)
|
| 625 |
+
self._show_param_with_unusual_grad(tuples)
|
| 626 |
+
|
| 627 |
+
if ans == 0.0:
|
| 628 |
+
for p, state, param_names in tuples:
|
| 629 |
+
p.grad.zero_() # get rid of infinity()
|
| 630 |
+
|
| 631 |
+
return ans
|
| 632 |
+
|
| 633 |
+
def _show_param_with_unusual_grad(
|
| 634 |
+
self,
|
| 635 |
+
tuples: List[Tuple[Tensor, dict, List[str]]],
|
| 636 |
+
):
|
| 637 |
+
"""
|
| 638 |
+
Print information about parameter which has the largest ratio of
|
| 639 |
+
grad-on-this-batch divided by normal grad size.
|
| 640 |
+
tuples: a list of tuples of (param, state, param_names)
|
| 641 |
+
where param is a batched set of parameters,
|
| 642 |
+
with a .grad (1st dim is batch dim)
|
| 643 |
+
and state is the state-dict where optimization parameters are kept.
|
| 644 |
+
param_names is a List[str] while each str is name for a parameter
|
| 645 |
+
in batched set of parameters "param".
|
| 646 |
+
"""
|
| 647 |
+
# ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
|
| 648 |
+
ratios_names = []
|
| 649 |
+
for p, state, batch_param_names in tuples:
|
| 650 |
+
dims = list(range(1, p.ndim))
|
| 651 |
+
|
| 652 |
+
def mean(x):
|
| 653 |
+
# workaround for bad interface of torch's "mean" for when dims is the
|
| 654 |
+
# empty list.
|
| 655 |
+
if len(dims) > 0:
|
| 656 |
+
return x.mean(dim=dims)
|
| 657 |
+
else:
|
| 658 |
+
return x
|
| 659 |
+
|
| 660 |
+
grad_ratio = (
|
| 661 |
+
(mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
|
| 662 |
+
.sqrt()
|
| 663 |
+
.to("cpu")
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
ratios_names += zip(
|
| 667 |
+
grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
ratios_names = sorted(ratios_names, reverse=True)
|
| 671 |
+
ratios_names = ratios_names[:10]
|
| 672 |
+
ratios_names = [
|
| 673 |
+
(ratio, name, largest_index(tensor))
|
| 674 |
+
for (ratio, name, tensor) in ratios_names
|
| 675 |
+
]
|
| 676 |
+
|
| 677 |
+
logging.debug(
|
| 678 |
+
f"Parameters with most larger-than-usual grads, with ratios, "
|
| 679 |
+
f"are: {ratios_names}"
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
def _show_gradient_dominating_parameter(
|
| 683 |
+
self,
|
| 684 |
+
tuples: List[Tuple[Tensor, dict, List[str]]],
|
| 685 |
+
tot_sumsq: Tensor,
|
| 686 |
+
scalar_lr_scale: float,
|
| 687 |
+
):
|
| 688 |
+
"""
|
| 689 |
+
Show information of parameter which dominates tot_sumsq.
|
| 690 |
+
|
| 691 |
+
Args:
|
| 692 |
+
tuples: a list of tuples of (param, state, param_names)
|
| 693 |
+
where param is a batched set of parameters,
|
| 694 |
+
with a .grad (1st dim is batch dim)
|
| 695 |
+
and state is the state-dict where optimization parameters are kept.
|
| 696 |
+
param_names is a List[str] while each str is name for a parameter
|
| 697 |
+
in batched set of parameters "param".
|
| 698 |
+
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
| 699 |
+
from tuples, we still pass it to save some time.
|
| 700 |
+
"""
|
| 701 |
+
all_sumsq_orig = {}
|
| 702 |
+
for p, state, batch_param_names in tuples:
|
| 703 |
+
# p is a stacked batch parameters.
|
| 704 |
+
batch_grad = p.grad
|
| 705 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
| 706 |
+
# Dummy values used by following `zip` statement.
|
| 707 |
+
batch_rms_orig = torch.full(
|
| 708 |
+
p.shape, scalar_lr_scale, device=batch_grad.device
|
| 709 |
+
)
|
| 710 |
+
else:
|
| 711 |
+
batch_rms_orig = state["param_rms"]
|
| 712 |
+
batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
|
| 713 |
+
if batch_grad.ndim > 1:
|
| 714 |
+
# need to guard it with if-statement because sum() sums over
|
| 715 |
+
# all dims if dim == ().
|
| 716 |
+
batch_sumsq_orig = batch_sumsq_orig.sum(
|
| 717 |
+
dim=list(range(1, batch_grad.ndim))
|
| 718 |
+
)
|
| 719 |
+
for name, sumsq_orig, rms, grad in zip(
|
| 720 |
+
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
| 721 |
+
):
|
| 722 |
+
|
| 723 |
+
proportion_orig = sumsq_orig / tot_sumsq
|
| 724 |
+
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
| 725 |
+
|
| 726 |
+
sorted_by_proportion = {
|
| 727 |
+
k: v
|
| 728 |
+
for k, v in sorted(
|
| 729 |
+
all_sumsq_orig.items(),
|
| 730 |
+
key=lambda item: item[1][0],
|
| 731 |
+
reverse=True,
|
| 732 |
+
)
|
| 733 |
+
}
|
| 734 |
+
dominant_param_name = next(iter(sorted_by_proportion))
|
| 735 |
+
(
|
| 736 |
+
dominant_proportion,
|
| 737 |
+
dominant_sumsq,
|
| 738 |
+
dominant_rms,
|
| 739 |
+
dominant_grad,
|
| 740 |
+
) = sorted_by_proportion[dominant_param_name]
|
| 741 |
+
logging.debug(
|
| 742 |
+
f"Parameter dominating tot_sumsq {dominant_param_name}"
|
| 743 |
+
f" with proportion {dominant_proportion:.2f},"
|
| 744 |
+
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
| 745 |
+
f"={dominant_sumsq:.3e},"
|
| 746 |
+
f" grad_sumsq={(dominant_grad**2).sum():.3e},"
|
| 747 |
+
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def largest_index(x: Tensor):
|
| 752 |
+
x = x.contiguous()
|
| 753 |
+
argmax = x.abs().argmax().item()
|
| 754 |
+
return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def _test_scaled_adam(hidden_dim: int):
|
| 758 |
+
import timeit
|
| 759 |
+
|
| 760 |
+
from zipvoice.models.modules.scaling import ScaledLinear
|
| 761 |
+
from zipvoice.utils.lr_scheduler import Eden
|
| 762 |
+
|
| 763 |
+
E = 100
|
| 764 |
+
B = 4
|
| 765 |
+
T = 2
|
| 766 |
+
logging.info("in test_eve_cain")
|
| 767 |
+
# device = torch.device('cuda')
|
| 768 |
+
device = torch.device("cpu")
|
| 769 |
+
dtype = torch.float32
|
| 770 |
+
|
| 771 |
+
fix_random_seed(42)
|
| 772 |
+
# these input_magnitudes and output_magnitudes are to test that
|
| 773 |
+
# Abel is working as we expect and is able to adjust scales of
|
| 774 |
+
# different dims differently.
|
| 775 |
+
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
| 776 |
+
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
| 777 |
+
|
| 778 |
+
fix_random_seed(42)
|
| 779 |
+
Linear = ScaledLinear
|
| 780 |
+
|
| 781 |
+
m = torch.nn.Sequential(
|
| 782 |
+
Linear(E, hidden_dim),
|
| 783 |
+
torch.nn.PReLU(),
|
| 784 |
+
Linear(hidden_dim, hidden_dim),
|
| 785 |
+
torch.nn.PReLU(),
|
| 786 |
+
Linear(hidden_dim, E),
|
| 787 |
+
).to(device)
|
| 788 |
+
|
| 789 |
+
train_pairs = [
|
| 790 |
+
(
|
| 791 |
+
100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
|
| 792 |
+
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
|
| 793 |
+
)
|
| 794 |
+
for _ in range(20)
|
| 795 |
+
]
|
| 796 |
+
optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
|
| 797 |
+
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
| 798 |
+
|
| 799 |
+
start = timeit.default_timer()
|
| 800 |
+
avg_loss = 0.0
|
| 801 |
+
for epoch in range(180):
|
| 802 |
+
scheduler.step_epoch()
|
| 803 |
+
# if epoch == 100 and iter in [2,3]:
|
| 804 |
+
# optim.reset_speedup() # check it doesn't crash.
|
| 805 |
+
|
| 806 |
+
# if epoch == 130:
|
| 807 |
+
# opts = diagnostics.TensorDiagnosticOptions(
|
| 808 |
+
# 512
|
| 809 |
+
# ) # allow 4 megabytes per sub-module
|
| 810 |
+
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
| 811 |
+
|
| 812 |
+
for n, (x, y) in enumerate(train_pairs):
|
| 813 |
+
y_out = m(x)
|
| 814 |
+
loss = ((y_out - y) ** 2).mean() * 100.0
|
| 815 |
+
if epoch == 0 and n == 0:
|
| 816 |
+
avg_loss = loss.item()
|
| 817 |
+
else:
|
| 818 |
+
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
| 819 |
+
if n == 0 and epoch % 5 == 0:
|
| 820 |
+
# norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
| 821 |
+
# norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
| 822 |
+
# norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
| 823 |
+
# norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
| 824 |
+
# scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
| 825 |
+
# scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
| 826 |
+
# scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
| 827 |
+
# scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
| 828 |
+
lr = scheduler.get_last_lr()[0]
|
| 829 |
+
logging.info(
|
| 830 |
+
f"Iter {iter}, epoch {epoch}, batch {n}, "
|
| 831 |
+
f"avg_loss {avg_loss:.4g}, lr={lr:.4e}"
|
| 832 |
+
) # , norms={norm1,norm1b,norm2,norm2b}")
|
| 833 |
+
# scales={scale1,scale1b,scale2,scale2b}
|
| 834 |
+
loss.log().backward()
|
| 835 |
+
optim.step()
|
| 836 |
+
optim.zero_grad()
|
| 837 |
+
scheduler.step_batch()
|
| 838 |
+
|
| 839 |
+
# diagnostic.print_diagnostics()
|
| 840 |
+
|
| 841 |
+
stop = timeit.default_timer()
|
| 842 |
+
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
| 843 |
+
|
| 844 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
| 845 |
+
# logging.info("state dict = ", scheduler.state_dict())
|
| 846 |
+
# logging.info("optim state_dict = ", optim.state_dict())
|
| 847 |
+
logging.info(f"input_magnitudes = {input_magnitudes}")
|
| 848 |
+
logging.info(f"output_magnitudes = {output_magnitudes}")
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
if __name__ == "__main__":
|
| 852 |
+
torch.set_num_threads(1)
|
| 853 |
+
torch.set_num_interop_threads(1)
|
| 854 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 855 |
+
import subprocess
|
| 856 |
+
|
| 857 |
+
s = subprocess.check_output(
|
| 858 |
+
"git status -uno .; git log -1; git diff HEAD .", shell=True
|
| 859 |
+
)
|
| 860 |
+
logging.info(s)
|
| 861 |
+
import sys
|
| 862 |
+
|
| 863 |
+
if len(sys.argv) > 1:
|
| 864 |
+
hidden_dim = int(sys.argv[1])
|
| 865 |
+
else:
|
| 866 |
+
hidden_dim = 200
|
| 867 |
+
|
| 868 |
+
_test_scaled_adam(hidden_dim)
|
zipvoice/utils/scaling_converter.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
| 2 |
+
# Zengwei Yao)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
This file replaces various modules in a model.
|
| 20 |
+
Specifically, ActivationBalancer is replaced with an identity operator;
|
| 21 |
+
Whiten is also replaced with an identity operator;
|
| 22 |
+
BasicNorm is replaced by a module with `exp` removed.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import copy
|
| 26 |
+
from typing import List
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
|
| 31 |
+
from zipvoice.models.modules.scaling import (
|
| 32 |
+
Balancer,
|
| 33 |
+
Dropout3,
|
| 34 |
+
SwooshL,
|
| 35 |
+
SwooshLOnnx,
|
| 36 |
+
SwooshR,
|
| 37 |
+
SwooshROnnx,
|
| 38 |
+
Whiten,
|
| 39 |
+
)
|
| 40 |
+
from zipvoice.models.modules.zipformer import CompactRelPositionalEncoding
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
| 44 |
+
# get_submodule was added to nn.Module at v1.9.0
|
| 45 |
+
def get_submodule(model, target):
|
| 46 |
+
if target == "":
|
| 47 |
+
return model
|
| 48 |
+
atoms: List[str] = target.split(".")
|
| 49 |
+
mod: torch.nn.Module = model
|
| 50 |
+
for item in atoms:
|
| 51 |
+
if not hasattr(mod, item):
|
| 52 |
+
raise AttributeError(
|
| 53 |
+
mod._get_name() + " has no " "attribute `" + item + "`"
|
| 54 |
+
)
|
| 55 |
+
mod = getattr(mod, item)
|
| 56 |
+
if not isinstance(mod, torch.nn.Module):
|
| 57 |
+
raise AttributeError("`" + item + "` is not " "an nn.Module")
|
| 58 |
+
return mod
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def convert_scaled_to_non_scaled(
|
| 62 |
+
model: nn.Module,
|
| 63 |
+
inplace: bool = False,
|
| 64 |
+
is_pnnx: bool = False,
|
| 65 |
+
is_onnx: bool = False,
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
model:
|
| 70 |
+
The model to be converted.
|
| 71 |
+
inplace:
|
| 72 |
+
If True, the input model is modified inplace.
|
| 73 |
+
If False, the input model is copied and we modify the copied version.
|
| 74 |
+
is_pnnx:
|
| 75 |
+
True if we are going to export the model for PNNX.
|
| 76 |
+
is_onnx:
|
| 77 |
+
True if we are going to export the model for ONNX.
|
| 78 |
+
Return:
|
| 79 |
+
Return a model without scaled layers.
|
| 80 |
+
"""
|
| 81 |
+
if not inplace:
|
| 82 |
+
model = copy.deepcopy(model)
|
| 83 |
+
|
| 84 |
+
d = {}
|
| 85 |
+
for name, m in model.named_modules():
|
| 86 |
+
if isinstance(m, (Balancer, Dropout3, Whiten)):
|
| 87 |
+
d[name] = nn.Identity()
|
| 88 |
+
elif is_onnx and isinstance(m, SwooshR):
|
| 89 |
+
d[name] = SwooshROnnx()
|
| 90 |
+
elif is_onnx and isinstance(m, SwooshL):
|
| 91 |
+
d[name] = SwooshLOnnx()
|
| 92 |
+
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
| 93 |
+
# We want to recreate the positional encoding vector when
|
| 94 |
+
# the input changes, so we have to use torch.jit.script()
|
| 95 |
+
# to replace torch.jit.trace()
|
| 96 |
+
d[name] = torch.jit.script(m)
|
| 97 |
+
|
| 98 |
+
for k, v in d.items():
|
| 99 |
+
if "." in k:
|
| 100 |
+
parent, child = k.rsplit(".", maxsplit=1)
|
| 101 |
+
setattr(get_submodule(model, parent), child, v)
|
| 102 |
+
else:
|
| 103 |
+
setattr(model, k, v)
|
| 104 |
+
|
| 105 |
+
return model
|