Upload 25 files
Browse files- .gitattributes +1 -0
- ASRDataset.py +793 -0
- added_tokens.json +3 -0
- chat_template.json +3 -0
- config.json +118 -0
- configuration_gemma3omni.py +206 -0
- eval.py +635 -0
- eval_multiturn.ipynb +0 -0
- eval_multiturn.py +211 -0
- merge_lora.ipynb +119 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_gemma3omni.py +668 -0
- preprocessing_gemma3omni.py +444 -0
- preprocessor_config.json +41 -0
- processor_config.json +7 -0
- special_tokens_map.json +36 -0
- speech_conformer_encoder.py +0 -0
- tokenizer.json +3 -0
- tokenizer.model +3 -0
- tokenizer_config.json +0 -0
- training.py +883 -0
- training_multiturn.py +329 -0
- training_multiturn_textonly.py +333 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
ASRDataset.py
ADDED
|
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import sacrebleu
|
| 13 |
+
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from transformers import (
|
| 18 |
+
BatchFeature,
|
| 19 |
+
)
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
from datasets import Audio
|
| 23 |
+
import random
|
| 24 |
+
from copy import deepcopy
|
| 25 |
+
import torchaudio
|
| 26 |
+
|
| 27 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 28 |
+
_IGNORE_INDEX = -100
|
| 29 |
+
class BaseAudioDataset(Dataset):
|
| 30 |
+
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 31 |
+
self.processor = processor
|
| 32 |
+
self.training = "train" in split or 'other' in split
|
| 33 |
+
self.debug = debug
|
| 34 |
+
self.sampling_rate = sampling_rate
|
| 35 |
+
self.name = ""
|
| 36 |
+
|
| 37 |
+
def set_dataset_name(self, name):
|
| 38 |
+
self.name = name
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
|
| 42 |
+
original_size = len(data)
|
| 43 |
+
|
| 44 |
+
data = data.cast_column(audio_field, Audio(decode=False))
|
| 45 |
+
|
| 46 |
+
def identify_corrupted_files(example):
|
| 47 |
+
try:
|
| 48 |
+
sf.read(example[audio_field]["path"])
|
| 49 |
+
|
| 50 |
+
for field in text_fields:
|
| 51 |
+
if field in example and example[field].replace('"', '') == "":
|
| 52 |
+
return False
|
| 53 |
+
return True
|
| 54 |
+
except Exception:
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
data = data.filter(identify_corrupted_files, num_proc=16)
|
| 58 |
+
validated_size = len(data)
|
| 59 |
+
|
| 60 |
+
# Audio Decoding
|
| 61 |
+
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
|
| 62 |
+
|
| 63 |
+
if debug:
|
| 64 |
+
print(f"Dataset: {dataset_name}")
|
| 65 |
+
print(f"Original data nums: {original_size}")
|
| 66 |
+
print(f"After filtering data nums: {validated_size}")
|
| 67 |
+
print(f"Filtering ratio: {validated_size/original_size:.2%}")
|
| 68 |
+
|
| 69 |
+
return data
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
|
| 73 |
+
original_size = len(data)
|
| 74 |
+
|
| 75 |
+
def filter_audio_by_length(example):
|
| 76 |
+
try:
|
| 77 |
+
audio = example[audio_field]['array']
|
| 78 |
+
channel = 1
|
| 79 |
+
if hasattr(audio, 'ndim') and audio.ndim > 1:
|
| 80 |
+
channel = audio.ndim
|
| 81 |
+
audio = audio.squeeze()
|
| 82 |
+
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
|
| 83 |
+
return min_sec <= audio_length <= max_sec
|
| 84 |
+
except Exception as e:
|
| 85 |
+
if debug:
|
| 86 |
+
print(f"Error : {str(e)[:100]}... - sample excluded")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
data = data.filter(filter_audio_by_length, num_proc=16)
|
| 90 |
+
filtered_size = len(data)
|
| 91 |
+
|
| 92 |
+
if debug:
|
| 93 |
+
print(f"Before Length Filtering data nums: {original_size}")
|
| 94 |
+
print(f"After Length Filtering data nums: {filtered_size}")
|
| 95 |
+
print(f"Filtering ratio: {filtered_size/original_size:.2%}")
|
| 96 |
+
|
| 97 |
+
return data
|
| 98 |
+
|
| 99 |
+
def prepare_model_inputs(self, audio_array, instruction, answer_text):
|
| 100 |
+
user_message = {
|
| 101 |
+
'role': 'user',
|
| 102 |
+
'content': '<start_of_audio>' + instruction,
|
| 103 |
+
}
|
| 104 |
+
prompt = self.processor.tokenizer.apply_chat_template(
|
| 105 |
+
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
inputs = self.processor(
|
| 109 |
+
text=prompt,
|
| 110 |
+
audio=[audio_array],
|
| 111 |
+
add_special_tokens=False,
|
| 112 |
+
return_tensors='pt'
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
answer = f"{answer_text}{ANSWER_SUFFIX}"
|
| 116 |
+
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
|
| 117 |
+
|
| 118 |
+
if self.debug:
|
| 119 |
+
self.debug = False
|
| 120 |
+
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
|
| 121 |
+
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
|
| 122 |
+
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
|
| 123 |
+
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
|
| 124 |
+
|
| 125 |
+
if self.training:
|
| 126 |
+
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
|
| 127 |
+
labels = torch.full_like(input_ids, _IGNORE_INDEX)
|
| 128 |
+
labels[:, -answer_ids.shape[1]:] = answer_ids
|
| 129 |
+
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
|
| 130 |
+
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
|
| 131 |
+
else:
|
| 132 |
+
input_ids = inputs.input_ids
|
| 133 |
+
labels = answer_ids
|
| 134 |
+
token_type_ids = inputs.token_type_ids
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
'input_ids': input_ids,
|
| 138 |
+
'labels': labels,
|
| 139 |
+
'token_type_ids': token_type_ids,
|
| 140 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 141 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 142 |
+
'input_modes': inputs.input_modes,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
# Libri Speech Dataset Class
|
| 146 |
+
class LibriSpeechDataset(BaseAudioDataset):
|
| 147 |
+
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
|
| 148 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 149 |
+
|
| 150 |
+
self.set_dataset_name(f"LibriSpeech_{subset}")
|
| 151 |
+
# only ASR
|
| 152 |
+
self.ast = False
|
| 153 |
+
self.lang = "en"
|
| 154 |
+
|
| 155 |
+
# load dataset
|
| 156 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
|
| 157 |
+
subset,
|
| 158 |
+
split=split,
|
| 159 |
+
trust_remote_code=True,
|
| 160 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# (Optional) Audio length Filtering
|
| 164 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 165 |
+
|
| 166 |
+
# Instruction Setting
|
| 167 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 168 |
+
|
| 169 |
+
def __len__(self):
|
| 170 |
+
return len(self.data)
|
| 171 |
+
|
| 172 |
+
def __getitem__(self, idx):
|
| 173 |
+
data = self.data[idx]
|
| 174 |
+
|
| 175 |
+
# Libri Speech is only for ASR
|
| 176 |
+
answer_text = data["text"].replace('"', '')
|
| 177 |
+
|
| 178 |
+
return self.prepare_model_inputs(
|
| 179 |
+
data["audio"]["array"],
|
| 180 |
+
self.instruction,
|
| 181 |
+
answer_text
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# common_voice_16_1 dataset
|
| 185 |
+
class CommonVoiceDataset(BaseAudioDataset):
|
| 186 |
+
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
|
| 187 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 188 |
+
|
| 189 |
+
self.set_dataset_name(f"CommonVoice_{source_lang}")
|
| 190 |
+
# only ASR
|
| 191 |
+
self.ast = False
|
| 192 |
+
self.lang=source_lang
|
| 193 |
+
|
| 194 |
+
# load dataset
|
| 195 |
+
if source_lang=="zh-TW":
|
| 196 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
|
| 197 |
+
else:
|
| 198 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
|
| 199 |
+
self.data = load_dataset(data_path,
|
| 200 |
+
source_lang,
|
| 201 |
+
split=split,
|
| 202 |
+
trust_remote_code=True,
|
| 203 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 204 |
+
)
|
| 205 |
+
def prepare_dataset(batch):
|
| 206 |
+
"""Function to preprocess the dataset with the .map method"""
|
| 207 |
+
transcription = batch["sentence"]
|
| 208 |
+
|
| 209 |
+
if transcription.startswith('"') and transcription.endswith('"'):
|
| 210 |
+
# we can remove trailing quotation marks as they do not affect the transcription
|
| 211 |
+
transcription = transcription[1:-1]
|
| 212 |
+
|
| 213 |
+
if transcription[-1] not in [".", "?", "!"]:
|
| 214 |
+
# append a full-stop to sentences that do not end in punctuation
|
| 215 |
+
transcription = transcription + "."
|
| 216 |
+
|
| 217 |
+
batch["sentence"] = transcription
|
| 218 |
+
|
| 219 |
+
return batch
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
import opencc
|
| 223 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 224 |
+
def To_zhTW(batch):
|
| 225 |
+
|
| 226 |
+
transcription = converter.convert(batch["sentence"])
|
| 227 |
+
batch["sentence"] = transcription
|
| 228 |
+
|
| 229 |
+
return batch
|
| 230 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 231 |
+
if source_lang=='zh-CN':
|
| 232 |
+
self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# (Optional) Audio length Filtering
|
| 236 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 237 |
+
|
| 238 |
+
if source_lang == "zh-TW" and split=='train':
|
| 239 |
+
import torchaudio
|
| 240 |
+
from torchaudio import transforms
|
| 241 |
+
import copy
|
| 242 |
+
import pickle
|
| 243 |
+
import os
|
| 244 |
+
def subsample(batch):
|
| 245 |
+
batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
|
| 246 |
+
batch['audio']['sampling_rate']=16000
|
| 247 |
+
return batch
|
| 248 |
+
def TW_data_augment_fast(batch):
|
| 249 |
+
speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
|
| 250 |
+
new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
|
| 251 |
+
batch['audio']['array'] = new_array_fast
|
| 252 |
+
return batch
|
| 253 |
+
def TW_data_augment_slow(batch):
|
| 254 |
+
speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
|
| 255 |
+
new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
|
| 256 |
+
batch['audio']['array'] = new_array_slow
|
| 257 |
+
return batch
|
| 258 |
+
# data = self.data.map(subsample, num_proc=1, desc="subsample")
|
| 259 |
+
fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
|
| 260 |
+
if not os.path.exists(fast_path):
|
| 261 |
+
data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
|
| 262 |
+
with open(fast_path,'wb') as f:
|
| 263 |
+
pickle.dump(data_fast,f)
|
| 264 |
+
else:
|
| 265 |
+
with open(fast_path,'rb') as f:
|
| 266 |
+
data_fast=pickle.load(f)
|
| 267 |
+
|
| 268 |
+
slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
|
| 269 |
+
if not os.path.exists(slow_path):
|
| 270 |
+
data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
|
| 271 |
+
with open(slow_path,'wb') as f:
|
| 272 |
+
pickle.dump(data_slow,f)
|
| 273 |
+
else:
|
| 274 |
+
with open(slow_path,'rb') as f:
|
| 275 |
+
data_slow=pickle.load(f)
|
| 276 |
+
self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
|
| 277 |
+
|
| 278 |
+
# Instruction Setting
|
| 279 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 280 |
+
|
| 281 |
+
def __len__(self):
|
| 282 |
+
return len(self.data)
|
| 283 |
+
|
| 284 |
+
def __getitem__(self, idx):
|
| 285 |
+
data = self.data[idx]
|
| 286 |
+
|
| 287 |
+
answer_text = data["sentence"]
|
| 288 |
+
return self.prepare_model_inputs(
|
| 289 |
+
data["audio"]["array"],
|
| 290 |
+
self.instruction,
|
| 291 |
+
answer_text
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Fleurs Dataset Class
|
| 296 |
+
class FleursDataset(BaseAudioDataset):
|
| 297 |
+
def __init__(self, processor, split, source_lang, target_lang=None,
|
| 298 |
+
mode="asr", sampling_rate=16000, debug=False):
|
| 299 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 300 |
+
|
| 301 |
+
self.set_dataset_name("Fleurs")
|
| 302 |
+
# Mode Setting (ASR or AST)
|
| 303 |
+
if mode not in ["asr", "ast"]:
|
| 304 |
+
raise ValueError("mode must be 'asr' or 'ast'.")
|
| 305 |
+
|
| 306 |
+
self.mode = mode
|
| 307 |
+
self.ast = (mode == "ast")
|
| 308 |
+
self.source_lang = source_lang
|
| 309 |
+
|
| 310 |
+
# Language name mapping (expand if needed)
|
| 311 |
+
self.lang_names = {
|
| 312 |
+
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
# load dataset - source language dataset
|
| 316 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 317 |
+
source_lang,
|
| 318 |
+
split=split,
|
| 319 |
+
trust_remote_code=True,
|
| 320 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 321 |
+
)
|
| 322 |
+
import opencc
|
| 323 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 324 |
+
def prepare_dataset(batch):
|
| 325 |
+
transcription = converter.convert(batch["transcription"])
|
| 326 |
+
batch["transcription"] = transcription
|
| 327 |
+
|
| 328 |
+
return batch
|
| 329 |
+
if (source_lang=="cmn_hans_cn"):
|
| 330 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 331 |
+
|
| 332 |
+
# (Optional) Audio length Filtering
|
| 333 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 334 |
+
self.target_lang_name = ""
|
| 335 |
+
# When AST mode, load target language dataset.
|
| 336 |
+
if self.ast:
|
| 337 |
+
if target_lang is None:
|
| 338 |
+
raise ValueError("AST mode requires target_lang.")
|
| 339 |
+
|
| 340 |
+
self.target_lang = target_lang
|
| 341 |
+
self.lang = f"{source_lang}_{target_lang}"
|
| 342 |
+
|
| 343 |
+
# load dataset - target language dataset (for translation)
|
| 344 |
+
target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 345 |
+
target_lang,
|
| 346 |
+
split=split,
|
| 347 |
+
trust_remote_code=True,
|
| 348 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 349 |
+
)
|
| 350 |
+
if target_lang=="cmn_hans_cn":
|
| 351 |
+
target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
|
| 352 |
+
source_dict = {item['id']: item for item in self.data}
|
| 353 |
+
target_dict = {item['id']: item for item in target_data}
|
| 354 |
+
|
| 355 |
+
# only Common ID, add translation fields
|
| 356 |
+
common_ids = set(source_dict.keys()) & set(target_dict.keys())
|
| 357 |
+
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
|
| 358 |
+
self.data = [
|
| 359 |
+
{**source_dict[id], 'translation': target_dict[id]['transcription']}
|
| 360 |
+
for id in common_ids
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
# Instruction Setting - use target language name
|
| 364 |
+
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 365 |
+
self.instruction = random.choice(INSTRUCTION["ast"])
|
| 366 |
+
else:
|
| 367 |
+
# ASR mode
|
| 368 |
+
self.lang = source_lang
|
| 369 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 370 |
+
|
| 371 |
+
if self.debug:
|
| 372 |
+
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
| 373 |
+
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
|
| 374 |
+
if self.ast:
|
| 375 |
+
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
|
| 376 |
+
print(f"dataset size: {len(self.data)}")
|
| 377 |
+
|
| 378 |
+
def __len__(self):
|
| 379 |
+
return len(self.data)
|
| 380 |
+
|
| 381 |
+
def __getitem__(self, idx):
|
| 382 |
+
data = self.data[idx]
|
| 383 |
+
audio_array = data["audio"]["array"]
|
| 384 |
+
|
| 385 |
+
if self.ast:
|
| 386 |
+
answer_text = data["translation"]
|
| 387 |
+
else:
|
| 388 |
+
answer_text = data["transcription"]
|
| 389 |
+
|
| 390 |
+
return self.prepare_model_inputs(
|
| 391 |
+
audio_array,
|
| 392 |
+
self.instruction.format(self.target_lang_name),
|
| 393 |
+
answer_text
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
class TWCostumData(BaseAudioDataset):
|
| 397 |
+
|
| 398 |
+
def __init__(self, processor, split="train", sampling_rate=16000,csv_path="", debug=False):
|
| 399 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 400 |
+
import pandas as pd
|
| 401 |
+
from datasets import Dataset, Audio
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
df = pd.read_csv(csv_path).fillna('')
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
self.set_dataset_name(f"TWCostumData")
|
| 408 |
+
self.data = Dataset.from_dict(
|
| 409 |
+
{
|
| 410 |
+
"audio": [audio for audio in df['audio']],
|
| 411 |
+
"sentence": [text for text in df['text']]
|
| 412 |
+
}
|
| 413 |
+
).cast_column("audio", Audio(sampling_rate=16000))
|
| 414 |
+
|
| 415 |
+
# Instruction Setting
|
| 416 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 417 |
+
|
| 418 |
+
def __len__(self):
|
| 419 |
+
return len(self.data)
|
| 420 |
+
|
| 421 |
+
def __getitem__(self, idx):
|
| 422 |
+
data = self.data[idx]
|
| 423 |
+
|
| 424 |
+
answer_text = data["sentence"]
|
| 425 |
+
return self.prepare_model_inputs(
|
| 426 |
+
data["audio"]["array"],
|
| 427 |
+
self.instruction,
|
| 428 |
+
answer_text
|
| 429 |
+
)
|
| 430 |
+
def covost_collate_fn(batch):
|
| 431 |
+
input_ids_list = []
|
| 432 |
+
labels_list = []
|
| 433 |
+
token_type_ids_list = []
|
| 434 |
+
input_audio_embeds_list = []
|
| 435 |
+
audio_embed_sizes_list = []
|
| 436 |
+
audio_attention_mask_list = []
|
| 437 |
+
input_modes_list = []
|
| 438 |
+
audio_paths = []
|
| 439 |
+
for inputs in batch:
|
| 440 |
+
if 'audio_path' in inputs:
|
| 441 |
+
audio_paths.append(inputs['audio_path'])
|
| 442 |
+
input_ids_list.append(inputs['input_ids'][0])
|
| 443 |
+
labels_list.append(inputs['labels'][0])
|
| 444 |
+
token_type_ids_list.append(inputs['token_type_ids'][0])
|
| 445 |
+
if inputs['input_modes']==2:
|
| 446 |
+
input_audio_embeds_list.append(inputs['input_audio_embeds'])
|
| 447 |
+
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
|
| 448 |
+
audio_attention_mask_list.append(
|
| 449 |
+
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
|
| 450 |
+
)
|
| 451 |
+
# else:
|
| 452 |
+
# input_audio_embeds_list.append(None)
|
| 453 |
+
# audio_embed_sizes_list.append(None)
|
| 454 |
+
# audio_attention_mask_list.append(None)
|
| 455 |
+
input_modes_list.append(inputs['input_modes'])
|
| 456 |
+
# try:
|
| 457 |
+
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
|
| 458 |
+
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
|
| 459 |
+
labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
|
| 460 |
+
audio_attention_mask = (
|
| 461 |
+
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
|
| 462 |
+
if len(audio_attention_mask_list) > 1
|
| 463 |
+
else None
|
| 464 |
+
)
|
| 465 |
+
# except Exception as e:
|
| 466 |
+
# print(e)
|
| 467 |
+
# print(input_ids_list)
|
| 468 |
+
# print(labels_list)
|
| 469 |
+
# raise
|
| 470 |
+
attention_mask = (input_ids != 0).long()
|
| 471 |
+
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) if len(input_audio_embeds_list)>0 else None
|
| 472 |
+
audio_embed_sizes = torch.cat(audio_embed_sizes_list) if len(audio_embed_sizes_list)>0 else None
|
| 473 |
+
input_modes = torch.cat(input_modes_list)
|
| 474 |
+
if len(audio_paths)>0:
|
| 475 |
+
return BatchFeature(
|
| 476 |
+
{
|
| 477 |
+
"audio_path": audio_paths,
|
| 478 |
+
'input_ids': input_ids,
|
| 479 |
+
'labels': labels,
|
| 480 |
+
'token_type_ids': token_type_ids,
|
| 481 |
+
'attention_mask': attention_mask,
|
| 482 |
+
'input_audio_embeds': input_audio_embeds,
|
| 483 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 484 |
+
'audio_attention_mask': audio_attention_mask,
|
| 485 |
+
'input_modes': input_modes,
|
| 486 |
+
}
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
return BatchFeature(
|
| 490 |
+
{
|
| 491 |
+
'input_ids': input_ids,
|
| 492 |
+
'labels': labels,
|
| 493 |
+
'token_type_ids': token_type_ids,
|
| 494 |
+
'attention_mask': attention_mask,
|
| 495 |
+
'input_audio_embeds': input_audio_embeds,
|
| 496 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 497 |
+
'audio_attention_mask': audio_attention_mask,
|
| 498 |
+
'input_modes': input_modes,
|
| 499 |
+
}
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
def pad_sequence(sequences, padding_side='left', padding_value=0):
|
| 503 |
+
"""
|
| 504 |
+
Pad a list of sequences to the same length.
|
| 505 |
+
sequences: list of tensors in [seq_len, *] shape
|
| 506 |
+
"""
|
| 507 |
+
assert padding_side in ['right', 'left']
|
| 508 |
+
max_size = sequences[0].size()
|
| 509 |
+
trailing_dims = max_size[1:]
|
| 510 |
+
max_len = max(len(seq) for seq in sequences)
|
| 511 |
+
batch_size = len(sequences)
|
| 512 |
+
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
|
| 513 |
+
for i, seq in enumerate(sequences):
|
| 514 |
+
length = seq.size(0)
|
| 515 |
+
if padding_side == 'right':
|
| 516 |
+
output.data[i, :length] = seq
|
| 517 |
+
else:
|
| 518 |
+
output.data[i, -length:] = seq
|
| 519 |
+
return output
|
| 520 |
+
|
| 521 |
+
def cat_with_pad(tensors, dim, padding_value=0):
|
| 522 |
+
"""
|
| 523 |
+
cat along dim, while pad to max for all other dims
|
| 524 |
+
"""
|
| 525 |
+
ndim = tensors[0].dim()
|
| 526 |
+
assert all(
|
| 527 |
+
t.dim() == ndim for t in tensors[1:]
|
| 528 |
+
), 'All tensors must have the same number of dimensions'
|
| 529 |
+
|
| 530 |
+
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
|
| 531 |
+
out_size[dim] = sum(t.shape[dim] for t in tensors)
|
| 532 |
+
output = tensors[0].new_full(out_size, padding_value)
|
| 533 |
+
|
| 534 |
+
index = 0
|
| 535 |
+
for t in tensors:
|
| 536 |
+
# Create a slice list where every dimension except dim is full slice
|
| 537 |
+
slices = [slice(0, t.shape[d]) for d in range(ndim)]
|
| 538 |
+
# Update only the concat dimension slice
|
| 539 |
+
slices[dim] = slice(index, index + t.shape[dim])
|
| 540 |
+
|
| 541 |
+
output[slices] = t
|
| 542 |
+
index += t.shape[dim]
|
| 543 |
+
|
| 544 |
+
return output
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class MultiturnAudioDataset(BaseAudioDataset):
|
| 549 |
+
def __init__(self, processor, split="train", sampling_rate=16000,json_path="",text_only=False, debug=False):
|
| 550 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 551 |
+
from llamafactory.data.template import Llama2Template,parse_template
|
| 552 |
+
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
| 553 |
+
from llamafactory.data.mm_plugin import get_mm_plugin
|
| 554 |
+
import json
|
| 555 |
+
self.train=False
|
| 556 |
+
self.text_only=text_only
|
| 557 |
+
with open(json_path) as f:
|
| 558 |
+
js_data = json.load(f)
|
| 559 |
+
if split=='train':
|
| 560 |
+
self.train=True
|
| 561 |
+
js_data = js_data[:int(len(js_data)*0.8)]
|
| 562 |
+
else:
|
| 563 |
+
js_data = js_data[-int(len(js_data)*0.2):]
|
| 564 |
+
for conv in js_data:
|
| 565 |
+
for mess in conv['conversations']:
|
| 566 |
+
if 'audio_path' in mess:
|
| 567 |
+
mess['audio_path'] = mess['audio_path'].replace('/home/jeff/codes/llm/InCar/srdc_generate_tts/','/mnt/jeff/InCar/data/multiturn_data/')
|
| 568 |
+
default_system = ""#"""You are a helpful assistant that determines how to solve problems based on user needs and converts user speech into text.\n"""
|
| 569 |
+
self.template=Llama2Template(
|
| 570 |
+
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
| 571 |
+
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
| 572 |
+
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
| 573 |
+
format_function=FunctionFormatter(slots=["{{content}}", {"eos_token"}], tool_format="default"),
|
| 574 |
+
format_tools = ToolFormatter(tool_format="default"),
|
| 575 |
+
format_observation=StringFormatter(
|
| 576 |
+
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
| 577 |
+
),
|
| 578 |
+
default_system=default_system,
|
| 579 |
+
thought_words=("<think>", "</think>"),
|
| 580 |
+
efficient_eos=False,
|
| 581 |
+
replace_eos=False,
|
| 582 |
+
replace_jinja_template=False,
|
| 583 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 584 |
+
stop_words=["<end_of_turn>"],
|
| 585 |
+
mm_plugin=get_mm_plugin(name="base"),
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
self.set_dataset_name(f"MultiturnCostumData")
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
self.data = []
|
| 592 |
+
self.text_only_data = []
|
| 593 |
+
for conv in js_data:
|
| 594 |
+
tools = conv['tools'] if 'tools' in conv else ""
|
| 595 |
+
system = conv['system'] if 'system' in conv else default_system
|
| 596 |
+
tmp = {
|
| 597 |
+
'tools':tools,
|
| 598 |
+
'system':system,
|
| 599 |
+
'messages':[],
|
| 600 |
+
}
|
| 601 |
+
for i,mess in enumerate(conv['conversations']):
|
| 602 |
+
tmp['messages'].append(mess)
|
| 603 |
+
if mess['from']=='human':
|
| 604 |
+
tmp['messages'].append(conv['conversations'][i+1])
|
| 605 |
+
d = deepcopy(tmp)
|
| 606 |
+
d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0]
|
| 607 |
+
self.data.append(d)
|
| 608 |
+
if self.text_only:
|
| 609 |
+
self.text_only_data.append(deepcopy(tmp))
|
| 610 |
+
tmp['messages'].pop()
|
| 611 |
+
elif mess['from']=='observation':
|
| 612 |
+
tmp['messages'].append(conv['conversations'][i+1])
|
| 613 |
+
d = deepcopy(tmp)
|
| 614 |
+
self.text_only_data.append(d)
|
| 615 |
+
tmp['messages'].pop()
|
| 616 |
+
if text_only:
|
| 617 |
+
self.data=self.text_only_data
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def prepare_multiturn_model_inputs(self, audio_array, messages, system="", tools=""):
|
| 621 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 622 |
+
prompt = ""
|
| 623 |
+
answer_text = ""
|
| 624 |
+
user_transcribe = ""
|
| 625 |
+
audio_paths = []
|
| 626 |
+
for i, message in enumerate(messages):
|
| 627 |
+
elements = []
|
| 628 |
+
|
| 629 |
+
system_text = ""
|
| 630 |
+
if i == 0:
|
| 631 |
+
elements += self.template.format_prefix.apply()
|
| 632 |
+
if system or tools:
|
| 633 |
+
tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
|
| 634 |
+
system_text = self.template.format_system.apply(content=(system + tool_text))[0]
|
| 635 |
+
|
| 636 |
+
if message["from"] == "human":
|
| 637 |
+
if i==len(messages)-2 and not self.text_only:
|
| 638 |
+
user_transcribe = message["value"]
|
| 639 |
+
elements += self.template.format_user.apply(content=system_text+'<start_of_audio>')
|
| 640 |
+
else:
|
| 641 |
+
elements += self.template.format_user.apply(content=system_text + message["value"])
|
| 642 |
+
audio_paths.append(message['audio_path'])
|
| 643 |
+
elif message["from"] == "gpt":
|
| 644 |
+
elements += self.template.format_assistant.apply(content=message["value"])
|
| 645 |
+
elif message["from"] == "observation":
|
| 646 |
+
elements += self.template.format_observation.apply(content=message["value"])
|
| 647 |
+
elif message["from"] == "function_call":
|
| 648 |
+
elements += self.template.format_function.apply(content=message["value"])
|
| 649 |
+
else:
|
| 650 |
+
raise NotImplementedError("Unexpected role: {}".format(message["from"]))
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
for elem in elements:
|
| 654 |
+
ele_str = ""
|
| 655 |
+
if isinstance(elem, str):
|
| 656 |
+
ele_str=elem
|
| 657 |
+
elif isinstance(elem, set):
|
| 658 |
+
if "bos_token" in elem and self.processor.tokenizer.bos_token_id is not None:
|
| 659 |
+
ele_str = self.processor.tokenizer.bos_token
|
| 660 |
+
elif "eos_token" in elem and self.processor.tokenizer.eos_token_id is not None:
|
| 661 |
+
ele_str = self.processor.tokenizer.eos_token
|
| 662 |
+
if i == len(messages)-1:
|
| 663 |
+
answer_text+=ele_str
|
| 664 |
+
else:
|
| 665 |
+
prompt+=ele_str
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
if type(audio_array)!=type(None):
|
| 669 |
+
inputs = self.processor(
|
| 670 |
+
text=prompt,
|
| 671 |
+
audio=[audio_array],
|
| 672 |
+
add_special_tokens=False,
|
| 673 |
+
return_tensors='pt'
|
| 674 |
+
)
|
| 675 |
+
answer = "\nUser transcribe is : {};\nGPT output is : {}{}".format(user_transcribe,answer_text,ANSWER_SUFFIX)
|
| 676 |
+
else:
|
| 677 |
+
inputs = self.processor(
|
| 678 |
+
text=prompt,
|
| 679 |
+
audio=None,
|
| 680 |
+
add_special_tokens=False,
|
| 681 |
+
return_tensors='pt'
|
| 682 |
+
)
|
| 683 |
+
answer = f"{answer_text}{ANSWER_SUFFIX}"
|
| 684 |
+
# print('user_transcribe',user_transcribe)
|
| 685 |
+
# print('answer_text', answer)
|
| 686 |
+
# print('prompt',prompt)
|
| 687 |
+
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
|
| 688 |
+
|
| 689 |
+
if self.debug:
|
| 690 |
+
self.debug = False
|
| 691 |
+
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
|
| 692 |
+
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
|
| 693 |
+
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
|
| 694 |
+
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
|
| 695 |
+
|
| 696 |
+
if self.training:
|
| 697 |
+
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
|
| 698 |
+
labels = torch.full_like(input_ids, _IGNORE_INDEX)
|
| 699 |
+
labels[:, -answer_ids.shape[1]:] = answer_ids
|
| 700 |
+
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
|
| 701 |
+
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
|
| 702 |
+
else:
|
| 703 |
+
input_ids = inputs.input_ids
|
| 704 |
+
labels = answer_ids
|
| 705 |
+
token_type_ids = inputs.token_type_ids
|
| 706 |
+
if type(audio_array)!=type(None):
|
| 707 |
+
if not self.train:
|
| 708 |
+
return {
|
| 709 |
+
"audio_path": audio_paths,
|
| 710 |
+
'input_ids': input_ids,
|
| 711 |
+
'labels': labels,
|
| 712 |
+
'token_type_ids': token_type_ids,
|
| 713 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 714 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 715 |
+
'input_modes': inputs.input_modes,
|
| 716 |
+
}
|
| 717 |
+
else:
|
| 718 |
+
return {
|
| 719 |
+
'input_ids': input_ids,
|
| 720 |
+
'labels': labels,
|
| 721 |
+
'token_type_ids': token_type_ids,
|
| 722 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 723 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 724 |
+
'input_modes': inputs.input_modes,
|
| 725 |
+
}
|
| 726 |
+
else:
|
| 727 |
+
return {
|
| 728 |
+
'input_ids': input_ids,
|
| 729 |
+
'labels': labels,
|
| 730 |
+
'token_type_ids': token_type_ids,
|
| 731 |
+
'input_audio_embeds': None,
|
| 732 |
+
'audio_embed_sizes': None,
|
| 733 |
+
'input_modes': inputs.input_modes,
|
| 734 |
+
}
|
| 735 |
+
def __len__(self):
|
| 736 |
+
return len(self.data)
|
| 737 |
+
|
| 738 |
+
def __getitem__(self, idx):
|
| 739 |
+
data = self.data[idx]
|
| 740 |
+
return self.prepare_multiturn_model_inputs(
|
| 741 |
+
audio_array=data["audio_array"] if "audio_array" in data else None,
|
| 742 |
+
messages=data['messages'],
|
| 743 |
+
system=data["system"],
|
| 744 |
+
tools=data["tools"]
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 750 |
+
|
| 751 |
+
INSTRUCTION = {
|
| 752 |
+
"ast": [
|
| 753 |
+
"Translate the audio to {0}.",
|
| 754 |
+
"Translate the audio clip into {0}.",
|
| 755 |
+
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
|
| 756 |
+
"Translate the provided audio file into {0}.",
|
| 757 |
+
"Convert the audio speech to {0} text.",
|
| 758 |
+
"Write an {0} translation of the audio file.",
|
| 759 |
+
"Translate spoken words from the audio into {0}.",
|
| 760 |
+
"Create an {0} version of the audio content.",
|
| 761 |
+
"Produce an accurate {0} translation of the audio.",
|
| 762 |
+
"Extract speech from the audio and translate it to {0}.",
|
| 763 |
+
"Turn the audio into readable {0} text.",
|
| 764 |
+
"Write all spoken content from the audio in {0}.",
|
| 765 |
+
"Generate an {0} translation of the speech in the file.",
|
| 766 |
+
"Convert the recording into {0} text.",
|
| 767 |
+
"Accurately translate the audio recording to {0}.",
|
| 768 |
+
"Write down dialogue from the given audio in {0}.",
|
| 769 |
+
"Translate all speech in this audio file to {0}.",
|
| 770 |
+
"Create an accurate {0} version of the speech.",
|
| 771 |
+
"Perform a complete {0} translation of the audio."
|
| 772 |
+
],
|
| 773 |
+
"asr": [
|
| 774 |
+
"Transcribe the audio clip into text.",
|
| 775 |
+
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
|
| 776 |
+
"Transcribe the provided audio file into text.",
|
| 777 |
+
"Convert the audio speech to text.",
|
| 778 |
+
"Write a transcript of the audio file.",
|
| 779 |
+
"Transcribe spoken words from the audio.",
|
| 780 |
+
"Create a text version of the audio content.",
|
| 781 |
+
"Produce a verbatim transcript of the audio.",
|
| 782 |
+
"Extract and transcribe speech from the audio.",
|
| 783 |
+
"Turn the audio into readable text.",
|
| 784 |
+
"Write all spoken words from the audio.",
|
| 785 |
+
"Generate a transcript of the speech in the file.",
|
| 786 |
+
"Convert the recording into a text transcript.",
|
| 787 |
+
"Accurately transcribe the audio recording.",
|
| 788 |
+
"Write down dialogue from the given audio.",
|
| 789 |
+
"Transcribe all speech in this audio file.",
|
| 790 |
+
"Create an accurate text version of the speech.",
|
| 791 |
+
"Perform a complete transcription of the audio."
|
| 792 |
+
],
|
| 793 |
+
}
|
added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image_soft_token>": 262144
|
| 3 |
+
}
|
chat_template.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'audio' -%}\n {{ '<start_of_audio>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
|
| 3 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Gemma3OmniForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"audio_config": {
|
| 6 |
+
"activation": "swish",
|
| 7 |
+
"activation_checkpointing": {
|
| 8 |
+
"interval": 1,
|
| 9 |
+
"module": "transformer",
|
| 10 |
+
"offload": false
|
| 11 |
+
},
|
| 12 |
+
"attention_dim": 1024,
|
| 13 |
+
"attention_heads": 16,
|
| 14 |
+
"batch_norm": false,
|
| 15 |
+
"bias_in_glu": true,
|
| 16 |
+
"causal": true,
|
| 17 |
+
"chunk_size": -1,
|
| 18 |
+
"cnn_layer_norm": true,
|
| 19 |
+
"conv_activation": "swish",
|
| 20 |
+
"conv_glu_type": "swish",
|
| 21 |
+
"depthwise_multiplier": 1,
|
| 22 |
+
"depthwise_seperable_out_channel": 1024,
|
| 23 |
+
"dropout_rate": 0.0,
|
| 24 |
+
"encoder_embedding_config": {
|
| 25 |
+
"input_size": 80
|
| 26 |
+
},
|
| 27 |
+
"ext_pw_kernel_size": 1,
|
| 28 |
+
"ext_pw_out_channel": 1024,
|
| 29 |
+
"input_layer": "nemo_conv",
|
| 30 |
+
"input_size": 80,
|
| 31 |
+
"kernel_size": 3,
|
| 32 |
+
"left_chunk": 18,
|
| 33 |
+
"linear_units": 1536,
|
| 34 |
+
"nemo_conv_settings": {
|
| 35 |
+
"conv_channels": 1024
|
| 36 |
+
},
|
| 37 |
+
"num_blocks": 24,
|
| 38 |
+
"relative_attention_bias_args": {
|
| 39 |
+
"t5_bias_max_distance": 500,
|
| 40 |
+
"type": "t5"
|
| 41 |
+
},
|
| 42 |
+
"time_reduction": 8
|
| 43 |
+
},
|
| 44 |
+
"audio_token_index": 262143,
|
| 45 |
+
"auto_map": {
|
| 46 |
+
"AutoConfig": "configuration_gemma3omni.Gemma3OmniConfig",
|
| 47 |
+
"AutoModel": "modeling_gemma3omni.Gemma3OmniForConditionalGeneration"
|
| 48 |
+
},
|
| 49 |
+
"boa_token_index": 256001,
|
| 50 |
+
"boi_token_index": 255999,
|
| 51 |
+
"eoa_token_index": 256002,
|
| 52 |
+
"eoi_token_index": 256000,
|
| 53 |
+
"eos_token_id": [
|
| 54 |
+
1,
|
| 55 |
+
106
|
| 56 |
+
],
|
| 57 |
+
"image_token_index": 262144,
|
| 58 |
+
"initializer_range": 0.02,
|
| 59 |
+
"mm_tokens_per_image": 256,
|
| 60 |
+
"model_type": "gemma3omni",
|
| 61 |
+
"speech_lora": {
|
| 62 |
+
"dp": 0.01,
|
| 63 |
+
"layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
|
| 64 |
+
"lora_alpha": 320,
|
| 65 |
+
"r": 320,
|
| 66 |
+
"use_rslora": true
|
| 67 |
+
},
|
| 68 |
+
"text_lora": {
|
| 69 |
+
"dp": 0.01,
|
| 70 |
+
"layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
|
| 71 |
+
"lora_alpha": 16,
|
| 72 |
+
"r": 8,
|
| 73 |
+
"use_rslora": true
|
| 74 |
+
},
|
| 75 |
+
"text_config": {
|
| 76 |
+
"attention_bias": false,
|
| 77 |
+
"attention_dropout": 0.0,
|
| 78 |
+
"attn_logit_softcapping": null,
|
| 79 |
+
"cache_implementation": "hybrid",
|
| 80 |
+
"final_logit_softcapping": null,
|
| 81 |
+
"head_dim": 256,
|
| 82 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
| 83 |
+
"hidden_size": 2560,
|
| 84 |
+
"initializer_range": 0.02,
|
| 85 |
+
"intermediate_size": 10240,
|
| 86 |
+
"max_position_embeddings": 131072,
|
| 87 |
+
"model_type": "gemma3_text",
|
| 88 |
+
"num_attention_heads": 8,
|
| 89 |
+
"num_hidden_layers": 34,
|
| 90 |
+
"num_key_value_heads": 4,
|
| 91 |
+
"query_pre_attn_scalar": 256,
|
| 92 |
+
"rms_norm_eps": 1e-06,
|
| 93 |
+
"rope_local_base_freq": 10000.0,
|
| 94 |
+
"rope_scaling": {
|
| 95 |
+
"factor": 8.0,
|
| 96 |
+
"rope_type": "linear"
|
| 97 |
+
},
|
| 98 |
+
"rope_theta": 1000000.0,
|
| 99 |
+
"sliding_window": 1024,
|
| 100 |
+
"sliding_window_pattern": 6,
|
| 101 |
+
"torch_dtype": "float",
|
| 102 |
+
"use_cache": true,
|
| 103 |
+
"vocab_size": 262208
|
| 104 |
+
},
|
| 105 |
+
"torch_dtype": "float",
|
| 106 |
+
"transformers_version": "4.51.3",
|
| 107 |
+
"use_cache": false,
|
| 108 |
+
"vision_config": {
|
| 109 |
+
"hidden_size": 1152,
|
| 110 |
+
"image_size": 896,
|
| 111 |
+
"intermediate_size": 4304,
|
| 112 |
+
"model_type": "siglip_vision_model",
|
| 113 |
+
"num_attention_heads": 16,
|
| 114 |
+
"num_hidden_layers": 27,
|
| 115 |
+
"patch_size": 14,
|
| 116 |
+
"vision_use_head": false
|
| 117 |
+
}
|
| 118 |
+
}
|
configuration_gemma3omni.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from transformers import AutoConfig, Gemma3TextConfig
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 6 |
+
from transformers.utils import logging
|
| 7 |
+
from transformers.models.siglip import SiglipVisionConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
class AudioConfig(PretrainedConfig):
|
| 13 |
+
model_type = "gemma3_audio"
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
input_size=80,
|
| 18 |
+
attention_dim=1024,
|
| 19 |
+
attention_heads=16,
|
| 20 |
+
num_blocks=24,
|
| 21 |
+
linear_units=1536,
|
| 22 |
+
dropout_rate=0.0,
|
| 23 |
+
kernel_size=3,
|
| 24 |
+
ext_pw_kernel_size=1,
|
| 25 |
+
ext_pw_out_channel=1024,
|
| 26 |
+
depthwise_seperable_out_channel=1024,
|
| 27 |
+
depthwise_multiplier=1,
|
| 28 |
+
activation="swish",
|
| 29 |
+
conv_activation="swish",
|
| 30 |
+
conv_glu_type="swish",
|
| 31 |
+
bias_in_glu=True,
|
| 32 |
+
causal=True,
|
| 33 |
+
batch_norm=False,
|
| 34 |
+
cnn_layer_norm=True,
|
| 35 |
+
time_reduction=8,
|
| 36 |
+
input_layer="nemo_conv",
|
| 37 |
+
nemo_conv_settings=None,
|
| 38 |
+
chunk_size=-1,
|
| 39 |
+
left_chunk=18,
|
| 40 |
+
relative_attention_bias_args=None,
|
| 41 |
+
activation_checkpointing=None,
|
| 42 |
+
encoder_embedding_config=None,
|
| 43 |
+
**kwargs
|
| 44 |
+
):
|
| 45 |
+
super().__init__(**kwargs)
|
| 46 |
+
|
| 47 |
+
self.input_size = input_size
|
| 48 |
+
self.attention_dim = attention_dim
|
| 49 |
+
self.attention_heads = attention_heads
|
| 50 |
+
self.num_blocks = num_blocks
|
| 51 |
+
self.linear_units = linear_units
|
| 52 |
+
self.dropout_rate = dropout_rate
|
| 53 |
+
self.kernel_size = kernel_size
|
| 54 |
+
self.ext_pw_kernel_size = ext_pw_kernel_size
|
| 55 |
+
self.ext_pw_out_channel = ext_pw_out_channel
|
| 56 |
+
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
| 57 |
+
self.depthwise_multiplier = depthwise_multiplier
|
| 58 |
+
self.activation = activation
|
| 59 |
+
self.conv_activation = conv_activation
|
| 60 |
+
self.conv_glu_type = conv_glu_type
|
| 61 |
+
self.bias_in_glu = bias_in_glu
|
| 62 |
+
self.causal = causal
|
| 63 |
+
self.batch_norm = batch_norm
|
| 64 |
+
self.cnn_layer_norm = cnn_layer_norm
|
| 65 |
+
self.time_reduction = time_reduction
|
| 66 |
+
self.input_layer = input_layer
|
| 67 |
+
|
| 68 |
+
if nemo_conv_settings is None:
|
| 69 |
+
self.nemo_conv_settings = {"conv_channels": 1024}
|
| 70 |
+
else:
|
| 71 |
+
self.nemo_conv_settings = nemo_conv_settings
|
| 72 |
+
|
| 73 |
+
self.chunk_size = chunk_size
|
| 74 |
+
self.left_chunk = left_chunk
|
| 75 |
+
|
| 76 |
+
if relative_attention_bias_args is None:
|
| 77 |
+
self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500}
|
| 78 |
+
else:
|
| 79 |
+
self.relative_attention_bias_args = relative_attention_bias_args
|
| 80 |
+
|
| 81 |
+
if activation_checkpointing is None:
|
| 82 |
+
self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False}
|
| 83 |
+
else:
|
| 84 |
+
self.activation_checkpointing = activation_checkpointing
|
| 85 |
+
|
| 86 |
+
if encoder_embedding_config is None:
|
| 87 |
+
self.encoder_embedding_config = {"input_size": input_size}
|
| 88 |
+
else:
|
| 89 |
+
self.encoder_embedding_config = encoder_embedding_config
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Gemma3OmniConfig(PretrainedConfig):
|
| 93 |
+
r"""
|
| 94 |
+
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
|
| 95 |
+
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 96 |
+
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
|
| 97 |
+
|
| 98 |
+
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
| 99 |
+
|
| 100 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 101 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
|
| 105 |
+
The config object of the text backbone.
|
| 106 |
+
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
| 107 |
+
Custom vision config or dict.
|
| 108 |
+
audio_config (`Union[AutoConfig, dict]`, *optional*):
|
| 109 |
+
Custom audio config or dict.
|
| 110 |
+
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
| 111 |
+
The number of tokens per image embedding.
|
| 112 |
+
boi_token_index (`int`, *optional*, defaults to 255999):
|
| 113 |
+
The begin-of-image token index to wrap the image prompt.
|
| 114 |
+
eoi_token_index (`int`, *optional*, defaults to 256000):
|
| 115 |
+
The end-of-image token index to wrap the image prompt.
|
| 116 |
+
image_token_index (`int`, *optional*, defaults to 262144):
|
| 117 |
+
The image token index to encode the image prompt.
|
| 118 |
+
audio_token_index (`int`, *optional*, defaults to 262145):
|
| 119 |
+
The audio token index to encode the audio prompt.
|
| 120 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 121 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
Example:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
| 128 |
+
|
| 129 |
+
>>> # Initializing a Siglip-like vision config
|
| 130 |
+
>>> vision_config = SiglipVisionConfig()
|
| 131 |
+
|
| 132 |
+
>>> # Initializing a Siglip-like vision config
|
| 133 |
+
>>> audio_config = AudioConfig()
|
| 134 |
+
|
| 135 |
+
>>> # Initializing a Gemma3 Text config
|
| 136 |
+
>>> text_config = Gemma3TextConfig()
|
| 137 |
+
|
| 138 |
+
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
| 139 |
+
>>> configuration = Gemma3Config(vision_config, text_config)
|
| 140 |
+
|
| 141 |
+
>>> # Initializing a model from the gemma-3-4b style configuration
|
| 142 |
+
>>> model = Gemma3TextConfig(configuration)
|
| 143 |
+
|
| 144 |
+
>>> # Accessing the model configuration
|
| 145 |
+
>>> configuration = model.config
|
| 146 |
+
```"""
|
| 147 |
+
|
| 148 |
+
model_type = "gemma3omni"
|
| 149 |
+
sub_configs = {
|
| 150 |
+
"text_config": Gemma3TextConfig,
|
| 151 |
+
"vision_config": SiglipVisionConfig,
|
| 152 |
+
"audio_config": AudioConfig,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
text_config: Optional[Gemma3TextConfig] = None,
|
| 158 |
+
vision_config: Optional[SiglipVisionConfig] = None,
|
| 159 |
+
audio_config: Optional[AudioConfig] = None,
|
| 160 |
+
mm_tokens_per_image: int = 256,
|
| 161 |
+
boi_token_index: int = 255_999,
|
| 162 |
+
eoi_token_index: int = 256_000,
|
| 163 |
+
boa_token_index: int = 256_001,
|
| 164 |
+
eoa_token_index: int = 256_002,
|
| 165 |
+
image_token_index: int = 262_144,
|
| 166 |
+
audio_token_index: int = 262_143,
|
| 167 |
+
initializer_range: float = 0.02,
|
| 168 |
+
**kwargs,
|
| 169 |
+
):
|
| 170 |
+
if text_config is None:
|
| 171 |
+
text_config = Gemma3TextConfig()
|
| 172 |
+
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
|
| 173 |
+
elif isinstance(text_config, dict):
|
| 174 |
+
text_config = Gemma3TextConfig(**text_config)
|
| 175 |
+
|
| 176 |
+
if isinstance(vision_config, dict):
|
| 177 |
+
vision_config = SiglipVisionConfig(**vision_config)
|
| 178 |
+
else:
|
| 179 |
+
vision_config = SiglipVisionConfig()
|
| 180 |
+
logger.info(
|
| 181 |
+
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
|
| 182 |
+
"to text tasks."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if isinstance(audio_config, dict):
|
| 186 |
+
audio_config = AudioConfig(**audio_config)
|
| 187 |
+
else:
|
| 188 |
+
audio_config = AudioConfig()
|
| 189 |
+
logger.info(
|
| 190 |
+
"audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited "
|
| 191 |
+
"to text tasks."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self.text_config = text_config
|
| 195 |
+
self.vision_config = vision_config
|
| 196 |
+
self.audio_config = audio_config
|
| 197 |
+
self.mm_tokens_per_image = mm_tokens_per_image
|
| 198 |
+
self.boi_token_index = boi_token_index
|
| 199 |
+
self.eoi_token_index = eoi_token_index
|
| 200 |
+
self.boa_token_index = boa_token_index
|
| 201 |
+
self.eoa_token_index = eoa_token_index
|
| 202 |
+
self.image_token_index = image_token_index
|
| 203 |
+
self.audio_token_index = audio_token_index
|
| 204 |
+
self.initializer_range = initializer_range
|
| 205 |
+
|
| 206 |
+
super().__init__(**kwargs)
|
eval.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from urllib.request import urlopen
|
| 3 |
+
import soundfile
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset, Audio
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModel, AutoProcessor, BatchFeature
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from whisper_normalizer.english import EnglishTextNormalizer
|
| 14 |
+
from whisper_normalizer.basic import BasicTextNormalizer
|
| 15 |
+
import sacrebleu
|
| 16 |
+
from jiwer import cer, wer
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
import soundfile as sf
|
| 19 |
+
import re
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import opencc
|
| 22 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 23 |
+
normalizer = {
|
| 24 |
+
"en_us" : EnglishTextNormalizer(),
|
| 25 |
+
"other" : BasicTextNormalizer()
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
model_id = "/mnt/jeff/gemma_test"
|
| 29 |
+
revision = "main" #"v1.0"
|
| 30 |
+
|
| 31 |
+
model = AutoModel.from_pretrained(
|
| 32 |
+
model_id, device_map="cuda", revision = revision, trust_remote_code=True
|
| 33 |
+
).eval()
|
| 34 |
+
|
| 35 |
+
processor = AutoProcessor.from_pretrained(
|
| 36 |
+
model_id, revision = revision, trust_remote_code=True
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 40 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
INSTRUCTION = {
|
| 44 |
+
"ast": "Translate the audio to {0}.",
|
| 45 |
+
"asr": "Transcribe the audio clip into text.",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
class BaseAudioDataset(Dataset):
|
| 49 |
+
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 50 |
+
self.processor = processor
|
| 51 |
+
self.training = "train" in split
|
| 52 |
+
self.debug = debug
|
| 53 |
+
self.sampling_rate = sampling_rate
|
| 54 |
+
self.name = ""
|
| 55 |
+
|
| 56 |
+
def set_dataset_name(self, name):
|
| 57 |
+
self.name = name
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
|
| 61 |
+
original_size = len(data)
|
| 62 |
+
|
| 63 |
+
data = data.cast_column(audio_field, Audio(decode=False))
|
| 64 |
+
|
| 65 |
+
def identify_corrupted_files(example):
|
| 66 |
+
try:
|
| 67 |
+
sf.read(example[audio_field]["path"])
|
| 68 |
+
|
| 69 |
+
for field in text_fields:
|
| 70 |
+
if example[field].replace('"', '') == "":
|
| 71 |
+
return False
|
| 72 |
+
return True
|
| 73 |
+
except Exception:
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
data = data.filter(identify_corrupted_files, num_proc=16)
|
| 77 |
+
validated_size = len(data)
|
| 78 |
+
|
| 79 |
+
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
return data
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
|
| 86 |
+
original_size = len(data)
|
| 87 |
+
|
| 88 |
+
def filter_audio_by_length(example):
|
| 89 |
+
try:
|
| 90 |
+
audio = example[audio_field]['array']
|
| 91 |
+
channel = 1
|
| 92 |
+
if hasattr(audio, 'ndim') and audio.ndim > 1:
|
| 93 |
+
channel = audio.ndim
|
| 94 |
+
audio = audio.squeeze()
|
| 95 |
+
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
|
| 96 |
+
return min_sec <= audio_length <= max_sec
|
| 97 |
+
except Exception as e:
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
data = data.filter(filter_audio_by_length, num_proc=16)
|
| 101 |
+
filtered_size = len(data)
|
| 102 |
+
|
| 103 |
+
return data
|
| 104 |
+
|
| 105 |
+
def prepare_model_inputs(self, audio_array, instruction, answer_text):
|
| 106 |
+
user_message = {
|
| 107 |
+
'role': 'user',
|
| 108 |
+
'content': '<start_of_audio>' + instruction,
|
| 109 |
+
}
|
| 110 |
+
prompt = self.processor.tokenizer.apply_chat_template(
|
| 111 |
+
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
inputs = self.processor(
|
| 115 |
+
text=prompt,
|
| 116 |
+
audio=[audio_array],
|
| 117 |
+
add_special_tokens=False,
|
| 118 |
+
return_tensors='pt'
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
input_ids = inputs.input_ids
|
| 122 |
+
token_type_ids = inputs.token_type_ids
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
'input_ids': input_ids,
|
| 126 |
+
'token_type_ids': token_type_ids,
|
| 127 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 128 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 129 |
+
'input_modes': inputs.input_modes,
|
| 130 |
+
'answer': answer_text,
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Libri Speech Dataset Class
|
| 135 |
+
class LibriSpeechDataset(BaseAudioDataset):
|
| 136 |
+
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
|
| 137 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 138 |
+
|
| 139 |
+
self.set_dataset_name(f"LibriSpeech_{subset}")
|
| 140 |
+
# only ASR
|
| 141 |
+
self.ast = False
|
| 142 |
+
self.lang = "en"
|
| 143 |
+
|
| 144 |
+
# load dataset
|
| 145 |
+
self.data = load_dataset("openslr/librispeech_asr",
|
| 146 |
+
subset,
|
| 147 |
+
split=split,
|
| 148 |
+
trust_remote_code=True,
|
| 149 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# (Optional) Audio length Filtering
|
| 153 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 154 |
+
|
| 155 |
+
# Instruction Setting
|
| 156 |
+
self.instruction = INSTRUCTION["asr"]
|
| 157 |
+
|
| 158 |
+
def __len__(self):
|
| 159 |
+
return len(self.data)
|
| 160 |
+
|
| 161 |
+
def __getitem__(self, idx):
|
| 162 |
+
data = self.data[idx]
|
| 163 |
+
|
| 164 |
+
# Libri Speech is only for ASR
|
| 165 |
+
answer_text = data["text"].replace('"', '')
|
| 166 |
+
|
| 167 |
+
return self.prepare_model_inputs(
|
| 168 |
+
data["audio"]["array"],
|
| 169 |
+
INSTRUCTION["asr"],
|
| 170 |
+
answer_text
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# common_voice_16_1 dataset
|
| 174 |
+
class CommonVoiceDataset(BaseAudioDataset):
|
| 175 |
+
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
|
| 176 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 177 |
+
|
| 178 |
+
self.set_dataset_name(f"CommonVoice_{source_lang}")
|
| 179 |
+
# only ASR
|
| 180 |
+
self.ast = False
|
| 181 |
+
self.lang=source_lang
|
| 182 |
+
|
| 183 |
+
# load dataset
|
| 184 |
+
self.data = load_dataset("mozilla-foundation/common_voice_16_1",
|
| 185 |
+
source_lang,
|
| 186 |
+
split=split,
|
| 187 |
+
trust_remote_code=True,
|
| 188 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 189 |
+
)
|
| 190 |
+
def prepare_dataset(batch):
|
| 191 |
+
"""Function to preprocess the dataset with the .map method"""
|
| 192 |
+
transcription = batch["sentence"]
|
| 193 |
+
|
| 194 |
+
if transcription.startswith('"') and transcription.endswith('"'):
|
| 195 |
+
# we can remove trailing quotation marks as they do not affect the transcription
|
| 196 |
+
transcription = transcription[1:-1]
|
| 197 |
+
|
| 198 |
+
if transcription[-1] not in [".", "?", "!"]:
|
| 199 |
+
# append a full-stop to sentences that do not end in punctuation
|
| 200 |
+
transcription = transcription + "."
|
| 201 |
+
|
| 202 |
+
batch["sentence"] = transcription
|
| 203 |
+
|
| 204 |
+
return batch
|
| 205 |
+
self.data=self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 206 |
+
|
| 207 |
+
# (Optional) Audio length Filtering
|
| 208 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 209 |
+
|
| 210 |
+
# Instruction Setting
|
| 211 |
+
self.instruction = INSTRUCTION["asr"]
|
| 212 |
+
|
| 213 |
+
def __len__(self):
|
| 214 |
+
return len(self.data)
|
| 215 |
+
|
| 216 |
+
def __getitem__(self, idx):
|
| 217 |
+
data = self.data[idx]
|
| 218 |
+
|
| 219 |
+
answer_text = data["sentence"]
|
| 220 |
+
return self.prepare_model_inputs(
|
| 221 |
+
data["audio"]["array"],
|
| 222 |
+
INSTRUCTION["asr"],
|
| 223 |
+
answer_text
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Fleurs Dataset Class
|
| 228 |
+
class FleursDataset(BaseAudioDataset):
|
| 229 |
+
def __init__(self, processor, split, source_lang, target_lang=None,
|
| 230 |
+
mode="asr", sampling_rate=16000, debug=False):
|
| 231 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 232 |
+
|
| 233 |
+
self.set_dataset_name("Fleurs")
|
| 234 |
+
# Mode Setting (ASR or AST)
|
| 235 |
+
if mode not in ["asr", "ast"]:
|
| 236 |
+
raise ValueError("mode must be 'asr' or 'ast'.")
|
| 237 |
+
|
| 238 |
+
self.mode = mode
|
| 239 |
+
self.ast = (mode == "ast")
|
| 240 |
+
self.source_lang = source_lang
|
| 241 |
+
|
| 242 |
+
# Language name mapping (expand if needed)
|
| 243 |
+
self.lang_names = {
|
| 244 |
+
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# load dataset - source language dataset
|
| 248 |
+
self.data = load_dataset("google/fleurs",
|
| 249 |
+
source_lang,
|
| 250 |
+
split=split,
|
| 251 |
+
trust_remote_code=True,
|
| 252 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 253 |
+
)
|
| 254 |
+
def prepare_dataset(batch):
|
| 255 |
+
import opencc
|
| 256 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 257 |
+
if self.ast:
|
| 258 |
+
translation = converter.convert(batch["translation"])
|
| 259 |
+
batch["translation"] = translation
|
| 260 |
+
else:
|
| 261 |
+
transcription = converter.convert(batch["transcription"])
|
| 262 |
+
batch["transcription"] = transcription
|
| 263 |
+
|
| 264 |
+
return batch
|
| 265 |
+
if (source_lang=="cmn_hans_cn" and not self.ast) or (self.ast and target_lang=="cmn_hans_cn"):
|
| 266 |
+
self.data=self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 267 |
+
|
| 268 |
+
# (Optional) Audio length Filtering
|
| 269 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 270 |
+
self.target_lang_name = ""
|
| 271 |
+
# When AST mode, load target language dataset.
|
| 272 |
+
if self.ast:
|
| 273 |
+
if target_lang is None:
|
| 274 |
+
raise ValueError("AST mode requires target_lang.")
|
| 275 |
+
|
| 276 |
+
self.target_lang = target_lang
|
| 277 |
+
self.lang = f"{source_lang}_{target_lang}"
|
| 278 |
+
|
| 279 |
+
# load dataset - target language dataset (for translation)
|
| 280 |
+
target_data = load_dataset("google/fleurs",
|
| 281 |
+
target_lang,
|
| 282 |
+
split=split,
|
| 283 |
+
trust_remote_code=True,
|
| 284 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
source_dict = {item['id']: item for item in self.data}
|
| 288 |
+
target_dict = {item['id']: item for item in target_data}
|
| 289 |
+
|
| 290 |
+
# only Common ID, add translation fields
|
| 291 |
+
common_ids = set(source_dict.keys()) & set(target_dict.keys())
|
| 292 |
+
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
|
| 293 |
+
self.data = [
|
| 294 |
+
{**source_dict[id], 'translation': target_dict[id]['transcription']}
|
| 295 |
+
for id in common_ids
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
+
# Instruction Setting - use target language name
|
| 299 |
+
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 300 |
+
self.instruction = INSTRUCTION["ast"]
|
| 301 |
+
else:
|
| 302 |
+
# ASR mode
|
| 303 |
+
self.lang = source_lang
|
| 304 |
+
self.instruction = INSTRUCTION["asr"]
|
| 305 |
+
|
| 306 |
+
if self.debug:
|
| 307 |
+
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
| 308 |
+
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
|
| 309 |
+
if self.ast:
|
| 310 |
+
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
|
| 311 |
+
print(f"dataset size: {len(self.data)}")
|
| 312 |
+
|
| 313 |
+
def __len__(self):
|
| 314 |
+
return len(self.data)
|
| 315 |
+
|
| 316 |
+
def __getitem__(self, idx):
|
| 317 |
+
data = self.data[idx]
|
| 318 |
+
audio_array = data["audio"]["array"]
|
| 319 |
+
|
| 320 |
+
if self.ast:
|
| 321 |
+
answer_text = data["translation"]
|
| 322 |
+
else:
|
| 323 |
+
answer_text = data["transcription"]
|
| 324 |
+
|
| 325 |
+
return self.prepare_model_inputs(
|
| 326 |
+
audio_array,
|
| 327 |
+
self.instruction.format(self.target_lang_name),
|
| 328 |
+
answer_text
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
def pad_sequence(sequences, padding_side='left', padding_value=0):
|
| 332 |
+
"""
|
| 333 |
+
Pad a list of sequences to the same length.
|
| 334 |
+
sequences: list of tensors in [seq_len, *] shape
|
| 335 |
+
"""
|
| 336 |
+
assert padding_side in ['right', 'left']
|
| 337 |
+
max_size = sequences[0].size()
|
| 338 |
+
trailing_dims = max_size[1:]
|
| 339 |
+
max_len = max(len(seq) for seq in sequences)
|
| 340 |
+
batch_size = len(sequences)
|
| 341 |
+
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
|
| 342 |
+
for i, seq in enumerate(sequences):
|
| 343 |
+
length = seq.size(0)
|
| 344 |
+
if padding_side == 'right':
|
| 345 |
+
output.data[i, :length] = seq
|
| 346 |
+
else:
|
| 347 |
+
output.data[i, -length:] = seq
|
| 348 |
+
return output
|
| 349 |
+
|
| 350 |
+
def cat_with_pad(tensors, dim, padding_value=0):
|
| 351 |
+
"""
|
| 352 |
+
cat along dim, while pad to max for all other dims
|
| 353 |
+
"""
|
| 354 |
+
ndim = tensors[0].dim()
|
| 355 |
+
assert all(
|
| 356 |
+
t.dim() == ndim for t in tensors[1:]
|
| 357 |
+
), 'All tensors must have the same number of dimensions'
|
| 358 |
+
|
| 359 |
+
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
|
| 360 |
+
out_size[dim] = sum(t.shape[dim] for t in tensors)
|
| 361 |
+
output = tensors[0].new_full(out_size, padding_value)
|
| 362 |
+
|
| 363 |
+
index = 0
|
| 364 |
+
for t in tensors:
|
| 365 |
+
# Create a slice list where every dimension except dim is full slice
|
| 366 |
+
slices = [slice(0, t.shape[d]) for d in range(ndim)]
|
| 367 |
+
# Update only the concat dimension slice
|
| 368 |
+
slices[dim] = slice(index, index + t.shape[dim])
|
| 369 |
+
|
| 370 |
+
output[slices] = t
|
| 371 |
+
index += t.shape[dim]
|
| 372 |
+
|
| 373 |
+
return output
|
| 374 |
+
|
| 375 |
+
def covost_collate_fn(batch):
|
| 376 |
+
input_ids_list = []
|
| 377 |
+
input_audio_embeds_list = []
|
| 378 |
+
audio_embed_sizes_list = []
|
| 379 |
+
audio_attention_mask_list = []
|
| 380 |
+
input_modes_list = []
|
| 381 |
+
answer_list = []
|
| 382 |
+
for inputs in batch:
|
| 383 |
+
input_ids_list.append(inputs['input_ids'][0])
|
| 384 |
+
input_audio_embeds_list.append(inputs['input_audio_embeds'])
|
| 385 |
+
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
|
| 386 |
+
audio_attention_mask_list.append(
|
| 387 |
+
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
|
| 388 |
+
)
|
| 389 |
+
input_modes_list.append(inputs['input_modes'])
|
| 390 |
+
answer_list.append(inputs['answer'])
|
| 391 |
+
|
| 392 |
+
try:
|
| 393 |
+
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
|
| 394 |
+
audio_attention_mask = (
|
| 395 |
+
pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
|
| 396 |
+
if len(audio_attention_mask_list) > 1
|
| 397 |
+
else None
|
| 398 |
+
)
|
| 399 |
+
except Exception as e:
|
| 400 |
+
print(e)
|
| 401 |
+
print(input_ids_list)
|
| 402 |
+
print(audio_attention_mask)
|
| 403 |
+
raise
|
| 404 |
+
attention_mask = (input_ids != 0).long()
|
| 405 |
+
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
|
| 406 |
+
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
|
| 407 |
+
input_modes = torch.cat(input_modes_list)
|
| 408 |
+
|
| 409 |
+
return BatchFeature(
|
| 410 |
+
{
|
| 411 |
+
'input_ids': input_ids,
|
| 412 |
+
'attention_mask': attention_mask,
|
| 413 |
+
'input_audio_embeds': input_audio_embeds,
|
| 414 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 415 |
+
'audio_attention_mask': audio_attention_mask,
|
| 416 |
+
'input_modes': input_modes,
|
| 417 |
+
'answer': answer_list,
|
| 418 |
+
}
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
|
| 422 |
+
filename = f"{task}_{dataset_name}_{source_lang}"
|
| 423 |
+
if target_lang:
|
| 424 |
+
filename += f"_to_{target_lang}"
|
| 425 |
+
if sample_idx is not None:
|
| 426 |
+
filename += f"_sample_{sample_idx}"
|
| 427 |
+
|
| 428 |
+
filepath = os.path.join(results_dir, f"{filename}.json")
|
| 429 |
+
|
| 430 |
+
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 431 |
+
|
| 432 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 433 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 434 |
+
|
| 435 |
+
return filepath
|
| 436 |
+
|
| 437 |
+
def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True):
|
| 438 |
+
task_type = "asr" if is_asr else "translation"
|
| 439 |
+
eval_lang = source_lang if is_asr else target_lang
|
| 440 |
+
if eval_lang in normalizer:
|
| 441 |
+
eval_normalizer = normalizer[eval_lang]
|
| 442 |
+
else:
|
| 443 |
+
eval_normalizer = normalizer['other']
|
| 444 |
+
sample_results = []
|
| 445 |
+
|
| 446 |
+
if num_samples > 0 and num_samples < len(dataset):
|
| 447 |
+
indices = np.random.choice(len(dataset), num_samples, replace=False)
|
| 448 |
+
dataset = dataset.select(indices)
|
| 449 |
+
|
| 450 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn)
|
| 451 |
+
|
| 452 |
+
evaluated_samples = {}
|
| 453 |
+
|
| 454 |
+
for batch_idx, batch in enumerate(tqdm(dataloader)):
|
| 455 |
+
batch_references = batch.pop("answer")
|
| 456 |
+
|
| 457 |
+
if torch.cuda.is_available():
|
| 458 |
+
try:
|
| 459 |
+
batch = {k: v.to("cuda") for k, v in batch.items()}
|
| 460 |
+
except:
|
| 461 |
+
print('error')
|
| 462 |
+
break
|
| 463 |
+
|
| 464 |
+
with torch.inference_mode():
|
| 465 |
+
generate_ids = model.generate(**batch,
|
| 466 |
+
max_new_tokens=256,
|
| 467 |
+
#temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
input_lengths = batch['input_ids'].shape[1]
|
| 471 |
+
generate_ids = generate_ids[:, input_lengths:]
|
| 472 |
+
|
| 473 |
+
batch_predictions = processor.batch_decode(
|
| 474 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
|
| 478 |
+
idx = batch_idx * batch_size + i
|
| 479 |
+
sample_result = {
|
| 480 |
+
"id": idx,
|
| 481 |
+
"reference": reference,
|
| 482 |
+
"prediction": converter.convert(prediction)
|
| 483 |
+
}
|
| 484 |
+
sample_results.append(sample_result)
|
| 485 |
+
|
| 486 |
+
if (batch_idx + 1) % 10 == 0:
|
| 487 |
+
temp_results = []
|
| 488 |
+
|
| 489 |
+
for item in sample_results:
|
| 490 |
+
sample_id = item["id"]
|
| 491 |
+
|
| 492 |
+
if sample_id in evaluated_samples:
|
| 493 |
+
temp_item = item.copy()
|
| 494 |
+
temp_item.update(evaluated_samples[sample_id])
|
| 495 |
+
temp_results.append(temp_item)
|
| 496 |
+
else:
|
| 497 |
+
temp_item = item.copy()
|
| 498 |
+
try:
|
| 499 |
+
ref = eval_normalizer(item["reference"])
|
| 500 |
+
pred = eval_normalizer(item["prediction"])
|
| 501 |
+
|
| 502 |
+
# BLEU, WER/CER
|
| 503 |
+
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
|
| 504 |
+
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
|
| 505 |
+
utt_wer = round(wer(ref, pred) * 100, 2)
|
| 506 |
+
|
| 507 |
+
metrics = {
|
| 508 |
+
"bleu": utt_bleu,
|
| 509 |
+
"cer": min(100,utt_cer),
|
| 510 |
+
"wer": utt_wer
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
evaluated_samples[sample_id] = metrics
|
| 514 |
+
temp_item.update(metrics)
|
| 515 |
+
except Exception as e:
|
| 516 |
+
print(f"Error evaluating sample {sample_id}: {e}")
|
| 517 |
+
metrics = {
|
| 518 |
+
"bleu": 0,
|
| 519 |
+
"cer": 100,
|
| 520 |
+
"wer": 100,
|
| 521 |
+
"error": str(e)
|
| 522 |
+
}
|
| 523 |
+
evaluated_samples[sample_id] = metrics
|
| 524 |
+
temp_item.update(metrics)
|
| 525 |
+
|
| 526 |
+
temp_results.append(temp_item)
|
| 527 |
+
|
| 528 |
+
partial_results = {
|
| 529 |
+
"task": task_type,
|
| 530 |
+
"source_lang": source_lang,
|
| 531 |
+
"target_lang": target_lang,
|
| 532 |
+
"num_samples": len(temp_results),
|
| 533 |
+
"sample_results": temp_results
|
| 534 |
+
}
|
| 535 |
+
save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
|
| 536 |
+
|
| 537 |
+
for item in sample_results:
|
| 538 |
+
ref = eval_normalizer(item["reference"])
|
| 539 |
+
pred = eval_normalizer(item["prediction"])
|
| 540 |
+
|
| 541 |
+
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
|
| 542 |
+
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
|
| 543 |
+
utt_wer = round(wer(ref, pred) * 100, 2)
|
| 544 |
+
|
| 545 |
+
item.update({
|
| 546 |
+
"bleu": utt_bleu,
|
| 547 |
+
"cer": min(100,utt_cer),
|
| 548 |
+
"wer": utt_wer
|
| 549 |
+
})
|
| 550 |
+
|
| 551 |
+
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
|
| 552 |
+
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
|
| 553 |
+
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
|
| 554 |
+
|
| 555 |
+
results = {
|
| 556 |
+
"dataset": dataset.name,
|
| 557 |
+
"task": task_type,
|
| 558 |
+
"source_lang": source_lang,
|
| 559 |
+
"target_lang": target_lang,
|
| 560 |
+
"num_samples": len(sample_results),
|
| 561 |
+
"metrics": {
|
| 562 |
+
"bleu": avg_bleu,
|
| 563 |
+
"cer": avg_cer,
|
| 564 |
+
"wer": avg_wer
|
| 565 |
+
},
|
| 566 |
+
"sample_results": sample_results
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
save_results(results, dataset.name, task_type, source_lang, target_lang)
|
| 570 |
+
return results
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
if __name__ == "__main__":
|
| 574 |
+
|
| 575 |
+
source_languages = [
|
| 576 |
+
("en_us", "English"),
|
| 577 |
+
]
|
| 578 |
+
|
| 579 |
+
target_languages = [
|
| 580 |
+
("zh-TW", "zh-TW"),
|
| 581 |
+
]
|
| 582 |
+
|
| 583 |
+
num_samples = -1
|
| 584 |
+
batch_size = 32
|
| 585 |
+
|
| 586 |
+
for source_lang, target_lang in zip(source_languages, target_languages):
|
| 587 |
+
print(f"\n===== {source_lang[0]} ASR =====")
|
| 588 |
+
|
| 589 |
+
split = "test"
|
| 590 |
+
|
| 591 |
+
datasets = []
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
commonvoice_speech_tw = CommonVoiceDataset(
|
| 596 |
+
processor=processor,
|
| 597 |
+
source_lang="zh-TW",
|
| 598 |
+
split=split
|
| 599 |
+
)
|
| 600 |
+
datasets.append(commonvoice_speech_tw)
|
| 601 |
+
fleurs = FleursDataset(
|
| 602 |
+
processor=processor,
|
| 603 |
+
split=split,
|
| 604 |
+
source_lang="en_us", # English
|
| 605 |
+
mode="asr"
|
| 606 |
+
)
|
| 607 |
+
datasets.append(fleurs)
|
| 608 |
+
|
| 609 |
+
# Libri Speech Clean ASR mode (English -> English text)
|
| 610 |
+
# libri_speech_clean = LibriSpeechDataset(
|
| 611 |
+
# processor=processor,
|
| 612 |
+
# subset="clean",
|
| 613 |
+
# split=split
|
| 614 |
+
# )
|
| 615 |
+
# datasets.append(libri_speech_clean)
|
| 616 |
+
|
| 617 |
+
# # Libri Speech Other ASR mode (English -> English text)
|
| 618 |
+
# libri_speech_other = LibriSpeechDataset(
|
| 619 |
+
# processor=processor,
|
| 620 |
+
# subset="other",
|
| 621 |
+
# split=split
|
| 622 |
+
# )
|
| 623 |
+
# datasets.append(libri_speech_other)
|
| 624 |
+
|
| 625 |
+
# Fleurs ASR mode (English -> English text)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
for dataset in datasets:
|
| 629 |
+
# ASR
|
| 630 |
+
asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
|
| 631 |
+
|
| 632 |
+
print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR===")
|
| 633 |
+
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
|
| 634 |
+
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
|
| 635 |
+
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
|
eval_multiturn.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
eval_multiturn.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from urllib.request import urlopen
|
| 3 |
+
import soundfile
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset, Audio
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModel, AutoProcessor, BatchFeature
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from whisper_normalizer.english import EnglishTextNormalizer
|
| 14 |
+
from whisper_normalizer.basic import BasicTextNormalizer
|
| 15 |
+
import sacrebleu
|
| 16 |
+
from jiwer import cer, wer
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
import soundfile as sf
|
| 19 |
+
import re
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import opencc
|
| 22 |
+
from ASRDataset import *
|
| 23 |
+
|
| 24 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 25 |
+
normalizer = {
|
| 26 |
+
"en_us" : EnglishTextNormalizer(),
|
| 27 |
+
"other" : BasicTextNormalizer()
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
model_id = "/mnt/jeff/gemma_test"
|
| 31 |
+
revision = "main" #"v1.0"
|
| 32 |
+
|
| 33 |
+
model = AutoModel.from_pretrained(
|
| 34 |
+
model_id, device_map="cuda", revision = revision, trust_remote_code=True
|
| 35 |
+
).eval()
|
| 36 |
+
|
| 37 |
+
processor = AutoProcessor.from_pretrained(
|
| 38 |
+
model_id, revision = revision, trust_remote_code=True
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 42 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
INSTRUCTION = {
|
| 46 |
+
"ast": "Translate the audio to {0}.",
|
| 47 |
+
"asr": "Transcribe the audio clip into text.",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
|
| 53 |
+
filename = f"{task}_{dataset_name}_{source_lang}"
|
| 54 |
+
if target_lang:
|
| 55 |
+
filename += f"_to_{target_lang}"
|
| 56 |
+
if sample_idx is not None:
|
| 57 |
+
filename += f"_sample_{sample_idx}"
|
| 58 |
+
|
| 59 |
+
filepath = os.path.join(results_dir, f"{filename}.json")
|
| 60 |
+
|
| 61 |
+
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 62 |
+
|
| 63 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 64 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 65 |
+
|
| 66 |
+
return filepath
|
| 67 |
+
|
| 68 |
+
def evaluate_task(dataset):
|
| 69 |
+
sample_results = []
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
|
| 73 |
+
|
| 74 |
+
evaluated_samples = {}
|
| 75 |
+
|
| 76 |
+
for batch_idx, batch in enumerate(tqdm(dataloader)):
|
| 77 |
+
|
| 78 |
+
if torch.cuda.is_available():
|
| 79 |
+
try:
|
| 80 |
+
batch = {k: v.to("cuda") for k, v in batch.items()}
|
| 81 |
+
except:
|
| 82 |
+
print('error')
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
with torch.inference_mode():
|
| 86 |
+
generate_ids = model.generate(**batch,
|
| 87 |
+
max_new_tokens=256,
|
| 88 |
+
#temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
input_lengths = batch['input_ids'].shape[1]
|
| 92 |
+
generate_ids = generate_ids[:, input_lengths:]
|
| 93 |
+
|
| 94 |
+
batch_predictions = processor.batch_decode(
|
| 95 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 96 |
+
)
|
| 97 |
+
input_lengths = batch['input_ids'].shape[1]
|
| 98 |
+
label_ids = generate_ids[:, input_lengths:]
|
| 99 |
+
batch_references = processor.batch_decode(
|
| 100 |
+
label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
|
| 104 |
+
idx = batch_idx + i
|
| 105 |
+
sample_result = {
|
| 106 |
+
"id": idx,
|
| 107 |
+
"reference": reference,
|
| 108 |
+
"prediction": converter.convert(prediction)
|
| 109 |
+
}
|
| 110 |
+
sample_results.append(sample_result)
|
| 111 |
+
|
| 112 |
+
if (batch_idx + 1) % 10 == 0:
|
| 113 |
+
temp_results = []
|
| 114 |
+
|
| 115 |
+
for item in sample_results:
|
| 116 |
+
sample_id = item["id"]
|
| 117 |
+
|
| 118 |
+
if sample_id in evaluated_samples:
|
| 119 |
+
temp_item = item.copy()
|
| 120 |
+
temp_item.update(evaluated_samples[sample_id])
|
| 121 |
+
temp_results.append(temp_item)
|
| 122 |
+
else:
|
| 123 |
+
temp_item = item.copy()
|
| 124 |
+
try:
|
| 125 |
+
ref = eval_normalizer(item["reference"])
|
| 126 |
+
pred = eval_normalizer(item["prediction"])
|
| 127 |
+
|
| 128 |
+
# BLEU, WER/CER
|
| 129 |
+
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
|
| 130 |
+
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
|
| 131 |
+
utt_wer = round(wer(ref, pred) * 100, 2)
|
| 132 |
+
|
| 133 |
+
metrics = {
|
| 134 |
+
"bleu": utt_bleu,
|
| 135 |
+
"cer": min(100,utt_cer),
|
| 136 |
+
"wer": utt_wer
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
evaluated_samples[sample_id] = metrics
|
| 140 |
+
temp_item.update(metrics)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"Error evaluating sample {sample_id}: {e}")
|
| 143 |
+
metrics = {
|
| 144 |
+
"bleu": 0,
|
| 145 |
+
"cer": 100,
|
| 146 |
+
"wer": 100,
|
| 147 |
+
"error": str(e)
|
| 148 |
+
}
|
| 149 |
+
evaluated_samples[sample_id] = metrics
|
| 150 |
+
temp_item.update(metrics)
|
| 151 |
+
|
| 152 |
+
temp_results.append(temp_item)
|
| 153 |
+
|
| 154 |
+
partial_results = {
|
| 155 |
+
"task": task_type,
|
| 156 |
+
"source_lang": source_lang,
|
| 157 |
+
"target_lang": target_lang,
|
| 158 |
+
"num_samples": len(temp_results),
|
| 159 |
+
"sample_results": temp_results
|
| 160 |
+
}
|
| 161 |
+
save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
|
| 162 |
+
|
| 163 |
+
for item in sample_results:
|
| 164 |
+
ref = eval_normalizer(item["reference"])
|
| 165 |
+
pred = eval_normalizer(item["prediction"])
|
| 166 |
+
|
| 167 |
+
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
|
| 168 |
+
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
|
| 169 |
+
utt_wer = round(wer(ref, pred) * 100, 2)
|
| 170 |
+
|
| 171 |
+
item.update({
|
| 172 |
+
"bleu": utt_bleu,
|
| 173 |
+
"cer": min(100,utt_cer),
|
| 174 |
+
"wer": utt_wer
|
| 175 |
+
})
|
| 176 |
+
|
| 177 |
+
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
|
| 178 |
+
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
|
| 179 |
+
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
|
| 180 |
+
|
| 181 |
+
results = {
|
| 182 |
+
"dataset": dataset.name,
|
| 183 |
+
"task": task_type,
|
| 184 |
+
"source_lang": source_lang,
|
| 185 |
+
"target_lang": target_lang,
|
| 186 |
+
"num_samples": len(sample_results),
|
| 187 |
+
"metrics": {
|
| 188 |
+
"bleu": avg_bleu,
|
| 189 |
+
"cer": avg_cer,
|
| 190 |
+
"wer": avg_wer
|
| 191 |
+
},
|
| 192 |
+
"sample_results": sample_results
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
save_results(results, dataset.name, task_type, source_lang, target_lang)
|
| 196 |
+
return results
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
|
| 201 |
+
datasets = []
|
| 202 |
+
pickup_dataset = MultiturnAudioDataset(split='eval',processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
|
| 203 |
+
datasets.append(pickup_dataset)
|
| 204 |
+
for dataset in datasets:
|
| 205 |
+
# ASR
|
| 206 |
+
asr_results = evaluate_task(dataset)
|
| 207 |
+
|
| 208 |
+
print(f"\n=== {asr_results.get('dataset', 'Dataset')}")
|
| 209 |
+
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
|
| 210 |
+
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
|
| 211 |
+
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
|
merge_lora.ipynb
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from safetensors import safe_open\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"lora = {}\n",
|
| 12 |
+
"with safe_open(\"/data2/bjh/diffusion-pipe/cosmos_test/20250327_02-37-25/epoch5/adapter_model.safetensors\", framework=\"pt\", device='cpu') as f:\n",
|
| 13 |
+
" for k in f.keys():\n",
|
| 14 |
+
" lora[k] = f.get_tensor(k)"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 2,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"tensors = {}\n",
|
| 24 |
+
"with safe_open(\"/data2/bjh/ComfyUI/models/diffusion_models/Cosmos-1_0-Diffusion-14B-Text2World.safetensors\", framework=\"pt\", device='cpu') as f:\n",
|
| 25 |
+
" for k in f.keys():\n",
|
| 26 |
+
" tensors[k] = f.get_tensor(k)"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": 3,
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [
|
| 34 |
+
{
|
| 35 |
+
"data": {
|
| 36 |
+
"text/plain": [
|
| 37 |
+
"1152"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
"execution_count": 3,
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"output_type": "execute_result"
|
| 43 |
+
}
|
| 44 |
+
],
|
| 45 |
+
"source": [
|
| 46 |
+
"len(lora)"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": 4,
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"name_lis = []\n",
|
| 56 |
+
"for k in lora:\n",
|
| 57 |
+
" a = k.split('.')[1:][:-2]\n",
|
| 58 |
+
" name = '.'.join(a)\n",
|
| 59 |
+
" name_lis.append(name)\n",
|
| 60 |
+
"name_lis=set(name_lis)"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"execution_count": 5,
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"source": [
|
| 69 |
+
"import torch\n",
|
| 70 |
+
"new_dic = {}\n",
|
| 71 |
+
"for k in tensors:\n",
|
| 72 |
+
" name='.'.join(k.split('.')[1:][:-1])\n",
|
| 73 |
+
" if name in name_lis:\n",
|
| 74 |
+
" a,b = lora['diffusion_model.'+name+'.lora_A.weight'],lora['diffusion_model.'+name+'.lora_B.weight']\n",
|
| 75 |
+
" new_dic[k]=tensors[k]+torch.matmul(b,a)\n",
|
| 76 |
+
" else:\n",
|
| 77 |
+
" new_dic[k]=tensors[k]"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": 6,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"from safetensors.torch import save_file\n",
|
| 87 |
+
"save_file(new_dic,'test.safetensors')"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": []
|
| 96 |
+
}
|
| 97 |
+
],
|
| 98 |
+
"metadata": {
|
| 99 |
+
"kernelspec": {
|
| 100 |
+
"display_name": "dp",
|
| 101 |
+
"language": "python",
|
| 102 |
+
"name": "python3"
|
| 103 |
+
},
|
| 104 |
+
"language_info": {
|
| 105 |
+
"codemirror_mode": {
|
| 106 |
+
"name": "ipython",
|
| 107 |
+
"version": 3
|
| 108 |
+
},
|
| 109 |
+
"file_extension": ".py",
|
| 110 |
+
"mimetype": "text/x-python",
|
| 111 |
+
"name": "python",
|
| 112 |
+
"nbconvert_exporter": "python",
|
| 113 |
+
"pygments_lexer": "ipython3",
|
| 114 |
+
"version": "3.12.9"
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
"nbformat": 4,
|
| 118 |
+
"nbformat_minor": 2
|
| 119 |
+
}
|
model-00001-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ddd3e8916f7ad6ad92651ac288227995c1d34628f0f888eb2dc5b9acb4dc0121
|
| 3 |
+
size 4976361384
|
model-00002-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c343a455ca768923cb3b9ab77cbb91c9cd2526a1bee5740cf9cf86bfa85a0a7b
|
| 3 |
+
size 4984907872
|
model-00003-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ad04449d015f4efbda75d6cc41e06296b4da996cd84053fa6f9791fe16d55d03
|
| 3 |
+
size 732141104
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_gemma3omni.py
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from collections.abc import Callable
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from transformers.activations import ACT2FN
|
| 10 |
+
from transformers.cache_utils import Cache, HybridCache, StaticCache
|
| 11 |
+
from transformers.generation import GenerationMixin
|
| 12 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
| 14 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 15 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 16 |
+
from transformers.processing_utils import Unpack
|
| 17 |
+
from transformers.utils import (
|
| 18 |
+
add_start_docstrings,
|
| 19 |
+
add_start_docstrings_to_model_forward,
|
| 20 |
+
is_torchdynamo_compiling,
|
| 21 |
+
logging,
|
| 22 |
+
replace_return_docstrings,
|
| 23 |
+
)
|
| 24 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 25 |
+
from transformers import AutoModel, AutoModelForCausalLM
|
| 26 |
+
|
| 27 |
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast, Gemma3PreTrainedModel, Gemma3MultiModalProjector
|
| 28 |
+
|
| 29 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 30 |
+
|
| 31 |
+
from .configuration_gemma3omni import Gemma3OmniConfig
|
| 32 |
+
from .speech_conformer_encoder import ConformerEncoder
|
| 33 |
+
from enum import Enum
|
| 34 |
+
class InputMode(Enum):
|
| 35 |
+
LANGUAGE = 0
|
| 36 |
+
VISION = 1
|
| 37 |
+
SPEECH = 2
|
| 38 |
+
VISION_SPEECH = 3
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
_CONFIG_FOR_DOC = "Gemma3OmniConfig"
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class Gemma3OmniCausalLMOutputWithPast(Gemma3CausalLMOutputWithPast):
|
| 44 |
+
"""
|
| 45 |
+
Multimodal version of `Gemma3CausalLMOutputWithPast`.
|
| 46 |
+
Adds audio-specific hidden states.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
audio_hidden_states (`torch.FloatTensor`, *optional*):
|
| 50 |
+
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
|
| 51 |
+
Audio hidden states produced by the audio encoder.
|
| 52 |
+
"""
|
| 53 |
+
audio_hidden_states: Optional[torch.FloatTensor] = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
GEMMA3_START_DOCSTRING = r"""
|
| 57 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 58 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 59 |
+
etc.)
|
| 60 |
+
|
| 61 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 62 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 63 |
+
and behavior.
|
| 64 |
+
|
| 65 |
+
Parameters:
|
| 66 |
+
config ([`Gemma3Config`]):
|
| 67 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 68 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 69 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
GEMMA3_INPUTS_DOCSTRING = r"""
|
| 75 |
+
Args:
|
| 76 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 77 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 78 |
+
it.
|
| 79 |
+
|
| 80 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 81 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 82 |
+
|
| 83 |
+
[What are input IDs?](../glossary#input-ids)
|
| 84 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 85 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 86 |
+
|
| 87 |
+
- 1 for tokens that are **not masked**,
|
| 88 |
+
- 0 for tokens that are **masked**.
|
| 89 |
+
|
| 90 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 91 |
+
|
| 92 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 93 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 94 |
+
|
| 95 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 96 |
+
`past_key_values`).
|
| 97 |
+
|
| 98 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 99 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 100 |
+
information on the default strategy.
|
| 101 |
+
|
| 102 |
+
- 1 indicates the head is **not masked**,
|
| 103 |
+
- 0 indicates the head is **masked**.
|
| 104 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 105 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 106 |
+
config.n_positions - 1]`.
|
| 107 |
+
|
| 108 |
+
[What are position IDs?](../glossary#position-ids)
|
| 109 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 110 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 111 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 112 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 113 |
+
|
| 114 |
+
Two formats are allowed:
|
| 115 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 116 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 117 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 118 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 119 |
+
cache format.
|
| 120 |
+
|
| 121 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 122 |
+
legacy cache format will be returned.
|
| 123 |
+
|
| 124 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 125 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 126 |
+
of shape `(batch_size, sequence_length)`.
|
| 127 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 128 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 129 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 130 |
+
model's internal embedding lookup matrix.
|
| 131 |
+
use_cache (`bool`, *optional*):
|
| 132 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 133 |
+
`past_key_values`).
|
| 134 |
+
output_attentions (`bool`, *optional*):
|
| 135 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 136 |
+
tensors for more detail.
|
| 137 |
+
output_hidden_states (`bool`, *optional*):
|
| 138 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 139 |
+
more detail.
|
| 140 |
+
return_dict (`bool`, *optional*):
|
| 141 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 142 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 143 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 144 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 145 |
+
the complete sequence length.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
@add_start_docstrings(
|
| 149 |
+
"The bare Gemma3 Model outputting raw hidden-states without any specific head on top.",
|
| 150 |
+
GEMMA3_START_DOCSTRING,
|
| 151 |
+
)
|
| 152 |
+
class Gemma3OmniPreTrainedModel(Gemma3PreTrainedModel):
|
| 153 |
+
config_class = Gemma3OmniConfig
|
| 154 |
+
|
| 155 |
+
@add_start_docstrings(
|
| 156 |
+
"""The GEMMA3 model which consists of a vision backbone and a language model.""",
|
| 157 |
+
GEMMA3_START_DOCSTRING,
|
| 158 |
+
)
|
| 159 |
+
class Gemma3OmniForConditionalGeneration(Gemma3OmniPreTrainedModel, GenerationMixin):
|
| 160 |
+
def __init__(self, config: Gemma3OmniConfig):
|
| 161 |
+
super().__init__(config)
|
| 162 |
+
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
| 163 |
+
audio_config = config.audio_config.to_diff_dict()
|
| 164 |
+
for item in ['transformers_version', 'model_type', 'torch_dtype']:
|
| 165 |
+
if item in audio_config:
|
| 166 |
+
audio_config.pop(item)
|
| 167 |
+
self.audio_tower = ConformerEncoder(**audio_config)
|
| 168 |
+
self.audio_tower.post_init({})
|
| 169 |
+
self.audio_tower = self.audio_tower.to(dtype=self.dtype)
|
| 170 |
+
self.audio_projector = nn.Sequential(
|
| 171 |
+
nn.Linear(in_features=config.audio_config.attention_dim, out_features=config.text_config.hidden_size, bias=True),
|
| 172 |
+
nn.GELU(approximate='none'),
|
| 173 |
+
nn.Linear(in_features=config.text_config.hidden_size, out_features=config.text_config.hidden_size, bias=True)
|
| 174 |
+
).to(dtype=self.dtype)
|
| 175 |
+
|
| 176 |
+
self.multi_modal_projector = Gemma3MultiModalProjector(config)
|
| 177 |
+
self.vocab_size = config.text_config.vocab_size
|
| 178 |
+
|
| 179 |
+
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
|
| 180 |
+
|
| 181 |
+
if language_model._tied_weights_keys is not None:
|
| 182 |
+
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
| 183 |
+
self.language_model = language_model
|
| 184 |
+
|
| 185 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
| 186 |
+
self.init_lora()
|
| 187 |
+
self.post_init()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def init_lora(self):
|
| 191 |
+
from peft import LoraConfig, get_peft_model
|
| 192 |
+
import warnings
|
| 193 |
+
print('######################## speech lora #############')
|
| 194 |
+
speech_lora_config = LoraConfig(
|
| 195 |
+
r=self.config.speech_lora['r'],
|
| 196 |
+
lora_alpha=self.config.speech_lora['lora_alpha'],
|
| 197 |
+
target_modules=self.config.speech_lora['layer'],
|
| 198 |
+
use_rslora=self.config.speech_lora['use_rslora'],
|
| 199 |
+
lora_dropout=self.config.speech_lora['dp'],
|
| 200 |
+
task_type="CAUSAL_LM",
|
| 201 |
+
)
|
| 202 |
+
self.language_model.model = get_peft_model(self.language_model.model, speech_lora_config, adapter_name="speech")
|
| 203 |
+
print('######################## text lora #############')
|
| 204 |
+
text_lora_config = LoraConfig(
|
| 205 |
+
r=self.config.text_lora['r'],
|
| 206 |
+
lora_alpha=self.config.text_lora['lora_alpha'],
|
| 207 |
+
target_modules=self.config.text_lora['layer'],
|
| 208 |
+
use_rslora=self.config.text_lora['use_rslora'],
|
| 209 |
+
lora_dropout=self.config.text_lora['dp'],
|
| 210 |
+
task_type="CAUSAL_LM",
|
| 211 |
+
)
|
| 212 |
+
self.language_model.model.base_model.active_adapter.append("text")
|
| 213 |
+
self.language_model.model.add_adapter("text", text_lora_config)
|
| 214 |
+
|
| 215 |
+
def set_lora_adapter(self, adapter_name) -> None:
|
| 216 |
+
from peft.tuners.lora.layer import LoraLayer
|
| 217 |
+
for module in self.modules():
|
| 218 |
+
if isinstance(module, LoraLayer):
|
| 219 |
+
if module.merged:
|
| 220 |
+
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
|
| 221 |
+
module.unmerge()
|
| 222 |
+
module.set_adapter(adapter_name)
|
| 223 |
+
module._disable_adapters = False
|
| 224 |
+
|
| 225 |
+
def unset_lora_adapter(self) -> None:
|
| 226 |
+
# Ref: peft/tuners/tuners_utils.py - enable_adapters()
|
| 227 |
+
# Ref: peft/tuners/lora/layer.py
|
| 228 |
+
from peft.tuners.lora.layer import LoraLayer
|
| 229 |
+
for module in self.modules():
|
| 230 |
+
if isinstance(module, LoraLayer):
|
| 231 |
+
# disable grads on all adapter layers
|
| 232 |
+
# TODO weijian: may use enable_adapters() instead
|
| 233 |
+
for layer_name in module.adapter_layer_names:
|
| 234 |
+
layer = getattr(module, layer_name)
|
| 235 |
+
layer.requires_grad_(False)
|
| 236 |
+
module._disable_adapters = True
|
| 237 |
+
|
| 238 |
+
def get_input_embeddings(self):
|
| 239 |
+
return self.language_model.get_input_embeddings()
|
| 240 |
+
|
| 241 |
+
def set_input_embeddings(self, value):
|
| 242 |
+
self.language_model.set_input_embeddings(value)
|
| 243 |
+
|
| 244 |
+
def get_output_embeddings(self):
|
| 245 |
+
return self.language_model.get_output_embeddings()
|
| 246 |
+
|
| 247 |
+
def set_output_embeddings(self, new_embeddings):
|
| 248 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 249 |
+
|
| 250 |
+
def set_decoder(self, decoder):
|
| 251 |
+
self.language_model.set_decoder(decoder)
|
| 252 |
+
|
| 253 |
+
def get_decoder(self):
|
| 254 |
+
return self.language_model.get_decoder()
|
| 255 |
+
|
| 256 |
+
def _update_causal_mask(
|
| 257 |
+
self,
|
| 258 |
+
attention_mask,
|
| 259 |
+
token_type_ids,
|
| 260 |
+
past_key_values,
|
| 261 |
+
cache_position,
|
| 262 |
+
input_tensor,
|
| 263 |
+
is_training: bool = False,
|
| 264 |
+
):
|
| 265 |
+
if self.config.text_config._attn_implementation == "flash_attention_2":
|
| 266 |
+
return attention_mask
|
| 267 |
+
|
| 268 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 269 |
+
# In this case we assume that the mask comes already in inverted
|
| 270 |
+
# form and requires no inversion or slicing.
|
| 271 |
+
return attention_mask
|
| 272 |
+
|
| 273 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 274 |
+
min_dtype = torch.finfo(self.dtype).min
|
| 275 |
+
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
| 276 |
+
if using_static_cache:
|
| 277 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 278 |
+
elif isinstance(past_key_values, HybridCache):
|
| 279 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 280 |
+
else:
|
| 281 |
+
target_length = (
|
| 282 |
+
attention_mask.shape[-1]
|
| 283 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 284 |
+
else cache_position[0] + sequence_length + 1
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 288 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 289 |
+
return attention_mask
|
| 290 |
+
|
| 291 |
+
causal_mask = torch.full(
|
| 292 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
| 296 |
+
if sequence_length != 1:
|
| 297 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 298 |
+
|
| 299 |
+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
| 300 |
+
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
| 301 |
+
|
| 302 |
+
# Apply bidirectional mask on images if token type ids are provided
|
| 303 |
+
if token_type_ids is not None and sequence_length != 1:
|
| 304 |
+
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
| 305 |
+
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
| 306 |
+
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
| 307 |
+
causal_mask = causal_mask.clone()
|
| 308 |
+
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
| 309 |
+
token_type_mask, 0.0
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if attention_mask is not None:
|
| 313 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 314 |
+
mask_length = attention_mask.shape[-1]
|
| 315 |
+
|
| 316 |
+
# Then apply padding mask (will mask pad tokens)
|
| 317 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
| 318 |
+
padding_mask = padding_mask == 0
|
| 319 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 320 |
+
padding_mask, min_dtype
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
return causal_mask
|
| 324 |
+
|
| 325 |
+
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 326 |
+
"""
|
| 327 |
+
Projects the last hidden state from the vision model into language model space.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
| 331 |
+
The tensors corresponding to the input images.
|
| 332 |
+
Returns:
|
| 333 |
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
| 334 |
+
"""
|
| 335 |
+
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
| 336 |
+
image_features = self.multi_modal_projector(vision_outputs)
|
| 337 |
+
return image_features
|
| 338 |
+
|
| 339 |
+
def get_audio_features(self, input_audio_embeds: torch.FloatTensor, audio_attention_mask: torch.FloatTensor, audio_embed_sizes: torch.FloatTensor):
|
| 340 |
+
"""
|
| 341 |
+
Projects the last hidden state from the audio model into language model space.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
audio_inputs (`torch.FloatTensor]` of shape `(batch_size, sequence_length, feature_dim)`)
|
| 345 |
+
The tensors corresponding to the input audio features.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
audio_features (`torch.Tensor`): Audio feature tensor of shape `(batch_size, audio_length, embed_dim)`).
|
| 349 |
+
"""
|
| 350 |
+
audio_features, masks = self.audio_tower(input_audio_embeds, audio_attention_mask)
|
| 351 |
+
audio_outputs = self.audio_projector(audio_features)
|
| 352 |
+
return audio_outputs
|
| 353 |
+
|
| 354 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
| 355 |
+
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
| 356 |
+
@replace_return_docstrings(output_type=Gemma3OmniCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 357 |
+
def forward(
|
| 358 |
+
self,
|
| 359 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 360 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 361 |
+
input_audio_embeds: torch.FloatTensor = None,
|
| 362 |
+
audio_embed_sizes: torch.FloatTensor = None,
|
| 363 |
+
audio_attention_mask: torch.FloatTensor = None,
|
| 364 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 365 |
+
input_modes: torch.LongTensor = None,
|
| 366 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 367 |
+
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
| 368 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 369 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 370 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 371 |
+
labels: Optional[torch.LongTensor] = None,
|
| 372 |
+
use_cache: Optional[bool] = None,
|
| 373 |
+
output_attentions: Optional[bool] = None,
|
| 374 |
+
output_hidden_states: Optional[bool] = None,
|
| 375 |
+
return_dict: Optional[bool] = None,
|
| 376 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 377 |
+
**lm_kwargs,
|
| 378 |
+
) -> Union[Tuple, Gemma3OmniCausalLMOutputWithPast]:
|
| 379 |
+
r"""
|
| 380 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 381 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 382 |
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 383 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
| 384 |
+
|
| 385 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 386 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 387 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 388 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 389 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 390 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
|
| 394 |
+
Example:
|
| 395 |
+
|
| 396 |
+
```python
|
| 397 |
+
>>> from PIL import Image
|
| 398 |
+
>>> import requests
|
| 399 |
+
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 400 |
+
|
| 401 |
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
| 402 |
+
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
| 403 |
+
|
| 404 |
+
>>> messages = [
|
| 405 |
+
... {
|
| 406 |
+
... "role": "system",
|
| 407 |
+
... "content": [
|
| 408 |
+
... {"type": "text", "text": "You are a helpful assistant."}
|
| 409 |
+
... ]
|
| 410 |
+
... },
|
| 411 |
+
... {
|
| 412 |
+
... "role": "user", "content": [
|
| 413 |
+
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
| 414 |
+
... {"type": "text", "text": "Where is the cat standing?"},
|
| 415 |
+
... ]
|
| 416 |
+
... },
|
| 417 |
+
... ]
|
| 418 |
+
|
| 419 |
+
>>> inputs = processor.apply_chat_template(
|
| 420 |
+
... messages,
|
| 421 |
+
... tokenizer=True,
|
| 422 |
+
... return_dict=True,
|
| 423 |
+
... return_tensors="pt",
|
| 424 |
+
... add_generation_prompt=True
|
| 425 |
+
... )
|
| 426 |
+
>>> # Generate
|
| 427 |
+
>>> generate_ids = model.generate(**inputs)
|
| 428 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 429 |
+
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
| 430 |
+
```
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 434 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 435 |
+
|
| 436 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 437 |
+
output_hidden_states = (
|
| 438 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 439 |
+
)
|
| 440 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 441 |
+
|
| 442 |
+
if isinstance(input_modes, torch.Tensor):
|
| 443 |
+
# len(input_mode) == num_beams in beam search, and all elements of input_mode should have the same value
|
| 444 |
+
input_modes = input_modes.unique()
|
| 445 |
+
if len(input_modes) != 1:
|
| 446 |
+
raise ValueError("Elements of input_modes should have the same value")
|
| 447 |
+
|
| 448 |
+
input_mode = InputMode(input_modes.item())
|
| 449 |
+
|
| 450 |
+
if input_mode in [InputMode.VISION_SPEECH, InputMode.VISION]:
|
| 451 |
+
self.unset_lora_adapter()
|
| 452 |
+
#self.set_lora_adapter('vision')
|
| 453 |
+
#audio_projection_mode = 'vision'
|
| 454 |
+
elif input_mode == InputMode.SPEECH:
|
| 455 |
+
self.unset_lora_adapter()
|
| 456 |
+
self.set_lora_adapter('speech')
|
| 457 |
+
#audio_projection_mode = 'speech'
|
| 458 |
+
elif input_mode == InputMode.LANGUAGE:
|
| 459 |
+
self.unset_lora_adapter()
|
| 460 |
+
self.set_lora_adapter('text')
|
| 461 |
+
|
| 462 |
+
#audio_projection_mode = 'speech'
|
| 463 |
+
else:
|
| 464 |
+
raise ValueError(f"Invalid input_mode: {input_mode}")
|
| 465 |
+
|
| 466 |
+
is_training = token_type_ids is not None and labels is not None
|
| 467 |
+
|
| 468 |
+
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
| 469 |
+
if input_ids is not None and self.config.image_token_index >= self.vocab_size or self.config.audio_token_index >= self.vocab_size:
|
| 470 |
+
special_image_mask = input_ids == self.config.image_token_index
|
| 471 |
+
special_audio_mask = input_ids == self.config.audio_token_index
|
| 472 |
+
llm_input_ids = input_ids.clone()
|
| 473 |
+
llm_input_ids[special_image_mask] = 0
|
| 474 |
+
llm_input_ids[special_audio_mask] = 0
|
| 475 |
+
else:
|
| 476 |
+
llm_input_ids = input_ids
|
| 477 |
+
|
| 478 |
+
if inputs_embeds is None:
|
| 479 |
+
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
| 480 |
+
inputs_embeds = inputs_embeds.to(dtype=self.dtype)
|
| 481 |
+
if cache_position is None:
|
| 482 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 483 |
+
cache_position = torch.arange(
|
| 484 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if position_ids is None:
|
| 488 |
+
position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
|
| 489 |
+
|
| 490 |
+
# Merge text and images
|
| 491 |
+
if pixel_values is not None:
|
| 492 |
+
image_features = self.get_image_features(pixel_values)
|
| 493 |
+
|
| 494 |
+
if input_ids is None:
|
| 495 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 496 |
+
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
| 500 |
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 501 |
+
|
| 502 |
+
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 503 |
+
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
| 504 |
+
raise ValueError(
|
| 505 |
+
f"Number of images does not match number of special image tokens in the input text. "
|
| 506 |
+
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
| 507 |
+
"tokens from image embeddings."
|
| 508 |
+
)
|
| 509 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 510 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 511 |
+
|
| 512 |
+
# Merge text and audios
|
| 513 |
+
if input_audio_embeds is not None:
|
| 514 |
+
input_audio_embeds=input_audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 515 |
+
if audio_attention_mask is not None:
|
| 516 |
+
audio_attention_mask=audio_attention_mask.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 517 |
+
audio_features = self.get_audio_features(input_audio_embeds, audio_attention_mask, audio_embed_sizes)
|
| 518 |
+
if input_ids is None:
|
| 519 |
+
special_audio_mask = inputs_embeds == self.get_input_embeddings()(
|
| 520 |
+
torch.tensor(self.config.audio_token_index, dtype=torch.long, device=inputs_embeds.device)
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
special_audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1)
|
| 524 |
+
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 525 |
+
masked_audio_features = []
|
| 526 |
+
for i, size in enumerate(audio_embed_sizes):
|
| 527 |
+
masked_audio_features.append(audio_features[i, :size, :])
|
| 528 |
+
masked_audio_features = torch.cat(masked_audio_features, dim=0)
|
| 529 |
+
|
| 530 |
+
if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != masked_audio_features.numel():
|
| 531 |
+
audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
|
| 532 |
+
masked_audio_size = audio_embed_sizes#.sum()[0]
|
| 533 |
+
raise ValueError(
|
| 534 |
+
f"Number of audio does not match number of special audio tokens in the input text. "
|
| 535 |
+
f"Got {audio_tokens_in_text} audio tokens in the text but {masked_audio_size} "
|
| 536 |
+
"tokens from audio embeddings. "
|
| 537 |
+
f"{masked_audio_features.numel()} \n"
|
| 538 |
+
f"{inputs_embeds[special_audio_mask].numel()} \n"
|
| 539 |
+
f"{audio_features} \n"
|
| 540 |
+
f"{inputs_embeds[special_audio_mask]} \n"
|
| 541 |
+
f"{special_audio_mask} \n"
|
| 542 |
+
)
|
| 543 |
+
masked_audio_features = masked_audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 544 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, masked_audio_features)
|
| 545 |
+
# mask out pad-token-ids in labels for BC
|
| 546 |
+
if labels is not None and self.pad_token_id in labels:
|
| 547 |
+
logger.warning_once(
|
| 548 |
+
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
| 549 |
+
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
| 550 |
+
)
|
| 551 |
+
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
|
| 552 |
+
|
| 553 |
+
causal_mask = self._update_causal_mask(
|
| 554 |
+
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
| 555 |
+
)
|
| 556 |
+
outputs = self.language_model(
|
| 557 |
+
attention_mask=causal_mask,
|
| 558 |
+
position_ids=position_ids,
|
| 559 |
+
past_key_values=past_key_values,
|
| 560 |
+
inputs_embeds=inputs_embeds,
|
| 561 |
+
use_cache=use_cache,
|
| 562 |
+
output_attentions=output_attentions,
|
| 563 |
+
output_hidden_states=output_hidden_states,
|
| 564 |
+
return_dict=return_dict,
|
| 565 |
+
cache_position=cache_position,
|
| 566 |
+
logits_to_keep=logits_to_keep,
|
| 567 |
+
**lm_kwargs,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
logits = outputs.logits
|
| 571 |
+
loss = None
|
| 572 |
+
# print('#############################')
|
| 573 |
+
# print(logits)
|
| 574 |
+
if labels is not None:
|
| 575 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 576 |
+
logits = logits.float()
|
| 577 |
+
shift_logits = logits[..., :-1, :]
|
| 578 |
+
shift_labels = labels[..., 1:]
|
| 579 |
+
if attention_mask is not None:
|
| 580 |
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
| 581 |
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
| 582 |
+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
| 583 |
+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
| 584 |
+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
| 585 |
+
else:
|
| 586 |
+
shift_logits = shift_logits.contiguous()
|
| 587 |
+
shift_labels = shift_labels.contiguous()
|
| 588 |
+
# Flatten the tokens
|
| 589 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 590 |
+
|
| 591 |
+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
| 592 |
+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
| 593 |
+
loss = loss_fct(flat_logits, flat_labels)
|
| 594 |
+
# print('flat logits',flat_logits)
|
| 595 |
+
# print(flat_labels)
|
| 596 |
+
# print(loss)
|
| 597 |
+
if not return_dict:
|
| 598 |
+
output = (logits,) + outputs[1:]
|
| 599 |
+
return (loss,) + output if loss is not None else output
|
| 600 |
+
|
| 601 |
+
return Gemma3OmniCausalLMOutputWithPast(
|
| 602 |
+
loss=loss,
|
| 603 |
+
logits=logits,
|
| 604 |
+
past_key_values=outputs.past_key_values,
|
| 605 |
+
hidden_states=outputs.hidden_states,
|
| 606 |
+
attentions=outputs.attentions,
|
| 607 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
| 608 |
+
audio_hidden_states=audio_features if input_audio_embeds is not None else None,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
def prepare_inputs_for_generation(
|
| 612 |
+
self,
|
| 613 |
+
input_ids,
|
| 614 |
+
past_key_values=None,
|
| 615 |
+
input_modes=None,
|
| 616 |
+
inputs_embeds=None,
|
| 617 |
+
cache_position=None,
|
| 618 |
+
position_ids=None,
|
| 619 |
+
pixel_values=None,
|
| 620 |
+
input_audio_embeds=None,
|
| 621 |
+
audio_embed_sizes=None,
|
| 622 |
+
audio_attention_mask=None,
|
| 623 |
+
attention_mask=None,
|
| 624 |
+
token_type_ids=None,
|
| 625 |
+
use_cache=True,
|
| 626 |
+
logits_to_keep=None,
|
| 627 |
+
labels=None,
|
| 628 |
+
**kwargs,
|
| 629 |
+
):
|
| 630 |
+
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
| 631 |
+
model_inputs = self.language_model.prepare_inputs_for_generation(
|
| 632 |
+
input_ids,
|
| 633 |
+
past_key_values=past_key_values,
|
| 634 |
+
input_modes=input_modes,
|
| 635 |
+
inputs_embeds=inputs_embeds,
|
| 636 |
+
attention_mask=attention_mask,
|
| 637 |
+
position_ids=position_ids,
|
| 638 |
+
cache_position=cache_position,
|
| 639 |
+
use_cache=use_cache,
|
| 640 |
+
logits_to_keep=logits_to_keep,
|
| 641 |
+
token_type_ids=token_type_ids,
|
| 642 |
+
**kwargs,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# position_ids in Gemma3 are 1-indexed
|
| 646 |
+
if model_inputs.get("position_ids") is not None:
|
| 647 |
+
model_inputs["position_ids"] += 1
|
| 648 |
+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 649 |
+
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
| 650 |
+
if cache_position[0] == 0:
|
| 651 |
+
model_inputs["pixel_values"] = pixel_values
|
| 652 |
+
model_inputs["input_audio_embeds"] = input_audio_embeds
|
| 653 |
+
model_inputs["audio_embed_sizes"] = audio_embed_sizes
|
| 654 |
+
model_inputs["audio_attention_mask"] = audio_attention_mask
|
| 655 |
+
model_inputs["input_modes"] = input_modes
|
| 656 |
+
is_training = token_type_ids is not None and labels is not None
|
| 657 |
+
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
| 658 |
+
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
| 659 |
+
causal_mask = self._update_causal_mask(
|
| 660 |
+
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
| 661 |
+
)
|
| 662 |
+
model_inputs["attention_mask"] = causal_mask
|
| 663 |
+
|
| 664 |
+
return model_inputs
|
| 665 |
+
|
| 666 |
+
def tie_weights(self):
|
| 667 |
+
return self.language_model.tie_weights()
|
| 668 |
+
|
preprocessing_gemma3omni.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Optional, Union, Tuple
|
| 3 |
+
from math import ceil
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import scipy
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
from transformers import AutoFeatureExtractor
|
| 13 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 14 |
+
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 15 |
+
from transformers.image_utils import ImageInput, make_nested_list_of_images
|
| 16 |
+
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
|
| 17 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 18 |
+
from transformers.utils import to_py_obj, TensorType
|
| 19 |
+
from transformers.audio_utils import AudioInput
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Gemma3ImagesKwargs(ImagesKwargs):
|
| 23 |
+
do_pan_and_scan: Optional[bool]
|
| 24 |
+
pan_and_scan_min_crop_size: Optional[int]
|
| 25 |
+
pan_and_scan_max_num_crops: Optional[int]
|
| 26 |
+
pan_and_scan_min_ratio_to_activate: Optional[float]
|
| 27 |
+
do_convert_rgb: Optional[bool]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
| 31 |
+
images_kwargs: Gemma3ImagesKwargs
|
| 32 |
+
_defaults = {
|
| 33 |
+
"text_kwargs": {
|
| 34 |
+
"padding": False,
|
| 35 |
+
},
|
| 36 |
+
"images_kwargs": {
|
| 37 |
+
"do_pan_and_scan": False,
|
| 38 |
+
"pan_and_scan_min_crop_size": 256,
|
| 39 |
+
"pan_and_scan_max_num_crops": 4,
|
| 40 |
+
"pan_and_scan_min_ratio_to_activate": 1.2,
|
| 41 |
+
},
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
|
| 45 |
+
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
|
| 49 |
+
n_fft (int): FFT size. int > 0 [scalar]
|
| 50 |
+
n_mel (int): Mel filter size. int > 0 [scalar]
|
| 51 |
+
fmin (float): lowest frequency (in Hz). If None use 0.0.
|
| 52 |
+
float >= 0 [scalar]
|
| 53 |
+
fmax: highest frequency (in Hz). If None use sample_rate / 2.
|
| 54 |
+
float >= 0 [scalar]
|
| 55 |
+
|
| 56 |
+
Returns
|
| 57 |
+
out (numpy.ndarray): Mel transform matrix
|
| 58 |
+
[shape=(n_mels, 1 + n_fft/2)]
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
bank_width = int(n_fft // 2 + 1)
|
| 62 |
+
if fmax is None:
|
| 63 |
+
fmax = sample_rate / 2
|
| 64 |
+
if fmin is None:
|
| 65 |
+
fmin = 0
|
| 66 |
+
assert fmin >= 0, "fmin cannot be negtive"
|
| 67 |
+
assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
|
| 68 |
+
|
| 69 |
+
def mel(f):
|
| 70 |
+
return 1127.0 * np.log(1.0 + f / 700.0)
|
| 71 |
+
|
| 72 |
+
def bin2mel(fft_bin):
|
| 73 |
+
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
|
| 74 |
+
|
| 75 |
+
def f2bin(f):
|
| 76 |
+
return int((f * n_fft / sample_rate) + 0.5)
|
| 77 |
+
|
| 78 |
+
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
|
| 79 |
+
klo = f2bin(fmin) + 1
|
| 80 |
+
khi = f2bin(fmax)
|
| 81 |
+
|
| 82 |
+
khi = max(khi, klo)
|
| 83 |
+
|
| 84 |
+
# Spec 2: SpeechLib uses trianges in Mel space
|
| 85 |
+
mlo = mel(fmin)
|
| 86 |
+
mhi = mel(fmax)
|
| 87 |
+
m_centers = np.linspace(mlo, mhi, n_mels + 2)
|
| 88 |
+
ms = (mhi - mlo) / (n_mels + 1)
|
| 89 |
+
|
| 90 |
+
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
|
| 91 |
+
for m in range(0, n_mels):
|
| 92 |
+
left = m_centers[m]
|
| 93 |
+
center = m_centers[m + 1]
|
| 94 |
+
right = m_centers[m + 2]
|
| 95 |
+
for fft_bin in range(klo, khi):
|
| 96 |
+
mbin = bin2mel(fft_bin)
|
| 97 |
+
if left < mbin < right:
|
| 98 |
+
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
|
| 99 |
+
|
| 100 |
+
return matrix
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
| 104 |
+
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
|
| 105 |
+
|
| 106 |
+
def __init__(self, audio_compression_rate=8,
|
| 107 |
+
audio_downsample_rate=1,
|
| 108 |
+
audio_feat_stride=1,
|
| 109 |
+
feature_size = 80,
|
| 110 |
+
sampling_rate = 16000,
|
| 111 |
+
padding_value = 0.0,
|
| 112 |
+
**kwargs):
|
| 113 |
+
|
| 114 |
+
super().__init__(feature_size=feature_size,
|
| 115 |
+
sampling_rate=sampling_rate,
|
| 116 |
+
padding_value=padding_value, **kwargs)
|
| 117 |
+
|
| 118 |
+
self.compression_rate = audio_compression_rate
|
| 119 |
+
self.qformer_compression_rate = audio_downsample_rate
|
| 120 |
+
self.feat_stride = audio_feat_stride
|
| 121 |
+
|
| 122 |
+
self._eightk_method = "fillzero"
|
| 123 |
+
self._mel = speechlib_mel(self.sampling_rate, 512, self.feature_size, fmin=None, fmax=self.sampling_rate//2-self.feature_size-230).T
|
| 124 |
+
|
| 125 |
+
self._hamming400 = np.hamming(400) # for 16k audio
|
| 126 |
+
self._hamming200 = np.hamming(200) # for 8k audio
|
| 127 |
+
|
| 128 |
+
def duration_to_frames(self, duration):
|
| 129 |
+
"""duration in s, estimated frames"""
|
| 130 |
+
frame_rate = 10
|
| 131 |
+
|
| 132 |
+
num_frames = duration * 1000 // frame_rate
|
| 133 |
+
return num_frames
|
| 134 |
+
|
| 135 |
+
def __call__(
|
| 136 |
+
self,
|
| 137 |
+
audios: List[AudioInput],
|
| 138 |
+
sampling_rate = 16000,
|
| 139 |
+
return_attention_mask=True,
|
| 140 |
+
padding="max_length",
|
| 141 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 142 |
+
):
|
| 143 |
+
# Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
|
| 144 |
+
returned_input_audio_embeds = []
|
| 145 |
+
returned_audio_embed_sizes = []
|
| 146 |
+
audio_frames_list = []
|
| 147 |
+
|
| 148 |
+
for audio_data in audios:
|
| 149 |
+
audio_embeds = self._extract_features(audio_data, sampling_rate)
|
| 150 |
+
audio_frames = len(audio_embeds) * self.feat_stride
|
| 151 |
+
audio_embed_size = self._compute_audio_embed_size(audio_frames)
|
| 152 |
+
|
| 153 |
+
returned_input_audio_embeds.append(torch.tensor(audio_embeds))
|
| 154 |
+
returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
|
| 155 |
+
audio_frames_list.append(audio_frames)
|
| 156 |
+
|
| 157 |
+
returned_input_audio_embeds = pad_sequence(
|
| 158 |
+
returned_input_audio_embeds, batch_first=True
|
| 159 |
+
)
|
| 160 |
+
returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
|
| 161 |
+
audio_frames = torch.tensor(audio_frames_list)
|
| 162 |
+
returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
|
| 163 |
+
|
| 164 |
+
data = {
|
| 165 |
+
"input_audio_embeds": returned_input_audio_embeds,
|
| 166 |
+
"audio_embed_sizes": returned_audio_embed_sizes,
|
| 167 |
+
}
|
| 168 |
+
if returned_audio_attention_mask is not None and return_attention_mask:
|
| 169 |
+
data["audio_attention_mask"] = returned_audio_attention_mask
|
| 170 |
+
|
| 171 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 172 |
+
|
| 173 |
+
def _extract_spectrogram(self, wav, fs):
|
| 174 |
+
"""Extract spectrogram features from waveform.
|
| 175 |
+
Args:
|
| 176 |
+
wav (1D array): waveform of the input
|
| 177 |
+
fs (int): sampling rate of the waveform, 16000 or 8000.
|
| 178 |
+
If fs=8000, the waveform will be resampled to 16000Hz.
|
| 179 |
+
Output:
|
| 180 |
+
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
|
| 181 |
+
D=80, and T is the number of frames.
|
| 182 |
+
"""
|
| 183 |
+
if wav.ndim > 1:
|
| 184 |
+
wav = np.squeeze(wav)
|
| 185 |
+
|
| 186 |
+
# by default, we extract the mean if stereo
|
| 187 |
+
if len(wav.shape) == 2:
|
| 188 |
+
wav = wav.mean(1)
|
| 189 |
+
|
| 190 |
+
# Resample to 16000 or 8000 if needed
|
| 191 |
+
if fs > 16000:
|
| 192 |
+
wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
|
| 193 |
+
fs = 16000
|
| 194 |
+
elif 8000 < fs < 16000:
|
| 195 |
+
wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
|
| 196 |
+
fs = 8000
|
| 197 |
+
elif fs < 8000:
|
| 198 |
+
raise RuntimeError(f"Unsupported sample rate {fs}")
|
| 199 |
+
|
| 200 |
+
if fs == 8000:
|
| 201 |
+
if self._eightk_method == "resample":
|
| 202 |
+
# Input audio is 8 kHz. Convert to 16 kHz before feature
|
| 203 |
+
# extraction
|
| 204 |
+
wav = scipy.signal.resample_poly(wav, 2, 1)
|
| 205 |
+
fs = 16000
|
| 206 |
+
# Do nothing here for fillzero method
|
| 207 |
+
elif fs != 16000:
|
| 208 |
+
# Input audio is not a supported sample rate.
|
| 209 |
+
raise RuntimeError(f"Input data using an unsupported sample rate: {fs}")
|
| 210 |
+
|
| 211 |
+
preemphasis = 0.97
|
| 212 |
+
|
| 213 |
+
if fs == 8000:
|
| 214 |
+
n_fft = 256
|
| 215 |
+
win_length = 200
|
| 216 |
+
hop_length = 80
|
| 217 |
+
fft_window = self._hamming200
|
| 218 |
+
elif fs == 16000:
|
| 219 |
+
n_fft = 512
|
| 220 |
+
win_length = 400
|
| 221 |
+
hop_length = 160
|
| 222 |
+
fft_window = self._hamming400
|
| 223 |
+
|
| 224 |
+
# Spec 1: SpeechLib cut remaining sample insufficient for a hop
|
| 225 |
+
n_batch = (wav.shape[0] - win_length) // hop_length + 1
|
| 226 |
+
# Here we don't use stride_tricks since the input array may not satisfy
|
| 227 |
+
# memory layout requirement and we need writeable output
|
| 228 |
+
# Here we only use list of views before copy to desination
|
| 229 |
+
# so it is more efficient than broadcasting
|
| 230 |
+
y_frames = np.array(
|
| 231 |
+
[wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
|
| 232 |
+
dtype=np.float32,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Spec 2: SpeechLib applies preemphasis within each batch
|
| 236 |
+
y_frames_prev = np.roll(y_frames, 1, axis=1)
|
| 237 |
+
y_frames_prev[:, 0] = y_frames_prev[:, 1]
|
| 238 |
+
y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
|
| 239 |
+
|
| 240 |
+
S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
|
| 241 |
+
|
| 242 |
+
if fs == 8000:
|
| 243 |
+
# Need to pad the output to look like 16 kHz data but with zeros in
|
| 244 |
+
# the 4 to 8 kHz bins.
|
| 245 |
+
frames, bins = S.shape
|
| 246 |
+
padarray = np.zeros((frames, bins))
|
| 247 |
+
S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
|
| 248 |
+
|
| 249 |
+
spec = np.abs(S).astype(np.float32)
|
| 250 |
+
return spec
|
| 251 |
+
|
| 252 |
+
def _extract_features(self, wav, fs):
|
| 253 |
+
"""Extract log filterbank features from waveform.
|
| 254 |
+
Args:
|
| 255 |
+
wav (1D array): waveform of the input
|
| 256 |
+
fs (int): sampling rate of the waveform, 16000 or 8000.
|
| 257 |
+
If fs=8000, the waveform will be resampled to 16000Hz.
|
| 258 |
+
Output:
|
| 259 |
+
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
|
| 260 |
+
D=80, and T is the number of frames.
|
| 261 |
+
"""
|
| 262 |
+
spec = self._extract_spectrogram(wav, fs)
|
| 263 |
+
spec_power = spec**2
|
| 264 |
+
|
| 265 |
+
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
|
| 266 |
+
log_fbank = np.log(fbank_power).astype(np.float32)
|
| 267 |
+
|
| 268 |
+
return log_fbank
|
| 269 |
+
|
| 270 |
+
def _compute_audio_embed_size(self, audio_frames):
|
| 271 |
+
integer = audio_frames // self.compression_rate
|
| 272 |
+
remainder = audio_frames % self.compression_rate
|
| 273 |
+
|
| 274 |
+
result = integer if remainder == 0 else integer + 1
|
| 275 |
+
|
| 276 |
+
integer = result // self.qformer_compression_rate
|
| 277 |
+
remainder = result % self.qformer_compression_rate
|
| 278 |
+
result = integer if remainder == 0 else integer + 1 # qformer compression
|
| 279 |
+
|
| 280 |
+
return result
|
| 281 |
+
|
| 282 |
+
class Gemma3OmniProcessor(ProcessorMixin):
|
| 283 |
+
attributes = ["image_processor", "feature_extractor", "tokenizer"]
|
| 284 |
+
valid_kwargs = ["chat_template", "image_seq_length"]
|
| 285 |
+
image_processor_class = "AutoImageProcessor"
|
| 286 |
+
feature_extractor_class = "Gemma3AudioFeatureExtractor"
|
| 287 |
+
tokenizer_class = "AutoTokenizer"
|
| 288 |
+
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
image_processor,
|
| 292 |
+
feature_extractor,
|
| 293 |
+
tokenizer,
|
| 294 |
+
chat_template=None,
|
| 295 |
+
image_seq_length: int = 256,
|
| 296 |
+
**kwargs,
|
| 297 |
+
):
|
| 298 |
+
self.image_seq_length = image_seq_length
|
| 299 |
+
self.image_token_id = tokenizer.image_token_id
|
| 300 |
+
self.boi_token = tokenizer.boi_token
|
| 301 |
+
self.image_token = tokenizer.image_token
|
| 302 |
+
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
|
| 303 |
+
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
| 304 |
+
|
| 305 |
+
self.audio_token_id = tokenizer.audio_token_id
|
| 306 |
+
self.boa_token = tokenizer.boa_token
|
| 307 |
+
self.eoa_token = tokenizer.eoa_token
|
| 308 |
+
self.audio_token = tokenizer.audio_token
|
| 309 |
+
|
| 310 |
+
super().__init__(
|
| 311 |
+
image_processor=image_processor,
|
| 312 |
+
feature_extractor=feature_extractor,
|
| 313 |
+
tokenizer=tokenizer,
|
| 314 |
+
chat_template=chat_template,
|
| 315 |
+
**kwargs,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def __call__(
|
| 319 |
+
self,
|
| 320 |
+
images: ImageInput = None,
|
| 321 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 322 |
+
videos=None,
|
| 323 |
+
audio: List[AudioInput] = None,
|
| 324 |
+
**kwargs: Unpack[Gemma3ProcessorKwargs],
|
| 325 |
+
) -> BatchFeature:
|
| 326 |
+
if text is None and images is None:
|
| 327 |
+
raise ValueError("Provide at least one of `text` or `images`.")
|
| 328 |
+
|
| 329 |
+
output_kwargs = self._merge_kwargs(
|
| 330 |
+
Gemma3ProcessorKwargs,
|
| 331 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 332 |
+
**kwargs,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if isinstance(text, str):
|
| 336 |
+
text = [text]
|
| 337 |
+
elif not isinstance(text, list) and not isinstance(text[0], str):
|
| 338 |
+
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
| 339 |
+
|
| 340 |
+
image_inputs = {}
|
| 341 |
+
if images is not None:
|
| 342 |
+
batched_images = make_nested_list_of_images(images)
|
| 343 |
+
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
|
| 344 |
+
|
| 345 |
+
# Create empty text to be replaced with placeholders
|
| 346 |
+
if not text:
|
| 347 |
+
text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
|
| 348 |
+
|
| 349 |
+
if len(batched_images) != len(text):
|
| 350 |
+
raise ValueError(
|
| 351 |
+
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Replace image tokens by the full expanded sequence
|
| 355 |
+
num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
| 356 |
+
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
|
| 357 |
+
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
| 358 |
+
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
| 359 |
+
|
| 360 |
+
if len(images) != len(image_indexes):
|
| 361 |
+
raise ValueError(
|
| 362 |
+
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Insert additional image tokens for Pan-and-Scan crops
|
| 366 |
+
for num, idx in reversed(list(zip(num_crops, image_indexes))):
|
| 367 |
+
if num:
|
| 368 |
+
formatted_image_text = (
|
| 369 |
+
f"Here is the original image {self.boi_token} and here are some crops to help you see better "
|
| 370 |
+
+ " ".join([self.boi_token] * num)
|
| 371 |
+
)
|
| 372 |
+
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
|
| 373 |
+
text[batch_idx] = prompt
|
| 374 |
+
|
| 375 |
+
# Expand placeholder image tokens to the full image token sequence
|
| 376 |
+
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
| 377 |
+
|
| 378 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 379 |
+
|
| 380 |
+
audio_inputs = {}
|
| 381 |
+
if audio is not None:
|
| 382 |
+
full_audio_sequences = []
|
| 383 |
+
audio_inputs = self.feature_extractor(audio)
|
| 384 |
+
def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
|
| 385 |
+
parts = prompt.split(boa_token)
|
| 386 |
+
result = ""
|
| 387 |
+
for i in range(len(parts) - 1):
|
| 388 |
+
result += parts[i]
|
| 389 |
+
if i < len(audio_sequences):
|
| 390 |
+
result += audio_sequences[i]
|
| 391 |
+
else:
|
| 392 |
+
result += boa_token
|
| 393 |
+
result += parts[-1]
|
| 394 |
+
return result
|
| 395 |
+
for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
|
| 396 |
+
audio_tokens_expanded = "".join([self.audio_token] * embed_size)
|
| 397 |
+
full_audio_sequence = f"\n\n{self.boa_token}{audio_tokens_expanded}{self.eoa_token}\n\n"
|
| 398 |
+
full_audio_sequences.append(full_audio_sequence)
|
| 399 |
+
|
| 400 |
+
text = [replace_tokens_sequentially(prompt, self.boa_token, [audio_sequences]) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
|
| 401 |
+
#text = [prompt.replace(self.boa_token, audio_sequences) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
|
| 402 |
+
|
| 403 |
+
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
| 404 |
+
|
| 405 |
+
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
| 406 |
+
array_ids = text_inputs["input_ids"]
|
| 407 |
+
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
| 408 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 409 |
+
mm_token_type_ids[array_ids == self.audio_token_id] = 2
|
| 410 |
+
|
| 411 |
+
has_vision_ids = np.any(mm_token_type_ids == 1, axis=1)
|
| 412 |
+
has_audio_ids = np.any(mm_token_type_ids == 2, axis=1)
|
| 413 |
+
|
| 414 |
+
input_modes = (has_audio_ids << 1) | has_vision_ids
|
| 415 |
+
|
| 416 |
+
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
|
| 417 |
+
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
| 418 |
+
text_inputs["input_modes"] = input_modes.tolist()
|
| 419 |
+
|
| 420 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
|
| 421 |
+
|
| 422 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
| 423 |
+
def batch_decode(self, *args, **kwargs):
|
| 424 |
+
"""
|
| 425 |
+
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 426 |
+
refer to the docstring of this method for more information.
|
| 427 |
+
"""
|
| 428 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 429 |
+
|
| 430 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
|
| 431 |
+
def decode(self, *args, **kwargs):
|
| 432 |
+
"""
|
| 433 |
+
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 434 |
+
the docstring of this method for more information.
|
| 435 |
+
"""
|
| 436 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 437 |
+
|
| 438 |
+
@property
|
| 439 |
+
def model_input_names(self):
|
| 440 |
+
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
|
| 441 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 442 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 443 |
+
|
| 444 |
+
AutoFeatureExtractor.register("Gemma3AudioFeatureExtractor", Gemma3AudioFeatureExtractor)
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"audio_compression_rate": 8,
|
| 3 |
+
"audio_downsample_rate": 1,
|
| 4 |
+
"audio_feat_stride": 1,
|
| 5 |
+
"compression_rate": 8,
|
| 6 |
+
"do_convert_rgb": null,
|
| 7 |
+
"do_normalize": true,
|
| 8 |
+
"do_pan_and_scan": null,
|
| 9 |
+
"do_rescale": true,
|
| 10 |
+
"do_resize": true,
|
| 11 |
+
"feat_stride": 1,
|
| 12 |
+
"feature_extractor_type": "Gemma3AudioFeatureExtractor",
|
| 13 |
+
"feature_size": 80,
|
| 14 |
+
"image_mean": [
|
| 15 |
+
0.5,
|
| 16 |
+
0.5,
|
| 17 |
+
0.5
|
| 18 |
+
],
|
| 19 |
+
"image_processor_type": "Gemma3ImageProcessor",
|
| 20 |
+
"image_seq_length": 256,
|
| 21 |
+
"image_std": [
|
| 22 |
+
0.5,
|
| 23 |
+
0.5,
|
| 24 |
+
0.5
|
| 25 |
+
],
|
| 26 |
+
"padding_side": "right",
|
| 27 |
+
"padding_value": 0.0,
|
| 28 |
+
"pan_and_scan_max_num_crops": null,
|
| 29 |
+
"pan_and_scan_min_crop_size": null,
|
| 30 |
+
"pan_and_scan_min_ratio_to_activate": null,
|
| 31 |
+
"processor_class": "Gemma3OmniProcessor",
|
| 32 |
+
"qformer_compression_rate": 1,
|
| 33 |
+
"resample": 2,
|
| 34 |
+
"rescale_factor": 0.00392156862745098,
|
| 35 |
+
"return_attention_mask": true,
|
| 36 |
+
"sampling_rate": 16000,
|
| 37 |
+
"size": {
|
| 38 |
+
"height": 896,
|
| 39 |
+
"width": 896
|
| 40 |
+
}
|
| 41 |
+
}
|
processor_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "preprocessing_gemma3omni.Gemma3OmniProcessor"
|
| 4 |
+
},
|
| 5 |
+
"image_seq_length": 256,
|
| 6 |
+
"processor_class": "Gemma3Processor"
|
| 7 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"audio_token": "<audio_soft_token>",
|
| 3 |
+
"boa_token": "<start_of_audio>",
|
| 4 |
+
"boi_token": "<start_of_image>",
|
| 5 |
+
"bos_token": {
|
| 6 |
+
"content": "<bos>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"eoa_token": "<end_of_audio>",
|
| 13 |
+
"eoi_token": "<end_of_image>",
|
| 14 |
+
"eos_token": {
|
| 15 |
+
"content": "<eos>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"image_token": "<image_soft_token>",
|
| 22 |
+
"pad_token": {
|
| 23 |
+
"content": "<pad>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false
|
| 28 |
+
},
|
| 29 |
+
"unk_token": {
|
| 30 |
+
"content": "<unk>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false
|
| 35 |
+
}
|
| 36 |
+
}
|
speech_conformer_encoder.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52941f2ba60fdcc48edb940f4252f6d874d0c369323dab293168015122e556be
|
| 3 |
+
size 33384559
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
| 3 |
+
size 4689074
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training.py
ADDED
|
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import sacrebleu
|
| 14 |
+
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoProcessor,
|
| 20 |
+
AutoModel,
|
| 21 |
+
BatchFeature,
|
| 22 |
+
Trainer,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
StoppingCriteria,
|
| 25 |
+
StoppingCriteriaList,
|
| 26 |
+
)
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
|
| 29 |
+
import soundfile as sf
|
| 30 |
+
from datasets import Audio
|
| 31 |
+
import random
|
| 32 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 33 |
+
_IGNORE_INDEX = -100
|
| 34 |
+
class BaseAudioDataset(Dataset):
|
| 35 |
+
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 36 |
+
self.processor = processor
|
| 37 |
+
self.training = "train" in split or 'other' in split
|
| 38 |
+
self.debug = debug
|
| 39 |
+
self.sampling_rate = sampling_rate
|
| 40 |
+
self.name = ""
|
| 41 |
+
|
| 42 |
+
def set_dataset_name(self, name):
|
| 43 |
+
self.name = name
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
|
| 47 |
+
original_size = len(data)
|
| 48 |
+
|
| 49 |
+
data = data.cast_column(audio_field, Audio(decode=False))
|
| 50 |
+
|
| 51 |
+
def identify_corrupted_files(example):
|
| 52 |
+
try:
|
| 53 |
+
sf.read(example[audio_field]["path"])
|
| 54 |
+
|
| 55 |
+
for field in text_fields:
|
| 56 |
+
if field in example and example[field].replace('"', '') == "":
|
| 57 |
+
return False
|
| 58 |
+
return True
|
| 59 |
+
except Exception:
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
data = data.filter(identify_corrupted_files, num_proc=16)
|
| 63 |
+
validated_size = len(data)
|
| 64 |
+
|
| 65 |
+
# Audio Decoding
|
| 66 |
+
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
|
| 67 |
+
|
| 68 |
+
if debug:
|
| 69 |
+
print(f"Dataset: {dataset_name}")
|
| 70 |
+
print(f"Original data nums: {original_size}")
|
| 71 |
+
print(f"After filtering data nums: {validated_size}")
|
| 72 |
+
print(f"Filtering ratio: {validated_size/original_size:.2%}")
|
| 73 |
+
|
| 74 |
+
return data
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
|
| 78 |
+
original_size = len(data)
|
| 79 |
+
|
| 80 |
+
def filter_audio_by_length(example):
|
| 81 |
+
try:
|
| 82 |
+
audio = example[audio_field]['array']
|
| 83 |
+
channel = 1
|
| 84 |
+
if hasattr(audio, 'ndim') and audio.ndim > 1:
|
| 85 |
+
channel = audio.ndim
|
| 86 |
+
audio = audio.squeeze()
|
| 87 |
+
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
|
| 88 |
+
return min_sec <= audio_length <= max_sec
|
| 89 |
+
except Exception as e:
|
| 90 |
+
if debug:
|
| 91 |
+
print(f"Error : {str(e)[:100]}... - sample excluded")
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
data = data.filter(filter_audio_by_length, num_proc=16)
|
| 95 |
+
filtered_size = len(data)
|
| 96 |
+
|
| 97 |
+
if debug:
|
| 98 |
+
print(f"Before Length Filtering data nums: {original_size}")
|
| 99 |
+
print(f"After Length Filtering data nums: {filtered_size}")
|
| 100 |
+
print(f"Filtering ratio: {filtered_size/original_size:.2%}")
|
| 101 |
+
|
| 102 |
+
return data
|
| 103 |
+
|
| 104 |
+
def prepare_model_inputs(self, audio_array, instruction, answer_text):
|
| 105 |
+
user_message = {
|
| 106 |
+
'role': 'user',
|
| 107 |
+
'content': '<start_of_audio>' + instruction,
|
| 108 |
+
}
|
| 109 |
+
prompt = self.processor.tokenizer.apply_chat_template(
|
| 110 |
+
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
inputs = self.processor(
|
| 114 |
+
text=prompt,
|
| 115 |
+
audio=[audio_array],
|
| 116 |
+
add_special_tokens=False,
|
| 117 |
+
return_tensors='pt'
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
answer = f"{answer_text}{ANSWER_SUFFIX}"
|
| 121 |
+
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
|
| 122 |
+
|
| 123 |
+
if self.debug:
|
| 124 |
+
self.debug = False
|
| 125 |
+
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
|
| 126 |
+
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
|
| 127 |
+
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
|
| 128 |
+
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
|
| 129 |
+
|
| 130 |
+
if self.training:
|
| 131 |
+
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
|
| 132 |
+
labels = torch.full_like(input_ids, _IGNORE_INDEX)
|
| 133 |
+
labels[:, -answer_ids.shape[1]:] = answer_ids
|
| 134 |
+
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
|
| 135 |
+
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
|
| 136 |
+
else:
|
| 137 |
+
input_ids = inputs.input_ids
|
| 138 |
+
labels = answer_ids
|
| 139 |
+
token_type_ids = inputs.token_type_ids
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
'input_ids': input_ids,
|
| 143 |
+
'labels': labels,
|
| 144 |
+
'token_type_ids': token_type_ids,
|
| 145 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 146 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 147 |
+
'input_modes': inputs.input_modes,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Libri Speech Dataset Class
|
| 152 |
+
class LibriSpeechDataset(BaseAudioDataset):
|
| 153 |
+
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
|
| 154 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 155 |
+
|
| 156 |
+
self.set_dataset_name(f"LibriSpeech_{subset}")
|
| 157 |
+
# only ASR
|
| 158 |
+
self.ast = False
|
| 159 |
+
self.lang = "en"
|
| 160 |
+
|
| 161 |
+
# load dataset
|
| 162 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
|
| 163 |
+
subset,
|
| 164 |
+
split=split,
|
| 165 |
+
trust_remote_code=True,
|
| 166 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# (Optional) Audio length Filtering
|
| 170 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 171 |
+
|
| 172 |
+
# Instruction Setting
|
| 173 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 174 |
+
|
| 175 |
+
def __len__(self):
|
| 176 |
+
return len(self.data)
|
| 177 |
+
|
| 178 |
+
def __getitem__(self, idx):
|
| 179 |
+
data = self.data[idx]
|
| 180 |
+
|
| 181 |
+
# Libri Speech is only for ASR
|
| 182 |
+
answer_text = data["text"].replace('"', '')
|
| 183 |
+
|
| 184 |
+
return self.prepare_model_inputs(
|
| 185 |
+
data["audio"]["array"],
|
| 186 |
+
self.instruction,
|
| 187 |
+
answer_text
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# common_voice_16_1 dataset
|
| 191 |
+
class CommonVoiceDataset(BaseAudioDataset):
|
| 192 |
+
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
|
| 193 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 194 |
+
|
| 195 |
+
self.set_dataset_name(f"CommonVoice_{source_lang}")
|
| 196 |
+
# only ASR
|
| 197 |
+
self.ast = False
|
| 198 |
+
self.lang=source_lang
|
| 199 |
+
|
| 200 |
+
# load dataset
|
| 201 |
+
if source_lang=="zh-TW":
|
| 202 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
|
| 203 |
+
else:
|
| 204 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
|
| 205 |
+
self.data = load_dataset(data_path,
|
| 206 |
+
source_lang,
|
| 207 |
+
split=split,
|
| 208 |
+
trust_remote_code=True,
|
| 209 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 210 |
+
)
|
| 211 |
+
def prepare_dataset(batch):
|
| 212 |
+
"""Function to preprocess the dataset with the .map method"""
|
| 213 |
+
transcription = batch["sentence"]
|
| 214 |
+
|
| 215 |
+
if transcription.startswith('"') and transcription.endswith('"'):
|
| 216 |
+
# we can remove trailing quotation marks as they do not affect the transcription
|
| 217 |
+
transcription = transcription[1:-1]
|
| 218 |
+
|
| 219 |
+
if transcription[-1] not in [".", "?", "!"]:
|
| 220 |
+
# append a full-stop to sentences that do not end in punctuation
|
| 221 |
+
transcription = transcription + "."
|
| 222 |
+
|
| 223 |
+
batch["sentence"] = transcription
|
| 224 |
+
|
| 225 |
+
return batch
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
import opencc
|
| 229 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 230 |
+
def To_zhTW(batch):
|
| 231 |
+
|
| 232 |
+
transcription = converter.convert(batch["sentence"])
|
| 233 |
+
batch["sentence"] = transcription
|
| 234 |
+
|
| 235 |
+
return batch
|
| 236 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 237 |
+
if source_lang=='zh-CN':
|
| 238 |
+
self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# (Optional) Audio length Filtering
|
| 242 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 243 |
+
|
| 244 |
+
if source_lang == "zh-TW" and split=='train':
|
| 245 |
+
import torchaudio
|
| 246 |
+
from torchaudio import transforms
|
| 247 |
+
import copy
|
| 248 |
+
import pickle
|
| 249 |
+
import os
|
| 250 |
+
def subsample(batch):
|
| 251 |
+
batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
|
| 252 |
+
batch['audio']['sampling_rate']=16000
|
| 253 |
+
return batch
|
| 254 |
+
def TW_data_augment_fast(batch):
|
| 255 |
+
speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
|
| 256 |
+
new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
|
| 257 |
+
batch['audio']['array'] = new_array_fast
|
| 258 |
+
return batch
|
| 259 |
+
def TW_data_augment_slow(batch):
|
| 260 |
+
speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
|
| 261 |
+
new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
|
| 262 |
+
batch['audio']['array'] = new_array_slow
|
| 263 |
+
return batch
|
| 264 |
+
# data = self.data.map(subsample, num_proc=1, desc="subsample")
|
| 265 |
+
fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
|
| 266 |
+
if not os.path.exists(fast_path):
|
| 267 |
+
data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
|
| 268 |
+
with open(fast_path,'wb') as f:
|
| 269 |
+
pickle.dump(data_fast,f)
|
| 270 |
+
else:
|
| 271 |
+
with open(fast_path,'rb') as f:
|
| 272 |
+
data_fast=pickle.load(f)
|
| 273 |
+
|
| 274 |
+
slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
|
| 275 |
+
if not os.path.exists(slow_path):
|
| 276 |
+
data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
|
| 277 |
+
with open(slow_path,'wb') as f:
|
| 278 |
+
pickle.dump(data_slow,f)
|
| 279 |
+
else:
|
| 280 |
+
with open(slow_path,'rb') as f:
|
| 281 |
+
data_slow=pickle.load(f)
|
| 282 |
+
self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
|
| 283 |
+
|
| 284 |
+
# Instruction Setting
|
| 285 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 286 |
+
|
| 287 |
+
def __len__(self):
|
| 288 |
+
return len(self.data)
|
| 289 |
+
|
| 290 |
+
def __getitem__(self, idx):
|
| 291 |
+
data = self.data[idx]
|
| 292 |
+
|
| 293 |
+
answer_text = data["sentence"]
|
| 294 |
+
return self.prepare_model_inputs(
|
| 295 |
+
data["audio"]["array"],
|
| 296 |
+
self.instruction,
|
| 297 |
+
answer_text
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# Fleurs Dataset Class
|
| 302 |
+
class FleursDataset(BaseAudioDataset):
|
| 303 |
+
def __init__(self, processor, split, source_lang, target_lang=None,
|
| 304 |
+
mode="asr", sampling_rate=16000, debug=False):
|
| 305 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 306 |
+
|
| 307 |
+
self.set_dataset_name("Fleurs")
|
| 308 |
+
# Mode Setting (ASR or AST)
|
| 309 |
+
if mode not in ["asr", "ast"]:
|
| 310 |
+
raise ValueError("mode must be 'asr' or 'ast'.")
|
| 311 |
+
|
| 312 |
+
self.mode = mode
|
| 313 |
+
self.ast = (mode == "ast")
|
| 314 |
+
self.source_lang = source_lang
|
| 315 |
+
|
| 316 |
+
# Language name mapping (expand if needed)
|
| 317 |
+
self.lang_names = {
|
| 318 |
+
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
# load dataset - source language dataset
|
| 322 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 323 |
+
source_lang,
|
| 324 |
+
split=split,
|
| 325 |
+
trust_remote_code=True,
|
| 326 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 327 |
+
)
|
| 328 |
+
import opencc
|
| 329 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 330 |
+
def prepare_dataset(batch):
|
| 331 |
+
transcription = converter.convert(batch["transcription"])
|
| 332 |
+
batch["transcription"] = transcription
|
| 333 |
+
|
| 334 |
+
return batch
|
| 335 |
+
if (source_lang=="cmn_hans_cn"):
|
| 336 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 337 |
+
|
| 338 |
+
# (Optional) Audio length Filtering
|
| 339 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 340 |
+
self.target_lang_name = ""
|
| 341 |
+
# When AST mode, load target language dataset.
|
| 342 |
+
if self.ast:
|
| 343 |
+
if target_lang is None:
|
| 344 |
+
raise ValueError("AST mode requires target_lang.")
|
| 345 |
+
|
| 346 |
+
self.target_lang = target_lang
|
| 347 |
+
self.lang = f"{source_lang}_{target_lang}"
|
| 348 |
+
|
| 349 |
+
# load dataset - target language dataset (for translation)
|
| 350 |
+
target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 351 |
+
target_lang,
|
| 352 |
+
split=split,
|
| 353 |
+
trust_remote_code=True,
|
| 354 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 355 |
+
)
|
| 356 |
+
if target_lang=="cmn_hans_cn":
|
| 357 |
+
target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
|
| 358 |
+
source_dict = {item['id']: item for item in self.data}
|
| 359 |
+
target_dict = {item['id']: item for item in target_data}
|
| 360 |
+
|
| 361 |
+
# only Common ID, add translation fields
|
| 362 |
+
common_ids = set(source_dict.keys()) & set(target_dict.keys())
|
| 363 |
+
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
|
| 364 |
+
self.data = [
|
| 365 |
+
{**source_dict[id], 'translation': target_dict[id]['transcription']}
|
| 366 |
+
for id in common_ids
|
| 367 |
+
]
|
| 368 |
+
|
| 369 |
+
# Instruction Setting - use target language name
|
| 370 |
+
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 371 |
+
self.instruction = random.choice(INSTRUCTION["ast"])
|
| 372 |
+
else:
|
| 373 |
+
# ASR mode
|
| 374 |
+
self.lang = source_lang
|
| 375 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 376 |
+
|
| 377 |
+
if self.debug:
|
| 378 |
+
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
| 379 |
+
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
|
| 380 |
+
if self.ast:
|
| 381 |
+
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
|
| 382 |
+
print(f"dataset size: {len(self.data)}")
|
| 383 |
+
|
| 384 |
+
def __len__(self):
|
| 385 |
+
return len(self.data)
|
| 386 |
+
|
| 387 |
+
def __getitem__(self, idx):
|
| 388 |
+
data = self.data[idx]
|
| 389 |
+
audio_array = data["audio"]["array"]
|
| 390 |
+
|
| 391 |
+
if self.ast:
|
| 392 |
+
answer_text = data["translation"]
|
| 393 |
+
else:
|
| 394 |
+
answer_text = data["transcription"]
|
| 395 |
+
|
| 396 |
+
return self.prepare_model_inputs(
|
| 397 |
+
audio_array,
|
| 398 |
+
self.instruction.format(self.target_lang_name),
|
| 399 |
+
answer_text
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
def covost_collate_fn(batch):
|
| 403 |
+
input_ids_list = []
|
| 404 |
+
labels_list = []
|
| 405 |
+
token_type_ids_list = []
|
| 406 |
+
input_audio_embeds_list = []
|
| 407 |
+
audio_embed_sizes_list = []
|
| 408 |
+
audio_attention_mask_list = []
|
| 409 |
+
input_modes_list = []
|
| 410 |
+
for inputs in batch:
|
| 411 |
+
input_ids_list.append(inputs['input_ids'][0])
|
| 412 |
+
labels_list.append(inputs['labels'][0])
|
| 413 |
+
token_type_ids_list.append(inputs['token_type_ids'][0])
|
| 414 |
+
input_audio_embeds_list.append(inputs['input_audio_embeds'])
|
| 415 |
+
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
|
| 416 |
+
audio_attention_mask_list.append(
|
| 417 |
+
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
|
| 418 |
+
)
|
| 419 |
+
input_modes_list.append(inputs['input_modes'])
|
| 420 |
+
|
| 421 |
+
try:
|
| 422 |
+
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
|
| 423 |
+
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
|
| 424 |
+
labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
|
| 425 |
+
audio_attention_mask = (
|
| 426 |
+
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
|
| 427 |
+
if len(audio_attention_mask_list) > 1
|
| 428 |
+
else None
|
| 429 |
+
)
|
| 430 |
+
except Exception as e:
|
| 431 |
+
print(e)
|
| 432 |
+
print(input_ids_list)
|
| 433 |
+
print(labels_list)
|
| 434 |
+
raise
|
| 435 |
+
attention_mask = (input_ids != 0).long()
|
| 436 |
+
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
|
| 437 |
+
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
|
| 438 |
+
input_modes = torch.cat(input_modes_list)
|
| 439 |
+
|
| 440 |
+
return BatchFeature(
|
| 441 |
+
{
|
| 442 |
+
'input_ids': input_ids,
|
| 443 |
+
'labels': labels,
|
| 444 |
+
'token_type_ids': token_type_ids,
|
| 445 |
+
'attention_mask': attention_mask,
|
| 446 |
+
'input_audio_embeds': input_audio_embeds,
|
| 447 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 448 |
+
'audio_attention_mask': audio_attention_mask,
|
| 449 |
+
'input_modes': input_modes,
|
| 450 |
+
}
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
def pad_sequence(sequences, padding_side='left', padding_value=0):
|
| 454 |
+
"""
|
| 455 |
+
Pad a list of sequences to the same length.
|
| 456 |
+
sequences: list of tensors in [seq_len, *] shape
|
| 457 |
+
"""
|
| 458 |
+
assert padding_side in ['right', 'left']
|
| 459 |
+
max_size = sequences[0].size()
|
| 460 |
+
trailing_dims = max_size[1:]
|
| 461 |
+
max_len = max(len(seq) for seq in sequences)
|
| 462 |
+
batch_size = len(sequences)
|
| 463 |
+
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
|
| 464 |
+
for i, seq in enumerate(sequences):
|
| 465 |
+
length = seq.size(0)
|
| 466 |
+
if padding_side == 'right':
|
| 467 |
+
output.data[i, :length] = seq
|
| 468 |
+
else:
|
| 469 |
+
output.data[i, -length:] = seq
|
| 470 |
+
return output
|
| 471 |
+
|
| 472 |
+
def cat_with_pad(tensors, dim, padding_value=0):
|
| 473 |
+
"""
|
| 474 |
+
cat along dim, while pad to max for all other dims
|
| 475 |
+
"""
|
| 476 |
+
ndim = tensors[0].dim()
|
| 477 |
+
assert all(
|
| 478 |
+
t.dim() == ndim for t in tensors[1:]
|
| 479 |
+
), 'All tensors must have the same number of dimensions'
|
| 480 |
+
|
| 481 |
+
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
|
| 482 |
+
out_size[dim] = sum(t.shape[dim] for t in tensors)
|
| 483 |
+
output = tensors[0].new_full(out_size, padding_value)
|
| 484 |
+
|
| 485 |
+
index = 0
|
| 486 |
+
for t in tensors:
|
| 487 |
+
# Create a slice list where every dimension except dim is full slice
|
| 488 |
+
slices = [slice(0, t.shape[d]) for d in range(ndim)]
|
| 489 |
+
# Update only the concat dimension slice
|
| 490 |
+
slices[dim] = slice(index, index + t.shape[dim])
|
| 491 |
+
|
| 492 |
+
output[slices] = t
|
| 493 |
+
index += t.shape[dim]
|
| 494 |
+
|
| 495 |
+
return output
|
| 496 |
+
|
| 497 |
+
def count_parameters_by_module(model):
|
| 498 |
+
# dictionary for parameters number by modules
|
| 499 |
+
module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
|
| 500 |
+
|
| 501 |
+
# all params
|
| 502 |
+
total_params = 0
|
| 503 |
+
total_trainable_params = 0
|
| 504 |
+
|
| 505 |
+
# Check Embedding Token masks
|
| 506 |
+
embedding_masks = {}
|
| 507 |
+
for name, param in model.named_parameters():
|
| 508 |
+
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
|
| 509 |
+
# check if params has embedding_grad_mask_hook
|
| 510 |
+
for hook_id, hook_fn in param._backward_hooks.items():
|
| 511 |
+
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
|
| 512 |
+
# Accessing mask variables in the closure of hook functions
|
| 513 |
+
for cell in hook_fn.__closure__ or []:
|
| 514 |
+
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
|
| 515 |
+
# check mask tensor
|
| 516 |
+
embedding_masks[name] = ~cell.cell_contents # True : Trainable
|
| 517 |
+
|
| 518 |
+
# Count params by modules
|
| 519 |
+
for name, param in model.named_parameters():
|
| 520 |
+
# extracts top module_name
|
| 521 |
+
module_name = name.split('.')[0]
|
| 522 |
+
param_count = param.numel()
|
| 523 |
+
|
| 524 |
+
module_params[module_name]["total"] += param_count
|
| 525 |
+
total_params += param_count
|
| 526 |
+
|
| 527 |
+
if param.requires_grad:
|
| 528 |
+
# Only count for real trainable params. (with masks)
|
| 529 |
+
if name in embedding_masks:
|
| 530 |
+
trainable_count = embedding_masks[name].sum().item()
|
| 531 |
+
module_params[module_name]["trainable"] += trainable_count
|
| 532 |
+
total_trainable_params += trainable_count
|
| 533 |
+
else:
|
| 534 |
+
module_params[module_name]["trainable"] += param_count
|
| 535 |
+
total_trainable_params += param_count
|
| 536 |
+
|
| 537 |
+
print(f"All Params: {total_params:,}")
|
| 538 |
+
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
|
| 539 |
+
print("\nParams by Module:")
|
| 540 |
+
|
| 541 |
+
for module_name, counts in sorted(module_params.items()):
|
| 542 |
+
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
|
| 543 |
+
total_percentage = counts["total"] / total_params * 100
|
| 544 |
+
|
| 545 |
+
print(f"- {module_name}:")
|
| 546 |
+
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
|
| 547 |
+
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
|
| 548 |
+
|
| 549 |
+
return module_params
|
| 550 |
+
|
| 551 |
+
def create_model(model_name_or_path, revision="main", use_flash_attention = False):
|
| 552 |
+
model = AutoModel.from_pretrained(
|
| 553 |
+
model_name_or_path,
|
| 554 |
+
revision=revision,
|
| 555 |
+
torch_dtype=torch.bfloat16,
|
| 556 |
+
device_map="auto",
|
| 557 |
+
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
|
| 558 |
+
trust_remote_code=True,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# Set use_cache to False after model loaded
|
| 562 |
+
model.config.use_cache = False
|
| 563 |
+
|
| 564 |
+
# Freeze all parameters
|
| 565 |
+
for param in model.parameters():
|
| 566 |
+
param.requires_grad = False
|
| 567 |
+
|
| 568 |
+
model.set_lora_adapter('speech')
|
| 569 |
+
model.to(torch.bfloat16)
|
| 570 |
+
|
| 571 |
+
# (Optional) unfreeze audio_tower parameters
|
| 572 |
+
# for param in model.audio_tower.parameters():
|
| 573 |
+
# param.requires_grad = True
|
| 574 |
+
|
| 575 |
+
# Only unfreeze audio_projector parameters
|
| 576 |
+
for param in model.audio_projector.parameters():
|
| 577 |
+
param.requires_grad = True
|
| 578 |
+
|
| 579 |
+
# (Optional) unfreeze audio embed_tokens
|
| 580 |
+
train_embed = True
|
| 581 |
+
if train_embed:
|
| 582 |
+
embed_tokens = model.language_model.model.model.embed_tokens
|
| 583 |
+
|
| 584 |
+
embed_tokens.weight.requires_grad = False
|
| 585 |
+
|
| 586 |
+
# Added Speech token IDs (only this tokens be trainable)
|
| 587 |
+
trainable_token_ids = [256001, 256002]
|
| 588 |
+
|
| 589 |
+
embed_tokens.weight.requires_grad = True
|
| 590 |
+
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
|
| 591 |
+
mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
|
| 592 |
+
|
| 593 |
+
# backward hook, with gradient masking
|
| 594 |
+
def embedding_grad_mask_hook(grad):
|
| 595 |
+
return grad.masked_fill(mask, 0)
|
| 596 |
+
|
| 597 |
+
embed_tokens.weight.register_hook(embedding_grad_mask_hook)
|
| 598 |
+
|
| 599 |
+
model.language_model.model.model.embed_tokens = embed_tokens
|
| 600 |
+
|
| 601 |
+
count_parameters_by_module(model)
|
| 602 |
+
|
| 603 |
+
return model
|
| 604 |
+
|
| 605 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 606 |
+
|
| 607 |
+
INSTRUCTION = {
|
| 608 |
+
"ast": [
|
| 609 |
+
"Translate the audio to {0}.",
|
| 610 |
+
"Translate the audio clip into {0}.",
|
| 611 |
+
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
|
| 612 |
+
"Translate the provided audio file into {0}.",
|
| 613 |
+
"Convert the audio speech to {0} text.",
|
| 614 |
+
"Write an {0} translation of the audio file.",
|
| 615 |
+
"Translate spoken words from the audio into {0}.",
|
| 616 |
+
"Create an {0} version of the audio content.",
|
| 617 |
+
"Produce an accurate {0} translation of the audio.",
|
| 618 |
+
"Extract speech from the audio and translate it to {0}.",
|
| 619 |
+
"Turn the audio into readable {0} text.",
|
| 620 |
+
"Write all spoken content from the audio in {0}.",
|
| 621 |
+
"Generate an {0} translation of the speech in the file.",
|
| 622 |
+
"Convert the recording into {0} text.",
|
| 623 |
+
"Accurately translate the audio recording to {0}.",
|
| 624 |
+
"Write down dialogue from the given audio in {0}.",
|
| 625 |
+
"Translate all speech in this audio file to {0}.",
|
| 626 |
+
"Create an accurate {0} version of the speech.",
|
| 627 |
+
"Perform a complete {0} translation of the audio."
|
| 628 |
+
],
|
| 629 |
+
"asr": [
|
| 630 |
+
"Transcribe the audio clip into text.",
|
| 631 |
+
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
|
| 632 |
+
"Transcribe the provided audio file into text.",
|
| 633 |
+
"Convert the audio speech to text.",
|
| 634 |
+
"Write a transcript of the audio file.",
|
| 635 |
+
"Transcribe spoken words from the audio.",
|
| 636 |
+
"Create a text version of the audio content.",
|
| 637 |
+
"Produce a verbatim transcript of the audio.",
|
| 638 |
+
"Extract and transcribe speech from the audio.",
|
| 639 |
+
"Turn the audio into readable text.",
|
| 640 |
+
"Write all spoken words from the audio.",
|
| 641 |
+
"Generate a transcript of the speech in the file.",
|
| 642 |
+
"Convert the recording into a text transcript.",
|
| 643 |
+
"Accurately transcribe the audio recording.",
|
| 644 |
+
"Write down dialogue from the given audio.",
|
| 645 |
+
"Transcribe all speech in this audio file.",
|
| 646 |
+
"Create an accurate text version of the speech.",
|
| 647 |
+
"Perform a complete transcription of the audio."
|
| 648 |
+
],
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 652 |
+
_IGNORE_INDEX = -100
|
| 653 |
+
|
| 654 |
+
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
|
| 655 |
+
use_flash_attention = True
|
| 656 |
+
|
| 657 |
+
output_dir = '../gemma_tmp7'
|
| 658 |
+
batch_size = 128
|
| 659 |
+
batch_size_per_gpu = 16
|
| 660 |
+
learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
|
| 661 |
+
wd = 0.01
|
| 662 |
+
num_train_epochs = 15
|
| 663 |
+
|
| 664 |
+
revision = "main" #"v1.0"
|
| 665 |
+
|
| 666 |
+
processor = AutoProcessor.from_pretrained(
|
| 667 |
+
model_name_or_path,
|
| 668 |
+
revision=revision,
|
| 669 |
+
trust_remote_code=True,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
model = create_model(
|
| 673 |
+
model_name_or_path,
|
| 674 |
+
revision=revision,
|
| 675 |
+
use_flash_attention=use_flash_attention,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
train_datasets = []
|
| 679 |
+
|
| 680 |
+
# common voice asr
|
| 681 |
+
commonvoice_speech_tw2 = CommonVoiceDataset(
|
| 682 |
+
processor=processor,
|
| 683 |
+
source_lang="zh-TW",
|
| 684 |
+
split="other[:70%]"
|
| 685 |
+
)
|
| 686 |
+
train_datasets.append(commonvoice_speech_tw2)
|
| 687 |
+
|
| 688 |
+
commonvoice_speech_cn = CommonVoiceDataset(
|
| 689 |
+
processor=processor,
|
| 690 |
+
source_lang="zh-CN",
|
| 691 |
+
split="train[:50%]"
|
| 692 |
+
)
|
| 693 |
+
train_datasets.append(commonvoice_speech_cn)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
commonvoice_speech_tw = CommonVoiceDataset(
|
| 697 |
+
processor=processor,
|
| 698 |
+
source_lang="zh-TW",
|
| 699 |
+
split="train"
|
| 700 |
+
)
|
| 701 |
+
train_datasets.append(commonvoice_speech_tw)
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
# Libri Speech Clean ASR mode (English -> English text)
|
| 707 |
+
libri_speech_clean = LibriSpeechDataset(
|
| 708 |
+
processor=processor,
|
| 709 |
+
subset="clean",
|
| 710 |
+
split="train.360[:50%]"
|
| 711 |
+
)
|
| 712 |
+
train_datasets.append(libri_speech_clean)
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# Fleurs ASR mode (English -> English text)
|
| 716 |
+
en_asr_fleurs = FleursDataset(
|
| 717 |
+
processor=processor,
|
| 718 |
+
split="train",
|
| 719 |
+
source_lang="en_us", # English
|
| 720 |
+
mode="asr"
|
| 721 |
+
)
|
| 722 |
+
train_datasets.append(en_asr_fleurs)
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
# en_ch_ast_fleurs = FleursDataset(
|
| 726 |
+
# processor=processor,
|
| 727 |
+
# split="train",
|
| 728 |
+
# source_lang="en_us",
|
| 729 |
+
# target_lang="cmn_hans_cn",
|
| 730 |
+
# mode="ast"
|
| 731 |
+
# )
|
| 732 |
+
# train_datasets.append(en_ch_ast_fleurs)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
ch_asr_fleurs = FleursDataset(
|
| 737 |
+
processor=processor,
|
| 738 |
+
split="train",
|
| 739 |
+
source_lang="cmn_hans_cn",
|
| 740 |
+
mode="asr"
|
| 741 |
+
)
|
| 742 |
+
train_datasets.append(ch_asr_fleurs)
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
# ch_en_ast_fleurs = FleursDataset(
|
| 746 |
+
# processor=processor,
|
| 747 |
+
# split="train",
|
| 748 |
+
# source_lang="cmn_hans_cn",
|
| 749 |
+
# target_lang="en_us",
|
| 750 |
+
# mode="ast"
|
| 751 |
+
# )
|
| 752 |
+
# train_datasets.append(ch_en_ast_fleurs)
|
| 753 |
+
|
| 754 |
+
print("Count Num of Datasets", len(train_datasets))
|
| 755 |
+
print([len(dataset) for dataset in train_datasets])
|
| 756 |
+
|
| 757 |
+
# ConcatDataset
|
| 758 |
+
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
| 759 |
+
print("Count Length of Datas", len(train_dataset))
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
# Check GPUs
|
| 764 |
+
num_gpus = torch.cuda.device_count()
|
| 765 |
+
print(f'training on {num_gpus} GPUs')
|
| 766 |
+
|
| 767 |
+
assert (
|
| 768 |
+
batch_size % (num_gpus * batch_size_per_gpu) == 0
|
| 769 |
+
), 'Batch size must be divisible by the number of GPUs'
|
| 770 |
+
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
|
| 771 |
+
|
| 772 |
+
# hard coded training args
|
| 773 |
+
dp_config = {
|
| 774 |
+
"fp16": {
|
| 775 |
+
"enabled": "auto",
|
| 776 |
+
"loss_scale": 0,
|
| 777 |
+
"loss_scale_window": 1000,
|
| 778 |
+
"initial_scale_power": 16,
|
| 779 |
+
"hysteresis": 2,
|
| 780 |
+
"min_loss_scale": 1
|
| 781 |
+
},
|
| 782 |
+
"zero_optimization": {
|
| 783 |
+
"stage": 2,
|
| 784 |
+
"allgather_partitions": True,
|
| 785 |
+
"allgather_bucket_size": 5e8,
|
| 786 |
+
"overlap_comm": False,
|
| 787 |
+
"reduce_scatter": True,
|
| 788 |
+
"reduce_bucket_size": 5e8,
|
| 789 |
+
"contiguous_gradients": True,
|
| 790 |
+
"cpu_offload": True
|
| 791 |
+
},
|
| 792 |
+
|
| 793 |
+
"train_batch_size": "auto",
|
| 794 |
+
"gradient_accumulation_steps": "auto",
|
| 795 |
+
"optimizer": {
|
| 796 |
+
"type": "AdamW",
|
| 797 |
+
"params": {
|
| 798 |
+
"lr": "auto",
|
| 799 |
+
"betas": 'auto',
|
| 800 |
+
"eps": 'auto',
|
| 801 |
+
"weight_decay": "auto"
|
| 802 |
+
}
|
| 803 |
+
},
|
| 804 |
+
"scheduler": {
|
| 805 |
+
"type": "WarmupDecayLR",
|
| 806 |
+
"params": {
|
| 807 |
+
"warmup_min_lr": "auto",
|
| 808 |
+
"warmup_max_lr": "auto",
|
| 809 |
+
"warmup_num_steps": "auto",
|
| 810 |
+
"total_num_steps": "auto"
|
| 811 |
+
}
|
| 812 |
+
},
|
| 813 |
+
"gradient_clipping": 1.0,
|
| 814 |
+
"zero_optimization": {
|
| 815 |
+
"stage": 0
|
| 816 |
+
}
|
| 817 |
+
}
|
| 818 |
+
training_args = TrainingArguments(
|
| 819 |
+
num_train_epochs=num_train_epochs,
|
| 820 |
+
per_device_train_batch_size=batch_size_per_gpu,
|
| 821 |
+
gradient_checkpointing=True,
|
| 822 |
+
gradient_checkpointing_kwargs={'use_reentrant': False},
|
| 823 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 824 |
+
optim='adamw_torch',
|
| 825 |
+
adam_beta1=0.9,
|
| 826 |
+
adam_beta2=0.95,
|
| 827 |
+
adam_epsilon=1e-7,
|
| 828 |
+
learning_rate=learning_rate,
|
| 829 |
+
weight_decay=wd,
|
| 830 |
+
max_grad_norm=1.0,
|
| 831 |
+
lr_scheduler_type='cosine',
|
| 832 |
+
warmup_steps=50,
|
| 833 |
+
logging_steps=10,
|
| 834 |
+
output_dir=output_dir,
|
| 835 |
+
save_total_limit=10,
|
| 836 |
+
save_only_model=True,
|
| 837 |
+
bf16=True,
|
| 838 |
+
fp16=False,
|
| 839 |
+
remove_unused_columns=False,
|
| 840 |
+
report_to='none',
|
| 841 |
+
deepspeed=dp_config if num_gpus==1 else None,
|
| 842 |
+
disable_tqdm=False,
|
| 843 |
+
dataloader_num_workers=4,
|
| 844 |
+
save_strategy='steps',
|
| 845 |
+
save_steps=1000,
|
| 846 |
+
ddp_find_unused_parameters=True,
|
| 847 |
+
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
out_path = Path(training_args.output_dir)
|
| 851 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 852 |
+
|
| 853 |
+
# create optimizer only for trainable params
|
| 854 |
+
optimizer = torch.optim.AdamW(
|
| 855 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 856 |
+
lr=learning_rate,
|
| 857 |
+
weight_decay=wd,
|
| 858 |
+
betas=(0.9, 0.95),
|
| 859 |
+
eps=1e-7,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
# Trainer Setting
|
| 863 |
+
trainer = Trainer(
|
| 864 |
+
model=model,
|
| 865 |
+
args=training_args,
|
| 866 |
+
data_collator=covost_collate_fn,
|
| 867 |
+
train_dataset=train_dataset,
|
| 868 |
+
optimizers=(optimizer, None)
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
trainer.train()
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
# # 1. Save LoRA Adapter
|
| 875 |
+
model.language_model.model.save_pretrained(output_dir)
|
| 876 |
+
|
| 877 |
+
# # 1-1. Delete Markdown file
|
| 878 |
+
# markdown_file = os.path.join(output_dir, "README.md")
|
| 879 |
+
# if os.path.exists(markdown_file):
|
| 880 |
+
# os.remove(markdown_file)
|
| 881 |
+
|
| 882 |
+
# 2. Save entire model
|
| 883 |
+
model.save_pretrained(output_dir)
|
training_multiturn.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import sacrebleu
|
| 14 |
+
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoProcessor,
|
| 20 |
+
AutoModel,
|
| 21 |
+
BatchFeature,
|
| 22 |
+
Trainer,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
StoppingCriteria,
|
| 25 |
+
StoppingCriteriaList,
|
| 26 |
+
)
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
|
| 29 |
+
import soundfile as sf
|
| 30 |
+
from datasets import Audio
|
| 31 |
+
import random
|
| 32 |
+
from ASRDataset import *
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def count_parameters_by_module(model):
|
| 36 |
+
# dictionary for parameters number by modules
|
| 37 |
+
module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
|
| 38 |
+
|
| 39 |
+
# all params
|
| 40 |
+
total_params = 0
|
| 41 |
+
total_trainable_params = 0
|
| 42 |
+
|
| 43 |
+
# Check Embedding Token masks
|
| 44 |
+
embedding_masks = {}
|
| 45 |
+
for name, param in model.named_parameters():
|
| 46 |
+
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
|
| 47 |
+
# check if params has embedding_grad_mask_hook
|
| 48 |
+
for hook_id, hook_fn in param._backward_hooks.items():
|
| 49 |
+
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
|
| 50 |
+
# Accessing mask variables in the closure of hook functions
|
| 51 |
+
for cell in hook_fn.__closure__ or []:
|
| 52 |
+
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
|
| 53 |
+
# check mask tensor
|
| 54 |
+
embedding_masks[name] = ~cell.cell_contents # True : Trainable
|
| 55 |
+
|
| 56 |
+
# Count params by modules
|
| 57 |
+
for name, param in model.named_parameters():
|
| 58 |
+
# extracts top module_name
|
| 59 |
+
module_name = name.split('.')[0]
|
| 60 |
+
param_count = param.numel()
|
| 61 |
+
|
| 62 |
+
module_params[module_name]["total"] += param_count
|
| 63 |
+
total_params += param_count
|
| 64 |
+
|
| 65 |
+
if param.requires_grad:
|
| 66 |
+
# Only count for real trainable params. (with masks)
|
| 67 |
+
if name in embedding_masks:
|
| 68 |
+
trainable_count = embedding_masks[name].sum().item()
|
| 69 |
+
module_params[module_name]["trainable"] += trainable_count
|
| 70 |
+
total_trainable_params += trainable_count
|
| 71 |
+
else:
|
| 72 |
+
module_params[module_name]["trainable"] += param_count
|
| 73 |
+
total_trainable_params += param_count
|
| 74 |
+
|
| 75 |
+
print(f"All Params: {total_params:,}")
|
| 76 |
+
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
|
| 77 |
+
print("\nParams by Module:")
|
| 78 |
+
|
| 79 |
+
for module_name, counts in sorted(module_params.items()):
|
| 80 |
+
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
|
| 81 |
+
total_percentage = counts["total"] / total_params * 100
|
| 82 |
+
|
| 83 |
+
print(f"- {module_name}:")
|
| 84 |
+
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
|
| 85 |
+
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
|
| 86 |
+
|
| 87 |
+
return module_params
|
| 88 |
+
|
| 89 |
+
def create_model(model_name_or_path, revision="main", use_flash_attention = False):
|
| 90 |
+
model = AutoModel.from_pretrained(
|
| 91 |
+
model_name_or_path,
|
| 92 |
+
revision=revision,
|
| 93 |
+
torch_dtype=torch.bfloat16,
|
| 94 |
+
device_map="auto",
|
| 95 |
+
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
|
| 96 |
+
trust_remote_code=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Set use_cache to False after model loaded
|
| 100 |
+
model.config.use_cache = False
|
| 101 |
+
|
| 102 |
+
# Freeze all parameters
|
| 103 |
+
for param in model.parameters():
|
| 104 |
+
param.requires_grad = False
|
| 105 |
+
|
| 106 |
+
model.set_lora_adapter('speech')
|
| 107 |
+
model.to(torch.bfloat16)
|
| 108 |
+
|
| 109 |
+
# (Optional) unfreeze audio_tower parameters
|
| 110 |
+
# for param in model.audio_tower.parameters():
|
| 111 |
+
# param.requires_grad = True
|
| 112 |
+
|
| 113 |
+
# Only unfreeze audio_projector parameters
|
| 114 |
+
# for param in model.audio_projector.parameters():
|
| 115 |
+
# param.requires_grad = True
|
| 116 |
+
|
| 117 |
+
# (Optional) unfreeze audio embed_tokens
|
| 118 |
+
train_embed = True
|
| 119 |
+
if train_embed:
|
| 120 |
+
embed_tokens = model.language_model.model.model.embed_tokens
|
| 121 |
+
|
| 122 |
+
embed_tokens.weight.requires_grad = False
|
| 123 |
+
|
| 124 |
+
# Added Speech token IDs (only this tokens be trainable)
|
| 125 |
+
trainable_token_ids = [256001, 256002]
|
| 126 |
+
|
| 127 |
+
embed_tokens.weight.requires_grad = True
|
| 128 |
+
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
|
| 129 |
+
mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
|
| 130 |
+
|
| 131 |
+
# backward hook, with gradient masking
|
| 132 |
+
def embedding_grad_mask_hook(grad):
|
| 133 |
+
return grad.masked_fill(mask, 0)
|
| 134 |
+
|
| 135 |
+
embed_tokens.weight.register_hook(embedding_grad_mask_hook)
|
| 136 |
+
|
| 137 |
+
model.language_model.model.model.embed_tokens = embed_tokens
|
| 138 |
+
|
| 139 |
+
count_parameters_by_module(model)
|
| 140 |
+
|
| 141 |
+
return model
|
| 142 |
+
|
| 143 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 144 |
+
_IGNORE_INDEX = -100
|
| 145 |
+
|
| 146 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 147 |
+
_IGNORE_INDEX = -100
|
| 148 |
+
|
| 149 |
+
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
|
| 150 |
+
use_flash_attention = False
|
| 151 |
+
|
| 152 |
+
output_dir = '../gemma_tmp13'
|
| 153 |
+
batch_size = 24
|
| 154 |
+
batch_size_per_gpu = 8
|
| 155 |
+
learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
|
| 156 |
+
wd = 0.01
|
| 157 |
+
num_train_epochs = 10
|
| 158 |
+
|
| 159 |
+
revision = "main" #"v1.0"
|
| 160 |
+
|
| 161 |
+
processor = AutoProcessor.from_pretrained(
|
| 162 |
+
model_name_or_path,
|
| 163 |
+
revision=revision,
|
| 164 |
+
trust_remote_code=True,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
model = create_model(
|
| 168 |
+
model_name_or_path,
|
| 169 |
+
revision=revision,
|
| 170 |
+
use_flash_attention=use_flash_attention,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
train_datasets = []
|
| 174 |
+
|
| 175 |
+
pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
|
| 176 |
+
train_datasets.append(pickup_dataset)
|
| 177 |
+
|
| 178 |
+
# custom_tw_loc = TWCostumData(processor=processor,
|
| 179 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
|
| 180 |
+
# train_datasets.append(custom_tw_loc) # 1500
|
| 181 |
+
|
| 182 |
+
# custom_tw_loc2 = TWCostumData(processor=processor,
|
| 183 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
|
| 184 |
+
# train_datasets.append(custom_tw_loc2) # 9458
|
| 185 |
+
|
| 186 |
+
# custom_yating_tw_road = TWCostumData(processor=processor,
|
| 187 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
|
| 188 |
+
# train_datasets.append(custom_yating_tw_road) # 35224
|
| 189 |
+
|
| 190 |
+
# custom_tw_road = TWCostumData(processor=processor,
|
| 191 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
|
| 192 |
+
# train_datasets.append(custom_tw_road) # 1500
|
| 193 |
+
|
| 194 |
+
# custom_tw_road2 = TWCostumData(processor=processor,
|
| 195 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
|
| 196 |
+
# train_datasets.append(custom_tw_road2) # 35224
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
print("Count Num of Datasets", len(train_datasets))
|
| 201 |
+
print([len(dataset) for dataset in train_datasets])
|
| 202 |
+
|
| 203 |
+
# ConcatDataset
|
| 204 |
+
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
| 205 |
+
print("Count Length of Datas", len(train_dataset))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Check GPUs
|
| 210 |
+
num_gpus = torch.cuda.device_count()
|
| 211 |
+
print(f'training on {num_gpus} GPUs')
|
| 212 |
+
|
| 213 |
+
assert (
|
| 214 |
+
batch_size % (num_gpus * batch_size_per_gpu) == 0
|
| 215 |
+
), 'Batch size must be divisible by the number of GPUs'
|
| 216 |
+
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
|
| 217 |
+
|
| 218 |
+
# hard coded training args
|
| 219 |
+
dp_config = {
|
| 220 |
+
"fp16": {
|
| 221 |
+
"enabled": "auto",
|
| 222 |
+
"loss_scale": 0,
|
| 223 |
+
"loss_scale_window": 1000,
|
| 224 |
+
"initial_scale_power": 16,
|
| 225 |
+
"hysteresis": 2,
|
| 226 |
+
"min_loss_scale": 1
|
| 227 |
+
},
|
| 228 |
+
"zero_optimization": {
|
| 229 |
+
"stage": 2,
|
| 230 |
+
"allgather_partitions": True,
|
| 231 |
+
"allgather_bucket_size": 5e8,
|
| 232 |
+
"overlap_comm": False,
|
| 233 |
+
"reduce_scatter": True,
|
| 234 |
+
"reduce_bucket_size": 5e8,
|
| 235 |
+
"contiguous_gradients": True,
|
| 236 |
+
"cpu_offload": True
|
| 237 |
+
},
|
| 238 |
+
|
| 239 |
+
"train_batch_size": "auto",
|
| 240 |
+
"gradient_accumulation_steps": "auto",
|
| 241 |
+
"optimizer": {
|
| 242 |
+
"type": "AdamW",
|
| 243 |
+
"params": {
|
| 244 |
+
"lr": "auto",
|
| 245 |
+
"betas": 'auto',
|
| 246 |
+
"eps": 'auto',
|
| 247 |
+
"weight_decay": "auto"
|
| 248 |
+
}
|
| 249 |
+
},
|
| 250 |
+
"scheduler": {
|
| 251 |
+
"type": "WarmupDecayLR",
|
| 252 |
+
"params": {
|
| 253 |
+
"warmup_min_lr": "auto",
|
| 254 |
+
"warmup_max_lr": "auto",
|
| 255 |
+
"warmup_num_steps": "auto",
|
| 256 |
+
"total_num_steps": "auto"
|
| 257 |
+
}
|
| 258 |
+
},
|
| 259 |
+
"gradient_clipping": 1.0,
|
| 260 |
+
"zero_optimization": {
|
| 261 |
+
"stage": 0
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
training_args = TrainingArguments(
|
| 265 |
+
num_train_epochs=num_train_epochs,
|
| 266 |
+
per_device_train_batch_size=batch_size_per_gpu,
|
| 267 |
+
gradient_checkpointing=True,
|
| 268 |
+
gradient_checkpointing_kwargs={'use_reentrant': False},
|
| 269 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 270 |
+
optim='adamw_torch',
|
| 271 |
+
adam_beta1=0.9,
|
| 272 |
+
adam_beta2=0.95,
|
| 273 |
+
adam_epsilon=1e-7,
|
| 274 |
+
learning_rate=learning_rate,
|
| 275 |
+
weight_decay=wd,
|
| 276 |
+
max_grad_norm=1.0,
|
| 277 |
+
lr_scheduler_type='cosine',
|
| 278 |
+
warmup_steps=50,
|
| 279 |
+
logging_steps=10,
|
| 280 |
+
output_dir=output_dir,
|
| 281 |
+
save_total_limit=10,
|
| 282 |
+
save_only_model=True,
|
| 283 |
+
bf16=True,
|
| 284 |
+
fp16=False,
|
| 285 |
+
remove_unused_columns=False,
|
| 286 |
+
report_to='none',
|
| 287 |
+
deepspeed=None,
|
| 288 |
+
disable_tqdm=False,
|
| 289 |
+
dataloader_num_workers=16,
|
| 290 |
+
save_strategy='epoch',
|
| 291 |
+
# save_steps=2500,
|
| 292 |
+
ddp_find_unused_parameters=True,
|
| 293 |
+
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
out_path = Path(training_args.output_dir)
|
| 297 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 298 |
+
|
| 299 |
+
# create optimizer only for trainable params
|
| 300 |
+
optimizer = torch.optim.AdamW(
|
| 301 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 302 |
+
lr=learning_rate,
|
| 303 |
+
weight_decay=wd,
|
| 304 |
+
betas=(0.9, 0.95),
|
| 305 |
+
eps=1e-7,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Trainer Setting
|
| 309 |
+
trainer = Trainer(
|
| 310 |
+
model=model,
|
| 311 |
+
args=training_args,
|
| 312 |
+
data_collator=covost_collate_fn,
|
| 313 |
+
train_dataset=train_dataset,
|
| 314 |
+
optimizers=(optimizer, None)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
trainer.train()
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# # 1. Save LoRA Adapter
|
| 321 |
+
model.language_model.model.save_pretrained(output_dir)
|
| 322 |
+
|
| 323 |
+
# # 1-1. Delete Markdown file
|
| 324 |
+
# markdown_file = os.path.join(output_dir, "README.md")
|
| 325 |
+
# if os.path.exists(markdown_file):
|
| 326 |
+
# os.remove(markdown_file)
|
| 327 |
+
|
| 328 |
+
# 2. Save entire model
|
| 329 |
+
model.save_pretrained(output_dir)
|
training_multiturn_textonly.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import sacrebleu
|
| 14 |
+
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoProcessor,
|
| 20 |
+
AutoModel,
|
| 21 |
+
BatchFeature,
|
| 22 |
+
Trainer,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
StoppingCriteria,
|
| 25 |
+
StoppingCriteriaList,
|
| 26 |
+
)
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
|
| 29 |
+
import soundfile as sf
|
| 30 |
+
from datasets import Audio
|
| 31 |
+
import random
|
| 32 |
+
from ASRDataset import *
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def count_parameters_by_module(model):
|
| 36 |
+
# dictionary for parameters number by modules
|
| 37 |
+
module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
|
| 38 |
+
|
| 39 |
+
# all params
|
| 40 |
+
total_params = 0
|
| 41 |
+
total_trainable_params = 0
|
| 42 |
+
|
| 43 |
+
# Check Embedding Token masks
|
| 44 |
+
embedding_masks = {}
|
| 45 |
+
for name, param in model.named_parameters():
|
| 46 |
+
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
|
| 47 |
+
# check if params has embedding_grad_mask_hook
|
| 48 |
+
for hook_id, hook_fn in param._backward_hooks.items():
|
| 49 |
+
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
|
| 50 |
+
# Accessing mask variables in the closure of hook functions
|
| 51 |
+
for cell in hook_fn.__closure__ or []:
|
| 52 |
+
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
|
| 53 |
+
# check mask tensor
|
| 54 |
+
embedding_masks[name] = ~cell.cell_contents # True : Trainable
|
| 55 |
+
|
| 56 |
+
# Count params by modules
|
| 57 |
+
for name, param in model.named_parameters():
|
| 58 |
+
# extracts top module_name
|
| 59 |
+
module_name = name.split('.')[0]
|
| 60 |
+
param_count = param.numel()
|
| 61 |
+
|
| 62 |
+
module_params[module_name]["total"] += param_count
|
| 63 |
+
total_params += param_count
|
| 64 |
+
|
| 65 |
+
if param.requires_grad:
|
| 66 |
+
# Only count for real trainable params. (with masks)
|
| 67 |
+
if name in embedding_masks:
|
| 68 |
+
trainable_count = embedding_masks[name].sum().item()
|
| 69 |
+
module_params[module_name]["trainable"] += trainable_count
|
| 70 |
+
total_trainable_params += trainable_count
|
| 71 |
+
else:
|
| 72 |
+
module_params[module_name]["trainable"] += param_count
|
| 73 |
+
total_trainable_params += param_count
|
| 74 |
+
|
| 75 |
+
print(f"All Params: {total_params:,}")
|
| 76 |
+
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
|
| 77 |
+
print("\nParams by Module:")
|
| 78 |
+
|
| 79 |
+
for module_name, counts in sorted(module_params.items()):
|
| 80 |
+
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
|
| 81 |
+
total_percentage = counts["total"] / total_params * 100
|
| 82 |
+
|
| 83 |
+
print(f"- {module_name}:")
|
| 84 |
+
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
|
| 85 |
+
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
|
| 86 |
+
|
| 87 |
+
return module_params
|
| 88 |
+
|
| 89 |
+
def create_model(model_name_or_path, revision="main", use_flash_attention = False):
|
| 90 |
+
model = AutoModel.from_pretrained(
|
| 91 |
+
model_name_or_path,
|
| 92 |
+
revision=revision,
|
| 93 |
+
torch_dtype=torch.bfloat16,
|
| 94 |
+
device_map="auto",
|
| 95 |
+
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
|
| 96 |
+
trust_remote_code=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Set use_cache to False after model loaded
|
| 100 |
+
model.config.use_cache = False
|
| 101 |
+
|
| 102 |
+
# Freeze all parameters
|
| 103 |
+
for param in model.parameters():
|
| 104 |
+
param.requires_grad = False
|
| 105 |
+
|
| 106 |
+
model.set_lora_adapter('speech')
|
| 107 |
+
# model.set_lora_adapter('text')
|
| 108 |
+
model.to(torch.bfloat16)
|
| 109 |
+
|
| 110 |
+
# (Optional) unfreeze audio_tower parameters
|
| 111 |
+
# for param in model.audio_tower.parameters():
|
| 112 |
+
# param.requires_grad = True
|
| 113 |
+
|
| 114 |
+
# Only unfreeze audio_projector parameters
|
| 115 |
+
# for param in model.audio_projector.parameters():
|
| 116 |
+
# param.requires_grad = True
|
| 117 |
+
|
| 118 |
+
# (Optional) unfreeze audio embed_tokens
|
| 119 |
+
train_embed = True
|
| 120 |
+
if train_embed:
|
| 121 |
+
embed_tokens = model.language_model.model.model.embed_tokens
|
| 122 |
+
|
| 123 |
+
embed_tokens.weight.requires_grad = False
|
| 124 |
+
|
| 125 |
+
# Added Speech token IDs (only this tokens be trainable)
|
| 126 |
+
trainable_token_ids = [256001, 256002]
|
| 127 |
+
|
| 128 |
+
embed_tokens.weight.requires_grad = True
|
| 129 |
+
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
|
| 130 |
+
mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
|
| 131 |
+
|
| 132 |
+
# backward hook, with gradient masking
|
| 133 |
+
def embedding_grad_mask_hook(grad):
|
| 134 |
+
return grad.masked_fill(mask, 0)
|
| 135 |
+
|
| 136 |
+
embed_tokens.weight.register_hook(embedding_grad_mask_hook)
|
| 137 |
+
|
| 138 |
+
model.language_model.model.model.embed_tokens = embed_tokens
|
| 139 |
+
|
| 140 |
+
count_parameters_by_module(model)
|
| 141 |
+
|
| 142 |
+
return model
|
| 143 |
+
|
| 144 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 145 |
+
_IGNORE_INDEX = -100
|
| 146 |
+
|
| 147 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 148 |
+
_IGNORE_INDEX = -100
|
| 149 |
+
|
| 150 |
+
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
|
| 151 |
+
use_flash_attention = False
|
| 152 |
+
|
| 153 |
+
output_dir = '../gemma_tmp14_audio_and_text_speechlora'
|
| 154 |
+
batch_size = 16
|
| 155 |
+
batch_size_per_gpu = 1
|
| 156 |
+
learning_rate = 5.0e-5 # 1.0e-4 for fine-tuning
|
| 157 |
+
wd = 0.01
|
| 158 |
+
num_train_epochs = 10
|
| 159 |
+
|
| 160 |
+
revision = "main" #"v1.0"
|
| 161 |
+
|
| 162 |
+
processor = AutoProcessor.from_pretrained(
|
| 163 |
+
model_name_or_path,
|
| 164 |
+
revision=revision,
|
| 165 |
+
trust_remote_code=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
model = create_model(
|
| 169 |
+
model_name_or_path,
|
| 170 |
+
revision=revision,
|
| 171 |
+
use_flash_attention=use_flash_attention,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
train_datasets = []
|
| 175 |
+
|
| 176 |
+
pickup_dataset = MultiturnAudioDataset(processor=processor,text_only=True,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
|
| 177 |
+
train_datasets.append(pickup_dataset)
|
| 178 |
+
|
| 179 |
+
pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
|
| 180 |
+
train_datasets.append(pickup_dataset)
|
| 181 |
+
|
| 182 |
+
# custom_tw_loc = TWCostumData(processor=processor,
|
| 183 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
|
| 184 |
+
# train_datasets.append(custom_tw_loc) # 1500
|
| 185 |
+
|
| 186 |
+
# custom_tw_loc2 = TWCostumData(processor=processor,
|
| 187 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
|
| 188 |
+
# train_datasets.append(custom_tw_loc2) # 9458
|
| 189 |
+
|
| 190 |
+
# custom_yating_tw_road = TWCostumData(processor=processor,
|
| 191 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
|
| 192 |
+
# train_datasets.append(custom_yating_tw_road) # 35224
|
| 193 |
+
|
| 194 |
+
# custom_tw_road = TWCostumData(processor=processor,
|
| 195 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
|
| 196 |
+
# train_datasets.append(custom_tw_road) # 1500
|
| 197 |
+
|
| 198 |
+
# custom_tw_road2 = TWCostumData(processor=processor,
|
| 199 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
|
| 200 |
+
# train_datasets.append(custom_tw_road2) # 35224
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
print("Count Num of Datasets", len(train_datasets))
|
| 205 |
+
print([len(dataset) for dataset in train_datasets])
|
| 206 |
+
|
| 207 |
+
# ConcatDataset
|
| 208 |
+
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
| 209 |
+
print("Count Length of Datas", len(train_dataset))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# Check GPUs
|
| 214 |
+
num_gpus = torch.cuda.device_count()
|
| 215 |
+
print(f'training on {num_gpus} GPUs')
|
| 216 |
+
|
| 217 |
+
assert (
|
| 218 |
+
batch_size % (num_gpus * batch_size_per_gpu) == 0
|
| 219 |
+
), 'Batch size must be divisible by the number of GPUs'
|
| 220 |
+
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
|
| 221 |
+
|
| 222 |
+
# hard coded training args
|
| 223 |
+
dp_config = {
|
| 224 |
+
"fp16": {
|
| 225 |
+
"enabled": "auto",
|
| 226 |
+
"loss_scale": 0,
|
| 227 |
+
"loss_scale_window": 1000,
|
| 228 |
+
"initial_scale_power": 16,
|
| 229 |
+
"hysteresis": 2,
|
| 230 |
+
"min_loss_scale": 1
|
| 231 |
+
},
|
| 232 |
+
"zero_optimization": {
|
| 233 |
+
"stage": 2,
|
| 234 |
+
"allgather_partitions": True,
|
| 235 |
+
"allgather_bucket_size": 5e8,
|
| 236 |
+
"overlap_comm": False,
|
| 237 |
+
"reduce_scatter": True,
|
| 238 |
+
"reduce_bucket_size": 5e8,
|
| 239 |
+
"contiguous_gradients": True,
|
| 240 |
+
"cpu_offload": True
|
| 241 |
+
},
|
| 242 |
+
|
| 243 |
+
"train_batch_size": "auto",
|
| 244 |
+
"gradient_accumulation_steps": "auto",
|
| 245 |
+
"optimizer": {
|
| 246 |
+
"type": "AdamW",
|
| 247 |
+
"params": {
|
| 248 |
+
"lr": "auto",
|
| 249 |
+
"betas": 'auto',
|
| 250 |
+
"eps": 'auto',
|
| 251 |
+
"weight_decay": "auto"
|
| 252 |
+
}
|
| 253 |
+
},
|
| 254 |
+
"scheduler": {
|
| 255 |
+
"type": "WarmupDecayLR",
|
| 256 |
+
"params": {
|
| 257 |
+
"warmup_min_lr": "auto",
|
| 258 |
+
"warmup_max_lr": "auto",
|
| 259 |
+
"warmup_num_steps": "auto",
|
| 260 |
+
"total_num_steps": "auto"
|
| 261 |
+
}
|
| 262 |
+
},
|
| 263 |
+
"gradient_clipping": 1.0,
|
| 264 |
+
"zero_optimization": {
|
| 265 |
+
"stage": 0
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
training_args = TrainingArguments(
|
| 269 |
+
num_train_epochs=num_train_epochs,
|
| 270 |
+
per_device_train_batch_size=batch_size_per_gpu,
|
| 271 |
+
gradient_checkpointing=True,
|
| 272 |
+
gradient_checkpointing_kwargs={'use_reentrant': False},
|
| 273 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 274 |
+
optim='adamw_torch',
|
| 275 |
+
adam_beta1=0.9,
|
| 276 |
+
adam_beta2=0.95,
|
| 277 |
+
adam_epsilon=1e-7,
|
| 278 |
+
learning_rate=learning_rate,
|
| 279 |
+
weight_decay=wd,
|
| 280 |
+
max_grad_norm=1.0,
|
| 281 |
+
lr_scheduler_type='cosine',
|
| 282 |
+
warmup_steps=50,
|
| 283 |
+
logging_steps=10,
|
| 284 |
+
output_dir=output_dir,
|
| 285 |
+
save_total_limit=10,
|
| 286 |
+
save_only_model=True,
|
| 287 |
+
bf16=True,
|
| 288 |
+
fp16=False,
|
| 289 |
+
remove_unused_columns=False,
|
| 290 |
+
report_to='none',
|
| 291 |
+
deepspeed=None,
|
| 292 |
+
disable_tqdm=False,
|
| 293 |
+
dataloader_num_workers=16,
|
| 294 |
+
save_strategy='epoch',
|
| 295 |
+
# save_steps=2500,
|
| 296 |
+
ddp_find_unused_parameters=True,
|
| 297 |
+
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
out_path = Path(training_args.output_dir)
|
| 301 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 302 |
+
|
| 303 |
+
# create optimizer only for trainable params
|
| 304 |
+
optimizer = torch.optim.AdamW(
|
| 305 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 306 |
+
lr=learning_rate,
|
| 307 |
+
weight_decay=wd,
|
| 308 |
+
betas=(0.9, 0.95),
|
| 309 |
+
eps=1e-7,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Trainer Setting
|
| 313 |
+
trainer = Trainer(
|
| 314 |
+
model=model,
|
| 315 |
+
args=training_args,
|
| 316 |
+
data_collator=covost_collate_fn,
|
| 317 |
+
train_dataset=train_dataset,
|
| 318 |
+
optimizers=(optimizer, None)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
trainer.train()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# # 1. Save LoRA Adapter
|
| 325 |
+
model.language_model.model.save_pretrained(output_dir)
|
| 326 |
+
|
| 327 |
+
# # 1-1. Delete Markdown file
|
| 328 |
+
# markdown_file = os.path.join(output_dir, "README.md")
|
| 329 |
+
# if os.path.exists(markdown_file):
|
| 330 |
+
# os.remove(markdown_file)
|
| 331 |
+
|
| 332 |
+
# 2. Save entire model
|
| 333 |
+
model.save_pretrained(output_dir)
|