# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math import librosa import numpy as np import torch import torch.nn.functional as F from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from diffusers.configuration_utils import ConfigMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.modeling_utils import ModelMixin def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): required_duration = num_sample / target_fps required_origin_frames = int(np.ceil(required_duration * original_fps)) if required_duration > total_frames / original_fps: raise ValueError("required_duration must be less than video length") if not fixed_start is None and fixed_start >= 0: start_frame = fixed_start else: max_start = total_frames - required_origin_frames if max_start < 0: raise ValueError("video length is too short") start_frame = np.random.randint(0, max_start + 1) start_time = start_frame / original_fps end_time = start_time + required_duration time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) frame_indices = np.round(np.array(time_points) * original_fps).astype(int) frame_indices = np.clip(frame_indices, 0, total_frames - 1) return frame_indices def linear_interpolation(features, input_fps, output_fps, output_len=None): """ features: shape=[1, T, 512] input_fps: fps for audio, f_a output_fps: fps for video, f_m output_len: video length """ features = features.transpose(1, 2) # [1, 512, T] seq_len = features.shape[2] / float(input_fps) # T/f_a if output_len is None: output_len = int(seq_len * output_fps) # f_m*T/f_a output_features = F.interpolate( features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len] return output_features.transpose(1, 2) # [1, output_len, 512] class WanAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin): def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'): super(WanAudioEncoder, self).__init__() # load pretrained model self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path) self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_model_path) self.model = self.model.to(device) self.video_rate = 30 def extract_audio_feat(self, audio_path, return_all_layers=False, dtype=torch.float32): audio_input, sample_rate = librosa.load(audio_path, sr=16000) input_values = self.processor( audio_input, sampling_rate=sample_rate, return_tensors="pt" ).input_values # INFERENCE # retrieve logits & take argmax res = self.model( input_values.to(self.model.device), output_hidden_states=True) if return_all_layers: feat = torch.cat(res.hidden_states) else: feat = res.hidden_states[-1] feat = linear_interpolation( feat, input_fps=50, output_fps=self.video_rate) z = feat.to(dtype) # Encoding for the motion return z def extract_audio_feat_without_file_load(self, audio_input, sample_rate, return_all_layers=False, dtype=torch.float32): input_values = self.processor( audio_input, sampling_rate=sample_rate, return_tensors="pt" ).input_values # INFERENCE # retrieve logits & take argmax res = self.model( input_values.to(self.model.device), output_hidden_states=True) if return_all_layers: feat = torch.cat(res.hidden_states) else: feat = res.hidden_states[-1] feat = linear_interpolation( feat, input_fps=50, output_fps=self.video_rate) z = feat.to(dtype) # Encoding for the motion return z def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): num_layers, audio_frame_num, audio_dim = audio_embed.shape if num_layers > 1: return_all_layers = True else: return_all_layers = False min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 bucket_num = min_batch_num * batch_frames batch_idx = [stride * i for i in range(bucket_num)] batch_audio_eb = [] for bi in batch_idx: if bi < audio_frame_num: audio_sample_stride = 2 chosen_idx = list( range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) chosen_idx = [0 if c < 0 else c for c in chosen_idx] chosen_idx = [ audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx ] if return_all_layers: frame_audio_embed = audio_embed[:, chosen_idx].flatten( start_dim=-2, end_dim=-1) else: frame_audio_embed = audio_embed[0][chosen_idx].flatten() else: frame_audio_embed = \ torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) batch_audio_eb.append(frame_audio_embed) batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) return batch_audio_eb, min_batch_num def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): num_layers, audio_frame_num, audio_dim = audio_embed.shape if num_layers > 1: return_all_layers = True else: return_all_layers = False scale = self.video_rate / fps min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 bucket_num = min_batch_num * batch_frames padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num batch_idx = get_sample_indices( original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0) batch_audio_eb = [] audio_sample_stride = int(self.video_rate / fps) for bi in batch_idx: if bi < audio_frame_num: chosen_idx = list( range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) chosen_idx = [0 if c < 0 else c for c in chosen_idx] chosen_idx = [ audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx ] if return_all_layers: frame_audio_embed = audio_embed[:, chosen_idx].flatten( start_dim=-2, end_dim=-1) else: frame_audio_embed = audio_embed[0][chosen_idx].flatten() else: frame_audio_embed = \ torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) batch_audio_eb.append(frame_audio_embed) batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) return batch_audio_eb, min_batch_num