go / ASRDataset.py
jva96160's picture
Upload 32 files
4c1ba5a verified
raw
history blame
34.2 kB
import datasets
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
import os
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
import json
import os
from pathlib import Path
import numpy as np
import torch
import sacrebleu
from datasets import load_dataset
from torch.utils.data import Dataset, ConcatDataset
from tqdm import tqdm
from transformers import (
BatchFeature,
)
import pandas as pd
import soundfile as sf
from datasets import Audio
import random
from copy import deepcopy
import torchaudio
ANSWER_SUFFIX = "<end_of_turn>"
_IGNORE_INDEX = -100
class BaseAudioDataset(Dataset):
def __init__(self, processor, split, sampling_rate=16000, debug=False):
self.processor = processor
self.training = "train" in split or 'other' in split
self.debug = debug
self.sampling_rate = sampling_rate
self.name = ""
def set_dataset_name(self, name):
self.name = name
@staticmethod
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
original_size = len(data)
data = data.cast_column(audio_field, Audio(decode=False))
def identify_corrupted_files(example):
try:
sf.read(example[audio_field]["path"])
for field in text_fields:
if field in example and example[field].replace('"', '') == "":
return False
return True
except Exception:
return False
data = data.filter(identify_corrupted_files, num_proc=16)
validated_size = len(data)
# Audio Decoding
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
if debug:
print(f"Dataset: {dataset_name}")
print(f"Original data nums: {original_size}")
print(f"After filtering data nums: {validated_size}")
print(f"Filtering ratio: {validated_size/original_size:.2%}")
return data
@staticmethod
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
original_size = len(data)
def filter_audio_by_length(example):
try:
audio = example[audio_field]['array']
channel = 1
if hasattr(audio, 'ndim') and audio.ndim > 1:
channel = audio.ndim
audio = audio.squeeze()
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
return min_sec <= audio_length <= max_sec
except Exception as e:
if debug:
print(f"Error : {str(e)[:100]}... - sample excluded")
return False
data = data.filter(filter_audio_by_length, num_proc=16)
filtered_size = len(data)
if debug:
print(f"Before Length Filtering data nums: {original_size}")
print(f"After Length Filtering data nums: {filtered_size}")
print(f"Filtering ratio: {filtered_size/original_size:.2%}")
return data
def prepare_model_inputs(self, audio_array, instruction, answer_text):
user_message = {
'role': 'user',
'content': '<start_of_audio>' + instruction,
}
prompt = self.processor.tokenizer.apply_chat_template(
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
)
inputs = self.processor(
text=prompt,
audio=[audio_array],
add_special_tokens=False,
return_tensors='pt'
)
answer = f"{answer_text}{ANSWER_SUFFIX}"
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
if self.debug:
self.debug = False
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
if self.training:
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
labels = torch.full_like(input_ids, _IGNORE_INDEX)
labels[:, -answer_ids.shape[1]:] = answer_ids
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
else:
input_ids = inputs.input_ids
labels = answer_ids
token_type_ids = inputs.token_type_ids
return {
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'input_audio_embeds': inputs.input_audio_embeds,
'audio_embed_sizes': inputs.audio_embed_sizes,
'input_modes': inputs.input_modes,
}
# Libri Speech Dataset Class
class LibriSpeechDataset(BaseAudioDataset):
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name(f"LibriSpeech_{subset}")
# only ASR
self.ast = False
self.lang = "en"
# load dataset
self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
subset,
split=split,
trust_remote_code=True,
cache_dir=Path("/mnt/jeff/InCar/data")
)
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# Instruction Setting
self.instruction = random.choice(INSTRUCTION["asr"])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
# Libri Speech is only for ASR
answer_text = data["text"].replace('"', '')
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
# common_voice_16_1 dataset
class CommonVoiceDataset(BaseAudioDataset):
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name(f"CommonVoice_{source_lang}")
# only ASR
self.ast = False
self.lang=source_lang
# load dataset
if source_lang=="zh-TW":
data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
else:
data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
self.data = load_dataset(data_path,
source_lang,
split=split,
trust_remote_code=True,
cache_dir=Path("/mnt/jeff/InCar/data")
)
def prepare_dataset(batch):
"""Function to preprocess the dataset with the .map method"""
transcription = batch["sentence"]
if transcription.startswith('"') and transcription.endswith('"'):
# we can remove trailing quotation marks as they do not affect the transcription
transcription = transcription[1:-1]
if transcription[-1] not in [".", "?", "!"]:
# append a full-stop to sentences that do not end in punctuation
transcription = transcription + "."
batch["sentence"] = transcription
return batch
import opencc
converter = opencc.OpenCC('s2tw.json')
def To_zhTW(batch):
transcription = converter.convert(batch["sentence"])
batch["sentence"] = transcription
return batch
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
if source_lang=='zh-CN':
self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
if source_lang == "zh-TW" and split=='train':
import torchaudio
from torchaudio import transforms
import copy
import pickle
import os
def subsample(batch):
batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
batch['audio']['sampling_rate']=16000
return batch
def TW_data_augment_fast(batch):
speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
batch['audio']['array'] = new_array_fast
return batch
def TW_data_augment_slow(batch):
speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
batch['audio']['array'] = new_array_slow
return batch
# data = self.data.map(subsample, num_proc=1, desc="subsample")
fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
if not os.path.exists(fast_path):
data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
with open(fast_path,'wb') as f:
pickle.dump(data_fast,f)
else:
with open(fast_path,'rb') as f:
data_fast=pickle.load(f)
slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
if not os.path.exists(slow_path):
data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
with open(slow_path,'wb') as f:
pickle.dump(data_slow,f)
else:
with open(slow_path,'rb') as f:
data_slow=pickle.load(f)
self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
# Instruction Setting
self.instruction = random.choice(INSTRUCTION["asr"])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
answer_text = data["sentence"]
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
# Fleurs Dataset Class
class FleursDataset(BaseAudioDataset):
def __init__(self, processor, split, source_lang, target_lang=None,
mode="asr", sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name("Fleurs")
# Mode Setting (ASR or AST)
if mode not in ["asr", "ast"]:
raise ValueError("mode must be 'asr' or 'ast'.")
self.mode = mode
self.ast = (mode == "ast")
self.source_lang = source_lang
# Language name mapping (expand if needed)
self.lang_names = {
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
}
# load dataset - source language dataset
self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
source_lang,
split=split,
trust_remote_code=True,
cache_dir=Path("/mnt/jeff/InCar/data")
)
import opencc
converter = opencc.OpenCC('s2tw.json')
def prepare_dataset(batch):
transcription = converter.convert(batch["transcription"])
batch["transcription"] = transcription
return batch
if (source_lang=="cmn_hans_cn"):
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
self.target_lang_name = ""
# When AST mode, load target language dataset.
if self.ast:
if target_lang is None:
raise ValueError("AST mode requires target_lang.")
self.target_lang = target_lang
self.lang = f"{source_lang}_{target_lang}"
# load dataset - target language dataset (for translation)
target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
target_lang,
split=split,
trust_remote_code=True,
cache_dir=Path("/mnt/jeff/InCar/data")
)
if target_lang=="cmn_hans_cn":
target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
source_dict = {item['id']: item for item in self.data}
target_dict = {item['id']: item for item in target_data}
# only Common ID, add translation fields
common_ids = set(source_dict.keys()) & set(target_dict.keys())
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
self.data = [
{**source_dict[id], 'translation': target_dict[id]['transcription']}
for id in common_ids
]
# Instruction Setting - use target language name
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
self.instruction = random.choice(INSTRUCTION["ast"])
else:
# ASR mode
self.lang = source_lang
self.instruction = random.choice(INSTRUCTION["asr"])
if self.debug:
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
if self.ast:
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
print(f"dataset size: {len(self.data)}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
audio_array = data["audio"]["array"]
if self.ast:
answer_text = data["translation"]
else:
answer_text = data["transcription"]
return self.prepare_model_inputs(
audio_array,
self.instruction.format(self.target_lang_name),
answer_text
)
class TWCostumData(BaseAudioDataset):
def __init__(self, processor, split="train", sampling_rate=16000,csv_path="", debug=False):
super().__init__(processor, split, sampling_rate, debug)
import pandas as pd
from datasets import Dataset, Audio
df = pd.read_csv(csv_path).fillna('')
self.set_dataset_name(f"TWCostumData")
self.data = Dataset.from_dict(
{
"audio": [audio for audio in df['audio']],
"sentence": [text for text in df['text']]
}
).cast_column("audio", Audio(sampling_rate=16000))
# Instruction Setting
self.instruction = random.choice(INSTRUCTION["asr"])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
answer_text = data["sentence"]
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
class TWCostumDataTasks(BaseAudioDataset):
def __init__(self, processor, split="train", sampling_rate=16000,json_path="", debug=False):
super().__init__(processor, split, sampling_rate, debug)
import pandas as pd
from datasets import Dataset, Audio
with open(json_path) as f:
js_data = json.load(f)
raw_data = {
"audio": [],
"sentence": []
}
for conv in js_data:
for mess in conv['conversations']:
if 'audio_path' in mess:
raw_data['audio'].append(mess['audio_path'])
raw_data['sentence'].append(mess["value"])
self.set_dataset_name("TWCostumDataTasks"+json_path)
self.data = Dataset.from_dict(raw_data).cast_column("audio", Audio(sampling_rate=16000))
# Instruction Setting
self.instruction = random.choice(INSTRUCTION["asr"])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
answer_text = data["sentence"]
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
def covost_collate_fn(batch):
input_ids_list = []
labels_list = []
token_type_ids_list = []
input_audio_embeds_list = []
audio_embed_sizes_list = []
audio_attention_mask_list = []
input_modes_list = []
audio_paths = []
for inputs in batch:
if 'audio_path' in inputs:
audio_paths.append(inputs['audio_path'])
input_ids_list.append(inputs['input_ids'][0])
labels_list.append(inputs['labels'][0])
token_type_ids_list.append(inputs['token_type_ids'][0])
if inputs['input_modes']==2:
input_audio_embeds_list.append(inputs['input_audio_embeds'])
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
audio_attention_mask_list.append(
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
)
# else:
# input_audio_embeds_list.append(None)
# audio_embed_sizes_list.append(None)
# audio_attention_mask_list.append(None)
input_modes_list.append(inputs['input_modes'])
# try:
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
audio_attention_mask = (
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
if len(audio_attention_mask_list) > 1
else None
)
# except Exception as e:
# print(e)
# print(input_ids_list)
# print(labels_list)
# raise
attention_mask = (input_ids != 0).long()
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) if len(input_audio_embeds_list)>0 else None
audio_embed_sizes = torch.cat(audio_embed_sizes_list) if len(audio_embed_sizes_list)>0 else None
input_modes = torch.cat(input_modes_list)
if len(audio_paths)>0:
return BatchFeature(
{
"audio_path": audio_paths,
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'input_audio_embeds': input_audio_embeds,
'audio_embed_sizes': audio_embed_sizes,
'audio_attention_mask': audio_attention_mask,
'input_modes': input_modes,
}
)
else:
return BatchFeature(
{
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'input_audio_embeds': input_audio_embeds,
'audio_embed_sizes': audio_embed_sizes,
'audio_attention_mask': audio_attention_mask,
'input_modes': input_modes,
}
)
def pad_sequence(sequences, padding_side='left', padding_value=0):
"""
Pad a list of sequences to the same length.
sequences: list of tensors in [seq_len, *] shape
"""
assert padding_side in ['right', 'left']
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max(len(seq) for seq in sequences)
batch_size = len(sequences)
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
for i, seq in enumerate(sequences):
length = seq.size(0)
if padding_side == 'right':
output.data[i, :length] = seq
else:
output.data[i, -length:] = seq
return output
def cat_with_pad(tensors, dim, padding_value=0):
"""
cat along dim, while pad to max for all other dims
"""
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in tensors[1:]
), 'All tensors must have the same number of dimensions'
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
out_size[dim] = sum(t.shape[dim] for t in tensors)
output = tensors[0].new_full(out_size, padding_value)
index = 0
for t in tensors:
# Create a slice list where every dimension except dim is full slice
slices = [slice(0, t.shape[d]) for d in range(ndim)]
# Update only the concat dimension slice
slices[dim] = slice(index, index + t.shape[dim])
output[slices] = t
index += t.shape[dim]
return output
class MultiturnAudioDataset(BaseAudioDataset):
def __init__(self, processor, split="train", sampling_rate=16000,json_path="",text_only=False, debug=False):
super().__init__(processor, split, sampling_rate, debug)
from llamafactory.data.template import Llama2Template,parse_template
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from llamafactory.data.mm_plugin import get_mm_plugin
import json
self.train=False
self.text_only=text_only
with open(json_path) as f:
js_data = json.load(f)
test_len = min(len(js_data)*0.2,200)
if split=='train':
self.train=True
js_data = js_data[:int(len(js_data)-test_len)]
else:
js_data = js_data[-test_len:]
for conv in js_data:
for mess in conv['conversations']:
if 'audio_path' in mess:
mess['audio_path'] = mess['audio_path'].replace('/home/jeff/codes/llm/InCar/srdc_generate_tts/','/mnt/jeff/InCar/data/multiturn_data/')
default_system = ""#"""You are a helpful assistant that determines how to solve problems based on user needs and converts user speech into text.\n"""
self.template=Llama2Template(
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["{{content}}<end_of_turn>\n"], tool_format="default"),
format_tools = ToolFormatter(tool_format="default"),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
default_system=default_system,
thought_words=("<think>", "</think>"),
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin(name="base"),
enable_thinking=False
)
self.set_dataset_name(f"MultiturnCostumData")
self.data = []
self.text_only_data = []
for conv in js_data:
tools = conv['tools'] if 'tools' in conv else ""
system = conv['system'] if 'system' in conv else default_system
tmp = {
'tools':tools,
'system':system,
'messages':[],
}
for i,mess in enumerate(conv['conversations']):
tmp['messages'].append(mess)
if mess['from']=='human':
tmp['messages'].append(conv['conversations'][i+1])
d = deepcopy(tmp)
if not self.text_only and 'audio_path' in mess:
d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0]
self.data.append(d)
else:
self.text_only_data.append(deepcopy(tmp))
tmp['messages'].pop()
elif mess['from']=='observation':
tmp['messages'].append(conv['conversations'][i+1])
d = deepcopy(tmp)
self.text_only_data.append(d)
tmp['messages'].pop()
if text_only:
self.data=self.text_only_data
def prepare_multiturn_model_inputs(self, audio_array, messages, system="", tools=""):
ANSWER_SUFFIX = "<end_of_turn>"
prompt = ""
answer_text = ""
user_transcribe = ""
audio_paths = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0:
elements += self.template.format_prefix.apply()
if system or tools:
tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.template.format_system.apply(content=(system + tool_text))[0]
elements += system_text
if message["from"] == "human":
if i==len(messages)-2 and not self.text_only:
user_transcribe = message["value"]
elements += self.template.format_user.apply(content='<start_of_audio>')
else:
elements += self.template.format_user.apply(content=message["value"])
if not self.text_only:
audio_paths.append(message['audio_path'])
elif message["from"] == "gpt":
elements += self.template.format_assistant.apply(content=message["value"])
elif message["from"] == "observation":
elements += self.template.format_observation.apply(content=message["value"])
elif message["from"] == "function_call":
elements += self.template.format_function.apply(content=message["value"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["from"]))
for elem in elements:
ele_str = ""
if isinstance(elem, str):
ele_str=elem
elif isinstance(elem, set):
if "bos_token" in elem and self.processor.tokenizer.bos_token_id is not None:
ele_str = self.processor.tokenizer.bos_token
elif "eos_token" in elem and self.processor.tokenizer.eos_token_id is not None:
ele_str = self.processor.tokenizer.eos_token
if i == len(messages)-1:
answer_text+=ele_str
else:
prompt+=ele_str
if type(audio_array)!=type(None):
inputs = self.processor(
text=prompt,
audio=[audio_array],
add_special_tokens=False,
return_tensors='pt'
)
answer = "\nUser transcribe is : {};\nGPT output is : {}{}".format(user_transcribe,answer_text,ANSWER_SUFFIX)
else:
inputs = self.processor(
text=prompt,
audio=None,
add_special_tokens=False,
return_tensors='pt'
)
answer = f"{answer_text}{ANSWER_SUFFIX}"
# print('user_transcribe',user_transcribe)
# print('answer_text', answer)
# print('prompt',prompt)
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
if self.debug:
self.debug = False
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
if self.training:
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
labels = torch.full_like(input_ids, _IGNORE_INDEX)
labels[:, -answer_ids.shape[1]:] = answer_ids
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
else:
input_ids = inputs.input_ids
labels = answer_ids
token_type_ids = inputs.token_type_ids
if type(audio_array)!=type(None):
if not self.train:
return {
"audio_path": audio_paths,
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'input_audio_embeds': inputs.input_audio_embeds,
'audio_embed_sizes': inputs.audio_embed_sizes,
'input_modes': inputs.input_modes,
}
else:
return {
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'input_audio_embeds': inputs.input_audio_embeds,
'audio_embed_sizes': inputs.audio_embed_sizes,
'input_modes': inputs.input_modes,
}
else:
return {
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'input_audio_embeds': None,
'audio_embed_sizes': None,
'input_modes': inputs.input_modes,
}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
return self.prepare_multiturn_model_inputs(
audio_array=data["audio_array"] if "audio_array" in data else None,
messages=data['messages'],
system=data["system"],
tools=data["tools"]
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
INSTRUCTION = {
"ast": [
"Translate the audio to {0}.",
"Translate the audio clip into {0}.",
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
"Translate the provided audio file into {0}.",
"Convert the audio speech to {0} text.",
"Write an {0} translation of the audio file.",
"Translate spoken words from the audio into {0}.",
"Create an {0} version of the audio content.",
"Produce an accurate {0} translation of the audio.",
"Extract speech from the audio and translate it to {0}.",
"Turn the audio into readable {0} text.",
"Write all spoken content from the audio in {0}.",
"Generate an {0} translation of the speech in the file.",
"Convert the recording into {0} text.",
"Accurately translate the audio recording to {0}.",
"Write down dialogue from the given audio in {0}.",
"Translate all speech in this audio file to {0}.",
"Create an accurate {0} version of the speech.",
"Perform a complete {0} translation of the audio."
],
"asr": [
"Transcribe the audio clip into text.",
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
"Transcribe the provided audio file into text.",
"Convert the audio speech to text.",
"Write a transcript of the audio file.",
"Transcribe spoken words from the audio.",
"Create a text version of the audio content.",
"Produce a verbatim transcript of the audio.",
"Extract and transcribe speech from the audio.",
"Turn the audio into readable text.",
"Write all spoken words from the audio.",
"Generate a transcript of the speech in the file.",
"Convert the recording into a text transcript.",
"Accurately transcribe the audio recording.",
"Write down dialogue from the given audio.",
"Transcribe all speech in this audio file.",
"Create an accurate text version of the speech.",
"Perform a complete transcription of the audio."
],
}