| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | import re |
| | import torchaudio |
| |
|
| | from transformers import processing_utils |
| |
|
| | processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel" |
| |
|
| | import torch |
| | from transformers import ( |
| | PreTrainedTokenizerBase, |
| | BatchFeature, |
| | ProcessorMixin, |
| | logging, |
| | AutoConfig, |
| | AutoModel, |
| | AutoTokenizer, |
| | ) |
| |
|
| | from .configuration_moss_tts import MossTTSDelayConfig |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | AUDIO_PLACEHOLDER = "<|audio|>" |
| |
|
| |
|
| | @dataclass |
| | class Message: |
| | def to_dict(self) -> Dict[str, Any]: |
| | raise NotImplementedError |
| |
|
| |
|
| | @dataclass |
| | class UserMessage(Message): |
| | text: Optional[str] = None |
| | reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None |
| | instruction: Optional[str] = None |
| | tokens: Optional[int] = None |
| | quality: Optional[str] = None |
| | sound_event: Optional[str] = None |
| | ambient_sound: Optional[str] = None |
| | language: Optional[str] = None |
| |
|
| | def __post_init__(self): |
| | template = """<user_inst> |
| | - Reference(s): |
| | {reference} |
| | - Instruction: |
| | {instruction} |
| | - Tokens: |
| | {tokens} |
| | - Quality: |
| | {quality} |
| | - Sound Event: |
| | {sound_event} |
| | - Ambient Sound: |
| | {ambient_sound} |
| | - Language: |
| | {language} |
| | - Text: |
| | {text} |
| | </user_inst>""" |
| |
|
| | audio_codes_list = [] |
| | if self.reference is None: |
| | reference = "None" |
| | elif isinstance(self.reference, List): |
| | reference = [] |
| | for speaker_idx, speaker_reference in enumerate(self.reference): |
| | if speaker_reference is not None: |
| | reference.append(f"[S{speaker_idx+1}]:\n{AUDIO_PLACEHOLDER}") |
| | reference = "\n".join(reference) |
| | audio_codes_list = [ |
| | speaker_reference |
| | for speaker_reference in self.reference |
| | if speaker_reference is not None |
| | ] |
| | else: |
| | raise TypeError("`reference` should be exactly a list when it is not None.") |
| |
|
| | content = ( |
| | template.replace("{reference}", str(reference)) |
| | .replace("{instruction}", str(self.instruction)) |
| | .replace("{tokens}", str(self.tokens)) |
| | .replace("{quality}", str(self.quality)) |
| | .replace("{sound_event}", str(self.sound_event)) |
| | .replace("{ambient_sound}", str(self.ambient_sound)) |
| | .replace("{language}", str(self.language)) |
| | .replace("{text}", str(self.text)) |
| | ) |
| |
|
| | self._content = content |
| | self._audio_codes_list = audio_codes_list |
| |
|
| | def to_dict(self): |
| | return { |
| | "role": "user", |
| | "content": self._content, |
| | "audio_codes_list": self._audio_codes_list, |
| | } |
| |
|
| |
|
| | @dataclass |
| | class AssistantMessage(Message): |
| | audio_codes_list: List[Union[str, torch.Tensor]] |
| | content: str = AUDIO_PLACEHOLDER |
| |
|
| | def to_dict(self): |
| | return { |
| | "role": "assistant", |
| | "content": self.content, |
| | "audio_codes_list": self.audio_codes_list, |
| | } |
| |
|
| |
|
| | USER_MESSAGE_FIELDS = ( |
| | "text", |
| | "reference", |
| | "instruction", |
| | "tokens", |
| | "quality", |
| | "sound_event", |
| | "ambient_sound", |
| | "language", |
| | ) |
| |
|
| |
|
| | class MossTTSDelayProcessor(ProcessorMixin): |
| | tokenizer_class = "AutoTokenizer" |
| | audio_tokenizer_class = "AutoModel" |
| |
|
| | tokenizer: PreTrainedTokenizerBase |
| | audio_tokenizer: Any |
| |
|
| | def __init__( |
| | self, |
| | tokenizer: PreTrainedTokenizerBase, |
| | audio_tokenizer: Any = None, |
| | model_config: Optional[MossTTSDelayConfig] = None, |
| | **kwargs, |
| | ): |
| | super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs) |
| |
|
| | |
| | self.tokenizer = tokenizer |
| | self.audio_tokenizer = audio_tokenizer |
| | if model_config is None: |
| | model_config = MossTTSDelayConfig() |
| | self.model_config = model_config |
| |
|
| | self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
| | self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| | self.newline_token_id = 198 |
| |
|
| | def _id_to_token(token_id: int) -> str: |
| | tok = tokenizer.convert_ids_to_tokens(int(token_id)) |
| | if isinstance(tok, list): |
| | return tok[0] if len(tok) > 0 else "" |
| | return cast(str, tok) |
| |
|
| | self.audio_user_slot_token = _id_to_token( |
| | self.model_config.audio_user_slot_token_id |
| | ) |
| | self.audio_assistant_gen_slot_token = _id_to_token( |
| | self.model_config.audio_assistant_gen_slot_token_id |
| | ) |
| | self.audio_assistant_delay_slot_token = _id_to_token( |
| | self.model_config.audio_assistant_delay_slot_token_id |
| | ) |
| | self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id) |
| | self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| | trust_remote_code = kwargs.pop("trust_remote_code", True) |
| | kwargs.pop("_from_auto", None) |
| |
|
| | audio_tokenizer_name_or_path = kwargs.pop( |
| | "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer" |
| | ) |
| |
|
| | pretrained_model_name_or_path = Path(pretrained_model_name_or_path) |
| | model_config = cast( |
| | MossTTSDelayConfig, |
| | AutoConfig.from_pretrained( |
| | pretrained_model_name_or_path, |
| | *args, |
| | trust_remote_code=trust_remote_code, |
| | **kwargs, |
| | ), |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | pretrained_model_name_or_path, |
| | *args, |
| | trust_remote_code=trust_remote_code, |
| | **kwargs, |
| | ) |
| | audio_tokenizer = AutoModel.from_pretrained( |
| | audio_tokenizer_name_or_path, |
| | trust_remote_code=trust_remote_code, |
| | **kwargs, |
| | ) |
| |
|
| | return cls( |
| | tokenizer=tokenizer, |
| | audio_tokenizer=audio_tokenizer, |
| | model_config=model_config, |
| | **kwargs, |
| | ) |
| |
|
| | def __call__(self, *args, **kwargs) -> BatchFeature: |
| | conversations = args[0] if len(args) > 0 else kwargs.pop("conversations") |
| | mode: str = kwargs.pop("mode", "generation") |
| | apply_chat_template: bool = kwargs.pop("apply_chat_template", True) |
| | n_vq: Optional[int] = kwargs.pop("n_vq", None) |
| |
|
| | |
| | kwargs.pop("return_tensors", None) |
| | kwargs.pop("padding", None) |
| | kwargs.pop("truncation", None) |
| |
|
| | """ |
| | mode only works when a Message is converted to a dict. |
| | """ |
| |
|
| | if mode not in {"generation", "continuation"}: |
| | raise RuntimeError |
| |
|
| | if isinstance(conversations, (Message, Dict)): |
| | conversations = [conversations] |
| |
|
| | truncation = False |
| | if mode == "continuation": |
| | truncation = True |
| |
|
| | input_ids_list = [] |
| | for conversation in conversations: |
| | if isinstance(conversation, (Message, Dict)): |
| | conversation = [conversation] |
| |
|
| | |
| | conversation = [self._normalize_message(m) for m in conversation] |
| |
|
| | if (mode == "generation") ^ (len(conversation) % 2 != 0): |
| | raise ValueError |
| |
|
| | if (mode == "generation") ^ (conversation[-1]["role"] == "user"): |
| | raise ValueError |
| |
|
| | unified_codes = [] |
| | for message_idx, message in enumerate(conversation): |
| | if apply_chat_template: |
| | add_generation_prompt = ( |
| | mode == "generation" and message_idx == len(conversation) - 1 |
| | ) |
| | try: |
| | content = self.tokenizer.apply_chat_template( |
| | [{"role": message["role"], "content": message["content"]}], |
| | add_generation_prompt=add_generation_prompt, |
| | tokenize=False, |
| | ) |
| | except TypeError: |
| | try: |
| | content = self.tokenizer.apply_chat_template( |
| | [ |
| | { |
| | "role": message["role"], |
| | "content": message["content"], |
| | } |
| | ], |
| | add_generation_prompt=add_generation_prompt, |
| | ) |
| | except Exception: |
| | logger.warning( |
| | "apply_chat_template failed; fallback to raw content." |
| | ) |
| | content = message["content"] |
| | else: |
| | content = message["content"] |
| |
|
| | if not isinstance(content, str): |
| | content = str(content) |
| |
|
| | |
| | |
| | |
| | raw_audio_items = message.get("audio_codes_list", []) |
| |
|
| | audio_codes_list: List[torch.Tensor] = [] |
| | if len(raw_audio_items) > 0: |
| | encoded_items: List[Optional[torch.Tensor]] = [None] * len( |
| | raw_audio_items |
| | ) |
| | paths: List[str] = [] |
| | path_positions: List[int] = [] |
| |
|
| | for idx, item in enumerate(raw_audio_items): |
| | if isinstance(item, torch.Tensor): |
| | if n_vq is not None and item.shape[1] != n_vq: |
| | raise RuntimeError( |
| | "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs." |
| | ) |
| | encoded_items[idx] = item |
| | continue |
| |
|
| | if isinstance(item, (str, os.PathLike)): |
| | paths.append(str(item)) |
| | path_positions.append(idx) |
| | continue |
| |
|
| | raise TypeError( |
| | "Each audio item must be a torch.Tensor of codes or a path-like string." |
| | ) |
| |
|
| | if len(paths) > 0: |
| | encoded_from_paths = self.encode_audios_from_path(paths, n_vq) |
| | if len(encoded_from_paths) != len(paths): |
| | raise RuntimeError( |
| | "encode_audios_from_path returned an unexpected number of items." |
| | ) |
| | for pos, codes in zip(path_positions, encoded_from_paths): |
| | encoded_items[pos] = codes |
| |
|
| | audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items] |
| | unified_codes.append( |
| | self._get_unified_codes( |
| | message["role"], content, audio_codes_list, truncation |
| | ) |
| | ) |
| |
|
| | unified_codes = torch.cat(unified_codes) |
| | input_ids_list.append(unified_codes) |
| |
|
| | return BatchFeature(data=self._pad(input_ids_list)) |
| |
|
| | @staticmethod |
| | def build_user_message( |
| | text: Optional[str] = None, |
| | reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None, |
| | instruction: Optional[str] = None, |
| | tokens: Optional[int] = 125, |
| | quality: Optional[str] = None, |
| | sound_event: Optional[str] = None, |
| | ambient_sound: Optional[str] = None, |
| | language: Optional[str] = None, |
| | ) -> Dict: |
| | if reference is not None and not isinstance(reference, list): |
| | reference = [reference] |
| | return UserMessage( |
| | text=text, |
| | reference=reference, |
| | instruction=instruction, |
| | tokens=tokens, |
| | quality=quality, |
| | sound_event=sound_event, |
| | ambient_sound=ambient_sound, |
| | language=language, |
| | ).to_dict() |
| |
|
| | @staticmethod |
| | def build_assistant_message( |
| | audio_codes_list: List[Union[str, torch.Tensor]], |
| | content: str = AUDIO_PLACEHOLDER, |
| | ) -> Dict: |
| | return AssistantMessage( |
| | audio_codes_list=audio_codes_list, |
| | content=content, |
| | ).to_dict() |
| |
|
| | def _normalize_message(self, message: Union[Message, Dict]) -> Dict: |
| | if isinstance(message, Message): |
| | return message.to_dict() |
| | if not isinstance(message, dict): |
| | raise TypeError("Each message must be a Message or dict.") |
| | if "role" not in message: |
| | raise ValueError("Message dict must include a 'role' field.") |
| | if "content" in message and "audio_codes_list" in message: |
| | return message |
| | role = message["role"] |
| | if role == "user": |
| | kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS} |
| | return self.build_user_message(**kwargs) |
| | if role == "assistant": |
| | return self.build_assistant_message( |
| | audio_codes_list=message.get("audio_codes_list", []), |
| | content=message.get("content", AUDIO_PLACEHOLDER), |
| | ) |
| | raise ValueError(f"Unsupported role: {role}") |
| |
|
| | def _pad(self, input_ids_list: List[torch.Tensor]): |
| | device = input_ids_list[0].device |
| | lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device) |
| | pad_input_ids = torch.nn.utils.rnn.pad_sequence( |
| | input_ids_list, |
| | batch_first=True, |
| | padding_value=self.model_config.audio_pad_code, |
| | padding_side="left", |
| | ) |
| | other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze( |
| | 1 |
| | ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0) |
| | pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id |
| | attention_mask = torch.zeros( |
| | pad_input_ids.shape[0], pad_input_ids.shape[1], device=device |
| | ) |
| | attention_mask[~other_channel_mask] = 1 |
| | attention_mask = attention_mask.bool() |
| | return { |
| | "input_ids": pad_input_ids, |
| | "attention_mask": attention_mask, |
| | } |
| |
|
| | @staticmethod |
| | def _replace_audio_placeholders( |
| | content: str, |
| | lengths: List[int], |
| | n_vq: int, |
| | gen_slot_token: str, |
| | delay_slot_token: str, |
| | audio_start_token: str, |
| | audio_end_token: str, |
| | ) -> str: |
| | if n_vq < 1: |
| | raise ValueError(f"n_vq must be >= 1, got {n_vq}") |
| |
|
| | num_placeholders = content.count(AUDIO_PLACEHOLDER) |
| | if num_placeholders != len(lengths): |
| | raise ValueError( |
| | f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) " |
| | f"does not match lengths ({len(lengths)})" |
| | ) |
| |
|
| | def build_audio_block(length: int) -> str: |
| | if length < 0: |
| | raise ValueError(f"length must be >= 0, got {length}") |
| |
|
| | if length == 0: |
| | return f"{audio_start_token}{audio_end_token}" |
| |
|
| | step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1)) |
| | return f"{audio_start_token}{step_tokens}{audio_end_token}" |
| |
|
| | lengths_iter = iter(lengths) |
| |
|
| | def replacer(match: re.Match) -> str: |
| | length = next(lengths_iter) |
| | return build_audio_block(length) |
| |
|
| | result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content) |
| |
|
| | return result |
| |
|
| | @staticmethod |
| | def _merge_consecutive_audio_placeholders( |
| | content: str, |
| | audio_codes_list: List[torch.Tensor], |
| | ) -> Tuple[str, List[torch.Tensor]]: |
| | matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content)) |
| | if len(matches) <= 1: |
| | return content, audio_codes_list |
| |
|
| | if len(matches) != len(audio_codes_list): |
| | raise ValueError( |
| | "Audio placeholders do not match the provided audio codes list." |
| | ) |
| |
|
| | new_audio_codes_list = [] |
| | new_parts = [] |
| | last_pos = 0 |
| | i = 0 |
| | while i < len(matches): |
| | j = i |
| | while ( |
| | j + 1 < len(matches) |
| | and content[matches[j].end() : matches[j + 1].start()].strip() == "" |
| | ): |
| | j += 1 |
| |
|
| | new_parts.append(content[last_pos : matches[i].start()]) |
| | new_parts.append(AUDIO_PLACEHOLDER) |
| | last_pos = matches[j].end() |
| |
|
| | if j == i: |
| | new_audio_codes_list.append(audio_codes_list[i]) |
| | else: |
| | new_audio_codes_list.append( |
| | torch.cat(audio_codes_list[i : j + 1], dim=0) |
| | ) |
| |
|
| | i = j + 1 |
| |
|
| | new_parts.append(content[last_pos:]) |
| | return "".join(new_parts), new_audio_codes_list |
| |
|
| | @staticmethod |
| | def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor: |
| | delayed_tokens = torch.full( |
| | (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]), |
| | pad_code, |
| | device=codes.device, |
| | dtype=codes.dtype, |
| | ) |
| | for i in range(codes.shape[1]): |
| | delayed_tokens[i : i + codes.shape[0], i] = codes[:, i] |
| | return delayed_tokens |
| |
|
| | @staticmethod |
| | def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor: |
| | tokens = torch.full( |
| | (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]), |
| | 0, |
| | device=delay_codes.device, |
| | dtype=delay_codes.dtype, |
| | ) |
| | for i in range(delay_codes.shape[1]): |
| | tokens[:, i] = delay_codes[i : i + tokens.shape[0], i] |
| | return tokens |
| |
|
| | def _get_unified_codes( |
| | self, |
| | role: str, |
| | content: str, |
| | audio_codes_list: List[torch.Tensor], |
| | truncation: bool, |
| | ) -> torch.Tensor: |
| | """ |
| | 此时的 content 已经是带上了对话格式 |
| | """ |
| | if role == "user": |
| | audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token |
| | truncation = False |
| | else: |
| | audio_gen_slot_token = self.audio_assistant_gen_slot_token |
| | audio_delay_slot_token = self.audio_assistant_delay_slot_token |
| |
|
| | if len(audio_codes_list): |
| | n_vq = audio_codes_list[0].shape[1] |
| | else: |
| | n_vq = self.model_config.n_vq |
| |
|
| | if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content: |
| | content, audio_codes_list = self._merge_consecutive_audio_placeholders( |
| | content, audio_codes_list |
| | ) |
| | content = self._replace_audio_placeholders( |
| | content=content, |
| | lengths=[len(audio_codes) for audio_codes in audio_codes_list], |
| | n_vq=n_vq, |
| | gen_slot_token=audio_gen_slot_token, |
| | delay_slot_token=audio_delay_slot_token, |
| | audio_start_token=self.audio_start_token, |
| | audio_end_token=self.audio_end_token, |
| | ) |
| | text_codes = torch.tensor( |
| | self.tokenizer.encode(content), |
| | device=audio_codes_list[0].device if audio_codes_list else None, |
| | ) |
| |
|
| | audio_start_indices = torch.where( |
| | text_codes == self.model_config.audio_start_token_id |
| | )[0] |
| | audio_end_indices = torch.where( |
| | text_codes == self.model_config.audio_end_token_id |
| | )[0] |
| | if len(audio_start_indices) != len(audio_codes_list) or len( |
| | audio_end_indices |
| | ) != len(audio_codes_list): |
| | raise ValueError( |
| | "Audio placeholders do not match the provided audio codes list." |
| | ) |
| |
|
| | delay_audio_codes_list = [] |
| | if len(audio_codes_list) == 0: |
| | delay_audio_codes_list = torch.full( |
| | (len(text_codes), n_vq), |
| | self.model_config.audio_pad_code, |
| | device=text_codes.device, |
| | dtype=text_codes.dtype, |
| | ) |
| | else: |
| | prefix_idx = 0 |
| | for audio_start_idx_t, audio_end_idx_t, audio_codes in zip( |
| | audio_start_indices, audio_end_indices, audio_codes_list |
| | ): |
| | audio_start_idx = int(audio_start_idx_t.item()) |
| | audio_end_idx = int(audio_end_idx_t.item()) |
| | delay_audio_codes = self.apply_delay_pattern( |
| | audio_codes, self.model_config.audio_pad_code |
| | ) |
| | pad_codes = torch.full( |
| | (audio_start_idx - prefix_idx + 1, n_vq), |
| | self.model_config.audio_pad_code, |
| | device=audio_codes.device, |
| | dtype=audio_codes.dtype, |
| | ) |
| | delay_audio_codes_list.extend([pad_codes, delay_audio_codes]) |
| | prefix_idx = audio_end_idx |
| |
|
| | if truncation: |
| | delay_audio_codes_list[-1] = delay_audio_codes_list[-1][ |
| | : -(n_vq - 1), : |
| | ] |
| | else: |
| | last_audio_end_idx = int(audio_end_indices[-1].item()) |
| | pad_codes = torch.full( |
| | (len(text_codes) - last_audio_end_idx, n_vq), |
| | self.model_config.audio_pad_code, |
| | device=audio_codes_list[0].device, |
| | dtype=audio_codes_list[0].dtype, |
| | ) |
| | delay_audio_codes_list.append(pad_codes) |
| |
|
| | delay_audio_codes_list = torch.cat(delay_audio_codes_list) |
| |
|
| | if text_codes.shape[0] != delay_audio_codes_list.shape[0]: |
| | text_codes = text_codes[: delay_audio_codes_list.shape[0]] |
| |
|
| | unified_codes = torch.cat( |
| | [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1 |
| | ) |
| | return unified_codes |
| |
|
| | def _parse_text_codes(self, start_length, text_codes): |
| | text = cast(str, self.tokenizer.decode(text_codes)) |
| | prefix = cast(str, self.tokenizer.decode(text_codes[:start_length])) |
| | text = text[len(prefix) :] |
| |
|
| | AUDIO_PATTERN = re.compile( |
| | rf"(?:{self.audio_start_token})?" |
| | rf"(?:{self.audio_assistant_gen_slot_token})*" |
| | rf"(?:{self.audio_assistant_delay_slot_token})*" |
| | rf"{self.audio_end_token}" |
| | ) |
| |
|
| | def normalize_audio_segments(text: str) -> str: |
| | def repl(match: re.Match) -> str: |
| | seg = match.group(0) |
| | |
| | if self.audio_assistant_gen_slot_token in seg: |
| | return AUDIO_PLACEHOLDER |
| | |
| | return "" |
| |
|
| | return AUDIO_PATTERN.sub(repl, text) |
| |
|
| | return normalize_audio_segments(text) |
| |
|
| | def _parse_audio_codes(self, start_length, audio_codes): |
| | |
| | audio_codes = self.apply_de_delay_pattern(audio_codes) |
| |
|
| | |
| | is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1) |
| | non_pad = ~is_pad |
| | if not non_pad.any(): |
| | return [] |
| |
|
| | idx = torch.nonzero(non_pad).squeeze(1) |
| | breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1 |
| | if breaks.numel() == 0: |
| | segments_idx = [idx] |
| | else: |
| | segments_idx = torch.split(idx, breaks.tolist()) |
| |
|
| | audio_codes_list = [audio_codes[s] for s in segments_idx] |
| |
|
| | |
| | decoded_audio_list = self.decode_audio_codes(audio_codes_list) |
| |
|
| | |
| | |
| | if ( |
| | start_length > 0 |
| | and len(audio_codes_list) > 0 |
| | and len(decoded_audio_list) > 0 |
| | ): |
| | first_codes_length = audio_codes_list[0].shape[0] |
| | if first_codes_length > 0: |
| | trim_ratio = max( |
| | 0.0, min(float(start_length) / float(first_codes_length), 1.0) |
| | ) |
| | first_audio = decoded_audio_list[0] |
| | if trim_ratio >= 1.0: |
| | decoded_audio_list = decoded_audio_list[1:] |
| | elif trim_ratio > 0.0: |
| | trim_samples = int(first_audio.shape[-1] * trim_ratio) |
| | decoded_audio_list[0] = first_audio[..., trim_samples:] |
| |
|
| | return decoded_audio_list |
| |
|
| | def decode(self, output: List[Tuple[int, torch.Tensor]]): |
| | """ |
| | 1. 这里不管怎样,都需要一个完整的 assistant generation ids; |
| | 2. 支持从任意位置进行截断; |
| | """ |
| |
|
| | genearted_messages = [] |
| | for start_length, generation_ids in output: |
| | content = self._parse_text_codes(start_length, generation_ids[:, 0]) |
| | audio_codes_list = self._parse_audio_codes( |
| | start_length, generation_ids[:, 1:] |
| | ) |
| | if content == "": |
| | message = None |
| | else: |
| | message = AssistantMessage( |
| | content=content, |
| | audio_codes_list=cast( |
| | List[Union[str, torch.Tensor]], audio_codes_list |
| | ), |
| | ) |
| | genearted_messages.append(message) |
| | return genearted_messages |
| |
|
| | @staticmethod |
| | def loudness_normalize( |
| | wav: torch.Tensor, |
| | target_dbfs: float = -20, |
| | gain_range: tuple[float, float] = (-3.0, 3.0), |
| | ) -> torch.Tensor: |
| | wav = wav.to(torch.float32) |
| | if wav.numel() == 0: |
| | return wav |
| | current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9) |
| | gain = float(target_dbfs - current_dbfs) |
| | gain = max(gain_range[0], min(gain, gain_range[1])) |
| | factor = 10.0 ** (gain / 20.0) |
| | return wav * factor |
| |
|
| | def _get_audio_tokenizer_device(self) -> torch.device: |
| | """Best-effort device inference for `self.audio_tokenizer`. |
| | |
| | Notes: |
| | - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not. |
| | - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device. |
| | """ |
| |
|
| | audio_tokenizer = getattr(self, "audio_tokenizer", None) |
| | if audio_tokenizer is None: |
| | logger.warning( |
| | "audio_tokenizer is not set on processor. Using CPU as default." |
| | ) |
| | return torch.device("cpu") |
| |
|
| | device_attr = getattr(audio_tokenizer, "device", None) |
| | if isinstance(device_attr, torch.device): |
| | return device_attr |
| |
|
| | try: |
| | return next(audio_tokenizer.parameters()).device |
| | except StopIteration: |
| | |
| | logger.warning( |
| | "No parameters found on audio_tokenizer. Using CPU as default." |
| | ) |
| | return torch.device("cpu") |
| |
|
| | def encode_audios_from_wav( |
| | self, |
| | wav_list: List[torch.Tensor], |
| | sampling_rate: int, |
| | n_vq: Optional[int] = None, |
| | ): |
| | if self.audio_tokenizer is None: |
| | raise RuntimeError("audio_tokenizer is not set on processor.") |
| | audio_tokenizer = self.audio_tokenizer |
| |
|
| | if isinstance(wav_list, torch.Tensor): |
| | wav_list = [wav_list] |
| | wav_list_ = [] |
| | resample = False |
| | if sampling_rate != self.model_config.sampling_rate: |
| | resample = True |
| | device = self._get_audio_tokenizer_device() |
| | for wav in wav_list: |
| | if wav.shape[0] > 1: |
| | wav = torch.mean(wav, dim=0, keepdim=True) |
| | if resample: |
| | wav = torchaudio.functional.resample( |
| | waveform=wav, |
| | orig_freq=sampling_rate, |
| | new_freq=self.model_config.sampling_rate, |
| | ) |
| | wav = wav.to(device) |
| | wav_list_.append(self.loudness_normalize(wav.squeeze(0))) |
| |
|
| | |
| | if hasattr(audio_tokenizer, "batch_encode"): |
| | enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq) |
| | audio_codes = enc.audio_codes |
| | audio_codes_lengths = enc.audio_codes_lengths |
| | else: |
| | |
| | max_len = max(int(wav.shape[-1]) for wav in wav_list_) |
| | input_values = torch.zeros( |
| | len(wav_list_), 1, max_len, device=device, dtype=torch.float32 |
| | ) |
| | padding_mask = torch.zeros( |
| | len(wav_list_), max_len, device=device, dtype=torch.bool |
| | ) |
| | for i, wav in enumerate(wav_list_): |
| | this_len = int(wav.shape[-1]) |
| | input_values[i, 0, :this_len] = wav |
| | padding_mask[i, :this_len] = True |
| | enc = audio_tokenizer.encode( |
| | input_values, |
| | padding_mask=padding_mask, |
| | num_quantizers=n_vq, |
| | return_dict=True, |
| | ) |
| | audio_codes = enc.audio_codes |
| | audio_codes_lengths = enc.audio_codes_lengths |
| |
|
| | if audio_codes is None or audio_codes_lengths is None: |
| | raise RuntimeError( |
| | "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)." |
| | ) |
| |
|
| | |
| | |
| | codes_list: List[torch.Tensor] = [] |
| | for i in range(int(audio_codes.shape[1])): |
| | length_i = int(audio_codes_lengths[i].item()) |
| | codes_i = ( |
| | audio_codes[:, i, :length_i] |
| | .transpose(0, 1) |
| | .contiguous() |
| | .to(torch.long) |
| | .cpu() |
| | ) |
| | codes_list.append(codes_i) |
| | return codes_list |
| |
|
| | def encode_audios_from_path( |
| | self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None |
| | ): |
| | if isinstance(wav_path_list, str): |
| | wav_path_list = [wav_path_list] |
| |
|
| | if len(wav_path_list) == 0: |
| | raise ValueError("Empty wav_path_list") |
| |
|
| | |
| | |
| | |
| | target_sr = int(self.model_config.sampling_rate) |
| | wav_list: List[torch.Tensor] = [] |
| | for wav_path in wav_path_list: |
| | wav, sr = torchaudio.load(wav_path) |
| | if int(sr) != target_sr: |
| | wav = torchaudio.functional.resample( |
| | waveform=wav, |
| | orig_freq=int(sr), |
| | new_freq=target_sr, |
| | ) |
| | wav_list.append(wav) |
| |
|
| | return self.encode_audios_from_wav(wav_list, target_sr, n_vq) |
| |
|
| | def decode_audio_codes( |
| | self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]] |
| | ): |
| | if self.audio_tokenizer is None: |
| | raise RuntimeError("audio_tokenizer is not set on processor.") |
| | audio_tokenizer = self.audio_tokenizer |
| |
|
| | if isinstance(audio_tokens_list, torch.Tensor): |
| | audio_tokens_list = [audio_tokens_list] |
| | if len(audio_tokens_list) == 0: |
| | return [] |
| |
|
| | device = self._get_audio_tokenizer_device() |
| |
|
| | |
| | codes_list = [ |
| | codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long) |
| | for codes in audio_tokens_list |
| | ] |
| | |
| | |
| | nq = int(codes_list[0].shape[0]) |
| | max_t = max(int(c.shape[1]) for c in codes_list) |
| | audio_codes = torch.zeros( |
| | nq, len(codes_list), max_t, device=device, dtype=torch.long |
| | ) |
| | padding_mask = torch.zeros( |
| | len(codes_list), max_t, device=device, dtype=torch.bool |
| | ) |
| | for i, c in enumerate(codes_list): |
| | t = int(c.shape[1]) |
| | audio_codes[:, i, :t] = c |
| | padding_mask[i, :t] = True |
| | dec = audio_tokenizer.decode( |
| | audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8 |
| | ) |
| | audio = dec.audio |
| | audio_lengths = dec.audio_lengths |
| |
|
| | if audio is None or audio_lengths is None: |
| | raise RuntimeError( |
| | "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)." |
| | ) |
| |
|
| | |
| | wav_list: List[torch.Tensor] = [] |
| | for i in range(int(audio.shape[0])): |
| | length_i = int(audio_lengths[i].item()) |
| | wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu() |
| | wav_list.append(wav) |
| | return wav_list |
| |
|