|
|
from io import BytesIO |
|
|
from urllib.request import urlopen |
|
|
import soundfile |
|
|
import torch |
|
|
from datasets import load_dataset, Audio |
|
|
import numpy as np |
|
|
from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor |
|
|
from tqdm import tqdm |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
from whisper_normalizer.english import EnglishTextNormalizer |
|
|
from whisper_normalizer.basic import BasicTextNormalizer |
|
|
import sacrebleu |
|
|
from jiwer import cer, wer |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import soundfile as sf |
|
|
import re |
|
|
from pathlib import Path |
|
|
import opencc |
|
|
from ASRDataset import * |
|
|
|
|
|
|
|
|
|
|
|
model_id = "./" |
|
|
revision = "main" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
model_id, revision = revision, trust_remote_code=True |
|
|
) |
|
|
|
|
|
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
|
|
|
|
|
|
def eval_text(model,dataloader,with_input_mode=False,save_path="",start_idx=0): |
|
|
res = {'label':[],"pred":[],'cer':[]} |
|
|
func_error = 0 |
|
|
total_func_call = 0 |
|
|
total_error = 0 |
|
|
all_output_text = [] |
|
|
remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\ |
|
|
replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\ |
|
|
replace('.','').replace('!','') |
|
|
for batch_idx, batch in enumerate(tqdm(dataloader)): |
|
|
if batch_idx<=start_idx:continue |
|
|
batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)} |
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
if not with_input_mode: batch.pop('input_modes') |
|
|
generate_ids = model.generate(**batch, |
|
|
max_new_tokens=256, |
|
|
temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True |
|
|
) |
|
|
batch_inputs = processor.batch_decode( |
|
|
generate_ids[:, :batch['input_ids'].shape[1]], skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
) |
|
|
batch_predictions = processor.batch_decode( |
|
|
generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
) |
|
|
batch_references = processor.batch_decode( |
|
|
batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
) |
|
|
for inp,label,output in zip(batch_inputs,batch_references,batch_predictions): |
|
|
|
|
|
cer_o = min(100,round(cer(re.sub(r"\s+", "", label), re.sub(r"\s+", "", output)) * 100, 2)) |
|
|
res['label'].append(batch_references) |
|
|
res['pred'].append(batch_predictions) |
|
|
res['cer'].append(cer_o) |
|
|
all_output_text.append({ |
|
|
'input':inp, |
|
|
'label':label, |
|
|
'output':output, |
|
|
'cer':cer_o, |
|
|
}) |
|
|
if 'Action:' in label: |
|
|
func_error+=(remove_sign(label)!=remove_sign(output)) |
|
|
total_func_call+=1 |
|
|
if batch_idx%100==0: |
|
|
with open(save_path,'w', encoding='utf-8') as f: |
|
|
json.dump(all_output_text,f, ensure_ascii=False, indent=4) |
|
|
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text) |
|
|
total_error = sum(a['cer']!=0 for a in all_output_text) |
|
|
print('total',len(all_output_text)) |
|
|
print('total_error & rate',total_error,total_error/len(all_output_text)) |
|
|
print('avg_cer',avg_cer) |
|
|
print('total_func_call',total_func_call) |
|
|
print('func_error & rate',func_error,',',func_error/total_func_call) |
|
|
except: |
|
|
print("error at ",batch_idx) |
|
|
time.sleep(2) |
|
|
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text) |
|
|
total_error = sum(a['cer']!=0 for a in all_output_text) |
|
|
print('total',len(all_output_text)) |
|
|
print('total_error & rate',total_error,total_error/len(all_output_text)) |
|
|
print('avg_cer',avg_cer) |
|
|
print('total_func_call',total_func_call) |
|
|
print('func_error & rate',func_error,',',func_error/total_func_call) |
|
|
with open(save_path,'w', encoding='utf-8') as f: |
|
|
json.dump(all_output_text,f, ensure_ascii=False, indent=4) |
|
|
return res,all_output_text |
|
|
|
|
|
|
|
|
nav_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/nav_toolcall_train.json') |
|
|
ctrl_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/ctrl_toolcall_train.json') |
|
|
ctrl_dataloader = DataLoader(ctrl_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn) |
|
|
nav_dataloader = DataLoader(nav_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn) |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration |
|
|
from PIL import Image |
|
|
import requests |
|
|
import torch |
|
|
|
|
|
model_id_org = "google/gemma-3-4b-it" |
|
|
|
|
|
model_org = Gemma3ForConditionalGeneration.from_pretrained( |
|
|
model_id_org, device_map="auto",attn_implementation="eager" |
|
|
).eval() |
|
|
|
|
|
from peft import PeftModel |
|
|
model_org = PeftModel.from_pretrained(model_org, '/mnt/data-2t/jeff/codes/LLaMA-Factory/saves/Gemma-3-4B-Instruct/lora/train_123/checkpoint-3270') |
|
|
|
|
|
|
|
|
|
|
|
res_org_nav,output_org_nav = eval_text(model_org,nav_dataloader,save_path='./output_org_nav_{}.json'.format(str(datetime.now())[:16])) |
|
|
res_org_ctrl,output_org_ctrl = eval_text(model_org,ctrl_dataloader,save_path='./output_org_ctrl_{}.json'.format(str(datetime.now())[:16])) |
|
|
|
|
|
|