go / eval.py
jva96160's picture
Upload 32 files
4c1ba5a verified
from io import BytesIO
from urllib.request import urlopen
import soundfile
import torch
from datasets import load_dataset, Audio
import numpy as np
from transformers import AutoModel, AutoProcessor, BatchFeature
from tqdm import tqdm
import json
import os
import time
from datetime import datetime
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
import sacrebleu
from jiwer import cer, wer
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import re
from pathlib import Path
normalizer = {
"en_us" : EnglishTextNormalizer(),
"other" : BasicTextNormalizer()
}
model_id = "/home/jeff/codes/llm/InCar/gemma-3-4b-it-omni"
revision = "main" #"v1.0"
model = AutoModel.from_pretrained(
model_id, device_map="cuda", revision = revision, trust_remote_code=True
).eval()
processor = AutoProcessor.from_pretrained(
model_id, revision = revision, trust_remote_code=True
)
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(results_dir, exist_ok=True)
INSTRUCTION = {
"ast": "Translate the audio to {0}.",
"asr": "Transcribe the audio clip into text.",
}
class BaseAudioDataset(Dataset):
def __init__(self, processor, split, sampling_rate=16000, debug=False):
self.processor = processor
self.training = "train" in split
self.debug = debug
self.sampling_rate = sampling_rate
self.name = ""
def set_dataset_name(self, name):
self.name = name
@staticmethod
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
original_size = len(data)
data = data.cast_column(audio_field, Audio(decode=False))
def identify_corrupted_files(example):
try:
sf.read(example[audio_field]["path"])
for field in text_fields:
if example[field].replace('"', '') == "":
return False
return True
except Exception:
return False
data = data.filter(identify_corrupted_files, num_proc=16)
validated_size = len(data)
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
return data
@staticmethod
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
original_size = len(data)
def filter_audio_by_length(example):
try:
audio = example[audio_field]['array']
channel = 1
if hasattr(audio, 'ndim') and audio.ndim > 1:
channel = audio.ndim
audio = audio.squeeze()
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
return min_sec <= audio_length <= max_sec
except Exception as e:
return False
data = data.filter(filter_audio_by_length, num_proc=16)
filtered_size = len(data)
return data
def prepare_model_inputs(self, audio_array, instruction, answer_text):
user_message = {
'role': 'user',
'content': '<start_of_audio>' + instruction,
}
prompt = self.processor.tokenizer.apply_chat_template(
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
)
inputs = self.processor(
text=prompt,
audio=[audio_array],
add_special_tokens=False,
return_tensors='pt'
)
input_ids = inputs.input_ids
token_type_ids = inputs.token_type_ids
return {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'input_audio_embeds': inputs.input_audio_embeds,
'audio_embed_sizes': inputs.audio_embed_sizes,
'input_modes': inputs.input_modes,
'answer': answer_text,
}
# Libri Speech Dataset Class
class LibriSpeechDataset(BaseAudioDataset):
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name(f"LibriSpeech_{subset}")
# only ASR
self.ast = False
self.lang = "en"
# load dataset
self.data = load_dataset("openslr/librispeech_asr",
subset,
split=split,
trust_remote_code=True,
cache_dir=Path("/home/jeff/codes/llm/InCar/data")
)
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# Instruction Setting
self.instruction = INSTRUCTION["asr"]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
# Libri Speech is only for ASR
answer_text = data["text"].replace('"', '')
return self.prepare_model_inputs(
data["audio"]["array"],
INSTRUCTION["asr"],
answer_text
)
# common_voice_16_1 dataset
class CommonVoiceDataset(BaseAudioDataset):
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name(f"CommonVoice_{source_lang}")
# only ASR
self.ast = False
self.lang=source_lang
# load dataset
self.data = load_dataset("mozilla-foundation/common_voice_16_1",
source_lang,
split=split,
trust_remote_code=True,
cache_dir=Path("/home/jeff/codes/llm/InCar/data")
)
def prepare_dataset(batch):
"""Function to preprocess the dataset with the .map method"""
transcription = batch["sentence"]
if transcription.startswith('"') and transcription.endswith('"'):
# we can remove trailing quotation marks as they do not affect the transcription
transcription = transcription[1:-1]
if transcription[-1] not in [".", "?", "!"]:
# append a full-stop to sentences that do not end in punctuation
transcription = transcription + "."
batch["sentence"] = transcription
return batch
self.data.map(prepare_dataset, desc="preprocess dataset")
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# Instruction Setting
self.instruction = INSTRUCTION["asr"]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
answer_text = data["sentence"]
return self.prepare_model_inputs(
data["audio"]["array"],
INSTRUCTION["asr"],
answer_text
)
# Fleurs Dataset Class
class FleursDataset(BaseAudioDataset):
def __init__(self, processor, split, source_lang, target_lang=None,
mode="asr", sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name("Fleurs")
# Mode Setting (ASR or AST)
if mode not in ["asr", "ast"]:
raise ValueError("mode must be 'asr' or 'ast'.")
self.mode = mode
self.ast = (mode == "ast")
self.source_lang = source_lang
# Language name mapping (expand if needed)
self.lang_names = {
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
}
# load dataset - source language dataset
self.data = load_dataset("google/fleurs",
source_lang,
split=split,
trust_remote_code=True,
cache_dir=Path("/home/jeff/codes/llm/InCar/data")
)
def prepare_dataset(batch):
import opencc
converter = opencc.OpenCC('s2tw.json')
if self.ast:
translation = converter.convert(batch["translation"])
batch["translation"] = translation
else:
transcription = converter.convert(batch["transcription"])
batch["transcription"] = transcription
return batch
if (source_lang=="cmn_hans_cn" and not self.ast) or (self.ast and target_lang=="cmn_hans_cn"):
self.data.map(prepare_dataset, desc="preprocess dataset")
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
self.target_lang_name = ""
# When AST mode, load target language dataset.
if self.ast:
if target_lang is None:
raise ValueError("AST mode requires target_lang.")
self.target_lang = target_lang
self.lang = f"{source_lang}_{target_lang}"
# load dataset - target language dataset (for translation)
target_data = load_dataset("google/fleurs",
target_lang,
split=split,
trust_remote_code=True,
cache_dir=Path("/home/jeff/codes/llm/InCar/data")
)
source_dict = {item['id']: item for item in self.data}
target_dict = {item['id']: item for item in target_data}
# only Common ID, add translation fields
common_ids = set(source_dict.keys()) & set(target_dict.keys())
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
self.data = [
{**source_dict[id], 'translation': target_dict[id]['transcription']}
for id in common_ids
]
# Instruction Setting - use target language name
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
self.instruction = INSTRUCTION["ast"]
else:
# ASR mode
self.lang = source_lang
self.instruction = INSTRUCTION["asr"]
if self.debug:
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
if self.ast:
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
print(f"dataset size: {len(self.data)}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
audio_array = data["audio"]["array"]
if self.ast:
answer_text = data["translation"]
else:
answer_text = data["transcription"]
return self.prepare_model_inputs(
audio_array,
self.instruction.format(self.target_lang_name),
answer_text
)
def pad_sequence(sequences, padding_side='left', padding_value=0):
"""
Pad a list of sequences to the same length.
sequences: list of tensors in [seq_len, *] shape
"""
assert padding_side in ['right', 'left']
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max(len(seq) for seq in sequences)
batch_size = len(sequences)
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
for i, seq in enumerate(sequences):
length = seq.size(0)
if padding_side == 'right':
output.data[i, :length] = seq
else:
output.data[i, -length:] = seq
return output
def cat_with_pad(tensors, dim, padding_value=0):
"""
cat along dim, while pad to max for all other dims
"""
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in tensors[1:]
), 'All tensors must have the same number of dimensions'
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
out_size[dim] = sum(t.shape[dim] for t in tensors)
output = tensors[0].new_full(out_size, padding_value)
index = 0
for t in tensors:
# Create a slice list where every dimension except dim is full slice
slices = [slice(0, t.shape[d]) for d in range(ndim)]
# Update only the concat dimension slice
slices[dim] = slice(index, index + t.shape[dim])
output[slices] = t
index += t.shape[dim]
return output
def covost_collate_fn(batch):
input_ids_list = []
input_audio_embeds_list = []
audio_embed_sizes_list = []
audio_attention_mask_list = []
input_modes_list = []
answer_list = []
for inputs in batch:
input_ids_list.append(inputs['input_ids'][0])
input_audio_embeds_list.append(inputs['input_audio_embeds'])
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
audio_attention_mask_list.append(
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
)
input_modes_list.append(inputs['input_modes'])
answer_list.append(inputs['answer'])
try:
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
audio_attention_mask = (
pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
if len(audio_attention_mask_list) > 1
else None
)
except Exception as e:
print(e)
print(input_ids_list)
print(audio_attention_mask)
raise
attention_mask = (input_ids != 0).long()
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
input_modes = torch.cat(input_modes_list)
return BatchFeature(
{
'input_ids': input_ids,
'attention_mask': attention_mask,
'input_audio_embeds': input_audio_embeds,
'audio_embed_sizes': audio_embed_sizes,
'audio_attention_mask': audio_attention_mask,
'input_modes': input_modes,
'answer': answer_list,
}
)
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
filename = f"{task}_{dataset_name}_{source_lang}"
if target_lang:
filename += f"_to_{target_lang}"
if sample_idx is not None:
filename += f"_sample_{sample_idx}"
filepath = os.path.join(results_dir, f"{filename}.json")
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
return filepath
def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True):
import opencc
converter = opencc.OpenCC('s2tw.json')
task_type = "asr" if is_asr else "translation"
eval_lang = source_lang if is_asr else target_lang
if eval_lang in normalizer:
eval_normalizer = normalizer[eval_lang]
else:
eval_normalizer = normalizer['other']
sample_results = []
if num_samples > 0 and num_samples < len(dataset):
indices = np.random.choice(len(dataset), num_samples, replace=False)
dataset = dataset.select(indices)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn)
evaluated_samples = {}
for batch_idx, batch in enumerate(tqdm(dataloader)):
batch_references = batch.pop("answer")
if torch.cuda.is_available():
try:
batch = {k: v.to("cuda") for k, v in batch.items()}
except:
print('error')
break
with torch.inference_mode():
generate_ids = model.generate(**batch,
max_new_tokens=256,
#temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
)
input_lengths = batch['input_ids'].shape[1]
generate_ids = generate_ids[:, input_lengths:]
batch_predictions = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
idx = batch_idx * batch_size + i
sample_result = {
"id": idx,
"reference": reference,
"prediction": converter.convert(prediction)
}
sample_results.append(sample_result)
if (batch_idx + 1) % 10 == 0:
temp_results = []
for item in sample_results:
sample_id = item["id"]
if sample_id in evaluated_samples:
temp_item = item.copy()
temp_item.update(evaluated_samples[sample_id])
temp_results.append(temp_item)
else:
temp_item = item.copy()
try:
ref = eval_normalizer(item["reference"])
pred = eval_normalizer(item["prediction"])
# BLEU, WER/CER
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
utt_wer = round(wer(ref, pred) * 100, 2)
metrics = {
"bleu": utt_bleu,
"cer": utt_cer,
"wer": utt_wer
}
evaluated_samples[sample_id] = metrics
temp_item.update(metrics)
except Exception as e:
print(f"Error evaluating sample {sample_id}: {e}")
metrics = {
"bleu": 0,
"cer": 100,
"wer": 100,
"error": str(e)
}
evaluated_samples[sample_id] = metrics
temp_item.update(metrics)
temp_results.append(temp_item)
partial_results = {
"task": task_type,
"source_lang": source_lang,
"target_lang": target_lang,
"num_samples": len(temp_results),
"sample_results": temp_results
}
save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
for item in sample_results:
ref = eval_normalizer(item["reference"])
pred = eval_normalizer(item["prediction"])
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
utt_wer = round(wer(ref, pred) * 100, 2)
item.update({
"bleu": utt_bleu,
"cer": utt_cer,
"wer": utt_wer
})
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
results = {
"dataset": dataset.name,
"task": task_type,
"source_lang": source_lang,
"target_lang": target_lang,
"num_samples": len(sample_results),
"metrics": {
"bleu": avg_bleu,
"cer": avg_cer,
"wer": avg_wer
},
"sample_results": sample_results
}
save_results(results, dataset.name, task_type, source_lang, target_lang)
return results
if __name__ == "__main__":
source_languages = [
("en_us", "English"),
]
target_languages = [
("zh-TW", "zh-TW"),
]
num_samples = -1
batch_size = 2
for source_lang, target_lang in zip(source_languages, target_languages):
print(f"\n===== {source_lang[0]} ASR =====")
split = "test"
datasets = []
commonvoice_speech_tw = CommonVoiceDataset(
processor=processor,
source_lang="zh-TW",
split=split
)
datasets.append(commonvoice_speech_tw)
# Libri Speech Clean ASR mode (English -> English text)
# libri_speech_clean = LibriSpeechDataset(
# processor=processor,
# subset="clean",
# split=split
# )
# datasets.append(libri_speech_clean)
# # Libri Speech Other ASR mode (English -> English text)
# libri_speech_other = LibriSpeechDataset(
# processor=processor,
# subset="other",
# split=split
# )
# datasets.append(libri_speech_other)
# Fleurs ASR mode (English -> English text)
fleurs = FleursDataset(
processor=processor,
split=split,
source_lang="en_us", # English
mode="asr"
)
datasets.append(fleurs)
for dataset in datasets:
# ASR
asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR===")
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")