Spaces:
Running
Running
| """ | |
| Maya1 Model Loader | |
| Loads Maya1 model with vLLM engine and validates emotion tags. | |
| """ | |
| import os | |
| from transformers import AutoTokenizer | |
| from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams | |
| from .constants import ( | |
| ALL_EMOTION_TAGS, | |
| DEFAULT_MAX_MODEL_LEN, | |
| SOH_ID, EOH_ID, SOA_ID, BOS_ID, TEXT_EOT_ID, CODE_START_TOKEN_ID, | |
| ) | |
| class Maya1Model: | |
| """Maya1 TTS Model with vLLM inference engine.""" | |
| def __init__( | |
| self, | |
| model_path: str = None, | |
| dtype: str = "bfloat16", | |
| max_model_len: int = DEFAULT_MAX_MODEL_LEN, | |
| gpu_memory_utilization: float = 0.85, | |
| tensor_parallel_size: int = 1, | |
| **engine_kwargs | |
| ): | |
| """ | |
| Initialize Maya1 model with vLLM. | |
| Args: | |
| model_path: Path to checkpoint (local or HF repo) | |
| dtype: Model precision (bfloat16 recommended) | |
| max_model_len: Maximum sequence length | |
| gpu_memory_utilization: GPU memory fraction | |
| tensor_parallel_size: Number of GPUs | |
| """ | |
| # Use provided path or environment variable or default | |
| if model_path is None: | |
| model_path = os.environ.get( | |
| 'MAYA1_MODEL_PATH', | |
| os.path.expanduser('~/models/maya1-voice') | |
| ) | |
| self.model_path = model_path | |
| self.dtype = dtype | |
| print(f"Initializing Maya1 Model") | |
| print(f"Model: {model_path}") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| ) | |
| print(f"Tokenizer loaded: {len(self.tokenizer)} tokens") | |
| # Validate emotion tags | |
| self._validate_emotion_tags() | |
| # Precompute special token strings | |
| self._init_special_tokens() | |
| # Initialize vLLM engine | |
| print(f"Initializing vLLM engine...") | |
| engine_args = AsyncEngineArgs( | |
| model=model_path, | |
| tokenizer=model_path, | |
| dtype=dtype, | |
| max_model_len=max_model_len, | |
| gpu_memory_utilization=gpu_memory_utilization, | |
| tensor_parallel_size=tensor_parallel_size, | |
| trust_remote_code=True, | |
| disable_log_stats=False, | |
| **engine_kwargs | |
| ) | |
| self.engine = AsyncLLMEngine.from_engine_args(engine_args) | |
| print(f"Maya1 Model ready\n") | |
| def _validate_emotion_tags(self): | |
| """Validate that all 20 emotion tags are single tokens.""" | |
| failed_tags = [] | |
| for tag in ALL_EMOTION_TAGS: | |
| token_ids = self.tokenizer.encode(tag, add_special_tokens=False) | |
| if len(token_ids) != 1: | |
| failed_tags.append((tag, len(token_ids))) | |
| if failed_tags: | |
| print(f"ERROR: {len(failed_tags)} emotion tags are NOT single tokens!") | |
| raise AssertionError(f"Emotion tags validation failed") | |
| print(f"All {len(ALL_EMOTION_TAGS)} emotion tags validated") | |
| def _init_special_tokens(self): | |
| """Precompute special token strings for fast prefix building.""" | |
| self.soh_token = self.tokenizer.decode([SOH_ID]) | |
| self.bos_token = self.tokenizer.bos_token | |
| self.eot_token = self.tokenizer.decode([TEXT_EOT_ID]) | |
| self.eoh_token = self.tokenizer.decode([EOH_ID]) | |
| self.soa_token = self.tokenizer.decode([SOA_ID]) | |
| self.sos_token = self.tokenizer.decode([CODE_START_TOKEN_ID]) | |
| async def generate(self, prompt: str, sampling_params: SamplingParams): | |
| """ | |
| Generate tokens from prompt (non-streaming). | |
| Args: | |
| prompt: Input prompt | |
| sampling_params: vLLM sampling parameters | |
| Returns: | |
| Generated output from vLLM | |
| """ | |
| request_id = f"req_{id(prompt)}" | |
| # Collect results from async generator | |
| final_output = None | |
| async for output in self.engine.generate( | |
| prompt=prompt, | |
| sampling_params=sampling_params, | |
| request_id=request_id | |
| ): | |
| final_output = output | |
| return [final_output] if final_output else [] | |
| async def generate_stream(self, prompt: str, sampling_params: SamplingParams): | |
| """ | |
| Generate tokens from prompt (streaming). | |
| Args: | |
| prompt: Input prompt | |
| sampling_params: vLLM sampling parameters | |
| Yields: | |
| Generated outputs from vLLM | |
| """ | |
| request_id = f"req_{id(prompt)}" | |
| # Stream from engine | |
| async for output in self.engine.generate( | |
| prompt=prompt, | |
| sampling_params=sampling_params, | |
| request_id=request_id | |
| ): | |
| yield output | |