OmniAvatar / supertonic.py
alex
gpu based
00e7318
import json
import os
import time
from time_util import timer
from typing import Optional
from unicodedata import normalize
import uuid
import numpy as np
import onnxruntime as ort
import soundfile as sf
from huggingface_hub import snapshot_download
from typing import Optional, Union
class UnicodeProcessor:
def __init__(self, unicode_indexer_path: str):
with open(unicode_indexer_path, "r") as f:
self.indexer = json.load(f)
def _preprocess_text(self, text: str) -> str:
# TODO: add more preprocessing
text = normalize("NFKD", text)
return text
def _get_text_mask(self, text_ids_lengths: np.ndarray) -> np.ndarray:
text_mask = length_to_mask(text_ids_lengths)
return text_mask
def _text_to_unicode_values(self, text: str) -> np.ndarray:
unicode_values = np.array(
[ord(char) for char in text], dtype=np.uint16
) # 2 bytes
return unicode_values
def __call__(self, text_list: list[str]) -> tuple[np.ndarray, np.ndarray]:
text_list = [self._preprocess_text(t) for t in text_list]
text_ids_lengths = np.array([len(text) for text in text_list], dtype=np.int64)
text_ids = np.zeros((len(text_list), text_ids_lengths.max()), dtype=np.int64)
for i, text in enumerate(text_list):
unicode_vals = self._text_to_unicode_values(text)
text_ids[i, : len(unicode_vals)] = np.array(
[self.indexer[val] for val in unicode_vals], dtype=np.int64
)
text_mask = self._get_text_mask(text_ids_lengths)
return text_ids, text_mask
class Style:
def __init__(self, style_ttl_onnx: np.ndarray, style_dp_onnx: np.ndarray):
self.ttl = style_ttl_onnx
self.dp = style_dp_onnx
class TextToSpeech:
def __init__(
self,
cfgs: dict,
text_processor: UnicodeProcessor,
dp_ort: ort.InferenceSession,
text_enc_ort: ort.InferenceSession,
vector_est_ort: ort.InferenceSession,
vocoder_ort: ort.InferenceSession,
):
self.cfgs = cfgs
self.text_processor = text_processor
self.dp_ort = dp_ort
self.text_enc_ort = text_enc_ort
self.vector_est_ort = vector_est_ort
self.vocoder_ort = vocoder_ort
self.sample_rate = cfgs["ae"]["sample_rate"]
self.base_chunk_size = cfgs["ae"]["base_chunk_size"]
self.chunk_compress_factor = cfgs["ttl"]["chunk_compress_factor"]
self.ldim = cfgs["ttl"]["latent_dim"]
def sample_noisy_latent(
self, duration: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
bsz = len(duration)
wav_len_max = duration.max() * self.sample_rate
wav_lengths = (duration * self.sample_rate).astype(np.int64)
chunk_size = self.base_chunk_size * self.chunk_compress_factor
latent_len = ((wav_len_max + chunk_size - 1) / chunk_size).astype(np.int32)
latent_dim = self.ldim * self.chunk_compress_factor
noisy_latent = np.random.randn(bsz, latent_dim, latent_len).astype(np.float32)
latent_mask = get_latent_mask(
wav_lengths, self.base_chunk_size, self.chunk_compress_factor
)
noisy_latent = noisy_latent * latent_mask
return noisy_latent, latent_mask
def _infer(
self,
text_list: list[str],
style: Style,
total_step: int,
speed: float = 1.05,
suggested_duration: Optional[Union[float, list[float], np.ndarray]] = None,
speed_min_factor: float = 0.75,
speed_max_factor: float = 1.2,
) -> tuple[np.ndarray, np.ndarray]:
assert (
len(text_list) == style.ttl.shape[0]
), "Number of texts must match number of style vectors"
bsz = len(text_list)
text_ids, text_mask = self.text_processor(text_list)
# 1) Predict base duration
dur_pred, *_ = self.dp_ort.run(
None, {"text_ids": text_ids, "style_dp": style.dp, "text_mask": text_mask}
)
dur_pred = np.array(dur_pred, dtype=np.float32).reshape(bsz) # (bsz,)
# 2) Adjust duration based on suggested_duration (if given)
if suggested_duration is not None:
sugg = np.array(suggested_duration, dtype=np.float32)
if sugg.ndim == 0:
# same suggestion for all
sugg = np.full((bsz,), float(sugg), dtype=np.float32)
else:
sugg = sugg.reshape(bsz)
eps = 1e-3
sugg = np.clip(sugg, eps, None)
# we want dur_used ≈ sugg
# dur_used = dur_pred / speed_used => speed_target = dur_pred / sugg
speed_target = dur_pred / sugg
speed_min = speed * speed_min_factor
speed_max = speed * speed_max_factor
speed_used = np.clip(speed_target, speed_min, speed_max)
dur_used = dur_pred / speed_used
else:
# default behaviour
speed_used = np.full((bsz,), speed, dtype=np.float32)
dur_used = dur_pred / speed_used
# 3) Continue as before, using dur_used
text_emb_onnx, *_ = self.text_enc_ort.run(
None,
{"text_ids": text_ids, "style_ttl": style.ttl, "text_mask": text_mask},
)
xt, latent_mask = self.sample_noisy_latent(dur_used)
total_step_np = np.array([total_step] * bsz, dtype=np.float32)
for step in range(total_step):
current_step = np.array([step] * bsz, dtype=np.float32)
xt, *_ = self.vector_est_ort.run(
None,
{
"noisy_latent": xt,
"text_emb": text_emb_onnx,
"style_ttl": style.ttl,
"text_mask": text_mask,
"latent_mask": latent_mask,
"current_step": current_step,
"total_step": total_step_np,
},
)
wav, *_ = self.vocoder_ort.run(None, {"latent": xt})
return wav, dur_used
def batch(
self,
text_list: list[str],
style: Style,
total_step: int,
speed: float = 1.05,
suggested_duration: Optional[Union[float, list[float], np.ndarray]] = None,
speed_min_factor: float = 0.75,
speed_max_factor: float = 1.2,
) -> tuple[np.ndarray, np.ndarray]:
return self._infer(
text_list,
style,
total_step,
speed=speed,
suggested_duration=suggested_duration,
speed_min_factor=speed_min_factor,
speed_max_factor=speed_max_factor,
)
def __call__(
self,
text: str,
style: Style,
total_step: int,
speed: float = 1.05,
silence_duration: float = 0.3,
) -> tuple[np.ndarray, np.ndarray]:
assert (
style.ttl.shape[0] == 1
), "Single speaker text to speech only supports single style"
text_list = chunk_text(text)
wav_cat = None
dur_cat = None
for text in text_list:
wav, dur_onnx = self._infer([text], style, total_step, speed)
if wav_cat is None:
wav_cat = wav
dur_cat = dur_onnx
else:
silence = np.zeros(
(1, int(silence_duration * self.sample_rate)), dtype=np.float32
)
wav_cat = np.concatenate([wav_cat, silence, wav], axis=1)
dur_cat += dur_onnx + silence_duration
return wav_cat, dur_cat
def length_to_mask(lengths: np.ndarray, max_len: Optional[int] = None) -> np.ndarray:
"""
Convert lengths to binary mask.
Args:
lengths: (B,)
max_len: int
Returns:
mask: (B, 1, max_len)
"""
max_len = max_len or lengths.max()
ids = np.arange(0, max_len)
mask = (ids < np.expand_dims(lengths, axis=1)).astype(np.float32)
return mask.reshape(-1, 1, max_len)
def get_latent_mask(
wav_lengths: np.ndarray, base_chunk_size: int, chunk_compress_factor: int
) -> np.ndarray:
latent_size = base_chunk_size * chunk_compress_factor
latent_lengths = (wav_lengths + latent_size - 1) // latent_size
latent_mask = length_to_mask(latent_lengths)
return latent_mask
def load_onnx(
onnx_path: str, opts: ort.SessionOptions, providers: list[str]
) -> ort.InferenceSession:
return ort.InferenceSession(onnx_path, sess_options=opts, providers=providers)
def load_onnx_all(
onnx_dir: str, opts: ort.SessionOptions, providers: list[str]
) -> tuple[
ort.InferenceSession,
ort.InferenceSession,
ort.InferenceSession,
ort.InferenceSession,
]:
dp_onnx_path = os.path.join(onnx_dir, "duration_predictor.onnx")
text_enc_onnx_path = os.path.join(onnx_dir, "text_encoder.onnx")
vector_est_onnx_path = os.path.join(onnx_dir, "vector_estimator.onnx")
vocoder_onnx_path = os.path.join(onnx_dir, "vocoder.onnx")
dp_ort = load_onnx(dp_onnx_path, opts, providers)
text_enc_ort = load_onnx(text_enc_onnx_path, opts, providers)
vector_est_ort = load_onnx(vector_est_onnx_path, opts, providers)
vocoder_ort = load_onnx(vocoder_onnx_path, opts, providers)
return dp_ort, text_enc_ort, vector_est_ort, vocoder_ort
def load_cfgs(onnx_dir: str) -> dict:
cfg_path = os.path.join(onnx_dir, "tts.json")
with open(cfg_path, "r") as f:
cfgs = json.load(f)
return cfgs
def load_text_processor(onnx_dir: str) -> UnicodeProcessor:
unicode_indexer_path = os.path.join(onnx_dir, "unicode_indexer.json")
text_processor = UnicodeProcessor(unicode_indexer_path)
return text_processor
# text_to_speech = load_text_to_speech(False)
model_dir = snapshot_download("Supertone/supertonic")
onnx_dir = f"{model_dir}/onnx"
def load_text_to_speech(use_gpu: bool = False) -> TextToSpeech:
opts = ort.SessionOptions()
if use_gpu:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
print("Using CPU for inference")
cfgs = load_cfgs(onnx_dir)
dp_ort, text_enc_ort, vector_est_ort, vocoder_ort = load_onnx_all(
onnx_dir, opts, providers
)
text_processor = load_text_processor(onnx_dir)
return TextToSpeech(
cfgs, text_processor, dp_ort, text_enc_ort, vector_est_ort, vocoder_ort
)
def load_voice_style(voice_style_paths: list[str], verbose: bool = False) -> Style:
bsz = len(voice_style_paths)
# Read first file to get dimensions
with open(voice_style_paths[0], "r") as f:
first_style = json.load(f)
ttl_dims = first_style["style_ttl"]["dims"]
dp_dims = first_style["style_dp"]["dims"]
# Pre-allocate arrays with full batch size
ttl_style = np.zeros([bsz, ttl_dims[1], ttl_dims[2]], dtype=np.float32)
dp_style = np.zeros([bsz, dp_dims[1], dp_dims[2]], dtype=np.float32)
# Fill in the data
for i, voice_style_path in enumerate(voice_style_paths):
with open(voice_style_path, "r") as f:
voice_style = json.load(f)
ttl_data = np.array(
voice_style["style_ttl"]["data"], dtype=np.float32
).flatten()
ttl_style[i] = ttl_data.reshape(ttl_dims[1], ttl_dims[2])
dp_data = np.array(voice_style["style_dp"]["data"], dtype=np.float32).flatten()
dp_style[i] = dp_data.reshape(dp_dims[1], dp_dims[2])
if verbose:
print(f"Loaded {bsz} voice styles")
return Style(ttl_style, dp_style)
def sanitize_filename(text: str, max_len: int) -> str:
"""Sanitize filename by replacing non-alphanumeric characters with underscores"""
import re
prefix = text[:max_len]
return re.sub(r"[^a-zA-Z0-9]", "_", prefix)
def chunk_text(text: str, max_len: int = 300) -> list[str]:
"""
Split text into chunks by paragraphs and sentences.
Args:
text: Input text to chunk
max_len: Maximum length of each chunk (default: 300)
Returns:
List of text chunks
"""
import re
# Split by paragraph (two or more newlines)
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text.strip()) if p.strip()]
chunks = []
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
# Split by sentence boundaries (period, question mark, exclamation mark followed by space)
# But exclude common abbreviations like Mr., Mrs., Dr., etc. and single capital letters like F.
pattern = r"(?<!Mr\.)(?<!Mrs\.)(?<!Ms\.)(?<!Dr\.)(?<!Prof\.)(?<!Sr\.)(?<!Jr\.)(?<!Ph\.D\.)(?<!etc\.)(?<!e\.g\.)(?<!i\.e\.)(?<!vs\.)(?<!Inc\.)(?<!Ltd\.)(?<!Co\.)(?<!Corp\.)(?<!St\.)(?<!Ave\.)(?<!Blvd\.)(?<!\b[A-Z]\.)(?<=[.!?])\s+"
sentences = re.split(pattern, paragraph)
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) + 1 <= max_len:
current_chunk += (" " if current_chunk else "") + sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def generate_speech(
text_to_speech,
text_list,
save_dir,
voice_style="M1",
total_step=5,
speed=1.05,
n_test=1,
batch=None,
suggested_durations=None, # NEW: list/np.ndarray of seconds, len == len(text_list)
speed_min_factor=0.75,
speed_max_factor=1.2,
):
saved_files_list = []
voice_style_paths = [f"{model_dir}/voice_styles/{voice_style}.json"] * len(text_list)
assert len(voice_style_paths) == len(
text_list
), f"Number of voice styles ({len(voice_style_paths)}) must match number of texts ({len(text_list)})"
bsz = len(voice_style_paths)
style = load_voice_style(voice_style_paths, verbose=True)
for n in range(n_test):
if batch:
wav, duration = text_to_speech.batch(
text_list,
style,
total_step,
speed=speed,
suggested_duration=suggested_durations,
speed_min_factor=speed_min_factor,
speed_max_factor=speed_max_factor,
)
else:
# optional: could support suggested_durations[0] here too
wav, duration = text_to_speech(
text_list[0], style, total_step, speed
)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for b in range(bsz):
unique = uuid.uuid4().hex[:8]
fname = f"{sanitize_filename(text_list[b], 20)}_{unique}_{n+1}.wav"
w = wav[b, : int(text_to_speech.sample_rate * duration[b].item())]
sf.write(os.path.join(save_dir, fname), w, text_to_speech.sample_rate)
saved_files_list.append(f"{save_dir}/{fname}")
return saved_files_list