from dataclasses import dataclass from typing import Callable, List, Tuple import torch import safetensors.torch as st from huggingface_hub import hf_hub_download from model import EchoDiT from autoencoder import build_ae, DAC import torchaudio from torchcodec.decoders import AudioDecoder # from samplers import Sampler SampleFn = Callable[ [EchoDiT, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int], torch.Tensor ] ### Loading def load_model_from_hf(repo_id: str = 'jordand/echo-tts-base', device: str = 'cuda', dtype: torch.dtype | None = torch.bfloat16, compile: bool = False, token: str | None = None) -> EchoDiT: with torch.device('meta'): model = EchoDiT( latent_size=80, model_size=2048, num_layers=24, num_heads=16, intermediate_size=5888, norm_eps=1e-5, max_seq_len=640, text_vocab_size=256, text_model_size=1280, text_num_layers=14, text_num_heads=10, text_intermediate_size=3328, text_max_seq_len=768, speaker_patch_size=4, speaker_model_size=1280, speaker_num_layers=14, speaker_num_heads=10, speaker_intermediate_size=3328, speaker_max_patched_seq_len=640, timestep_embed_size=512, adaln_rank=256, ) w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token) # Load to CPU first state = st.load_file(w_path, device='cpu') # Convert dtype on CPU if needed if dtype is not None: state = {k: v.to(dtype=dtype) for k, v in state.items()} # Now move to device state = {k: v.to(device=device) for k, v in state.items()} model.load_state_dict(state, strict=False, assign=True) model = model.eval() if compile: model = torch.compile(model) model.get_kv_cache = torch.compile(model.get_kv_cache) return model def load_fish_ae_from_hf(repo_id: str = 'jordand/fish-s1-dac-min', device: str = 'cuda', dtype: torch.dtype | None = torch.float32, compile: bool = False, token: str | None = None) -> DAC: # have not tested lower precisions with fish AE yet with torch.device('meta'): fish_ae = build_ae() w_path = hf_hub_download(repo_id, 'pytorch_model.safetensors', token=token) if dtype is not None and dtype != torch.float32: state = st.load_file(w_path, device='cpu') state = {k: v.to(dtype=dtype) for k, v in state.items()} state = {k: v.to(device=device) for k, v in state.items()} fish_ae.load_state_dict(state, strict=False, assign=True) else: state = st.load_file(w_path, device=device) fish_ae.load_state_dict(state, strict=False, assign=True) fish_ae = fish_ae.eval().to(device) if compile: fish_ae.encoder = torch.compile(fish_ae.encoder) fish_ae.decoder = torch.compile(fish_ae.decoder) return fish_ae @dataclass class PCAState: pca_components: torch.Tensor pca_mean: torch.Tensor latent_scale: float def load_pca_state_from_hf(repo_id: str = 'jordand/echo-tts', device: str = 'cuda', filename: str = 'pca_state.safetensors', token: str | None = None) -> PCAState: p_path = hf_hub_download(repo_id, filename, token=token) t = st.load_file(p_path, device=device) return PCAState( pca_components=t["pca_components"], pca_mean=t["pca_mean"], latent_scale=float(t["latent_scale"].item()), ) ### default load audio def load_audio(path: str) -> torch.Tensor: decoder = AudioDecoder(path) sr = decoder.metadata.sample_rate audio = decoder.get_samples_played_in_range(0, 120) audio = audio.data.mean(dim=0).unsqueeze(0) audio = torchaudio.functional.resample(audio, sr, 44_100) audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.)) # TODO is this better than clipping? should we target a specific energy level? return audio ### Text helpers def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True) -> torch.Tensor: if normalize: text = text.replace('…', '...') text = text.replace('“', '"') text = text.replace('”', '"') text = text.replace('’', "'") text = text.replace('\n', " ") text = text.replace(':', ',') text = text.replace(';', ',') b = list(text.encode('utf-8')) if append_bos: b.insert(0, 0) return torch.tensor(b) def get_text_input_ids_and_mask(text_arr: List[str], max_length: int | None, device: str | None = None) -> tuple[torch.Tensor, torch.Tensor]: batch_size = len(text_arr) if max_length is None: max_length = max(len(tokenizer_encode(text)) for text in text_arr) # obviously bad... tokens = torch.zeros((batch_size, max_length), dtype=torch.int32) mask = torch.zeros((batch_size, max_length), dtype=torch.bool) for i, text in enumerate(text_arr): encoded = tokenizer_encode(text) length = min(len(encoded), max_length) tokens[i, :length] = encoded[:length] mask[i, :length] = 1 if device is not None: tokens = tokens.to(device) mask = mask.to(device) return tokens, mask ### Autoencoder Inference @torch.inference_mode() def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor: assert audio.ndim == 3 and audio.shape[1] == 1 # (b, 1, length) z_q = fish_ae.encode_zq(audio).float() z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T z_q = z_q * pca_state.latent_scale return z_q @torch.inference_mode() def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor: z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float() @torch.inference_mode() def ae_reconstruct(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor: # (audio is (b, 1, length)) z_q = ae_encode(fish_ae, pca_state, audio.to(fish_ae.dtype)) return ae_decode(fish_ae, pca_state, z_q) @torch.inference_mode() def get_speaker_latent_and_mask( fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor, # (1, length) max_speaker_latent_len: int = 2560, # pretrained max length audio_chunk_size: int = 640 * 2048 # (~30 seconds, 1/4 max speaker condition size) ) -> tuple[torch.Tensor, torch.Tensor]: # gets speaker latent and mask from audio, computes in chunks and concatenates (similar to pretraining setup) AE_DOWNSAMPLE_FACTOR = 2048 max_audio_len = max_speaker_latent_len * AE_DOWNSAMPLE_FACTOR assert audio.ndim == 2 and audio.shape[0] == 1 # (1, length) audio = audio[:, :max_audio_len] audio_len = audio.shape[1] latent_arr = [] for i in range(0, audio_len, audio_chunk_size): audio_chunk = audio[:, i:i + audio_chunk_size] if audio_chunk.shape[1] < audio_chunk_size: audio_chunk = torch.nn.functional.pad(audio_chunk, (0, audio_chunk_size - audio_chunk.shape[1])) latent_chunk = ae_encode(fish_ae, pca_state, audio_chunk.unsqueeze(0)) latent_arr.append(latent_chunk) speaker_latent = torch.cat(latent_arr, dim=1) actual_latent_len = audio_len // AE_DOWNSAMPLE_FACTOR speaker_mask = (torch.arange(speaker_latent.shape[1], device=speaker_latent.device) < actual_latent_len).unsqueeze(0) if speaker_latent.shape[1] < max_speaker_latent_len: speaker_latent = torch.nn.functional.pad(speaker_latent, (0, 0, 0, max_speaker_latent_len - speaker_latent.shape[1])) speaker_mask = torch.nn.functional.pad(speaker_mask, (0, max_speaker_latent_len - speaker_mask.shape[1])) return speaker_latent, speaker_mask ### Full sample pipeline def find_flattening_point(data, target_value=0.0, window_size=20, std_threshold=0.05): padded_data = torch.cat([data, torch.zeros(window_size, *data.shape[1:], device=data.device, dtype=data.dtype)]) for i in range(len(padded_data) - window_size): window = padded_data[i:i + window_size] if window.std() < std_threshold and abs(window.mean() - target_value) < 0.1: return i return len(data) @torch.inference_mode() def sample_pipeline( model: EchoDiT, fish_ae: DAC, pca_state: PCAState, sample_fn: SampleFn, text_prompt: str, speaker_audio: torch.Tensor | None, rng_seed: int, pad_to_max_speaker_latent_len: int | None = 2560, pad_to_max_text_seq_len: int | None = 768, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: MAX_SPEAKER_LATENT_LEN = 2560 MAX_TEXT_SEQ_LEN = 768 device, dtype = model.device, model.dtype text_input_ids, text_mask = get_text_input_ids_and_mask([text_prompt], min(pad_to_max_text_seq_len or MAX_TEXT_SEQ_LEN, MAX_TEXT_SEQ_LEN), device=device) # print('initial text input ids length: ', text_input_ids.shape[1]) # torch.cuda.synchronize() # import time # t0 = time.time() if speaker_audio is None: # No speaker prompt - use zero speaker latent and mask speaker_latent = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN, 80), device=device, dtype=dtype) speaker_mask = torch.zeros((1, pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN), device=device, dtype=torch.bool) # print("Using zero speaker latent and mask (no speaker prompt)") else: speaker_latent, speaker_mask = get_speaker_latent_and_mask( fish_ae, pca_state, speaker_audio.to(fish_ae.dtype), max_speaker_latent_len=pad_to_max_speaker_latent_len if pad_to_max_speaker_latent_len else MAX_SPEAKER_LATENT_LEN ) speaker_latent = speaker_latent.to(device) speaker_mask = speaker_mask.to(device) # print('speaker latent shape: ', speaker_latent.shape) # print('speaker mask shape: ', speaker_mask.shape) # torch.cuda.synchronize() # t1 = time.time() # print(f"Time taken encode: {t1 - t0} seconds") latent_out = sample_fn(model, speaker_latent, speaker_mask, text_input_ids, text_mask, rng_seed) # torch.cuda.synchronize() # t2 = time.time() # print(f"Time taken sample: {t2 - t1} seconds") audio_out = ae_decode(fish_ae, pca_state, latent_out) # torch.cuda.synchronize() # t3 = time.time() # print(f"Time taken decode: {t3 - t2} seconds") flattening_point = find_flattening_point(latent_out[0]) audio_out = audio_out[..., :flattening_point * 2048] # print(f"\nTime taken total: {t3 - t0} seconds") # peak_mem = torch.cuda.max_memory_allocated() # print(f"Peak memory: {peak_mem / 1024**2:.2f} MB") # print(torch.cuda.memory_summary(abbreviated=True)) return audio_out