go / eval_multiturn_textonly.py
jva96160's picture
Upload 32 files
4c1ba5a verified
raw
history blame
5.93 kB
from io import BytesIO
from urllib.request import urlopen
import soundfile
import torch
from datasets import load_dataset, Audio
import numpy as np
from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor
from tqdm import tqdm
import json
import os
import time
from datetime import datetime
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
import sacrebleu
from jiwer import cer, wer
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import re
from pathlib import Path
import opencc
from ASRDataset import *
# converter = opencc.OpenCC('s2tw.json')
model_id = "./"
revision = "main" #"v1.0"
processor = AutoProcessor.from_pretrained(
model_id, revision = revision, trust_remote_code=True
)
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# os.makedirs(results_dir, exist_ok=True)
def eval_text(model,dataloader,with_input_mode=False,save_path="",start_idx=0):
res = {'label':[],"pred":[],'cer':[]}
func_error = 0
total_func_call = 0
total_error = 0
all_output_text = []
remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\
replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\
replace('.','').replace('!','')
for batch_idx, batch in enumerate(tqdm(dataloader)):
if batch_idx<=start_idx:continue
batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)}
try:
with torch.inference_mode():
if not with_input_mode: batch.pop('input_modes')
generate_ids = model.generate(**batch,
max_new_tokens=256,
temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True
)
batch_inputs = processor.batch_decode(
generate_ids[:, :batch['input_ids'].shape[1]], skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
batch_predictions = processor.batch_decode(
generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
batch_references = processor.batch_decode(
batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for inp,label,output in zip(batch_inputs,batch_references,batch_predictions):
cer_o = min(100,round(cer(re.sub(r"\s+", "", label), re.sub(r"\s+", "", output)) * 100, 2))
res['label'].append(batch_references)
res['pred'].append(batch_predictions)
res['cer'].append(cer_o)
all_output_text.append({
'input':inp,
'label':label,
'output':output,
'cer':cer_o,
})
if 'Action:' in label:
func_error+=(remove_sign(label)!=remove_sign(output))
total_func_call+=1
if batch_idx%100==0:
with open(save_path,'w', encoding='utf-8') as f:
json.dump(all_output_text,f, ensure_ascii=False, indent=4)
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
total_error = sum(a['cer']!=0 for a in all_output_text)
print('total',len(all_output_text))
print('total_error & rate',total_error,total_error/len(all_output_text))
print('avg_cer',avg_cer)
print('total_func_call',total_func_call)
print('func_error & rate',func_error,',',func_error/total_func_call)
except:
print("error at ",batch_idx)
time.sleep(2)
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
total_error = sum(a['cer']!=0 for a in all_output_text)
print('total',len(all_output_text))
print('total_error & rate',total_error,total_error/len(all_output_text))
print('avg_cer',avg_cer)
print('total_func_call',total_func_call)
print('func_error & rate',func_error,',',func_error/total_func_call)
with open(save_path,'w', encoding='utf-8') as f:
json.dump(all_output_text,f, ensure_ascii=False, indent=4)
return res,all_output_text
nav_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/nav_toolcall_train.json')
ctrl_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/ctrl_toolcall_train.json')
ctrl_dataloader = DataLoader(ctrl_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
nav_dataloader = DataLoader(nav_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch
model_id_org = "google/gemma-3-4b-it"
model_org = Gemma3ForConditionalGeneration.from_pretrained(
model_id_org, device_map="auto",attn_implementation="eager"
).eval()
from peft import PeftModel
model_org = PeftModel.from_pretrained(model_org, '/mnt/data-2t/jeff/codes/LLaMA-Factory/saves/Gemma-3-4B-Instruct/lora/train_123/checkpoint-3270')
res_org_nav,output_org_nav = eval_text(model_org,nav_dataloader,save_path='./output_org_nav_{}.json'.format(str(datetime.now())[:16]))
res_org_ctrl,output_org_ctrl = eval_text(model_org,ctrl_dataloader,save_path='./output_org_ctrl_{}.json'.format(str(datetime.now())[:16]))