In [2]:
from io import BytesIO
from urllib.request import urlopen
import soundfile
import torch
from datasets import load_dataset, Audio
import numpy as np
from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor
from tqdm import tqdm
import json
import os
import time
from datetime import datetime
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
import sacrebleu
from jiwer import cer, wer
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import re
from pathlib import Path
import opencc
from ASRDataset import *

# converter = opencc.OpenCC('s2tw.json')

model_id = "./"
revision = "main" #"v1.0"

model = AutoModel.from_pretrained(
    model_id, device_map="cuda", revision = revision, trust_remote_code=True
).eval()

processor = AutoProcessor.from_pretrained(
    model_id, revision = revision, trust_remote_code=True
)

results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# os.makedirs(results_dir, exist_ok=True)


def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
    filename = f"{task}_{dataset_name}_{source_lang}"
    if target_lang:
        filename += f"_to_{target_lang}"
    if sample_idx is not None:
        filename += f"_sample_{sample_idx}"
    
    filepath = os.path.join(results_dir, f"{filename}.json")
    
    results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    return filepath



  lambda i: encoder_checkpoint_wrapper(


######################## speech lora #############
######################## text lora #############


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
# from safetensors import safe_open

# tensors = {}
# with safe_open("/mnt/jeff/LLaMA-Factory/saves/Gemma-3-4B/lora/pickup_jimmy/adapter_model.safetensors", framework="pt", device=0) as f:
#     for k in f.keys():
#         tensors[k] = f.get_tensor(k)
# lang_lora = {}
# for k in tensors:
#     new_k=k.replace('base_model.model.language_model.model','language_model.model.base_model.model')
#     lang_lora[new_k]=tensors[k]
# from safetensors.torch import save_file
# os.makedirs('./lora')
# save_file(lang_lora,'./lora/adapter_model.safetensors')


In [3]:
# !cp /mnt/jeff/LLaMA-Factory/saves/Gemma-3-4B/lora/pickup_jimmy/adapter_config.json ./lora

In [4]:
# model.load_adapter('./lora')

In [3]:

pickup_dataset = MultiturnAudioDataset(split='eval',processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
dataloader = DataLoader(pickup_dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)


In [None]:
transcribe_error = 0
output_error = 0
format_error = 0
func_error = 0
total_func_call = 0
all_output = []
remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\
                        replace(' ','').replace('?','').replace('？','').replace('!','').replace('。','').\
                        replace('.','').replace('！','')
for batch_idx, batch in enumerate(tqdm(dataloader)):
    audio_path = batch.pop('audio_path')
    batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)}
    with torch.inference_mode():
        
        generate_ids = model.generate(**batch, 
        max_new_tokens=256,
        temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True
        )
        
        batch_inp = processor.batch_decode(
            batch['input_ids'], skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        batch_predictions = processor.batch_decode(
            generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        batch_references = processor.batch_decode(
            batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        for inp,label,output in zip(batch_inp,batch_references,batch_predictions):
            if (not ';\n' in output) or (not 'User transcribe is' in output) or (not 'GPT output is' in output):
                transcribe_error+=1
                output_error+=1
                format_error+=1
                cer_t=100
                cer_o=100
            else:
                label_t,label_o = label.split(';\n')
                label_t,label_o = remove_sign(label_t),remove_sign(label_o)
                output_t,output_o = output.split(';\n')
                output_t,output_o = remove_sign(output_t),remove_sign(output_o)
                transcribe_error+=(label_t!=output_t)
                output_error+=(label_o!=output_o)
                cer_t = min(100,round(cer(re.sub(r"\s+", "", label_t), re.sub(r"\s+", "", output_t)) * 100, 2))
                cer_o = min(100,round(cer(re.sub(r"\s+", "", label_o), re.sub(r"\s+", "", output_o)) * 100, 2))
                if 'Action:' in label:
                    func_error+=(label_o!=output_o)
                    total_func_call+=1

            all_output.append({
                'audio_path':audio_path,
                'input':inp,
                'label':label,
                'output':output,
                'cer_o':cer_o,
                'cer_t':cer_t
            })
    break
    # if batch_idx>10:break
avg_cer_o = sum(a['cer_o'] for a in all_output)/len(all_output)
avg_cer_t = sum(a['cer_t'] for a in all_output)/len(all_output)
print('total',len(all_output))
print('avg_cer_o',avg_cer_o)
print('avg_cer_t',avg_cer_t)
print('transcribe_error & rate',transcribe_error,',',transcribe_error/len(all_output))
print('output_error & rate',output_error,',',output_error/len(all_output))
print('format_error',format_error)
print('total_func_call',total_func_call)
print('func_error & rate',func_error,',',func_error/total_func_call)

100%|██████████| 434/434 [49:58<00:00,  6.91s/it] 

total 434
avg_cer_o 1.4704147465437787
avg_cer_t 1.8358986175115206
transcribe_error & rate 13 , 0.029953917050691243
output_error & rate 13 , 0.029953917050691243
format_error 0
total_func_call 229
func_error & rate 3 , 0.013100436681222707





In [18]:
for a in all_output:
    if a['cer_o']!=0 or a['cer_t']!=0: 
        display(a)

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_26987044-breezyvoice-01789.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_19324340-breezyvoice-01790.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_26743189-breezyvoice-01791.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_up

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_19313207-breezyvoice-01793.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n\n\n\n\n\nmodel\n',
 'label': '\nUser transcribe is : 我要去接阿德;\nGPT output 

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_27003818-breezyvoice-01806.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n\n\n\n\n\nmodel\n',
 'label': '\nUser transcribe is : 我要去接小白;\nGPT output 

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_25905179-breezyvoice-01855.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n\n\n\n\n\nmodel\n',
 'label': '\nUser transcribe is : 我要去接小白;\nGPT output 

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_27358623-breezyvoice-01976.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n\n\n\n\n\nmodel\n',
 'label': '\nUser transcribe is : 我要去接小白;\nGPT output 

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_27358623-breezyvoice-01976.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_24477926-breezyvoice-01977.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input":

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_26776103-breezyvoice-02001.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_31045905-breezyvoice-02002.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_22007606-breezyvoice-02003.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_up

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_19324372-breezyvoice-02030.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_26775796-breezyvoice-02031.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_27166142-breezyvoice-02032.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_up

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_30160540-breezyvoice-02069.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_25042163-breezyvoice-02070.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_30160932-breezyvoice-02071.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_19981972-breezyvoice-02072.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, re

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_21544641-breezyvoice-02105.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_35101161-breezyvoice-02106.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_24478208-breezyvoice-02107.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_up

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_22460722-breezyvoice-02124.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n\n\n\n\n\nmodel\n',
 'label': '\nUser transcribe is : 我要去接Emily;\nGPT outp

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_19239246-breezyvoice-02133.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n\n\n\n\n\nmodel\n',
 'label': '\nUser transcribe is : 我要去接Emily;\nGPT outp

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_27358283-breezyvoice-02135.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_26947370-breezyvoice-02136.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_19251372-breezyvoice-02137.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_up

{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_30460669-breezyvoice-02183.wav',
   '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_30423060-breezyvoice-02184.wav']],
 'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input":

In [4]:
all_output

[{'audio_path': [['/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_18855684-breezyvoice-01778.wav']],
  'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n\n\n\n\n\nmodel\n',
  'label': '\nUser transcribe is : 我要去寶高園區接人;\nGPT o

In [11]:
pickup_dataset = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
dataloader = DataLoader(pickup_dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)

In [10]:
# for batch_idx, batch in enumerate(tqdm(dataloader)):
#     batch_inputs = processor.batch_decode(
#         batch['input_ids'], skip_special_tokens=True, 
#         clean_up_tokenization_spaces=False
#     )
#     batch_references = processor.batch_decode(
#         batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False
#     )
#     print('input',batch_inputs)
#     print('label',batch_references)
#     print('-----------------')
#     if batch_idx>5:break

In [13]:
func_error = 0
total_func_call = 0
total_error = 0
all_output_text = []
remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\
                        replace(' ','').replace('?','').replace('？','').replace('!','').replace('。','').\
                        replace('.','').replace('！','')
for batch_idx, batch in enumerate(tqdm(dataloader)):
    batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)}
    with torch.inference_mode():
        
        generate_ids = model.generate(**batch, 
        max_new_tokens=256,
        temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True
        )
        batch_inputs = processor.batch_decode(
            generate_ids[:, :batch['input_ids'].shape[1]], skip_special_tokens=True, 
            clean_up_tokenization_spaces=False
        )
        batch_predictions = processor.batch_decode(
            generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True, 
            clean_up_tokenization_spaces=False
        )
        batch_references = processor.batch_decode(
            batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        for inp,label,output in zip(batch_inputs,batch_references,batch_predictions):
        
            cer_o = min(100,round(cer(re.sub(r"\s+", "", label), re.sub(r"\s+", "", output)) * 100, 2))
            all_output_text.append({
                'input':inp,
                'label':label,
                'output':output,
                'cer':cer_o,
            })
            if 'Action:' in label:
                func_error+=(label!=output)
                total_func_call+=1
avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text)
total_error = sum(a['cer']!=0 for a in all_output_text)
print('total',len(all_output_text))
print('total_error & rate',total_error,total_error/len(all_output_text))
print('avg_cer',avg_cer)
print('total_func_call',total_func_call)
print('func_error & rate',func_error,',',func_error/total_func_call)

100%|██████████| 955/955 [47:04<00:00,  2.96s/it]  

total 955
total_error & rate 3 0.0031413612565445027
avg_cer 0.25132984293193716
total_func_call 375
func_error & rate 2 , 0.005333333333333333





In [16]:
for a in all_output_text:
    if a['cer']!=0:
        display(a)

{'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n我要去接Emily\nmodel\n請問要去哪裡接Emily呢\nuser\n我要去接Emily\nmodel\n請問要去哪裡接Emily呢\nuser\n去市政府捷運站接人\nmodel\n',
 'label': 'Action: get_phonebook\nAction Input: {"name": "Emily", "destination": "市政府捷運站"}\n',
 'output': '請問要去

{'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n我要去接阿德\nmodel\n請問要去哪裡接阿德呢\nuser\n我要去市政府捷運站接阿德\nmodel\n',
 'label': 'Action: get_phonebook\nAction Input: {"name": "阿德", "destination": "市政府捷運站"}\n',
 'output': '請問要去哪裡接阿德呢',
 'cer': 97.06}

{'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n我要去接Emily\nmodel\n請問要去哪裡接Emily呢\nuser\n我要去接Emily\nmodel\n請問要去哪裡接Emily呢\nuser\n我要去接小白\nmodel\n',
 'label': '請問要去哪裡接小白呢\n',
 'output': '請問要去哪裡接Emily呢',
 'cer': 50.0}

In [6]:
all_output_text

[{'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指定聯絡人，若無則回傳空字串\nTool Args:\n  - name (string, required): name參數\n\n> Tool Name: message_update\nTool Description: 更新乘客通知簡訊內容\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n  - message (string, required): message參數\n\n> Tool Name: save_phonebook\nTool Description: 儲存新聯絡人資訊到通訊錄\nTool Args:\n  - name (string, required): name參數\n  - phone (string, required): phone參數\n\nUse the following format if using a tool:\n```\nAction: tool name (one of [get_phonebook, message_update, save_phonebook])\nAction Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)\n```\n\n\n我要去寶高園區接人\nmodel\n',
  'label': '請問你要去寶高園區接誰？\n',
  'output': '請問你要去寶高園區接誰？',
  'cer': 0.0},
 {'input': 'user\nYou have access to the following tools:\n> Tool Name: get_phonebook\nTool Description: 查詢通訊錄內是否存在指

In [8]:
aud = [
    # '/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529/common_voice_16_1-TW/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW-common_voice_zh-TW_17372022-breezyvoice-05137.wav',
    '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_24478208-breezyvoice-02107.wav',
    '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_27358623-breezyvoice-01976.wav',
    '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_30423060-breezyvoice-02184.wav'
    ]
aduio_arr = torchaudio.load(aud[0])
prompt = "Transcribe the audio clip into text."
inp = pickup_dataset.prepare_model_inputs(aduio_arr[0],prompt,"")
inp = {k:v.to('cuda') for k,v in inp.items()}
generate_ids = model.generate(**inp, 
max_new_tokens=256,
temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True
)

input_lengths = inp['input_ids'].shape[1]

batch_predictions = processor.batch_decode(
    generate_ids[:, input_lengths:], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
processor.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

['user\n\n\n\n\nTranscribe the audio clip into text.\nmodel\n屏東縣九如鄉民生路']

In [1]:
!cp '/mnt/jeff/InCar/data/multiturn_data/pickup_breezy/pickup_breezy-common_voice_zh-TW_24478208-breezyvoice-02107.wav' .