|
|
""" |
|
|
Modified From https://github.com/XXXXRT666/GPT-SoVITS |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import random |
|
|
from abc import ABC, abstractmethod |
|
|
from contextlib import nullcontext |
|
|
from typing import Any, Dict, List, MutableSequence, Tuple, Type |
|
|
|
|
|
import torch |
|
|
import torch._inductor.config |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.cuda.graphs import CUDAGraph |
|
|
from torch.profiler import ProfilerAction, tensorboard_trace_handler |
|
|
|
|
|
from AR.models.embedding import ( |
|
|
SinePositionalEmbeddingNested as SinePositionalEmbedding, |
|
|
) |
|
|
from AR.models.embedding import TokenEmbedding |
|
|
|
|
|
Tensor = torch.Tensor |
|
|
|
|
|
|
|
|
class Sampler(nn.Module): |
|
|
def __init__(self, batch_size: int, vocab_size: int) -> None: |
|
|
super().__init__() |
|
|
self.batch_size = batch_size |
|
|
|
|
|
|
|
|
def sample( |
|
|
self, |
|
|
logits: Tensor, |
|
|
previous_tokens: Tensor, |
|
|
temperature: float, |
|
|
top_k: int, |
|
|
top_p: float, |
|
|
repetition_penalty: float, |
|
|
) -> Tensor: |
|
|
previous_tokens = previous_tokens.long() |
|
|
score = torch.gather(logits, dim=1, index=previous_tokens) |
|
|
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) |
|
|
logits.scatter_(dim=1, index=previous_tokens, src=score) |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cum_probs > top_p |
|
|
sorted_indices_to_remove[:, 0] = False |
|
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) |
|
|
logits = logits.masked_fill(indices_to_remove, -float("Inf")) |
|
|
|
|
|
logits = logits / max(temperature, 1e-5) |
|
|
|
|
|
v, _ = torch.topk(logits, top_k) |
|
|
pivot = v[:, -1].unsqueeze(-1) |
|
|
logits = torch.where(logits < pivot, -float("Inf"), logits) |
|
|
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
q = torch.empty_like(probs).exponential_(1.0) |
|
|
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32) |
|
|
|
|
|
return idx_next |
|
|
|
|
|
|
|
|
class KVCacheABC(ABC, nn.Module): |
|
|
def __init__(self, *args, **kwds) -> None: |
|
|
super().__init__() |
|
|
self.k_cache: Tensor |
|
|
self.v_cache: Tensor |
|
|
self.n_head: int |
|
|
self.head_dim: int |
|
|
self.batch_size: int |
|
|
self.max_seq_length: int |
|
|
|
|
|
def empty(self): |
|
|
self.k_cache.zero_() |
|
|
self.v_cache.zero_() |
|
|
|
|
|
@abstractmethod |
|
|
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> Tuple[Tensor, Tensor]: ... |
|
|
|
|
|
@abstractmethod |
|
|
def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int) -> None: ... |
|
|
|
|
|
def sync_cache(self, kv_cache: KVCacheABC): |
|
|
self.k_cache.copy_(kv_cache.k_cache) |
|
|
self.v_cache.copy_(kv_cache.v_cache) |
|
|
|
|
|
def forward(self): |
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
class KVCacheNHD(KVCacheABC): |
|
|
def __init__(self, batch_size, max_seq_length, n_heads, head_dim): |
|
|
super().__init__() |
|
|
assert batch_size > 0 |
|
|
cache_shape = (batch_size, max_seq_length, n_heads, head_dim) |
|
|
self.n_head = n_heads |
|
|
self.head_dim = head_dim |
|
|
self.batch_size = batch_size |
|
|
self.max_seq_length = max_seq_length |
|
|
|
|
|
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False) |
|
|
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False) |
|
|
|
|
|
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor): |
|
|
|
|
|
|
|
|
index = ( |
|
|
(input_pos - 1) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
.expand( |
|
|
-1, |
|
|
-1, |
|
|
self.n_head, |
|
|
self.head_dim, |
|
|
) |
|
|
.to(torch.int64) |
|
|
) |
|
|
|
|
|
k_out = self.k_cache |
|
|
v_out = self.v_cache |
|
|
k_out.scatter_(1, index, k_val) |
|
|
v_out.scatter_(1, index, v_val) |
|
|
|
|
|
return k_out, v_out |
|
|
|
|
|
def empty(self): |
|
|
self.k_cache.zero_() |
|
|
self.v_cache.zero_() |
|
|
|
|
|
def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int): |
|
|
|
|
|
|
|
|
self.k_cache[[bs], : k_val.shape[1]] = k_val |
|
|
self.v_cache[[bs], : v_val.shape[1]] = v_val |
|
|
|
|
|
|
|
|
class KVCacheHND(KVCacheABC): |
|
|
def __init__(self, batch_size, max_seq_length, n_heads, head_dim): |
|
|
super().__init__() |
|
|
assert batch_size > 0 |
|
|
cache_shape = (batch_size, n_heads, max_seq_length, head_dim) |
|
|
self.n_head = n_heads |
|
|
self.head_dim = head_dim |
|
|
self.batch_size = batch_size |
|
|
self.max_seq_length = max_seq_length |
|
|
|
|
|
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False) |
|
|
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False) |
|
|
|
|
|
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor): |
|
|
|
|
|
|
|
|
index = ( |
|
|
(input_pos - 1) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
.unsqueeze(-1) |
|
|
.expand( |
|
|
-1, |
|
|
self.n_head, |
|
|
-1, |
|
|
self.head_dim, |
|
|
) |
|
|
.to(torch.int64) |
|
|
) |
|
|
|
|
|
k_out = self.k_cache |
|
|
v_out = self.v_cache |
|
|
k_out.scatter_(2, index, k_val) |
|
|
v_out.scatter_(2, index, v_val) |
|
|
|
|
|
return k_out, v_out |
|
|
|
|
|
def empty(self): |
|
|
self.k_cache.zero_() |
|
|
self.v_cache.zero_() |
|
|
|
|
|
def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int): |
|
|
|
|
|
|
|
|
self.k_cache[[bs], :, : k_val.shape[1]] = k_val.transpose(1, 2) |
|
|
self.v_cache[[bs], :, : v_val.shape[1]] = v_val.transpose(1, 2) |
|
|
|
|
|
|
|
|
class AttentionABC(ABC, nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.n_head: int |
|
|
self.hidden_dim: int |
|
|
self.head_dim: int |
|
|
|
|
|
|
|
|
self.in_proj: nn.Linear |
|
|
self.out_proj: nn.Linear |
|
|
|
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook) |
|
|
|
|
|
def load_hook(self, state_dict: dict, prefix, *args): |
|
|
keys_to_modify = [key for key in state_dict if "in_proj_" in key] |
|
|
for key in keys_to_modify: |
|
|
new_key = key.replace("in_proj_", "in_proj.") |
|
|
state_dict[new_key] = state_dict.pop(key) |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor: ... |
|
|
|
|
|
def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheABC) -> Tensor: |
|
|
bsz = x.size(0) |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
for bs in range(bsz): |
|
|
x_b = x[bs].unsqueeze(0) |
|
|
|
|
|
q, k, v = self.in_proj.forward(x_b.unsqueeze(0)).chunk(3, dim=-1) |
|
|
|
|
|
q = q.contiguous().view(1, -1, self.n_head, self.head_dim) |
|
|
k = k.contiguous().view(1, -1, self.n_head, self.head_dim) |
|
|
v = v.contiguous().view(1, -1, self.n_head, self.head_dim) |
|
|
|
|
|
kv_cache.prefill_kv(k, v, bs) |
|
|
|
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) |
|
|
|
|
|
attn_mask = mask[bs].unsqueeze(0).unsqueeze(0).expand(1, self.n_head, -1, -1) |
|
|
|
|
|
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) |
|
|
|
|
|
attn = self.dropout.forward(attn) |
|
|
|
|
|
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim) |
|
|
|
|
|
output = self.out_proj.forward(attn) |
|
|
|
|
|
outputs.append(output.squeeze(0)) |
|
|
|
|
|
return torch.nested.nested_tensor(outputs) |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim: int, hidden_dim: int) -> None: |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(dim, hidden_dim, bias=True) |
|
|
self.linear2 = nn.Linear(hidden_dim, dim, bias=True) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return self.dropout.forward(self.linear2(self.dropout.forward(F.relu(self.linear1(x))))) |
|
|
|
|
|
|
|
|
class TransformerBlockABC(ABC, nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
self.hidden_dim: int |
|
|
self.attention: AttentionABC |
|
|
self.feed_forward: FeedForward |
|
|
self.attention_norm: nn.LayerNorm |
|
|
self.ffn_norm: nn.LayerNorm |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook) |
|
|
|
|
|
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args): |
|
|
for key in list(state_dict.keys()): |
|
|
new_key = ( |
|
|
key.replace("self_attn", "attention") |
|
|
.replace("linear", "feed_forward.linear") |
|
|
.replace("norm1", "attention_norm") |
|
|
.replace("norm2", "ffn_norm") |
|
|
) |
|
|
state_dict[new_key] = state_dict.pop(key) |
|
|
|
|
|
def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor: |
|
|
h = self.attention_norm.forward( |
|
|
x |
|
|
+ self.dropout.forward( |
|
|
self.attention.forward( |
|
|
x, |
|
|
input_pos, |
|
|
kv_cache, |
|
|
*args, |
|
|
**kwds, |
|
|
) |
|
|
) |
|
|
) |
|
|
out = self.ffn_norm.forward(h + self.feed_forward.forward(h)) |
|
|
return out |
|
|
|
|
|
def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheABC) -> Tensor: |
|
|
h = self.attention_norm.forward( |
|
|
x |
|
|
+ self.dropout.forward( |
|
|
self.attention.prefill( |
|
|
x, |
|
|
mask, |
|
|
kv_cache, |
|
|
) |
|
|
) |
|
|
) |
|
|
out = self.ffn_norm.forward(h + self.feed_forward.forward(h)) |
|
|
return out |
|
|
|
|
|
|
|
|
class TransformerDecoderABC(ABC, nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.hidden_dim: int |
|
|
self.n_head: int |
|
|
self.head_dim: int |
|
|
self.vocab_size: int |
|
|
self.n_layer: int |
|
|
|
|
|
self.layers: MutableSequence[TransformerBlockABC] |
|
|
|
|
|
self.max_seq_length: int |
|
|
self.max_batch_size: int |
|
|
|
|
|
self.input_pos: Tensor |
|
|
self.xy_pos: Tensor |
|
|
self.xy_dec: Tensor |
|
|
|
|
|
def forward(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheABC], *args, **kwds): |
|
|
for layer, kv_cache in zip(self.layers, kv_caches): |
|
|
x = layer.forward(x, input_pos, kv_cache, *args, **kwds) |
|
|
return x |
|
|
|
|
|
def prefill(self, x: Tensor, mask: Tensor, kv_caches: MutableSequence[KVCacheABC]): |
|
|
for layer, kv_cache in zip(self.layers, kv_caches): |
|
|
x = layer.prefill(x, mask, kv_cache) |
|
|
return x |
|
|
|
|
|
|
|
|
class T2SDecoderABC(ABC, nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.n_layer: int |
|
|
self.hidden_dim: int |
|
|
self.n_head: int |
|
|
|
|
|
self.head_dim: int |
|
|
self.embedding_dim: int |
|
|
self.vocab_size: int |
|
|
self.phoneme_vocab_size: int |
|
|
self.p_dropout: float |
|
|
self.max_seq_length: int |
|
|
self.max_batch_size: int |
|
|
self.EOS: int |
|
|
|
|
|
self.bert_proj: nn.Linear |
|
|
self.ar_text_embedding: TokenEmbedding |
|
|
self.ar_text_position: SinePositionalEmbedding |
|
|
self.ar_audio_embedding: TokenEmbedding |
|
|
self.ar_audio_position: SinePositionalEmbedding |
|
|
self.ar_predict_layer: nn.Linear |
|
|
self.h: TransformerDecoderABC |
|
|
|
|
|
self.kv_class: Type[KVCacheNHD] | Type[KVCacheHND] |
|
|
|
|
|
self.GraphCache: CUDAGraphCacheABC | None |
|
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook) |
|
|
|
|
|
def load_hook(self, state_dict, prefix, *args): |
|
|
model_keys = [key for key in state_dict if key.startswith("model.")] |
|
|
for key in model_keys: |
|
|
new_key = key[len("model.") :] |
|
|
state_dict[new_key] = state_dict.pop(key) |
|
|
|
|
|
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheABC]: |
|
|
bsz = bsz or self.h.max_batch_size |
|
|
assert bsz <= self.h.max_batch_size |
|
|
seq_lens = self.h.max_seq_length |
|
|
device = self.bert_proj.bias.device |
|
|
dtype = self.bert_proj.bias.dtype |
|
|
kvclass = self.kv_class |
|
|
return nn.ModuleList( |
|
|
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)], |
|
|
).to(device, dtype) |
|
|
|
|
|
@abstractmethod |
|
|
def embed(self, x: List[torch.Tensor], y: torch.Tensor, bert_features: List[Tensor]) -> Tensor: ... |
|
|
|
|
|
def compile(self, *args, **kwds): |
|
|
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True |
|
|
torch._inductor.config.coordinate_descent_tuning = True |
|
|
torch._inductor.config.triton.unique_kernel_names = True |
|
|
|
|
|
torch._inductor.config.fx_graph_cache = True |
|
|
torch._inductor.config.triton.cudagraph_trees = True |
|
|
torch._inductor.config.triton.cudagraph_support_input_mutation = True |
|
|
self.h.compile(fullgraph=True, mode="reduce-overhead") |
|
|
|
|
|
def capture(self, input_pos: Tensor, x: Tensor, x_dec: Tensor, *args, **kwds) -> CUDAGraph: |
|
|
assert torch.cuda.is_available() |
|
|
s = torch.cuda.Stream() |
|
|
s.wait_stream(torch.cuda.current_stream()) |
|
|
|
|
|
graph = torch.cuda.CUDAGraph() |
|
|
|
|
|
with torch.cuda.stream(s): |
|
|
for _ in range(5): |
|
|
self.h.forward(input_pos, x, *args, **kwds) |
|
|
torch.cuda.current_stream().wait_stream(s) |
|
|
|
|
|
with torch.cuda.graph(graph): |
|
|
x_dec.copy_(self.h.forward(input_pos, x, *args, **kwds)) |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
return graph |
|
|
|
|
|
@abstractmethod |
|
|
def pre_forward(self, session: Any) -> Tuple[List, Dict]: ... |
|
|
|
|
|
@abstractmethod |
|
|
def post_forward(self, idx: int, session: Any) -> None: ... |
|
|
|
|
|
|
|
|
class CUDAGraphCacheABC(ABC): |
|
|
def __init__( |
|
|
self, |
|
|
decoder: T2SDecoderABC, |
|
|
device: torch.device = torch.device("cpu"), |
|
|
dtype: torch.dtype = torch.float32, |
|
|
) -> None: |
|
|
assert torch.cuda.is_available() |
|
|
|
|
|
self.assigned: bool = False |
|
|
|
|
|
self.decoder: T2SDecoderABC = decoder |
|
|
self.kv_cache: MutableSequence[KVCacheABC] = decoder.init_cache(1) |
|
|
self.xy_pos = torch.rand((1, 1, decoder.embedding_dim), device=device).to(dtype) |
|
|
self.xy_dec = torch.rand((1, 1, decoder.embedding_dim), device=device).to(dtype) |
|
|
self.input_pos = torch.tensor([10]).int().cuda() |
|
|
self.graph: torch.cuda.CUDAGraph | None = None |
|
|
|
|
|
self.id: int = random.randint(1, 2**32 - 1) |
|
|
|
|
|
def assign_graph(self, session: Any): |
|
|
if self.graph is None: |
|
|
args, kwds = self.decoder.pre_forward(session) |
|
|
graph = self.decoder.capture( |
|
|
self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds |
|
|
) |
|
|
self.graph = graph |
|
|
|
|
|
if self.assigned is False: |
|
|
self.get_cache_graph(session) |
|
|
session.id = self.id |
|
|
self.assigned = True |
|
|
else: |
|
|
self.capture_new_graph(session) |
|
|
|
|
|
@abstractmethod |
|
|
def release_graph(self, session: Any): ... |
|
|
|
|
|
@abstractmethod |
|
|
def get_cache_graph(self, session: Any): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def capture_new_graph(self, session: Any): |
|
|
pass |
|
|
|
|
|
|
|
|
class TorchProfiler: |
|
|
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None: |
|
|
self.debug = debug |
|
|
self.log_dir = log_dir |
|
|
self.__profiler: torch.profiler.profile |
|
|
|
|
|
if self.debug and not os.path.exists(self.log_dir): |
|
|
os.makedirs(self.log_dir) |
|
|
|
|
|
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir) |
|
|
|
|
|
def profiler_callback(self, prof: torch.profiler.profile): |
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) |
|
|
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30)) |
|
|
self.tensorboard_handler(prof) |
|
|
|
|
|
@staticmethod |
|
|
def three_step_schedule(step: int) -> ProfilerAction: |
|
|
if step == 0: |
|
|
return ProfilerAction.NONE |
|
|
elif step == 1: |
|
|
return ProfilerAction.RECORD |
|
|
elif step == 2: |
|
|
return ProfilerAction.RECORD_AND_SAVE |
|
|
else: |
|
|
return ProfilerAction.NONE |
|
|
|
|
|
def start(self): |
|
|
if not self.debug: |
|
|
return |
|
|
assert self.__profiler is not None |
|
|
self.__profiler.step() |
|
|
|
|
|
def end(self): |
|
|
if not self.debug: |
|
|
return |
|
|
assert self.__profiler is not None |
|
|
self.__profiler.step() |
|
|
|
|
|
def profiler(self): |
|
|
if self.debug: |
|
|
activities_list = [torch.profiler.ProfilerActivity.CPU] |
|
|
if torch.cuda.is_available(): |
|
|
activities_list.append(torch.profiler.ProfilerActivity.CUDA) |
|
|
|
|
|
self.__profiler = torch.profiler.profile( |
|
|
activities=activities_list, |
|
|
record_shapes=True, |
|
|
with_stack=True, |
|
|
with_modules=True, |
|
|
profile_memory=True, |
|
|
schedule=self.three_step_schedule, |
|
|
on_trace_ready=self.profiler_callback, |
|
|
) |
|
|
return self.__profiler |
|
|
else: |
|
|
return nullcontext() |
|
|
|
|
|
def record(self, func_name: str): |
|
|
if self.debug: |
|
|
return torch.profiler.record_function(func_name) |
|
|
else: |
|
|
return nullcontext() |
|
|
|