|
|
""" |
|
|
Modified From https://github.com/XXXXRT666/GPT-SoVITS |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import List, Literal, MutableSequence, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from AR.models.t2s_model_abc import KVCacheABC, Sampler, T2SDecoderABC |
|
|
|
|
|
Tensor = torch.Tensor |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class T2SResult: |
|
|
result: List[Tensor] | None = None |
|
|
infer_speed: float = 0.0 |
|
|
status: Literal["Success", "Error"] = "Success" |
|
|
exception: Optional[Exception] = None |
|
|
traceback: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class T2SRequest: |
|
|
x: List[torch.Tensor] |
|
|
x_lens: Tensor |
|
|
prompts: torch.Tensor |
|
|
bert_feature: List[Tensor] |
|
|
valid_length: int |
|
|
top_k: int = 5 |
|
|
top_p: float = 1 |
|
|
early_stop_num: int = -1 |
|
|
temperature: float = 1.0 |
|
|
repetition_penalty: float = 1.35 |
|
|
use_cuda_graph: bool = False |
|
|
debug: bool = False |
|
|
|
|
|
|
|
|
class T2SSession: |
|
|
def __init__(self, decoder: T2SDecoderABC, request: T2SRequest, device: torch.device, dtype: torch.dtype): |
|
|
with device: |
|
|
self.decoder = decoder |
|
|
self.request = request |
|
|
self.device = device |
|
|
self.dtype = dtype |
|
|
|
|
|
bsz = len(request.x) |
|
|
y_len = request.prompts.size(-1) |
|
|
self.bsz = bsz |
|
|
self.y_len = y_len |
|
|
|
|
|
|
|
|
self.kv_cache: MutableSequence[KVCacheABC] |
|
|
self.sampler = Sampler(bsz, decoder.vocab_size) |
|
|
|
|
|
|
|
|
self.x = request.x |
|
|
self.x_lens = request.x_lens.to(torch.int32) |
|
|
self.y = request.prompts |
|
|
self.bert_feature = request.bert_feature |
|
|
|
|
|
self.prefill_len = self.x_lens + self.y.size(1) |
|
|
|
|
|
self.input_pos = torch.zeros_like(self.prefill_len) |
|
|
self.input_pos.add_(self.prefill_len) |
|
|
|
|
|
|
|
|
self.graph: Optional[torch.cuda.CUDAGraph] = None |
|
|
self.xy_pos_: Tensor |
|
|
self.xy_dec_: Tensor |
|
|
|
|
|
|
|
|
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device) |
|
|
self.y_results: List[Tensor] = [None] * len(self.x) |
|
|
|
|
|
self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature) |
|
|
|
|
|
attn_mask = [] |
|
|
for bs in range(bsz): |
|
|
pos = int(self.x_lens[bs].item()) |
|
|
mask = torch.zeros(pos + y_len, pos + y_len).bool() |
|
|
mask[:, :pos].fill_(True) |
|
|
if y_len > 0: |
|
|
mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1) |
|
|
attn_mask.append(mask) |
|
|
self.attn_mask_nested = torch.nested.nested_tensor(attn_mask) |
|
|
|
|
|
self.id: int = -1 |
|
|
|