|
|
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model |
|
|
import sys |
|
|
import torch |
|
|
import torchinfo |
|
|
hf_path = sys.argv[1] |
|
|
audio_file = sys.argv[2] |
|
|
extract = Wav2Vec2FeatureExtractor.from_pretrained(hf_path) |
|
|
hf_model = Wav2Vec2Model.from_pretrained(hf_path) |
|
|
|
|
|
torchinfo.summary(hf_model) |
|
|
|
|
|
hf_model.eval() |
|
|
|
|
|
import torchaudio |
|
|
|
|
|
waveform, sample_rate = torchaudio.load(audio_file) |
|
|
if waveform.shape[0] > 1: |
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
feat = extract(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt", padding=True) |
|
|
out = hf_model(feat.input_values,output_hidden_states=True) |
|
|
last_hidden_states = out.last_hidden_state |
|
|
cnn_out = out.extract_features |
|
|
hidden_states = out.hidden_states |
|
|
print("CNN features shape:", cnn_out.shape) |
|
|
print("Hidden states length:", hidden_states.__len__()) |
|
|
print("Last hidden states shape:", last_hidden_states.shape) |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
last_hidden_states = (last_hidden_states - last_hidden_states.min()) / (last_hidden_states.max() - last_hidden_states.min()) |
|
|
cnn_out = [(feat - feat.min()) / (feat.max() - feat.min()) for feat in cnn_out] |
|
|
|
|
|
|
|
|
|
|
|
last_hidden_states = torch.log1p(last_hidden_states * 100) |
|
|
cnn_out = [torch.log1p(feat * 100) for feat in cnn_out] |
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
plt.figure(figsize=(12, 6)) |
|
|
plt.subplot(2, 1, 1) |
|
|
plt.title("Last Hidden States") |
|
|
plt.imshow(last_hidden_states[0].cpu().numpy().T, aspect='auto', origin='lower') |
|
|
plt.colorbar() |
|
|
plt.subplot(2, 1, 2) |
|
|
plt.title("CNN Features") |
|
|
plt.imshow(cnn_out[-1].cpu().numpy().T, aspect='auto', origin='lower') |
|
|
plt.colorbar() |
|
|
plt.tight_layout() |
|
|
plt.savefig("hf_wav2vec2_features.png") |
|
|
plt.close() |