|
|
from io import BytesIO |
|
|
from urllib.request import urlopen |
|
|
import soundfile |
|
|
import torch |
|
|
from datasets import load_dataset, Audio |
|
|
import numpy as np |
|
|
from transformers import AutoModel, AutoProcessor, BatchFeature |
|
|
from tqdm import tqdm |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
from whisper_normalizer.english import EnglishTextNormalizer |
|
|
from whisper_normalizer.basic import BasicTextNormalizer |
|
|
import sacrebleu |
|
|
from jiwer import cer, wer |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import soundfile as sf |
|
|
import re |
|
|
from pathlib import Path |
|
|
import opencc |
|
|
from ASRDataset import * |
|
|
|
|
|
converter = opencc.OpenCC('s2tw.json') |
|
|
normalizer = { |
|
|
"en_us" : EnglishTextNormalizer(), |
|
|
"other" : BasicTextNormalizer() |
|
|
} |
|
|
|
|
|
model_id = "/mnt/jeff/gemma_test" |
|
|
revision = "main" |
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
|
model_id, device_map="cuda", revision = revision, trust_remote_code=True |
|
|
).eval() |
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
model_id, revision = revision, trust_remote_code=True |
|
|
) |
|
|
|
|
|
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
os.makedirs(results_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
INSTRUCTION = { |
|
|
"ast": "Translate the audio to {0}.", |
|
|
"asr": "Transcribe the audio clip into text.", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None): |
|
|
filename = f"{task}_{dataset_name}_{source_lang}" |
|
|
if target_lang: |
|
|
filename += f"_to_{target_lang}" |
|
|
if sample_idx is not None: |
|
|
filename += f"_sample_{sample_idx}" |
|
|
|
|
|
filepath = os.path.join(results_dir, f"{filename}.json") |
|
|
|
|
|
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
|
json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
return filepath |
|
|
|
|
|
def evaluate_task(dataset): |
|
|
sample_results = [] |
|
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn) |
|
|
|
|
|
evaluated_samples = {} |
|
|
|
|
|
for batch_idx, batch in enumerate(tqdm(dataloader)): |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
batch = {k: v.to("cuda") for k, v in batch.items()} |
|
|
except: |
|
|
print('error') |
|
|
break |
|
|
|
|
|
with torch.inference_mode(): |
|
|
generate_ids = model.generate(**batch, |
|
|
max_new_tokens=256, |
|
|
|
|
|
) |
|
|
|
|
|
input_lengths = batch['input_ids'].shape[1] |
|
|
generate_ids = generate_ids[:, input_lengths:] |
|
|
|
|
|
batch_predictions = processor.batch_decode( |
|
|
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
) |
|
|
input_lengths = batch['input_ids'].shape[1] |
|
|
label_ids = generate_ids[:, input_lengths:] |
|
|
batch_references = processor.batch_decode( |
|
|
label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
) |
|
|
|
|
|
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)): |
|
|
idx = batch_idx + i |
|
|
sample_result = { |
|
|
"id": idx, |
|
|
"reference": reference, |
|
|
"prediction": converter.convert(prediction) |
|
|
} |
|
|
sample_results.append(sample_result) |
|
|
|
|
|
if (batch_idx + 1) % 10 == 0: |
|
|
temp_results = [] |
|
|
|
|
|
for item in sample_results: |
|
|
sample_id = item["id"] |
|
|
|
|
|
if sample_id in evaluated_samples: |
|
|
temp_item = item.copy() |
|
|
temp_item.update(evaluated_samples[sample_id]) |
|
|
temp_results.append(temp_item) |
|
|
else: |
|
|
temp_item = item.copy() |
|
|
try: |
|
|
ref = eval_normalizer(item["reference"]) |
|
|
pred = eval_normalizer(item["prediction"]) |
|
|
|
|
|
|
|
|
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score |
|
|
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) |
|
|
utt_wer = round(wer(ref, pred) * 100, 2) |
|
|
|
|
|
metrics = { |
|
|
"bleu": utt_bleu, |
|
|
"cer": min(100,utt_cer), |
|
|
"wer": utt_wer |
|
|
} |
|
|
|
|
|
evaluated_samples[sample_id] = metrics |
|
|
temp_item.update(metrics) |
|
|
except Exception as e: |
|
|
print(f"Error evaluating sample {sample_id}: {e}") |
|
|
metrics = { |
|
|
"bleu": 0, |
|
|
"cer": 100, |
|
|
"wer": 100, |
|
|
"error": str(e) |
|
|
} |
|
|
evaluated_samples[sample_id] = metrics |
|
|
temp_item.update(metrics) |
|
|
|
|
|
temp_results.append(temp_item) |
|
|
|
|
|
partial_results = { |
|
|
"task": task_type, |
|
|
"source_lang": source_lang, |
|
|
"target_lang": target_lang, |
|
|
"num_samples": len(temp_results), |
|
|
"sample_results": temp_results |
|
|
} |
|
|
save_results(partial_results, dataset.name, task_type, source_lang, target_lang) |
|
|
|
|
|
for item in sample_results: |
|
|
ref = eval_normalizer(item["reference"]) |
|
|
pred = eval_normalizer(item["prediction"]) |
|
|
|
|
|
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score |
|
|
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) |
|
|
utt_wer = round(wer(ref, pred) * 100, 2) |
|
|
|
|
|
item.update({ |
|
|
"bleu": utt_bleu, |
|
|
"cer": min(100,utt_cer), |
|
|
"wer": utt_wer |
|
|
}) |
|
|
|
|
|
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results) |
|
|
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results) |
|
|
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results) |
|
|
|
|
|
results = { |
|
|
"dataset": dataset.name, |
|
|
"task": task_type, |
|
|
"source_lang": source_lang, |
|
|
"target_lang": target_lang, |
|
|
"num_samples": len(sample_results), |
|
|
"metrics": { |
|
|
"bleu": avg_bleu, |
|
|
"cer": avg_cer, |
|
|
"wer": avg_wer |
|
|
}, |
|
|
"sample_results": sample_results |
|
|
} |
|
|
|
|
|
save_results(results, dataset.name, task_type, source_lang, target_lang) |
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
datasets = [] |
|
|
pickup_dataset = MultiturnAudioDataset(split='eval',processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json') |
|
|
datasets.append(pickup_dataset) |
|
|
for dataset in datasets: |
|
|
|
|
|
asr_results = evaluate_task(dataset) |
|
|
|
|
|
print(f"\n=== {asr_results.get('dataset', 'Dataset')}") |
|
|
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}") |
|
|
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}") |
|
|
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}") |