|
|
""" |
|
|
Model Architecture for Speaker Profiling |
|
|
Supports multiple encoders: WavLM, HuBERT, Wav2Vec2, Whisper, ECAPA-TDNN |
|
|
Architecture: Encoder + Attentive Pooling + LayerNorm + Classification Heads |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import ( |
|
|
WavLMModel, |
|
|
HubertModel, |
|
|
Wav2Vec2Model, |
|
|
WhisperModel, |
|
|
AutoConfig |
|
|
) |
|
|
|
|
|
|
|
|
SPEECHBRAIN_AVAILABLE = None |
|
|
EncoderClassifier = None |
|
|
|
|
|
def _check_speechbrain(): |
|
|
"""Lazily check and import SpeechBrain""" |
|
|
global SPEECHBRAIN_AVAILABLE, EncoderClassifier |
|
|
if SPEECHBRAIN_AVAILABLE is None: |
|
|
try: |
|
|
from speechbrain.inference.speaker import EncoderClassifier as _EncoderClassifier |
|
|
EncoderClassifier = _EncoderClassifier |
|
|
SPEECHBRAIN_AVAILABLE = True |
|
|
except (ImportError, AttributeError) as e: |
|
|
SPEECHBRAIN_AVAILABLE = False |
|
|
logger.warning(f"SpeechBrain not available: {e}") |
|
|
return SPEECHBRAIN_AVAILABLE |
|
|
|
|
|
logger = logging.getLogger("speaker_profiling") |
|
|
|
|
|
|
|
|
|
|
|
class ECAPATDNNEncoder(nn.Module): |
|
|
""" |
|
|
Wrapper for SpeechBrain ECAPA-TDNN encoder. |
|
|
|
|
|
ECAPA-TDNN outputs fixed-size embeddings (192 or 512 dim) instead of |
|
|
frame-level features like WavLM/HuBERT. This wrapper handles the difference. |
|
|
|
|
|
Supported models: |
|
|
- speechbrain/spkrec-ecapa-voxceleb: 192-dim embeddings |
|
|
- speechbrain/spkrec-xvect-voxceleb: 512-dim embeddings (x-vector) |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name: str = "speechbrain/spkrec-ecapa-voxceleb"): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if not _check_speechbrain(): |
|
|
raise ImportError( |
|
|
"SpeechBrain is required for ECAPA-TDNN. " |
|
|
"Install with: pip install speechbrain" |
|
|
) |
|
|
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.encoder = EncoderClassifier.from_hparams( |
|
|
source=model_name, |
|
|
savedir=f"pretrained_models/{model_name.split('/')[-1]}", |
|
|
run_opts={"device": device} |
|
|
) |
|
|
|
|
|
|
|
|
self.encoder.mods.float() |
|
|
|
|
|
|
|
|
if "ecapa" in model_name.lower(): |
|
|
self.embedding_size = 192 |
|
|
elif "xvect" in model_name.lower(): |
|
|
self.embedding_size = 512 |
|
|
else: |
|
|
self.embedding_size = 192 |
|
|
|
|
|
|
|
|
class Config: |
|
|
def __init__(self, hidden_size): |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
self.config = Config(self.embedding_size) |
|
|
|
|
|
|
|
|
self._current_device = device |
|
|
|
|
|
def forward(self, input_values: torch.Tensor, attention_mask: torch.Tensor = None): |
|
|
""" |
|
|
Extract embeddings from audio. |
|
|
|
|
|
Args: |
|
|
input_values: Audio waveform [B, T] |
|
|
attention_mask: Not used for ECAPA-TDNN |
|
|
|
|
|
Returns: |
|
|
Object with last_hidden_state attribute [B, 1, H] |
|
|
""" |
|
|
|
|
|
device = input_values.device |
|
|
|
|
|
|
|
|
if str(device) != str(self._current_device): |
|
|
self.encoder.to(device) |
|
|
self.encoder.mods.float() |
|
|
self._current_device = device |
|
|
|
|
|
|
|
|
input_values = input_values.float().to(device) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
self.encoder.eval() |
|
|
embeddings = self.encoder.encode_batch(input_values) |
|
|
|
|
|
|
|
|
embeddings = embeddings.float() |
|
|
|
|
|
|
|
|
class Output: |
|
|
def __init__(self, hidden_state): |
|
|
self.last_hidden_state = hidden_state |
|
|
|
|
|
return Output(embeddings) |
|
|
|
|
|
|
|
|
|
|
|
ENCODER_REGISTRY = { |
|
|
|
|
|
"microsoft/wavlm-base": {"class": WavLMModel, "hidden_size": 768}, |
|
|
"microsoft/wavlm-base-plus": {"class": WavLMModel, "hidden_size": 768}, |
|
|
"microsoft/wavlm-large": {"class": WavLMModel, "hidden_size": 1024}, |
|
|
|
|
|
|
|
|
"facebook/hubert-base-ls960": {"class": HubertModel, "hidden_size": 768}, |
|
|
"facebook/hubert-large-ls960-ft": {"class": HubertModel, "hidden_size": 1024}, |
|
|
"facebook/hubert-xlarge-ls960-ft": {"class": HubertModel, "hidden_size": 1280}, |
|
|
|
|
|
|
|
|
"facebook/wav2vec2-base": {"class": Wav2Vec2Model, "hidden_size": 768}, |
|
|
"facebook/wav2vec2-base-960h": {"class": Wav2Vec2Model, "hidden_size": 768}, |
|
|
"facebook/wav2vec2-large": {"class": Wav2Vec2Model, "hidden_size": 1024}, |
|
|
"facebook/wav2vec2-large-960h": {"class": Wav2Vec2Model, "hidden_size": 1024}, |
|
|
"facebook/wav2vec2-xls-r-300m": {"class": Wav2Vec2Model, "hidden_size": 1024}, |
|
|
|
|
|
|
|
|
"nguyenvulebinh/wav2vec2-base-vi-vlsp2020": {"class": Wav2Vec2Model, "hidden_size": 768}, |
|
|
|
|
|
|
|
|
"openai/whisper-tiny": {"class": WhisperModel, "hidden_size": 384, "is_whisper": True}, |
|
|
"openai/whisper-base": {"class": WhisperModel, "hidden_size": 512, "is_whisper": True}, |
|
|
"openai/whisper-small": {"class": WhisperModel, "hidden_size": 768, "is_whisper": True}, |
|
|
"openai/whisper-medium": {"class": WhisperModel, "hidden_size": 1024, "is_whisper": True}, |
|
|
"openai/whisper-large": {"class": WhisperModel, "hidden_size": 1280, "is_whisper": True}, |
|
|
"openai/whisper-large-v2": {"class": WhisperModel, "hidden_size": 1280, "is_whisper": True}, |
|
|
"openai/whisper-large-v3": {"class": WhisperModel, "hidden_size": 1280, "is_whisper": True}, |
|
|
|
|
|
|
|
|
"vinai/PhoWhisper-tiny": {"class": WhisperModel, "hidden_size": 384, "is_whisper": True}, |
|
|
"vinai/PhoWhisper-base": {"class": WhisperModel, "hidden_size": 512, "is_whisper": True}, |
|
|
"vinai/PhoWhisper-small": {"class": WhisperModel, "hidden_size": 768, "is_whisper": True}, |
|
|
"vinai/PhoWhisper-medium": {"class": WhisperModel, "hidden_size": 1024, "is_whisper": True}, |
|
|
"vinai/PhoWhisper-large": {"class": WhisperModel, "hidden_size": 1280, "is_whisper": True}, |
|
|
|
|
|
|
|
|
"speechbrain/spkrec-ecapa-voxceleb": { |
|
|
"class": ECAPATDNNEncoder, |
|
|
"hidden_size": 192, |
|
|
"is_ecapa": True |
|
|
}, |
|
|
"speechbrain/spkrec-xvect-voxceleb": { |
|
|
"class": ECAPATDNNEncoder, |
|
|
"hidden_size": 512, |
|
|
"is_ecapa": True |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def get_encoder_info(model_name: str) -> dict: |
|
|
"""Get encoder class and hidden size for a model name""" |
|
|
if model_name in ENCODER_REGISTRY: |
|
|
return ENCODER_REGISTRY[model_name] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'ecapa' in model_name.lower() or 'speechbrain' in model_name.lower(): |
|
|
hidden_size = 512 if 'xvect' in model_name.lower() else 192 |
|
|
return {"class": ECAPATDNNEncoder, "hidden_size": hidden_size, "is_ecapa": True} |
|
|
|
|
|
|
|
|
try: |
|
|
config = AutoConfig.from_pretrained(model_name) |
|
|
hidden_size = getattr(config, 'hidden_size', 768) |
|
|
|
|
|
if 'wavlm' in model_name.lower(): |
|
|
return {"class": WavLMModel, "hidden_size": hidden_size} |
|
|
elif 'hubert' in model_name.lower(): |
|
|
return {"class": HubertModel, "hidden_size": hidden_size} |
|
|
elif 'wav2vec2' in model_name.lower(): |
|
|
return {"class": Wav2Vec2Model, "hidden_size": hidden_size} |
|
|
elif 'whisper' in model_name.lower() or 'phowhisper' in model_name.lower(): |
|
|
return {"class": WhisperModel, "hidden_size": hidden_size, "is_whisper": True} |
|
|
else: |
|
|
|
|
|
return {"class": Wav2Vec2Model, "hidden_size": hidden_size} |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not auto-detect encoder for {model_name}: {e}") |
|
|
return {"class": WavLMModel, "hidden_size": 768} |
|
|
|
|
|
|
|
|
class AttentivePooling(nn.Module): |
|
|
""" |
|
|
Attention-based pooling for temporal aggregation |
|
|
|
|
|
Takes sequence of hidden states and produces a single vector |
|
|
by computing attention weights and performing weighted sum. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int): |
|
|
super().__init__() |
|
|
self.attention = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.Tanh(), |
|
|
nn.Linear(hidden_size, 1, bias=False) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None): |
|
|
""" |
|
|
Args: |
|
|
x: Hidden states [B, T, H] |
|
|
mask: Attention mask [B, T] |
|
|
|
|
|
Returns: |
|
|
pooled: Pooled representation [B, H] |
|
|
attn_weights: Attention weights [B, T] |
|
|
""" |
|
|
attn_weights = self.attention(x) |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(-1) |
|
|
attn_weights = attn_weights.masked_fill(mask == 0, -1e9) |
|
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=1) |
|
|
pooled = torch.sum(x * attn_weights, dim=1) |
|
|
|
|
|
return pooled, attn_weights.squeeze(-1) |
|
|
|
|
|
|
|
|
class MultiTaskSpeakerModel(nn.Module): |
|
|
""" |
|
|
Multi-task model for gender and dialect classification |
|
|
|
|
|
Architecture: |
|
|
Audio -> Encoder (WavLM/HuBERT/Wav2Vec2/Whisper/ECAPA-TDNN) -> Last Hidden [B,T,H] |
|
|
| |
|
|
Attentive Pooling [B,H] (skipped for ECAPA-TDNN) |
|
|
| |
|
|
Layer Normalization |
|
|
| |
|
|
Dropout(0.1) |
|
|
| |
|
|
+---------------+---------------+ |
|
|
| | |
|
|
Gender Head (2 layers) Dialect Head (3 layers) |
|
|
| | |
|
|
[B,2] [B,3] |
|
|
|
|
|
Supported encoders: |
|
|
- WavLM: microsoft/wavlm-base-plus, microsoft/wavlm-large |
|
|
- HuBERT: facebook/hubert-base-ls960, facebook/hubert-large-ls960-ft |
|
|
- Wav2Vec2: facebook/wav2vec2-base, facebook/wav2vec2-large-960h |
|
|
- Whisper: openai/whisper-base, openai/whisper-small, openai/whisper-medium |
|
|
- ECAPA-TDNN: speechbrain/spkrec-ecapa-voxceleb (192-dim embeddings) |
|
|
|
|
|
Args: |
|
|
model_name: Pretrained encoder model name or path |
|
|
num_genders: Number of gender classes (default: 2) |
|
|
num_dialects: Number of dialect classes (default: 3) |
|
|
dropout: Dropout probability (default: 0.1) |
|
|
head_hidden_dim: Hidden dimension for classification heads (default: 256) |
|
|
freeze_encoder: Whether to freeze encoder (default: False) |
|
|
dialect_loss_weight: Weight for dialect loss in multi-task learning (default: 3.0) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str, |
|
|
num_genders: int = 2, |
|
|
num_dialects: int = 3, |
|
|
dropout: float = 0.1, |
|
|
head_hidden_dim: int = 256, |
|
|
freeze_encoder: bool = False, |
|
|
dialect_loss_weight: float = 3.0 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.model_name = model_name |
|
|
self.dialect_loss_weight = dialect_loss_weight |
|
|
|
|
|
|
|
|
encoder_info = get_encoder_info(model_name) |
|
|
encoder_class = encoder_info["class"] |
|
|
self.is_whisper = encoder_info.get("is_whisper", False) |
|
|
self.is_ecapa = encoder_info.get("is_ecapa", False) |
|
|
|
|
|
logger.info(f"Loading encoder: {model_name}") |
|
|
logger.info(f"Encoder class: {encoder_class.__name__}") |
|
|
|
|
|
|
|
|
if self.is_ecapa: |
|
|
|
|
|
self.encoder = encoder_class(model_name) |
|
|
else: |
|
|
self.encoder = encoder_class.from_pretrained(model_name) |
|
|
|
|
|
hidden_size = self.encoder.config.hidden_size |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
logger.info(f"Hidden size: {hidden_size}") |
|
|
|
|
|
|
|
|
if freeze_encoder: |
|
|
for param in self.encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
logger.info("Encoder weights frozen") |
|
|
|
|
|
|
|
|
self.attentive_pooling = AttentivePooling(hidden_size) |
|
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.gender_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, head_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(head_hidden_dim, num_genders) |
|
|
) |
|
|
|
|
|
|
|
|
self.dialect_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, head_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(head_hidden_dim, head_hidden_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(head_hidden_dim // 2, num_dialects) |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: torch.Tensor = None, |
|
|
input_features: torch.Tensor = None, |
|
|
attention_mask: torch.Tensor = None, |
|
|
gender_labels: torch.Tensor = None, |
|
|
dialect_labels: torch.Tensor = None |
|
|
): |
|
|
""" |
|
|
Forward pass - supports both raw audio and pre-extracted features |
|
|
|
|
|
Args: |
|
|
input_values: Audio waveform [B, T] (for raw audio mode) |
|
|
input_features: Pre-extracted features [B, T, H] or [B, 1, H] for ECAPA |
|
|
attention_mask: Attention mask [B, T] |
|
|
gender_labels: Gender labels [B] (optional, for training) |
|
|
dialect_labels: Dialect labels [B] (optional, for training) |
|
|
|
|
|
Returns: |
|
|
dict with keys: |
|
|
- loss: Combined loss (if labels provided) |
|
|
- gender_logits: Gender predictions [B, num_genders] |
|
|
- dialect_logits: Dialect predictions [B, num_dialects] |
|
|
- attention_weights: Attention weights from pooling [B, T] (None for ECAPA) |
|
|
""" |
|
|
|
|
|
if input_features is not None: |
|
|
|
|
|
hidden_states = input_features |
|
|
elif input_values is not None: |
|
|
|
|
|
hidden_states = self._encode(input_values, attention_mask) |
|
|
else: |
|
|
raise ValueError("Either input_values or input_features must be provided") |
|
|
|
|
|
|
|
|
if self.is_ecapa or hidden_states.shape[1] == 1: |
|
|
|
|
|
pooled = hidden_states.squeeze(1) |
|
|
attn_weights = None |
|
|
else: |
|
|
|
|
|
|
|
|
if attention_mask is not None and hidden_states.shape[1] != attention_mask.shape[1]: |
|
|
|
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
pooled_mask = torch.ones(batch_size, seq_len, device=hidden_states.device) |
|
|
else: |
|
|
pooled_mask = attention_mask |
|
|
|
|
|
|
|
|
pooled, attn_weights = self.attentive_pooling(hidden_states, pooled_mask) |
|
|
|
|
|
|
|
|
pooled = self.layer_norm(pooled) |
|
|
pooled = self.dropout(pooled) |
|
|
|
|
|
|
|
|
gender_logits = self.gender_head(pooled) |
|
|
dialect_logits = self.dialect_head(pooled) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if gender_labels is not None and dialect_labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
gender_loss = loss_fct(gender_logits, gender_labels) |
|
|
dialect_loss = loss_fct(dialect_logits, dialect_labels) |
|
|
loss = gender_loss + self.dialect_loss_weight * dialect_loss |
|
|
|
|
|
return { |
|
|
'loss': loss, |
|
|
'gender_logits': gender_logits, |
|
|
'dialect_logits': dialect_logits, |
|
|
'attention_weights': attn_weights |
|
|
} |
|
|
|
|
|
def _encode( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
attention_mask: torch.Tensor = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Extract hidden states from encoder |
|
|
|
|
|
Args: |
|
|
input_values: Audio waveform [B, T] |
|
|
attention_mask: Attention mask [B, T] |
|
|
|
|
|
Returns: |
|
|
hidden_states: Hidden states [B, T, H] or [B, 1, H] for ECAPA-TDNN |
|
|
""" |
|
|
if self.is_ecapa: |
|
|
|
|
|
outputs = self.encoder(input_values, attention_mask) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
elif self.is_whisper: |
|
|
|
|
|
outputs = self.encoder.encoder(input_values) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
else: |
|
|
|
|
|
outputs = self.encoder(input_values, attention_mask=attention_mask) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
def get_embeddings( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
attention_mask: torch.Tensor = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Extract speaker embeddings (pooled representations) |
|
|
|
|
|
Args: |
|
|
input_values: Audio waveform [B, T] |
|
|
attention_mask: Attention mask [B, T] |
|
|
|
|
|
Returns: |
|
|
embeddings: Speaker embeddings [B, H] |
|
|
""" |
|
|
hidden_states = self._encode(input_values, attention_mask) |
|
|
|
|
|
if self.is_ecapa or hidden_states.shape[1] == 1: |
|
|
|
|
|
pooled = hidden_states.squeeze(1) |
|
|
else: |
|
|
pooled, _ = self.attentive_pooling(hidden_states, attention_mask) |
|
|
|
|
|
pooled = self.layer_norm(pooled) |
|
|
return pooled |
|
|
|
|
|
|
|
|
class MultiTaskSpeakerModelFromConfig(MultiTaskSpeakerModel): |
|
|
""" |
|
|
Multi-task model initialized from OmegaConf config |
|
|
|
|
|
Supports multiple encoders: WavLM, HuBERT, Wav2Vec2, Whisper |
|
|
Use this for inference with raw audio input. |
|
|
|
|
|
Usage: |
|
|
config = OmegaConf.load('configs/finetune.yaml') |
|
|
model = MultiTaskSpeakerModelFromConfig(config) |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
model_config = config['model'] |
|
|
|
|
|
super().__init__( |
|
|
model_name=model_config['name'], |
|
|
num_genders=model_config.get('num_genders', 2), |
|
|
num_dialects=model_config.get('num_dialects', 3), |
|
|
dropout=model_config.get('dropout', 0.1), |
|
|
head_hidden_dim=model_config.get('head_hidden_dim', 256), |
|
|
freeze_encoder=model_config.get('freeze_encoder', False), |
|
|
dialect_loss_weight=config.get('loss', {}).get('dialect_weight', 3.0) |
|
|
) |
|
|
|
|
|
logger.info(f"Architecture: {model_config['name']} + Attentive Pooling + LayerNorm") |
|
|
logger.info(f"Hidden size: {self.hidden_size}") |
|
|
logger.info(f"Head hidden dim: {model_config.get('head_hidden_dim', 256)}") |
|
|
logger.info(f"Dropout: {model_config.get('dropout', 0.1)}") |
|
|
|
|
|
|
|
|
class ClassificationHeadModel(nn.Module): |
|
|
""" |
|
|
Lightweight model with only classification heads (no encoder). |
|
|
|
|
|
Use this for training with pre-extracted features to save memory. |
|
|
Hidden_size depends on encoder: WavLM-base=768, WavLM-large=1024, etc. |
|
|
|
|
|
Usage: |
|
|
model = ClassificationHeadModel(config) |
|
|
output = model(input_features=features, gender_labels=y_gender, dialect_labels=y_dialect) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int = 768, |
|
|
num_genders: int = 2, |
|
|
num_dialects: int = 3, |
|
|
dropout: float = 0.1, |
|
|
head_hidden_dim: int = 256, |
|
|
dialect_loss_weight: float = 3.0 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.dialect_loss_weight = dialect_loss_weight |
|
|
|
|
|
|
|
|
self.attentive_pooling = AttentivePooling(hidden_size) |
|
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.gender_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, head_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(head_hidden_dim, num_genders) |
|
|
) |
|
|
|
|
|
|
|
|
self.dialect_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, head_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(head_hidden_dim, head_hidden_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(head_hidden_dim // 2, num_dialects) |
|
|
) |
|
|
|
|
|
logger.info(f"ClassificationHeadModel initialized (hidden_size={hidden_size})") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_features: torch.Tensor, |
|
|
attention_mask: torch.Tensor = None, |
|
|
gender_labels: torch.Tensor = None, |
|
|
dialect_labels: torch.Tensor = None |
|
|
): |
|
|
""" |
|
|
Forward pass for pre-extracted features |
|
|
|
|
|
Args: |
|
|
input_features: Pre-extracted WavLM features [B, T, H] |
|
|
attention_mask: Attention mask [B, T] |
|
|
gender_labels: Gender labels [B] (optional, for training) |
|
|
dialect_labels: Dialect labels [B] (optional, for training) |
|
|
|
|
|
Returns: |
|
|
dict with keys: |
|
|
- loss: Combined loss (if labels provided) |
|
|
- gender_logits: Gender predictions [B, num_genders] |
|
|
- dialect_logits: Dialect predictions [B, num_dialects] |
|
|
- attention_weights: Attention weights from pooling [B, T] |
|
|
""" |
|
|
|
|
|
pooled, attn_weights = self.attentive_pooling(input_features, attention_mask) |
|
|
|
|
|
|
|
|
pooled = self.layer_norm(pooled) |
|
|
pooled = self.dropout(pooled) |
|
|
|
|
|
|
|
|
gender_logits = self.gender_head(pooled) |
|
|
dialect_logits = self.dialect_head(pooled) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if gender_labels is not None and dialect_labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
gender_loss = loss_fct(gender_logits, gender_labels) |
|
|
dialect_loss = loss_fct(dialect_logits, dialect_labels) |
|
|
loss = gender_loss + self.dialect_loss_weight * dialect_loss |
|
|
|
|
|
return { |
|
|
'loss': loss, |
|
|
'gender_logits': gender_logits, |
|
|
'dialect_logits': dialect_logits, |
|
|
'attention_weights': attn_weights |
|
|
} |
|
|
|
|
|
|
|
|
class ClassificationHeadModelFromConfig(ClassificationHeadModel): |
|
|
""" |
|
|
Lightweight classification model initialized from OmegaConf config. |
|
|
|
|
|
Use this for training with pre-extracted features. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
model_config = config['model'] |
|
|
|
|
|
super().__init__( |
|
|
hidden_size=model_config.get('hidden_size', 768), |
|
|
num_genders=model_config.get('num_genders', 2), |
|
|
num_dialects=model_config.get('num_dialects', 3), |
|
|
dropout=model_config.get('dropout', 0.1), |
|
|
head_hidden_dim=model_config.get('head_hidden_dim', 256), |
|
|
dialect_loss_weight=config.get('loss', {}).get('dialect_weight', 3.0) |
|
|
) |
|
|
|
|
|
logger.info("Architecture: Attentive Pooling + LayerNorm + Classification Heads") |
|
|
logger.info(f"Hidden size: {self.hidden_size}") |
|
|
logger.info(f"Head hidden dim: {model_config.get('head_hidden_dim', 256)}") |
|
|
logger.info(f"Dropout: {model_config.get('dropout', 0.1)}") |
|
|
|