go / eval_multiturn.py
jva96160's picture
Upload 25 files
a16e4aa verified
raw
history blame
7.33 kB
from io import BytesIO
from urllib.request import urlopen
import soundfile
import torch
from datasets import load_dataset, Audio
import numpy as np
from transformers import AutoModel, AutoProcessor, BatchFeature
from tqdm import tqdm
import json
import os
import time
from datetime import datetime
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
import sacrebleu
from jiwer import cer, wer
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import re
from pathlib import Path
import opencc
from ASRDataset import *
converter = opencc.OpenCC('s2tw.json')
normalizer = {
"en_us" : EnglishTextNormalizer(),
"other" : BasicTextNormalizer()
}
model_id = "/mnt/jeff/gemma_test"
revision = "main" #"v1.0"
model = AutoModel.from_pretrained(
model_id, device_map="cuda", revision = revision, trust_remote_code=True
).eval()
processor = AutoProcessor.from_pretrained(
model_id, revision = revision, trust_remote_code=True
)
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(results_dir, exist_ok=True)
INSTRUCTION = {
"ast": "Translate the audio to {0}.",
"asr": "Transcribe the audio clip into text.",
}
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
filename = f"{task}_{dataset_name}_{source_lang}"
if target_lang:
filename += f"_to_{target_lang}"
if sample_idx is not None:
filename += f"_sample_{sample_idx}"
filepath = os.path.join(results_dir, f"{filename}.json")
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
return filepath
def evaluate_task(dataset):
sample_results = []
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
evaluated_samples = {}
for batch_idx, batch in enumerate(tqdm(dataloader)):
if torch.cuda.is_available():
try:
batch = {k: v.to("cuda") for k, v in batch.items()}
except:
print('error')
break
with torch.inference_mode():
generate_ids = model.generate(**batch,
max_new_tokens=256,
#temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
)
input_lengths = batch['input_ids'].shape[1]
generate_ids = generate_ids[:, input_lengths:]
batch_predictions = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
input_lengths = batch['input_ids'].shape[1]
label_ids = generate_ids[:, input_lengths:]
batch_references = processor.batch_decode(
label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
idx = batch_idx + i
sample_result = {
"id": idx,
"reference": reference,
"prediction": converter.convert(prediction)
}
sample_results.append(sample_result)
if (batch_idx + 1) % 10 == 0:
temp_results = []
for item in sample_results:
sample_id = item["id"]
if sample_id in evaluated_samples:
temp_item = item.copy()
temp_item.update(evaluated_samples[sample_id])
temp_results.append(temp_item)
else:
temp_item = item.copy()
try:
ref = eval_normalizer(item["reference"])
pred = eval_normalizer(item["prediction"])
# BLEU, WER/CER
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
utt_wer = round(wer(ref, pred) * 100, 2)
metrics = {
"bleu": utt_bleu,
"cer": min(100,utt_cer),
"wer": utt_wer
}
evaluated_samples[sample_id] = metrics
temp_item.update(metrics)
except Exception as e:
print(f"Error evaluating sample {sample_id}: {e}")
metrics = {
"bleu": 0,
"cer": 100,
"wer": 100,
"error": str(e)
}
evaluated_samples[sample_id] = metrics
temp_item.update(metrics)
temp_results.append(temp_item)
partial_results = {
"task": task_type,
"source_lang": source_lang,
"target_lang": target_lang,
"num_samples": len(temp_results),
"sample_results": temp_results
}
save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
for item in sample_results:
ref = eval_normalizer(item["reference"])
pred = eval_normalizer(item["prediction"])
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
utt_wer = round(wer(ref, pred) * 100, 2)
item.update({
"bleu": utt_bleu,
"cer": min(100,utt_cer),
"wer": utt_wer
})
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
results = {
"dataset": dataset.name,
"task": task_type,
"source_lang": source_lang,
"target_lang": target_lang,
"num_samples": len(sample_results),
"metrics": {
"bleu": avg_bleu,
"cer": avg_cer,
"wer": avg_wer
},
"sample_results": sample_results
}
save_results(results, dataset.name, task_type, source_lang, target_lang)
return results
if __name__ == "__main__":
datasets = []
pickup_dataset = MultiturnAudioDataset(split='eval',processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
datasets.append(pickup_dataset)
for dataset in datasets:
# ASR
asr_results = evaluate_task(dataset)
print(f"\n=== {asr_results.get('dataset', 'Dataset')}")
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")