"""Policy wrapper for making voice models RL-compatible.""" import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional import logging logger = logging.getLogger(__name__) class PolicyValueHead(nn.Module): """ Policy and value head for RL training on voice models. Adds a policy head (for action log probabilities) and value head (for state value estimation) on top of a voice model's hidden states. """ def __init__( self, hidden_size: int, action_dim: int = 256, value_hidden_size: int = 128 ): """ Initialize policy and value heads. Args: hidden_size: Size of the base model's hidden states action_dim: Dimensionality of the action space value_hidden_size: Hidden size for value network """ super().__init__() # Policy head - outputs action logits self.policy_head = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size // 2, action_dim) ) # Value head - outputs state value estimate self.value_head = nn.Sequential( nn.Linear(hidden_size, value_hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(value_hidden_size, 1) ) logger.info(f"Initialized PolicyValueHead with hidden_size={hidden_size}, action_dim={action_dim}") def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass through policy and value heads. Args: hidden_states: Hidden states from base model [batch, seq_len, hidden_size] Returns: Tuple of (action_logits, state_values) """ # Pool hidden states (mean pooling over sequence) pooled = hidden_states.mean(dim=1) # [batch, hidden_size] # Get action logits and values action_logits = self.policy_head(pooled) # [batch, action_dim] state_values = self.value_head(pooled) # [batch, 1] return action_logits, state_values class RLVoiceModel(nn.Module): """ RL-compatible wrapper for voice models. Wraps a HuggingFace voice model and adds policy/value heads for reinforcement learning training. """ def __init__( self, base_model: nn.Module, hidden_size: int, action_dim: int = 256, action_representation: str = "discrete" ): """ Initialize RL voice model wrapper. Args: base_model: Base voice model (e.g., wav2vec2) hidden_size: Hidden size of base model action_dim: Dimensionality of action space action_representation: "discrete" or "continuous" """ super().__init__() self.base_model = base_model self.hidden_size = hidden_size self.action_dim = action_dim self.action_representation = action_representation # Add policy and value heads self.policy_value_head = PolicyValueHead( hidden_size=hidden_size, action_dim=action_dim ) logger.info(f"Initialized RLVoiceModel with action_representation={action_representation}") def forward( self, input_features: torch.Tensor, return_hidden_states: bool = False, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Forward pass for RL training. Args: input_features: Input audio features [batch, seq_len, features] return_hidden_states: Whether to return base model hidden states **kwargs: Additional arguments for base model Returns: Tuple of (log_probs, values, hidden_states) """ # Get base model outputs base_outputs = self.base_model(input_features, **kwargs) # Extract hidden states if hasattr(base_outputs, 'last_hidden_state'): hidden_states = base_outputs.last_hidden_state elif isinstance(base_outputs, torch.Tensor): hidden_states = base_outputs else: hidden_states = base_outputs[0] # Get policy and value outputs action_logits, state_values = self.policy_value_head(hidden_states) # Compute log probabilities if self.action_representation == "discrete": log_probs = F.log_softmax(action_logits, dim=-1) else: # For continuous actions, return the logits directly log_probs = action_logits if return_hidden_states: return log_probs, state_values, hidden_states else: return log_probs, state_values, None def sample_action( self, input_features: torch.Tensor, deterministic: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sample actions from the policy. Args: input_features: Input audio features deterministic: If True, take most likely action Returns: Tuple of (actions, log_probs, values) """ log_probs, values, _ = self.forward(input_features) if self.action_representation == "discrete": if deterministic: actions = log_probs.argmax(dim=-1) else: # Sample from categorical distribution probs = torch.exp(log_probs) actions = torch.multinomial(probs, num_samples=1).squeeze(-1) # Get log prob of selected actions action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) else: # For continuous actions, add noise for exploration if deterministic: actions = log_probs else: actions = log_probs + torch.randn_like(log_probs) * 0.1 action_log_probs = -0.5 * ((actions - log_probs) ** 2).sum(dim=-1) return actions, action_log_probs, values def evaluate_actions( self, input_features: torch.Tensor, actions: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Evaluate actions (for PPO training). Args: input_features: Input audio features actions: Actions to evaluate Returns: Tuple of (log_probs, values, entropy) """ log_probs, values, _ = self.forward(input_features) if self.action_representation == "discrete": # Get log probs of given actions action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) # Compute entropy probs = torch.exp(log_probs) entropy = -(probs * log_probs).sum(dim=-1).mean() else: # For continuous actions action_log_probs = -0.5 * ((actions - log_probs) ** 2).sum(dim=-1) # Entropy for continuous (Gaussian assumption) entropy = 0.5 * log_probs.shape[-1] * (1.0 + torch.log(torch.tensor(2.0 * 3.14159))) return action_log_probs, values.squeeze(-1), entropy def get_base_model(self) -> nn.Module: """Get the underlying base model.""" return self.base_model def freeze_base_model(self) -> None: """Freeze base model parameters (only train policy/value heads).""" for param in self.base_model.parameters(): param.requires_grad = False logger.info("Froze base model parameters") def unfreeze_base_model(self) -> None: """Unfreeze base model parameters.""" for param in self.base_model.parameters(): param.requires_grad = True logger.info("Unfroze base model parameters") class SequentialVoicePolicy(nn.Module): """ Sequential policy for frame-by-frame voice generation. For autoregressive voice generation where each frame is an action. """ def __init__( self, base_model: nn.Module, hidden_size: int, frame_size: int = 80, # e.g., 80-dim mel spectrogram max_seq_len: int = 1000 ): """ Initialize sequential voice policy. Args: base_model: Base model for processing context hidden_size: Hidden size frame_size: Size of each output frame max_seq_len: Maximum sequence length """ super().__init__() self.base_model = base_model self.hidden_size = hidden_size self.frame_size = frame_size self.max_seq_len = max_seq_len # Frame generation network self.frame_generator = nn.LSTM( input_size=hidden_size + frame_size, hidden_size=hidden_size, num_layers=2, batch_first=True ) # Output projection self.output_projection = nn.Linear(hidden_size, frame_size) # Value network self.value_net = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Linear(hidden_size // 2, 1) ) logger.info(f"Initialized SequentialVoicePolicy with frame_size={frame_size}") def forward( self, input_features: torch.Tensor, previous_frames: Optional[torch.Tensor] = None, num_frames: int = 10 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Generate sequence of frames. Args: input_features: Input conditioning features previous_frames: Previous generated frames (for autoregression) num_frames: Number of frames to generate Returns: Tuple of (generated_frames, log_probs, values) """ batch_size = input_features.shape[0] # Get context from base model base_outputs = self.base_model(input_features) if hasattr(base_outputs, 'last_hidden_state'): context = base_outputs.last_hidden_state.mean(dim=1) # [batch, hidden] else: context = base_outputs.mean(dim=1) if len(base_outputs.shape) > 2 else base_outputs # Initialize if previous_frames is None: current_frame = torch.zeros(batch_size, self.frame_size, device=input_features.device) else: current_frame = previous_frames[:, -1] hidden = None generated_frames = [] log_probs = [] # Generate frames autoregressively for t in range(num_frames): # Combine context and previous frame lstm_input = torch.cat([context, current_frame], dim=-1).unsqueeze(1) # LSTM step lstm_out, hidden = self.frame_generator(lstm_input, hidden) # Project to frame frame_logits = self.output_projection(lstm_out.squeeze(1)) # Sample frame (treat as continuous output) current_frame = torch.tanh(frame_logits) # Bound to [-1, 1] # Compute log prob (simplified) frame_log_prob = -0.5 * (frame_logits ** 2).sum(dim=-1) generated_frames.append(current_frame) log_probs.append(frame_log_prob) # Stack results generated_frames = torch.stack(generated_frames, dim=1) # [batch, num_frames, frame_size] log_probs = torch.stack(log_probs, dim=1) # [batch, num_frames] # Compute values values = self.value_net(context) # [batch, 1] return generated_frames, log_probs, values