from io import BytesIO from urllib.request import urlopen import soundfile import torch from datasets import load_dataset, Audio import numpy as np from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor from tqdm import tqdm import json import os import time from datetime import datetime from whisper_normalizer.english import EnglishTextNormalizer from whisper_normalizer.basic import BasicTextNormalizer import sacrebleu from jiwer import cer, wer from torch.utils.data import Dataset, DataLoader import soundfile as sf import re from pathlib import Path import opencc from ASRDataset import * # converter = opencc.OpenCC('s2tw.json') model_id = "./" revision = "main" #"v1.0" processor = AutoProcessor.from_pretrained( model_id, revision = revision, trust_remote_code=True ) results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" # os.makedirs(results_dir, exist_ok=True) def eval_text(model,dataloader,with_input_mode=False,save_path="",start_idx=0): res = {'label':[],"pred":[],'cer':[]} func_error = 0 total_func_call = 0 total_error = 0 all_output_text = [] remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\ replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\ replace('.','').replace('!','') for batch_idx, batch in enumerate(tqdm(dataloader)): if batch_idx<=start_idx:continue batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)} try: with torch.inference_mode(): if not with_input_mode: batch.pop('input_modes') generate_ids = model.generate(**batch, max_new_tokens=256, temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True ) batch_inputs = processor.batch_decode( generate_ids[:, :batch['input_ids'].shape[1]], skip_special_tokens=True, clean_up_tokenization_spaces=False ) batch_predictions = processor.batch_decode( generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=False ) batch_references = processor.batch_decode( batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False ) for inp,label,output in zip(batch_inputs,batch_references,batch_predictions): cer_o = min(100,round(cer(re.sub(r"\s+", "", label), re.sub(r"\s+", "", output)) * 100, 2)) res['label'].append(batch_references) res['pred'].append(batch_predictions) res['cer'].append(cer_o) all_output_text.append({ 'input':inp, 'label':label, 'output':output, 'cer':cer_o, }) if 'Action:' in label: func_error+=(remove_sign(label)!=remove_sign(output)) total_func_call+=1 if batch_idx%100==0: with open(save_path,'w', encoding='utf-8') as f: json.dump(all_output_text,f, ensure_ascii=False, indent=4) avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text) total_error = sum(a['cer']!=0 for a in all_output_text) print('total',len(all_output_text)) print('total_error & rate',total_error,total_error/len(all_output_text)) print('avg_cer',avg_cer) print('total_func_call',total_func_call) print('func_error & rate',func_error,',',func_error/total_func_call) except: print("error at ",batch_idx) time.sleep(2) avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text) total_error = sum(a['cer']!=0 for a in all_output_text) print('total',len(all_output_text)) print('total_error & rate',total_error,total_error/len(all_output_text)) print('avg_cer',avg_cer) print('total_func_call',total_func_call) print('func_error & rate',func_error,',',func_error/total_func_call) with open(save_path,'w', encoding='utf-8') as f: json.dump(all_output_text,f, ensure_ascii=False, indent=4) return res,all_output_text nav_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/nav_toolcall_train.json') ctrl_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/ctrl_toolcall_train.json') ctrl_dataloader = DataLoader(ctrl_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn) nav_dataloader = DataLoader(nav_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn) from transformers import AutoProcessor, Gemma3ForConditionalGeneration from PIL import Image import requests import torch model_id_org = "google/gemma-3-4b-it" model_org = Gemma3ForConditionalGeneration.from_pretrained( model_id_org, device_map="auto",attn_implementation="eager" ).eval() from peft import PeftModel model_org = PeftModel.from_pretrained(model_org, '/mnt/data-2t/jeff/codes/LLaMA-Factory/saves/Gemma-3-4B-Instruct/lora/train_123/checkpoint-3270') res_org_nav,output_org_nav = eval_text(model_org,nav_dataloader,save_path='./output_org_nav_{}.json'.format(str(datetime.now())[:16])) res_org_ctrl,output_org_ctrl = eval_text(model_org,ctrl_dataloader,save_path='./output_org_ctrl_{}.json'.format(str(datetime.now())[:16]))