jva96160 commited on
Commit
a16e4aa
·
verified ·
1 Parent(s): 1d9d001

Upload 25 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
ASRDataset.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+ import sacrebleu
13
+
14
+ from datasets import load_dataset
15
+ from torch.utils.data import Dataset, ConcatDataset
16
+ from tqdm import tqdm
17
+ from transformers import (
18
+ BatchFeature,
19
+ )
20
+ import pandas as pd
21
+ import soundfile as sf
22
+ from datasets import Audio
23
+ import random
24
+ from copy import deepcopy
25
+ import torchaudio
26
+
27
+ ANSWER_SUFFIX = "<end_of_turn>"
28
+ _IGNORE_INDEX = -100
29
+ class BaseAudioDataset(Dataset):
30
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
31
+ self.processor = processor
32
+ self.training = "train" in split or 'other' in split
33
+ self.debug = debug
34
+ self.sampling_rate = sampling_rate
35
+ self.name = ""
36
+
37
+ def set_dataset_name(self, name):
38
+ self.name = name
39
+
40
+ @staticmethod
41
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
42
+ original_size = len(data)
43
+
44
+ data = data.cast_column(audio_field, Audio(decode=False))
45
+
46
+ def identify_corrupted_files(example):
47
+ try:
48
+ sf.read(example[audio_field]["path"])
49
+
50
+ for field in text_fields:
51
+ if field in example and example[field].replace('"', '') == "":
52
+ return False
53
+ return True
54
+ except Exception:
55
+ return False
56
+
57
+ data = data.filter(identify_corrupted_files, num_proc=16)
58
+ validated_size = len(data)
59
+
60
+ # Audio Decoding
61
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
62
+
63
+ if debug:
64
+ print(f"Dataset: {dataset_name}")
65
+ print(f"Original data nums: {original_size}")
66
+ print(f"After filtering data nums: {validated_size}")
67
+ print(f"Filtering ratio: {validated_size/original_size:.2%}")
68
+
69
+ return data
70
+
71
+ @staticmethod
72
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
73
+ original_size = len(data)
74
+
75
+ def filter_audio_by_length(example):
76
+ try:
77
+ audio = example[audio_field]['array']
78
+ channel = 1
79
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
80
+ channel = audio.ndim
81
+ audio = audio.squeeze()
82
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
83
+ return min_sec <= audio_length <= max_sec
84
+ except Exception as e:
85
+ if debug:
86
+ print(f"Error : {str(e)[:100]}... - sample excluded")
87
+ return False
88
+
89
+ data = data.filter(filter_audio_by_length, num_proc=16)
90
+ filtered_size = len(data)
91
+
92
+ if debug:
93
+ print(f"Before Length Filtering data nums: {original_size}")
94
+ print(f"After Length Filtering data nums: {filtered_size}")
95
+ print(f"Filtering ratio: {filtered_size/original_size:.2%}")
96
+
97
+ return data
98
+
99
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
100
+ user_message = {
101
+ 'role': 'user',
102
+ 'content': '<start_of_audio>' + instruction,
103
+ }
104
+ prompt = self.processor.tokenizer.apply_chat_template(
105
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
106
+ )
107
+
108
+ inputs = self.processor(
109
+ text=prompt,
110
+ audio=[audio_array],
111
+ add_special_tokens=False,
112
+ return_tensors='pt'
113
+ )
114
+
115
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
116
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
117
+
118
+ if self.debug:
119
+ self.debug = False
120
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
121
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
122
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
123
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
124
+
125
+ if self.training:
126
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
127
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
128
+ labels[:, -answer_ids.shape[1]:] = answer_ids
129
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
130
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
131
+ else:
132
+ input_ids = inputs.input_ids
133
+ labels = answer_ids
134
+ token_type_ids = inputs.token_type_ids
135
+
136
+ return {
137
+ 'input_ids': input_ids,
138
+ 'labels': labels,
139
+ 'token_type_ids': token_type_ids,
140
+ 'input_audio_embeds': inputs.input_audio_embeds,
141
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
142
+ 'input_modes': inputs.input_modes,
143
+ }
144
+
145
+ # Libri Speech Dataset Class
146
+ class LibriSpeechDataset(BaseAudioDataset):
147
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
148
+ super().__init__(processor, split, sampling_rate, debug)
149
+
150
+ self.set_dataset_name(f"LibriSpeech_{subset}")
151
+ # only ASR
152
+ self.ast = False
153
+ self.lang = "en"
154
+
155
+ # load dataset
156
+ self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
157
+ subset,
158
+ split=split,
159
+ trust_remote_code=True,
160
+ cache_dir=Path("/mnt/jeff/InCar/data")
161
+ )
162
+
163
+ # (Optional) Audio length Filtering
164
+ self.data = self.filter_by_audio_length(self.data, "audio")
165
+
166
+ # Instruction Setting
167
+ self.instruction = random.choice(INSTRUCTION["asr"])
168
+
169
+ def __len__(self):
170
+ return len(self.data)
171
+
172
+ def __getitem__(self, idx):
173
+ data = self.data[idx]
174
+
175
+ # Libri Speech is only for ASR
176
+ answer_text = data["text"].replace('"', '')
177
+
178
+ return self.prepare_model_inputs(
179
+ data["audio"]["array"],
180
+ self.instruction,
181
+ answer_text
182
+ )
183
+
184
+ # common_voice_16_1 dataset
185
+ class CommonVoiceDataset(BaseAudioDataset):
186
+ def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
187
+ super().__init__(processor, split, sampling_rate, debug)
188
+
189
+ self.set_dataset_name(f"CommonVoice_{source_lang}")
190
+ # only ASR
191
+ self.ast = False
192
+ self.lang=source_lang
193
+
194
+ # load dataset
195
+ if source_lang=="zh-TW":
196
+ data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
197
+ else:
198
+ data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
199
+ self.data = load_dataset(data_path,
200
+ source_lang,
201
+ split=split,
202
+ trust_remote_code=True,
203
+ cache_dir=Path("/mnt/jeff/InCar/data")
204
+ )
205
+ def prepare_dataset(batch):
206
+ """Function to preprocess the dataset with the .map method"""
207
+ transcription = batch["sentence"]
208
+
209
+ if transcription.startswith('"') and transcription.endswith('"'):
210
+ # we can remove trailing quotation marks as they do not affect the transcription
211
+ transcription = transcription[1:-1]
212
+
213
+ if transcription[-1] not in [".", "?", "!"]:
214
+ # append a full-stop to sentences that do not end in punctuation
215
+ transcription = transcription + "."
216
+
217
+ batch["sentence"] = transcription
218
+
219
+ return batch
220
+
221
+
222
+ import opencc
223
+ converter = opencc.OpenCC('s2tw.json')
224
+ def To_zhTW(batch):
225
+
226
+ transcription = converter.convert(batch["sentence"])
227
+ batch["sentence"] = transcription
228
+
229
+ return batch
230
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
231
+ if source_lang=='zh-CN':
232
+ self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
233
+
234
+
235
+ # (Optional) Audio length Filtering
236
+ self.data = self.filter_by_audio_length(self.data, "audio")
237
+
238
+ if source_lang == "zh-TW" and split=='train':
239
+ import torchaudio
240
+ from torchaudio import transforms
241
+ import copy
242
+ import pickle
243
+ import os
244
+ def subsample(batch):
245
+ batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
246
+ batch['audio']['sampling_rate']=16000
247
+ return batch
248
+ def TW_data_augment_fast(batch):
249
+ speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
250
+ new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
251
+ batch['audio']['array'] = new_array_fast
252
+ return batch
253
+ def TW_data_augment_slow(batch):
254
+ speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
255
+ new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
256
+ batch['audio']['array'] = new_array_slow
257
+ return batch
258
+ # data = self.data.map(subsample, num_proc=1, desc="subsample")
259
+ fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
260
+ if not os.path.exists(fast_path):
261
+ data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
262
+ with open(fast_path,'wb') as f:
263
+ pickle.dump(data_fast,f)
264
+ else:
265
+ with open(fast_path,'rb') as f:
266
+ data_fast=pickle.load(f)
267
+
268
+ slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
269
+ if not os.path.exists(slow_path):
270
+ data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
271
+ with open(slow_path,'wb') as f:
272
+ pickle.dump(data_slow,f)
273
+ else:
274
+ with open(slow_path,'rb') as f:
275
+ data_slow=pickle.load(f)
276
+ self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
277
+
278
+ # Instruction Setting
279
+ self.instruction = random.choice(INSTRUCTION["asr"])
280
+
281
+ def __len__(self):
282
+ return len(self.data)
283
+
284
+ def __getitem__(self, idx):
285
+ data = self.data[idx]
286
+
287
+ answer_text = data["sentence"]
288
+ return self.prepare_model_inputs(
289
+ data["audio"]["array"],
290
+ self.instruction,
291
+ answer_text
292
+ )
293
+
294
+
295
+ # Fleurs Dataset Class
296
+ class FleursDataset(BaseAudioDataset):
297
+ def __init__(self, processor, split, source_lang, target_lang=None,
298
+ mode="asr", sampling_rate=16000, debug=False):
299
+ super().__init__(processor, split, sampling_rate, debug)
300
+
301
+ self.set_dataset_name("Fleurs")
302
+ # Mode Setting (ASR or AST)
303
+ if mode not in ["asr", "ast"]:
304
+ raise ValueError("mode must be 'asr' or 'ast'.")
305
+
306
+ self.mode = mode
307
+ self.ast = (mode == "ast")
308
+ self.source_lang = source_lang
309
+
310
+ # Language name mapping (expand if needed)
311
+ self.lang_names = {
312
+ 'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
313
+ }
314
+
315
+ # load dataset - source language dataset
316
+ self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
317
+ source_lang,
318
+ split=split,
319
+ trust_remote_code=True,
320
+ cache_dir=Path("/mnt/jeff/InCar/data")
321
+ )
322
+ import opencc
323
+ converter = opencc.OpenCC('s2tw.json')
324
+ def prepare_dataset(batch):
325
+ transcription = converter.convert(batch["transcription"])
326
+ batch["transcription"] = transcription
327
+
328
+ return batch
329
+ if (source_lang=="cmn_hans_cn"):
330
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
331
+
332
+ # (Optional) Audio length Filtering
333
+ self.data = self.filter_by_audio_length(self.data, "audio")
334
+ self.target_lang_name = ""
335
+ # When AST mode, load target language dataset.
336
+ if self.ast:
337
+ if target_lang is None:
338
+ raise ValueError("AST mode requires target_lang.")
339
+
340
+ self.target_lang = target_lang
341
+ self.lang = f"{source_lang}_{target_lang}"
342
+
343
+ # load dataset - target language dataset (for translation)
344
+ target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
345
+ target_lang,
346
+ split=split,
347
+ trust_remote_code=True,
348
+ cache_dir=Path("/mnt/jeff/InCar/data")
349
+ )
350
+ if target_lang=="cmn_hans_cn":
351
+ target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
352
+ source_dict = {item['id']: item for item in self.data}
353
+ target_dict = {item['id']: item for item in target_data}
354
+
355
+ # only Common ID, add translation fields
356
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
357
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
358
+ self.data = [
359
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
360
+ for id in common_ids
361
+ ]
362
+
363
+ # Instruction Setting - use target language name
364
+ self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
365
+ self.instruction = random.choice(INSTRUCTION["ast"])
366
+ else:
367
+ # ASR mode
368
+ self.lang = source_lang
369
+ self.instruction = random.choice(INSTRUCTION["asr"])
370
+
371
+ if self.debug:
372
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
373
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
374
+ if self.ast:
375
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
376
+ print(f"dataset size: {len(self.data)}")
377
+
378
+ def __len__(self):
379
+ return len(self.data)
380
+
381
+ def __getitem__(self, idx):
382
+ data = self.data[idx]
383
+ audio_array = data["audio"]["array"]
384
+
385
+ if self.ast:
386
+ answer_text = data["translation"]
387
+ else:
388
+ answer_text = data["transcription"]
389
+
390
+ return self.prepare_model_inputs(
391
+ audio_array,
392
+ self.instruction.format(self.target_lang_name),
393
+ answer_text
394
+ )
395
+
396
+ class TWCostumData(BaseAudioDataset):
397
+
398
+ def __init__(self, processor, split="train", sampling_rate=16000,csv_path="", debug=False):
399
+ super().__init__(processor, split, sampling_rate, debug)
400
+ import pandas as pd
401
+ from datasets import Dataset, Audio
402
+
403
+
404
+ df = pd.read_csv(csv_path).fillna('')
405
+
406
+
407
+ self.set_dataset_name(f"TWCostumData")
408
+ self.data = Dataset.from_dict(
409
+ {
410
+ "audio": [audio for audio in df['audio']],
411
+ "sentence": [text for text in df['text']]
412
+ }
413
+ ).cast_column("audio", Audio(sampling_rate=16000))
414
+
415
+ # Instruction Setting
416
+ self.instruction = random.choice(INSTRUCTION["asr"])
417
+
418
+ def __len__(self):
419
+ return len(self.data)
420
+
421
+ def __getitem__(self, idx):
422
+ data = self.data[idx]
423
+
424
+ answer_text = data["sentence"]
425
+ return self.prepare_model_inputs(
426
+ data["audio"]["array"],
427
+ self.instruction,
428
+ answer_text
429
+ )
430
+ def covost_collate_fn(batch):
431
+ input_ids_list = []
432
+ labels_list = []
433
+ token_type_ids_list = []
434
+ input_audio_embeds_list = []
435
+ audio_embed_sizes_list = []
436
+ audio_attention_mask_list = []
437
+ input_modes_list = []
438
+ audio_paths = []
439
+ for inputs in batch:
440
+ if 'audio_path' in inputs:
441
+ audio_paths.append(inputs['audio_path'])
442
+ input_ids_list.append(inputs['input_ids'][0])
443
+ labels_list.append(inputs['labels'][0])
444
+ token_type_ids_list.append(inputs['token_type_ids'][0])
445
+ if inputs['input_modes']==2:
446
+ input_audio_embeds_list.append(inputs['input_audio_embeds'])
447
+ audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
448
+ audio_attention_mask_list.append(
449
+ inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
450
+ )
451
+ # else:
452
+ # input_audio_embeds_list.append(None)
453
+ # audio_embed_sizes_list.append(None)
454
+ # audio_attention_mask_list.append(None)
455
+ input_modes_list.append(inputs['input_modes'])
456
+ # try:
457
+ token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
458
+ input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
459
+ labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
460
+ audio_attention_mask = (
461
+ pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
462
+ if len(audio_attention_mask_list) > 1
463
+ else None
464
+ )
465
+ # except Exception as e:
466
+ # print(e)
467
+ # print(input_ids_list)
468
+ # print(labels_list)
469
+ # raise
470
+ attention_mask = (input_ids != 0).long()
471
+ input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) if len(input_audio_embeds_list)>0 else None
472
+ audio_embed_sizes = torch.cat(audio_embed_sizes_list) if len(audio_embed_sizes_list)>0 else None
473
+ input_modes = torch.cat(input_modes_list)
474
+ if len(audio_paths)>0:
475
+ return BatchFeature(
476
+ {
477
+ "audio_path": audio_paths,
478
+ 'input_ids': input_ids,
479
+ 'labels': labels,
480
+ 'token_type_ids': token_type_ids,
481
+ 'attention_mask': attention_mask,
482
+ 'input_audio_embeds': input_audio_embeds,
483
+ 'audio_embed_sizes': audio_embed_sizes,
484
+ 'audio_attention_mask': audio_attention_mask,
485
+ 'input_modes': input_modes,
486
+ }
487
+ )
488
+ else:
489
+ return BatchFeature(
490
+ {
491
+ 'input_ids': input_ids,
492
+ 'labels': labels,
493
+ 'token_type_ids': token_type_ids,
494
+ 'attention_mask': attention_mask,
495
+ 'input_audio_embeds': input_audio_embeds,
496
+ 'audio_embed_sizes': audio_embed_sizes,
497
+ 'audio_attention_mask': audio_attention_mask,
498
+ 'input_modes': input_modes,
499
+ }
500
+ )
501
+
502
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
503
+ """
504
+ Pad a list of sequences to the same length.
505
+ sequences: list of tensors in [seq_len, *] shape
506
+ """
507
+ assert padding_side in ['right', 'left']
508
+ max_size = sequences[0].size()
509
+ trailing_dims = max_size[1:]
510
+ max_len = max(len(seq) for seq in sequences)
511
+ batch_size = len(sequences)
512
+ output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
513
+ for i, seq in enumerate(sequences):
514
+ length = seq.size(0)
515
+ if padding_side == 'right':
516
+ output.data[i, :length] = seq
517
+ else:
518
+ output.data[i, -length:] = seq
519
+ return output
520
+
521
+ def cat_with_pad(tensors, dim, padding_value=0):
522
+ """
523
+ cat along dim, while pad to max for all other dims
524
+ """
525
+ ndim = tensors[0].dim()
526
+ assert all(
527
+ t.dim() == ndim for t in tensors[1:]
528
+ ), 'All tensors must have the same number of dimensions'
529
+
530
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
531
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
532
+ output = tensors[0].new_full(out_size, padding_value)
533
+
534
+ index = 0
535
+ for t in tensors:
536
+ # Create a slice list where every dimension except dim is full slice
537
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
538
+ # Update only the concat dimension slice
539
+ slices[dim] = slice(index, index + t.shape[dim])
540
+
541
+ output[slices] = t
542
+ index += t.shape[dim]
543
+
544
+ return output
545
+
546
+
547
+
548
+ class MultiturnAudioDataset(BaseAudioDataset):
549
+ def __init__(self, processor, split="train", sampling_rate=16000,json_path="",text_only=False, debug=False):
550
+ super().__init__(processor, split, sampling_rate, debug)
551
+ from llamafactory.data.template import Llama2Template,parse_template
552
+ from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
553
+ from llamafactory.data.mm_plugin import get_mm_plugin
554
+ import json
555
+ self.train=False
556
+ self.text_only=text_only
557
+ with open(json_path) as f:
558
+ js_data = json.load(f)
559
+ if split=='train':
560
+ self.train=True
561
+ js_data = js_data[:int(len(js_data)*0.8)]
562
+ else:
563
+ js_data = js_data[-int(len(js_data)*0.2):]
564
+ for conv in js_data:
565
+ for mess in conv['conversations']:
566
+ if 'audio_path' in mess:
567
+ mess['audio_path'] = mess['audio_path'].replace('/home/jeff/codes/llm/InCar/srdc_generate_tts/','/mnt/jeff/InCar/data/multiturn_data/')
568
+ default_system = ""#"""You are a helpful assistant that determines how to solve problems based on user needs and converts user speech into text.\n"""
569
+ self.template=Llama2Template(
570
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
571
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
572
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
573
+ format_function=FunctionFormatter(slots=["{{content}}", {"eos_token"}], tool_format="default"),
574
+ format_tools = ToolFormatter(tool_format="default"),
575
+ format_observation=StringFormatter(
576
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
577
+ ),
578
+ default_system=default_system,
579
+ thought_words=("<think>", "</think>"),
580
+ efficient_eos=False,
581
+ replace_eos=False,
582
+ replace_jinja_template=False,
583
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
584
+ stop_words=["<end_of_turn>"],
585
+ mm_plugin=get_mm_plugin(name="base"),
586
+ )
587
+
588
+ self.set_dataset_name(f"MultiturnCostumData")
589
+
590
+
591
+ self.data = []
592
+ self.text_only_data = []
593
+ for conv in js_data:
594
+ tools = conv['tools'] if 'tools' in conv else ""
595
+ system = conv['system'] if 'system' in conv else default_system
596
+ tmp = {
597
+ 'tools':tools,
598
+ 'system':system,
599
+ 'messages':[],
600
+ }
601
+ for i,mess in enumerate(conv['conversations']):
602
+ tmp['messages'].append(mess)
603
+ if mess['from']=='human':
604
+ tmp['messages'].append(conv['conversations'][i+1])
605
+ d = deepcopy(tmp)
606
+ d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0]
607
+ self.data.append(d)
608
+ if self.text_only:
609
+ self.text_only_data.append(deepcopy(tmp))
610
+ tmp['messages'].pop()
611
+ elif mess['from']=='observation':
612
+ tmp['messages'].append(conv['conversations'][i+1])
613
+ d = deepcopy(tmp)
614
+ self.text_only_data.append(d)
615
+ tmp['messages'].pop()
616
+ if text_only:
617
+ self.data=self.text_only_data
618
+
619
+
620
+ def prepare_multiturn_model_inputs(self, audio_array, messages, system="", tools=""):
621
+ ANSWER_SUFFIX = "<end_of_turn>"
622
+ prompt = ""
623
+ answer_text = ""
624
+ user_transcribe = ""
625
+ audio_paths = []
626
+ for i, message in enumerate(messages):
627
+ elements = []
628
+
629
+ system_text = ""
630
+ if i == 0:
631
+ elements += self.template.format_prefix.apply()
632
+ if system or tools:
633
+ tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
634
+ system_text = self.template.format_system.apply(content=(system + tool_text))[0]
635
+
636
+ if message["from"] == "human":
637
+ if i==len(messages)-2 and not self.text_only:
638
+ user_transcribe = message["value"]
639
+ elements += self.template.format_user.apply(content=system_text+'<start_of_audio>')
640
+ else:
641
+ elements += self.template.format_user.apply(content=system_text + message["value"])
642
+ audio_paths.append(message['audio_path'])
643
+ elif message["from"] == "gpt":
644
+ elements += self.template.format_assistant.apply(content=message["value"])
645
+ elif message["from"] == "observation":
646
+ elements += self.template.format_observation.apply(content=message["value"])
647
+ elif message["from"] == "function_call":
648
+ elements += self.template.format_function.apply(content=message["value"])
649
+ else:
650
+ raise NotImplementedError("Unexpected role: {}".format(message["from"]))
651
+
652
+
653
+ for elem in elements:
654
+ ele_str = ""
655
+ if isinstance(elem, str):
656
+ ele_str=elem
657
+ elif isinstance(elem, set):
658
+ if "bos_token" in elem and self.processor.tokenizer.bos_token_id is not None:
659
+ ele_str = self.processor.tokenizer.bos_token
660
+ elif "eos_token" in elem and self.processor.tokenizer.eos_token_id is not None:
661
+ ele_str = self.processor.tokenizer.eos_token
662
+ if i == len(messages)-1:
663
+ answer_text+=ele_str
664
+ else:
665
+ prompt+=ele_str
666
+
667
+
668
+ if type(audio_array)!=type(None):
669
+ inputs = self.processor(
670
+ text=prompt,
671
+ audio=[audio_array],
672
+ add_special_tokens=False,
673
+ return_tensors='pt'
674
+ )
675
+ answer = "\nUser transcribe is : {};\nGPT output is : {}{}".format(user_transcribe,answer_text,ANSWER_SUFFIX)
676
+ else:
677
+ inputs = self.processor(
678
+ text=prompt,
679
+ audio=None,
680
+ add_special_tokens=False,
681
+ return_tensors='pt'
682
+ )
683
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
684
+ # print('user_transcribe',user_transcribe)
685
+ # print('answer_text', answer)
686
+ # print('prompt',prompt)
687
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
688
+
689
+ if self.debug:
690
+ self.debug = False
691
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
692
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
693
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
694
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
695
+
696
+ if self.training:
697
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
698
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
699
+ labels[:, -answer_ids.shape[1]:] = answer_ids
700
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
701
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
702
+ else:
703
+ input_ids = inputs.input_ids
704
+ labels = answer_ids
705
+ token_type_ids = inputs.token_type_ids
706
+ if type(audio_array)!=type(None):
707
+ if not self.train:
708
+ return {
709
+ "audio_path": audio_paths,
710
+ 'input_ids': input_ids,
711
+ 'labels': labels,
712
+ 'token_type_ids': token_type_ids,
713
+ 'input_audio_embeds': inputs.input_audio_embeds,
714
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
715
+ 'input_modes': inputs.input_modes,
716
+ }
717
+ else:
718
+ return {
719
+ 'input_ids': input_ids,
720
+ 'labels': labels,
721
+ 'token_type_ids': token_type_ids,
722
+ 'input_audio_embeds': inputs.input_audio_embeds,
723
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
724
+ 'input_modes': inputs.input_modes,
725
+ }
726
+ else:
727
+ return {
728
+ 'input_ids': input_ids,
729
+ 'labels': labels,
730
+ 'token_type_ids': token_type_ids,
731
+ 'input_audio_embeds': None,
732
+ 'audio_embed_sizes': None,
733
+ 'input_modes': inputs.input_modes,
734
+ }
735
+ def __len__(self):
736
+ return len(self.data)
737
+
738
+ def __getitem__(self, idx):
739
+ data = self.data[idx]
740
+ return self.prepare_multiturn_model_inputs(
741
+ audio_array=data["audio_array"] if "audio_array" in data else None,
742
+ messages=data['messages'],
743
+ system=data["system"],
744
+ tools=data["tools"]
745
+ )
746
+
747
+
748
+
749
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
750
+
751
+ INSTRUCTION = {
752
+ "ast": [
753
+ "Translate the audio to {0}.",
754
+ "Translate the audio clip into {0}.",
755
+ "Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
756
+ "Translate the provided audio file into {0}.",
757
+ "Convert the audio speech to {0} text.",
758
+ "Write an {0} translation of the audio file.",
759
+ "Translate spoken words from the audio into {0}.",
760
+ "Create an {0} version of the audio content.",
761
+ "Produce an accurate {0} translation of the audio.",
762
+ "Extract speech from the audio and translate it to {0}.",
763
+ "Turn the audio into readable {0} text.",
764
+ "Write all spoken content from the audio in {0}.",
765
+ "Generate an {0} translation of the speech in the file.",
766
+ "Convert the recording into {0} text.",
767
+ "Accurately translate the audio recording to {0}.",
768
+ "Write down dialogue from the given audio in {0}.",
769
+ "Translate all speech in this audio file to {0}.",
770
+ "Create an accurate {0} version of the speech.",
771
+ "Perform a complete {0} translation of the audio."
772
+ ],
773
+ "asr": [
774
+ "Transcribe the audio clip into text.",
775
+ "Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
776
+ "Transcribe the provided audio file into text.",
777
+ "Convert the audio speech to text.",
778
+ "Write a transcript of the audio file.",
779
+ "Transcribe spoken words from the audio.",
780
+ "Create a text version of the audio content.",
781
+ "Produce a verbatim transcript of the audio.",
782
+ "Extract and transcribe speech from the audio.",
783
+ "Turn the audio into readable text.",
784
+ "Write all spoken words from the audio.",
785
+ "Generate a transcript of the speech in the file.",
786
+ "Convert the recording into a text transcript.",
787
+ "Accurately transcribe the audio recording.",
788
+ "Write down dialogue from the given audio.",
789
+ "Transcribe all speech in this audio file.",
790
+ "Create an accurate text version of the speech.",
791
+ "Perform a complete transcription of the audio."
792
+ ],
793
+ }
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'audio' -%}\n {{ '<start_of_audio>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
3
+ }
config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3OmniForConditionalGeneration"
4
+ ],
5
+ "audio_config": {
6
+ "activation": "swish",
7
+ "activation_checkpointing": {
8
+ "interval": 1,
9
+ "module": "transformer",
10
+ "offload": false
11
+ },
12
+ "attention_dim": 1024,
13
+ "attention_heads": 16,
14
+ "batch_norm": false,
15
+ "bias_in_glu": true,
16
+ "causal": true,
17
+ "chunk_size": -1,
18
+ "cnn_layer_norm": true,
19
+ "conv_activation": "swish",
20
+ "conv_glu_type": "swish",
21
+ "depthwise_multiplier": 1,
22
+ "depthwise_seperable_out_channel": 1024,
23
+ "dropout_rate": 0.0,
24
+ "encoder_embedding_config": {
25
+ "input_size": 80
26
+ },
27
+ "ext_pw_kernel_size": 1,
28
+ "ext_pw_out_channel": 1024,
29
+ "input_layer": "nemo_conv",
30
+ "input_size": 80,
31
+ "kernel_size": 3,
32
+ "left_chunk": 18,
33
+ "linear_units": 1536,
34
+ "nemo_conv_settings": {
35
+ "conv_channels": 1024
36
+ },
37
+ "num_blocks": 24,
38
+ "relative_attention_bias_args": {
39
+ "t5_bias_max_distance": 500,
40
+ "type": "t5"
41
+ },
42
+ "time_reduction": 8
43
+ },
44
+ "audio_token_index": 262143,
45
+ "auto_map": {
46
+ "AutoConfig": "configuration_gemma3omni.Gemma3OmniConfig",
47
+ "AutoModel": "modeling_gemma3omni.Gemma3OmniForConditionalGeneration"
48
+ },
49
+ "boa_token_index": 256001,
50
+ "boi_token_index": 255999,
51
+ "eoa_token_index": 256002,
52
+ "eoi_token_index": 256000,
53
+ "eos_token_id": [
54
+ 1,
55
+ 106
56
+ ],
57
+ "image_token_index": 262144,
58
+ "initializer_range": 0.02,
59
+ "mm_tokens_per_image": 256,
60
+ "model_type": "gemma3omni",
61
+ "speech_lora": {
62
+ "dp": 0.01,
63
+ "layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
64
+ "lora_alpha": 320,
65
+ "r": 320,
66
+ "use_rslora": true
67
+ },
68
+ "text_lora": {
69
+ "dp": 0.01,
70
+ "layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
71
+ "lora_alpha": 16,
72
+ "r": 8,
73
+ "use_rslora": true
74
+ },
75
+ "text_config": {
76
+ "attention_bias": false,
77
+ "attention_dropout": 0.0,
78
+ "attn_logit_softcapping": null,
79
+ "cache_implementation": "hybrid",
80
+ "final_logit_softcapping": null,
81
+ "head_dim": 256,
82
+ "hidden_activation": "gelu_pytorch_tanh",
83
+ "hidden_size": 2560,
84
+ "initializer_range": 0.02,
85
+ "intermediate_size": 10240,
86
+ "max_position_embeddings": 131072,
87
+ "model_type": "gemma3_text",
88
+ "num_attention_heads": 8,
89
+ "num_hidden_layers": 34,
90
+ "num_key_value_heads": 4,
91
+ "query_pre_attn_scalar": 256,
92
+ "rms_norm_eps": 1e-06,
93
+ "rope_local_base_freq": 10000.0,
94
+ "rope_scaling": {
95
+ "factor": 8.0,
96
+ "rope_type": "linear"
97
+ },
98
+ "rope_theta": 1000000.0,
99
+ "sliding_window": 1024,
100
+ "sliding_window_pattern": 6,
101
+ "torch_dtype": "float",
102
+ "use_cache": true,
103
+ "vocab_size": 262208
104
+ },
105
+ "torch_dtype": "float",
106
+ "transformers_version": "4.51.3",
107
+ "use_cache": false,
108
+ "vision_config": {
109
+ "hidden_size": 1152,
110
+ "image_size": 896,
111
+ "intermediate_size": 4304,
112
+ "model_type": "siglip_vision_model",
113
+ "num_attention_heads": 16,
114
+ "num_hidden_layers": 27,
115
+ "patch_size": 14,
116
+ "vision_use_head": false
117
+ }
118
+ }
configuration_gemma3omni.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoConfig, Gemma3TextConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.modeling_rope_utils import rope_config_validation
6
+ from transformers.utils import logging
7
+ from transformers.models.siglip import SiglipVisionConfig
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ class AudioConfig(PretrainedConfig):
13
+ model_type = "gemma3_audio"
14
+
15
+ def __init__(
16
+ self,
17
+ input_size=80,
18
+ attention_dim=1024,
19
+ attention_heads=16,
20
+ num_blocks=24,
21
+ linear_units=1536,
22
+ dropout_rate=0.0,
23
+ kernel_size=3,
24
+ ext_pw_kernel_size=1,
25
+ ext_pw_out_channel=1024,
26
+ depthwise_seperable_out_channel=1024,
27
+ depthwise_multiplier=1,
28
+ activation="swish",
29
+ conv_activation="swish",
30
+ conv_glu_type="swish",
31
+ bias_in_glu=True,
32
+ causal=True,
33
+ batch_norm=False,
34
+ cnn_layer_norm=True,
35
+ time_reduction=8,
36
+ input_layer="nemo_conv",
37
+ nemo_conv_settings=None,
38
+ chunk_size=-1,
39
+ left_chunk=18,
40
+ relative_attention_bias_args=None,
41
+ activation_checkpointing=None,
42
+ encoder_embedding_config=None,
43
+ **kwargs
44
+ ):
45
+ super().__init__(**kwargs)
46
+
47
+ self.input_size = input_size
48
+ self.attention_dim = attention_dim
49
+ self.attention_heads = attention_heads
50
+ self.num_blocks = num_blocks
51
+ self.linear_units = linear_units
52
+ self.dropout_rate = dropout_rate
53
+ self.kernel_size = kernel_size
54
+ self.ext_pw_kernel_size = ext_pw_kernel_size
55
+ self.ext_pw_out_channel = ext_pw_out_channel
56
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
57
+ self.depthwise_multiplier = depthwise_multiplier
58
+ self.activation = activation
59
+ self.conv_activation = conv_activation
60
+ self.conv_glu_type = conv_glu_type
61
+ self.bias_in_glu = bias_in_glu
62
+ self.causal = causal
63
+ self.batch_norm = batch_norm
64
+ self.cnn_layer_norm = cnn_layer_norm
65
+ self.time_reduction = time_reduction
66
+ self.input_layer = input_layer
67
+
68
+ if nemo_conv_settings is None:
69
+ self.nemo_conv_settings = {"conv_channels": 1024}
70
+ else:
71
+ self.nemo_conv_settings = nemo_conv_settings
72
+
73
+ self.chunk_size = chunk_size
74
+ self.left_chunk = left_chunk
75
+
76
+ if relative_attention_bias_args is None:
77
+ self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500}
78
+ else:
79
+ self.relative_attention_bias_args = relative_attention_bias_args
80
+
81
+ if activation_checkpointing is None:
82
+ self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False}
83
+ else:
84
+ self.activation_checkpointing = activation_checkpointing
85
+
86
+ if encoder_embedding_config is None:
87
+ self.encoder_embedding_config = {"input_size": input_size}
88
+ else:
89
+ self.encoder_embedding_config = encoder_embedding_config
90
+
91
+
92
+ class Gemma3OmniConfig(PretrainedConfig):
93
+ r"""
94
+ This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
95
+ Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
96
+ with the defaults will yield a similar configuration to that of the PaliGemma-2B.
97
+
98
+ e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
99
+
100
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
101
+ documentation from [`PretrainedConfig`] for more information.
102
+
103
+ Args:
104
+ text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
105
+ The config object of the text backbone.
106
+ vision_config (`Union[AutoConfig, dict]`, *optional*):
107
+ Custom vision config or dict.
108
+ audio_config (`Union[AutoConfig, dict]`, *optional*):
109
+ Custom audio config or dict.
110
+ mm_tokens_per_image (`int`, *optional*, defaults to 256):
111
+ The number of tokens per image embedding.
112
+ boi_token_index (`int`, *optional*, defaults to 255999):
113
+ The begin-of-image token index to wrap the image prompt.
114
+ eoi_token_index (`int`, *optional*, defaults to 256000):
115
+ The end-of-image token index to wrap the image prompt.
116
+ image_token_index (`int`, *optional*, defaults to 262144):
117
+ The image token index to encode the image prompt.
118
+ audio_token_index (`int`, *optional*, defaults to 262145):
119
+ The audio token index to encode the audio prompt.
120
+ initializer_range (`float`, *optional*, defaults to 0.02):
121
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
122
+
123
+
124
+ Example:
125
+
126
+ ```python
127
+ >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
128
+
129
+ >>> # Initializing a Siglip-like vision config
130
+ >>> vision_config = SiglipVisionConfig()
131
+
132
+ >>> # Initializing a Siglip-like vision config
133
+ >>> audio_config = AudioConfig()
134
+
135
+ >>> # Initializing a Gemma3 Text config
136
+ >>> text_config = Gemma3TextConfig()
137
+
138
+ >>> # Initializing a Gemma3 gemma-3-4b style configuration
139
+ >>> configuration = Gemma3Config(vision_config, text_config)
140
+
141
+ >>> # Initializing a model from the gemma-3-4b style configuration
142
+ >>> model = Gemma3TextConfig(configuration)
143
+
144
+ >>> # Accessing the model configuration
145
+ >>> configuration = model.config
146
+ ```"""
147
+
148
+ model_type = "gemma3omni"
149
+ sub_configs = {
150
+ "text_config": Gemma3TextConfig,
151
+ "vision_config": SiglipVisionConfig,
152
+ "audio_config": AudioConfig,
153
+ }
154
+
155
+ def __init__(
156
+ self,
157
+ text_config: Optional[Gemma3TextConfig] = None,
158
+ vision_config: Optional[SiglipVisionConfig] = None,
159
+ audio_config: Optional[AudioConfig] = None,
160
+ mm_tokens_per_image: int = 256,
161
+ boi_token_index: int = 255_999,
162
+ eoi_token_index: int = 256_000,
163
+ boa_token_index: int = 256_001,
164
+ eoa_token_index: int = 256_002,
165
+ image_token_index: int = 262_144,
166
+ audio_token_index: int = 262_143,
167
+ initializer_range: float = 0.02,
168
+ **kwargs,
169
+ ):
170
+ if text_config is None:
171
+ text_config = Gemma3TextConfig()
172
+ logger.info("text_config is None, using default Gemma3TextConfig vision config.")
173
+ elif isinstance(text_config, dict):
174
+ text_config = Gemma3TextConfig(**text_config)
175
+
176
+ if isinstance(vision_config, dict):
177
+ vision_config = SiglipVisionConfig(**vision_config)
178
+ else:
179
+ vision_config = SiglipVisionConfig()
180
+ logger.info(
181
+ "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
182
+ "to text tasks."
183
+ )
184
+
185
+ if isinstance(audio_config, dict):
186
+ audio_config = AudioConfig(**audio_config)
187
+ else:
188
+ audio_config = AudioConfig()
189
+ logger.info(
190
+ "audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited "
191
+ "to text tasks."
192
+ )
193
+
194
+ self.text_config = text_config
195
+ self.vision_config = vision_config
196
+ self.audio_config = audio_config
197
+ self.mm_tokens_per_image = mm_tokens_per_image
198
+ self.boi_token_index = boi_token_index
199
+ self.eoi_token_index = eoi_token_index
200
+ self.boa_token_index = boa_token_index
201
+ self.eoa_token_index = eoa_token_index
202
+ self.image_token_index = image_token_index
203
+ self.audio_token_index = audio_token_index
204
+ self.initializer_range = initializer_range
205
+
206
+ super().__init__(**kwargs)
eval.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from urllib.request import urlopen
3
+ import soundfile
4
+ import torch
5
+ from datasets import load_dataset, Audio
6
+ import numpy as np
7
+ from transformers import AutoModel, AutoProcessor, BatchFeature
8
+ from tqdm import tqdm
9
+ import json
10
+ import os
11
+ import time
12
+ from datetime import datetime
13
+ from whisper_normalizer.english import EnglishTextNormalizer
14
+ from whisper_normalizer.basic import BasicTextNormalizer
15
+ import sacrebleu
16
+ from jiwer import cer, wer
17
+ from torch.utils.data import Dataset, DataLoader
18
+ import soundfile as sf
19
+ import re
20
+ from pathlib import Path
21
+ import opencc
22
+ converter = opencc.OpenCC('s2tw.json')
23
+ normalizer = {
24
+ "en_us" : EnglishTextNormalizer(),
25
+ "other" : BasicTextNormalizer()
26
+ }
27
+
28
+ model_id = "/mnt/jeff/gemma_test"
29
+ revision = "main" #"v1.0"
30
+
31
+ model = AutoModel.from_pretrained(
32
+ model_id, device_map="cuda", revision = revision, trust_remote_code=True
33
+ ).eval()
34
+
35
+ processor = AutoProcessor.from_pretrained(
36
+ model_id, revision = revision, trust_remote_code=True
37
+ )
38
+
39
+ results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
40
+ os.makedirs(results_dir, exist_ok=True)
41
+
42
+
43
+ INSTRUCTION = {
44
+ "ast": "Translate the audio to {0}.",
45
+ "asr": "Transcribe the audio clip into text.",
46
+ }
47
+
48
+ class BaseAudioDataset(Dataset):
49
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
50
+ self.processor = processor
51
+ self.training = "train" in split
52
+ self.debug = debug
53
+ self.sampling_rate = sampling_rate
54
+ self.name = ""
55
+
56
+ def set_dataset_name(self, name):
57
+ self.name = name
58
+
59
+ @staticmethod
60
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
61
+ original_size = len(data)
62
+
63
+ data = data.cast_column(audio_field, Audio(decode=False))
64
+
65
+ def identify_corrupted_files(example):
66
+ try:
67
+ sf.read(example[audio_field]["path"])
68
+
69
+ for field in text_fields:
70
+ if example[field].replace('"', '') == "":
71
+ return False
72
+ return True
73
+ except Exception:
74
+ return False
75
+
76
+ data = data.filter(identify_corrupted_files, num_proc=16)
77
+ validated_size = len(data)
78
+
79
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
80
+
81
+
82
+ return data
83
+
84
+ @staticmethod
85
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
86
+ original_size = len(data)
87
+
88
+ def filter_audio_by_length(example):
89
+ try:
90
+ audio = example[audio_field]['array']
91
+ channel = 1
92
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
93
+ channel = audio.ndim
94
+ audio = audio.squeeze()
95
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
96
+ return min_sec <= audio_length <= max_sec
97
+ except Exception as e:
98
+ return False
99
+
100
+ data = data.filter(filter_audio_by_length, num_proc=16)
101
+ filtered_size = len(data)
102
+
103
+ return data
104
+
105
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
106
+ user_message = {
107
+ 'role': 'user',
108
+ 'content': '<start_of_audio>' + instruction,
109
+ }
110
+ prompt = self.processor.tokenizer.apply_chat_template(
111
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
112
+ )
113
+
114
+ inputs = self.processor(
115
+ text=prompt,
116
+ audio=[audio_array],
117
+ add_special_tokens=False,
118
+ return_tensors='pt'
119
+ )
120
+
121
+ input_ids = inputs.input_ids
122
+ token_type_ids = inputs.token_type_ids
123
+
124
+ return {
125
+ 'input_ids': input_ids,
126
+ 'token_type_ids': token_type_ids,
127
+ 'input_audio_embeds': inputs.input_audio_embeds,
128
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
129
+ 'input_modes': inputs.input_modes,
130
+ 'answer': answer_text,
131
+ }
132
+
133
+
134
+ # Libri Speech Dataset Class
135
+ class LibriSpeechDataset(BaseAudioDataset):
136
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
137
+ super().__init__(processor, split, sampling_rate, debug)
138
+
139
+ self.set_dataset_name(f"LibriSpeech_{subset}")
140
+ # only ASR
141
+ self.ast = False
142
+ self.lang = "en"
143
+
144
+ # load dataset
145
+ self.data = load_dataset("openslr/librispeech_asr",
146
+ subset,
147
+ split=split,
148
+ trust_remote_code=True,
149
+ cache_dir=Path("/mnt/jeff/InCar/data")
150
+ )
151
+
152
+ # (Optional) Audio length Filtering
153
+ self.data = self.filter_by_audio_length(self.data, "audio")
154
+
155
+ # Instruction Setting
156
+ self.instruction = INSTRUCTION["asr"]
157
+
158
+ def __len__(self):
159
+ return len(self.data)
160
+
161
+ def __getitem__(self, idx):
162
+ data = self.data[idx]
163
+
164
+ # Libri Speech is only for ASR
165
+ answer_text = data["text"].replace('"', '')
166
+
167
+ return self.prepare_model_inputs(
168
+ data["audio"]["array"],
169
+ INSTRUCTION["asr"],
170
+ answer_text
171
+ )
172
+
173
+ # common_voice_16_1 dataset
174
+ class CommonVoiceDataset(BaseAudioDataset):
175
+ def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
176
+ super().__init__(processor, split, sampling_rate, debug)
177
+
178
+ self.set_dataset_name(f"CommonVoice_{source_lang}")
179
+ # only ASR
180
+ self.ast = False
181
+ self.lang=source_lang
182
+
183
+ # load dataset
184
+ self.data = load_dataset("mozilla-foundation/common_voice_16_1",
185
+ source_lang,
186
+ split=split,
187
+ trust_remote_code=True,
188
+ cache_dir=Path("/mnt/jeff/InCar/data")
189
+ )
190
+ def prepare_dataset(batch):
191
+ """Function to preprocess the dataset with the .map method"""
192
+ transcription = batch["sentence"]
193
+
194
+ if transcription.startswith('"') and transcription.endswith('"'):
195
+ # we can remove trailing quotation marks as they do not affect the transcription
196
+ transcription = transcription[1:-1]
197
+
198
+ if transcription[-1] not in [".", "?", "!"]:
199
+ # append a full-stop to sentences that do not end in punctuation
200
+ transcription = transcription + "."
201
+
202
+ batch["sentence"] = transcription
203
+
204
+ return batch
205
+ self.data=self.data.map(prepare_dataset, desc="preprocess dataset")
206
+
207
+ # (Optional) Audio length Filtering
208
+ self.data = self.filter_by_audio_length(self.data, "audio")
209
+
210
+ # Instruction Setting
211
+ self.instruction = INSTRUCTION["asr"]
212
+
213
+ def __len__(self):
214
+ return len(self.data)
215
+
216
+ def __getitem__(self, idx):
217
+ data = self.data[idx]
218
+
219
+ answer_text = data["sentence"]
220
+ return self.prepare_model_inputs(
221
+ data["audio"]["array"],
222
+ INSTRUCTION["asr"],
223
+ answer_text
224
+ )
225
+
226
+
227
+ # Fleurs Dataset Class
228
+ class FleursDataset(BaseAudioDataset):
229
+ def __init__(self, processor, split, source_lang, target_lang=None,
230
+ mode="asr", sampling_rate=16000, debug=False):
231
+ super().__init__(processor, split, sampling_rate, debug)
232
+
233
+ self.set_dataset_name("Fleurs")
234
+ # Mode Setting (ASR or AST)
235
+ if mode not in ["asr", "ast"]:
236
+ raise ValueError("mode must be 'asr' or 'ast'.")
237
+
238
+ self.mode = mode
239
+ self.ast = (mode == "ast")
240
+ self.source_lang = source_lang
241
+
242
+ # Language name mapping (expand if needed)
243
+ self.lang_names = {
244
+ 'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
245
+ }
246
+
247
+ # load dataset - source language dataset
248
+ self.data = load_dataset("google/fleurs",
249
+ source_lang,
250
+ split=split,
251
+ trust_remote_code=True,
252
+ cache_dir=Path("/mnt/jeff/InCar/data")
253
+ )
254
+ def prepare_dataset(batch):
255
+ import opencc
256
+ converter = opencc.OpenCC('s2tw.json')
257
+ if self.ast:
258
+ translation = converter.convert(batch["translation"])
259
+ batch["translation"] = translation
260
+ else:
261
+ transcription = converter.convert(batch["transcription"])
262
+ batch["transcription"] = transcription
263
+
264
+ return batch
265
+ if (source_lang=="cmn_hans_cn" and not self.ast) or (self.ast and target_lang=="cmn_hans_cn"):
266
+ self.data=self.data.map(prepare_dataset, desc="preprocess dataset")
267
+
268
+ # (Optional) Audio length Filtering
269
+ self.data = self.filter_by_audio_length(self.data, "audio")
270
+ self.target_lang_name = ""
271
+ # When AST mode, load target language dataset.
272
+ if self.ast:
273
+ if target_lang is None:
274
+ raise ValueError("AST mode requires target_lang.")
275
+
276
+ self.target_lang = target_lang
277
+ self.lang = f"{source_lang}_{target_lang}"
278
+
279
+ # load dataset - target language dataset (for translation)
280
+ target_data = load_dataset("google/fleurs",
281
+ target_lang,
282
+ split=split,
283
+ trust_remote_code=True,
284
+ cache_dir=Path("/mnt/jeff/InCar/data")
285
+ )
286
+
287
+ source_dict = {item['id']: item for item in self.data}
288
+ target_dict = {item['id']: item for item in target_data}
289
+
290
+ # only Common ID, add translation fields
291
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
292
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
293
+ self.data = [
294
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
295
+ for id in common_ids
296
+ ]
297
+
298
+ # Instruction Setting - use target language name
299
+ self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
300
+ self.instruction = INSTRUCTION["ast"]
301
+ else:
302
+ # ASR mode
303
+ self.lang = source_lang
304
+ self.instruction = INSTRUCTION["asr"]
305
+
306
+ if self.debug:
307
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
308
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
309
+ if self.ast:
310
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
311
+ print(f"dataset size: {len(self.data)}")
312
+
313
+ def __len__(self):
314
+ return len(self.data)
315
+
316
+ def __getitem__(self, idx):
317
+ data = self.data[idx]
318
+ audio_array = data["audio"]["array"]
319
+
320
+ if self.ast:
321
+ answer_text = data["translation"]
322
+ else:
323
+ answer_text = data["transcription"]
324
+
325
+ return self.prepare_model_inputs(
326
+ audio_array,
327
+ self.instruction.format(self.target_lang_name),
328
+ answer_text
329
+ )
330
+
331
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
332
+ """
333
+ Pad a list of sequences to the same length.
334
+ sequences: list of tensors in [seq_len, *] shape
335
+ """
336
+ assert padding_side in ['right', 'left']
337
+ max_size = sequences[0].size()
338
+ trailing_dims = max_size[1:]
339
+ max_len = max(len(seq) for seq in sequences)
340
+ batch_size = len(sequences)
341
+ output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
342
+ for i, seq in enumerate(sequences):
343
+ length = seq.size(0)
344
+ if padding_side == 'right':
345
+ output.data[i, :length] = seq
346
+ else:
347
+ output.data[i, -length:] = seq
348
+ return output
349
+
350
+ def cat_with_pad(tensors, dim, padding_value=0):
351
+ """
352
+ cat along dim, while pad to max for all other dims
353
+ """
354
+ ndim = tensors[0].dim()
355
+ assert all(
356
+ t.dim() == ndim for t in tensors[1:]
357
+ ), 'All tensors must have the same number of dimensions'
358
+
359
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
360
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
361
+ output = tensors[0].new_full(out_size, padding_value)
362
+
363
+ index = 0
364
+ for t in tensors:
365
+ # Create a slice list where every dimension except dim is full slice
366
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
367
+ # Update only the concat dimension slice
368
+ slices[dim] = slice(index, index + t.shape[dim])
369
+
370
+ output[slices] = t
371
+ index += t.shape[dim]
372
+
373
+ return output
374
+
375
+ def covost_collate_fn(batch):
376
+ input_ids_list = []
377
+ input_audio_embeds_list = []
378
+ audio_embed_sizes_list = []
379
+ audio_attention_mask_list = []
380
+ input_modes_list = []
381
+ answer_list = []
382
+ for inputs in batch:
383
+ input_ids_list.append(inputs['input_ids'][0])
384
+ input_audio_embeds_list.append(inputs['input_audio_embeds'])
385
+ audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
386
+ audio_attention_mask_list.append(
387
+ inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
388
+ )
389
+ input_modes_list.append(inputs['input_modes'])
390
+ answer_list.append(inputs['answer'])
391
+
392
+ try:
393
+ input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
394
+ audio_attention_mask = (
395
+ pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
396
+ if len(audio_attention_mask_list) > 1
397
+ else None
398
+ )
399
+ except Exception as e:
400
+ print(e)
401
+ print(input_ids_list)
402
+ print(audio_attention_mask)
403
+ raise
404
+ attention_mask = (input_ids != 0).long()
405
+ input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
406
+ audio_embed_sizes = torch.cat(audio_embed_sizes_list)
407
+ input_modes = torch.cat(input_modes_list)
408
+
409
+ return BatchFeature(
410
+ {
411
+ 'input_ids': input_ids,
412
+ 'attention_mask': attention_mask,
413
+ 'input_audio_embeds': input_audio_embeds,
414
+ 'audio_embed_sizes': audio_embed_sizes,
415
+ 'audio_attention_mask': audio_attention_mask,
416
+ 'input_modes': input_modes,
417
+ 'answer': answer_list,
418
+ }
419
+ )
420
+
421
+ def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
422
+ filename = f"{task}_{dataset_name}_{source_lang}"
423
+ if target_lang:
424
+ filename += f"_to_{target_lang}"
425
+ if sample_idx is not None:
426
+ filename += f"_sample_{sample_idx}"
427
+
428
+ filepath = os.path.join(results_dir, f"{filename}.json")
429
+
430
+ results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
431
+
432
+ with open(filepath, 'w', encoding='utf-8') as f:
433
+ json.dump(results, f, ensure_ascii=False, indent=2)
434
+
435
+ return filepath
436
+
437
+ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True):
438
+ task_type = "asr" if is_asr else "translation"
439
+ eval_lang = source_lang if is_asr else target_lang
440
+ if eval_lang in normalizer:
441
+ eval_normalizer = normalizer[eval_lang]
442
+ else:
443
+ eval_normalizer = normalizer['other']
444
+ sample_results = []
445
+
446
+ if num_samples > 0 and num_samples < len(dataset):
447
+ indices = np.random.choice(len(dataset), num_samples, replace=False)
448
+ dataset = dataset.select(indices)
449
+
450
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn)
451
+
452
+ evaluated_samples = {}
453
+
454
+ for batch_idx, batch in enumerate(tqdm(dataloader)):
455
+ batch_references = batch.pop("answer")
456
+
457
+ if torch.cuda.is_available():
458
+ try:
459
+ batch = {k: v.to("cuda") for k, v in batch.items()}
460
+ except:
461
+ print('error')
462
+ break
463
+
464
+ with torch.inference_mode():
465
+ generate_ids = model.generate(**batch,
466
+ max_new_tokens=256,
467
+ #temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
468
+ )
469
+
470
+ input_lengths = batch['input_ids'].shape[1]
471
+ generate_ids = generate_ids[:, input_lengths:]
472
+
473
+ batch_predictions = processor.batch_decode(
474
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
475
+ )
476
+
477
+ for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
478
+ idx = batch_idx * batch_size + i
479
+ sample_result = {
480
+ "id": idx,
481
+ "reference": reference,
482
+ "prediction": converter.convert(prediction)
483
+ }
484
+ sample_results.append(sample_result)
485
+
486
+ if (batch_idx + 1) % 10 == 0:
487
+ temp_results = []
488
+
489
+ for item in sample_results:
490
+ sample_id = item["id"]
491
+
492
+ if sample_id in evaluated_samples:
493
+ temp_item = item.copy()
494
+ temp_item.update(evaluated_samples[sample_id])
495
+ temp_results.append(temp_item)
496
+ else:
497
+ temp_item = item.copy()
498
+ try:
499
+ ref = eval_normalizer(item["reference"])
500
+ pred = eval_normalizer(item["prediction"])
501
+
502
+ # BLEU, WER/CER
503
+ utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
504
+ utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
505
+ utt_wer = round(wer(ref, pred) * 100, 2)
506
+
507
+ metrics = {
508
+ "bleu": utt_bleu,
509
+ "cer": min(100,utt_cer),
510
+ "wer": utt_wer
511
+ }
512
+
513
+ evaluated_samples[sample_id] = metrics
514
+ temp_item.update(metrics)
515
+ except Exception as e:
516
+ print(f"Error evaluating sample {sample_id}: {e}")
517
+ metrics = {
518
+ "bleu": 0,
519
+ "cer": 100,
520
+ "wer": 100,
521
+ "error": str(e)
522
+ }
523
+ evaluated_samples[sample_id] = metrics
524
+ temp_item.update(metrics)
525
+
526
+ temp_results.append(temp_item)
527
+
528
+ partial_results = {
529
+ "task": task_type,
530
+ "source_lang": source_lang,
531
+ "target_lang": target_lang,
532
+ "num_samples": len(temp_results),
533
+ "sample_results": temp_results
534
+ }
535
+ save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
536
+
537
+ for item in sample_results:
538
+ ref = eval_normalizer(item["reference"])
539
+ pred = eval_normalizer(item["prediction"])
540
+
541
+ utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
542
+ utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
543
+ utt_wer = round(wer(ref, pred) * 100, 2)
544
+
545
+ item.update({
546
+ "bleu": utt_bleu,
547
+ "cer": min(100,utt_cer),
548
+ "wer": utt_wer
549
+ })
550
+
551
+ avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
552
+ avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
553
+ avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
554
+
555
+ results = {
556
+ "dataset": dataset.name,
557
+ "task": task_type,
558
+ "source_lang": source_lang,
559
+ "target_lang": target_lang,
560
+ "num_samples": len(sample_results),
561
+ "metrics": {
562
+ "bleu": avg_bleu,
563
+ "cer": avg_cer,
564
+ "wer": avg_wer
565
+ },
566
+ "sample_results": sample_results
567
+ }
568
+
569
+ save_results(results, dataset.name, task_type, source_lang, target_lang)
570
+ return results
571
+
572
+
573
+ if __name__ == "__main__":
574
+
575
+ source_languages = [
576
+ ("en_us", "English"),
577
+ ]
578
+
579
+ target_languages = [
580
+ ("zh-TW", "zh-TW"),
581
+ ]
582
+
583
+ num_samples = -1
584
+ batch_size = 32
585
+
586
+ for source_lang, target_lang in zip(source_languages, target_languages):
587
+ print(f"\n===== {source_lang[0]} ASR =====")
588
+
589
+ split = "test"
590
+
591
+ datasets = []
592
+
593
+
594
+
595
+ commonvoice_speech_tw = CommonVoiceDataset(
596
+ processor=processor,
597
+ source_lang="zh-TW",
598
+ split=split
599
+ )
600
+ datasets.append(commonvoice_speech_tw)
601
+ fleurs = FleursDataset(
602
+ processor=processor,
603
+ split=split,
604
+ source_lang="en_us", # English
605
+ mode="asr"
606
+ )
607
+ datasets.append(fleurs)
608
+
609
+ # Libri Speech Clean ASR mode (English -> English text)
610
+ # libri_speech_clean = LibriSpeechDataset(
611
+ # processor=processor,
612
+ # subset="clean",
613
+ # split=split
614
+ # )
615
+ # datasets.append(libri_speech_clean)
616
+
617
+ # # Libri Speech Other ASR mode (English -> English text)
618
+ # libri_speech_other = LibriSpeechDataset(
619
+ # processor=processor,
620
+ # subset="other",
621
+ # split=split
622
+ # )
623
+ # datasets.append(libri_speech_other)
624
+
625
+ # Fleurs ASR mode (English -> English text)
626
+
627
+
628
+ for dataset in datasets:
629
+ # ASR
630
+ asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
631
+
632
+ print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR===")
633
+ print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
634
+ print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
635
+ print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
eval_multiturn.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
eval_multiturn.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from urllib.request import urlopen
3
+ import soundfile
4
+ import torch
5
+ from datasets import load_dataset, Audio
6
+ import numpy as np
7
+ from transformers import AutoModel, AutoProcessor, BatchFeature
8
+ from tqdm import tqdm
9
+ import json
10
+ import os
11
+ import time
12
+ from datetime import datetime
13
+ from whisper_normalizer.english import EnglishTextNormalizer
14
+ from whisper_normalizer.basic import BasicTextNormalizer
15
+ import sacrebleu
16
+ from jiwer import cer, wer
17
+ from torch.utils.data import Dataset, DataLoader
18
+ import soundfile as sf
19
+ import re
20
+ from pathlib import Path
21
+ import opencc
22
+ from ASRDataset import *
23
+
24
+ converter = opencc.OpenCC('s2tw.json')
25
+ normalizer = {
26
+ "en_us" : EnglishTextNormalizer(),
27
+ "other" : BasicTextNormalizer()
28
+ }
29
+
30
+ model_id = "/mnt/jeff/gemma_test"
31
+ revision = "main" #"v1.0"
32
+
33
+ model = AutoModel.from_pretrained(
34
+ model_id, device_map="cuda", revision = revision, trust_remote_code=True
35
+ ).eval()
36
+
37
+ processor = AutoProcessor.from_pretrained(
38
+ model_id, revision = revision, trust_remote_code=True
39
+ )
40
+
41
+ results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
42
+ os.makedirs(results_dir, exist_ok=True)
43
+
44
+
45
+ INSTRUCTION = {
46
+ "ast": "Translate the audio to {0}.",
47
+ "asr": "Transcribe the audio clip into text.",
48
+ }
49
+
50
+
51
+
52
+ def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
53
+ filename = f"{task}_{dataset_name}_{source_lang}"
54
+ if target_lang:
55
+ filename += f"_to_{target_lang}"
56
+ if sample_idx is not None:
57
+ filename += f"_sample_{sample_idx}"
58
+
59
+ filepath = os.path.join(results_dir, f"{filename}.json")
60
+
61
+ results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
62
+
63
+ with open(filepath, 'w', encoding='utf-8') as f:
64
+ json.dump(results, f, ensure_ascii=False, indent=2)
65
+
66
+ return filepath
67
+
68
+ def evaluate_task(dataset):
69
+ sample_results = []
70
+
71
+
72
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
73
+
74
+ evaluated_samples = {}
75
+
76
+ for batch_idx, batch in enumerate(tqdm(dataloader)):
77
+
78
+ if torch.cuda.is_available():
79
+ try:
80
+ batch = {k: v.to("cuda") for k, v in batch.items()}
81
+ except:
82
+ print('error')
83
+ break
84
+
85
+ with torch.inference_mode():
86
+ generate_ids = model.generate(**batch,
87
+ max_new_tokens=256,
88
+ #temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
89
+ )
90
+
91
+ input_lengths = batch['input_ids'].shape[1]
92
+ generate_ids = generate_ids[:, input_lengths:]
93
+
94
+ batch_predictions = processor.batch_decode(
95
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
96
+ )
97
+ input_lengths = batch['input_ids'].shape[1]
98
+ label_ids = generate_ids[:, input_lengths:]
99
+ batch_references = processor.batch_decode(
100
+ label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
101
+ )
102
+
103
+ for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
104
+ idx = batch_idx + i
105
+ sample_result = {
106
+ "id": idx,
107
+ "reference": reference,
108
+ "prediction": converter.convert(prediction)
109
+ }
110
+ sample_results.append(sample_result)
111
+
112
+ if (batch_idx + 1) % 10 == 0:
113
+ temp_results = []
114
+
115
+ for item in sample_results:
116
+ sample_id = item["id"]
117
+
118
+ if sample_id in evaluated_samples:
119
+ temp_item = item.copy()
120
+ temp_item.update(evaluated_samples[sample_id])
121
+ temp_results.append(temp_item)
122
+ else:
123
+ temp_item = item.copy()
124
+ try:
125
+ ref = eval_normalizer(item["reference"])
126
+ pred = eval_normalizer(item["prediction"])
127
+
128
+ # BLEU, WER/CER
129
+ utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
130
+ utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
131
+ utt_wer = round(wer(ref, pred) * 100, 2)
132
+
133
+ metrics = {
134
+ "bleu": utt_bleu,
135
+ "cer": min(100,utt_cer),
136
+ "wer": utt_wer
137
+ }
138
+
139
+ evaluated_samples[sample_id] = metrics
140
+ temp_item.update(metrics)
141
+ except Exception as e:
142
+ print(f"Error evaluating sample {sample_id}: {e}")
143
+ metrics = {
144
+ "bleu": 0,
145
+ "cer": 100,
146
+ "wer": 100,
147
+ "error": str(e)
148
+ }
149
+ evaluated_samples[sample_id] = metrics
150
+ temp_item.update(metrics)
151
+
152
+ temp_results.append(temp_item)
153
+
154
+ partial_results = {
155
+ "task": task_type,
156
+ "source_lang": source_lang,
157
+ "target_lang": target_lang,
158
+ "num_samples": len(temp_results),
159
+ "sample_results": temp_results
160
+ }
161
+ save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
162
+
163
+ for item in sample_results:
164
+ ref = eval_normalizer(item["reference"])
165
+ pred = eval_normalizer(item["prediction"])
166
+
167
+ utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
168
+ utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
169
+ utt_wer = round(wer(ref, pred) * 100, 2)
170
+
171
+ item.update({
172
+ "bleu": utt_bleu,
173
+ "cer": min(100,utt_cer),
174
+ "wer": utt_wer
175
+ })
176
+
177
+ avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
178
+ avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
179
+ avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
180
+
181
+ results = {
182
+ "dataset": dataset.name,
183
+ "task": task_type,
184
+ "source_lang": source_lang,
185
+ "target_lang": target_lang,
186
+ "num_samples": len(sample_results),
187
+ "metrics": {
188
+ "bleu": avg_bleu,
189
+ "cer": avg_cer,
190
+ "wer": avg_wer
191
+ },
192
+ "sample_results": sample_results
193
+ }
194
+
195
+ save_results(results, dataset.name, task_type, source_lang, target_lang)
196
+ return results
197
+
198
+
199
+ if __name__ == "__main__":
200
+
201
+ datasets = []
202
+ pickup_dataset = MultiturnAudioDataset(split='eval',processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
203
+ datasets.append(pickup_dataset)
204
+ for dataset in datasets:
205
+ # ASR
206
+ asr_results = evaluate_task(dataset)
207
+
208
+ print(f"\n=== {asr_results.get('dataset', 'Dataset')}")
209
+ print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
210
+ print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
211
+ print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
merge_lora.ipynb ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from safetensors import safe_open\n",
10
+ "\n",
11
+ "lora = {}\n",
12
+ "with safe_open(\"/data2/bjh/diffusion-pipe/cosmos_test/20250327_02-37-25/epoch5/adapter_model.safetensors\", framework=\"pt\", device='cpu') as f:\n",
13
+ " for k in f.keys():\n",
14
+ " lora[k] = f.get_tensor(k)"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "tensors = {}\n",
24
+ "with safe_open(\"/data2/bjh/ComfyUI/models/diffusion_models/Cosmos-1_0-Diffusion-14B-Text2World.safetensors\", framework=\"pt\", device='cpu') as f:\n",
25
+ " for k in f.keys():\n",
26
+ " tensors[k] = f.get_tensor(k)"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 3,
32
+ "metadata": {},
33
+ "outputs": [
34
+ {
35
+ "data": {
36
+ "text/plain": [
37
+ "1152"
38
+ ]
39
+ },
40
+ "execution_count": 3,
41
+ "metadata": {},
42
+ "output_type": "execute_result"
43
+ }
44
+ ],
45
+ "source": [
46
+ "len(lora)"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 4,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "name_lis = []\n",
56
+ "for k in lora:\n",
57
+ " a = k.split('.')[1:][:-2]\n",
58
+ " name = '.'.join(a)\n",
59
+ " name_lis.append(name)\n",
60
+ "name_lis=set(name_lis)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 5,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "import torch\n",
70
+ "new_dic = {}\n",
71
+ "for k in tensors:\n",
72
+ " name='.'.join(k.split('.')[1:][:-1])\n",
73
+ " if name in name_lis:\n",
74
+ " a,b = lora['diffusion_model.'+name+'.lora_A.weight'],lora['diffusion_model.'+name+'.lora_B.weight']\n",
75
+ " new_dic[k]=tensors[k]+torch.matmul(b,a)\n",
76
+ " else:\n",
77
+ " new_dic[k]=tensors[k]"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 6,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "from safetensors.torch import save_file\n",
87
+ "save_file(new_dic,'test.safetensors')"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": []
96
+ }
97
+ ],
98
+ "metadata": {
99
+ "kernelspec": {
100
+ "display_name": "dp",
101
+ "language": "python",
102
+ "name": "python3"
103
+ },
104
+ "language_info": {
105
+ "codemirror_mode": {
106
+ "name": "ipython",
107
+ "version": 3
108
+ },
109
+ "file_extension": ".py",
110
+ "mimetype": "text/x-python",
111
+ "name": "python",
112
+ "nbconvert_exporter": "python",
113
+ "pygments_lexer": "ipython3",
114
+ "version": "3.12.9"
115
+ }
116
+ },
117
+ "nbformat": 4,
118
+ "nbformat_minor": 2
119
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddd3e8916f7ad6ad92651ac288227995c1d34628f0f888eb2dc5b9acb4dc0121
3
+ size 4976361384
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c343a455ca768923cb3b9ab77cbb91c9cd2526a1bee5740cf9cf86bfa85a0a7b
3
+ size 4984907872
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad04449d015f4efbda75d6cc41e06296b4da996cd84053fa6f9791fe16d55d03
3
+ size 732141104
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_gemma3omni.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, HybridCache, StaticCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
14
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
15
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
16
+ from transformers.processing_utils import Unpack
17
+ from transformers.utils import (
18
+ add_start_docstrings,
19
+ add_start_docstrings_to_model_forward,
20
+ is_torchdynamo_compiling,
21
+ logging,
22
+ replace_return_docstrings,
23
+ )
24
+ from transformers.utils.deprecation import deprecate_kwarg
25
+ from transformers import AutoModel, AutoModelForCausalLM
26
+
27
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast, Gemma3PreTrainedModel, Gemma3MultiModalProjector
28
+
29
+ from transformers import AutoConfig, AutoModelForCausalLM
30
+
31
+ from .configuration_gemma3omni import Gemma3OmniConfig
32
+ from .speech_conformer_encoder import ConformerEncoder
33
+ from enum import Enum
34
+ class InputMode(Enum):
35
+ LANGUAGE = 0
36
+ VISION = 1
37
+ SPEECH = 2
38
+ VISION_SPEECH = 3
39
+ logger = logging.get_logger(__name__)
40
+ _CONFIG_FOR_DOC = "Gemma3OmniConfig"
41
+
42
+ @dataclass
43
+ class Gemma3OmniCausalLMOutputWithPast(Gemma3CausalLMOutputWithPast):
44
+ """
45
+ Multimodal version of `Gemma3CausalLMOutputWithPast`.
46
+ Adds audio-specific hidden states.
47
+
48
+ Args:
49
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
50
+ A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
51
+ Audio hidden states produced by the audio encoder.
52
+ """
53
+ audio_hidden_states: Optional[torch.FloatTensor] = None
54
+
55
+
56
+ GEMMA3_START_DOCSTRING = r"""
57
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
58
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
59
+ etc.)
60
+
61
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
62
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
63
+ and behavior.
64
+
65
+ Parameters:
66
+ config ([`Gemma3Config`]):
67
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
68
+ load the weights associated with the model, only the configuration. Check out the
69
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
70
+ """
71
+
72
+
73
+
74
+ GEMMA3_INPUTS_DOCSTRING = r"""
75
+ Args:
76
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
77
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
78
+ it.
79
+
80
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
81
+ [`PreTrainedTokenizer.__call__`] for details.
82
+
83
+ [What are input IDs?](../glossary#input-ids)
84
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
85
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
86
+
87
+ - 1 for tokens that are **not masked**,
88
+ - 0 for tokens that are **masked**.
89
+
90
+ [What are attention masks?](../glossary#attention-mask)
91
+
92
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
93
+ [`PreTrainedTokenizer.__call__`] for details.
94
+
95
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
96
+ `past_key_values`).
97
+
98
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
99
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
100
+ information on the default strategy.
101
+
102
+ - 1 indicates the head is **not masked**,
103
+ - 0 indicates the head is **masked**.
104
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
105
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
106
+ config.n_positions - 1]`.
107
+
108
+ [What are position IDs?](../glossary#position-ids)
109
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
110
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
111
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
112
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
113
+
114
+ Two formats are allowed:
115
+ - a [`~cache_utils.Cache`] instance, see our
116
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
117
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
118
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
119
+ cache format.
120
+
121
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
122
+ legacy cache format will be returned.
123
+
124
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
125
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
126
+ of shape `(batch_size, sequence_length)`.
127
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
128
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
129
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
130
+ model's internal embedding lookup matrix.
131
+ use_cache (`bool`, *optional*):
132
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
133
+ `past_key_values`).
134
+ output_attentions (`bool`, *optional*):
135
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
136
+ tensors for more detail.
137
+ output_hidden_states (`bool`, *optional*):
138
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
139
+ more detail.
140
+ return_dict (`bool`, *optional*):
141
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
142
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
143
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
144
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
145
+ the complete sequence length.
146
+ """
147
+
148
+ @add_start_docstrings(
149
+ "The bare Gemma3 Model outputting raw hidden-states without any specific head on top.",
150
+ GEMMA3_START_DOCSTRING,
151
+ )
152
+ class Gemma3OmniPreTrainedModel(Gemma3PreTrainedModel):
153
+ config_class = Gemma3OmniConfig
154
+
155
+ @add_start_docstrings(
156
+ """The GEMMA3 model which consists of a vision backbone and a language model.""",
157
+ GEMMA3_START_DOCSTRING,
158
+ )
159
+ class Gemma3OmniForConditionalGeneration(Gemma3OmniPreTrainedModel, GenerationMixin):
160
+ def __init__(self, config: Gemma3OmniConfig):
161
+ super().__init__(config)
162
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
163
+ audio_config = config.audio_config.to_diff_dict()
164
+ for item in ['transformers_version', 'model_type', 'torch_dtype']:
165
+ if item in audio_config:
166
+ audio_config.pop(item)
167
+ self.audio_tower = ConformerEncoder(**audio_config)
168
+ self.audio_tower.post_init({})
169
+ self.audio_tower = self.audio_tower.to(dtype=self.dtype)
170
+ self.audio_projector = nn.Sequential(
171
+ nn.Linear(in_features=config.audio_config.attention_dim, out_features=config.text_config.hidden_size, bias=True),
172
+ nn.GELU(approximate='none'),
173
+ nn.Linear(in_features=config.text_config.hidden_size, out_features=config.text_config.hidden_size, bias=True)
174
+ ).to(dtype=self.dtype)
175
+
176
+ self.multi_modal_projector = Gemma3MultiModalProjector(config)
177
+ self.vocab_size = config.text_config.vocab_size
178
+
179
+ language_model = AutoModelForCausalLM.from_config(config=config.text_config)
180
+
181
+ if language_model._tied_weights_keys is not None:
182
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
183
+ self.language_model = language_model
184
+
185
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
186
+ self.init_lora()
187
+ self.post_init()
188
+
189
+
190
+ def init_lora(self):
191
+ from peft import LoraConfig, get_peft_model
192
+ import warnings
193
+ print('######################## speech lora #############')
194
+ speech_lora_config = LoraConfig(
195
+ r=self.config.speech_lora['r'],
196
+ lora_alpha=self.config.speech_lora['lora_alpha'],
197
+ target_modules=self.config.speech_lora['layer'],
198
+ use_rslora=self.config.speech_lora['use_rslora'],
199
+ lora_dropout=self.config.speech_lora['dp'],
200
+ task_type="CAUSAL_LM",
201
+ )
202
+ self.language_model.model = get_peft_model(self.language_model.model, speech_lora_config, adapter_name="speech")
203
+ print('######################## text lora #############')
204
+ text_lora_config = LoraConfig(
205
+ r=self.config.text_lora['r'],
206
+ lora_alpha=self.config.text_lora['lora_alpha'],
207
+ target_modules=self.config.text_lora['layer'],
208
+ use_rslora=self.config.text_lora['use_rslora'],
209
+ lora_dropout=self.config.text_lora['dp'],
210
+ task_type="CAUSAL_LM",
211
+ )
212
+ self.language_model.model.base_model.active_adapter.append("text")
213
+ self.language_model.model.add_adapter("text", text_lora_config)
214
+
215
+ def set_lora_adapter(self, adapter_name) -> None:
216
+ from peft.tuners.lora.layer import LoraLayer
217
+ for module in self.modules():
218
+ if isinstance(module, LoraLayer):
219
+ if module.merged:
220
+ warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
221
+ module.unmerge()
222
+ module.set_adapter(adapter_name)
223
+ module._disable_adapters = False
224
+
225
+ def unset_lora_adapter(self) -> None:
226
+ # Ref: peft/tuners/tuners_utils.py - enable_adapters()
227
+ # Ref: peft/tuners/lora/layer.py
228
+ from peft.tuners.lora.layer import LoraLayer
229
+ for module in self.modules():
230
+ if isinstance(module, LoraLayer):
231
+ # disable grads on all adapter layers
232
+ # TODO weijian: may use enable_adapters() instead
233
+ for layer_name in module.adapter_layer_names:
234
+ layer = getattr(module, layer_name)
235
+ layer.requires_grad_(False)
236
+ module._disable_adapters = True
237
+
238
+ def get_input_embeddings(self):
239
+ return self.language_model.get_input_embeddings()
240
+
241
+ def set_input_embeddings(self, value):
242
+ self.language_model.set_input_embeddings(value)
243
+
244
+ def get_output_embeddings(self):
245
+ return self.language_model.get_output_embeddings()
246
+
247
+ def set_output_embeddings(self, new_embeddings):
248
+ self.language_model.set_output_embeddings(new_embeddings)
249
+
250
+ def set_decoder(self, decoder):
251
+ self.language_model.set_decoder(decoder)
252
+
253
+ def get_decoder(self):
254
+ return self.language_model.get_decoder()
255
+
256
+ def _update_causal_mask(
257
+ self,
258
+ attention_mask,
259
+ token_type_ids,
260
+ past_key_values,
261
+ cache_position,
262
+ input_tensor,
263
+ is_training: bool = False,
264
+ ):
265
+ if self.config.text_config._attn_implementation == "flash_attention_2":
266
+ return attention_mask
267
+
268
+ if attention_mask is not None and attention_mask.dim() == 4:
269
+ # In this case we assume that the mask comes already in inverted
270
+ # form and requires no inversion or slicing.
271
+ return attention_mask
272
+
273
+ using_static_cache = isinstance(past_key_values, StaticCache)
274
+ min_dtype = torch.finfo(self.dtype).min
275
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
276
+ if using_static_cache:
277
+ target_length = past_key_values.get_max_cache_shape()
278
+ elif isinstance(past_key_values, HybridCache):
279
+ target_length = past_key_values.get_max_cache_shape()
280
+ else:
281
+ target_length = (
282
+ attention_mask.shape[-1]
283
+ if isinstance(attention_mask, torch.Tensor)
284
+ else cache_position[0] + sequence_length + 1
285
+ )
286
+
287
+ if attention_mask is not None and attention_mask.dim() == 4:
288
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
289
+ return attention_mask
290
+
291
+ causal_mask = torch.full(
292
+ (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
293
+ )
294
+
295
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
296
+ if sequence_length != 1:
297
+ causal_mask = torch.triu(causal_mask, diagonal=1)
298
+
299
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
300
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
301
+
302
+ # Apply bidirectional mask on images if token type ids are provided
303
+ if token_type_ids is not None and sequence_length != 1:
304
+ token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
305
+ token_type_mask[token_type_ids == 0] = False # if text token do not change anything
306
+ token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
307
+ causal_mask = causal_mask.clone()
308
+ causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
309
+ token_type_mask, 0.0
310
+ )
311
+
312
+ if attention_mask is not None:
313
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
314
+ mask_length = attention_mask.shape[-1]
315
+
316
+ # Then apply padding mask (will mask pad tokens)
317
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
318
+ padding_mask = padding_mask == 0
319
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
320
+ padding_mask, min_dtype
321
+ )
322
+
323
+ return causal_mask
324
+
325
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
326
+ """
327
+ Projects the last hidden state from the vision model into language model space.
328
+
329
+ Args:
330
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
331
+ The tensors corresponding to the input images.
332
+ Returns:
333
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
334
+ """
335
+ vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
336
+ image_features = self.multi_modal_projector(vision_outputs)
337
+ return image_features
338
+
339
+ def get_audio_features(self, input_audio_embeds: torch.FloatTensor, audio_attention_mask: torch.FloatTensor, audio_embed_sizes: torch.FloatTensor):
340
+ """
341
+ Projects the last hidden state from the audio model into language model space.
342
+
343
+ Args:
344
+ audio_inputs (`torch.FloatTensor]` of shape `(batch_size, sequence_length, feature_dim)`)
345
+ The tensors corresponding to the input audio features.
346
+
347
+ Returns:
348
+ audio_features (`torch.Tensor`): Audio feature tensor of shape `(batch_size, audio_length, embed_dim)`).
349
+ """
350
+ audio_features, masks = self.audio_tower(input_audio_embeds, audio_attention_mask)
351
+ audio_outputs = self.audio_projector(audio_features)
352
+ return audio_outputs
353
+
354
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
355
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
356
+ @replace_return_docstrings(output_type=Gemma3OmniCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
357
+ def forward(
358
+ self,
359
+ input_ids: Optional[torch.LongTensor] = None,
360
+ pixel_values: Optional[torch.FloatTensor] = None,
361
+ input_audio_embeds: torch.FloatTensor = None,
362
+ audio_embed_sizes: torch.FloatTensor = None,
363
+ audio_attention_mask: torch.FloatTensor = None,
364
+ attention_mask: Optional[torch.Tensor] = None,
365
+ input_modes: torch.LongTensor = None,
366
+ position_ids: Optional[torch.LongTensor] = None,
367
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
368
+ token_type_ids: Optional[torch.LongTensor] = None,
369
+ cache_position: Optional[torch.LongTensor] = None,
370
+ inputs_embeds: Optional[torch.FloatTensor] = None,
371
+ labels: Optional[torch.LongTensor] = None,
372
+ use_cache: Optional[bool] = None,
373
+ output_attentions: Optional[bool] = None,
374
+ output_hidden_states: Optional[bool] = None,
375
+ return_dict: Optional[bool] = None,
376
+ logits_to_keep: Union[int, torch.Tensor] = 0,
377
+ **lm_kwargs,
378
+ ) -> Union[Tuple, Gemma3OmniCausalLMOutputWithPast]:
379
+ r"""
380
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
381
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
382
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
383
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
384
+
385
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
386
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
387
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
388
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
389
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
390
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
391
+
392
+ Returns:
393
+
394
+ Example:
395
+
396
+ ```python
397
+ >>> from PIL import Image
398
+ >>> import requests
399
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
400
+
401
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
402
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
403
+
404
+ >>> messages = [
405
+ ... {
406
+ ... "role": "system",
407
+ ... "content": [
408
+ ... {"type": "text", "text": "You are a helpful assistant."}
409
+ ... ]
410
+ ... },
411
+ ... {
412
+ ... "role": "user", "content": [
413
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
414
+ ... {"type": "text", "text": "Where is the cat standing?"},
415
+ ... ]
416
+ ... },
417
+ ... ]
418
+
419
+ >>> inputs = processor.apply_chat_template(
420
+ ... messages,
421
+ ... tokenizer=True,
422
+ ... return_dict=True,
423
+ ... return_tensors="pt",
424
+ ... add_generation_prompt=True
425
+ ... )
426
+ >>> # Generate
427
+ >>> generate_ids = model.generate(**inputs)
428
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
429
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
430
+ ```
431
+ """
432
+
433
+ if (input_ids is None) ^ (inputs_embeds is not None):
434
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
435
+
436
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
437
+ output_hidden_states = (
438
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
439
+ )
440
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
441
+
442
+ if isinstance(input_modes, torch.Tensor):
443
+ # len(input_mode) == num_beams in beam search, and all elements of input_mode should have the same value
444
+ input_modes = input_modes.unique()
445
+ if len(input_modes) != 1:
446
+ raise ValueError("Elements of input_modes should have the same value")
447
+
448
+ input_mode = InputMode(input_modes.item())
449
+
450
+ if input_mode in [InputMode.VISION_SPEECH, InputMode.VISION]:
451
+ self.unset_lora_adapter()
452
+ #self.set_lora_adapter('vision')
453
+ #audio_projection_mode = 'vision'
454
+ elif input_mode == InputMode.SPEECH:
455
+ self.unset_lora_adapter()
456
+ self.set_lora_adapter('speech')
457
+ #audio_projection_mode = 'speech'
458
+ elif input_mode == InputMode.LANGUAGE:
459
+ self.unset_lora_adapter()
460
+ self.set_lora_adapter('text')
461
+
462
+ #audio_projection_mode = 'speech'
463
+ else:
464
+ raise ValueError(f"Invalid input_mode: {input_mode}")
465
+
466
+ is_training = token_type_ids is not None and labels is not None
467
+
468
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
469
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size or self.config.audio_token_index >= self.vocab_size:
470
+ special_image_mask = input_ids == self.config.image_token_index
471
+ special_audio_mask = input_ids == self.config.audio_token_index
472
+ llm_input_ids = input_ids.clone()
473
+ llm_input_ids[special_image_mask] = 0
474
+ llm_input_ids[special_audio_mask] = 0
475
+ else:
476
+ llm_input_ids = input_ids
477
+
478
+ if inputs_embeds is None:
479
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
480
+ inputs_embeds = inputs_embeds.to(dtype=self.dtype)
481
+ if cache_position is None:
482
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
483
+ cache_position = torch.arange(
484
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
485
+ )
486
+
487
+ if position_ids is None:
488
+ position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
489
+
490
+ # Merge text and images
491
+ if pixel_values is not None:
492
+ image_features = self.get_image_features(pixel_values)
493
+
494
+ if input_ids is None:
495
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
496
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
497
+ )
498
+ else:
499
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
500
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
501
+
502
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
503
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
504
+ raise ValueError(
505
+ f"Number of images does not match number of special image tokens in the input text. "
506
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
507
+ "tokens from image embeddings."
508
+ )
509
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
510
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
511
+
512
+ # Merge text and audios
513
+ if input_audio_embeds is not None:
514
+ input_audio_embeds=input_audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
515
+ if audio_attention_mask is not None:
516
+ audio_attention_mask=audio_attention_mask.to(inputs_embeds.device, inputs_embeds.dtype)
517
+ audio_features = self.get_audio_features(input_audio_embeds, audio_attention_mask, audio_embed_sizes)
518
+ if input_ids is None:
519
+ special_audio_mask = inputs_embeds == self.get_input_embeddings()(
520
+ torch.tensor(self.config.audio_token_index, dtype=torch.long, device=inputs_embeds.device)
521
+ )
522
+ else:
523
+ special_audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1)
524
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
525
+ masked_audio_features = []
526
+ for i, size in enumerate(audio_embed_sizes):
527
+ masked_audio_features.append(audio_features[i, :size, :])
528
+ masked_audio_features = torch.cat(masked_audio_features, dim=0)
529
+
530
+ if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != masked_audio_features.numel():
531
+ audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
532
+ masked_audio_size = audio_embed_sizes#.sum()[0]
533
+ raise ValueError(
534
+ f"Number of audio does not match number of special audio tokens in the input text. "
535
+ f"Got {audio_tokens_in_text} audio tokens in the text but {masked_audio_size} "
536
+ "tokens from audio embeddings. "
537
+ f"{masked_audio_features.numel()} \n"
538
+ f"{inputs_embeds[special_audio_mask].numel()} \n"
539
+ f"{audio_features} \n"
540
+ f"{inputs_embeds[special_audio_mask]} \n"
541
+ f"{special_audio_mask} \n"
542
+ )
543
+ masked_audio_features = masked_audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
544
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, masked_audio_features)
545
+ # mask out pad-token-ids in labels for BC
546
+ if labels is not None and self.pad_token_id in labels:
547
+ logger.warning_once(
548
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
549
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
550
+ )
551
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
552
+
553
+ causal_mask = self._update_causal_mask(
554
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
555
+ )
556
+ outputs = self.language_model(
557
+ attention_mask=causal_mask,
558
+ position_ids=position_ids,
559
+ past_key_values=past_key_values,
560
+ inputs_embeds=inputs_embeds,
561
+ use_cache=use_cache,
562
+ output_attentions=output_attentions,
563
+ output_hidden_states=output_hidden_states,
564
+ return_dict=return_dict,
565
+ cache_position=cache_position,
566
+ logits_to_keep=logits_to_keep,
567
+ **lm_kwargs,
568
+ )
569
+
570
+ logits = outputs.logits
571
+ loss = None
572
+ # print('#############################')
573
+ # print(logits)
574
+ if labels is not None:
575
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
576
+ logits = logits.float()
577
+ shift_logits = logits[..., :-1, :]
578
+ shift_labels = labels[..., 1:]
579
+ if attention_mask is not None:
580
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
581
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
582
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
583
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
584
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
585
+ else:
586
+ shift_logits = shift_logits.contiguous()
587
+ shift_labels = shift_labels.contiguous()
588
+ # Flatten the tokens
589
+ loss_fct = nn.CrossEntropyLoss()
590
+
591
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
592
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
593
+ loss = loss_fct(flat_logits, flat_labels)
594
+ # print('flat logits',flat_logits)
595
+ # print(flat_labels)
596
+ # print(loss)
597
+ if not return_dict:
598
+ output = (logits,) + outputs[1:]
599
+ return (loss,) + output if loss is not None else output
600
+
601
+ return Gemma3OmniCausalLMOutputWithPast(
602
+ loss=loss,
603
+ logits=logits,
604
+ past_key_values=outputs.past_key_values,
605
+ hidden_states=outputs.hidden_states,
606
+ attentions=outputs.attentions,
607
+ image_hidden_states=image_features if pixel_values is not None else None,
608
+ audio_hidden_states=audio_features if input_audio_embeds is not None else None,
609
+ )
610
+
611
+ def prepare_inputs_for_generation(
612
+ self,
613
+ input_ids,
614
+ past_key_values=None,
615
+ input_modes=None,
616
+ inputs_embeds=None,
617
+ cache_position=None,
618
+ position_ids=None,
619
+ pixel_values=None,
620
+ input_audio_embeds=None,
621
+ audio_embed_sizes=None,
622
+ audio_attention_mask=None,
623
+ attention_mask=None,
624
+ token_type_ids=None,
625
+ use_cache=True,
626
+ logits_to_keep=None,
627
+ labels=None,
628
+ **kwargs,
629
+ ):
630
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
631
+ model_inputs = self.language_model.prepare_inputs_for_generation(
632
+ input_ids,
633
+ past_key_values=past_key_values,
634
+ input_modes=input_modes,
635
+ inputs_embeds=inputs_embeds,
636
+ attention_mask=attention_mask,
637
+ position_ids=position_ids,
638
+ cache_position=cache_position,
639
+ use_cache=use_cache,
640
+ logits_to_keep=logits_to_keep,
641
+ token_type_ids=token_type_ids,
642
+ **kwargs,
643
+ )
644
+
645
+ # position_ids in Gemma3 are 1-indexed
646
+ if model_inputs.get("position_ids") is not None:
647
+ model_inputs["position_ids"] += 1
648
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
649
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
650
+ if cache_position[0] == 0:
651
+ model_inputs["pixel_values"] = pixel_values
652
+ model_inputs["input_audio_embeds"] = input_audio_embeds
653
+ model_inputs["audio_embed_sizes"] = audio_embed_sizes
654
+ model_inputs["audio_attention_mask"] = audio_attention_mask
655
+ model_inputs["input_modes"] = input_modes
656
+ is_training = token_type_ids is not None and labels is not None
657
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
658
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
659
+ causal_mask = self._update_causal_mask(
660
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
661
+ )
662
+ model_inputs["attention_mask"] = causal_mask
663
+
664
+ return model_inputs
665
+
666
+ def tie_weights(self):
667
+ return self.language_model.tie_weights()
668
+
preprocessing_gemma3omni.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Optional, Union, Tuple
3
+ from math import ceil
4
+
5
+ import numpy as np
6
+ import torch
7
+ import scipy
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+ from enum import Enum
11
+
12
+ from transformers import AutoFeatureExtractor
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
15
+ from transformers.image_utils import ImageInput, make_nested_list_of_images
16
+ from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
17
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
18
+ from transformers.utils import to_py_obj, TensorType
19
+ from transformers.audio_utils import AudioInput
20
+
21
+
22
+ class Gemma3ImagesKwargs(ImagesKwargs):
23
+ do_pan_and_scan: Optional[bool]
24
+ pan_and_scan_min_crop_size: Optional[int]
25
+ pan_and_scan_max_num_crops: Optional[int]
26
+ pan_and_scan_min_ratio_to_activate: Optional[float]
27
+ do_convert_rgb: Optional[bool]
28
+
29
+
30
+ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
31
+ images_kwargs: Gemma3ImagesKwargs
32
+ _defaults = {
33
+ "text_kwargs": {
34
+ "padding": False,
35
+ },
36
+ "images_kwargs": {
37
+ "do_pan_and_scan": False,
38
+ "pan_and_scan_min_crop_size": 256,
39
+ "pan_and_scan_max_num_crops": 4,
40
+ "pan_and_scan_min_ratio_to_activate": 1.2,
41
+ },
42
+ }
43
+
44
+ def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
45
+ """Create a Mel filter-bank the same as SpeechLib FbankFC.
46
+
47
+ Args:
48
+ sample_rate (int): Sample rate in Hz. number > 0 [scalar]
49
+ n_fft (int): FFT size. int > 0 [scalar]
50
+ n_mel (int): Mel filter size. int > 0 [scalar]
51
+ fmin (float): lowest frequency (in Hz). If None use 0.0.
52
+ float >= 0 [scalar]
53
+ fmax: highest frequency (in Hz). If None use sample_rate / 2.
54
+ float >= 0 [scalar]
55
+
56
+ Returns
57
+ out (numpy.ndarray): Mel transform matrix
58
+ [shape=(n_mels, 1 + n_fft/2)]
59
+ """
60
+
61
+ bank_width = int(n_fft // 2 + 1)
62
+ if fmax is None:
63
+ fmax = sample_rate / 2
64
+ if fmin is None:
65
+ fmin = 0
66
+ assert fmin >= 0, "fmin cannot be negtive"
67
+ assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
68
+
69
+ def mel(f):
70
+ return 1127.0 * np.log(1.0 + f / 700.0)
71
+
72
+ def bin2mel(fft_bin):
73
+ return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
74
+
75
+ def f2bin(f):
76
+ return int((f * n_fft / sample_rate) + 0.5)
77
+
78
+ # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
79
+ klo = f2bin(fmin) + 1
80
+ khi = f2bin(fmax)
81
+
82
+ khi = max(khi, klo)
83
+
84
+ # Spec 2: SpeechLib uses trianges in Mel space
85
+ mlo = mel(fmin)
86
+ mhi = mel(fmax)
87
+ m_centers = np.linspace(mlo, mhi, n_mels + 2)
88
+ ms = (mhi - mlo) / (n_mels + 1)
89
+
90
+ matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
91
+ for m in range(0, n_mels):
92
+ left = m_centers[m]
93
+ center = m_centers[m + 1]
94
+ right = m_centers[m + 2]
95
+ for fft_bin in range(klo, khi):
96
+ mbin = bin2mel(fft_bin)
97
+ if left < mbin < right:
98
+ matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
99
+
100
+ return matrix
101
+
102
+
103
+ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
104
+ model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
105
+
106
+ def __init__(self, audio_compression_rate=8,
107
+ audio_downsample_rate=1,
108
+ audio_feat_stride=1,
109
+ feature_size = 80,
110
+ sampling_rate = 16000,
111
+ padding_value = 0.0,
112
+ **kwargs):
113
+
114
+ super().__init__(feature_size=feature_size,
115
+ sampling_rate=sampling_rate,
116
+ padding_value=padding_value, **kwargs)
117
+
118
+ self.compression_rate = audio_compression_rate
119
+ self.qformer_compression_rate = audio_downsample_rate
120
+ self.feat_stride = audio_feat_stride
121
+
122
+ self._eightk_method = "fillzero"
123
+ self._mel = speechlib_mel(self.sampling_rate, 512, self.feature_size, fmin=None, fmax=self.sampling_rate//2-self.feature_size-230).T
124
+
125
+ self._hamming400 = np.hamming(400) # for 16k audio
126
+ self._hamming200 = np.hamming(200) # for 8k audio
127
+
128
+ def duration_to_frames(self, duration):
129
+ """duration in s, estimated frames"""
130
+ frame_rate = 10
131
+
132
+ num_frames = duration * 1000 // frame_rate
133
+ return num_frames
134
+
135
+ def __call__(
136
+ self,
137
+ audios: List[AudioInput],
138
+ sampling_rate = 16000,
139
+ return_attention_mask=True,
140
+ padding="max_length",
141
+ return_tensors: Optional[Union[str, TensorType]] = None,
142
+ ):
143
+ # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
144
+ returned_input_audio_embeds = []
145
+ returned_audio_embed_sizes = []
146
+ audio_frames_list = []
147
+
148
+ for audio_data in audios:
149
+ audio_embeds = self._extract_features(audio_data, sampling_rate)
150
+ audio_frames = len(audio_embeds) * self.feat_stride
151
+ audio_embed_size = self._compute_audio_embed_size(audio_frames)
152
+
153
+ returned_input_audio_embeds.append(torch.tensor(audio_embeds))
154
+ returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
155
+ audio_frames_list.append(audio_frames)
156
+
157
+ returned_input_audio_embeds = pad_sequence(
158
+ returned_input_audio_embeds, batch_first=True
159
+ )
160
+ returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
161
+ audio_frames = torch.tensor(audio_frames_list)
162
+ returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
163
+
164
+ data = {
165
+ "input_audio_embeds": returned_input_audio_embeds,
166
+ "audio_embed_sizes": returned_audio_embed_sizes,
167
+ }
168
+ if returned_audio_attention_mask is not None and return_attention_mask:
169
+ data["audio_attention_mask"] = returned_audio_attention_mask
170
+
171
+ return BatchFeature(data=data, tensor_type=return_tensors)
172
+
173
+ def _extract_spectrogram(self, wav, fs):
174
+ """Extract spectrogram features from waveform.
175
+ Args:
176
+ wav (1D array): waveform of the input
177
+ fs (int): sampling rate of the waveform, 16000 or 8000.
178
+ If fs=8000, the waveform will be resampled to 16000Hz.
179
+ Output:
180
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
181
+ D=80, and T is the number of frames.
182
+ """
183
+ if wav.ndim > 1:
184
+ wav = np.squeeze(wav)
185
+
186
+ # by default, we extract the mean if stereo
187
+ if len(wav.shape) == 2:
188
+ wav = wav.mean(1)
189
+
190
+ # Resample to 16000 or 8000 if needed
191
+ if fs > 16000:
192
+ wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
193
+ fs = 16000
194
+ elif 8000 < fs < 16000:
195
+ wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
196
+ fs = 8000
197
+ elif fs < 8000:
198
+ raise RuntimeError(f"Unsupported sample rate {fs}")
199
+
200
+ if fs == 8000:
201
+ if self._eightk_method == "resample":
202
+ # Input audio is 8 kHz. Convert to 16 kHz before feature
203
+ # extraction
204
+ wav = scipy.signal.resample_poly(wav, 2, 1)
205
+ fs = 16000
206
+ # Do nothing here for fillzero method
207
+ elif fs != 16000:
208
+ # Input audio is not a supported sample rate.
209
+ raise RuntimeError(f"Input data using an unsupported sample rate: {fs}")
210
+
211
+ preemphasis = 0.97
212
+
213
+ if fs == 8000:
214
+ n_fft = 256
215
+ win_length = 200
216
+ hop_length = 80
217
+ fft_window = self._hamming200
218
+ elif fs == 16000:
219
+ n_fft = 512
220
+ win_length = 400
221
+ hop_length = 160
222
+ fft_window = self._hamming400
223
+
224
+ # Spec 1: SpeechLib cut remaining sample insufficient for a hop
225
+ n_batch = (wav.shape[0] - win_length) // hop_length + 1
226
+ # Here we don't use stride_tricks since the input array may not satisfy
227
+ # memory layout requirement and we need writeable output
228
+ # Here we only use list of views before copy to desination
229
+ # so it is more efficient than broadcasting
230
+ y_frames = np.array(
231
+ [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
232
+ dtype=np.float32,
233
+ )
234
+
235
+ # Spec 2: SpeechLib applies preemphasis within each batch
236
+ y_frames_prev = np.roll(y_frames, 1, axis=1)
237
+ y_frames_prev[:, 0] = y_frames_prev[:, 1]
238
+ y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
239
+
240
+ S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
241
+
242
+ if fs == 8000:
243
+ # Need to pad the output to look like 16 kHz data but with zeros in
244
+ # the 4 to 8 kHz bins.
245
+ frames, bins = S.shape
246
+ padarray = np.zeros((frames, bins))
247
+ S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
248
+
249
+ spec = np.abs(S).astype(np.float32)
250
+ return spec
251
+
252
+ def _extract_features(self, wav, fs):
253
+ """Extract log filterbank features from waveform.
254
+ Args:
255
+ wav (1D array): waveform of the input
256
+ fs (int): sampling rate of the waveform, 16000 or 8000.
257
+ If fs=8000, the waveform will be resampled to 16000Hz.
258
+ Output:
259
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
260
+ D=80, and T is the number of frames.
261
+ """
262
+ spec = self._extract_spectrogram(wav, fs)
263
+ spec_power = spec**2
264
+
265
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
266
+ log_fbank = np.log(fbank_power).astype(np.float32)
267
+
268
+ return log_fbank
269
+
270
+ def _compute_audio_embed_size(self, audio_frames):
271
+ integer = audio_frames // self.compression_rate
272
+ remainder = audio_frames % self.compression_rate
273
+
274
+ result = integer if remainder == 0 else integer + 1
275
+
276
+ integer = result // self.qformer_compression_rate
277
+ remainder = result % self.qformer_compression_rate
278
+ result = integer if remainder == 0 else integer + 1 # qformer compression
279
+
280
+ return result
281
+
282
+ class Gemma3OmniProcessor(ProcessorMixin):
283
+ attributes = ["image_processor", "feature_extractor", "tokenizer"]
284
+ valid_kwargs = ["chat_template", "image_seq_length"]
285
+ image_processor_class = "AutoImageProcessor"
286
+ feature_extractor_class = "Gemma3AudioFeatureExtractor"
287
+ tokenizer_class = "AutoTokenizer"
288
+
289
+ def __init__(
290
+ self,
291
+ image_processor,
292
+ feature_extractor,
293
+ tokenizer,
294
+ chat_template=None,
295
+ image_seq_length: int = 256,
296
+ **kwargs,
297
+ ):
298
+ self.image_seq_length = image_seq_length
299
+ self.image_token_id = tokenizer.image_token_id
300
+ self.boi_token = tokenizer.boi_token
301
+ self.image_token = tokenizer.image_token
302
+ image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
303
+ self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
304
+
305
+ self.audio_token_id = tokenizer.audio_token_id
306
+ self.boa_token = tokenizer.boa_token
307
+ self.eoa_token = tokenizer.eoa_token
308
+ self.audio_token = tokenizer.audio_token
309
+
310
+ super().__init__(
311
+ image_processor=image_processor,
312
+ feature_extractor=feature_extractor,
313
+ tokenizer=tokenizer,
314
+ chat_template=chat_template,
315
+ **kwargs,
316
+ )
317
+
318
+ def __call__(
319
+ self,
320
+ images: ImageInput = None,
321
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
322
+ videos=None,
323
+ audio: List[AudioInput] = None,
324
+ **kwargs: Unpack[Gemma3ProcessorKwargs],
325
+ ) -> BatchFeature:
326
+ if text is None and images is None:
327
+ raise ValueError("Provide at least one of `text` or `images`.")
328
+
329
+ output_kwargs = self._merge_kwargs(
330
+ Gemma3ProcessorKwargs,
331
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
332
+ **kwargs,
333
+ )
334
+
335
+ if isinstance(text, str):
336
+ text = [text]
337
+ elif not isinstance(text, list) and not isinstance(text[0], str):
338
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
339
+
340
+ image_inputs = {}
341
+ if images is not None:
342
+ batched_images = make_nested_list_of_images(images)
343
+ image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
344
+
345
+ # Create empty text to be replaced with placeholders
346
+ if not text:
347
+ text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
348
+
349
+ if len(batched_images) != len(text):
350
+ raise ValueError(
351
+ f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
352
+ )
353
+
354
+ # Replace image tokens by the full expanded sequence
355
+ num_crops = to_py_obj(image_inputs.pop("num_crops"))
356
+ batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
357
+ for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
358
+ image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
359
+
360
+ if len(images) != len(image_indexes):
361
+ raise ValueError(
362
+ f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
363
+ )
364
+
365
+ # Insert additional image tokens for Pan-and-Scan crops
366
+ for num, idx in reversed(list(zip(num_crops, image_indexes))):
367
+ if num:
368
+ formatted_image_text = (
369
+ f"Here is the original image {self.boi_token} and here are some crops to help you see better "
370
+ + " ".join([self.boi_token] * num)
371
+ )
372
+ prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
373
+ text[batch_idx] = prompt
374
+
375
+ # Expand placeholder image tokens to the full image token sequence
376
+ text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
377
+
378
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
379
+
380
+ audio_inputs = {}
381
+ if audio is not None:
382
+ full_audio_sequences = []
383
+ audio_inputs = self.feature_extractor(audio)
384
+ def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
385
+ parts = prompt.split(boa_token)
386
+ result = ""
387
+ for i in range(len(parts) - 1):
388
+ result += parts[i]
389
+ if i < len(audio_sequences):
390
+ result += audio_sequences[i]
391
+ else:
392
+ result += boa_token
393
+ result += parts[-1]
394
+ return result
395
+ for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
396
+ audio_tokens_expanded = "".join([self.audio_token] * embed_size)
397
+ full_audio_sequence = f"\n\n{self.boa_token}{audio_tokens_expanded}{self.eoa_token}\n\n"
398
+ full_audio_sequences.append(full_audio_sequence)
399
+
400
+ text = [replace_tokens_sequentially(prompt, self.boa_token, [audio_sequences]) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
401
+ #text = [prompt.replace(self.boa_token, audio_sequences) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
402
+
403
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
404
+
405
+ # Add token type ids manually, as tokenizer can't do arbitrary position token types
406
+ array_ids = text_inputs["input_ids"]
407
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
408
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
409
+ mm_token_type_ids[array_ids == self.audio_token_id] = 2
410
+
411
+ has_vision_ids = np.any(mm_token_type_ids == 1, axis=1)
412
+ has_audio_ids = np.any(mm_token_type_ids == 2, axis=1)
413
+
414
+ input_modes = (has_audio_ids << 1) | has_vision_ids
415
+
416
+ text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
417
+ text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
418
+ text_inputs["input_modes"] = input_modes.tolist()
419
+
420
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
421
+
422
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
423
+ def batch_decode(self, *args, **kwargs):
424
+ """
425
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
426
+ refer to the docstring of this method for more information.
427
+ """
428
+ return self.tokenizer.batch_decode(*args, **kwargs)
429
+
430
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
431
+ def decode(self, *args, **kwargs):
432
+ """
433
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
434
+ the docstring of this method for more information.
435
+ """
436
+ return self.tokenizer.decode(*args, **kwargs)
437
+
438
+ @property
439
+ def model_input_names(self):
440
+ tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
441
+ image_processor_input_names = self.image_processor.model_input_names
442
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
443
+
444
+ AutoFeatureExtractor.register("Gemma3AudioFeatureExtractor", Gemma3AudioFeatureExtractor)
preprocessor_config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_compression_rate": 8,
3
+ "audio_downsample_rate": 1,
4
+ "audio_feat_stride": 1,
5
+ "compression_rate": 8,
6
+ "do_convert_rgb": null,
7
+ "do_normalize": true,
8
+ "do_pan_and_scan": null,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feat_stride": 1,
12
+ "feature_extractor_type": "Gemma3AudioFeatureExtractor",
13
+ "feature_size": 80,
14
+ "image_mean": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "image_processor_type": "Gemma3ImageProcessor",
20
+ "image_seq_length": 256,
21
+ "image_std": [
22
+ 0.5,
23
+ 0.5,
24
+ 0.5
25
+ ],
26
+ "padding_side": "right",
27
+ "padding_value": 0.0,
28
+ "pan_and_scan_max_num_crops": null,
29
+ "pan_and_scan_min_crop_size": null,
30
+ "pan_and_scan_min_ratio_to_activate": null,
31
+ "processor_class": "Gemma3OmniProcessor",
32
+ "qformer_compression_rate": 1,
33
+ "resample": 2,
34
+ "rescale_factor": 0.00392156862745098,
35
+ "return_attention_mask": true,
36
+ "sampling_rate": 16000,
37
+ "size": {
38
+ "height": 896,
39
+ "width": 896
40
+ }
41
+ }
processor_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "preprocessing_gemma3omni.Gemma3OmniProcessor"
4
+ },
5
+ "image_seq_length": 256,
6
+ "processor_class": "Gemma3Processor"
7
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_token": "<audio_soft_token>",
3
+ "boa_token": "<start_of_audio>",
4
+ "boi_token": "<start_of_image>",
5
+ "bos_token": {
6
+ "content": "<bos>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "eoa_token": "<end_of_audio>",
13
+ "eoi_token": "<end_of_image>",
14
+ "eos_token": {
15
+ "content": "<eos>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "image_token": "<image_soft_token>",
22
+ "pad_token": {
23
+ "content": "<pad>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false
28
+ },
29
+ "unk_token": {
30
+ "content": "<unk>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ }
36
+ }
speech_conformer_encoder.py ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52941f2ba60fdcc48edb940f4252f6d874d0c369323dab293168015122e556be
3
+ size 33384559
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
training.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import sacrebleu
14
+
15
+ from datasets import load_dataset
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+ from tqdm import tqdm
18
+ from transformers import (
19
+ AutoProcessor,
20
+ AutoModel,
21
+ BatchFeature,
22
+ Trainer,
23
+ TrainingArguments,
24
+ StoppingCriteria,
25
+ StoppingCriteriaList,
26
+ )
27
+ from collections import defaultdict
28
+
29
+ import soundfile as sf
30
+ from datasets import Audio
31
+ import random
32
+ ANSWER_SUFFIX = "<end_of_turn>"
33
+ _IGNORE_INDEX = -100
34
+ class BaseAudioDataset(Dataset):
35
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
36
+ self.processor = processor
37
+ self.training = "train" in split or 'other' in split
38
+ self.debug = debug
39
+ self.sampling_rate = sampling_rate
40
+ self.name = ""
41
+
42
+ def set_dataset_name(self, name):
43
+ self.name = name
44
+
45
+ @staticmethod
46
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
47
+ original_size = len(data)
48
+
49
+ data = data.cast_column(audio_field, Audio(decode=False))
50
+
51
+ def identify_corrupted_files(example):
52
+ try:
53
+ sf.read(example[audio_field]["path"])
54
+
55
+ for field in text_fields:
56
+ if field in example and example[field].replace('"', '') == "":
57
+ return False
58
+ return True
59
+ except Exception:
60
+ return False
61
+
62
+ data = data.filter(identify_corrupted_files, num_proc=16)
63
+ validated_size = len(data)
64
+
65
+ # Audio Decoding
66
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
67
+
68
+ if debug:
69
+ print(f"Dataset: {dataset_name}")
70
+ print(f"Original data nums: {original_size}")
71
+ print(f"After filtering data nums: {validated_size}")
72
+ print(f"Filtering ratio: {validated_size/original_size:.2%}")
73
+
74
+ return data
75
+
76
+ @staticmethod
77
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
78
+ original_size = len(data)
79
+
80
+ def filter_audio_by_length(example):
81
+ try:
82
+ audio = example[audio_field]['array']
83
+ channel = 1
84
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
85
+ channel = audio.ndim
86
+ audio = audio.squeeze()
87
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
88
+ return min_sec <= audio_length <= max_sec
89
+ except Exception as e:
90
+ if debug:
91
+ print(f"Error : {str(e)[:100]}... - sample excluded")
92
+ return False
93
+
94
+ data = data.filter(filter_audio_by_length, num_proc=16)
95
+ filtered_size = len(data)
96
+
97
+ if debug:
98
+ print(f"Before Length Filtering data nums: {original_size}")
99
+ print(f"After Length Filtering data nums: {filtered_size}")
100
+ print(f"Filtering ratio: {filtered_size/original_size:.2%}")
101
+
102
+ return data
103
+
104
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
105
+ user_message = {
106
+ 'role': 'user',
107
+ 'content': '<start_of_audio>' + instruction,
108
+ }
109
+ prompt = self.processor.tokenizer.apply_chat_template(
110
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
111
+ )
112
+
113
+ inputs = self.processor(
114
+ text=prompt,
115
+ audio=[audio_array],
116
+ add_special_tokens=False,
117
+ return_tensors='pt'
118
+ )
119
+
120
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
121
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
122
+
123
+ if self.debug:
124
+ self.debug = False
125
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
126
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
127
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
128
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
129
+
130
+ if self.training:
131
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
132
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
133
+ labels[:, -answer_ids.shape[1]:] = answer_ids
134
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
135
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
136
+ else:
137
+ input_ids = inputs.input_ids
138
+ labels = answer_ids
139
+ token_type_ids = inputs.token_type_ids
140
+
141
+ return {
142
+ 'input_ids': input_ids,
143
+ 'labels': labels,
144
+ 'token_type_ids': token_type_ids,
145
+ 'input_audio_embeds': inputs.input_audio_embeds,
146
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
147
+ 'input_modes': inputs.input_modes,
148
+ }
149
+
150
+
151
+ # Libri Speech Dataset Class
152
+ class LibriSpeechDataset(BaseAudioDataset):
153
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
154
+ super().__init__(processor, split, sampling_rate, debug)
155
+
156
+ self.set_dataset_name(f"LibriSpeech_{subset}")
157
+ # only ASR
158
+ self.ast = False
159
+ self.lang = "en"
160
+
161
+ # load dataset
162
+ self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
163
+ subset,
164
+ split=split,
165
+ trust_remote_code=True,
166
+ cache_dir=Path("/mnt/jeff/InCar/data")
167
+ )
168
+
169
+ # (Optional) Audio length Filtering
170
+ self.data = self.filter_by_audio_length(self.data, "audio")
171
+
172
+ # Instruction Setting
173
+ self.instruction = random.choice(INSTRUCTION["asr"])
174
+
175
+ def __len__(self):
176
+ return len(self.data)
177
+
178
+ def __getitem__(self, idx):
179
+ data = self.data[idx]
180
+
181
+ # Libri Speech is only for ASR
182
+ answer_text = data["text"].replace('"', '')
183
+
184
+ return self.prepare_model_inputs(
185
+ data["audio"]["array"],
186
+ self.instruction,
187
+ answer_text
188
+ )
189
+
190
+ # common_voice_16_1 dataset
191
+ class CommonVoiceDataset(BaseAudioDataset):
192
+ def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
193
+ super().__init__(processor, split, sampling_rate, debug)
194
+
195
+ self.set_dataset_name(f"CommonVoice_{source_lang}")
196
+ # only ASR
197
+ self.ast = False
198
+ self.lang=source_lang
199
+
200
+ # load dataset
201
+ if source_lang=="zh-TW":
202
+ data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
203
+ else:
204
+ data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
205
+ self.data = load_dataset(data_path,
206
+ source_lang,
207
+ split=split,
208
+ trust_remote_code=True,
209
+ cache_dir=Path("/mnt/jeff/InCar/data")
210
+ )
211
+ def prepare_dataset(batch):
212
+ """Function to preprocess the dataset with the .map method"""
213
+ transcription = batch["sentence"]
214
+
215
+ if transcription.startswith('"') and transcription.endswith('"'):
216
+ # we can remove trailing quotation marks as they do not affect the transcription
217
+ transcription = transcription[1:-1]
218
+
219
+ if transcription[-1] not in [".", "?", "!"]:
220
+ # append a full-stop to sentences that do not end in punctuation
221
+ transcription = transcription + "."
222
+
223
+ batch["sentence"] = transcription
224
+
225
+ return batch
226
+
227
+
228
+ import opencc
229
+ converter = opencc.OpenCC('s2tw.json')
230
+ def To_zhTW(batch):
231
+
232
+ transcription = converter.convert(batch["sentence"])
233
+ batch["sentence"] = transcription
234
+
235
+ return batch
236
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
237
+ if source_lang=='zh-CN':
238
+ self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
239
+
240
+
241
+ # (Optional) Audio length Filtering
242
+ self.data = self.filter_by_audio_length(self.data, "audio")
243
+
244
+ if source_lang == "zh-TW" and split=='train':
245
+ import torchaudio
246
+ from torchaudio import transforms
247
+ import copy
248
+ import pickle
249
+ import os
250
+ def subsample(batch):
251
+ batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
252
+ batch['audio']['sampling_rate']=16000
253
+ return batch
254
+ def TW_data_augment_fast(batch):
255
+ speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
256
+ new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
257
+ batch['audio']['array'] = new_array_fast
258
+ return batch
259
+ def TW_data_augment_slow(batch):
260
+ speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
261
+ new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
262
+ batch['audio']['array'] = new_array_slow
263
+ return batch
264
+ # data = self.data.map(subsample, num_proc=1, desc="subsample")
265
+ fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
266
+ if not os.path.exists(fast_path):
267
+ data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
268
+ with open(fast_path,'wb') as f:
269
+ pickle.dump(data_fast,f)
270
+ else:
271
+ with open(fast_path,'rb') as f:
272
+ data_fast=pickle.load(f)
273
+
274
+ slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
275
+ if not os.path.exists(slow_path):
276
+ data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
277
+ with open(slow_path,'wb') as f:
278
+ pickle.dump(data_slow,f)
279
+ else:
280
+ with open(slow_path,'rb') as f:
281
+ data_slow=pickle.load(f)
282
+ self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
283
+
284
+ # Instruction Setting
285
+ self.instruction = random.choice(INSTRUCTION["asr"])
286
+
287
+ def __len__(self):
288
+ return len(self.data)
289
+
290
+ def __getitem__(self, idx):
291
+ data = self.data[idx]
292
+
293
+ answer_text = data["sentence"]
294
+ return self.prepare_model_inputs(
295
+ data["audio"]["array"],
296
+ self.instruction,
297
+ answer_text
298
+ )
299
+
300
+
301
+ # Fleurs Dataset Class
302
+ class FleursDataset(BaseAudioDataset):
303
+ def __init__(self, processor, split, source_lang, target_lang=None,
304
+ mode="asr", sampling_rate=16000, debug=False):
305
+ super().__init__(processor, split, sampling_rate, debug)
306
+
307
+ self.set_dataset_name("Fleurs")
308
+ # Mode Setting (ASR or AST)
309
+ if mode not in ["asr", "ast"]:
310
+ raise ValueError("mode must be 'asr' or 'ast'.")
311
+
312
+ self.mode = mode
313
+ self.ast = (mode == "ast")
314
+ self.source_lang = source_lang
315
+
316
+ # Language name mapping (expand if needed)
317
+ self.lang_names = {
318
+ 'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
319
+ }
320
+
321
+ # load dataset - source language dataset
322
+ self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
323
+ source_lang,
324
+ split=split,
325
+ trust_remote_code=True,
326
+ cache_dir=Path("/mnt/jeff/InCar/data")
327
+ )
328
+ import opencc
329
+ converter = opencc.OpenCC('s2tw.json')
330
+ def prepare_dataset(batch):
331
+ transcription = converter.convert(batch["transcription"])
332
+ batch["transcription"] = transcription
333
+
334
+ return batch
335
+ if (source_lang=="cmn_hans_cn"):
336
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
337
+
338
+ # (Optional) Audio length Filtering
339
+ self.data = self.filter_by_audio_length(self.data, "audio")
340
+ self.target_lang_name = ""
341
+ # When AST mode, load target language dataset.
342
+ if self.ast:
343
+ if target_lang is None:
344
+ raise ValueError("AST mode requires target_lang.")
345
+
346
+ self.target_lang = target_lang
347
+ self.lang = f"{source_lang}_{target_lang}"
348
+
349
+ # load dataset - target language dataset (for translation)
350
+ target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
351
+ target_lang,
352
+ split=split,
353
+ trust_remote_code=True,
354
+ cache_dir=Path("/mnt/jeff/InCar/data")
355
+ )
356
+ if target_lang=="cmn_hans_cn":
357
+ target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
358
+ source_dict = {item['id']: item for item in self.data}
359
+ target_dict = {item['id']: item for item in target_data}
360
+
361
+ # only Common ID, add translation fields
362
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
363
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
364
+ self.data = [
365
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
366
+ for id in common_ids
367
+ ]
368
+
369
+ # Instruction Setting - use target language name
370
+ self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
371
+ self.instruction = random.choice(INSTRUCTION["ast"])
372
+ else:
373
+ # ASR mode
374
+ self.lang = source_lang
375
+ self.instruction = random.choice(INSTRUCTION["asr"])
376
+
377
+ if self.debug:
378
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
379
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
380
+ if self.ast:
381
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
382
+ print(f"dataset size: {len(self.data)}")
383
+
384
+ def __len__(self):
385
+ return len(self.data)
386
+
387
+ def __getitem__(self, idx):
388
+ data = self.data[idx]
389
+ audio_array = data["audio"]["array"]
390
+
391
+ if self.ast:
392
+ answer_text = data["translation"]
393
+ else:
394
+ answer_text = data["transcription"]
395
+
396
+ return self.prepare_model_inputs(
397
+ audio_array,
398
+ self.instruction.format(self.target_lang_name),
399
+ answer_text
400
+ )
401
+
402
+ def covost_collate_fn(batch):
403
+ input_ids_list = []
404
+ labels_list = []
405
+ token_type_ids_list = []
406
+ input_audio_embeds_list = []
407
+ audio_embed_sizes_list = []
408
+ audio_attention_mask_list = []
409
+ input_modes_list = []
410
+ for inputs in batch:
411
+ input_ids_list.append(inputs['input_ids'][0])
412
+ labels_list.append(inputs['labels'][0])
413
+ token_type_ids_list.append(inputs['token_type_ids'][0])
414
+ input_audio_embeds_list.append(inputs['input_audio_embeds'])
415
+ audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
416
+ audio_attention_mask_list.append(
417
+ inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
418
+ )
419
+ input_modes_list.append(inputs['input_modes'])
420
+
421
+ try:
422
+ token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
423
+ input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
424
+ labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
425
+ audio_attention_mask = (
426
+ pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
427
+ if len(audio_attention_mask_list) > 1
428
+ else None
429
+ )
430
+ except Exception as e:
431
+ print(e)
432
+ print(input_ids_list)
433
+ print(labels_list)
434
+ raise
435
+ attention_mask = (input_ids != 0).long()
436
+ input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
437
+ audio_embed_sizes = torch.cat(audio_embed_sizes_list)
438
+ input_modes = torch.cat(input_modes_list)
439
+
440
+ return BatchFeature(
441
+ {
442
+ 'input_ids': input_ids,
443
+ 'labels': labels,
444
+ 'token_type_ids': token_type_ids,
445
+ 'attention_mask': attention_mask,
446
+ 'input_audio_embeds': input_audio_embeds,
447
+ 'audio_embed_sizes': audio_embed_sizes,
448
+ 'audio_attention_mask': audio_attention_mask,
449
+ 'input_modes': input_modes,
450
+ }
451
+ )
452
+
453
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
454
+ """
455
+ Pad a list of sequences to the same length.
456
+ sequences: list of tensors in [seq_len, *] shape
457
+ """
458
+ assert padding_side in ['right', 'left']
459
+ max_size = sequences[0].size()
460
+ trailing_dims = max_size[1:]
461
+ max_len = max(len(seq) for seq in sequences)
462
+ batch_size = len(sequences)
463
+ output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
464
+ for i, seq in enumerate(sequences):
465
+ length = seq.size(0)
466
+ if padding_side == 'right':
467
+ output.data[i, :length] = seq
468
+ else:
469
+ output.data[i, -length:] = seq
470
+ return output
471
+
472
+ def cat_with_pad(tensors, dim, padding_value=0):
473
+ """
474
+ cat along dim, while pad to max for all other dims
475
+ """
476
+ ndim = tensors[0].dim()
477
+ assert all(
478
+ t.dim() == ndim for t in tensors[1:]
479
+ ), 'All tensors must have the same number of dimensions'
480
+
481
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
482
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
483
+ output = tensors[0].new_full(out_size, padding_value)
484
+
485
+ index = 0
486
+ for t in tensors:
487
+ # Create a slice list where every dimension except dim is full slice
488
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
489
+ # Update only the concat dimension slice
490
+ slices[dim] = slice(index, index + t.shape[dim])
491
+
492
+ output[slices] = t
493
+ index += t.shape[dim]
494
+
495
+ return output
496
+
497
+ def count_parameters_by_module(model):
498
+ # dictionary for parameters number by modules
499
+ module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
500
+
501
+ # all params
502
+ total_params = 0
503
+ total_trainable_params = 0
504
+
505
+ # Check Embedding Token masks
506
+ embedding_masks = {}
507
+ for name, param in model.named_parameters():
508
+ if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
509
+ # check if params has embedding_grad_mask_hook
510
+ for hook_id, hook_fn in param._backward_hooks.items():
511
+ if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
512
+ # Accessing mask variables in the closure of hook functions
513
+ for cell in hook_fn.__closure__ or []:
514
+ if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
515
+ # check mask tensor
516
+ embedding_masks[name] = ~cell.cell_contents # True : Trainable
517
+
518
+ # Count params by modules
519
+ for name, param in model.named_parameters():
520
+ # extracts top module_name
521
+ module_name = name.split('.')[0]
522
+ param_count = param.numel()
523
+
524
+ module_params[module_name]["total"] += param_count
525
+ total_params += param_count
526
+
527
+ if param.requires_grad:
528
+ # Only count for real trainable params. (with masks)
529
+ if name in embedding_masks:
530
+ trainable_count = embedding_masks[name].sum().item()
531
+ module_params[module_name]["trainable"] += trainable_count
532
+ total_trainable_params += trainable_count
533
+ else:
534
+ module_params[module_name]["trainable"] += param_count
535
+ total_trainable_params += param_count
536
+
537
+ print(f"All Params: {total_params:,}")
538
+ print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
539
+ print("\nParams by Module:")
540
+
541
+ for module_name, counts in sorted(module_params.items()):
542
+ trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
543
+ total_percentage = counts["total"] / total_params * 100
544
+
545
+ print(f"- {module_name}:")
546
+ print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
547
+ print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
548
+
549
+ return module_params
550
+
551
+ def create_model(model_name_or_path, revision="main", use_flash_attention = False):
552
+ model = AutoModel.from_pretrained(
553
+ model_name_or_path,
554
+ revision=revision,
555
+ torch_dtype=torch.bfloat16,
556
+ device_map="auto",
557
+ attn_implementation="flash_attention_2" if use_flash_attention else "eager",
558
+ trust_remote_code=True,
559
+ )
560
+
561
+ # Set use_cache to False after model loaded
562
+ model.config.use_cache = False
563
+
564
+ # Freeze all parameters
565
+ for param in model.parameters():
566
+ param.requires_grad = False
567
+
568
+ model.set_lora_adapter('speech')
569
+ model.to(torch.bfloat16)
570
+
571
+ # (Optional) unfreeze audio_tower parameters
572
+ # for param in model.audio_tower.parameters():
573
+ # param.requires_grad = True
574
+
575
+ # Only unfreeze audio_projector parameters
576
+ for param in model.audio_projector.parameters():
577
+ param.requires_grad = True
578
+
579
+ # (Optional) unfreeze audio embed_tokens
580
+ train_embed = True
581
+ if train_embed:
582
+ embed_tokens = model.language_model.model.model.embed_tokens
583
+
584
+ embed_tokens.weight.requires_grad = False
585
+
586
+ # Added Speech token IDs (only this tokens be trainable)
587
+ trainable_token_ids = [256001, 256002]
588
+
589
+ embed_tokens.weight.requires_grad = True
590
+ mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
591
+ mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
592
+
593
+ # backward hook, with gradient masking
594
+ def embedding_grad_mask_hook(grad):
595
+ return grad.masked_fill(mask, 0)
596
+
597
+ embed_tokens.weight.register_hook(embedding_grad_mask_hook)
598
+
599
+ model.language_model.model.model.embed_tokens = embed_tokens
600
+
601
+ count_parameters_by_module(model)
602
+
603
+ return model
604
+
605
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
606
+
607
+ INSTRUCTION = {
608
+ "ast": [
609
+ "Translate the audio to {0}.",
610
+ "Translate the audio clip into {0}.",
611
+ "Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
612
+ "Translate the provided audio file into {0}.",
613
+ "Convert the audio speech to {0} text.",
614
+ "Write an {0} translation of the audio file.",
615
+ "Translate spoken words from the audio into {0}.",
616
+ "Create an {0} version of the audio content.",
617
+ "Produce an accurate {0} translation of the audio.",
618
+ "Extract speech from the audio and translate it to {0}.",
619
+ "Turn the audio into readable {0} text.",
620
+ "Write all spoken content from the audio in {0}.",
621
+ "Generate an {0} translation of the speech in the file.",
622
+ "Convert the recording into {0} text.",
623
+ "Accurately translate the audio recording to {0}.",
624
+ "Write down dialogue from the given audio in {0}.",
625
+ "Translate all speech in this audio file to {0}.",
626
+ "Create an accurate {0} version of the speech.",
627
+ "Perform a complete {0} translation of the audio."
628
+ ],
629
+ "asr": [
630
+ "Transcribe the audio clip into text.",
631
+ "Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
632
+ "Transcribe the provided audio file into text.",
633
+ "Convert the audio speech to text.",
634
+ "Write a transcript of the audio file.",
635
+ "Transcribe spoken words from the audio.",
636
+ "Create a text version of the audio content.",
637
+ "Produce a verbatim transcript of the audio.",
638
+ "Extract and transcribe speech from the audio.",
639
+ "Turn the audio into readable text.",
640
+ "Write all spoken words from the audio.",
641
+ "Generate a transcript of the speech in the file.",
642
+ "Convert the recording into a text transcript.",
643
+ "Accurately transcribe the audio recording.",
644
+ "Write down dialogue from the given audio.",
645
+ "Transcribe all speech in this audio file.",
646
+ "Create an accurate text version of the speech.",
647
+ "Perform a complete transcription of the audio."
648
+ ],
649
+ }
650
+
651
+ ANSWER_SUFFIX = "<end_of_turn>"
652
+ _IGNORE_INDEX = -100
653
+
654
+ model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
655
+ use_flash_attention = True
656
+
657
+ output_dir = '../gemma_tmp7'
658
+ batch_size = 128
659
+ batch_size_per_gpu = 16
660
+ learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
661
+ wd = 0.01
662
+ num_train_epochs = 15
663
+
664
+ revision = "main" #"v1.0"
665
+
666
+ processor = AutoProcessor.from_pretrained(
667
+ model_name_or_path,
668
+ revision=revision,
669
+ trust_remote_code=True,
670
+ )
671
+
672
+ model = create_model(
673
+ model_name_or_path,
674
+ revision=revision,
675
+ use_flash_attention=use_flash_attention,
676
+ )
677
+
678
+ train_datasets = []
679
+
680
+ # common voice asr
681
+ commonvoice_speech_tw2 = CommonVoiceDataset(
682
+ processor=processor,
683
+ source_lang="zh-TW",
684
+ split="other[:70%]"
685
+ )
686
+ train_datasets.append(commonvoice_speech_tw2)
687
+
688
+ commonvoice_speech_cn = CommonVoiceDataset(
689
+ processor=processor,
690
+ source_lang="zh-CN",
691
+ split="train[:50%]"
692
+ )
693
+ train_datasets.append(commonvoice_speech_cn)
694
+
695
+
696
+ commonvoice_speech_tw = CommonVoiceDataset(
697
+ processor=processor,
698
+ source_lang="zh-TW",
699
+ split="train"
700
+ )
701
+ train_datasets.append(commonvoice_speech_tw)
702
+
703
+
704
+
705
+
706
+ # Libri Speech Clean ASR mode (English -> English text)
707
+ libri_speech_clean = LibriSpeechDataset(
708
+ processor=processor,
709
+ subset="clean",
710
+ split="train.360[:50%]"
711
+ )
712
+ train_datasets.append(libri_speech_clean)
713
+
714
+
715
+ # Fleurs ASR mode (English -> English text)
716
+ en_asr_fleurs = FleursDataset(
717
+ processor=processor,
718
+ split="train",
719
+ source_lang="en_us", # English
720
+ mode="asr"
721
+ )
722
+ train_datasets.append(en_asr_fleurs)
723
+
724
+
725
+ # en_ch_ast_fleurs = FleursDataset(
726
+ # processor=processor,
727
+ # split="train",
728
+ # source_lang="en_us",
729
+ # target_lang="cmn_hans_cn",
730
+ # mode="ast"
731
+ # )
732
+ # train_datasets.append(en_ch_ast_fleurs)
733
+
734
+
735
+
736
+ ch_asr_fleurs = FleursDataset(
737
+ processor=processor,
738
+ split="train",
739
+ source_lang="cmn_hans_cn",
740
+ mode="asr"
741
+ )
742
+ train_datasets.append(ch_asr_fleurs)
743
+
744
+
745
+ # ch_en_ast_fleurs = FleursDataset(
746
+ # processor=processor,
747
+ # split="train",
748
+ # source_lang="cmn_hans_cn",
749
+ # target_lang="en_us",
750
+ # mode="ast"
751
+ # )
752
+ # train_datasets.append(ch_en_ast_fleurs)
753
+
754
+ print("Count Num of Datasets", len(train_datasets))
755
+ print([len(dataset) for dataset in train_datasets])
756
+
757
+ # ConcatDataset
758
+ train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
759
+ print("Count Length of Datas", len(train_dataset))
760
+
761
+
762
+
763
+ # Check GPUs
764
+ num_gpus = torch.cuda.device_count()
765
+ print(f'training on {num_gpus} GPUs')
766
+
767
+ assert (
768
+ batch_size % (num_gpus * batch_size_per_gpu) == 0
769
+ ), 'Batch size must be divisible by the number of GPUs'
770
+ gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
771
+
772
+ # hard coded training args
773
+ dp_config = {
774
+ "fp16": {
775
+ "enabled": "auto",
776
+ "loss_scale": 0,
777
+ "loss_scale_window": 1000,
778
+ "initial_scale_power": 16,
779
+ "hysteresis": 2,
780
+ "min_loss_scale": 1
781
+ },
782
+ "zero_optimization": {
783
+ "stage": 2,
784
+ "allgather_partitions": True,
785
+ "allgather_bucket_size": 5e8,
786
+ "overlap_comm": False,
787
+ "reduce_scatter": True,
788
+ "reduce_bucket_size": 5e8,
789
+ "contiguous_gradients": True,
790
+ "cpu_offload": True
791
+ },
792
+
793
+ "train_batch_size": "auto",
794
+ "gradient_accumulation_steps": "auto",
795
+ "optimizer": {
796
+ "type": "AdamW",
797
+ "params": {
798
+ "lr": "auto",
799
+ "betas": 'auto',
800
+ "eps": 'auto',
801
+ "weight_decay": "auto"
802
+ }
803
+ },
804
+ "scheduler": {
805
+ "type": "WarmupDecayLR",
806
+ "params": {
807
+ "warmup_min_lr": "auto",
808
+ "warmup_max_lr": "auto",
809
+ "warmup_num_steps": "auto",
810
+ "total_num_steps": "auto"
811
+ }
812
+ },
813
+ "gradient_clipping": 1.0,
814
+ "zero_optimization": {
815
+ "stage": 0
816
+ }
817
+ }
818
+ training_args = TrainingArguments(
819
+ num_train_epochs=num_train_epochs,
820
+ per_device_train_batch_size=batch_size_per_gpu,
821
+ gradient_checkpointing=True,
822
+ gradient_checkpointing_kwargs={'use_reentrant': False},
823
+ gradient_accumulation_steps=gradient_accumulation_steps,
824
+ optim='adamw_torch',
825
+ adam_beta1=0.9,
826
+ adam_beta2=0.95,
827
+ adam_epsilon=1e-7,
828
+ learning_rate=learning_rate,
829
+ weight_decay=wd,
830
+ max_grad_norm=1.0,
831
+ lr_scheduler_type='cosine',
832
+ warmup_steps=50,
833
+ logging_steps=10,
834
+ output_dir=output_dir,
835
+ save_total_limit=10,
836
+ save_only_model=True,
837
+ bf16=True,
838
+ fp16=False,
839
+ remove_unused_columns=False,
840
+ report_to='none',
841
+ deepspeed=dp_config if num_gpus==1 else None,
842
+ disable_tqdm=False,
843
+ dataloader_num_workers=4,
844
+ save_strategy='steps',
845
+ save_steps=1000,
846
+ ddp_find_unused_parameters=True,
847
+
848
+ )
849
+
850
+ out_path = Path(training_args.output_dir)
851
+ out_path.mkdir(parents=True, exist_ok=True)
852
+
853
+ # create optimizer only for trainable params
854
+ optimizer = torch.optim.AdamW(
855
+ filter(lambda p: p.requires_grad, model.parameters()),
856
+ lr=learning_rate,
857
+ weight_decay=wd,
858
+ betas=(0.9, 0.95),
859
+ eps=1e-7,
860
+ )
861
+
862
+ # Trainer Setting
863
+ trainer = Trainer(
864
+ model=model,
865
+ args=training_args,
866
+ data_collator=covost_collate_fn,
867
+ train_dataset=train_dataset,
868
+ optimizers=(optimizer, None)
869
+ )
870
+
871
+ trainer.train()
872
+
873
+
874
+ # # 1. Save LoRA Adapter
875
+ model.language_model.model.save_pretrained(output_dir)
876
+
877
+ # # 1-1. Delete Markdown file
878
+ # markdown_file = os.path.join(output_dir, "README.md")
879
+ # if os.path.exists(markdown_file):
880
+ # os.remove(markdown_file)
881
+
882
+ # 2. Save entire model
883
+ model.save_pretrained(output_dir)
training_multiturn.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import sacrebleu
14
+
15
+ from datasets import load_dataset
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+ from tqdm import tqdm
18
+ from transformers import (
19
+ AutoProcessor,
20
+ AutoModel,
21
+ BatchFeature,
22
+ Trainer,
23
+ TrainingArguments,
24
+ StoppingCriteria,
25
+ StoppingCriteriaList,
26
+ )
27
+ from collections import defaultdict
28
+
29
+ import soundfile as sf
30
+ from datasets import Audio
31
+ import random
32
+ from ASRDataset import *
33
+
34
+
35
+ def count_parameters_by_module(model):
36
+ # dictionary for parameters number by modules
37
+ module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
38
+
39
+ # all params
40
+ total_params = 0
41
+ total_trainable_params = 0
42
+
43
+ # Check Embedding Token masks
44
+ embedding_masks = {}
45
+ for name, param in model.named_parameters():
46
+ if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
47
+ # check if params has embedding_grad_mask_hook
48
+ for hook_id, hook_fn in param._backward_hooks.items():
49
+ if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
50
+ # Accessing mask variables in the closure of hook functions
51
+ for cell in hook_fn.__closure__ or []:
52
+ if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
53
+ # check mask tensor
54
+ embedding_masks[name] = ~cell.cell_contents # True : Trainable
55
+
56
+ # Count params by modules
57
+ for name, param in model.named_parameters():
58
+ # extracts top module_name
59
+ module_name = name.split('.')[0]
60
+ param_count = param.numel()
61
+
62
+ module_params[module_name]["total"] += param_count
63
+ total_params += param_count
64
+
65
+ if param.requires_grad:
66
+ # Only count for real trainable params. (with masks)
67
+ if name in embedding_masks:
68
+ trainable_count = embedding_masks[name].sum().item()
69
+ module_params[module_name]["trainable"] += trainable_count
70
+ total_trainable_params += trainable_count
71
+ else:
72
+ module_params[module_name]["trainable"] += param_count
73
+ total_trainable_params += param_count
74
+
75
+ print(f"All Params: {total_params:,}")
76
+ print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
77
+ print("\nParams by Module:")
78
+
79
+ for module_name, counts in sorted(module_params.items()):
80
+ trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
81
+ total_percentage = counts["total"] / total_params * 100
82
+
83
+ print(f"- {module_name}:")
84
+ print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
85
+ print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
86
+
87
+ return module_params
88
+
89
+ def create_model(model_name_or_path, revision="main", use_flash_attention = False):
90
+ model = AutoModel.from_pretrained(
91
+ model_name_or_path,
92
+ revision=revision,
93
+ torch_dtype=torch.bfloat16,
94
+ device_map="auto",
95
+ attn_implementation="flash_attention_2" if use_flash_attention else "eager",
96
+ trust_remote_code=True,
97
+ )
98
+
99
+ # Set use_cache to False after model loaded
100
+ model.config.use_cache = False
101
+
102
+ # Freeze all parameters
103
+ for param in model.parameters():
104
+ param.requires_grad = False
105
+
106
+ model.set_lora_adapter('speech')
107
+ model.to(torch.bfloat16)
108
+
109
+ # (Optional) unfreeze audio_tower parameters
110
+ # for param in model.audio_tower.parameters():
111
+ # param.requires_grad = True
112
+
113
+ # Only unfreeze audio_projector parameters
114
+ # for param in model.audio_projector.parameters():
115
+ # param.requires_grad = True
116
+
117
+ # (Optional) unfreeze audio embed_tokens
118
+ train_embed = True
119
+ if train_embed:
120
+ embed_tokens = model.language_model.model.model.embed_tokens
121
+
122
+ embed_tokens.weight.requires_grad = False
123
+
124
+ # Added Speech token IDs (only this tokens be trainable)
125
+ trainable_token_ids = [256001, 256002]
126
+
127
+ embed_tokens.weight.requires_grad = True
128
+ mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
129
+ mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
130
+
131
+ # backward hook, with gradient masking
132
+ def embedding_grad_mask_hook(grad):
133
+ return grad.masked_fill(mask, 0)
134
+
135
+ embed_tokens.weight.register_hook(embedding_grad_mask_hook)
136
+
137
+ model.language_model.model.model.embed_tokens = embed_tokens
138
+
139
+ count_parameters_by_module(model)
140
+
141
+ return model
142
+
143
+ ANSWER_SUFFIX = "<end_of_turn>"
144
+ _IGNORE_INDEX = -100
145
+
146
+ ANSWER_SUFFIX = "<end_of_turn>"
147
+ _IGNORE_INDEX = -100
148
+
149
+ model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
150
+ use_flash_attention = False
151
+
152
+ output_dir = '../gemma_tmp13'
153
+ batch_size = 24
154
+ batch_size_per_gpu = 8
155
+ learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
156
+ wd = 0.01
157
+ num_train_epochs = 10
158
+
159
+ revision = "main" #"v1.0"
160
+
161
+ processor = AutoProcessor.from_pretrained(
162
+ model_name_or_path,
163
+ revision=revision,
164
+ trust_remote_code=True,
165
+ )
166
+
167
+ model = create_model(
168
+ model_name_or_path,
169
+ revision=revision,
170
+ use_flash_attention=use_flash_attention,
171
+ )
172
+
173
+ train_datasets = []
174
+
175
+ pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
176
+ train_datasets.append(pickup_dataset)
177
+
178
+ # custom_tw_loc = TWCostumData(processor=processor,
179
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
180
+ # train_datasets.append(custom_tw_loc) # 1500
181
+
182
+ # custom_tw_loc2 = TWCostumData(processor=processor,
183
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
184
+ # train_datasets.append(custom_tw_loc2) # 9458
185
+
186
+ # custom_yating_tw_road = TWCostumData(processor=processor,
187
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
188
+ # train_datasets.append(custom_yating_tw_road) # 35224
189
+
190
+ # custom_tw_road = TWCostumData(processor=processor,
191
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
192
+ # train_datasets.append(custom_tw_road) # 1500
193
+
194
+ # custom_tw_road2 = TWCostumData(processor=processor,
195
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
196
+ # train_datasets.append(custom_tw_road2) # 35224
197
+
198
+
199
+
200
+ print("Count Num of Datasets", len(train_datasets))
201
+ print([len(dataset) for dataset in train_datasets])
202
+
203
+ # ConcatDataset
204
+ train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
205
+ print("Count Length of Datas", len(train_dataset))
206
+
207
+
208
+
209
+ # Check GPUs
210
+ num_gpus = torch.cuda.device_count()
211
+ print(f'training on {num_gpus} GPUs')
212
+
213
+ assert (
214
+ batch_size % (num_gpus * batch_size_per_gpu) == 0
215
+ ), 'Batch size must be divisible by the number of GPUs'
216
+ gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
217
+
218
+ # hard coded training args
219
+ dp_config = {
220
+ "fp16": {
221
+ "enabled": "auto",
222
+ "loss_scale": 0,
223
+ "loss_scale_window": 1000,
224
+ "initial_scale_power": 16,
225
+ "hysteresis": 2,
226
+ "min_loss_scale": 1
227
+ },
228
+ "zero_optimization": {
229
+ "stage": 2,
230
+ "allgather_partitions": True,
231
+ "allgather_bucket_size": 5e8,
232
+ "overlap_comm": False,
233
+ "reduce_scatter": True,
234
+ "reduce_bucket_size": 5e8,
235
+ "contiguous_gradients": True,
236
+ "cpu_offload": True
237
+ },
238
+
239
+ "train_batch_size": "auto",
240
+ "gradient_accumulation_steps": "auto",
241
+ "optimizer": {
242
+ "type": "AdamW",
243
+ "params": {
244
+ "lr": "auto",
245
+ "betas": 'auto',
246
+ "eps": 'auto',
247
+ "weight_decay": "auto"
248
+ }
249
+ },
250
+ "scheduler": {
251
+ "type": "WarmupDecayLR",
252
+ "params": {
253
+ "warmup_min_lr": "auto",
254
+ "warmup_max_lr": "auto",
255
+ "warmup_num_steps": "auto",
256
+ "total_num_steps": "auto"
257
+ }
258
+ },
259
+ "gradient_clipping": 1.0,
260
+ "zero_optimization": {
261
+ "stage": 0
262
+ }
263
+ }
264
+ training_args = TrainingArguments(
265
+ num_train_epochs=num_train_epochs,
266
+ per_device_train_batch_size=batch_size_per_gpu,
267
+ gradient_checkpointing=True,
268
+ gradient_checkpointing_kwargs={'use_reentrant': False},
269
+ gradient_accumulation_steps=gradient_accumulation_steps,
270
+ optim='adamw_torch',
271
+ adam_beta1=0.9,
272
+ adam_beta2=0.95,
273
+ adam_epsilon=1e-7,
274
+ learning_rate=learning_rate,
275
+ weight_decay=wd,
276
+ max_grad_norm=1.0,
277
+ lr_scheduler_type='cosine',
278
+ warmup_steps=50,
279
+ logging_steps=10,
280
+ output_dir=output_dir,
281
+ save_total_limit=10,
282
+ save_only_model=True,
283
+ bf16=True,
284
+ fp16=False,
285
+ remove_unused_columns=False,
286
+ report_to='none',
287
+ deepspeed=None,
288
+ disable_tqdm=False,
289
+ dataloader_num_workers=16,
290
+ save_strategy='epoch',
291
+ # save_steps=2500,
292
+ ddp_find_unused_parameters=True,
293
+
294
+ )
295
+
296
+ out_path = Path(training_args.output_dir)
297
+ out_path.mkdir(parents=True, exist_ok=True)
298
+
299
+ # create optimizer only for trainable params
300
+ optimizer = torch.optim.AdamW(
301
+ filter(lambda p: p.requires_grad, model.parameters()),
302
+ lr=learning_rate,
303
+ weight_decay=wd,
304
+ betas=(0.9, 0.95),
305
+ eps=1e-7,
306
+ )
307
+
308
+ # Trainer Setting
309
+ trainer = Trainer(
310
+ model=model,
311
+ args=training_args,
312
+ data_collator=covost_collate_fn,
313
+ train_dataset=train_dataset,
314
+ optimizers=(optimizer, None)
315
+ )
316
+
317
+ trainer.train()
318
+
319
+
320
+ # # 1. Save LoRA Adapter
321
+ model.language_model.model.save_pretrained(output_dir)
322
+
323
+ # # 1-1. Delete Markdown file
324
+ # markdown_file = os.path.join(output_dir, "README.md")
325
+ # if os.path.exists(markdown_file):
326
+ # os.remove(markdown_file)
327
+
328
+ # 2. Save entire model
329
+ model.save_pretrained(output_dir)
training_multiturn_textonly.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import sacrebleu
14
+
15
+ from datasets import load_dataset
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+ from tqdm import tqdm
18
+ from transformers import (
19
+ AutoProcessor,
20
+ AutoModel,
21
+ BatchFeature,
22
+ Trainer,
23
+ TrainingArguments,
24
+ StoppingCriteria,
25
+ StoppingCriteriaList,
26
+ )
27
+ from collections import defaultdict
28
+
29
+ import soundfile as sf
30
+ from datasets import Audio
31
+ import random
32
+ from ASRDataset import *
33
+
34
+
35
+ def count_parameters_by_module(model):
36
+ # dictionary for parameters number by modules
37
+ module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
38
+
39
+ # all params
40
+ total_params = 0
41
+ total_trainable_params = 0
42
+
43
+ # Check Embedding Token masks
44
+ embedding_masks = {}
45
+ for name, param in model.named_parameters():
46
+ if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
47
+ # check if params has embedding_grad_mask_hook
48
+ for hook_id, hook_fn in param._backward_hooks.items():
49
+ if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
50
+ # Accessing mask variables in the closure of hook functions
51
+ for cell in hook_fn.__closure__ or []:
52
+ if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
53
+ # check mask tensor
54
+ embedding_masks[name] = ~cell.cell_contents # True : Trainable
55
+
56
+ # Count params by modules
57
+ for name, param in model.named_parameters():
58
+ # extracts top module_name
59
+ module_name = name.split('.')[0]
60
+ param_count = param.numel()
61
+
62
+ module_params[module_name]["total"] += param_count
63
+ total_params += param_count
64
+
65
+ if param.requires_grad:
66
+ # Only count for real trainable params. (with masks)
67
+ if name in embedding_masks:
68
+ trainable_count = embedding_masks[name].sum().item()
69
+ module_params[module_name]["trainable"] += trainable_count
70
+ total_trainable_params += trainable_count
71
+ else:
72
+ module_params[module_name]["trainable"] += param_count
73
+ total_trainable_params += param_count
74
+
75
+ print(f"All Params: {total_params:,}")
76
+ print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
77
+ print("\nParams by Module:")
78
+
79
+ for module_name, counts in sorted(module_params.items()):
80
+ trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
81
+ total_percentage = counts["total"] / total_params * 100
82
+
83
+ print(f"- {module_name}:")
84
+ print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
85
+ print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
86
+
87
+ return module_params
88
+
89
+ def create_model(model_name_or_path, revision="main", use_flash_attention = False):
90
+ model = AutoModel.from_pretrained(
91
+ model_name_or_path,
92
+ revision=revision,
93
+ torch_dtype=torch.bfloat16,
94
+ device_map="auto",
95
+ attn_implementation="flash_attention_2" if use_flash_attention else "eager",
96
+ trust_remote_code=True,
97
+ )
98
+
99
+ # Set use_cache to False after model loaded
100
+ model.config.use_cache = False
101
+
102
+ # Freeze all parameters
103
+ for param in model.parameters():
104
+ param.requires_grad = False
105
+
106
+ model.set_lora_adapter('speech')
107
+ # model.set_lora_adapter('text')
108
+ model.to(torch.bfloat16)
109
+
110
+ # (Optional) unfreeze audio_tower parameters
111
+ # for param in model.audio_tower.parameters():
112
+ # param.requires_grad = True
113
+
114
+ # Only unfreeze audio_projector parameters
115
+ # for param in model.audio_projector.parameters():
116
+ # param.requires_grad = True
117
+
118
+ # (Optional) unfreeze audio embed_tokens
119
+ train_embed = True
120
+ if train_embed:
121
+ embed_tokens = model.language_model.model.model.embed_tokens
122
+
123
+ embed_tokens.weight.requires_grad = False
124
+
125
+ # Added Speech token IDs (only this tokens be trainable)
126
+ trainable_token_ids = [256001, 256002]
127
+
128
+ embed_tokens.weight.requires_grad = True
129
+ mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
130
+ mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
131
+
132
+ # backward hook, with gradient masking
133
+ def embedding_grad_mask_hook(grad):
134
+ return grad.masked_fill(mask, 0)
135
+
136
+ embed_tokens.weight.register_hook(embedding_grad_mask_hook)
137
+
138
+ model.language_model.model.model.embed_tokens = embed_tokens
139
+
140
+ count_parameters_by_module(model)
141
+
142
+ return model
143
+
144
+ ANSWER_SUFFIX = "<end_of_turn>"
145
+ _IGNORE_INDEX = -100
146
+
147
+ ANSWER_SUFFIX = "<end_of_turn>"
148
+ _IGNORE_INDEX = -100
149
+
150
+ model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
151
+ use_flash_attention = False
152
+
153
+ output_dir = '../gemma_tmp14_audio_and_text_speechlora'
154
+ batch_size = 16
155
+ batch_size_per_gpu = 1
156
+ learning_rate = 5.0e-5 # 1.0e-4 for fine-tuning
157
+ wd = 0.01
158
+ num_train_epochs = 10
159
+
160
+ revision = "main" #"v1.0"
161
+
162
+ processor = AutoProcessor.from_pretrained(
163
+ model_name_or_path,
164
+ revision=revision,
165
+ trust_remote_code=True,
166
+ )
167
+
168
+ model = create_model(
169
+ model_name_or_path,
170
+ revision=revision,
171
+ use_flash_attention=use_flash_attention,
172
+ )
173
+
174
+ train_datasets = []
175
+
176
+ pickup_dataset = MultiturnAudioDataset(processor=processor,text_only=True,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
177
+ train_datasets.append(pickup_dataset)
178
+
179
+ pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
180
+ train_datasets.append(pickup_dataset)
181
+
182
+ # custom_tw_loc = TWCostumData(processor=processor,
183
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
184
+ # train_datasets.append(custom_tw_loc) # 1500
185
+
186
+ # custom_tw_loc2 = TWCostumData(processor=processor,
187
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
188
+ # train_datasets.append(custom_tw_loc2) # 9458
189
+
190
+ # custom_yating_tw_road = TWCostumData(processor=processor,
191
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
192
+ # train_datasets.append(custom_yating_tw_road) # 35224
193
+
194
+ # custom_tw_road = TWCostumData(processor=processor,
195
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
196
+ # train_datasets.append(custom_tw_road) # 1500
197
+
198
+ # custom_tw_road2 = TWCostumData(processor=processor,
199
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
200
+ # train_datasets.append(custom_tw_road2) # 35224
201
+
202
+
203
+
204
+ print("Count Num of Datasets", len(train_datasets))
205
+ print([len(dataset) for dataset in train_datasets])
206
+
207
+ # ConcatDataset
208
+ train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
209
+ print("Count Length of Datas", len(train_dataset))
210
+
211
+
212
+
213
+ # Check GPUs
214
+ num_gpus = torch.cuda.device_count()
215
+ print(f'training on {num_gpus} GPUs')
216
+
217
+ assert (
218
+ batch_size % (num_gpus * batch_size_per_gpu) == 0
219
+ ), 'Batch size must be divisible by the number of GPUs'
220
+ gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
221
+
222
+ # hard coded training args
223
+ dp_config = {
224
+ "fp16": {
225
+ "enabled": "auto",
226
+ "loss_scale": 0,
227
+ "loss_scale_window": 1000,
228
+ "initial_scale_power": 16,
229
+ "hysteresis": 2,
230
+ "min_loss_scale": 1
231
+ },
232
+ "zero_optimization": {
233
+ "stage": 2,
234
+ "allgather_partitions": True,
235
+ "allgather_bucket_size": 5e8,
236
+ "overlap_comm": False,
237
+ "reduce_scatter": True,
238
+ "reduce_bucket_size": 5e8,
239
+ "contiguous_gradients": True,
240
+ "cpu_offload": True
241
+ },
242
+
243
+ "train_batch_size": "auto",
244
+ "gradient_accumulation_steps": "auto",
245
+ "optimizer": {
246
+ "type": "AdamW",
247
+ "params": {
248
+ "lr": "auto",
249
+ "betas": 'auto',
250
+ "eps": 'auto',
251
+ "weight_decay": "auto"
252
+ }
253
+ },
254
+ "scheduler": {
255
+ "type": "WarmupDecayLR",
256
+ "params": {
257
+ "warmup_min_lr": "auto",
258
+ "warmup_max_lr": "auto",
259
+ "warmup_num_steps": "auto",
260
+ "total_num_steps": "auto"
261
+ }
262
+ },
263
+ "gradient_clipping": 1.0,
264
+ "zero_optimization": {
265
+ "stage": 0
266
+ }
267
+ }
268
+ training_args = TrainingArguments(
269
+ num_train_epochs=num_train_epochs,
270
+ per_device_train_batch_size=batch_size_per_gpu,
271
+ gradient_checkpointing=True,
272
+ gradient_checkpointing_kwargs={'use_reentrant': False},
273
+ gradient_accumulation_steps=gradient_accumulation_steps,
274
+ optim='adamw_torch',
275
+ adam_beta1=0.9,
276
+ adam_beta2=0.95,
277
+ adam_epsilon=1e-7,
278
+ learning_rate=learning_rate,
279
+ weight_decay=wd,
280
+ max_grad_norm=1.0,
281
+ lr_scheduler_type='cosine',
282
+ warmup_steps=50,
283
+ logging_steps=10,
284
+ output_dir=output_dir,
285
+ save_total_limit=10,
286
+ save_only_model=True,
287
+ bf16=True,
288
+ fp16=False,
289
+ remove_unused_columns=False,
290
+ report_to='none',
291
+ deepspeed=None,
292
+ disable_tqdm=False,
293
+ dataloader_num_workers=16,
294
+ save_strategy='epoch',
295
+ # save_steps=2500,
296
+ ddp_find_unused_parameters=True,
297
+
298
+ )
299
+
300
+ out_path = Path(training_args.output_dir)
301
+ out_path.mkdir(parents=True, exist_ok=True)
302
+
303
+ # create optimizer only for trainable params
304
+ optimizer = torch.optim.AdamW(
305
+ filter(lambda p: p.requires_grad, model.parameters()),
306
+ lr=learning_rate,
307
+ weight_decay=wd,
308
+ betas=(0.9, 0.95),
309
+ eps=1e-7,
310
+ )
311
+
312
+ # Trainer Setting
313
+ trainer = Trainer(
314
+ model=model,
315
+ args=training_args,
316
+ data_collator=covost_collate_fn,
317
+ train_dataset=train_dataset,
318
+ optimizers=(optimizer, None)
319
+ )
320
+
321
+ trainer.train()
322
+
323
+
324
+ # # 1. Save LoRA Adapter
325
+ model.language_model.model.save_pretrained(output_dir)
326
+
327
+ # # 1-1. Delete Markdown file
328
+ # markdown_file = os.path.join(output_dir, "README.md")
329
+ # if os.path.exists(markdown_file):
330
+ # os.remove(markdown_file)
331
+
332
+ # 2. Save entire model
333
+ model.save_pretrained(output_dir)