|
|
import datasets |
|
|
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data" |
|
|
import os |
|
|
os.environ['HF_HOME'] = '/mnt/jeff/huggingface' |
|
|
|
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import sacrebleu |
|
|
|
|
|
from datasets import load_dataset |
|
|
from torch.utils.data import Dataset, ConcatDataset |
|
|
from tqdm import tqdm |
|
|
from transformers import ( |
|
|
BatchFeature, |
|
|
) |
|
|
import pandas as pd |
|
|
import soundfile as sf |
|
|
from datasets import Audio |
|
|
import random |
|
|
from copy import deepcopy |
|
|
import torchaudio |
|
|
|
|
|
ANSWER_SUFFIX = "<end_of_turn>" |
|
|
_IGNORE_INDEX = -100 |
|
|
class BaseAudioDataset(Dataset): |
|
|
def __init__(self, processor, split, sampling_rate=16000, debug=False): |
|
|
self.processor = processor |
|
|
self.training = "train" in split or 'other' in split |
|
|
self.debug = debug |
|
|
self.sampling_rate = sampling_rate |
|
|
self.name = "" |
|
|
|
|
|
def set_dataset_name(self, name): |
|
|
self.name = name |
|
|
|
|
|
@staticmethod |
|
|
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True): |
|
|
original_size = len(data) |
|
|
|
|
|
data = data.cast_column(audio_field, Audio(decode=False)) |
|
|
|
|
|
def identify_corrupted_files(example): |
|
|
try: |
|
|
sf.read(example[audio_field]["path"]) |
|
|
|
|
|
for field in text_fields: |
|
|
if field in example and example[field].replace('"', '') == "": |
|
|
return False |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
data = data.filter(identify_corrupted_files, num_proc=16) |
|
|
validated_size = len(data) |
|
|
|
|
|
|
|
|
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True)) |
|
|
|
|
|
if debug: |
|
|
print(f"Dataset: {dataset_name}") |
|
|
print(f"Original data nums: {original_size}") |
|
|
print(f"After filtering data nums: {validated_size}") |
|
|
print(f"Filtering ratio: {validated_size/original_size:.2%}") |
|
|
|
|
|
return data |
|
|
|
|
|
@staticmethod |
|
|
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True): |
|
|
original_size = len(data) |
|
|
|
|
|
def filter_audio_by_length(example): |
|
|
try: |
|
|
audio = example[audio_field]['array'] |
|
|
channel = 1 |
|
|
if hasattr(audio, 'ndim') and audio.ndim > 1: |
|
|
channel = audio.ndim |
|
|
audio = audio.squeeze() |
|
|
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel |
|
|
return min_sec <= audio_length <= max_sec |
|
|
except Exception as e: |
|
|
if debug: |
|
|
print(f"Error : {str(e)[:100]}... - sample excluded") |
|
|
return False |
|
|
|
|
|
data = data.filter(filter_audio_by_length, num_proc=16) |
|
|
filtered_size = len(data) |
|
|
|
|
|
if debug: |
|
|
print(f"Before Length Filtering data nums: {original_size}") |
|
|
print(f"After Length Filtering data nums: {filtered_size}") |
|
|
print(f"Filtering ratio: {filtered_size/original_size:.2%}") |
|
|
|
|
|
return data |
|
|
|
|
|
def prepare_model_inputs(self, audio_array, instruction, answer_text): |
|
|
user_message = { |
|
|
'role': 'user', |
|
|
'content': '<start_of_audio>' + instruction, |
|
|
} |
|
|
prompt = self.processor.tokenizer.apply_chat_template( |
|
|
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True |
|
|
) |
|
|
|
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
audio=[audio_array], |
|
|
add_special_tokens=False, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
answer = f"{answer_text}{ANSWER_SUFFIX}" |
|
|
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids |
|
|
|
|
|
if self.debug: |
|
|
self.debug = False |
|
|
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR' |
|
|
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else "" |
|
|
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n") |
|
|
print(f"INPUT_MODE: {inputs.input_modes[0].item()}") |
|
|
|
|
|
if self.training: |
|
|
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1) |
|
|
labels = torch.full_like(input_ids, _IGNORE_INDEX) |
|
|
labels[:, -answer_ids.shape[1]:] = answer_ids |
|
|
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1])) |
|
|
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1) |
|
|
else: |
|
|
input_ids = inputs.input_ids |
|
|
labels = answer_ids |
|
|
token_type_ids = inputs.token_type_ids |
|
|
|
|
|
return { |
|
|
'input_ids': input_ids, |
|
|
'labels': labels, |
|
|
'token_type_ids': token_type_ids, |
|
|
'input_audio_embeds': inputs.input_audio_embeds, |
|
|
'audio_embed_sizes': inputs.audio_embed_sizes, |
|
|
'input_modes': inputs.input_modes, |
|
|
} |
|
|
|
|
|
|
|
|
class LibriSpeechDataset(BaseAudioDataset): |
|
|
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False): |
|
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
|
|
self.set_dataset_name(f"LibriSpeech_{subset}") |
|
|
|
|
|
self.ast = False |
|
|
self.lang = "en" |
|
|
|
|
|
|
|
|
self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr", |
|
|
subset, |
|
|
split=split, |
|
|
trust_remote_code=True, |
|
|
cache_dir=Path("/mnt/jeff/InCar/data") |
|
|
) |
|
|
|
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
|
|
|
self.instruction = random.choice(INSTRUCTION["asr"]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = self.data[idx] |
|
|
|
|
|
|
|
|
answer_text = data["text"].replace('"', '') |
|
|
|
|
|
return self.prepare_model_inputs( |
|
|
data["audio"]["array"], |
|
|
self.instruction, |
|
|
answer_text |
|
|
) |
|
|
|
|
|
|
|
|
class CommonVoiceDataset(BaseAudioDataset): |
|
|
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False): |
|
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
|
|
self.set_dataset_name(f"CommonVoice_{source_lang}") |
|
|
|
|
|
self.ast = False |
|
|
self.lang=source_lang |
|
|
|
|
|
|
|
|
if source_lang=="zh-TW": |
|
|
data_path = "/mnt/jeff/InCar/data/common_voice_16_1" |
|
|
else: |
|
|
data_path = "/mnt/jeff/InCar/data/common_voice_17_0" |
|
|
self.data = load_dataset(data_path, |
|
|
source_lang, |
|
|
split=split, |
|
|
trust_remote_code=True, |
|
|
cache_dir=Path("/mnt/jeff/InCar/data") |
|
|
) |
|
|
def prepare_dataset(batch): |
|
|
"""Function to preprocess the dataset with the .map method""" |
|
|
transcription = batch["sentence"] |
|
|
|
|
|
if transcription.startswith('"') and transcription.endswith('"'): |
|
|
|
|
|
transcription = transcription[1:-1] |
|
|
|
|
|
if transcription[-1] not in [".", "?", "!"]: |
|
|
|
|
|
transcription = transcription + "." |
|
|
|
|
|
batch["sentence"] = transcription |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
import opencc |
|
|
converter = opencc.OpenCC('s2tw.json') |
|
|
def To_zhTW(batch): |
|
|
|
|
|
transcription = converter.convert(batch["sentence"]) |
|
|
batch["sentence"] = transcription |
|
|
|
|
|
return batch |
|
|
self.data = self.data.map(prepare_dataset, desc="preprocess dataset") |
|
|
if source_lang=='zh-CN': |
|
|
self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW") |
|
|
|
|
|
|
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
if source_lang == "zh-TW" and split=='train': |
|
|
import torchaudio |
|
|
from torchaudio import transforms |
|
|
import copy |
|
|
import pickle |
|
|
import os |
|
|
def subsample(batch): |
|
|
batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000) |
|
|
batch['audio']['sampling_rate']=16000 |
|
|
return batch |
|
|
def TW_data_augment_fast(batch): |
|
|
speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1]) |
|
|
new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0] |
|
|
batch['audio']['array'] = new_array_fast |
|
|
return batch |
|
|
def TW_data_augment_slow(batch): |
|
|
speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9]) |
|
|
new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0] |
|
|
batch['audio']['array'] = new_array_slow |
|
|
return batch |
|
|
|
|
|
fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl' |
|
|
if not os.path.exists(fast_path): |
|
|
data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast") |
|
|
with open(fast_path,'wb') as f: |
|
|
pickle.dump(data_fast,f) |
|
|
else: |
|
|
with open(fast_path,'rb') as f: |
|
|
data_fast=pickle.load(f) |
|
|
|
|
|
slow_path = '/mnt/jeff/InCar/data/data_slow.pkl' |
|
|
if not os.path.exists(slow_path): |
|
|
data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow") |
|
|
with open(slow_path,'wb') as f: |
|
|
pickle.dump(data_slow,f) |
|
|
else: |
|
|
with open(slow_path,'rb') as f: |
|
|
data_slow=pickle.load(f) |
|
|
self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow] |
|
|
|
|
|
|
|
|
self.instruction = random.choice(INSTRUCTION["asr"]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = self.data[idx] |
|
|
|
|
|
answer_text = data["sentence"] |
|
|
return self.prepare_model_inputs( |
|
|
data["audio"]["array"], |
|
|
self.instruction, |
|
|
answer_text |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class FleursDataset(BaseAudioDataset): |
|
|
def __init__(self, processor, split, source_lang, target_lang=None, |
|
|
mode="asr", sampling_rate=16000, debug=False): |
|
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
|
|
self.set_dataset_name("Fleurs") |
|
|
|
|
|
if mode not in ["asr", "ast"]: |
|
|
raise ValueError("mode must be 'asr' or 'ast'.") |
|
|
|
|
|
self.mode = mode |
|
|
self.ast = (mode == "ast") |
|
|
self.source_lang = source_lang |
|
|
|
|
|
|
|
|
self.lang_names = { |
|
|
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese' |
|
|
} |
|
|
|
|
|
|
|
|
self.data = load_dataset("/mnt/jeff/InCar/data/fleurs", |
|
|
source_lang, |
|
|
split=split, |
|
|
trust_remote_code=True, |
|
|
cache_dir=Path("/mnt/jeff/InCar/data") |
|
|
) |
|
|
import opencc |
|
|
converter = opencc.OpenCC('s2tw.json') |
|
|
def prepare_dataset(batch): |
|
|
transcription = converter.convert(batch["transcription"]) |
|
|
batch["transcription"] = transcription |
|
|
|
|
|
return batch |
|
|
if (source_lang=="cmn_hans_cn"): |
|
|
self.data = self.data.map(prepare_dataset, desc="preprocess dataset") |
|
|
|
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
self.target_lang_name = "" |
|
|
|
|
|
if self.ast: |
|
|
if target_lang is None: |
|
|
raise ValueError("AST mode requires target_lang.") |
|
|
|
|
|
self.target_lang = target_lang |
|
|
self.lang = f"{source_lang}_{target_lang}" |
|
|
|
|
|
|
|
|
target_data = load_dataset("/mnt/jeff/InCar/data/fleurs", |
|
|
target_lang, |
|
|
split=split, |
|
|
trust_remote_code=True, |
|
|
cache_dir=Path("/mnt/jeff/InCar/data") |
|
|
) |
|
|
if target_lang=="cmn_hans_cn": |
|
|
target_data=target_data.map(prepare_dataset, desc="preprocess dataset") |
|
|
source_dict = {item['id']: item for item in self.data} |
|
|
target_dict = {item['id']: item for item in target_data} |
|
|
|
|
|
|
|
|
common_ids = set(source_dict.keys()) & set(target_dict.keys()) |
|
|
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}") |
|
|
self.data = [ |
|
|
{**source_dict[id], 'translation': target_dict[id]['transcription']} |
|
|
for id in common_ids |
|
|
] |
|
|
|
|
|
|
|
|
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize()) |
|
|
self.instruction = random.choice(INSTRUCTION["ast"]) |
|
|
else: |
|
|
|
|
|
self.lang = source_lang |
|
|
self.instruction = random.choice(INSTRUCTION["asr"]) |
|
|
|
|
|
if self.debug: |
|
|
print(f"FLEURS dataset loaded: {self.mode.upper()} mode") |
|
|
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})") |
|
|
if self.ast: |
|
|
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})") |
|
|
print(f"dataset size: {len(self.data)}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = self.data[idx] |
|
|
audio_array = data["audio"]["array"] |
|
|
|
|
|
if self.ast: |
|
|
answer_text = data["translation"] |
|
|
else: |
|
|
answer_text = data["transcription"] |
|
|
|
|
|
return self.prepare_model_inputs( |
|
|
audio_array, |
|
|
self.instruction.format(self.target_lang_name), |
|
|
answer_text |
|
|
) |
|
|
|
|
|
class TWCostumData(BaseAudioDataset): |
|
|
|
|
|
def __init__(self, processor, split="train", sampling_rate=16000,csv_path="", debug=False): |
|
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
import pandas as pd |
|
|
from datasets import Dataset, Audio |
|
|
|
|
|
|
|
|
df = pd.read_csv(csv_path).fillna('') |
|
|
|
|
|
|
|
|
self.set_dataset_name(f"TWCostumData") |
|
|
self.data = Dataset.from_dict( |
|
|
{ |
|
|
"audio": [audio for audio in df['audio']], |
|
|
"sentence": [text for text in df['text']] |
|
|
} |
|
|
).cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
|
|
|
|
|
self.instruction = random.choice(INSTRUCTION["asr"]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = self.data[idx] |
|
|
|
|
|
answer_text = data["sentence"] |
|
|
return self.prepare_model_inputs( |
|
|
data["audio"]["array"], |
|
|
self.instruction, |
|
|
answer_text |
|
|
) |
|
|
|
|
|
class TWCostumDataTasks(BaseAudioDataset): |
|
|
|
|
|
def __init__(self, processor, split="train", sampling_rate=16000,json_path="", debug=False): |
|
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
import pandas as pd |
|
|
from datasets import Dataset, Audio |
|
|
|
|
|
with open(json_path) as f: |
|
|
js_data = json.load(f) |
|
|
|
|
|
raw_data = { |
|
|
"audio": [], |
|
|
"sentence": [] |
|
|
} |
|
|
for conv in js_data: |
|
|
for mess in conv['conversations']: |
|
|
if 'audio_path' in mess: |
|
|
raw_data['audio'].append(mess['audio_path']) |
|
|
raw_data['sentence'].append(mess["value"]) |
|
|
|
|
|
|
|
|
self.set_dataset_name("TWCostumDataTasks"+json_path) |
|
|
self.data = Dataset.from_dict(raw_data).cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
|
|
|
|
|
self.instruction = random.choice(INSTRUCTION["asr"]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = self.data[idx] |
|
|
|
|
|
answer_text = data["sentence"] |
|
|
return self.prepare_model_inputs( |
|
|
data["audio"]["array"], |
|
|
self.instruction, |
|
|
answer_text |
|
|
) |
|
|
|
|
|
def covost_collate_fn(batch): |
|
|
input_ids_list = [] |
|
|
labels_list = [] |
|
|
token_type_ids_list = [] |
|
|
input_audio_embeds_list = [] |
|
|
audio_embed_sizes_list = [] |
|
|
audio_attention_mask_list = [] |
|
|
input_modes_list = [] |
|
|
audio_paths = [] |
|
|
for inputs in batch: |
|
|
if 'audio_path' in inputs: |
|
|
audio_paths.append(inputs['audio_path']) |
|
|
input_ids_list.append(inputs['input_ids'][0]) |
|
|
labels_list.append(inputs['labels'][0]) |
|
|
token_type_ids_list.append(inputs['token_type_ids'][0]) |
|
|
if inputs['input_modes']==2: |
|
|
input_audio_embeds_list.append(inputs['input_audio_embeds']) |
|
|
audio_embed_sizes_list.append(inputs['audio_embed_sizes']) |
|
|
audio_attention_mask_list.append( |
|
|
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_modes_list.append(inputs['input_modes']) |
|
|
|
|
|
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0) |
|
|
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) |
|
|
labels = pad_sequence(labels_list, padding_side='left', padding_value=0) |
|
|
audio_attention_mask = ( |
|
|
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False) |
|
|
if len(audio_attention_mask_list) > 1 |
|
|
else None |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_mask = (input_ids != 0).long() |
|
|
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) if len(input_audio_embeds_list)>0 else None |
|
|
audio_embed_sizes = torch.cat(audio_embed_sizes_list) if len(audio_embed_sizes_list)>0 else None |
|
|
input_modes = torch.cat(input_modes_list) |
|
|
if len(audio_paths)>0: |
|
|
return BatchFeature( |
|
|
{ |
|
|
"audio_path": audio_paths, |
|
|
'input_ids': input_ids, |
|
|
'labels': labels, |
|
|
'token_type_ids': token_type_ids, |
|
|
'attention_mask': attention_mask, |
|
|
'input_audio_embeds': input_audio_embeds, |
|
|
'audio_embed_sizes': audio_embed_sizes, |
|
|
'audio_attention_mask': audio_attention_mask, |
|
|
'input_modes': input_modes, |
|
|
} |
|
|
) |
|
|
else: |
|
|
return BatchFeature( |
|
|
{ |
|
|
'input_ids': input_ids, |
|
|
'labels': labels, |
|
|
'token_type_ids': token_type_ids, |
|
|
'attention_mask': attention_mask, |
|
|
'input_audio_embeds': input_audio_embeds, |
|
|
'audio_embed_sizes': audio_embed_sizes, |
|
|
'audio_attention_mask': audio_attention_mask, |
|
|
'input_modes': input_modes, |
|
|
} |
|
|
) |
|
|
|
|
|
def pad_sequence(sequences, padding_side='left', padding_value=0): |
|
|
""" |
|
|
Pad a list of sequences to the same length. |
|
|
sequences: list of tensors in [seq_len, *] shape |
|
|
""" |
|
|
assert padding_side in ['right', 'left'] |
|
|
max_size = sequences[0].size() |
|
|
trailing_dims = max_size[1:] |
|
|
max_len = max(len(seq) for seq in sequences) |
|
|
batch_size = len(sequences) |
|
|
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) |
|
|
for i, seq in enumerate(sequences): |
|
|
length = seq.size(0) |
|
|
if padding_side == 'right': |
|
|
output.data[i, :length] = seq |
|
|
else: |
|
|
output.data[i, -length:] = seq |
|
|
return output |
|
|
|
|
|
def cat_with_pad(tensors, dim, padding_value=0): |
|
|
""" |
|
|
cat along dim, while pad to max for all other dims |
|
|
""" |
|
|
ndim = tensors[0].dim() |
|
|
assert all( |
|
|
t.dim() == ndim for t in tensors[1:] |
|
|
), 'All tensors must have the same number of dimensions' |
|
|
|
|
|
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] |
|
|
out_size[dim] = sum(t.shape[dim] for t in tensors) |
|
|
output = tensors[0].new_full(out_size, padding_value) |
|
|
|
|
|
index = 0 |
|
|
for t in tensors: |
|
|
|
|
|
slices = [slice(0, t.shape[d]) for d in range(ndim)] |
|
|
|
|
|
slices[dim] = slice(index, index + t.shape[dim]) |
|
|
|
|
|
output[slices] = t |
|
|
index += t.shape[dim] |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class MultiturnAudioDataset(BaseAudioDataset): |
|
|
def __init__(self, processor, split="train", sampling_rate=16000,json_path="",text_only=False, debug=False): |
|
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
from llamafactory.data.template import Llama2Template,parse_template |
|
|
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter |
|
|
from llamafactory.data.mm_plugin import get_mm_plugin |
|
|
import json |
|
|
self.train=False |
|
|
self.text_only=text_only |
|
|
with open(json_path) as f: |
|
|
js_data = json.load(f) |
|
|
test_len = min(len(js_data)*0.2,200) |
|
|
if split=='train': |
|
|
self.train=True |
|
|
js_data = js_data[:int(len(js_data)-test_len)] |
|
|
else: |
|
|
js_data = js_data[-test_len:] |
|
|
for conv in js_data: |
|
|
for mess in conv['conversations']: |
|
|
if 'audio_path' in mess: |
|
|
mess['audio_path'] = mess['audio_path'].replace('/home/jeff/codes/llm/InCar/srdc_generate_tts/','/mnt/jeff/InCar/data/multiturn_data/') |
|
|
default_system = "" |
|
|
self.template=Llama2Template( |
|
|
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), |
|
|
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]), |
|
|
format_system=StringFormatter(slots=["{{content}}\n\n"]), |
|
|
format_function=FunctionFormatter(slots=["{{content}}<end_of_turn>\n"], tool_format="default"), |
|
|
format_tools = ToolFormatter(tool_format="default"), |
|
|
format_observation=StringFormatter( |
|
|
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] |
|
|
), |
|
|
default_system=default_system, |
|
|
thought_words=("<think>", "</think>"), |
|
|
efficient_eos=False, |
|
|
replace_eos=False, |
|
|
replace_jinja_template=False, |
|
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), |
|
|
stop_words=["<end_of_turn>"], |
|
|
mm_plugin=get_mm_plugin(name="base"), |
|
|
enable_thinking=False |
|
|
) |
|
|
|
|
|
self.set_dataset_name(f"MultiturnCostumData") |
|
|
|
|
|
|
|
|
self.data = [] |
|
|
self.text_only_data = [] |
|
|
for conv in js_data: |
|
|
tools = conv['tools'] if 'tools' in conv else "" |
|
|
system = conv['system'] if 'system' in conv else default_system |
|
|
tmp = { |
|
|
'tools':tools, |
|
|
'system':system, |
|
|
'messages':[], |
|
|
} |
|
|
for i,mess in enumerate(conv['conversations']): |
|
|
tmp['messages'].append(mess) |
|
|
if mess['from']=='human': |
|
|
tmp['messages'].append(conv['conversations'][i+1]) |
|
|
d = deepcopy(tmp) |
|
|
if not self.text_only and 'audio_path' in mess: |
|
|
d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0] |
|
|
self.data.append(d) |
|
|
else: |
|
|
self.text_only_data.append(deepcopy(tmp)) |
|
|
tmp['messages'].pop() |
|
|
elif mess['from']=='observation': |
|
|
tmp['messages'].append(conv['conversations'][i+1]) |
|
|
d = deepcopy(tmp) |
|
|
self.text_only_data.append(d) |
|
|
tmp['messages'].pop() |
|
|
if text_only: |
|
|
self.data=self.text_only_data |
|
|
|
|
|
|
|
|
def prepare_multiturn_model_inputs(self, audio_array, messages, system="", tools=""): |
|
|
ANSWER_SUFFIX = "<end_of_turn>" |
|
|
prompt = "" |
|
|
answer_text = "" |
|
|
user_transcribe = "" |
|
|
audio_paths = [] |
|
|
for i, message in enumerate(messages): |
|
|
elements = [] |
|
|
|
|
|
system_text = "" |
|
|
if i == 0: |
|
|
elements += self.template.format_prefix.apply() |
|
|
if system or tools: |
|
|
tool_text = self.template.format_tools.apply(content=tools)[0] if tools else "" |
|
|
system_text = self.template.format_system.apply(content=(system + tool_text))[0] |
|
|
elements += system_text |
|
|
|
|
|
if message["from"] == "human": |
|
|
if i==len(messages)-2 and not self.text_only: |
|
|
user_transcribe = message["value"] |
|
|
elements += self.template.format_user.apply(content='<start_of_audio>') |
|
|
else: |
|
|
elements += self.template.format_user.apply(content=message["value"]) |
|
|
if not self.text_only: |
|
|
audio_paths.append(message['audio_path']) |
|
|
elif message["from"] == "gpt": |
|
|
elements += self.template.format_assistant.apply(content=message["value"]) |
|
|
elif message["from"] == "observation": |
|
|
elements += self.template.format_observation.apply(content=message["value"]) |
|
|
elif message["from"] == "function_call": |
|
|
elements += self.template.format_function.apply(content=message["value"]) |
|
|
else: |
|
|
raise NotImplementedError("Unexpected role: {}".format(message["from"])) |
|
|
|
|
|
|
|
|
for elem in elements: |
|
|
ele_str = "" |
|
|
if isinstance(elem, str): |
|
|
ele_str=elem |
|
|
elif isinstance(elem, set): |
|
|
if "bos_token" in elem and self.processor.tokenizer.bos_token_id is not None: |
|
|
ele_str = self.processor.tokenizer.bos_token |
|
|
elif "eos_token" in elem and self.processor.tokenizer.eos_token_id is not None: |
|
|
ele_str = self.processor.tokenizer.eos_token |
|
|
if i == len(messages)-1: |
|
|
answer_text+=ele_str |
|
|
else: |
|
|
prompt+=ele_str |
|
|
|
|
|
|
|
|
if type(audio_array)!=type(None): |
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
audio=[audio_array], |
|
|
add_special_tokens=False, |
|
|
return_tensors='pt' |
|
|
) |
|
|
answer = "\nUser transcribe is : {};\nGPT output is : {}{}".format(user_transcribe,answer_text,ANSWER_SUFFIX) |
|
|
else: |
|
|
inputs = self.processor( |
|
|
text=prompt, |
|
|
audio=None, |
|
|
add_special_tokens=False, |
|
|
return_tensors='pt' |
|
|
) |
|
|
answer = f"{answer_text}{ANSWER_SUFFIX}" |
|
|
|
|
|
|
|
|
|
|
|
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids |
|
|
|
|
|
if self.debug: |
|
|
self.debug = False |
|
|
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR' |
|
|
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else "" |
|
|
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n") |
|
|
print(f"INPUT_MODE: {inputs.input_modes[0].item()}") |
|
|
|
|
|
if self.training: |
|
|
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1) |
|
|
labels = torch.full_like(input_ids, _IGNORE_INDEX) |
|
|
labels[:, -answer_ids.shape[1]:] = answer_ids |
|
|
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1])) |
|
|
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1) |
|
|
else: |
|
|
input_ids = inputs.input_ids |
|
|
labels = answer_ids |
|
|
token_type_ids = inputs.token_type_ids |
|
|
if type(audio_array)!=type(None): |
|
|
if not self.train: |
|
|
return { |
|
|
"audio_path": audio_paths, |
|
|
'input_ids': input_ids, |
|
|
'labels': labels, |
|
|
'token_type_ids': token_type_ids, |
|
|
'input_audio_embeds': inputs.input_audio_embeds, |
|
|
'audio_embed_sizes': inputs.audio_embed_sizes, |
|
|
'input_modes': inputs.input_modes, |
|
|
|
|
|
} |
|
|
else: |
|
|
return { |
|
|
'input_ids': input_ids, |
|
|
'labels': labels, |
|
|
'token_type_ids': token_type_ids, |
|
|
'input_audio_embeds': inputs.input_audio_embeds, |
|
|
'audio_embed_sizes': inputs.audio_embed_sizes, |
|
|
'input_modes': inputs.input_modes, |
|
|
} |
|
|
else: |
|
|
return { |
|
|
'input_ids': input_ids, |
|
|
'labels': labels, |
|
|
'token_type_ids': token_type_ids, |
|
|
'input_audio_embeds': None, |
|
|
'audio_embed_sizes': None, |
|
|
'input_modes': inputs.input_modes, |
|
|
} |
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = self.data[idx] |
|
|
return self.prepare_multiturn_model_inputs( |
|
|
audio_array=data["audio_array"] if "audio_array" in data else None, |
|
|
messages=data['messages'], |
|
|
system=data["system"], |
|
|
tools=data["tools"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
INSTRUCTION = { |
|
|
"ast": [ |
|
|
"Translate the audio to {0}.", |
|
|
"Translate the audio clip into {0}.", |
|
|
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.", |
|
|
"Translate the provided audio file into {0}.", |
|
|
"Convert the audio speech to {0} text.", |
|
|
"Write an {0} translation of the audio file.", |
|
|
"Translate spoken words from the audio into {0}.", |
|
|
"Create an {0} version of the audio content.", |
|
|
"Produce an accurate {0} translation of the audio.", |
|
|
"Extract speech from the audio and translate it to {0}.", |
|
|
"Turn the audio into readable {0} text.", |
|
|
"Write all spoken content from the audio in {0}.", |
|
|
"Generate an {0} translation of the speech in the file.", |
|
|
"Convert the recording into {0} text.", |
|
|
"Accurately translate the audio recording to {0}.", |
|
|
"Write down dialogue from the given audio in {0}.", |
|
|
"Translate all speech in this audio file to {0}.", |
|
|
"Create an accurate {0} version of the speech.", |
|
|
"Perform a complete {0} translation of the audio." |
|
|
], |
|
|
"asr": [ |
|
|
"Transcribe the audio clip into text.", |
|
|
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.", |
|
|
"Transcribe the provided audio file into text.", |
|
|
"Convert the audio speech to text.", |
|
|
"Write a transcript of the audio file.", |
|
|
"Transcribe spoken words from the audio.", |
|
|
"Create a text version of the audio content.", |
|
|
"Produce a verbatim transcript of the audio.", |
|
|
"Extract and transcribe speech from the audio.", |
|
|
"Turn the audio into readable text.", |
|
|
"Write all spoken words from the audio.", |
|
|
"Generate a transcript of the speech in the file.", |
|
|
"Convert the recording into a text transcript.", |
|
|
"Accurately transcribe the audio recording.", |
|
|
"Write down dialogue from the given audio.", |
|
|
"Transcribe all speech in this audio file.", |
|
|
"Create an accurate text version of the speech.", |
|
|
"Perform a complete transcription of the audio." |
|
|
], |
|
|
} |
|
|
|