jva96160 commited on
Commit
a2dca42
·
verified ·
1 Parent(s): 4c1ba5a

Upload 71 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. cpp/ASRDataset.py +794 -0
  3. cpp/__pycache__/ASRDataset.cpython-310.pyc +0 -0
  4. cpp/__pycache__/speech_conformer_encoder.cpython-310.pyc +0 -0
  5. cpp/convert_onnx.ipynb +767 -0
  6. cpp/convert_tensorRT.ipynb +0 -0
  7. cpp/gemma_v1/ASRDataset.py +793 -0
  8. cpp/gemma_v1/__pycache__/ASRDataset.cpython-312.pyc +0 -0
  9. cpp/gemma_v1/added_tokens.json +3 -0
  10. cpp/gemma_v1/chat_template.json +3 -0
  11. cpp/gemma_v1/config.json +118 -0
  12. cpp/gemma_v1/configuration_gemma3omni.py +206 -0
  13. cpp/gemma_v1/eval.py +635 -0
  14. cpp/gemma_v1/eval_multiturn.ipynb +0 -0
  15. cpp/gemma_v1/eval_multiturn.py +211 -0
  16. cpp/gemma_v1/merge_lora.ipynb +119 -0
  17. cpp/gemma_v1/model-00001-of-00003.safetensors +3 -0
  18. cpp/gemma_v1/model-00002-of-00003.safetensors +3 -0
  19. cpp/gemma_v1/model-00003-of-00003.safetensors +3 -0
  20. cpp/gemma_v1/model.safetensors.index.json +0 -0
  21. cpp/gemma_v1/modeling_gemma3omni.py +668 -0
  22. cpp/gemma_v1/preprocessing_gemma3omni.py +444 -0
  23. cpp/gemma_v1/preprocessor_config.json +41 -0
  24. cpp/gemma_v1/processor_config.json +7 -0
  25. cpp/gemma_v1/special_tokens_map.json +36 -0
  26. cpp/gemma_v1/speech_conformer_encoder.py +0 -0
  27. cpp/gemma_v1/speech_conformer_encoder_old.py +0 -0
  28. cpp/gemma_v1/tokenizer.json +3 -0
  29. cpp/gemma_v1/tokenizer.model +3 -0
  30. cpp/gemma_v1/tokenizer_config.json +0 -0
  31. cpp/gemma_v1/training.py +883 -0
  32. cpp/gemma_v1/training_multiturn.py +329 -0
  33. cpp/gemma_v1/training_multiturn_textonly.py +333 -0
  34. cpp/inference/audio_encoder_lib.cpp +388 -0
  35. cpp/inference/audio_encoder_lib.h +141 -0
  36. cpp/inference/audio_encoder_lib.o +0 -0
  37. cpp/inference/audio_inference +0 -0
  38. cpp/inference/audio_inference_app +0 -0
  39. cpp/inference/compile.sh +32 -0
  40. cpp/inference/dummy.wav +0 -0
  41. cpp/inference/f0.txt +0 -0
  42. cpp/inference/f_inp.txt +0 -0
  43. cpp/inference/kiss_fft.o +0 -0
  44. cpp/inference/kiss_fftr.o +0 -0
  45. cpp/inference/main_text.cpp +165 -0
  46. cpp/inference/matrix_output.txt +0 -0
  47. cpp/inference/run.sh +7 -0
  48. cpp/inference/test copy 2.cpp +567 -0
  49. cpp/inference/test copy.cpp +301 -0
  50. cpp/inference/test.cpp +702 -0
.gitattributes CHANGED
@@ -34,3 +34,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ cpp/gemma_v1/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17376838-breezyvoice-00818.pcm filter=lfs diff=lfs merge=lfs -text
39
+ cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17376838-breezyvoice-00818.wav filter=lfs diff=lfs merge=lfs -text
40
+ cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382475-breezyvoice-01452.wav filter=lfs diff=lfs merge=lfs -text
41
+ cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382570-breezyvoice-01041.pcm filter=lfs diff=lfs merge=lfs -text
42
+ cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382570-breezyvoice-01041.wav filter=lfs diff=lfs merge=lfs -text
43
+ cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382594-breezyvoice-00389.pcm filter=lfs diff=lfs merge=lfs -text
44
+ cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382594-breezyvoice-00389.wav filter=lfs diff=lfs merge=lfs -text
45
+ cpp/sample_data/pickup_breezy-common_voice_zh-TW_17382570-breezyvoice-01041.wav filter=lfs diff=lfs merge=lfs -text
cpp/ASRDataset.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+ import sacrebleu
13
+
14
+ from datasets import load_dataset
15
+ from torch.utils.data import Dataset, ConcatDataset
16
+ from tqdm import tqdm
17
+ from transformers import (
18
+ BatchFeature,
19
+ )
20
+ import pandas as pd
21
+ import soundfile as sf
22
+ from datasets import Audio
23
+ import random
24
+ from copy import deepcopy
25
+ import torchaudio
26
+
27
+ ANSWER_SUFFIX = "<end_of_turn>"
28
+ _IGNORE_INDEX = -100
29
+ class BaseAudioDataset(Dataset):
30
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
31
+ self.processor = processor
32
+ self.training = "train" in split or 'other' in split
33
+ self.debug = debug
34
+ self.sampling_rate = sampling_rate
35
+ self.name = ""
36
+
37
+ def set_dataset_name(self, name):
38
+ self.name = name
39
+
40
+ @staticmethod
41
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
42
+ original_size = len(data)
43
+
44
+ data = data.cast_column(audio_field, Audio(decode=False))
45
+
46
+ def identify_corrupted_files(example):
47
+ try:
48
+ sf.read(example[audio_field]["path"])
49
+
50
+ for field in text_fields:
51
+ if field in example and example[field].replace('"', '') == "":
52
+ return False
53
+ return True
54
+ except Exception:
55
+ return False
56
+
57
+ data = data.filter(identify_corrupted_files, num_proc=16)
58
+ validated_size = len(data)
59
+
60
+ # Audio Decoding
61
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
62
+
63
+ if debug:
64
+ print(f"Dataset: {dataset_name}")
65
+ print(f"Original data nums: {original_size}")
66
+ print(f"After filtering data nums: {validated_size}")
67
+ print(f"Filtering ratio: {validated_size/original_size:.2%}")
68
+
69
+ return data
70
+
71
+ @staticmethod
72
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
73
+ original_size = len(data)
74
+
75
+ def filter_audio_by_length(example):
76
+ try:
77
+ audio = example[audio_field]['array']
78
+ channel = 1
79
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
80
+ channel = audio.ndim
81
+ audio = audio.squeeze()
82
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
83
+ return min_sec <= audio_length <= max_sec
84
+ except Exception as e:
85
+ if debug:
86
+ print(f"Error : {str(e)[:100]}... - sample excluded")
87
+ return False
88
+
89
+ data = data.filter(filter_audio_by_length, num_proc=16)
90
+ filtered_size = len(data)
91
+
92
+ if debug:
93
+ print(f"Before Length Filtering data nums: {original_size}")
94
+ print(f"After Length Filtering data nums: {filtered_size}")
95
+ print(f"Filtering ratio: {filtered_size/original_size:.2%}")
96
+
97
+ return data
98
+
99
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
100
+ user_message = {
101
+ 'role': 'user',
102
+ 'content': '<start_of_audio>' + instruction,
103
+ }
104
+ prompt = self.processor.tokenizer.apply_chat_template(
105
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
106
+ )
107
+
108
+ inputs = self.processor(
109
+ text=prompt,
110
+ audio=[audio_array],
111
+ add_special_tokens=False,
112
+ return_tensors='pt'
113
+ )
114
+
115
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
116
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
117
+
118
+ if self.debug:
119
+ self.debug = False
120
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
121
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
122
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
123
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
124
+
125
+ if self.training:
126
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
127
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
128
+ labels[:, -answer_ids.shape[1]:] = answer_ids
129
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
130
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
131
+ else:
132
+ input_ids = inputs.input_ids
133
+ labels = answer_ids
134
+ token_type_ids = inputs.token_type_ids
135
+
136
+ return {
137
+ 'input_ids': input_ids,
138
+ 'labels': labels,
139
+ 'token_type_ids': token_type_ids,
140
+ 'input_audio_embeds': inputs.input_audio_embeds,
141
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
142
+ 'input_modes': inputs.input_modes,
143
+ }
144
+
145
+ # Libri Speech Dataset Class
146
+ class LibriSpeechDataset(BaseAudioDataset):
147
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
148
+ super().__init__(processor, split, sampling_rate, debug)
149
+
150
+ self.set_dataset_name(f"LibriSpeech_{subset}")
151
+ # only ASR
152
+ self.ast = False
153
+ self.lang = "en"
154
+
155
+ # load dataset
156
+ self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
157
+ subset,
158
+ split=split,
159
+ trust_remote_code=True,
160
+ cache_dir=Path("/mnt/jeff/InCar/data")
161
+ )
162
+
163
+ # (Optional) Audio length Filtering
164
+ self.data = self.filter_by_audio_length(self.data, "audio")
165
+
166
+ # Instruction Setting
167
+ self.instruction = random.choice(INSTRUCTION["asr"])
168
+
169
+ def __len__(self):
170
+ return len(self.data)
171
+
172
+ def __getitem__(self, idx):
173
+ data = self.data[idx]
174
+
175
+ # Libri Speech is only for ASR
176
+ answer_text = data["text"].replace('"', '')
177
+
178
+ return self.prepare_model_inputs(
179
+ data["audio"]["array"],
180
+ self.instruction,
181
+ answer_text
182
+ )
183
+
184
+ # common_voice_16_1 dataset
185
+ class CommonVoiceDataset(BaseAudioDataset):
186
+ def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
187
+ super().__init__(processor, split, sampling_rate, debug)
188
+
189
+ self.set_dataset_name(f"CommonVoice_{source_lang}")
190
+ # only ASR
191
+ self.ast = False
192
+ self.lang=source_lang
193
+
194
+ # load dataset
195
+ if source_lang=="zh-TW":
196
+ data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
197
+ else:
198
+ data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
199
+ self.data = load_dataset(data_path,
200
+ source_lang,
201
+ split=split,
202
+ trust_remote_code=True,
203
+ cache_dir=Path("/mnt/jeff/InCar/data")
204
+ )
205
+ def prepare_dataset(batch):
206
+ """Function to preprocess the dataset with the .map method"""
207
+ transcription = batch["sentence"]
208
+
209
+ if transcription.startswith('"') and transcription.endswith('"'):
210
+ # we can remove trailing quotation marks as they do not affect the transcription
211
+ transcription = transcription[1:-1]
212
+
213
+ if transcription[-1] not in [".", "?", "!"]:
214
+ # append a full-stop to sentences that do not end in punctuation
215
+ transcription = transcription + "."
216
+
217
+ batch["sentence"] = transcription
218
+
219
+ return batch
220
+
221
+
222
+ import opencc
223
+ converter = opencc.OpenCC('s2tw.json')
224
+ def To_zhTW(batch):
225
+
226
+ transcription = converter.convert(batch["sentence"])
227
+ batch["sentence"] = transcription
228
+
229
+ return batch
230
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
231
+ if source_lang=='zh-CN':
232
+ self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
233
+
234
+
235
+ # (Optional) Audio length Filtering
236
+ self.data = self.filter_by_audio_length(self.data, "audio")
237
+
238
+ if source_lang == "zh-TW" and split=='train':
239
+ import torchaudio
240
+ from torchaudio import transforms
241
+ import copy
242
+ import pickle
243
+ import os
244
+ def subsample(batch):
245
+ batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
246
+ batch['audio']['sampling_rate']=16000
247
+ return batch
248
+ def TW_data_augment_fast(batch):
249
+ speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
250
+ new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
251
+ batch['audio']['array'] = new_array_fast
252
+ return batch
253
+ def TW_data_augment_slow(batch):
254
+ speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
255
+ new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
256
+ batch['audio']['array'] = new_array_slow
257
+ return batch
258
+ # data = self.data.map(subsample, num_proc=1, desc="subsample")
259
+ fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
260
+ if not os.path.exists(fast_path):
261
+ data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
262
+ with open(fast_path,'wb') as f:
263
+ pickle.dump(data_fast,f)
264
+ else:
265
+ with open(fast_path,'rb') as f:
266
+ data_fast=pickle.load(f)
267
+
268
+ slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
269
+ if not os.path.exists(slow_path):
270
+ data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
271
+ with open(slow_path,'wb') as f:
272
+ pickle.dump(data_slow,f)
273
+ else:
274
+ with open(slow_path,'rb') as f:
275
+ data_slow=pickle.load(f)
276
+ self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
277
+
278
+ # Instruction Setting
279
+ self.instruction = random.choice(INSTRUCTION["asr"])
280
+
281
+ def __len__(self):
282
+ return len(self.data)
283
+
284
+ def __getitem__(self, idx):
285
+ data = self.data[idx]
286
+
287
+ answer_text = data["sentence"]
288
+ return self.prepare_model_inputs(
289
+ data["audio"]["array"],
290
+ self.instruction,
291
+ answer_text
292
+ )
293
+
294
+
295
+ # Fleurs Dataset Class
296
+ class FleursDataset(BaseAudioDataset):
297
+ def __init__(self, processor, split, source_lang, target_lang=None,
298
+ mode="asr", sampling_rate=16000, debug=False):
299
+ super().__init__(processor, split, sampling_rate, debug)
300
+
301
+ self.set_dataset_name("Fleurs")
302
+ # Mode Setting (ASR or AST)
303
+ if mode not in ["asr", "ast"]:
304
+ raise ValueError("mode must be 'asr' or 'ast'.")
305
+
306
+ self.mode = mode
307
+ self.ast = (mode == "ast")
308
+ self.source_lang = source_lang
309
+
310
+ # Language name mapping (expand if needed)
311
+ self.lang_names = {
312
+ 'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
313
+ }
314
+
315
+ # load dataset - source language dataset
316
+ self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
317
+ source_lang,
318
+ split=split,
319
+ trust_remote_code=True,
320
+ cache_dir=Path("/mnt/jeff/InCar/data")
321
+ )
322
+ import opencc
323
+ converter = opencc.OpenCC('s2tw.json')
324
+ def prepare_dataset(batch):
325
+ transcription = converter.convert(batch["transcription"])
326
+ batch["transcription"] = transcription
327
+
328
+ return batch
329
+ if (source_lang=="cmn_hans_cn"):
330
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
331
+
332
+ # (Optional) Audio length Filtering
333
+ self.data = self.filter_by_audio_length(self.data, "audio")
334
+ self.target_lang_name = ""
335
+ # When AST mode, load target language dataset.
336
+ if self.ast:
337
+ if target_lang is None:
338
+ raise ValueError("AST mode requires target_lang.")
339
+
340
+ self.target_lang = target_lang
341
+ self.lang = f"{source_lang}_{target_lang}"
342
+
343
+ # load dataset - target language dataset (for translation)
344
+ target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
345
+ target_lang,
346
+ split=split,
347
+ trust_remote_code=True,
348
+ cache_dir=Path("/mnt/jeff/InCar/data")
349
+ )
350
+ if target_lang=="cmn_hans_cn":
351
+ target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
352
+ source_dict = {item['id']: item for item in self.data}
353
+ target_dict = {item['id']: item for item in target_data}
354
+
355
+ # only Common ID, add translation fields
356
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
357
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
358
+ self.data = [
359
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
360
+ for id in common_ids
361
+ ]
362
+
363
+ # Instruction Setting - use target language name
364
+ self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
365
+ self.instruction = random.choice(INSTRUCTION["ast"])
366
+ else:
367
+ # ASR mode
368
+ self.lang = source_lang
369
+ self.instruction = random.choice(INSTRUCTION["asr"])
370
+
371
+ if self.debug:
372
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
373
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
374
+ if self.ast:
375
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
376
+ print(f"dataset size: {len(self.data)}")
377
+
378
+ def __len__(self):
379
+ return len(self.data)
380
+
381
+ def __getitem__(self, idx):
382
+ data = self.data[idx]
383
+ audio_array = data["audio"]["array"]
384
+
385
+ if self.ast:
386
+ answer_text = data["translation"]
387
+ else:
388
+ answer_text = data["transcription"]
389
+
390
+ return self.prepare_model_inputs(
391
+ audio_array,
392
+ self.instruction.format(self.target_lang_name),
393
+ answer_text
394
+ )
395
+
396
+ class TWCostumData(BaseAudioDataset):
397
+
398
+ def __init__(self, processor, split="train", sampling_rate=16000,csv_path="", debug=False):
399
+ super().__init__(processor, split, sampling_rate, debug)
400
+ import pandas as pd
401
+ from datasets import Dataset, Audio
402
+
403
+
404
+ df = pd.read_csv(csv_path).fillna('')
405
+
406
+
407
+ self.set_dataset_name(f"TWCostumData")
408
+ self.data = Dataset.from_dict(
409
+ {
410
+ "audio": [audio for audio in df['audio']],
411
+ "sentence": [text for text in df['text']]
412
+ }
413
+ ).cast_column("audio", Audio(sampling_rate=16000))
414
+
415
+ # Instruction Setting
416
+ self.instruction = random.choice(INSTRUCTION["asr"])
417
+
418
+ def __len__(self):
419
+ return len(self.data)
420
+
421
+ def __getitem__(self, idx):
422
+ data = self.data[idx]
423
+
424
+ answer_text = data["sentence"]
425
+ return self.prepare_model_inputs(
426
+ data["audio"]["array"],
427
+ self.instruction,
428
+ answer_text
429
+ )
430
+ def covost_collate_fn(batch):
431
+ input_ids_list = []
432
+ labels_list = []
433
+ token_type_ids_list = []
434
+ input_audio_embeds_list = []
435
+ audio_embed_sizes_list = []
436
+ audio_attention_mask_list = []
437
+ input_modes_list = []
438
+ audio_paths = []
439
+ for inputs in batch:
440
+ if 'audio_path' in inputs:
441
+ audio_paths.append(inputs['audio_path'])
442
+ input_ids_list.append(inputs['input_ids'][0])
443
+ labels_list.append(inputs['labels'][0])
444
+ token_type_ids_list.append(inputs['token_type_ids'][0])
445
+ if inputs['input_modes']==2:
446
+ input_audio_embeds_list.append(inputs['input_audio_embeds'])
447
+ audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
448
+ audio_attention_mask_list.append(
449
+ inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
450
+ )
451
+ # else:
452
+ # input_audio_embeds_list.append(None)
453
+ # audio_embed_sizes_list.append(None)
454
+ # audio_attention_mask_list.append(None)
455
+ input_modes_list.append(inputs['input_modes'])
456
+ # try:
457
+ token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
458
+ input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
459
+ labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
460
+ audio_attention_mask = (
461
+ pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
462
+ if len(audio_attention_mask_list) > 1
463
+ else None
464
+ )
465
+ # except Exception as e:
466
+ # print(e)
467
+ # print(input_ids_list)
468
+ # print(labels_list)
469
+ # raise
470
+ attention_mask = (input_ids != 0).long()
471
+ input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) if len(input_audio_embeds_list)>0 else None
472
+ audio_embed_sizes = torch.cat(audio_embed_sizes_list) if len(audio_embed_sizes_list)>0 else None
473
+ input_modes = torch.cat(input_modes_list)
474
+ if len(audio_paths)>0:
475
+ return BatchFeature(
476
+ {
477
+ "audio_path": audio_paths,
478
+ 'input_ids': input_ids,
479
+ 'labels': labels,
480
+ 'token_type_ids': token_type_ids,
481
+ 'attention_mask': attention_mask,
482
+ 'input_audio_embeds': input_audio_embeds,
483
+ 'audio_embed_sizes': audio_embed_sizes,
484
+ 'audio_attention_mask': audio_attention_mask,
485
+ 'input_modes': input_modes,
486
+ }
487
+ )
488
+ else:
489
+ return BatchFeature(
490
+ {
491
+ 'input_ids': input_ids,
492
+ 'labels': labels,
493
+ 'token_type_ids': token_type_ids,
494
+ 'attention_mask': attention_mask,
495
+ 'input_audio_embeds': input_audio_embeds,
496
+ 'audio_embed_sizes': audio_embed_sizes,
497
+ 'audio_attention_mask': audio_attention_mask,
498
+ 'input_modes': input_modes,
499
+ }
500
+ )
501
+
502
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
503
+ """
504
+ Pad a list of sequences to the same length.
505
+ sequences: list of tensors in [seq_len, *] shape
506
+ """
507
+ assert padding_side in ['right', 'left']
508
+ max_size = sequences[0].size()
509
+ trailing_dims = max_size[1:]
510
+ max_len = max(len(seq) for seq in sequences)
511
+ batch_size = len(sequences)
512
+ output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
513
+ for i, seq in enumerate(sequences):
514
+ length = seq.size(0)
515
+ if padding_side == 'right':
516
+ output.data[i, :length] = seq
517
+ else:
518
+ output.data[i, -length:] = seq
519
+ return output
520
+
521
+ def cat_with_pad(tensors, dim, padding_value=0):
522
+ """
523
+ cat along dim, while pad to max for all other dims
524
+ """
525
+ ndim = tensors[0].dim()
526
+ assert all(
527
+ t.dim() == ndim for t in tensors[1:]
528
+ ), 'All tensors must have the same number of dimensions'
529
+
530
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
531
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
532
+ output = tensors[0].new_full(out_size, padding_value)
533
+
534
+ index = 0
535
+ for t in tensors:
536
+ # Create a slice list where every dimension except dim is full slice
537
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
538
+ # Update only the concat dimension slice
539
+ slices[dim] = slice(index, index + t.shape[dim])
540
+
541
+ output[slices] = t
542
+ index += t.shape[dim]
543
+
544
+ return output
545
+
546
+
547
+
548
+ class MultiturnAudioDataset(BaseAudioDataset):
549
+ def __init__(self, processor, split="train", sampling_rate=16000,json_path="",text_only=False, debug=False):
550
+ super().__init__(processor, split, sampling_rate, debug)
551
+ from llamafactory.data.template import Llama2Template,parse_template
552
+ from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
553
+ from llamafactory.data.mm_plugin import get_mm_plugin
554
+ import json
555
+ self.train=False
556
+ self.text_only=text_only
557
+ with open(json_path) as f:
558
+ js_data = json.load(f)
559
+ if split=='train':
560
+ self.train=True
561
+ js_data = js_data[:int(len(js_data)*0.8)]
562
+ else:
563
+ js_data = js_data[-int(len(js_data)*0.2):]
564
+ for conv in js_data:
565
+ for mess in conv['conversations']:
566
+ if 'audio_path' in mess:
567
+ mess['audio_path'] = mess['audio_path'].replace('/home/jeff/codes/llm/InCar/srdc_generate_tts/','/mnt/jeff/InCar/data/multiturn_data/')
568
+ default_system = ""#"""You are a helpful assistant that determines how to solve problems based on user needs and converts user speech into text.\n"""
569
+ self.template=Llama2Template(
570
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
571
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
572
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
573
+ format_function=FunctionFormatter(slots=["{{content}}", {"eos_token"}], tool_format="default"),
574
+ format_tools = ToolFormatter(tool_format="default"),
575
+ format_observation=StringFormatter(
576
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
577
+ ),
578
+ default_system=default_system,
579
+ thought_words=("<think>", "</think>"),
580
+ efficient_eos=False,
581
+ replace_eos=False,
582
+ replace_jinja_template=False,
583
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
584
+ stop_words=["<end_of_turn>"],
585
+ mm_plugin=get_mm_plugin(name="base"),
586
+ enable_thinking=False
587
+ )
588
+
589
+ self.set_dataset_name(f"MultiturnCostumData")
590
+
591
+
592
+ self.data = []
593
+ self.text_only_data = []
594
+ for conv in js_data:
595
+ tools = conv['tools'] if 'tools' in conv else ""
596
+ system = conv['system'] if 'system' in conv else default_system
597
+ tmp = {
598
+ 'tools':tools,
599
+ 'system':system,
600
+ 'messages':[],
601
+ }
602
+ for i,mess in enumerate(conv['conversations']):
603
+ tmp['messages'].append(mess)
604
+ if mess['from']=='human':
605
+ tmp['messages'].append(conv['conversations'][i+1])
606
+ d = deepcopy(tmp)
607
+ d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0]
608
+ self.data.append(d)
609
+ if self.text_only:
610
+ self.text_only_data.append(deepcopy(tmp))
611
+ tmp['messages'].pop()
612
+ elif mess['from']=='observation':
613
+ tmp['messages'].append(conv['conversations'][i+1])
614
+ d = deepcopy(tmp)
615
+ self.text_only_data.append(d)
616
+ tmp['messages'].pop()
617
+ if text_only:
618
+ self.data=self.text_only_data
619
+
620
+
621
+ def prepare_multiturn_model_inputs(self, audio_array, messages, system="", tools=""):
622
+ ANSWER_SUFFIX = "<end_of_turn>"
623
+ prompt = ""
624
+ answer_text = ""
625
+ user_transcribe = ""
626
+ audio_paths = []
627
+ for i, message in enumerate(messages):
628
+ elements = []
629
+
630
+ system_text = ""
631
+ if i == 0:
632
+ elements += self.template.format_prefix.apply()
633
+ if system or tools:
634
+ tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
635
+ system_text = self.template.format_system.apply(content=(system + tool_text))[0]
636
+
637
+ if message["from"] == "human":
638
+ if i==len(messages)-2 and not self.text_only:
639
+ user_transcribe = message["value"]
640
+ elements += self.template.format_user.apply(content=system_text+'<start_of_audio>')
641
+ else:
642
+ elements += self.template.format_user.apply(content=system_text + message["value"])
643
+ audio_paths.append(message['audio_path'])
644
+ elif message["from"] == "gpt":
645
+ elements += self.template.format_assistant.apply(content=message["value"])
646
+ elif message["from"] == "observation":
647
+ elements += self.template.format_observation.apply(content=message["value"])
648
+ elif message["from"] == "function_call":
649
+ elements += self.template.format_function.apply(content=message["value"])
650
+ else:
651
+ raise NotImplementedError("Unexpected role: {}".format(message["from"]))
652
+
653
+
654
+ for elem in elements:
655
+ ele_str = ""
656
+ if isinstance(elem, str):
657
+ ele_str=elem
658
+ elif isinstance(elem, set):
659
+ if "bos_token" in elem and self.processor.tokenizer.bos_token_id is not None:
660
+ ele_str = self.processor.tokenizer.bos_token
661
+ elif "eos_token" in elem and self.processor.tokenizer.eos_token_id is not None:
662
+ ele_str = self.processor.tokenizer.eos_token
663
+ if i == len(messages)-1:
664
+ answer_text+=ele_str
665
+ else:
666
+ prompt+=ele_str
667
+
668
+
669
+ if type(audio_array)!=type(None):
670
+ inputs = self.processor(
671
+ text=prompt,
672
+ audio=[audio_array],
673
+ add_special_tokens=False,
674
+ return_tensors='pt'
675
+ )
676
+ answer = "\nUser transcribe is : {};\nGPT output is : {}{}".format(user_transcribe,answer_text,ANSWER_SUFFIX)
677
+ else:
678
+ inputs = self.processor(
679
+ text=prompt,
680
+ audio=None,
681
+ add_special_tokens=False,
682
+ return_tensors='pt'
683
+ )
684
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
685
+ # print('user_transcribe',user_transcribe)
686
+ # print('answer_text', answer)
687
+ # print('prompt',prompt)
688
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
689
+
690
+ if self.debug:
691
+ self.debug = False
692
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
693
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
694
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
695
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
696
+
697
+ if self.training:
698
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
699
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
700
+ labels[:, -answer_ids.shape[1]:] = answer_ids
701
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
702
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
703
+ else:
704
+ input_ids = inputs.input_ids
705
+ labels = answer_ids
706
+ token_type_ids = inputs.token_type_ids
707
+ if type(audio_array)!=type(None):
708
+ if not self.train:
709
+ return {
710
+ "audio_path": audio_paths,
711
+ 'input_ids': input_ids,
712
+ 'labels': labels,
713
+ 'token_type_ids': token_type_ids,
714
+ 'input_audio_embeds': inputs.input_audio_embeds,
715
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
716
+ 'input_modes': inputs.input_modes,
717
+ }
718
+ else:
719
+ return {
720
+ 'input_ids': input_ids,
721
+ 'labels': labels,
722
+ 'token_type_ids': token_type_ids,
723
+ 'input_audio_embeds': inputs.input_audio_embeds,
724
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
725
+ 'input_modes': inputs.input_modes,
726
+ }
727
+ else:
728
+ return {
729
+ 'input_ids': input_ids,
730
+ 'labels': labels,
731
+ 'token_type_ids': token_type_ids,
732
+ 'input_audio_embeds': None,
733
+ 'audio_embed_sizes': None,
734
+ 'input_modes': inputs.input_modes,
735
+ }
736
+ def __len__(self):
737
+ return len(self.data)
738
+
739
+ def __getitem__(self, idx):
740
+ data = self.data[idx]
741
+ return self.prepare_multiturn_model_inputs(
742
+ audio_array=data["audio_array"] if "audio_array" in data else None,
743
+ messages=data['messages'],
744
+ system=data["system"],
745
+ tools=data["tools"]
746
+ )
747
+
748
+
749
+
750
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
751
+
752
+ INSTRUCTION = {
753
+ "ast": [
754
+ "Translate the audio to {0}.",
755
+ "Translate the audio clip into {0}.",
756
+ "Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
757
+ "Translate the provided audio file into {0}.",
758
+ "Convert the audio speech to {0} text.",
759
+ "Write an {0} translation of the audio file.",
760
+ "Translate spoken words from the audio into {0}.",
761
+ "Create an {0} version of the audio content.",
762
+ "Produce an accurate {0} translation of the audio.",
763
+ "Extract speech from the audio and translate it to {0}.",
764
+ "Turn the audio into readable {0} text.",
765
+ "Write all spoken content from the audio in {0}.",
766
+ "Generate an {0} translation of the speech in the file.",
767
+ "Convert the recording into {0} text.",
768
+ "Accurately translate the audio recording to {0}.",
769
+ "Write down dialogue from the given audio in {0}.",
770
+ "Translate all speech in this audio file to {0}.",
771
+ "Create an accurate {0} version of the speech.",
772
+ "Perform a complete {0} translation of the audio."
773
+ ],
774
+ "asr": [
775
+ "Transcribe the audio clip into text.",
776
+ "Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
777
+ "Transcribe the provided audio file into text.",
778
+ "Convert the audio speech to text.",
779
+ "Write a transcript of the audio file.",
780
+ "Transcribe spoken words from the audio.",
781
+ "Create a text version of the audio content.",
782
+ "Produce a verbatim transcript of the audio.",
783
+ "Extract and transcribe speech from the audio.",
784
+ "Turn the audio into readable text.",
785
+ "Write all spoken words from the audio.",
786
+ "Generate a transcript of the speech in the file.",
787
+ "Convert the recording into a text transcript.",
788
+ "Accurately transcribe the audio recording.",
789
+ "Write down dialogue from the given audio.",
790
+ "Transcribe all speech in this audio file.",
791
+ "Create an accurate text version of the speech.",
792
+ "Perform a complete transcription of the audio."
793
+ ],
794
+ }
cpp/__pycache__/ASRDataset.cpython-310.pyc ADDED
Binary file (22.5 kB). View file
 
cpp/__pycache__/speech_conformer_encoder.cpython-310.pyc ADDED
Binary file (79.5 kB). View file
 
cpp/convert_onnx.ipynb ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/foxconnhy/miniconda3/envs/llamafactory/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n",
14
+ "/home/jeff/.cache/huggingface/modules/transformers_modules/gemma_v1/speech_conformer_encoder.py:2798: FutureWarning: Please specify CheckpointImpl.NO_REENTRANT as CheckpointImpl.REENTRANT will soon be removed as the default and eventually deprecated.\n",
15
+ " lambda i: encoder_checkpoint_wrapper(\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stdout",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "######################## speech lora #############\n",
23
+ "######################## text lora #############\n"
24
+ ]
25
+ },
26
+ {
27
+ "name": "stderr",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00, 1.80it/s]\n",
31
+ "Some weights of Gemma3OmniForConditionalGeneration were not initialized from the model checkpoint at /mnt/data-2t/jeff/codes/llm/cpp/gemma_v1 and are newly initialized: ['language_model.model.base_model.model.layers.0.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.v_proj.lora_B.text.weight']\n",
32
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
33
+ "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
34
+ ]
35
+ }
36
+ ],
37
+ "source": [
38
+ "from io import BytesIO\n",
39
+ "import torch\n",
40
+ "import numpy as np\n",
41
+ "from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor\n",
42
+ "\n",
43
+ "# converter = opencc.OpenCC('s2tw.json')\n",
44
+ "\n",
45
+ "model_id = \"/mnt/data-2t/jeff/codes/llm/cpp/gemma_v1\"\n",
46
+ "revision = \"main\" #\"v1.0\"\n",
47
+ "\n",
48
+ "model = AutoModel.from_pretrained(\n",
49
+ " model_id, device_map=\"cpu\", revision = revision, trust_remote_code=True\n",
50
+ ").eval()\n",
51
+ "\n",
52
+ "processor = AutoProcessor.from_pretrained(\n",
53
+ " model_id, revision = revision, trust_remote_code=True\n",
54
+ ")"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 2,
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "data": {
64
+ "text/plain": [
65
+ "Sequential(\n",
66
+ " (0): Linear(in_features=1024, out_features=2560, bias=True)\n",
67
+ " (1): GELU(approximate='none')\n",
68
+ " (2): Linear(in_features=2560, out_features=2560, bias=True)\n",
69
+ ")"
70
+ ]
71
+ },
72
+ "execution_count": 2,
73
+ "metadata": {},
74
+ "output_type": "execute_result"
75
+ }
76
+ ],
77
+ "source": [
78
+ "model.audio_tower\n",
79
+ "model.audio_projector"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 179,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "from ASRDataset import *\n",
89
+ "pickup_dataset = MultiturnAudioDataset(split='train',processor=processor,json_path='/mnt/data-2t/jeff/codes/llm/cpp/sample_data/pickup_processed.json')"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 180,
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "name": "stdout",
99
+ "output_type": "stream",
100
+ "text": [
101
+ "torch.Size([1, 256, 80])\n",
102
+ "torch.Size([1, 217, 80])\n",
103
+ "torch.Size([1, 77, 80])\n",
104
+ "torch.Size([1, 580, 80])\n"
105
+ ]
106
+ }
107
+ ],
108
+ "source": [
109
+ "for i in range(len(pickup_dataset)):\n",
110
+ " inp = pickup_dataset.__getitem__(i)\n",
111
+ " print(inp['input_audio_embeds'].shape)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 5,
117
+ "metadata": {},
118
+ "outputs": [
119
+ {
120
+ "data": {
121
+ "text/plain": [
122
+ "torch.Size([1, 100, 2560])"
123
+ ]
124
+ },
125
+ "execution_count": 5,
126
+ "metadata": {},
127
+ "output_type": "execute_result"
128
+ }
129
+ ],
130
+ "source": [
131
+ "inp = pickup_dataset.__getitem__(3)\n",
132
+ "fea,mask = model.audio_tower(inp['input_audio_embeds'],torch.ones(inp['input_audio_embeds'].shape[:2]))\n",
133
+ "model.audio_projector(fea).shape"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "import torch \n",
143
+ "import torch.nn as nn\n",
144
+ "from speech_conformer_encoder import ConformerEncoder\n",
145
+ "class Gemma3AudioEncoder(nn.Module):\n",
146
+ " def __init__(self,):\n",
147
+ " super().__init__()\n",
148
+ " audio_config = model.config.audio_config.to_diff_dict()\n",
149
+ " for item in ['transformers_version', 'model_type', 'torch_dtype']:\n",
150
+ " if item in audio_config:\n",
151
+ " audio_config.pop(item)\n",
152
+ " # self.audio_tower = model.audio_tower\n",
153
+ " # self.audio_projector = model.audio_projector\n",
154
+ " self.audio_tower = ConformerEncoder(**audio_config)#model.audio_tower\n",
155
+ " self.audio_projector = nn.Sequential(\n",
156
+ " nn.Linear(in_features=1024, out_features=2560, bias=True),\n",
157
+ " nn.GELU(approximate='none'),\n",
158
+ " nn.Linear(in_features=2560, out_features=2560, bias=True))#model.audio_projector\n",
159
+ " def forward(self,x,mask):\n",
160
+ " # mask = torch.ones(x.shape[:2])\n",
161
+ " x,_ = self.audio_tower(x,mask)\n",
162
+ " x = self.audio_projector(x)\n",
163
+ " return x\n",
164
+ "audio_encoder = Gemma3AudioEncoder()\n",
165
+ "import copy\n",
166
+ "audio_encoder.audio_tower.encoder_embedding=copy.deepcopy(model.audio_tower.encoder_embedding)\n",
167
+ "audio_encoder.audio_projector.load_state_dict(model.audio_projector.state_dict())\n",
168
+ "audio_encoder.audio_tower.load_state_dict(model.audio_tower.state_dict())"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "import numpy as np\n",
178
+ "import onnx\n",
179
+ "import onnxruntime as ort\n",
180
+ "import onnxscript\n",
181
+ "import os\n",
182
+ "import requests\n",
183
+ "import shutil\n",
184
+ "import soundfile\n",
185
+ "import subprocess\n",
186
+ "import sys\n",
187
+ "import torch\n",
188
+ "\n",
189
+ "from onnx import helper, numpy_helper, TensorProto\n",
190
+ "from onnxruntime_genai.models.builder import create_model\n",
191
+ "from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper\n",
192
+ "from onnxscript import ir\n",
193
+ "from torch.export import Dim, export\n",
194
+ "def build_speech(outputdir='./onnx_files'):\n",
195
+ " # TorchScript export\n",
196
+ " dummy_inputs = (\n",
197
+ " torch.randn((1,97,80)),\n",
198
+ " torch.ones((1,97))\n",
199
+ " #inputs[\"input_audio_embeds\"], # audio_embeds: torch.FloatTensor\n",
200
+ " #inputs[\"audio_attention_mask\"], # audio_attention_mask: torch.BoolTensor\n",
201
+ " # inputs[\"audio_embed_sizes\"], # audio_sizes: torch.LongTensor\n",
202
+ " # inputs[\"input_mode\"], # audio_projection_mode: int\n",
203
+ " )\n",
204
+ " filename = \"phi-4-mm-speech.onnx\"\n",
205
+ "\n",
206
+ " temp_folder_1 = os.path.join(outputdir, \"speech_init_export\")\n",
207
+ " os.makedirs(temp_folder_1, exist_ok=True)\n",
208
+ "\n",
209
+ " fpath_1 = os.path.join(temp_folder_1, filename)\n",
210
+ " torch._dynamo.config.capture_scalar_outputs = True\n",
211
+ " onnx_program = torch.onnx.export(audio_encoder, dummy_inputs, fpath_1,\n",
212
+ " input_names=[\"audio_embeds\", \"audio_attention_mask\"], \n",
213
+ " output_names=[\"audio_features\"],\n",
214
+ " opset_version=20,\n",
215
+ " dynamic_axes={\n",
216
+ " \"audio_embeds\": {0:'B',1: \"L\"},\n",
217
+ " \"audio_attention_mask\": {0:'B',1: \"L\"},\n",
218
+ " },\n",
219
+ " )\n",
220
+ "\n",
221
+ "build_speech()"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 44,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "import onnxruntime as ort\n",
231
+ "import numpy as np\n",
232
+ "ort_sess = ort.InferenceSession(\"/mnt/data-2t/jeff/codes/llm/cpp/onnx_files/speech_init_export/phi-4-mm-speech.onnx\")"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 45,
238
+ "metadata": {},
239
+ "outputs": [
240
+ {
241
+ "data": {
242
+ "text/plain": [
243
+ "(2, 111, 2560)"
244
+ ]
245
+ },
246
+ "execution_count": 45,
247
+ "metadata": {},
248
+ "output_type": "execute_result"
249
+ }
250
+ ],
251
+ "source": [
252
+ "import warnings\n",
253
+ "warnings.filterwarnings('ignore')\n",
254
+ "from tqdm import tqdm\n",
255
+ "import torch\n",
256
+ "import numpy as np\n",
257
+ "a=[]\n",
258
+ "# for i in tqdm(range(10000)):\n",
259
+ "# try:\n",
260
+ "ort_sess.run(None, {\"audio_embeds\": np.array(torch.randn(1,97,80),dtype=np.float32),\n",
261
+ " # \"audio_attention_mask\":np.ones((1,97),dtype=np.float32)\n",
262
+ " }\n",
263
+ " )\n",
264
+ " # print(i)\n",
265
+ " # a.append(i)\n",
266
+ " # except:\n",
267
+ " # pass\n",
268
+ "ort_sess.run(None, {\"audio_embeds\": np.array(torch.randn(2,888,80),dtype=np.float32),\n",
269
+ " # \"audio_attention_mask\":np.ones((2,97),dtype=np.float32)\n",
270
+ " }\n",
271
+ " )[0].shape"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "metadata": {},
277
+ "source": [
278
+ "# Python inference time check"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "import time\n",
288
+ "total = 0\n",
289
+ "_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n",
290
+ "for i in range(100):\n",
291
+ " now = time.time()\n",
292
+ " inp = np.random.randn(np.random.randint(16240, 48240)).reshape(1,-1)#np.array(torch.randn(1,np.random.randint(100,300),80),dtype=np.float32)\n",
293
+ " inp = _extract_features(inp,16000).reshape(1,-1,80)\n",
294
+ " now = time.time()\n",
295
+ " # inp = np.array(torch.randn(1,150,80),dtype=np.float32)\n",
296
+ " ort_sess.run(None, {\"audio_embeds\": inp,\n",
297
+ " # \"audio_attention_mask\":np.ones((1,97),dtype=np.float32)\n",
298
+ " })\n",
299
+ " total += time.time()-now\n",
300
+ " \n",
301
+ "total,total/100"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": 238,
307
+ "metadata": {},
308
+ "outputs": [
309
+ {
310
+ "data": {
311
+ "text/plain": [
312
+ "24240"
313
+ ]
314
+ },
315
+ "execution_count": 238,
316
+ "metadata": {},
317
+ "output_type": "execute_result"
318
+ }
319
+ ],
320
+ "source": [
321
+ "149*160+400"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": 233,
327
+ "metadata": {},
328
+ "outputs": [
329
+ {
330
+ "data": {
331
+ "text/plain": [
332
+ "(1, 40917)"
333
+ ]
334
+ },
335
+ "execution_count": 233,
336
+ "metadata": {},
337
+ "output_type": "execute_result"
338
+ }
339
+ ],
340
+ "source": [
341
+ "np.random.randn(np.random.randint(16240, 48240)).reshape(1,-1).shape"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {},
348
+ "outputs": [
349
+ {
350
+ "data": {
351
+ "text/plain": [
352
+ "(10.608245611190796, 0.10608245611190796)"
353
+ ]
354
+ },
355
+ "execution_count": 218,
356
+ "metadata": {},
357
+ "output_type": "execute_result"
358
+ }
359
+ ],
360
+ "source": [
361
+ "import time\n",
362
+ "total = 0\n",
363
+ "for i in range(100):\n",
364
+ " tmp = torch.randn(1,np.random.randint(100,300),80)\n",
365
+ " mask = torch.ones(tmp.shape[:2])\n",
366
+ " now = time.time()\n",
367
+ " audio_encoder(tmp,mask)\n",
368
+ " total += time.time()-now\n",
369
+ "total,total/100"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "markdown",
374
+ "metadata": {},
375
+ "source": [
376
+ "# C++ ERROR check"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": 167,
382
+ "metadata": {},
383
+ "outputs": [
384
+ {
385
+ "data": {
386
+ "text/plain": [
387
+ "tensor([[[ 0.3246, 0.0295, 0.1076, ..., -0.1125, -0.0894, -0.3800],\n",
388
+ " [ 0.3267, -0.2442, 0.2653, ..., 0.7783, -0.6049, -1.0858],\n",
389
+ " [ 0.1797, 0.0438, 0.9673, ..., 0.5126, -0.5657, -0.7050],\n",
390
+ " ...,\n",
391
+ " [ 0.0261, -0.0324, 0.0230, ..., -0.1303, 0.0343, 0.1486],\n",
392
+ " [ 0.1655, -0.3327, 0.4232, ..., 0.0513, 0.4222, -0.3645],\n",
393
+ " [ 0.1147, -0.1201, 0.4198, ..., 0.6170, 0.0838, -0.1409]]],\n",
394
+ " grad_fn=<ViewBackward0>)"
395
+ ]
396
+ },
397
+ "execution_count": 167,
398
+ "metadata": {},
399
+ "output_type": "execute_result"
400
+ }
401
+ ],
402
+ "source": [
403
+ "inp = pickup_dataset.__getitem__(3)\n",
404
+ "fea,mask = model.audio_tower(inp['input_audio_embeds'],torch.ones(inp['input_audio_embeds'].shape[:2]))\n",
405
+ "model.audio_projector(fea)"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": 201,
411
+ "metadata": {},
412
+ "outputs": [
413
+ {
414
+ "data": {
415
+ "text/plain": [
416
+ "array([[[ 0.13004433, -0.06643961, 0.01333247, ..., -0.05643693,\n",
417
+ " -0.23922557, 0.569423 ],\n",
418
+ " [-0.75552 , -0.05047493, -0.82725084, ..., 0.32261163,\n",
419
+ " -0.14968234, -0.7078437 ],\n",
420
+ " [-0.6673857 , 0.33906737, -0.6191502 , ..., 0.04259709,\n",
421
+ " -0.01194861, 0.27635992],\n",
422
+ " ...,\n",
423
+ " [ 0.02916821, -0.03163592, 0.02736526, ..., -0.12979224,\n",
424
+ " 0.03317374, 0.15346158],\n",
425
+ " [-0.8559882 , -0.5196625 , 0.2549707 , ..., 0.28192428,\n",
426
+ " 1.4099622 , -0.15940394],\n",
427
+ " [-0.20253824, -0.30478072, -0.6786582 , ..., 0.08860758,\n",
428
+ " -0.12145798, 0.525889 ]]], dtype=float32)"
429
+ ]
430
+ },
431
+ "execution_count": 201,
432
+ "metadata": {},
433
+ "output_type": "execute_result"
434
+ }
435
+ ],
436
+ "source": [
437
+ "inp = pickup_dataset.__getitem__(0)\n",
438
+ "res = ort_sess.run(None, {\"audio_embeds\": np.array(inp['input_audio_embeds'],dtype=np.float32),\n",
439
+ " # \"audio_attention_mask\":np.ones((2,97),dtype=np.float32)\n",
440
+ " }\n",
441
+ " )[0]\n",
442
+ "res"
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "execution_count": 208,
448
+ "metadata": {},
449
+ "outputs": [
450
+ {
451
+ "data": {
452
+ "text/plain": [
453
+ "array([[[ 0.130969 , -0.0697925, 0.0150866, ..., -0.0559536,\n",
454
+ " -0.239062 , 0.567436 ],\n",
455
+ " [-0.753288 , -0.0582227, -0.825365 , ..., 0.320587 ,\n",
456
+ " -0.153626 , -0.709664 ],\n",
457
+ " [-0.656874 , 0.342632 , -0.607641 , ..., 0.0383743,\n",
458
+ " -0.0218912, 0.269968 ],\n",
459
+ " ...,\n",
460
+ " [ 0.0291714, -0.0316175, 0.027369 , ..., -0.129825 ,\n",
461
+ " 0.033166 , 0.153453 ],\n",
462
+ " [-0.854555 , -0.530883 , 0.258313 , ..., 0.279057 ,\n",
463
+ " 1.40658 , -0.159066 ],\n",
464
+ " [-0.197598 , -0.306157 , -0.67907 , ..., 0.0915015,\n",
465
+ " -0.124402 , 0.52159 ]]])"
466
+ ]
467
+ },
468
+ "execution_count": 208,
469
+ "metadata": {},
470
+ "output_type": "execute_result"
471
+ }
472
+ ],
473
+ "source": [
474
+ "f = open('/mnt/data-2t/jeff/codes/llm/cpp/inference/f0.txt')\n",
475
+ "content = f.readlines()\n",
476
+ "f.close()\n",
477
+ "audio_fea_cpp = np.array([float(i) for i in content[0].split(',')]).reshape(1,-1,2560)\n",
478
+ "audio_fea_cpp"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": 202,
484
+ "metadata": {},
485
+ "outputs": [
486
+ {
487
+ "data": {
488
+ "text/plain": [
489
+ "(array([[[0.917797, 1.33496 , 1.9894 , ..., 6.60723 , 6.95787 ,\n",
490
+ " 7.20139 ],\n",
491
+ " [0. , 0. , 0. , ..., 5.99914 , 6.11214 ,\n",
492
+ " 6.40908 ],\n",
493
+ " [0. , 0. , 0. , ..., 5.1184 , 5.36291 ,\n",
494
+ " 5.14623 ],\n",
495
+ " ...,\n",
496
+ " [0. , 0. , 0. , ..., 6.25256 , 6.29312 ,\n",
497
+ " 7.05511 ],\n",
498
+ " [0. , 0. , 0. , ..., 6.49829 , 6.7198 ,\n",
499
+ " 7.08144 ],\n",
500
+ " [0. , 0. , 1.08376 , ..., 5.43068 , 5.97577 ,\n",
501
+ " 6.35748 ]]]),\n",
502
+ " tensor([[[0.8826, 1.3054, 1.9652, ..., 6.6069, 6.9578, 7.2011],\n",
503
+ " [0.0000, 0.0000, 0.0000, ..., 5.9991, 6.1121, 6.4091],\n",
504
+ " [0.0000, 0.0000, 0.0000, ..., 5.1147, 5.3624, 5.1428],\n",
505
+ " ...,\n",
506
+ " [0.0000, 0.0000, 0.0000, ..., 6.2526, 6.2931, 7.0548],\n",
507
+ " [0.0000, 0.0000, 0.0000, ..., 6.4981, 6.7198, 7.0807],\n",
508
+ " [0.0000, 0.0000, 1.1479, ..., 5.4311, 5.9743, 6.3568]]]))"
509
+ ]
510
+ },
511
+ "execution_count": 202,
512
+ "metadata": {},
513
+ "output_type": "execute_result"
514
+ }
515
+ ],
516
+ "source": [
517
+ "f = open('/mnt/data-2t/jeff/codes/llm/cpp/inference/matrix_output.txt')\n",
518
+ "txtlines = f.readlines()\n",
519
+ "f.close()\n",
520
+ "inp_emb_cpp = np.array([float(i) for l in txtlines for i in l.split(',')]).reshape(1,-1,80)\n",
521
+ "inp_emb_cpp,pickup_dataset.__getitem__(0)['input_audio_embeds']"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "markdown",
526
+ "metadata": {},
527
+ "source": [
528
+ "# Python preprocessor"
529
+ ]
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "execution_count": null,
534
+ "metadata": {},
535
+ "outputs": [
536
+ {
537
+ "data": {
538
+ "text/plain": [
539
+ "(353, 80)"
540
+ ]
541
+ },
542
+ "execution_count": 66,
543
+ "metadata": {},
544
+ "output_type": "execute_result"
545
+ }
546
+ ],
547
+ "source": [
548
+ "# modify the code : \n",
549
+ "# 1. input model and input pcm from args. \n",
550
+ "# 2. add model input preprocessor by following python code. The wav input of _extract_features which is an audio array\n",
551
+ "# 3. the onnx model input is [batch,frames,feature size] = [-1,-1,80]\n",
552
+ "\n",
553
+ "def _extract_spectrogram(wav, fs):\n",
554
+ " \"\"\"Extract spectrogram features from waveform.\n",
555
+ " Args:\n",
556
+ " wav (1D array): waveform of the input\n",
557
+ " fs (int): sampling rate of the waveform, 16000.\n",
558
+ " Output:\n",
559
+ " log_fbank (2D array): a TxD matrix of log Mel filterbank features.\n",
560
+ " D=80, and T is the number of frames.\n",
561
+ " \"\"\"\n",
562
+ " if wav.ndim > 1:\n",
563
+ " wav = np.squeeze(wav)\n",
564
+ "\n",
565
+ " # by default, we extract the mean if stereo\n",
566
+ " if len(wav.shape) == 2:\n",
567
+ " wav = wav.mean(1)\n",
568
+ "\n",
569
+ " preemphasis = 0.97\n",
570
+ " n_fft = 512\n",
571
+ " win_length = 400\n",
572
+ " hop_length = 160\n",
573
+ " fft_window = np.hamming(400)\n",
574
+ "\n",
575
+ " # Spec 1: SpeechLib cut remaining sample insufficient for a hop\n",
576
+ " n_batch = (wav.shape[0] - win_length) // hop_length + 1\n",
577
+ " # Here we don't use stride_tricks since the input array may not satisfy\n",
578
+ " # memory layout requirement and we need writeable output\n",
579
+ " # Here we only use list of views before copy to desination\n",
580
+ " # so it is more efficient than broadcasting\n",
581
+ " y_frames = np.array(\n",
582
+ " [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],\n",
583
+ " dtype=np.float32,\n",
584
+ " )\n",
585
+ "\n",
586
+ " # Spec 2: SpeechLib applies preemphasis within each batch\n",
587
+ " y_frames_prev = np.roll(y_frames, 1, axis=1)\n",
588
+ " y_frames_prev[:, 0] = y_frames_prev[:, 1]\n",
589
+ " y_frames = (y_frames - preemphasis * y_frames_prev) * 32768\n",
590
+ "\n",
591
+ " S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)\n",
592
+ " spec = np.abs(S).astype(np.float32)\n",
593
+ " return spec\n",
594
+ "def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):\n",
595
+ " \"\"\"Create a Mel filter-bank the same as SpeechLib FbankFC.\n",
596
+ "\n",
597
+ " Args:\n",
598
+ " sample_rate (int): Sample rate in Hz. number > 0 [scalar]\n",
599
+ " n_fft (int): FFT size. int > 0 [scalar]\n",
600
+ " n_mel (int): Mel filter size. int > 0 [scalar]\n",
601
+ " fmin (float): lowest frequency (in Hz). If None use 0.0.\n",
602
+ " float >= 0 [scalar]\n",
603
+ " fmax: highest frequency (in Hz). If None use sample_rate / 2.\n",
604
+ " float >= 0 [scalar]\n",
605
+ "\n",
606
+ " Returns\n",
607
+ " out (numpy.ndarray): Mel transform matrix\n",
608
+ " [shape=(n_mels, 1 + n_fft/2)]\n",
609
+ " \"\"\"\n",
610
+ "\n",
611
+ " bank_width = int(n_fft // 2 + 1)\n",
612
+ " if fmax is None:\n",
613
+ " fmax = sample_rate / 2\n",
614
+ " if fmin is None:\n",
615
+ " fmin = 0\n",
616
+ " assert fmin >= 0, \"fmin cannot be negtive\"\n",
617
+ " assert fmin < fmax <= sample_rate / 2, \"fmax must be between (fmin, samplerate / 2]\"\n",
618
+ "\n",
619
+ " def mel(f):\n",
620
+ " return 1127.0 * np.log(1.0 + f / 700.0)\n",
621
+ "\n",
622
+ " def bin2mel(fft_bin):\n",
623
+ " return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))\n",
624
+ "\n",
625
+ " def f2bin(f):\n",
626
+ " return int((f * n_fft / sample_rate) + 0.5)\n",
627
+ "\n",
628
+ " # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]\n",
629
+ " klo = f2bin(fmin) + 1\n",
630
+ " khi = f2bin(fmax)\n",
631
+ "\n",
632
+ " khi = max(khi, klo)\n",
633
+ "\n",
634
+ " # Spec 2: SpeechLib uses trianges in Mel space\n",
635
+ " mlo = mel(fmin)\n",
636
+ " mhi = mel(fmax)\n",
637
+ " m_centers = np.linspace(mlo, mhi, n_mels + 2)\n",
638
+ " ms = (mhi - mlo) / (n_mels + 1)\n",
639
+ "\n",
640
+ " matrix = np.zeros((n_mels, bank_width), dtype=np.float32)\n",
641
+ " for m in range(0, n_mels):\n",
642
+ " left = m_centers[m]\n",
643
+ " center = m_centers[m + 1]\n",
644
+ " right = m_centers[m + 2]\n",
645
+ " for fft_bin in range(klo, khi):\n",
646
+ " mbin = bin2mel(fft_bin)\n",
647
+ " if left < mbin < right:\n",
648
+ " matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms\n",
649
+ "\n",
650
+ " return matrix\n",
651
+ "\n",
652
+ "def _extract_features(wav, fs):\n",
653
+ " \"\"\"Extract log filterbank features from waveform.\n",
654
+ " Args:\n",
655
+ " wav (1D array): waveform of the input\n",
656
+ " fs (int): sampling rate of the waveform, 16000 or 8000.\n",
657
+ " If fs=8000, the waveform will be resampled to 16000Hz.\n",
658
+ " Output:\n",
659
+ " log_fbank (2D array): a TxD matrix of log Mel filterbank features.\n",
660
+ " D=80, and T is the number of frames.\n",
661
+ " \"\"\"\n",
662
+ " spec = _extract_spectrogram(wav, fs)\n",
663
+ " spec_power = spec**2\n",
664
+ "\n",
665
+ " fbank_power = np.clip(spec_power.dot(_mel), 1.0, None)\n",
666
+ " log_fbank = np.log(fbank_power).astype(np.float32)\n",
667
+ "\n",
668
+ " return log_fbank\n",
669
+ "\n",
670
+ "## example \n",
671
+ "## input shape of arr is [1, 56832], output shape will be (353,80)\n",
672
+ "_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n",
673
+ "output = _extract_features(arr,16000)"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": 227,
679
+ "metadata": {},
680
+ "outputs": [
681
+ {
682
+ "data": {
683
+ "text/plain": [
684
+ "(256, 80)"
685
+ ]
686
+ },
687
+ "execution_count": 227,
688
+ "metadata": {},
689
+ "output_type": "execute_result"
690
+ }
691
+ ],
692
+ "source": [
693
+ "_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n",
694
+ "output = _extract_features(arr,16000)\n",
695
+ "output.shape"
696
+ ]
697
+ },
698
+ {
699
+ "cell_type": "code",
700
+ "execution_count": null,
701
+ "metadata": {},
702
+ "outputs": [
703
+ {
704
+ "data": {
705
+ "text/plain": [
706
+ "256"
707
+ ]
708
+ },
709
+ "execution_count": 228,
710
+ "metadata": {},
711
+ "output_type": "execute_result"
712
+ }
713
+ ],
714
+ "source": [
715
+ "(41239-400)//160+1 100~300"
716
+ ]
717
+ },
718
+ {
719
+ "cell_type": "code",
720
+ "execution_count": 229,
721
+ "metadata": {},
722
+ "outputs": [
723
+ {
724
+ "data": {
725
+ "text/plain": [
726
+ "(16240, 48240)"
727
+ ]
728
+ },
729
+ "execution_count": 229,
730
+ "metadata": {},
731
+ "output_type": "execute_result"
732
+ }
733
+ ],
734
+ "source": [
735
+ "99*160+400,299*160+400"
736
+ ]
737
+ },
738
+ {
739
+ "cell_type": "code",
740
+ "execution_count": null,
741
+ "metadata": {},
742
+ "outputs": [],
743
+ "source": []
744
+ }
745
+ ],
746
+ "metadata": {
747
+ "kernelspec": {
748
+ "display_name": "llamafactory",
749
+ "language": "python",
750
+ "name": "python3"
751
+ },
752
+ "language_info": {
753
+ "codemirror_mode": {
754
+ "name": "ipython",
755
+ "version": 3
756
+ },
757
+ "file_extension": ".py",
758
+ "mimetype": "text/x-python",
759
+ "name": "python",
760
+ "nbconvert_exporter": "python",
761
+ "pygments_lexer": "ipython3",
762
+ "version": "3.10.16"
763
+ }
764
+ },
765
+ "nbformat": 4,
766
+ "nbformat_minor": 2
767
+ }
cpp/convert_tensorRT.ipynb ADDED
File without changes
cpp/gemma_v1/ASRDataset.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+ import sacrebleu
13
+
14
+ from datasets import load_dataset
15
+ from torch.utils.data import Dataset, ConcatDataset
16
+ from tqdm import tqdm
17
+ from transformers import (
18
+ BatchFeature,
19
+ )
20
+ import pandas as pd
21
+ import soundfile as sf
22
+ from datasets import Audio
23
+ import random
24
+ from copy import deepcopy
25
+ import torchaudio
26
+
27
+ ANSWER_SUFFIX = "<end_of_turn>"
28
+ _IGNORE_INDEX = -100
29
+ class BaseAudioDataset(Dataset):
30
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
31
+ self.processor = processor
32
+ self.training = "train" in split or 'other' in split
33
+ self.debug = debug
34
+ self.sampling_rate = sampling_rate
35
+ self.name = ""
36
+
37
+ def set_dataset_name(self, name):
38
+ self.name = name
39
+
40
+ @staticmethod
41
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
42
+ original_size = len(data)
43
+
44
+ data = data.cast_column(audio_field, Audio(decode=False))
45
+
46
+ def identify_corrupted_files(example):
47
+ try:
48
+ sf.read(example[audio_field]["path"])
49
+
50
+ for field in text_fields:
51
+ if field in example and example[field].replace('"', '') == "":
52
+ return False
53
+ return True
54
+ except Exception:
55
+ return False
56
+
57
+ data = data.filter(identify_corrupted_files, num_proc=16)
58
+ validated_size = len(data)
59
+
60
+ # Audio Decoding
61
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
62
+
63
+ if debug:
64
+ print(f"Dataset: {dataset_name}")
65
+ print(f"Original data nums: {original_size}")
66
+ print(f"After filtering data nums: {validated_size}")
67
+ print(f"Filtering ratio: {validated_size/original_size:.2%}")
68
+
69
+ return data
70
+
71
+ @staticmethod
72
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
73
+ original_size = len(data)
74
+
75
+ def filter_audio_by_length(example):
76
+ try:
77
+ audio = example[audio_field]['array']
78
+ channel = 1
79
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
80
+ channel = audio.ndim
81
+ audio = audio.squeeze()
82
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
83
+ return min_sec <= audio_length <= max_sec
84
+ except Exception as e:
85
+ if debug:
86
+ print(f"Error : {str(e)[:100]}... - sample excluded")
87
+ return False
88
+
89
+ data = data.filter(filter_audio_by_length, num_proc=16)
90
+ filtered_size = len(data)
91
+
92
+ if debug:
93
+ print(f"Before Length Filtering data nums: {original_size}")
94
+ print(f"After Length Filtering data nums: {filtered_size}")
95
+ print(f"Filtering ratio: {filtered_size/original_size:.2%}")
96
+
97
+ return data
98
+
99
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
100
+ user_message = {
101
+ 'role': 'user',
102
+ 'content': '<start_of_audio>' + instruction,
103
+ }
104
+ prompt = self.processor.tokenizer.apply_chat_template(
105
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
106
+ )
107
+
108
+ inputs = self.processor(
109
+ text=prompt,
110
+ audio=[audio_array],
111
+ add_special_tokens=False,
112
+ return_tensors='pt'
113
+ )
114
+
115
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
116
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
117
+
118
+ if self.debug:
119
+ self.debug = False
120
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
121
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
122
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
123
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
124
+
125
+ if self.training:
126
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
127
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
128
+ labels[:, -answer_ids.shape[1]:] = answer_ids
129
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
130
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
131
+ else:
132
+ input_ids = inputs.input_ids
133
+ labels = answer_ids
134
+ token_type_ids = inputs.token_type_ids
135
+
136
+ return {
137
+ 'input_ids': input_ids,
138
+ 'labels': labels,
139
+ 'token_type_ids': token_type_ids,
140
+ 'input_audio_embeds': inputs.input_audio_embeds,
141
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
142
+ 'input_modes': inputs.input_modes,
143
+ }
144
+
145
+ # Libri Speech Dataset Class
146
+ class LibriSpeechDataset(BaseAudioDataset):
147
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
148
+ super().__init__(processor, split, sampling_rate, debug)
149
+
150
+ self.set_dataset_name(f"LibriSpeech_{subset}")
151
+ # only ASR
152
+ self.ast = False
153
+ self.lang = "en"
154
+
155
+ # load dataset
156
+ self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
157
+ subset,
158
+ split=split,
159
+ trust_remote_code=True,
160
+ cache_dir=Path("/mnt/jeff/InCar/data")
161
+ )
162
+
163
+ # (Optional) Audio length Filtering
164
+ self.data = self.filter_by_audio_length(self.data, "audio")
165
+
166
+ # Instruction Setting
167
+ self.instruction = random.choice(INSTRUCTION["asr"])
168
+
169
+ def __len__(self):
170
+ return len(self.data)
171
+
172
+ def __getitem__(self, idx):
173
+ data = self.data[idx]
174
+
175
+ # Libri Speech is only for ASR
176
+ answer_text = data["text"].replace('"', '')
177
+
178
+ return self.prepare_model_inputs(
179
+ data["audio"]["array"],
180
+ self.instruction,
181
+ answer_text
182
+ )
183
+
184
+ # common_voice_16_1 dataset
185
+ class CommonVoiceDataset(BaseAudioDataset):
186
+ def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
187
+ super().__init__(processor, split, sampling_rate, debug)
188
+
189
+ self.set_dataset_name(f"CommonVoice_{source_lang}")
190
+ # only ASR
191
+ self.ast = False
192
+ self.lang=source_lang
193
+
194
+ # load dataset
195
+ if source_lang=="zh-TW":
196
+ data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
197
+ else:
198
+ data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
199
+ self.data = load_dataset(data_path,
200
+ source_lang,
201
+ split=split,
202
+ trust_remote_code=True,
203
+ cache_dir=Path("/mnt/jeff/InCar/data")
204
+ )
205
+ def prepare_dataset(batch):
206
+ """Function to preprocess the dataset with the .map method"""
207
+ transcription = batch["sentence"]
208
+
209
+ if transcription.startswith('"') and transcription.endswith('"'):
210
+ # we can remove trailing quotation marks as they do not affect the transcription
211
+ transcription = transcription[1:-1]
212
+
213
+ if transcription[-1] not in [".", "?", "!"]:
214
+ # append a full-stop to sentences that do not end in punctuation
215
+ transcription = transcription + "."
216
+
217
+ batch["sentence"] = transcription
218
+
219
+ return batch
220
+
221
+
222
+ import opencc
223
+ converter = opencc.OpenCC('s2tw.json')
224
+ def To_zhTW(batch):
225
+
226
+ transcription = converter.convert(batch["sentence"])
227
+ batch["sentence"] = transcription
228
+
229
+ return batch
230
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
231
+ if source_lang=='zh-CN':
232
+ self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
233
+
234
+
235
+ # (Optional) Audio length Filtering
236
+ self.data = self.filter_by_audio_length(self.data, "audio")
237
+
238
+ if source_lang == "zh-TW" and split=='train':
239
+ import torchaudio
240
+ from torchaudio import transforms
241
+ import copy
242
+ import pickle
243
+ import os
244
+ def subsample(batch):
245
+ batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
246
+ batch['audio']['sampling_rate']=16000
247
+ return batch
248
+ def TW_data_augment_fast(batch):
249
+ speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
250
+ new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
251
+ batch['audio']['array'] = new_array_fast
252
+ return batch
253
+ def TW_data_augment_slow(batch):
254
+ speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
255
+ new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
256
+ batch['audio']['array'] = new_array_slow
257
+ return batch
258
+ # data = self.data.map(subsample, num_proc=1, desc="subsample")
259
+ fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
260
+ if not os.path.exists(fast_path):
261
+ data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
262
+ with open(fast_path,'wb') as f:
263
+ pickle.dump(data_fast,f)
264
+ else:
265
+ with open(fast_path,'rb') as f:
266
+ data_fast=pickle.load(f)
267
+
268
+ slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
269
+ if not os.path.exists(slow_path):
270
+ data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
271
+ with open(slow_path,'wb') as f:
272
+ pickle.dump(data_slow,f)
273
+ else:
274
+ with open(slow_path,'rb') as f:
275
+ data_slow=pickle.load(f)
276
+ self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
277
+
278
+ # Instruction Setting
279
+ self.instruction = random.choice(INSTRUCTION["asr"])
280
+
281
+ def __len__(self):
282
+ return len(self.data)
283
+
284
+ def __getitem__(self, idx):
285
+ data = self.data[idx]
286
+
287
+ answer_text = data["sentence"]
288
+ return self.prepare_model_inputs(
289
+ data["audio"]["array"],
290
+ self.instruction,
291
+ answer_text
292
+ )
293
+
294
+
295
+ # Fleurs Dataset Class
296
+ class FleursDataset(BaseAudioDataset):
297
+ def __init__(self, processor, split, source_lang, target_lang=None,
298
+ mode="asr", sampling_rate=16000, debug=False):
299
+ super().__init__(processor, split, sampling_rate, debug)
300
+
301
+ self.set_dataset_name("Fleurs")
302
+ # Mode Setting (ASR or AST)
303
+ if mode not in ["asr", "ast"]:
304
+ raise ValueError("mode must be 'asr' or 'ast'.")
305
+
306
+ self.mode = mode
307
+ self.ast = (mode == "ast")
308
+ self.source_lang = source_lang
309
+
310
+ # Language name mapping (expand if needed)
311
+ self.lang_names = {
312
+ 'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
313
+ }
314
+
315
+ # load dataset - source language dataset
316
+ self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
317
+ source_lang,
318
+ split=split,
319
+ trust_remote_code=True,
320
+ cache_dir=Path("/mnt/jeff/InCar/data")
321
+ )
322
+ import opencc
323
+ converter = opencc.OpenCC('s2tw.json')
324
+ def prepare_dataset(batch):
325
+ transcription = converter.convert(batch["transcription"])
326
+ batch["transcription"] = transcription
327
+
328
+ return batch
329
+ if (source_lang=="cmn_hans_cn"):
330
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
331
+
332
+ # (Optional) Audio length Filtering
333
+ self.data = self.filter_by_audio_length(self.data, "audio")
334
+ self.target_lang_name = ""
335
+ # When AST mode, load target language dataset.
336
+ if self.ast:
337
+ if target_lang is None:
338
+ raise ValueError("AST mode requires target_lang.")
339
+
340
+ self.target_lang = target_lang
341
+ self.lang = f"{source_lang}_{target_lang}"
342
+
343
+ # load dataset - target language dataset (for translation)
344
+ target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
345
+ target_lang,
346
+ split=split,
347
+ trust_remote_code=True,
348
+ cache_dir=Path("/mnt/jeff/InCar/data")
349
+ )
350
+ if target_lang=="cmn_hans_cn":
351
+ target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
352
+ source_dict = {item['id']: item for item in self.data}
353
+ target_dict = {item['id']: item for item in target_data}
354
+
355
+ # only Common ID, add translation fields
356
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
357
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
358
+ self.data = [
359
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
360
+ for id in common_ids
361
+ ]
362
+
363
+ # Instruction Setting - use target language name
364
+ self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
365
+ self.instruction = random.choice(INSTRUCTION["ast"])
366
+ else:
367
+ # ASR mode
368
+ self.lang = source_lang
369
+ self.instruction = random.choice(INSTRUCTION["asr"])
370
+
371
+ if self.debug:
372
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
373
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
374
+ if self.ast:
375
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
376
+ print(f"dataset size: {len(self.data)}")
377
+
378
+ def __len__(self):
379
+ return len(self.data)
380
+
381
+ def __getitem__(self, idx):
382
+ data = self.data[idx]
383
+ audio_array = data["audio"]["array"]
384
+
385
+ if self.ast:
386
+ answer_text = data["translation"]
387
+ else:
388
+ answer_text = data["transcription"]
389
+
390
+ return self.prepare_model_inputs(
391
+ audio_array,
392
+ self.instruction.format(self.target_lang_name),
393
+ answer_text
394
+ )
395
+
396
+ class TWCostumData(BaseAudioDataset):
397
+
398
+ def __init__(self, processor, split="train", sampling_rate=16000,csv_path="", debug=False):
399
+ super().__init__(processor, split, sampling_rate, debug)
400
+ import pandas as pd
401
+ from datasets import Dataset, Audio
402
+
403
+
404
+ df = pd.read_csv(csv_path).fillna('')
405
+
406
+
407
+ self.set_dataset_name(f"TWCostumData")
408
+ self.data = Dataset.from_dict(
409
+ {
410
+ "audio": [audio for audio in df['audio']],
411
+ "sentence": [text for text in df['text']]
412
+ }
413
+ ).cast_column("audio", Audio(sampling_rate=16000))
414
+
415
+ # Instruction Setting
416
+ self.instruction = random.choice(INSTRUCTION["asr"])
417
+
418
+ def __len__(self):
419
+ return len(self.data)
420
+
421
+ def __getitem__(self, idx):
422
+ data = self.data[idx]
423
+
424
+ answer_text = data["sentence"]
425
+ return self.prepare_model_inputs(
426
+ data["audio"]["array"],
427
+ self.instruction,
428
+ answer_text
429
+ )
430
+ def covost_collate_fn(batch):
431
+ input_ids_list = []
432
+ labels_list = []
433
+ token_type_ids_list = []
434
+ input_audio_embeds_list = []
435
+ audio_embed_sizes_list = []
436
+ audio_attention_mask_list = []
437
+ input_modes_list = []
438
+ audio_paths = []
439
+ for inputs in batch:
440
+ if 'audio_path' in inputs:
441
+ audio_paths.append(inputs['audio_path'])
442
+ input_ids_list.append(inputs['input_ids'][0])
443
+ labels_list.append(inputs['labels'][0])
444
+ token_type_ids_list.append(inputs['token_type_ids'][0])
445
+ if inputs['input_modes']==2:
446
+ input_audio_embeds_list.append(inputs['input_audio_embeds'])
447
+ audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
448
+ audio_attention_mask_list.append(
449
+ inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
450
+ )
451
+ # else:
452
+ # input_audio_embeds_list.append(None)
453
+ # audio_embed_sizes_list.append(None)
454
+ # audio_attention_mask_list.append(None)
455
+ input_modes_list.append(inputs['input_modes'])
456
+ # try:
457
+ token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
458
+ input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
459
+ labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
460
+ audio_attention_mask = (
461
+ pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
462
+ if len(audio_attention_mask_list) > 1
463
+ else None
464
+ )
465
+ # except Exception as e:
466
+ # print(e)
467
+ # print(input_ids_list)
468
+ # print(labels_list)
469
+ # raise
470
+ attention_mask = (input_ids != 0).long()
471
+ input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) if len(input_audio_embeds_list)>0 else None
472
+ audio_embed_sizes = torch.cat(audio_embed_sizes_list) if len(audio_embed_sizes_list)>0 else None
473
+ input_modes = torch.cat(input_modes_list)
474
+ if len(audio_paths)>0:
475
+ return BatchFeature(
476
+ {
477
+ "audio_path": audio_paths,
478
+ 'input_ids': input_ids,
479
+ 'labels': labels,
480
+ 'token_type_ids': token_type_ids,
481
+ 'attention_mask': attention_mask,
482
+ 'input_audio_embeds': input_audio_embeds,
483
+ 'audio_embed_sizes': audio_embed_sizes,
484
+ 'audio_attention_mask': audio_attention_mask,
485
+ 'input_modes': input_modes,
486
+ }
487
+ )
488
+ else:
489
+ return BatchFeature(
490
+ {
491
+ 'input_ids': input_ids,
492
+ 'labels': labels,
493
+ 'token_type_ids': token_type_ids,
494
+ 'attention_mask': attention_mask,
495
+ 'input_audio_embeds': input_audio_embeds,
496
+ 'audio_embed_sizes': audio_embed_sizes,
497
+ 'audio_attention_mask': audio_attention_mask,
498
+ 'input_modes': input_modes,
499
+ }
500
+ )
501
+
502
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
503
+ """
504
+ Pad a list of sequences to the same length.
505
+ sequences: list of tensors in [seq_len, *] shape
506
+ """
507
+ assert padding_side in ['right', 'left']
508
+ max_size = sequences[0].size()
509
+ trailing_dims = max_size[1:]
510
+ max_len = max(len(seq) for seq in sequences)
511
+ batch_size = len(sequences)
512
+ output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
513
+ for i, seq in enumerate(sequences):
514
+ length = seq.size(0)
515
+ if padding_side == 'right':
516
+ output.data[i, :length] = seq
517
+ else:
518
+ output.data[i, -length:] = seq
519
+ return output
520
+
521
+ def cat_with_pad(tensors, dim, padding_value=0):
522
+ """
523
+ cat along dim, while pad to max for all other dims
524
+ """
525
+ ndim = tensors[0].dim()
526
+ assert all(
527
+ t.dim() == ndim for t in tensors[1:]
528
+ ), 'All tensors must have the same number of dimensions'
529
+
530
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
531
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
532
+ output = tensors[0].new_full(out_size, padding_value)
533
+
534
+ index = 0
535
+ for t in tensors:
536
+ # Create a slice list where every dimension except dim is full slice
537
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
538
+ # Update only the concat dimension slice
539
+ slices[dim] = slice(index, index + t.shape[dim])
540
+
541
+ output[slices] = t
542
+ index += t.shape[dim]
543
+
544
+ return output
545
+
546
+
547
+
548
+ class MultiturnAudioDataset(BaseAudioDataset):
549
+ def __init__(self, processor, split="train", sampling_rate=16000,json_path="",text_only=False, debug=False):
550
+ super().__init__(processor, split, sampling_rate, debug)
551
+ from llamafactory.data.template import Llama2Template,parse_template
552
+ from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
553
+ from llamafactory.data.mm_plugin import get_mm_plugin
554
+ import json
555
+ self.train=False
556
+ self.text_only=text_only
557
+ with open(json_path) as f:
558
+ js_data = json.load(f)
559
+ if split=='train':
560
+ self.train=True
561
+ js_data = js_data[:int(len(js_data)*0.8)]
562
+ else:
563
+ js_data = js_data[-int(len(js_data)*0.2):]
564
+ for conv in js_data:
565
+ for mess in conv['conversations']:
566
+ if 'audio_path' in mess:
567
+ mess['audio_path'] = mess['audio_path'].replace('/home/jeff/codes/llm/InCar/srdc_generate_tts/','/mnt/jeff/InCar/data/multiturn_data/')
568
+ default_system = ""#"""You are a helpful assistant that determines how to solve problems based on user needs and converts user speech into text.\n"""
569
+ self.template=Llama2Template(
570
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
571
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
572
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
573
+ format_function=FunctionFormatter(slots=["{{content}}", {"eos_token"}], tool_format="default"),
574
+ format_tools = ToolFormatter(tool_format="default"),
575
+ format_observation=StringFormatter(
576
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
577
+ ),
578
+ default_system=default_system,
579
+ thought_words=("<think>", "</think>"),
580
+ efficient_eos=False,
581
+ replace_eos=False,
582
+ replace_jinja_template=False,
583
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
584
+ stop_words=["<end_of_turn>"],
585
+ mm_plugin=get_mm_plugin(name="base"),
586
+ )
587
+
588
+ self.set_dataset_name(f"MultiturnCostumData")
589
+
590
+
591
+ self.data = []
592
+ self.text_only_data = []
593
+ for conv in js_data:
594
+ tools = conv['tools'] if 'tools' in conv else ""
595
+ system = conv['system'] if 'system' in conv else default_system
596
+ tmp = {
597
+ 'tools':tools,
598
+ 'system':system,
599
+ 'messages':[],
600
+ }
601
+ for i,mess in enumerate(conv['conversations']):
602
+ tmp['messages'].append(mess)
603
+ if mess['from']=='human':
604
+ tmp['messages'].append(conv['conversations'][i+1])
605
+ d = deepcopy(tmp)
606
+ d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0]
607
+ self.data.append(d)
608
+ if self.text_only:
609
+ self.text_only_data.append(deepcopy(tmp))
610
+ tmp['messages'].pop()
611
+ elif mess['from']=='observation':
612
+ tmp['messages'].append(conv['conversations'][i+1])
613
+ d = deepcopy(tmp)
614
+ self.text_only_data.append(d)
615
+ tmp['messages'].pop()
616
+ if text_only:
617
+ self.data=self.text_only_data
618
+
619
+
620
+ def prepare_multiturn_model_inputs(self, audio_array, messages, system="", tools=""):
621
+ ANSWER_SUFFIX = "<end_of_turn>"
622
+ prompt = ""
623
+ answer_text = ""
624
+ user_transcribe = ""
625
+ audio_paths = []
626
+ for i, message in enumerate(messages):
627
+ elements = []
628
+
629
+ system_text = ""
630
+ if i == 0:
631
+ elements += self.template.format_prefix.apply()
632
+ if system or tools:
633
+ tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
634
+ system_text = self.template.format_system.apply(content=(system + tool_text))[0]
635
+
636
+ if message["from"] == "human":
637
+ if i==len(messages)-2 and not self.text_only:
638
+ user_transcribe = message["value"]
639
+ elements += self.template.format_user.apply(content=system_text+'<start_of_audio>')
640
+ else:
641
+ elements += self.template.format_user.apply(content=system_text + message["value"])
642
+ audio_paths.append(message['audio_path'])
643
+ elif message["from"] == "gpt":
644
+ elements += self.template.format_assistant.apply(content=message["value"])
645
+ elif message["from"] == "observation":
646
+ elements += self.template.format_observation.apply(content=message["value"])
647
+ elif message["from"] == "function_call":
648
+ elements += self.template.format_function.apply(content=message["value"])
649
+ else:
650
+ raise NotImplementedError("Unexpected role: {}".format(message["from"]))
651
+
652
+
653
+ for elem in elements:
654
+ ele_str = ""
655
+ if isinstance(elem, str):
656
+ ele_str=elem
657
+ elif isinstance(elem, set):
658
+ if "bos_token" in elem and self.processor.tokenizer.bos_token_id is not None:
659
+ ele_str = self.processor.tokenizer.bos_token
660
+ elif "eos_token" in elem and self.processor.tokenizer.eos_token_id is not None:
661
+ ele_str = self.processor.tokenizer.eos_token
662
+ if i == len(messages)-1:
663
+ answer_text+=ele_str
664
+ else:
665
+ prompt+=ele_str
666
+
667
+
668
+ if type(audio_array)!=type(None):
669
+ inputs = self.processor(
670
+ text=prompt,
671
+ audio=[audio_array],
672
+ add_special_tokens=False,
673
+ return_tensors='pt'
674
+ )
675
+ answer = "\nUser transcribe is : {};\nGPT output is : {}{}".format(user_transcribe,answer_text,ANSWER_SUFFIX)
676
+ else:
677
+ inputs = self.processor(
678
+ text=prompt,
679
+ audio=None,
680
+ add_special_tokens=False,
681
+ return_tensors='pt'
682
+ )
683
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
684
+ # print('user_transcribe',user_transcribe)
685
+ # print('answer_text', answer)
686
+ # print('prompt',prompt)
687
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
688
+
689
+ if self.debug:
690
+ self.debug = False
691
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
692
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
693
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
694
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
695
+
696
+ if self.training:
697
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
698
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
699
+ labels[:, -answer_ids.shape[1]:] = answer_ids
700
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
701
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
702
+ else:
703
+ input_ids = inputs.input_ids
704
+ labels = answer_ids
705
+ token_type_ids = inputs.token_type_ids
706
+ if type(audio_array)!=type(None):
707
+ if not self.train:
708
+ return {
709
+ "audio_path": audio_paths,
710
+ 'input_ids': input_ids,
711
+ 'labels': labels,
712
+ 'token_type_ids': token_type_ids,
713
+ 'input_audio_embeds': inputs.input_audio_embeds,
714
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
715
+ 'input_modes': inputs.input_modes,
716
+ }
717
+ else:
718
+ return {
719
+ 'input_ids': input_ids,
720
+ 'labels': labels,
721
+ 'token_type_ids': token_type_ids,
722
+ 'input_audio_embeds': inputs.input_audio_embeds,
723
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
724
+ 'input_modes': inputs.input_modes,
725
+ }
726
+ else:
727
+ return {
728
+ 'input_ids': input_ids,
729
+ 'labels': labels,
730
+ 'token_type_ids': token_type_ids,
731
+ 'input_audio_embeds': None,
732
+ 'audio_embed_sizes': None,
733
+ 'input_modes': inputs.input_modes,
734
+ }
735
+ def __len__(self):
736
+ return len(self.data)
737
+
738
+ def __getitem__(self, idx):
739
+ data = self.data[idx]
740
+ return self.prepare_multiturn_model_inputs(
741
+ audio_array=data["audio_array"] if "audio_array" in data else None,
742
+ messages=data['messages'],
743
+ system=data["system"],
744
+ tools=data["tools"]
745
+ )
746
+
747
+
748
+
749
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
750
+
751
+ INSTRUCTION = {
752
+ "ast": [
753
+ "Translate the audio to {0}.",
754
+ "Translate the audio clip into {0}.",
755
+ "Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
756
+ "Translate the provided audio file into {0}.",
757
+ "Convert the audio speech to {0} text.",
758
+ "Write an {0} translation of the audio file.",
759
+ "Translate spoken words from the audio into {0}.",
760
+ "Create an {0} version of the audio content.",
761
+ "Produce an accurate {0} translation of the audio.",
762
+ "Extract speech from the audio and translate it to {0}.",
763
+ "Turn the audio into readable {0} text.",
764
+ "Write all spoken content from the audio in {0}.",
765
+ "Generate an {0} translation of the speech in the file.",
766
+ "Convert the recording into {0} text.",
767
+ "Accurately translate the audio recording to {0}.",
768
+ "Write down dialogue from the given audio in {0}.",
769
+ "Translate all speech in this audio file to {0}.",
770
+ "Create an accurate {0} version of the speech.",
771
+ "Perform a complete {0} translation of the audio."
772
+ ],
773
+ "asr": [
774
+ "Transcribe the audio clip into text.",
775
+ "Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
776
+ "Transcribe the provided audio file into text.",
777
+ "Convert the audio speech to text.",
778
+ "Write a transcript of the audio file.",
779
+ "Transcribe spoken words from the audio.",
780
+ "Create a text version of the audio content.",
781
+ "Produce a verbatim transcript of the audio.",
782
+ "Extract and transcribe speech from the audio.",
783
+ "Turn the audio into readable text.",
784
+ "Write all spoken words from the audio.",
785
+ "Generate a transcript of the speech in the file.",
786
+ "Convert the recording into a text transcript.",
787
+ "Accurately transcribe the audio recording.",
788
+ "Write down dialogue from the given audio.",
789
+ "Transcribe all speech in this audio file.",
790
+ "Create an accurate text version of the speech.",
791
+ "Perform a complete transcription of the audio."
792
+ ],
793
+ }
cpp/gemma_v1/__pycache__/ASRDataset.cpython-312.pyc ADDED
Binary file (36.9 kB). View file
 
cpp/gemma_v1/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
cpp/gemma_v1/chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'audio' -%}\n {{ '<start_of_audio>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
3
+ }
cpp/gemma_v1/config.json ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3OmniForConditionalGeneration"
4
+ ],
5
+ "audio_config": {
6
+ "activation": "swish",
7
+ "activation_checkpointing": {
8
+ "interval": 1,
9
+ "module": "transformer",
10
+ "offload": false
11
+ },
12
+ "attention_dim": 1024,
13
+ "attention_heads": 16,
14
+ "batch_norm": false,
15
+ "bias_in_glu": true,
16
+ "causal": true,
17
+ "chunk_size": -1,
18
+ "cnn_layer_norm": true,
19
+ "conv_activation": "swish",
20
+ "conv_glu_type": "swish",
21
+ "depthwise_multiplier": 1,
22
+ "depthwise_seperable_out_channel": 1024,
23
+ "dropout_rate": 0.0,
24
+ "encoder_embedding_config": {
25
+ "input_size": 80
26
+ },
27
+ "ext_pw_kernel_size": 1,
28
+ "ext_pw_out_channel": 1024,
29
+ "input_layer": "nemo_conv",
30
+ "input_size": 80,
31
+ "kernel_size": 3,
32
+ "left_chunk": 18,
33
+ "linear_units": 1536,
34
+ "nemo_conv_settings": {
35
+ "conv_channels": 1024
36
+ },
37
+ "num_blocks": 24,
38
+ "relative_attention_bias_args": {
39
+ "t5_bias_max_distance": 500,
40
+ "type": "t5"
41
+ },
42
+ "time_reduction": 8
43
+ },
44
+ "audio_token_index": 262143,
45
+ "auto_map": {
46
+ "AutoConfig": "configuration_gemma3omni.Gemma3OmniConfig",
47
+ "AutoModel": "modeling_gemma3omni.Gemma3OmniForConditionalGeneration"
48
+ },
49
+ "boa_token_index": 256001,
50
+ "boi_token_index": 255999,
51
+ "eoa_token_index": 256002,
52
+ "eoi_token_index": 256000,
53
+ "eos_token_id": [
54
+ 1,
55
+ 106
56
+ ],
57
+ "image_token_index": 262144,
58
+ "initializer_range": 0.02,
59
+ "mm_tokens_per_image": 256,
60
+ "model_type": "gemma3omni",
61
+ "speech_lora": {
62
+ "dp": 0.01,
63
+ "layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
64
+ "lora_alpha": 320,
65
+ "r": 320,
66
+ "use_rslora": true
67
+ },
68
+ "text_lora": {
69
+ "dp": 0.01,
70
+ "layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
71
+ "lora_alpha": 16,
72
+ "r": 8,
73
+ "use_rslora": true
74
+ },
75
+ "text_config": {
76
+ "attention_bias": false,
77
+ "attention_dropout": 0.0,
78
+ "attn_logit_softcapping": null,
79
+ "cache_implementation": "hybrid",
80
+ "final_logit_softcapping": null,
81
+ "head_dim": 256,
82
+ "hidden_activation": "gelu_pytorch_tanh",
83
+ "hidden_size": 2560,
84
+ "initializer_range": 0.02,
85
+ "intermediate_size": 10240,
86
+ "max_position_embeddings": 131072,
87
+ "model_type": "gemma3_text",
88
+ "num_attention_heads": 8,
89
+ "num_hidden_layers": 34,
90
+ "num_key_value_heads": 4,
91
+ "query_pre_attn_scalar": 256,
92
+ "rms_norm_eps": 1e-06,
93
+ "rope_local_base_freq": 10000.0,
94
+ "rope_scaling": {
95
+ "factor": 8.0,
96
+ "rope_type": "linear"
97
+ },
98
+ "rope_theta": 1000000.0,
99
+ "sliding_window": 1024,
100
+ "sliding_window_pattern": 6,
101
+ "torch_dtype": "float",
102
+ "use_cache": true,
103
+ "vocab_size": 262208
104
+ },
105
+ "torch_dtype": "float",
106
+ "transformers_version": "4.51.3",
107
+ "use_cache": false,
108
+ "vision_config": {
109
+ "hidden_size": 1152,
110
+ "image_size": 896,
111
+ "intermediate_size": 4304,
112
+ "model_type": "siglip_vision_model",
113
+ "num_attention_heads": 16,
114
+ "num_hidden_layers": 27,
115
+ "patch_size": 14,
116
+ "vision_use_head": false
117
+ }
118
+ }
cpp/gemma_v1/configuration_gemma3omni.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoConfig, Gemma3TextConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.modeling_rope_utils import rope_config_validation
6
+ from transformers.utils import logging
7
+ from transformers.models.siglip import SiglipVisionConfig
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ class AudioConfig(PretrainedConfig):
13
+ model_type = "gemma3_audio"
14
+
15
+ def __init__(
16
+ self,
17
+ input_size=80,
18
+ attention_dim=1024,
19
+ attention_heads=16,
20
+ num_blocks=24,
21
+ linear_units=1536,
22
+ dropout_rate=0.0,
23
+ kernel_size=3,
24
+ ext_pw_kernel_size=1,
25
+ ext_pw_out_channel=1024,
26
+ depthwise_seperable_out_channel=1024,
27
+ depthwise_multiplier=1,
28
+ activation="swish",
29
+ conv_activation="swish",
30
+ conv_glu_type="swish",
31
+ bias_in_glu=True,
32
+ causal=True,
33
+ batch_norm=False,
34
+ cnn_layer_norm=True,
35
+ time_reduction=8,
36
+ input_layer="nemo_conv",
37
+ nemo_conv_settings=None,
38
+ chunk_size=-1,
39
+ left_chunk=18,
40
+ relative_attention_bias_args=None,
41
+ activation_checkpointing=None,
42
+ encoder_embedding_config=None,
43
+ **kwargs
44
+ ):
45
+ super().__init__(**kwargs)
46
+
47
+ self.input_size = input_size
48
+ self.attention_dim = attention_dim
49
+ self.attention_heads = attention_heads
50
+ self.num_blocks = num_blocks
51
+ self.linear_units = linear_units
52
+ self.dropout_rate = dropout_rate
53
+ self.kernel_size = kernel_size
54
+ self.ext_pw_kernel_size = ext_pw_kernel_size
55
+ self.ext_pw_out_channel = ext_pw_out_channel
56
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
57
+ self.depthwise_multiplier = depthwise_multiplier
58
+ self.activation = activation
59
+ self.conv_activation = conv_activation
60
+ self.conv_glu_type = conv_glu_type
61
+ self.bias_in_glu = bias_in_glu
62
+ self.causal = causal
63
+ self.batch_norm = batch_norm
64
+ self.cnn_layer_norm = cnn_layer_norm
65
+ self.time_reduction = time_reduction
66
+ self.input_layer = input_layer
67
+
68
+ if nemo_conv_settings is None:
69
+ self.nemo_conv_settings = {"conv_channels": 1024}
70
+ else:
71
+ self.nemo_conv_settings = nemo_conv_settings
72
+
73
+ self.chunk_size = chunk_size
74
+ self.left_chunk = left_chunk
75
+
76
+ if relative_attention_bias_args is None:
77
+ self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500}
78
+ else:
79
+ self.relative_attention_bias_args = relative_attention_bias_args
80
+
81
+ if activation_checkpointing is None:
82
+ self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False}
83
+ else:
84
+ self.activation_checkpointing = activation_checkpointing
85
+
86
+ if encoder_embedding_config is None:
87
+ self.encoder_embedding_config = {"input_size": input_size}
88
+ else:
89
+ self.encoder_embedding_config = encoder_embedding_config
90
+
91
+
92
+ class Gemma3OmniConfig(PretrainedConfig):
93
+ r"""
94
+ This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
95
+ Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
96
+ with the defaults will yield a similar configuration to that of the PaliGemma-2B.
97
+
98
+ e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
99
+
100
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
101
+ documentation from [`PretrainedConfig`] for more information.
102
+
103
+ Args:
104
+ text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
105
+ The config object of the text backbone.
106
+ vision_config (`Union[AutoConfig, dict]`, *optional*):
107
+ Custom vision config or dict.
108
+ audio_config (`Union[AutoConfig, dict]`, *optional*):
109
+ Custom audio config or dict.
110
+ mm_tokens_per_image (`int`, *optional*, defaults to 256):
111
+ The number of tokens per image embedding.
112
+ boi_token_index (`int`, *optional*, defaults to 255999):
113
+ The begin-of-image token index to wrap the image prompt.
114
+ eoi_token_index (`int`, *optional*, defaults to 256000):
115
+ The end-of-image token index to wrap the image prompt.
116
+ image_token_index (`int`, *optional*, defaults to 262144):
117
+ The image token index to encode the image prompt.
118
+ audio_token_index (`int`, *optional*, defaults to 262145):
119
+ The audio token index to encode the audio prompt.
120
+ initializer_range (`float`, *optional*, defaults to 0.02):
121
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
122
+
123
+
124
+ Example:
125
+
126
+ ```python
127
+ >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
128
+
129
+ >>> # Initializing a Siglip-like vision config
130
+ >>> vision_config = SiglipVisionConfig()
131
+
132
+ >>> # Initializing a Siglip-like vision config
133
+ >>> audio_config = AudioConfig()
134
+
135
+ >>> # Initializing a Gemma3 Text config
136
+ >>> text_config = Gemma3TextConfig()
137
+
138
+ >>> # Initializing a Gemma3 gemma-3-4b style configuration
139
+ >>> configuration = Gemma3Config(vision_config, text_config)
140
+
141
+ >>> # Initializing a model from the gemma-3-4b style configuration
142
+ >>> model = Gemma3TextConfig(configuration)
143
+
144
+ >>> # Accessing the model configuration
145
+ >>> configuration = model.config
146
+ ```"""
147
+
148
+ model_type = "gemma3omni"
149
+ sub_configs = {
150
+ "text_config": Gemma3TextConfig,
151
+ "vision_config": SiglipVisionConfig,
152
+ "audio_config": AudioConfig,
153
+ }
154
+
155
+ def __init__(
156
+ self,
157
+ text_config: Optional[Gemma3TextConfig] = None,
158
+ vision_config: Optional[SiglipVisionConfig] = None,
159
+ audio_config: Optional[AudioConfig] = None,
160
+ mm_tokens_per_image: int = 256,
161
+ boi_token_index: int = 255_999,
162
+ eoi_token_index: int = 256_000,
163
+ boa_token_index: int = 256_001,
164
+ eoa_token_index: int = 256_002,
165
+ image_token_index: int = 262_144,
166
+ audio_token_index: int = 262_143,
167
+ initializer_range: float = 0.02,
168
+ **kwargs,
169
+ ):
170
+ if text_config is None:
171
+ text_config = Gemma3TextConfig()
172
+ logger.info("text_config is None, using default Gemma3TextConfig vision config.")
173
+ elif isinstance(text_config, dict):
174
+ text_config = Gemma3TextConfig(**text_config)
175
+
176
+ if isinstance(vision_config, dict):
177
+ vision_config = SiglipVisionConfig(**vision_config)
178
+ else:
179
+ vision_config = SiglipVisionConfig()
180
+ logger.info(
181
+ "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
182
+ "to text tasks."
183
+ )
184
+
185
+ if isinstance(audio_config, dict):
186
+ audio_config = AudioConfig(**audio_config)
187
+ else:
188
+ audio_config = AudioConfig()
189
+ logger.info(
190
+ "audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited "
191
+ "to text tasks."
192
+ )
193
+
194
+ self.text_config = text_config
195
+ self.vision_config = vision_config
196
+ self.audio_config = audio_config
197
+ self.mm_tokens_per_image = mm_tokens_per_image
198
+ self.boi_token_index = boi_token_index
199
+ self.eoi_token_index = eoi_token_index
200
+ self.boa_token_index = boa_token_index
201
+ self.eoa_token_index = eoa_token_index
202
+ self.image_token_index = image_token_index
203
+ self.audio_token_index = audio_token_index
204
+ self.initializer_range = initializer_range
205
+
206
+ super().__init__(**kwargs)
cpp/gemma_v1/eval.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from urllib.request import urlopen
3
+ import soundfile
4
+ import torch
5
+ from datasets import load_dataset, Audio
6
+ import numpy as np
7
+ from transformers import AutoModel, AutoProcessor, BatchFeature
8
+ from tqdm import tqdm
9
+ import json
10
+ import os
11
+ import time
12
+ from datetime import datetime
13
+ from whisper_normalizer.english import EnglishTextNormalizer
14
+ from whisper_normalizer.basic import BasicTextNormalizer
15
+ import sacrebleu
16
+ from jiwer import cer, wer
17
+ from torch.utils.data import Dataset, DataLoader
18
+ import soundfile as sf
19
+ import re
20
+ from pathlib import Path
21
+ import opencc
22
+ converter = opencc.OpenCC('s2tw.json')
23
+ normalizer = {
24
+ "en_us" : EnglishTextNormalizer(),
25
+ "other" : BasicTextNormalizer()
26
+ }
27
+
28
+ model_id = "/mnt/jeff/gemma_test"
29
+ revision = "main" #"v1.0"
30
+
31
+ model = AutoModel.from_pretrained(
32
+ model_id, device_map="cuda", revision = revision, trust_remote_code=True
33
+ ).eval()
34
+
35
+ processor = AutoProcessor.from_pretrained(
36
+ model_id, revision = revision, trust_remote_code=True
37
+ )
38
+
39
+ results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
40
+ os.makedirs(results_dir, exist_ok=True)
41
+
42
+
43
+ INSTRUCTION = {
44
+ "ast": "Translate the audio to {0}.",
45
+ "asr": "Transcribe the audio clip into text.",
46
+ }
47
+
48
+ class BaseAudioDataset(Dataset):
49
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
50
+ self.processor = processor
51
+ self.training = "train" in split
52
+ self.debug = debug
53
+ self.sampling_rate = sampling_rate
54
+ self.name = ""
55
+
56
+ def set_dataset_name(self, name):
57
+ self.name = name
58
+
59
+ @staticmethod
60
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
61
+ original_size = len(data)
62
+
63
+ data = data.cast_column(audio_field, Audio(decode=False))
64
+
65
+ def identify_corrupted_files(example):
66
+ try:
67
+ sf.read(example[audio_field]["path"])
68
+
69
+ for field in text_fields:
70
+ if example[field].replace('"', '') == "":
71
+ return False
72
+ return True
73
+ except Exception:
74
+ return False
75
+
76
+ data = data.filter(identify_corrupted_files, num_proc=16)
77
+ validated_size = len(data)
78
+
79
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
80
+
81
+
82
+ return data
83
+
84
+ @staticmethod
85
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
86
+ original_size = len(data)
87
+
88
+ def filter_audio_by_length(example):
89
+ try:
90
+ audio = example[audio_field]['array']
91
+ channel = 1
92
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
93
+ channel = audio.ndim
94
+ audio = audio.squeeze()
95
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
96
+ return min_sec <= audio_length <= max_sec
97
+ except Exception as e:
98
+ return False
99
+
100
+ data = data.filter(filter_audio_by_length, num_proc=16)
101
+ filtered_size = len(data)
102
+
103
+ return data
104
+
105
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
106
+ user_message = {
107
+ 'role': 'user',
108
+ 'content': '<start_of_audio>' + instruction,
109
+ }
110
+ prompt = self.processor.tokenizer.apply_chat_template(
111
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
112
+ )
113
+
114
+ inputs = self.processor(
115
+ text=prompt,
116
+ audio=[audio_array],
117
+ add_special_tokens=False,
118
+ return_tensors='pt'
119
+ )
120
+
121
+ input_ids = inputs.input_ids
122
+ token_type_ids = inputs.token_type_ids
123
+
124
+ return {
125
+ 'input_ids': input_ids,
126
+ 'token_type_ids': token_type_ids,
127
+ 'input_audio_embeds': inputs.input_audio_embeds,
128
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
129
+ 'input_modes': inputs.input_modes,
130
+ 'answer': answer_text,
131
+ }
132
+
133
+
134
+ # Libri Speech Dataset Class
135
+ class LibriSpeechDataset(BaseAudioDataset):
136
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
137
+ super().__init__(processor, split, sampling_rate, debug)
138
+
139
+ self.set_dataset_name(f"LibriSpeech_{subset}")
140
+ # only ASR
141
+ self.ast = False
142
+ self.lang = "en"
143
+
144
+ # load dataset
145
+ self.data = load_dataset("openslr/librispeech_asr",
146
+ subset,
147
+ split=split,
148
+ trust_remote_code=True,
149
+ cache_dir=Path("/mnt/jeff/InCar/data")
150
+ )
151
+
152
+ # (Optional) Audio length Filtering
153
+ self.data = self.filter_by_audio_length(self.data, "audio")
154
+
155
+ # Instruction Setting
156
+ self.instruction = INSTRUCTION["asr"]
157
+
158
+ def __len__(self):
159
+ return len(self.data)
160
+
161
+ def __getitem__(self, idx):
162
+ data = self.data[idx]
163
+
164
+ # Libri Speech is only for ASR
165
+ answer_text = data["text"].replace('"', '')
166
+
167
+ return self.prepare_model_inputs(
168
+ data["audio"]["array"],
169
+ INSTRUCTION["asr"],
170
+ answer_text
171
+ )
172
+
173
+ # common_voice_16_1 dataset
174
+ class CommonVoiceDataset(BaseAudioDataset):
175
+ def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
176
+ super().__init__(processor, split, sampling_rate, debug)
177
+
178
+ self.set_dataset_name(f"CommonVoice_{source_lang}")
179
+ # only ASR
180
+ self.ast = False
181
+ self.lang=source_lang
182
+
183
+ # load dataset
184
+ self.data = load_dataset("mozilla-foundation/common_voice_16_1",
185
+ source_lang,
186
+ split=split,
187
+ trust_remote_code=True,
188
+ cache_dir=Path("/mnt/jeff/InCar/data")
189
+ )
190
+ def prepare_dataset(batch):
191
+ """Function to preprocess the dataset with the .map method"""
192
+ transcription = batch["sentence"]
193
+
194
+ if transcription.startswith('"') and transcription.endswith('"'):
195
+ # we can remove trailing quotation marks as they do not affect the transcription
196
+ transcription = transcription[1:-1]
197
+
198
+ if transcription[-1] not in [".", "?", "!"]:
199
+ # append a full-stop to sentences that do not end in punctuation
200
+ transcription = transcription + "."
201
+
202
+ batch["sentence"] = transcription
203
+
204
+ return batch
205
+ self.data=self.data.map(prepare_dataset, desc="preprocess dataset")
206
+
207
+ # (Optional) Audio length Filtering
208
+ self.data = self.filter_by_audio_length(self.data, "audio")
209
+
210
+ # Instruction Setting
211
+ self.instruction = INSTRUCTION["asr"]
212
+
213
+ def __len__(self):
214
+ return len(self.data)
215
+
216
+ def __getitem__(self, idx):
217
+ data = self.data[idx]
218
+
219
+ answer_text = data["sentence"]
220
+ return self.prepare_model_inputs(
221
+ data["audio"]["array"],
222
+ INSTRUCTION["asr"],
223
+ answer_text
224
+ )
225
+
226
+
227
+ # Fleurs Dataset Class
228
+ class FleursDataset(BaseAudioDataset):
229
+ def __init__(self, processor, split, source_lang, target_lang=None,
230
+ mode="asr", sampling_rate=16000, debug=False):
231
+ super().__init__(processor, split, sampling_rate, debug)
232
+
233
+ self.set_dataset_name("Fleurs")
234
+ # Mode Setting (ASR or AST)
235
+ if mode not in ["asr", "ast"]:
236
+ raise ValueError("mode must be 'asr' or 'ast'.")
237
+
238
+ self.mode = mode
239
+ self.ast = (mode == "ast")
240
+ self.source_lang = source_lang
241
+
242
+ # Language name mapping (expand if needed)
243
+ self.lang_names = {
244
+ 'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
245
+ }
246
+
247
+ # load dataset - source language dataset
248
+ self.data = load_dataset("google/fleurs",
249
+ source_lang,
250
+ split=split,
251
+ trust_remote_code=True,
252
+ cache_dir=Path("/mnt/jeff/InCar/data")
253
+ )
254
+ def prepare_dataset(batch):
255
+ import opencc
256
+ converter = opencc.OpenCC('s2tw.json')
257
+ if self.ast:
258
+ translation = converter.convert(batch["translation"])
259
+ batch["translation"] = translation
260
+ else:
261
+ transcription = converter.convert(batch["transcription"])
262
+ batch["transcription"] = transcription
263
+
264
+ return batch
265
+ if (source_lang=="cmn_hans_cn" and not self.ast) or (self.ast and target_lang=="cmn_hans_cn"):
266
+ self.data=self.data.map(prepare_dataset, desc="preprocess dataset")
267
+
268
+ # (Optional) Audio length Filtering
269
+ self.data = self.filter_by_audio_length(self.data, "audio")
270
+ self.target_lang_name = ""
271
+ # When AST mode, load target language dataset.
272
+ if self.ast:
273
+ if target_lang is None:
274
+ raise ValueError("AST mode requires target_lang.")
275
+
276
+ self.target_lang = target_lang
277
+ self.lang = f"{source_lang}_{target_lang}"
278
+
279
+ # load dataset - target language dataset (for translation)
280
+ target_data = load_dataset("google/fleurs",
281
+ target_lang,
282
+ split=split,
283
+ trust_remote_code=True,
284
+ cache_dir=Path("/mnt/jeff/InCar/data")
285
+ )
286
+
287
+ source_dict = {item['id']: item for item in self.data}
288
+ target_dict = {item['id']: item for item in target_data}
289
+
290
+ # only Common ID, add translation fields
291
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
292
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
293
+ self.data = [
294
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
295
+ for id in common_ids
296
+ ]
297
+
298
+ # Instruction Setting - use target language name
299
+ self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
300
+ self.instruction = INSTRUCTION["ast"]
301
+ else:
302
+ # ASR mode
303
+ self.lang = source_lang
304
+ self.instruction = INSTRUCTION["asr"]
305
+
306
+ if self.debug:
307
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
308
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
309
+ if self.ast:
310
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
311
+ print(f"dataset size: {len(self.data)}")
312
+
313
+ def __len__(self):
314
+ return len(self.data)
315
+
316
+ def __getitem__(self, idx):
317
+ data = self.data[idx]
318
+ audio_array = data["audio"]["array"]
319
+
320
+ if self.ast:
321
+ answer_text = data["translation"]
322
+ else:
323
+ answer_text = data["transcription"]
324
+
325
+ return self.prepare_model_inputs(
326
+ audio_array,
327
+ self.instruction.format(self.target_lang_name),
328
+ answer_text
329
+ )
330
+
331
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
332
+ """
333
+ Pad a list of sequences to the same length.
334
+ sequences: list of tensors in [seq_len, *] shape
335
+ """
336
+ assert padding_side in ['right', 'left']
337
+ max_size = sequences[0].size()
338
+ trailing_dims = max_size[1:]
339
+ max_len = max(len(seq) for seq in sequences)
340
+ batch_size = len(sequences)
341
+ output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
342
+ for i, seq in enumerate(sequences):
343
+ length = seq.size(0)
344
+ if padding_side == 'right':
345
+ output.data[i, :length] = seq
346
+ else:
347
+ output.data[i, -length:] = seq
348
+ return output
349
+
350
+ def cat_with_pad(tensors, dim, padding_value=0):
351
+ """
352
+ cat along dim, while pad to max for all other dims
353
+ """
354
+ ndim = tensors[0].dim()
355
+ assert all(
356
+ t.dim() == ndim for t in tensors[1:]
357
+ ), 'All tensors must have the same number of dimensions'
358
+
359
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
360
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
361
+ output = tensors[0].new_full(out_size, padding_value)
362
+
363
+ index = 0
364
+ for t in tensors:
365
+ # Create a slice list where every dimension except dim is full slice
366
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
367
+ # Update only the concat dimension slice
368
+ slices[dim] = slice(index, index + t.shape[dim])
369
+
370
+ output[slices] = t
371
+ index += t.shape[dim]
372
+
373
+ return output
374
+
375
+ def covost_collate_fn(batch):
376
+ input_ids_list = []
377
+ input_audio_embeds_list = []
378
+ audio_embed_sizes_list = []
379
+ audio_attention_mask_list = []
380
+ input_modes_list = []
381
+ answer_list = []
382
+ for inputs in batch:
383
+ input_ids_list.append(inputs['input_ids'][0])
384
+ input_audio_embeds_list.append(inputs['input_audio_embeds'])
385
+ audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
386
+ audio_attention_mask_list.append(
387
+ inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
388
+ )
389
+ input_modes_list.append(inputs['input_modes'])
390
+ answer_list.append(inputs['answer'])
391
+
392
+ try:
393
+ input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
394
+ audio_attention_mask = (
395
+ pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
396
+ if len(audio_attention_mask_list) > 1
397
+ else None
398
+ )
399
+ except Exception as e:
400
+ print(e)
401
+ print(input_ids_list)
402
+ print(audio_attention_mask)
403
+ raise
404
+ attention_mask = (input_ids != 0).long()
405
+ input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
406
+ audio_embed_sizes = torch.cat(audio_embed_sizes_list)
407
+ input_modes = torch.cat(input_modes_list)
408
+
409
+ return BatchFeature(
410
+ {
411
+ 'input_ids': input_ids,
412
+ 'attention_mask': attention_mask,
413
+ 'input_audio_embeds': input_audio_embeds,
414
+ 'audio_embed_sizes': audio_embed_sizes,
415
+ 'audio_attention_mask': audio_attention_mask,
416
+ 'input_modes': input_modes,
417
+ 'answer': answer_list,
418
+ }
419
+ )
420
+
421
+ def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
422
+ filename = f"{task}_{dataset_name}_{source_lang}"
423
+ if target_lang:
424
+ filename += f"_to_{target_lang}"
425
+ if sample_idx is not None:
426
+ filename += f"_sample_{sample_idx}"
427
+
428
+ filepath = os.path.join(results_dir, f"{filename}.json")
429
+
430
+ results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
431
+
432
+ with open(filepath, 'w', encoding='utf-8') as f:
433
+ json.dump(results, f, ensure_ascii=False, indent=2)
434
+
435
+ return filepath
436
+
437
+ def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True):
438
+ task_type = "asr" if is_asr else "translation"
439
+ eval_lang = source_lang if is_asr else target_lang
440
+ if eval_lang in normalizer:
441
+ eval_normalizer = normalizer[eval_lang]
442
+ else:
443
+ eval_normalizer = normalizer['other']
444
+ sample_results = []
445
+
446
+ if num_samples > 0 and num_samples < len(dataset):
447
+ indices = np.random.choice(len(dataset), num_samples, replace=False)
448
+ dataset = dataset.select(indices)
449
+
450
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn)
451
+
452
+ evaluated_samples = {}
453
+
454
+ for batch_idx, batch in enumerate(tqdm(dataloader)):
455
+ batch_references = batch.pop("answer")
456
+
457
+ if torch.cuda.is_available():
458
+ try:
459
+ batch = {k: v.to("cuda") for k, v in batch.items()}
460
+ except:
461
+ print('error')
462
+ break
463
+
464
+ with torch.inference_mode():
465
+ generate_ids = model.generate(**batch,
466
+ max_new_tokens=256,
467
+ #temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
468
+ )
469
+
470
+ input_lengths = batch['input_ids'].shape[1]
471
+ generate_ids = generate_ids[:, input_lengths:]
472
+
473
+ batch_predictions = processor.batch_decode(
474
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
475
+ )
476
+
477
+ for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
478
+ idx = batch_idx * batch_size + i
479
+ sample_result = {
480
+ "id": idx,
481
+ "reference": reference,
482
+ "prediction": converter.convert(prediction)
483
+ }
484
+ sample_results.append(sample_result)
485
+
486
+ if (batch_idx + 1) % 10 == 0:
487
+ temp_results = []
488
+
489
+ for item in sample_results:
490
+ sample_id = item["id"]
491
+
492
+ if sample_id in evaluated_samples:
493
+ temp_item = item.copy()
494
+ temp_item.update(evaluated_samples[sample_id])
495
+ temp_results.append(temp_item)
496
+ else:
497
+ temp_item = item.copy()
498
+ try:
499
+ ref = eval_normalizer(item["reference"])
500
+ pred = eval_normalizer(item["prediction"])
501
+
502
+ # BLEU, WER/CER
503
+ utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
504
+ utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
505
+ utt_wer = round(wer(ref, pred) * 100, 2)
506
+
507
+ metrics = {
508
+ "bleu": utt_bleu,
509
+ "cer": min(100,utt_cer),
510
+ "wer": utt_wer
511
+ }
512
+
513
+ evaluated_samples[sample_id] = metrics
514
+ temp_item.update(metrics)
515
+ except Exception as e:
516
+ print(f"Error evaluating sample {sample_id}: {e}")
517
+ metrics = {
518
+ "bleu": 0,
519
+ "cer": 100,
520
+ "wer": 100,
521
+ "error": str(e)
522
+ }
523
+ evaluated_samples[sample_id] = metrics
524
+ temp_item.update(metrics)
525
+
526
+ temp_results.append(temp_item)
527
+
528
+ partial_results = {
529
+ "task": task_type,
530
+ "source_lang": source_lang,
531
+ "target_lang": target_lang,
532
+ "num_samples": len(temp_results),
533
+ "sample_results": temp_results
534
+ }
535
+ save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
536
+
537
+ for item in sample_results:
538
+ ref = eval_normalizer(item["reference"])
539
+ pred = eval_normalizer(item["prediction"])
540
+
541
+ utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
542
+ utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
543
+ utt_wer = round(wer(ref, pred) * 100, 2)
544
+
545
+ item.update({
546
+ "bleu": utt_bleu,
547
+ "cer": min(100,utt_cer),
548
+ "wer": utt_wer
549
+ })
550
+
551
+ avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
552
+ avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
553
+ avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
554
+
555
+ results = {
556
+ "dataset": dataset.name,
557
+ "task": task_type,
558
+ "source_lang": source_lang,
559
+ "target_lang": target_lang,
560
+ "num_samples": len(sample_results),
561
+ "metrics": {
562
+ "bleu": avg_bleu,
563
+ "cer": avg_cer,
564
+ "wer": avg_wer
565
+ },
566
+ "sample_results": sample_results
567
+ }
568
+
569
+ save_results(results, dataset.name, task_type, source_lang, target_lang)
570
+ return results
571
+
572
+
573
+ if __name__ == "__main__":
574
+
575
+ source_languages = [
576
+ ("en_us", "English"),
577
+ ]
578
+
579
+ target_languages = [
580
+ ("zh-TW", "zh-TW"),
581
+ ]
582
+
583
+ num_samples = -1
584
+ batch_size = 32
585
+
586
+ for source_lang, target_lang in zip(source_languages, target_languages):
587
+ print(f"\n===== {source_lang[0]} ASR =====")
588
+
589
+ split = "test"
590
+
591
+ datasets = []
592
+
593
+
594
+
595
+ commonvoice_speech_tw = CommonVoiceDataset(
596
+ processor=processor,
597
+ source_lang="zh-TW",
598
+ split=split
599
+ )
600
+ datasets.append(commonvoice_speech_tw)
601
+ fleurs = FleursDataset(
602
+ processor=processor,
603
+ split=split,
604
+ source_lang="en_us", # English
605
+ mode="asr"
606
+ )
607
+ datasets.append(fleurs)
608
+
609
+ # Libri Speech Clean ASR mode (English -> English text)
610
+ # libri_speech_clean = LibriSpeechDataset(
611
+ # processor=processor,
612
+ # subset="clean",
613
+ # split=split
614
+ # )
615
+ # datasets.append(libri_speech_clean)
616
+
617
+ # # Libri Speech Other ASR mode (English -> English text)
618
+ # libri_speech_other = LibriSpeechDataset(
619
+ # processor=processor,
620
+ # subset="other",
621
+ # split=split
622
+ # )
623
+ # datasets.append(libri_speech_other)
624
+
625
+ # Fleurs ASR mode (English -> English text)
626
+
627
+
628
+ for dataset in datasets:
629
+ # ASR
630
+ asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
631
+
632
+ print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR===")
633
+ print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
634
+ print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
635
+ print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
cpp/gemma_v1/eval_multiturn.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
cpp/gemma_v1/eval_multiturn.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from urllib.request import urlopen
3
+ import soundfile
4
+ import torch
5
+ from datasets import load_dataset, Audio
6
+ import numpy as np
7
+ from transformers import AutoModel, AutoProcessor, BatchFeature
8
+ from tqdm import tqdm
9
+ import json
10
+ import os
11
+ import time
12
+ from datetime import datetime
13
+ from whisper_normalizer.english import EnglishTextNormalizer
14
+ from whisper_normalizer.basic import BasicTextNormalizer
15
+ import sacrebleu
16
+ from jiwer import cer, wer
17
+ from torch.utils.data import Dataset, DataLoader
18
+ import soundfile as sf
19
+ import re
20
+ from pathlib import Path
21
+ import opencc
22
+ from ASRDataset import *
23
+
24
+ converter = opencc.OpenCC('s2tw.json')
25
+ normalizer = {
26
+ "en_us" : EnglishTextNormalizer(),
27
+ "other" : BasicTextNormalizer()
28
+ }
29
+
30
+ model_id = "/mnt/jeff/gemma_test"
31
+ revision = "main" #"v1.0"
32
+
33
+ model = AutoModel.from_pretrained(
34
+ model_id, device_map="cuda", revision = revision, trust_remote_code=True
35
+ ).eval()
36
+
37
+ processor = AutoProcessor.from_pretrained(
38
+ model_id, revision = revision, trust_remote_code=True
39
+ )
40
+
41
+ results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
42
+ os.makedirs(results_dir, exist_ok=True)
43
+
44
+
45
+ INSTRUCTION = {
46
+ "ast": "Translate the audio to {0}.",
47
+ "asr": "Transcribe the audio clip into text.",
48
+ }
49
+
50
+
51
+
52
+ def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
53
+ filename = f"{task}_{dataset_name}_{source_lang}"
54
+ if target_lang:
55
+ filename += f"_to_{target_lang}"
56
+ if sample_idx is not None:
57
+ filename += f"_sample_{sample_idx}"
58
+
59
+ filepath = os.path.join(results_dir, f"{filename}.json")
60
+
61
+ results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
62
+
63
+ with open(filepath, 'w', encoding='utf-8') as f:
64
+ json.dump(results, f, ensure_ascii=False, indent=2)
65
+
66
+ return filepath
67
+
68
+ def evaluate_task(dataset):
69
+ sample_results = []
70
+
71
+
72
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
73
+
74
+ evaluated_samples = {}
75
+
76
+ for batch_idx, batch in enumerate(tqdm(dataloader)):
77
+
78
+ if torch.cuda.is_available():
79
+ try:
80
+ batch = {k: v.to("cuda") for k, v in batch.items()}
81
+ except:
82
+ print('error')
83
+ break
84
+
85
+ with torch.inference_mode():
86
+ generate_ids = model.generate(**batch,
87
+ max_new_tokens=256,
88
+ #temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
89
+ )
90
+
91
+ input_lengths = batch['input_ids'].shape[1]
92
+ generate_ids = generate_ids[:, input_lengths:]
93
+
94
+ batch_predictions = processor.batch_decode(
95
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
96
+ )
97
+ input_lengths = batch['input_ids'].shape[1]
98
+ label_ids = generate_ids[:, input_lengths:]
99
+ batch_references = processor.batch_decode(
100
+ label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
101
+ )
102
+
103
+ for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
104
+ idx = batch_idx + i
105
+ sample_result = {
106
+ "id": idx,
107
+ "reference": reference,
108
+ "prediction": converter.convert(prediction)
109
+ }
110
+ sample_results.append(sample_result)
111
+
112
+ if (batch_idx + 1) % 10 == 0:
113
+ temp_results = []
114
+
115
+ for item in sample_results:
116
+ sample_id = item["id"]
117
+
118
+ if sample_id in evaluated_samples:
119
+ temp_item = item.copy()
120
+ temp_item.update(evaluated_samples[sample_id])
121
+ temp_results.append(temp_item)
122
+ else:
123
+ temp_item = item.copy()
124
+ try:
125
+ ref = eval_normalizer(item["reference"])
126
+ pred = eval_normalizer(item["prediction"])
127
+
128
+ # BLEU, WER/CER
129
+ utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
130
+ utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
131
+ utt_wer = round(wer(ref, pred) * 100, 2)
132
+
133
+ metrics = {
134
+ "bleu": utt_bleu,
135
+ "cer": min(100,utt_cer),
136
+ "wer": utt_wer
137
+ }
138
+
139
+ evaluated_samples[sample_id] = metrics
140
+ temp_item.update(metrics)
141
+ except Exception as e:
142
+ print(f"Error evaluating sample {sample_id}: {e}")
143
+ metrics = {
144
+ "bleu": 0,
145
+ "cer": 100,
146
+ "wer": 100,
147
+ "error": str(e)
148
+ }
149
+ evaluated_samples[sample_id] = metrics
150
+ temp_item.update(metrics)
151
+
152
+ temp_results.append(temp_item)
153
+
154
+ partial_results = {
155
+ "task": task_type,
156
+ "source_lang": source_lang,
157
+ "target_lang": target_lang,
158
+ "num_samples": len(temp_results),
159
+ "sample_results": temp_results
160
+ }
161
+ save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
162
+
163
+ for item in sample_results:
164
+ ref = eval_normalizer(item["reference"])
165
+ pred = eval_normalizer(item["prediction"])
166
+
167
+ utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
168
+ utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
169
+ utt_wer = round(wer(ref, pred) * 100, 2)
170
+
171
+ item.update({
172
+ "bleu": utt_bleu,
173
+ "cer": min(100,utt_cer),
174
+ "wer": utt_wer
175
+ })
176
+
177
+ avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
178
+ avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
179
+ avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
180
+
181
+ results = {
182
+ "dataset": dataset.name,
183
+ "task": task_type,
184
+ "source_lang": source_lang,
185
+ "target_lang": target_lang,
186
+ "num_samples": len(sample_results),
187
+ "metrics": {
188
+ "bleu": avg_bleu,
189
+ "cer": avg_cer,
190
+ "wer": avg_wer
191
+ },
192
+ "sample_results": sample_results
193
+ }
194
+
195
+ save_results(results, dataset.name, task_type, source_lang, target_lang)
196
+ return results
197
+
198
+
199
+ if __name__ == "__main__":
200
+
201
+ datasets = []
202
+ pickup_dataset = MultiturnAudioDataset(split='eval',processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
203
+ datasets.append(pickup_dataset)
204
+ for dataset in datasets:
205
+ # ASR
206
+ asr_results = evaluate_task(dataset)
207
+
208
+ print(f"\n=== {asr_results.get('dataset', 'Dataset')}")
209
+ print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
210
+ print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
211
+ print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
cpp/gemma_v1/merge_lora.ipynb ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from safetensors import safe_open\n",
10
+ "\n",
11
+ "lora = {}\n",
12
+ "with safe_open(\"/data2/bjh/diffusion-pipe/cosmos_test/20250327_02-37-25/epoch5/adapter_model.safetensors\", framework=\"pt\", device='cpu') as f:\n",
13
+ " for k in f.keys():\n",
14
+ " lora[k] = f.get_tensor(k)"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "tensors = {}\n",
24
+ "with safe_open(\"/data2/bjh/ComfyUI/models/diffusion_models/Cosmos-1_0-Diffusion-14B-Text2World.safetensors\", framework=\"pt\", device='cpu') as f:\n",
25
+ " for k in f.keys():\n",
26
+ " tensors[k] = f.get_tensor(k)"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 3,
32
+ "metadata": {},
33
+ "outputs": [
34
+ {
35
+ "data": {
36
+ "text/plain": [
37
+ "1152"
38
+ ]
39
+ },
40
+ "execution_count": 3,
41
+ "metadata": {},
42
+ "output_type": "execute_result"
43
+ }
44
+ ],
45
+ "source": [
46
+ "len(lora)"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 4,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "name_lis = []\n",
56
+ "for k in lora:\n",
57
+ " a = k.split('.')[1:][:-2]\n",
58
+ " name = '.'.join(a)\n",
59
+ " name_lis.append(name)\n",
60
+ "name_lis=set(name_lis)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 5,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "import torch\n",
70
+ "new_dic = {}\n",
71
+ "for k in tensors:\n",
72
+ " name='.'.join(k.split('.')[1:][:-1])\n",
73
+ " if name in name_lis:\n",
74
+ " a,b = lora['diffusion_model.'+name+'.lora_A.weight'],lora['diffusion_model.'+name+'.lora_B.weight']\n",
75
+ " new_dic[k]=tensors[k]+torch.matmul(b,a)\n",
76
+ " else:\n",
77
+ " new_dic[k]=tensors[k]"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 6,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "from safetensors.torch import save_file\n",
87
+ "save_file(new_dic,'test.safetensors')"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": []
96
+ }
97
+ ],
98
+ "metadata": {
99
+ "kernelspec": {
100
+ "display_name": "dp",
101
+ "language": "python",
102
+ "name": "python3"
103
+ },
104
+ "language_info": {
105
+ "codemirror_mode": {
106
+ "name": "ipython",
107
+ "version": 3
108
+ },
109
+ "file_extension": ".py",
110
+ "mimetype": "text/x-python",
111
+ "name": "python",
112
+ "nbconvert_exporter": "python",
113
+ "pygments_lexer": "ipython3",
114
+ "version": "3.12.9"
115
+ }
116
+ },
117
+ "nbformat": 4,
118
+ "nbformat_minor": 2
119
+ }
cpp/gemma_v1/model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddd3e8916f7ad6ad92651ac288227995c1d34628f0f888eb2dc5b9acb4dc0121
3
+ size 4976361384
cpp/gemma_v1/model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c343a455ca768923cb3b9ab77cbb91c9cd2526a1bee5740cf9cf86bfa85a0a7b
3
+ size 4984907872
cpp/gemma_v1/model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad04449d015f4efbda75d6cc41e06296b4da996cd84053fa6f9791fe16d55d03
3
+ size 732141104
cpp/gemma_v1/model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
cpp/gemma_v1/modeling_gemma3omni.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, HybridCache, StaticCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
14
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
15
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
16
+ from transformers.processing_utils import Unpack
17
+ from transformers.utils import (
18
+ add_start_docstrings,
19
+ add_start_docstrings_to_model_forward,
20
+ is_torchdynamo_compiling,
21
+ logging,
22
+ replace_return_docstrings,
23
+ )
24
+ from transformers.utils.deprecation import deprecate_kwarg
25
+ from transformers import AutoModel, AutoModelForCausalLM
26
+
27
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast, Gemma3PreTrainedModel, Gemma3MultiModalProjector
28
+
29
+ from transformers import AutoConfig, AutoModelForCausalLM
30
+
31
+ from .configuration_gemma3omni import Gemma3OmniConfig
32
+ from .speech_conformer_encoder import ConformerEncoder
33
+ from enum import Enum
34
+ class InputMode(Enum):
35
+ LANGUAGE = 0
36
+ VISION = 1
37
+ SPEECH = 2
38
+ VISION_SPEECH = 3
39
+ logger = logging.get_logger(__name__)
40
+ _CONFIG_FOR_DOC = "Gemma3OmniConfig"
41
+
42
+ @dataclass
43
+ class Gemma3OmniCausalLMOutputWithPast(Gemma3CausalLMOutputWithPast):
44
+ """
45
+ Multimodal version of `Gemma3CausalLMOutputWithPast`.
46
+ Adds audio-specific hidden states.
47
+
48
+ Args:
49
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
50
+ A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
51
+ Audio hidden states produced by the audio encoder.
52
+ """
53
+ audio_hidden_states: Optional[torch.FloatTensor] = None
54
+
55
+
56
+ GEMMA3_START_DOCSTRING = r"""
57
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
58
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
59
+ etc.)
60
+
61
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
62
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
63
+ and behavior.
64
+
65
+ Parameters:
66
+ config ([`Gemma3Config`]):
67
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
68
+ load the weights associated with the model, only the configuration. Check out the
69
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
70
+ """
71
+
72
+
73
+
74
+ GEMMA3_INPUTS_DOCSTRING = r"""
75
+ Args:
76
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
77
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
78
+ it.
79
+
80
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
81
+ [`PreTrainedTokenizer.__call__`] for details.
82
+
83
+ [What are input IDs?](../glossary#input-ids)
84
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
85
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
86
+
87
+ - 1 for tokens that are **not masked**,
88
+ - 0 for tokens that are **masked**.
89
+
90
+ [What are attention masks?](../glossary#attention-mask)
91
+
92
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
93
+ [`PreTrainedTokenizer.__call__`] for details.
94
+
95
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
96
+ `past_key_values`).
97
+
98
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
99
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
100
+ information on the default strategy.
101
+
102
+ - 1 indicates the head is **not masked**,
103
+ - 0 indicates the head is **masked**.
104
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
105
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
106
+ config.n_positions - 1]`.
107
+
108
+ [What are position IDs?](../glossary#position-ids)
109
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
110
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
111
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
112
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
113
+
114
+ Two formats are allowed:
115
+ - a [`~cache_utils.Cache`] instance, see our
116
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
117
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
118
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
119
+ cache format.
120
+
121
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
122
+ legacy cache format will be returned.
123
+
124
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
125
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
126
+ of shape `(batch_size, sequence_length)`.
127
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
128
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
129
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
130
+ model's internal embedding lookup matrix.
131
+ use_cache (`bool`, *optional*):
132
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
133
+ `past_key_values`).
134
+ output_attentions (`bool`, *optional*):
135
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
136
+ tensors for more detail.
137
+ output_hidden_states (`bool`, *optional*):
138
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
139
+ more detail.
140
+ return_dict (`bool`, *optional*):
141
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
142
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
143
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
144
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
145
+ the complete sequence length.
146
+ """
147
+
148
+ @add_start_docstrings(
149
+ "The bare Gemma3 Model outputting raw hidden-states without any specific head on top.",
150
+ GEMMA3_START_DOCSTRING,
151
+ )
152
+ class Gemma3OmniPreTrainedModel(Gemma3PreTrainedModel):
153
+ config_class = Gemma3OmniConfig
154
+
155
+ @add_start_docstrings(
156
+ """The GEMMA3 model which consists of a vision backbone and a language model.""",
157
+ GEMMA3_START_DOCSTRING,
158
+ )
159
+ class Gemma3OmniForConditionalGeneration(Gemma3OmniPreTrainedModel, GenerationMixin):
160
+ def __init__(self, config: Gemma3OmniConfig):
161
+ super().__init__(config)
162
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
163
+ audio_config = config.audio_config.to_diff_dict()
164
+ for item in ['transformers_version', 'model_type', 'torch_dtype']:
165
+ if item in audio_config:
166
+ audio_config.pop(item)
167
+ self.audio_tower = ConformerEncoder(**audio_config)
168
+ self.audio_tower.post_init({})
169
+ self.audio_tower = self.audio_tower.to(dtype=self.dtype)
170
+ self.audio_projector = nn.Sequential(
171
+ nn.Linear(in_features=config.audio_config.attention_dim, out_features=config.text_config.hidden_size, bias=True),
172
+ nn.GELU(approximate='none'),
173
+ nn.Linear(in_features=config.text_config.hidden_size, out_features=config.text_config.hidden_size, bias=True)
174
+ ).to(dtype=self.dtype)
175
+
176
+ self.multi_modal_projector = Gemma3MultiModalProjector(config)
177
+ self.vocab_size = config.text_config.vocab_size
178
+
179
+ language_model = AutoModelForCausalLM.from_config(config=config.text_config)
180
+
181
+ if language_model._tied_weights_keys is not None:
182
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
183
+ self.language_model = language_model
184
+
185
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
186
+ self.init_lora()
187
+ self.post_init()
188
+
189
+
190
+ def init_lora(self):
191
+ from peft import LoraConfig, get_peft_model
192
+ import warnings
193
+ print('######################## speech lora #############')
194
+ speech_lora_config = LoraConfig(
195
+ r=self.config.speech_lora['r'],
196
+ lora_alpha=self.config.speech_lora['lora_alpha'],
197
+ target_modules=self.config.speech_lora['layer'],
198
+ use_rslora=self.config.speech_lora['use_rslora'],
199
+ lora_dropout=self.config.speech_lora['dp'],
200
+ task_type="CAUSAL_LM",
201
+ )
202
+ self.language_model.model = get_peft_model(self.language_model.model, speech_lora_config, adapter_name="speech")
203
+ print('######################## text lora #############')
204
+ text_lora_config = LoraConfig(
205
+ r=self.config.text_lora['r'],
206
+ lora_alpha=self.config.text_lora['lora_alpha'],
207
+ target_modules=self.config.text_lora['layer'],
208
+ use_rslora=self.config.text_lora['use_rslora'],
209
+ lora_dropout=self.config.text_lora['dp'],
210
+ task_type="CAUSAL_LM",
211
+ )
212
+ self.language_model.model.base_model.active_adapter.append("text")
213
+ self.language_model.model.add_adapter("text", text_lora_config)
214
+
215
+ def set_lora_adapter(self, adapter_name) -> None:
216
+ from peft.tuners.lora.layer import LoraLayer
217
+ for module in self.modules():
218
+ if isinstance(module, LoraLayer):
219
+ if module.merged:
220
+ warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
221
+ module.unmerge()
222
+ module.set_adapter(adapter_name)
223
+ module._disable_adapters = False
224
+
225
+ def unset_lora_adapter(self) -> None:
226
+ # Ref: peft/tuners/tuners_utils.py - enable_adapters()
227
+ # Ref: peft/tuners/lora/layer.py
228
+ from peft.tuners.lora.layer import LoraLayer
229
+ for module in self.modules():
230
+ if isinstance(module, LoraLayer):
231
+ # disable grads on all adapter layers
232
+ # TODO weijian: may use enable_adapters() instead
233
+ for layer_name in module.adapter_layer_names:
234
+ layer = getattr(module, layer_name)
235
+ layer.requires_grad_(False)
236
+ module._disable_adapters = True
237
+
238
+ def get_input_embeddings(self):
239
+ return self.language_model.get_input_embeddings()
240
+
241
+ def set_input_embeddings(self, value):
242
+ self.language_model.set_input_embeddings(value)
243
+
244
+ def get_output_embeddings(self):
245
+ return self.language_model.get_output_embeddings()
246
+
247
+ def set_output_embeddings(self, new_embeddings):
248
+ self.language_model.set_output_embeddings(new_embeddings)
249
+
250
+ def set_decoder(self, decoder):
251
+ self.language_model.set_decoder(decoder)
252
+
253
+ def get_decoder(self):
254
+ return self.language_model.get_decoder()
255
+
256
+ def _update_causal_mask(
257
+ self,
258
+ attention_mask,
259
+ token_type_ids,
260
+ past_key_values,
261
+ cache_position,
262
+ input_tensor,
263
+ is_training: bool = False,
264
+ ):
265
+ if self.config.text_config._attn_implementation == "flash_attention_2":
266
+ return attention_mask
267
+
268
+ if attention_mask is not None and attention_mask.dim() == 4:
269
+ # In this case we assume that the mask comes already in inverted
270
+ # form and requires no inversion or slicing.
271
+ return attention_mask
272
+
273
+ using_static_cache = isinstance(past_key_values, StaticCache)
274
+ min_dtype = torch.finfo(self.dtype).min
275
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
276
+ if using_static_cache:
277
+ target_length = past_key_values.get_max_cache_shape()
278
+ elif isinstance(past_key_values, HybridCache):
279
+ target_length = past_key_values.get_max_cache_shape()
280
+ else:
281
+ target_length = (
282
+ attention_mask.shape[-1]
283
+ if isinstance(attention_mask, torch.Tensor)
284
+ else cache_position[0] + sequence_length + 1
285
+ )
286
+
287
+ if attention_mask is not None and attention_mask.dim() == 4:
288
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
289
+ return attention_mask
290
+
291
+ causal_mask = torch.full(
292
+ (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
293
+ )
294
+
295
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
296
+ if sequence_length != 1:
297
+ causal_mask = torch.triu(causal_mask, diagonal=1)
298
+
299
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
300
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
301
+
302
+ # Apply bidirectional mask on images if token type ids are provided
303
+ if token_type_ids is not None and sequence_length != 1:
304
+ token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
305
+ token_type_mask[token_type_ids == 0] = False # if text token do not change anything
306
+ token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
307
+ causal_mask = causal_mask.clone()
308
+ causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
309
+ token_type_mask, 0.0
310
+ )
311
+
312
+ if attention_mask is not None:
313
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
314
+ mask_length = attention_mask.shape[-1]
315
+
316
+ # Then apply padding mask (will mask pad tokens)
317
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
318
+ padding_mask = padding_mask == 0
319
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
320
+ padding_mask, min_dtype
321
+ )
322
+
323
+ return causal_mask
324
+
325
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
326
+ """
327
+ Projects the last hidden state from the vision model into language model space.
328
+
329
+ Args:
330
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
331
+ The tensors corresponding to the input images.
332
+ Returns:
333
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
334
+ """
335
+ vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
336
+ image_features = self.multi_modal_projector(vision_outputs)
337
+ return image_features
338
+
339
+ def get_audio_features(self, input_audio_embeds: torch.FloatTensor, audio_attention_mask: torch.FloatTensor, audio_embed_sizes: torch.FloatTensor):
340
+ """
341
+ Projects the last hidden state from the audio model into language model space.
342
+
343
+ Args:
344
+ audio_inputs (`torch.FloatTensor]` of shape `(batch_size, sequence_length, feature_dim)`)
345
+ The tensors corresponding to the input audio features.
346
+
347
+ Returns:
348
+ audio_features (`torch.Tensor`): Audio feature tensor of shape `(batch_size, audio_length, embed_dim)`).
349
+ """
350
+ audio_features, masks = self.audio_tower(input_audio_embeds, audio_attention_mask)
351
+ audio_outputs = self.audio_projector(audio_features)
352
+ return audio_outputs
353
+
354
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
355
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
356
+ @replace_return_docstrings(output_type=Gemma3OmniCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
357
+ def forward(
358
+ self,
359
+ input_ids: Optional[torch.LongTensor] = None,
360
+ pixel_values: Optional[torch.FloatTensor] = None,
361
+ input_audio_embeds: torch.FloatTensor = None,
362
+ audio_embed_sizes: torch.FloatTensor = None,
363
+ audio_attention_mask: torch.FloatTensor = None,
364
+ attention_mask: Optional[torch.Tensor] = None,
365
+ input_modes: torch.LongTensor = None,
366
+ position_ids: Optional[torch.LongTensor] = None,
367
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
368
+ token_type_ids: Optional[torch.LongTensor] = None,
369
+ cache_position: Optional[torch.LongTensor] = None,
370
+ inputs_embeds: Optional[torch.FloatTensor] = None,
371
+ labels: Optional[torch.LongTensor] = None,
372
+ use_cache: Optional[bool] = None,
373
+ output_attentions: Optional[bool] = None,
374
+ output_hidden_states: Optional[bool] = None,
375
+ return_dict: Optional[bool] = None,
376
+ logits_to_keep: Union[int, torch.Tensor] = 0,
377
+ **lm_kwargs,
378
+ ) -> Union[Tuple, Gemma3OmniCausalLMOutputWithPast]:
379
+ r"""
380
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
381
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
382
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
383
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
384
+
385
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
386
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
387
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
388
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
389
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
390
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
391
+
392
+ Returns:
393
+
394
+ Example:
395
+
396
+ ```python
397
+ >>> from PIL import Image
398
+ >>> import requests
399
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
400
+
401
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
402
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
403
+
404
+ >>> messages = [
405
+ ... {
406
+ ... "role": "system",
407
+ ... "content": [
408
+ ... {"type": "text", "text": "You are a helpful assistant."}
409
+ ... ]
410
+ ... },
411
+ ... {
412
+ ... "role": "user", "content": [
413
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
414
+ ... {"type": "text", "text": "Where is the cat standing?"},
415
+ ... ]
416
+ ... },
417
+ ... ]
418
+
419
+ >>> inputs = processor.apply_chat_template(
420
+ ... messages,
421
+ ... tokenizer=True,
422
+ ... return_dict=True,
423
+ ... return_tensors="pt",
424
+ ... add_generation_prompt=True
425
+ ... )
426
+ >>> # Generate
427
+ >>> generate_ids = model.generate(**inputs)
428
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
429
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
430
+ ```
431
+ """
432
+
433
+ if (input_ids is None) ^ (inputs_embeds is not None):
434
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
435
+
436
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
437
+ output_hidden_states = (
438
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
439
+ )
440
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
441
+
442
+ if isinstance(input_modes, torch.Tensor):
443
+ # len(input_mode) == num_beams in beam search, and all elements of input_mode should have the same value
444
+ input_modes = input_modes.unique()
445
+ if len(input_modes) != 1:
446
+ raise ValueError("Elements of input_modes should have the same value")
447
+
448
+ input_mode = InputMode(input_modes.item())
449
+
450
+ if input_mode in [InputMode.VISION_SPEECH, InputMode.VISION]:
451
+ self.unset_lora_adapter()
452
+ #self.set_lora_adapter('vision')
453
+ #audio_projection_mode = 'vision'
454
+ elif input_mode == InputMode.SPEECH:
455
+ self.unset_lora_adapter()
456
+ self.set_lora_adapter('speech')
457
+ #audio_projection_mode = 'speech'
458
+ elif input_mode == InputMode.LANGUAGE:
459
+ self.unset_lora_adapter()
460
+ self.set_lora_adapter('text')
461
+
462
+ #audio_projection_mode = 'speech'
463
+ else:
464
+ raise ValueError(f"Invalid input_mode: {input_mode}")
465
+
466
+ is_training = token_type_ids is not None and labels is not None
467
+
468
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
469
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size or self.config.audio_token_index >= self.vocab_size:
470
+ special_image_mask = input_ids == self.config.image_token_index
471
+ special_audio_mask = input_ids == self.config.audio_token_index
472
+ llm_input_ids = input_ids.clone()
473
+ llm_input_ids[special_image_mask] = 0
474
+ llm_input_ids[special_audio_mask] = 0
475
+ else:
476
+ llm_input_ids = input_ids
477
+
478
+ if inputs_embeds is None:
479
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
480
+ inputs_embeds = inputs_embeds.to(dtype=self.dtype)
481
+ if cache_position is None:
482
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
483
+ cache_position = torch.arange(
484
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
485
+ )
486
+
487
+ if position_ids is None:
488
+ position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
489
+
490
+ # Merge text and images
491
+ if pixel_values is not None:
492
+ image_features = self.get_image_features(pixel_values)
493
+
494
+ if input_ids is None:
495
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
496
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
497
+ )
498
+ else:
499
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
500
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
501
+
502
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
503
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
504
+ raise ValueError(
505
+ f"Number of images does not match number of special image tokens in the input text. "
506
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
507
+ "tokens from image embeddings."
508
+ )
509
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
510
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
511
+
512
+ # Merge text and audios
513
+ if input_audio_embeds is not None:
514
+ input_audio_embeds=input_audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
515
+ if audio_attention_mask is not None:
516
+ audio_attention_mask=audio_attention_mask.to(inputs_embeds.device, inputs_embeds.dtype)
517
+ audio_features = self.get_audio_features(input_audio_embeds, audio_attention_mask, audio_embed_sizes)
518
+ if input_ids is None:
519
+ special_audio_mask = inputs_embeds == self.get_input_embeddings()(
520
+ torch.tensor(self.config.audio_token_index, dtype=torch.long, device=inputs_embeds.device)
521
+ )
522
+ else:
523
+ special_audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1)
524
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
525
+ masked_audio_features = []
526
+ for i, size in enumerate(audio_embed_sizes):
527
+ masked_audio_features.append(audio_features[i, :size, :])
528
+ masked_audio_features = torch.cat(masked_audio_features, dim=0)
529
+
530
+ if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != masked_audio_features.numel():
531
+ audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
532
+ masked_audio_size = audio_embed_sizes#.sum()[0]
533
+ raise ValueError(
534
+ f"Number of audio does not match number of special audio tokens in the input text. "
535
+ f"Got {audio_tokens_in_text} audio tokens in the text but {masked_audio_size} "
536
+ "tokens from audio embeddings. "
537
+ f"{masked_audio_features.numel()} \n"
538
+ f"{inputs_embeds[special_audio_mask].numel()} \n"
539
+ f"{audio_features} \n"
540
+ f"{inputs_embeds[special_audio_mask]} \n"
541
+ f"{special_audio_mask} \n"
542
+ )
543
+ masked_audio_features = masked_audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
544
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, masked_audio_features)
545
+ # mask out pad-token-ids in labels for BC
546
+ if labels is not None and self.pad_token_id in labels:
547
+ logger.warning_once(
548
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
549
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
550
+ )
551
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
552
+
553
+ causal_mask = self._update_causal_mask(
554
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
555
+ )
556
+ outputs = self.language_model(
557
+ attention_mask=causal_mask,
558
+ position_ids=position_ids,
559
+ past_key_values=past_key_values,
560
+ inputs_embeds=inputs_embeds,
561
+ use_cache=use_cache,
562
+ output_attentions=output_attentions,
563
+ output_hidden_states=output_hidden_states,
564
+ return_dict=return_dict,
565
+ cache_position=cache_position,
566
+ logits_to_keep=logits_to_keep,
567
+ **lm_kwargs,
568
+ )
569
+
570
+ logits = outputs.logits
571
+ loss = None
572
+ # print('#############################')
573
+ # print(logits)
574
+ if labels is not None:
575
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
576
+ logits = logits.float()
577
+ shift_logits = logits[..., :-1, :]
578
+ shift_labels = labels[..., 1:]
579
+ if attention_mask is not None:
580
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
581
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
582
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
583
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
584
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
585
+ else:
586
+ shift_logits = shift_logits.contiguous()
587
+ shift_labels = shift_labels.contiguous()
588
+ # Flatten the tokens
589
+ loss_fct = nn.CrossEntropyLoss()
590
+
591
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
592
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
593
+ loss = loss_fct(flat_logits, flat_labels)
594
+ # print('flat logits',flat_logits)
595
+ # print(flat_labels)
596
+ # print(loss)
597
+ if not return_dict:
598
+ output = (logits,) + outputs[1:]
599
+ return (loss,) + output if loss is not None else output
600
+
601
+ return Gemma3OmniCausalLMOutputWithPast(
602
+ loss=loss,
603
+ logits=logits,
604
+ past_key_values=outputs.past_key_values,
605
+ hidden_states=outputs.hidden_states,
606
+ attentions=outputs.attentions,
607
+ image_hidden_states=image_features if pixel_values is not None else None,
608
+ audio_hidden_states=audio_features if input_audio_embeds is not None else None,
609
+ )
610
+
611
+ def prepare_inputs_for_generation(
612
+ self,
613
+ input_ids,
614
+ past_key_values=None,
615
+ input_modes=None,
616
+ inputs_embeds=None,
617
+ cache_position=None,
618
+ position_ids=None,
619
+ pixel_values=None,
620
+ input_audio_embeds=None,
621
+ audio_embed_sizes=None,
622
+ audio_attention_mask=None,
623
+ attention_mask=None,
624
+ token_type_ids=None,
625
+ use_cache=True,
626
+ logits_to_keep=None,
627
+ labels=None,
628
+ **kwargs,
629
+ ):
630
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
631
+ model_inputs = self.language_model.prepare_inputs_for_generation(
632
+ input_ids,
633
+ past_key_values=past_key_values,
634
+ input_modes=input_modes,
635
+ inputs_embeds=inputs_embeds,
636
+ attention_mask=attention_mask,
637
+ position_ids=position_ids,
638
+ cache_position=cache_position,
639
+ use_cache=use_cache,
640
+ logits_to_keep=logits_to_keep,
641
+ token_type_ids=token_type_ids,
642
+ **kwargs,
643
+ )
644
+
645
+ # position_ids in Gemma3 are 1-indexed
646
+ if model_inputs.get("position_ids") is not None:
647
+ model_inputs["position_ids"] += 1
648
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
649
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
650
+ if cache_position[0] == 0:
651
+ model_inputs["pixel_values"] = pixel_values
652
+ model_inputs["input_audio_embeds"] = input_audio_embeds
653
+ model_inputs["audio_embed_sizes"] = audio_embed_sizes
654
+ model_inputs["audio_attention_mask"] = audio_attention_mask
655
+ model_inputs["input_modes"] = input_modes
656
+ is_training = token_type_ids is not None and labels is not None
657
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
658
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
659
+ causal_mask = self._update_causal_mask(
660
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
661
+ )
662
+ model_inputs["attention_mask"] = causal_mask
663
+
664
+ return model_inputs
665
+
666
+ def tie_weights(self):
667
+ return self.language_model.tie_weights()
668
+
cpp/gemma_v1/preprocessing_gemma3omni.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Optional, Union, Tuple
3
+ from math import ceil
4
+
5
+ import numpy as np
6
+ import torch
7
+ import scipy
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+ from enum import Enum
11
+
12
+ from transformers import AutoFeatureExtractor
13
+ from transformers.feature_extraction_utils import BatchFeature
14
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
15
+ from transformers.image_utils import ImageInput, make_nested_list_of_images
16
+ from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
17
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
18
+ from transformers.utils import to_py_obj, TensorType
19
+ from transformers.audio_utils import AudioInput
20
+
21
+
22
+ class Gemma3ImagesKwargs(ImagesKwargs):
23
+ do_pan_and_scan: Optional[bool]
24
+ pan_and_scan_min_crop_size: Optional[int]
25
+ pan_and_scan_max_num_crops: Optional[int]
26
+ pan_and_scan_min_ratio_to_activate: Optional[float]
27
+ do_convert_rgb: Optional[bool]
28
+
29
+
30
+ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
31
+ images_kwargs: Gemma3ImagesKwargs
32
+ _defaults = {
33
+ "text_kwargs": {
34
+ "padding": False,
35
+ },
36
+ "images_kwargs": {
37
+ "do_pan_and_scan": False,
38
+ "pan_and_scan_min_crop_size": 256,
39
+ "pan_and_scan_max_num_crops": 4,
40
+ "pan_and_scan_min_ratio_to_activate": 1.2,
41
+ },
42
+ }
43
+
44
+ def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
45
+ """Create a Mel filter-bank the same as SpeechLib FbankFC.
46
+
47
+ Args:
48
+ sample_rate (int): Sample rate in Hz. number > 0 [scalar]
49
+ n_fft (int): FFT size. int > 0 [scalar]
50
+ n_mel (int): Mel filter size. int > 0 [scalar]
51
+ fmin (float): lowest frequency (in Hz). If None use 0.0.
52
+ float >= 0 [scalar]
53
+ fmax: highest frequency (in Hz). If None use sample_rate / 2.
54
+ float >= 0 [scalar]
55
+
56
+ Returns
57
+ out (numpy.ndarray): Mel transform matrix
58
+ [shape=(n_mels, 1 + n_fft/2)]
59
+ """
60
+
61
+ bank_width = int(n_fft // 2 + 1)
62
+ if fmax is None:
63
+ fmax = sample_rate / 2
64
+ if fmin is None:
65
+ fmin = 0
66
+ assert fmin >= 0, "fmin cannot be negtive"
67
+ assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
68
+
69
+ def mel(f):
70
+ return 1127.0 * np.log(1.0 + f / 700.0)
71
+
72
+ def bin2mel(fft_bin):
73
+ return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
74
+
75
+ def f2bin(f):
76
+ return int((f * n_fft / sample_rate) + 0.5)
77
+
78
+ # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
79
+ klo = f2bin(fmin) + 1
80
+ khi = f2bin(fmax)
81
+
82
+ khi = max(khi, klo)
83
+
84
+ # Spec 2: SpeechLib uses trianges in Mel space
85
+ mlo = mel(fmin)
86
+ mhi = mel(fmax)
87
+ m_centers = np.linspace(mlo, mhi, n_mels + 2)
88
+ ms = (mhi - mlo) / (n_mels + 1)
89
+
90
+ matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
91
+ for m in range(0, n_mels):
92
+ left = m_centers[m]
93
+ center = m_centers[m + 1]
94
+ right = m_centers[m + 2]
95
+ for fft_bin in range(klo, khi):
96
+ mbin = bin2mel(fft_bin)
97
+ if left < mbin < right:
98
+ matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
99
+
100
+ return matrix
101
+
102
+
103
+ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
104
+ model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
105
+
106
+ def __init__(self, audio_compression_rate=8,
107
+ audio_downsample_rate=1,
108
+ audio_feat_stride=1,
109
+ feature_size = 80,
110
+ sampling_rate = 16000,
111
+ padding_value = 0.0,
112
+ **kwargs):
113
+
114
+ super().__init__(feature_size=feature_size,
115
+ sampling_rate=sampling_rate,
116
+ padding_value=padding_value, **kwargs)
117
+
118
+ self.compression_rate = audio_compression_rate
119
+ self.qformer_compression_rate = audio_downsample_rate
120
+ self.feat_stride = audio_feat_stride
121
+
122
+ self._eightk_method = "fillzero"
123
+ self._mel = speechlib_mel(self.sampling_rate, 512, self.feature_size, fmin=None, fmax=self.sampling_rate//2-self.feature_size-230).T
124
+
125
+ self._hamming400 = np.hamming(400) # for 16k audio
126
+ self._hamming200 = np.hamming(200) # for 8k audio
127
+
128
+ def duration_to_frames(self, duration):
129
+ """duration in s, estimated frames"""
130
+ frame_rate = 10
131
+
132
+ num_frames = duration * 1000 // frame_rate
133
+ return num_frames
134
+
135
+ def __call__(
136
+ self,
137
+ audios: List[AudioInput],
138
+ sampling_rate = 16000,
139
+ return_attention_mask=True,
140
+ padding="max_length",
141
+ return_tensors: Optional[Union[str, TensorType]] = None,
142
+ ):
143
+ # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
144
+ returned_input_audio_embeds = []
145
+ returned_audio_embed_sizes = []
146
+ audio_frames_list = []
147
+
148
+ for audio_data in audios:
149
+ audio_embeds = self._extract_features(audio_data, sampling_rate)
150
+ audio_frames = len(audio_embeds) * self.feat_stride
151
+ audio_embed_size = self._compute_audio_embed_size(audio_frames)
152
+
153
+ returned_input_audio_embeds.append(torch.tensor(audio_embeds))
154
+ returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
155
+ audio_frames_list.append(audio_frames)
156
+
157
+ returned_input_audio_embeds = pad_sequence(
158
+ returned_input_audio_embeds, batch_first=True
159
+ )
160
+ returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
161
+ audio_frames = torch.tensor(audio_frames_list)
162
+ returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
163
+
164
+ data = {
165
+ "input_audio_embeds": returned_input_audio_embeds,
166
+ "audio_embed_sizes": returned_audio_embed_sizes,
167
+ }
168
+ if returned_audio_attention_mask is not None and return_attention_mask:
169
+ data["audio_attention_mask"] = returned_audio_attention_mask
170
+
171
+ return BatchFeature(data=data, tensor_type=return_tensors)
172
+
173
+ def _extract_spectrogram(self, wav, fs):
174
+ """Extract spectrogram features from waveform.
175
+ Args:
176
+ wav (1D array): waveform of the input
177
+ fs (int): sampling rate of the waveform, 16000 or 8000.
178
+ If fs=8000, the waveform will be resampled to 16000Hz.
179
+ Output:
180
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
181
+ D=80, and T is the number of frames.
182
+ """
183
+ if wav.ndim > 1:
184
+ wav = np.squeeze(wav)
185
+
186
+ # by default, we extract the mean if stereo
187
+ if len(wav.shape) == 2:
188
+ wav = wav.mean(1)
189
+
190
+ # Resample to 16000 or 8000 if needed
191
+ if fs > 16000:
192
+ wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
193
+ fs = 16000
194
+ elif 8000 < fs < 16000:
195
+ wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
196
+ fs = 8000
197
+ elif fs < 8000:
198
+ raise RuntimeError(f"Unsupported sample rate {fs}")
199
+
200
+ if fs == 8000:
201
+ if self._eightk_method == "resample":
202
+ # Input audio is 8 kHz. Convert to 16 kHz before feature
203
+ # extraction
204
+ wav = scipy.signal.resample_poly(wav, 2, 1)
205
+ fs = 16000
206
+ # Do nothing here for fillzero method
207
+ elif fs != 16000:
208
+ # Input audio is not a supported sample rate.
209
+ raise RuntimeError(f"Input data using an unsupported sample rate: {fs}")
210
+
211
+ preemphasis = 0.97
212
+
213
+ if fs == 8000:
214
+ n_fft = 256
215
+ win_length = 200
216
+ hop_length = 80
217
+ fft_window = self._hamming200
218
+ elif fs == 16000:
219
+ n_fft = 512
220
+ win_length = 400
221
+ hop_length = 160
222
+ fft_window = self._hamming400
223
+
224
+ # Spec 1: SpeechLib cut remaining sample insufficient for a hop
225
+ n_batch = (wav.shape[0] - win_length) // hop_length + 1
226
+ # Here we don't use stride_tricks since the input array may not satisfy
227
+ # memory layout requirement and we need writeable output
228
+ # Here we only use list of views before copy to desination
229
+ # so it is more efficient than broadcasting
230
+ y_frames = np.array(
231
+ [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
232
+ dtype=np.float32,
233
+ )
234
+
235
+ # Spec 2: SpeechLib applies preemphasis within each batch
236
+ y_frames_prev = np.roll(y_frames, 1, axis=1)
237
+ y_frames_prev[:, 0] = y_frames_prev[:, 1]
238
+ y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
239
+
240
+ S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
241
+
242
+ if fs == 8000:
243
+ # Need to pad the output to look like 16 kHz data but with zeros in
244
+ # the 4 to 8 kHz bins.
245
+ frames, bins = S.shape
246
+ padarray = np.zeros((frames, bins))
247
+ S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
248
+
249
+ spec = np.abs(S).astype(np.float32)
250
+ return spec
251
+
252
+ def _extract_features(self, wav, fs):
253
+ """Extract log filterbank features from waveform.
254
+ Args:
255
+ wav (1D array): waveform of the input
256
+ fs (int): sampling rate of the waveform, 16000 or 8000.
257
+ If fs=8000, the waveform will be resampled to 16000Hz.
258
+ Output:
259
+ log_fbank (2D array): a TxD matrix of log Mel filterbank features.
260
+ D=80, and T is the number of frames.
261
+ """
262
+ spec = self._extract_spectrogram(wav, fs)
263
+ spec_power = spec**2
264
+
265
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
266
+ log_fbank = np.log(fbank_power).astype(np.float32)
267
+
268
+ return log_fbank
269
+
270
+ def _compute_audio_embed_size(self, audio_frames):
271
+ integer = audio_frames // self.compression_rate
272
+ remainder = audio_frames % self.compression_rate
273
+
274
+ result = integer if remainder == 0 else integer + 1
275
+
276
+ integer = result // self.qformer_compression_rate
277
+ remainder = result % self.qformer_compression_rate
278
+ result = integer if remainder == 0 else integer + 1 # qformer compression
279
+
280
+ return result
281
+
282
+ class Gemma3OmniProcessor(ProcessorMixin):
283
+ attributes = ["image_processor", "feature_extractor", "tokenizer"]
284
+ valid_kwargs = ["chat_template", "image_seq_length"]
285
+ image_processor_class = "AutoImageProcessor"
286
+ feature_extractor_class = "Gemma3AudioFeatureExtractor"
287
+ tokenizer_class = "AutoTokenizer"
288
+
289
+ def __init__(
290
+ self,
291
+ image_processor,
292
+ feature_extractor,
293
+ tokenizer,
294
+ chat_template=None,
295
+ image_seq_length: int = 256,
296
+ **kwargs,
297
+ ):
298
+ self.image_seq_length = image_seq_length
299
+ self.image_token_id = tokenizer.image_token_id
300
+ self.boi_token = tokenizer.boi_token
301
+ self.image_token = tokenizer.image_token
302
+ image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
303
+ self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
304
+
305
+ self.audio_token_id = tokenizer.audio_token_id
306
+ self.boa_token = tokenizer.boa_token
307
+ self.eoa_token = tokenizer.eoa_token
308
+ self.audio_token = tokenizer.audio_token
309
+
310
+ super().__init__(
311
+ image_processor=image_processor,
312
+ feature_extractor=feature_extractor,
313
+ tokenizer=tokenizer,
314
+ chat_template=chat_template,
315
+ **kwargs,
316
+ )
317
+
318
+ def __call__(
319
+ self,
320
+ images: ImageInput = None,
321
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
322
+ videos=None,
323
+ audio: List[AudioInput] = None,
324
+ **kwargs: Unpack[Gemma3ProcessorKwargs],
325
+ ) -> BatchFeature:
326
+ if text is None and images is None:
327
+ raise ValueError("Provide at least one of `text` or `images`.")
328
+
329
+ output_kwargs = self._merge_kwargs(
330
+ Gemma3ProcessorKwargs,
331
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
332
+ **kwargs,
333
+ )
334
+
335
+ if isinstance(text, str):
336
+ text = [text]
337
+ elif not isinstance(text, list) and not isinstance(text[0], str):
338
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
339
+
340
+ image_inputs = {}
341
+ if images is not None:
342
+ batched_images = make_nested_list_of_images(images)
343
+ image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
344
+
345
+ # Create empty text to be replaced with placeholders
346
+ if not text:
347
+ text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
348
+
349
+ if len(batched_images) != len(text):
350
+ raise ValueError(
351
+ f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
352
+ )
353
+
354
+ # Replace image tokens by the full expanded sequence
355
+ num_crops = to_py_obj(image_inputs.pop("num_crops"))
356
+ batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
357
+ for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
358
+ image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
359
+
360
+ if len(images) != len(image_indexes):
361
+ raise ValueError(
362
+ f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
363
+ )
364
+
365
+ # Insert additional image tokens for Pan-and-Scan crops
366
+ for num, idx in reversed(list(zip(num_crops, image_indexes))):
367
+ if num:
368
+ formatted_image_text = (
369
+ f"Here is the original image {self.boi_token} and here are some crops to help you see better "
370
+ + " ".join([self.boi_token] * num)
371
+ )
372
+ prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
373
+ text[batch_idx] = prompt
374
+
375
+ # Expand placeholder image tokens to the full image token sequence
376
+ text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
377
+
378
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
379
+
380
+ audio_inputs = {}
381
+ if audio is not None:
382
+ full_audio_sequences = []
383
+ audio_inputs = self.feature_extractor(audio)
384
+ def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
385
+ parts = prompt.split(boa_token)
386
+ result = ""
387
+ for i in range(len(parts) - 1):
388
+ result += parts[i]
389
+ if i < len(audio_sequences):
390
+ result += audio_sequences[i]
391
+ else:
392
+ result += boa_token
393
+ result += parts[-1]
394
+ return result
395
+ for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
396
+ audio_tokens_expanded = "".join([self.audio_token] * embed_size)
397
+ full_audio_sequence = f"\n\n{self.boa_token}{audio_tokens_expanded}{self.eoa_token}\n\n"
398
+ full_audio_sequences.append(full_audio_sequence)
399
+
400
+ text = [replace_tokens_sequentially(prompt, self.boa_token, [audio_sequences]) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
401
+ #text = [prompt.replace(self.boa_token, audio_sequences) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
402
+
403
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
404
+
405
+ # Add token type ids manually, as tokenizer can't do arbitrary position token types
406
+ array_ids = text_inputs["input_ids"]
407
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
408
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
409
+ mm_token_type_ids[array_ids == self.audio_token_id] = 2
410
+
411
+ has_vision_ids = np.any(mm_token_type_ids == 1, axis=1)
412
+ has_audio_ids = np.any(mm_token_type_ids == 2, axis=1)
413
+
414
+ input_modes = (has_audio_ids << 1) | has_vision_ids
415
+
416
+ text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
417
+ text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
418
+ text_inputs["input_modes"] = input_modes.tolist()
419
+
420
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
421
+
422
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
423
+ def batch_decode(self, *args, **kwargs):
424
+ """
425
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
426
+ refer to the docstring of this method for more information.
427
+ """
428
+ return self.tokenizer.batch_decode(*args, **kwargs)
429
+
430
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
431
+ def decode(self, *args, **kwargs):
432
+ """
433
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
434
+ the docstring of this method for more information.
435
+ """
436
+ return self.tokenizer.decode(*args, **kwargs)
437
+
438
+ @property
439
+ def model_input_names(self):
440
+ tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
441
+ image_processor_input_names = self.image_processor.model_input_names
442
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
443
+
444
+ AutoFeatureExtractor.register("Gemma3AudioFeatureExtractor", Gemma3AudioFeatureExtractor)
cpp/gemma_v1/preprocessor_config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_compression_rate": 8,
3
+ "audio_downsample_rate": 1,
4
+ "audio_feat_stride": 1,
5
+ "compression_rate": 8,
6
+ "do_convert_rgb": null,
7
+ "do_normalize": true,
8
+ "do_pan_and_scan": null,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feat_stride": 1,
12
+ "feature_extractor_type": "Gemma3AudioFeatureExtractor",
13
+ "feature_size": 80,
14
+ "image_mean": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "image_processor_type": "Gemma3ImageProcessor",
20
+ "image_seq_length": 256,
21
+ "image_std": [
22
+ 0.5,
23
+ 0.5,
24
+ 0.5
25
+ ],
26
+ "padding_side": "right",
27
+ "padding_value": 0.0,
28
+ "pan_and_scan_max_num_crops": null,
29
+ "pan_and_scan_min_crop_size": null,
30
+ "pan_and_scan_min_ratio_to_activate": null,
31
+ "processor_class": "Gemma3OmniProcessor",
32
+ "qformer_compression_rate": 1,
33
+ "resample": 2,
34
+ "rescale_factor": 0.00392156862745098,
35
+ "return_attention_mask": true,
36
+ "sampling_rate": 16000,
37
+ "size": {
38
+ "height": 896,
39
+ "width": 896
40
+ }
41
+ }
cpp/gemma_v1/processor_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "preprocessing_gemma3omni.Gemma3OmniProcessor"
4
+ },
5
+ "image_seq_length": 256,
6
+ "processor_class": "Gemma3Processor"
7
+ }
cpp/gemma_v1/special_tokens_map.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_token": "<audio_soft_token>",
3
+ "boa_token": "<start_of_audio>",
4
+ "boi_token": "<start_of_image>",
5
+ "bos_token": {
6
+ "content": "<bos>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "eoa_token": "<end_of_audio>",
13
+ "eoi_token": "<end_of_image>",
14
+ "eos_token": {
15
+ "content": "<eos>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "image_token": "<image_soft_token>",
22
+ "pad_token": {
23
+ "content": "<pad>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false
28
+ },
29
+ "unk_token": {
30
+ "content": "<unk>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ }
36
+ }
cpp/gemma_v1/speech_conformer_encoder.py ADDED
The diff for this file is too large to render. See raw diff
 
cpp/gemma_v1/speech_conformer_encoder_old.py ADDED
The diff for this file is too large to render. See raw diff
 
cpp/gemma_v1/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52941f2ba60fdcc48edb940f4252f6d874d0c369323dab293168015122e556be
3
+ size 33384559
cpp/gemma_v1/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
cpp/gemma_v1/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
cpp/gemma_v1/training.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import sacrebleu
14
+
15
+ from datasets import load_dataset
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+ from tqdm import tqdm
18
+ from transformers import (
19
+ AutoProcessor,
20
+ AutoModel,
21
+ BatchFeature,
22
+ Trainer,
23
+ TrainingArguments,
24
+ StoppingCriteria,
25
+ StoppingCriteriaList,
26
+ )
27
+ from collections import defaultdict
28
+
29
+ import soundfile as sf
30
+ from datasets import Audio
31
+ import random
32
+ ANSWER_SUFFIX = "<end_of_turn>"
33
+ _IGNORE_INDEX = -100
34
+ class BaseAudioDataset(Dataset):
35
+ def __init__(self, processor, split, sampling_rate=16000, debug=False):
36
+ self.processor = processor
37
+ self.training = "train" in split or 'other' in split
38
+ self.debug = debug
39
+ self.sampling_rate = sampling_rate
40
+ self.name = ""
41
+
42
+ def set_dataset_name(self, name):
43
+ self.name = name
44
+
45
+ @staticmethod
46
+ def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
47
+ original_size = len(data)
48
+
49
+ data = data.cast_column(audio_field, Audio(decode=False))
50
+
51
+ def identify_corrupted_files(example):
52
+ try:
53
+ sf.read(example[audio_field]["path"])
54
+
55
+ for field in text_fields:
56
+ if field in example and example[field].replace('"', '') == "":
57
+ return False
58
+ return True
59
+ except Exception:
60
+ return False
61
+
62
+ data = data.filter(identify_corrupted_files, num_proc=16)
63
+ validated_size = len(data)
64
+
65
+ # Audio Decoding
66
+ data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
67
+
68
+ if debug:
69
+ print(f"Dataset: {dataset_name}")
70
+ print(f"Original data nums: {original_size}")
71
+ print(f"After filtering data nums: {validated_size}")
72
+ print(f"Filtering ratio: {validated_size/original_size:.2%}")
73
+
74
+ return data
75
+
76
+ @staticmethod
77
+ def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
78
+ original_size = len(data)
79
+
80
+ def filter_audio_by_length(example):
81
+ try:
82
+ audio = example[audio_field]['array']
83
+ channel = 1
84
+ if hasattr(audio, 'ndim') and audio.ndim > 1:
85
+ channel = audio.ndim
86
+ audio = audio.squeeze()
87
+ audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
88
+ return min_sec <= audio_length <= max_sec
89
+ except Exception as e:
90
+ if debug:
91
+ print(f"Error : {str(e)[:100]}... - sample excluded")
92
+ return False
93
+
94
+ data = data.filter(filter_audio_by_length, num_proc=16)
95
+ filtered_size = len(data)
96
+
97
+ if debug:
98
+ print(f"Before Length Filtering data nums: {original_size}")
99
+ print(f"After Length Filtering data nums: {filtered_size}")
100
+ print(f"Filtering ratio: {filtered_size/original_size:.2%}")
101
+
102
+ return data
103
+
104
+ def prepare_model_inputs(self, audio_array, instruction, answer_text):
105
+ user_message = {
106
+ 'role': 'user',
107
+ 'content': '<start_of_audio>' + instruction,
108
+ }
109
+ prompt = self.processor.tokenizer.apply_chat_template(
110
+ [user_message], tokenize=False, add_generation_prompt=True, add_bos=True
111
+ )
112
+
113
+ inputs = self.processor(
114
+ text=prompt,
115
+ audio=[audio_array],
116
+ add_special_tokens=False,
117
+ return_tensors='pt'
118
+ )
119
+
120
+ answer = f"{answer_text}{ANSWER_SUFFIX}"
121
+ answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
122
+
123
+ if self.debug:
124
+ self.debug = False
125
+ task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
126
+ lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
127
+ print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
128
+ print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
129
+
130
+ if self.training:
131
+ input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
132
+ labels = torch.full_like(input_ids, _IGNORE_INDEX)
133
+ labels[:, -answer_ids.shape[1]:] = answer_ids
134
+ padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
135
+ token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
136
+ else:
137
+ input_ids = inputs.input_ids
138
+ labels = answer_ids
139
+ token_type_ids = inputs.token_type_ids
140
+
141
+ return {
142
+ 'input_ids': input_ids,
143
+ 'labels': labels,
144
+ 'token_type_ids': token_type_ids,
145
+ 'input_audio_embeds': inputs.input_audio_embeds,
146
+ 'audio_embed_sizes': inputs.audio_embed_sizes,
147
+ 'input_modes': inputs.input_modes,
148
+ }
149
+
150
+
151
+ # Libri Speech Dataset Class
152
+ class LibriSpeechDataset(BaseAudioDataset):
153
+ def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
154
+ super().__init__(processor, split, sampling_rate, debug)
155
+
156
+ self.set_dataset_name(f"LibriSpeech_{subset}")
157
+ # only ASR
158
+ self.ast = False
159
+ self.lang = "en"
160
+
161
+ # load dataset
162
+ self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
163
+ subset,
164
+ split=split,
165
+ trust_remote_code=True,
166
+ cache_dir=Path("/mnt/jeff/InCar/data")
167
+ )
168
+
169
+ # (Optional) Audio length Filtering
170
+ self.data = self.filter_by_audio_length(self.data, "audio")
171
+
172
+ # Instruction Setting
173
+ self.instruction = random.choice(INSTRUCTION["asr"])
174
+
175
+ def __len__(self):
176
+ return len(self.data)
177
+
178
+ def __getitem__(self, idx):
179
+ data = self.data[idx]
180
+
181
+ # Libri Speech is only for ASR
182
+ answer_text = data["text"].replace('"', '')
183
+
184
+ return self.prepare_model_inputs(
185
+ data["audio"]["array"],
186
+ self.instruction,
187
+ answer_text
188
+ )
189
+
190
+ # common_voice_16_1 dataset
191
+ class CommonVoiceDataset(BaseAudioDataset):
192
+ def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
193
+ super().__init__(processor, split, sampling_rate, debug)
194
+
195
+ self.set_dataset_name(f"CommonVoice_{source_lang}")
196
+ # only ASR
197
+ self.ast = False
198
+ self.lang=source_lang
199
+
200
+ # load dataset
201
+ if source_lang=="zh-TW":
202
+ data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
203
+ else:
204
+ data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
205
+ self.data = load_dataset(data_path,
206
+ source_lang,
207
+ split=split,
208
+ trust_remote_code=True,
209
+ cache_dir=Path("/mnt/jeff/InCar/data")
210
+ )
211
+ def prepare_dataset(batch):
212
+ """Function to preprocess the dataset with the .map method"""
213
+ transcription = batch["sentence"]
214
+
215
+ if transcription.startswith('"') and transcription.endswith('"'):
216
+ # we can remove trailing quotation marks as they do not affect the transcription
217
+ transcription = transcription[1:-1]
218
+
219
+ if transcription[-1] not in [".", "?", "!"]:
220
+ # append a full-stop to sentences that do not end in punctuation
221
+ transcription = transcription + "."
222
+
223
+ batch["sentence"] = transcription
224
+
225
+ return batch
226
+
227
+
228
+ import opencc
229
+ converter = opencc.OpenCC('s2tw.json')
230
+ def To_zhTW(batch):
231
+
232
+ transcription = converter.convert(batch["sentence"])
233
+ batch["sentence"] = transcription
234
+
235
+ return batch
236
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
237
+ if source_lang=='zh-CN':
238
+ self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
239
+
240
+
241
+ # (Optional) Audio length Filtering
242
+ self.data = self.filter_by_audio_length(self.data, "audio")
243
+
244
+ if source_lang == "zh-TW" and split=='train':
245
+ import torchaudio
246
+ from torchaudio import transforms
247
+ import copy
248
+ import pickle
249
+ import os
250
+ def subsample(batch):
251
+ batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
252
+ batch['audio']['sampling_rate']=16000
253
+ return batch
254
+ def TW_data_augment_fast(batch):
255
+ speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
256
+ new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
257
+ batch['audio']['array'] = new_array_fast
258
+ return batch
259
+ def TW_data_augment_slow(batch):
260
+ speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
261
+ new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
262
+ batch['audio']['array'] = new_array_slow
263
+ return batch
264
+ # data = self.data.map(subsample, num_proc=1, desc="subsample")
265
+ fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
266
+ if not os.path.exists(fast_path):
267
+ data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
268
+ with open(fast_path,'wb') as f:
269
+ pickle.dump(data_fast,f)
270
+ else:
271
+ with open(fast_path,'rb') as f:
272
+ data_fast=pickle.load(f)
273
+
274
+ slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
275
+ if not os.path.exists(slow_path):
276
+ data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
277
+ with open(slow_path,'wb') as f:
278
+ pickle.dump(data_slow,f)
279
+ else:
280
+ with open(slow_path,'rb') as f:
281
+ data_slow=pickle.load(f)
282
+ self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
283
+
284
+ # Instruction Setting
285
+ self.instruction = random.choice(INSTRUCTION["asr"])
286
+
287
+ def __len__(self):
288
+ return len(self.data)
289
+
290
+ def __getitem__(self, idx):
291
+ data = self.data[idx]
292
+
293
+ answer_text = data["sentence"]
294
+ return self.prepare_model_inputs(
295
+ data["audio"]["array"],
296
+ self.instruction,
297
+ answer_text
298
+ )
299
+
300
+
301
+ # Fleurs Dataset Class
302
+ class FleursDataset(BaseAudioDataset):
303
+ def __init__(self, processor, split, source_lang, target_lang=None,
304
+ mode="asr", sampling_rate=16000, debug=False):
305
+ super().__init__(processor, split, sampling_rate, debug)
306
+
307
+ self.set_dataset_name("Fleurs")
308
+ # Mode Setting (ASR or AST)
309
+ if mode not in ["asr", "ast"]:
310
+ raise ValueError("mode must be 'asr' or 'ast'.")
311
+
312
+ self.mode = mode
313
+ self.ast = (mode == "ast")
314
+ self.source_lang = source_lang
315
+
316
+ # Language name mapping (expand if needed)
317
+ self.lang_names = {
318
+ 'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
319
+ }
320
+
321
+ # load dataset - source language dataset
322
+ self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
323
+ source_lang,
324
+ split=split,
325
+ trust_remote_code=True,
326
+ cache_dir=Path("/mnt/jeff/InCar/data")
327
+ )
328
+ import opencc
329
+ converter = opencc.OpenCC('s2tw.json')
330
+ def prepare_dataset(batch):
331
+ transcription = converter.convert(batch["transcription"])
332
+ batch["transcription"] = transcription
333
+
334
+ return batch
335
+ if (source_lang=="cmn_hans_cn"):
336
+ self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
337
+
338
+ # (Optional) Audio length Filtering
339
+ self.data = self.filter_by_audio_length(self.data, "audio")
340
+ self.target_lang_name = ""
341
+ # When AST mode, load target language dataset.
342
+ if self.ast:
343
+ if target_lang is None:
344
+ raise ValueError("AST mode requires target_lang.")
345
+
346
+ self.target_lang = target_lang
347
+ self.lang = f"{source_lang}_{target_lang}"
348
+
349
+ # load dataset - target language dataset (for translation)
350
+ target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
351
+ target_lang,
352
+ split=split,
353
+ trust_remote_code=True,
354
+ cache_dir=Path("/mnt/jeff/InCar/data")
355
+ )
356
+ if target_lang=="cmn_hans_cn":
357
+ target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
358
+ source_dict = {item['id']: item for item in self.data}
359
+ target_dict = {item['id']: item for item in target_data}
360
+
361
+ # only Common ID, add translation fields
362
+ common_ids = set(source_dict.keys()) & set(target_dict.keys())
363
+ print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
364
+ self.data = [
365
+ {**source_dict[id], 'translation': target_dict[id]['transcription']}
366
+ for id in common_ids
367
+ ]
368
+
369
+ # Instruction Setting - use target language name
370
+ self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
371
+ self.instruction = random.choice(INSTRUCTION["ast"])
372
+ else:
373
+ # ASR mode
374
+ self.lang = source_lang
375
+ self.instruction = random.choice(INSTRUCTION["asr"])
376
+
377
+ if self.debug:
378
+ print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
379
+ print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
380
+ if self.ast:
381
+ print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
382
+ print(f"dataset size: {len(self.data)}")
383
+
384
+ def __len__(self):
385
+ return len(self.data)
386
+
387
+ def __getitem__(self, idx):
388
+ data = self.data[idx]
389
+ audio_array = data["audio"]["array"]
390
+
391
+ if self.ast:
392
+ answer_text = data["translation"]
393
+ else:
394
+ answer_text = data["transcription"]
395
+
396
+ return self.prepare_model_inputs(
397
+ audio_array,
398
+ self.instruction.format(self.target_lang_name),
399
+ answer_text
400
+ )
401
+
402
+ def covost_collate_fn(batch):
403
+ input_ids_list = []
404
+ labels_list = []
405
+ token_type_ids_list = []
406
+ input_audio_embeds_list = []
407
+ audio_embed_sizes_list = []
408
+ audio_attention_mask_list = []
409
+ input_modes_list = []
410
+ for inputs in batch:
411
+ input_ids_list.append(inputs['input_ids'][0])
412
+ labels_list.append(inputs['labels'][0])
413
+ token_type_ids_list.append(inputs['token_type_ids'][0])
414
+ input_audio_embeds_list.append(inputs['input_audio_embeds'])
415
+ audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
416
+ audio_attention_mask_list.append(
417
+ inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
418
+ )
419
+ input_modes_list.append(inputs['input_modes'])
420
+
421
+ try:
422
+ token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
423
+ input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
424
+ labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
425
+ audio_attention_mask = (
426
+ pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
427
+ if len(audio_attention_mask_list) > 1
428
+ else None
429
+ )
430
+ except Exception as e:
431
+ print(e)
432
+ print(input_ids_list)
433
+ print(labels_list)
434
+ raise
435
+ attention_mask = (input_ids != 0).long()
436
+ input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
437
+ audio_embed_sizes = torch.cat(audio_embed_sizes_list)
438
+ input_modes = torch.cat(input_modes_list)
439
+
440
+ return BatchFeature(
441
+ {
442
+ 'input_ids': input_ids,
443
+ 'labels': labels,
444
+ 'token_type_ids': token_type_ids,
445
+ 'attention_mask': attention_mask,
446
+ 'input_audio_embeds': input_audio_embeds,
447
+ 'audio_embed_sizes': audio_embed_sizes,
448
+ 'audio_attention_mask': audio_attention_mask,
449
+ 'input_modes': input_modes,
450
+ }
451
+ )
452
+
453
+ def pad_sequence(sequences, padding_side='left', padding_value=0):
454
+ """
455
+ Pad a list of sequences to the same length.
456
+ sequences: list of tensors in [seq_len, *] shape
457
+ """
458
+ assert padding_side in ['right', 'left']
459
+ max_size = sequences[0].size()
460
+ trailing_dims = max_size[1:]
461
+ max_len = max(len(seq) for seq in sequences)
462
+ batch_size = len(sequences)
463
+ output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
464
+ for i, seq in enumerate(sequences):
465
+ length = seq.size(0)
466
+ if padding_side == 'right':
467
+ output.data[i, :length] = seq
468
+ else:
469
+ output.data[i, -length:] = seq
470
+ return output
471
+
472
+ def cat_with_pad(tensors, dim, padding_value=0):
473
+ """
474
+ cat along dim, while pad to max for all other dims
475
+ """
476
+ ndim = tensors[0].dim()
477
+ assert all(
478
+ t.dim() == ndim for t in tensors[1:]
479
+ ), 'All tensors must have the same number of dimensions'
480
+
481
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
482
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
483
+ output = tensors[0].new_full(out_size, padding_value)
484
+
485
+ index = 0
486
+ for t in tensors:
487
+ # Create a slice list where every dimension except dim is full slice
488
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
489
+ # Update only the concat dimension slice
490
+ slices[dim] = slice(index, index + t.shape[dim])
491
+
492
+ output[slices] = t
493
+ index += t.shape[dim]
494
+
495
+ return output
496
+
497
+ def count_parameters_by_module(model):
498
+ # dictionary for parameters number by modules
499
+ module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
500
+
501
+ # all params
502
+ total_params = 0
503
+ total_trainable_params = 0
504
+
505
+ # Check Embedding Token masks
506
+ embedding_masks = {}
507
+ for name, param in model.named_parameters():
508
+ if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
509
+ # check if params has embedding_grad_mask_hook
510
+ for hook_id, hook_fn in param._backward_hooks.items():
511
+ if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
512
+ # Accessing mask variables in the closure of hook functions
513
+ for cell in hook_fn.__closure__ or []:
514
+ if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
515
+ # check mask tensor
516
+ embedding_masks[name] = ~cell.cell_contents # True : Trainable
517
+
518
+ # Count params by modules
519
+ for name, param in model.named_parameters():
520
+ # extracts top module_name
521
+ module_name = name.split('.')[0]
522
+ param_count = param.numel()
523
+
524
+ module_params[module_name]["total"] += param_count
525
+ total_params += param_count
526
+
527
+ if param.requires_grad:
528
+ # Only count for real trainable params. (with masks)
529
+ if name in embedding_masks:
530
+ trainable_count = embedding_masks[name].sum().item()
531
+ module_params[module_name]["trainable"] += trainable_count
532
+ total_trainable_params += trainable_count
533
+ else:
534
+ module_params[module_name]["trainable"] += param_count
535
+ total_trainable_params += param_count
536
+
537
+ print(f"All Params: {total_params:,}")
538
+ print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
539
+ print("\nParams by Module:")
540
+
541
+ for module_name, counts in sorted(module_params.items()):
542
+ trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
543
+ total_percentage = counts["total"] / total_params * 100
544
+
545
+ print(f"- {module_name}:")
546
+ print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
547
+ print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
548
+
549
+ return module_params
550
+
551
+ def create_model(model_name_or_path, revision="main", use_flash_attention = False):
552
+ model = AutoModel.from_pretrained(
553
+ model_name_or_path,
554
+ revision=revision,
555
+ torch_dtype=torch.bfloat16,
556
+ device_map="auto",
557
+ attn_implementation="flash_attention_2" if use_flash_attention else "eager",
558
+ trust_remote_code=True,
559
+ )
560
+
561
+ # Set use_cache to False after model loaded
562
+ model.config.use_cache = False
563
+
564
+ # Freeze all parameters
565
+ for param in model.parameters():
566
+ param.requires_grad = False
567
+
568
+ model.set_lora_adapter('speech')
569
+ model.to(torch.bfloat16)
570
+
571
+ # (Optional) unfreeze audio_tower parameters
572
+ # for param in model.audio_tower.parameters():
573
+ # param.requires_grad = True
574
+
575
+ # Only unfreeze audio_projector parameters
576
+ for param in model.audio_projector.parameters():
577
+ param.requires_grad = True
578
+
579
+ # (Optional) unfreeze audio embed_tokens
580
+ train_embed = True
581
+ if train_embed:
582
+ embed_tokens = model.language_model.model.model.embed_tokens
583
+
584
+ embed_tokens.weight.requires_grad = False
585
+
586
+ # Added Speech token IDs (only this tokens be trainable)
587
+ trainable_token_ids = [256001, 256002]
588
+
589
+ embed_tokens.weight.requires_grad = True
590
+ mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
591
+ mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
592
+
593
+ # backward hook, with gradient masking
594
+ def embedding_grad_mask_hook(grad):
595
+ return grad.masked_fill(mask, 0)
596
+
597
+ embed_tokens.weight.register_hook(embedding_grad_mask_hook)
598
+
599
+ model.language_model.model.model.embed_tokens = embed_tokens
600
+
601
+ count_parameters_by_module(model)
602
+
603
+ return model
604
+
605
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
606
+
607
+ INSTRUCTION = {
608
+ "ast": [
609
+ "Translate the audio to {0}.",
610
+ "Translate the audio clip into {0}.",
611
+ "Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
612
+ "Translate the provided audio file into {0}.",
613
+ "Convert the audio speech to {0} text.",
614
+ "Write an {0} translation of the audio file.",
615
+ "Translate spoken words from the audio into {0}.",
616
+ "Create an {0} version of the audio content.",
617
+ "Produce an accurate {0} translation of the audio.",
618
+ "Extract speech from the audio and translate it to {0}.",
619
+ "Turn the audio into readable {0} text.",
620
+ "Write all spoken content from the audio in {0}.",
621
+ "Generate an {0} translation of the speech in the file.",
622
+ "Convert the recording into {0} text.",
623
+ "Accurately translate the audio recording to {0}.",
624
+ "Write down dialogue from the given audio in {0}.",
625
+ "Translate all speech in this audio file to {0}.",
626
+ "Create an accurate {0} version of the speech.",
627
+ "Perform a complete {0} translation of the audio."
628
+ ],
629
+ "asr": [
630
+ "Transcribe the audio clip into text.",
631
+ "Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
632
+ "Transcribe the provided audio file into text.",
633
+ "Convert the audio speech to text.",
634
+ "Write a transcript of the audio file.",
635
+ "Transcribe spoken words from the audio.",
636
+ "Create a text version of the audio content.",
637
+ "Produce a verbatim transcript of the audio.",
638
+ "Extract and transcribe speech from the audio.",
639
+ "Turn the audio into readable text.",
640
+ "Write all spoken words from the audio.",
641
+ "Generate a transcript of the speech in the file.",
642
+ "Convert the recording into a text transcript.",
643
+ "Accurately transcribe the audio recording.",
644
+ "Write down dialogue from the given audio.",
645
+ "Transcribe all speech in this audio file.",
646
+ "Create an accurate text version of the speech.",
647
+ "Perform a complete transcription of the audio."
648
+ ],
649
+ }
650
+
651
+ ANSWER_SUFFIX = "<end_of_turn>"
652
+ _IGNORE_INDEX = -100
653
+
654
+ model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
655
+ use_flash_attention = True
656
+
657
+ output_dir = '../gemma_tmp7'
658
+ batch_size = 128
659
+ batch_size_per_gpu = 16
660
+ learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
661
+ wd = 0.01
662
+ num_train_epochs = 15
663
+
664
+ revision = "main" #"v1.0"
665
+
666
+ processor = AutoProcessor.from_pretrained(
667
+ model_name_or_path,
668
+ revision=revision,
669
+ trust_remote_code=True,
670
+ )
671
+
672
+ model = create_model(
673
+ model_name_or_path,
674
+ revision=revision,
675
+ use_flash_attention=use_flash_attention,
676
+ )
677
+
678
+ train_datasets = []
679
+
680
+ # common voice asr
681
+ commonvoice_speech_tw2 = CommonVoiceDataset(
682
+ processor=processor,
683
+ source_lang="zh-TW",
684
+ split="other[:70%]"
685
+ )
686
+ train_datasets.append(commonvoice_speech_tw2)
687
+
688
+ commonvoice_speech_cn = CommonVoiceDataset(
689
+ processor=processor,
690
+ source_lang="zh-CN",
691
+ split="train[:50%]"
692
+ )
693
+ train_datasets.append(commonvoice_speech_cn)
694
+
695
+
696
+ commonvoice_speech_tw = CommonVoiceDataset(
697
+ processor=processor,
698
+ source_lang="zh-TW",
699
+ split="train"
700
+ )
701
+ train_datasets.append(commonvoice_speech_tw)
702
+
703
+
704
+
705
+
706
+ # Libri Speech Clean ASR mode (English -> English text)
707
+ libri_speech_clean = LibriSpeechDataset(
708
+ processor=processor,
709
+ subset="clean",
710
+ split="train.360[:50%]"
711
+ )
712
+ train_datasets.append(libri_speech_clean)
713
+
714
+
715
+ # Fleurs ASR mode (English -> English text)
716
+ en_asr_fleurs = FleursDataset(
717
+ processor=processor,
718
+ split="train",
719
+ source_lang="en_us", # English
720
+ mode="asr"
721
+ )
722
+ train_datasets.append(en_asr_fleurs)
723
+
724
+
725
+ # en_ch_ast_fleurs = FleursDataset(
726
+ # processor=processor,
727
+ # split="train",
728
+ # source_lang="en_us",
729
+ # target_lang="cmn_hans_cn",
730
+ # mode="ast"
731
+ # )
732
+ # train_datasets.append(en_ch_ast_fleurs)
733
+
734
+
735
+
736
+ ch_asr_fleurs = FleursDataset(
737
+ processor=processor,
738
+ split="train",
739
+ source_lang="cmn_hans_cn",
740
+ mode="asr"
741
+ )
742
+ train_datasets.append(ch_asr_fleurs)
743
+
744
+
745
+ # ch_en_ast_fleurs = FleursDataset(
746
+ # processor=processor,
747
+ # split="train",
748
+ # source_lang="cmn_hans_cn",
749
+ # target_lang="en_us",
750
+ # mode="ast"
751
+ # )
752
+ # train_datasets.append(ch_en_ast_fleurs)
753
+
754
+ print("Count Num of Datasets", len(train_datasets))
755
+ print([len(dataset) for dataset in train_datasets])
756
+
757
+ # ConcatDataset
758
+ train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
759
+ print("Count Length of Datas", len(train_dataset))
760
+
761
+
762
+
763
+ # Check GPUs
764
+ num_gpus = torch.cuda.device_count()
765
+ print(f'training on {num_gpus} GPUs')
766
+
767
+ assert (
768
+ batch_size % (num_gpus * batch_size_per_gpu) == 0
769
+ ), 'Batch size must be divisible by the number of GPUs'
770
+ gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
771
+
772
+ # hard coded training args
773
+ dp_config = {
774
+ "fp16": {
775
+ "enabled": "auto",
776
+ "loss_scale": 0,
777
+ "loss_scale_window": 1000,
778
+ "initial_scale_power": 16,
779
+ "hysteresis": 2,
780
+ "min_loss_scale": 1
781
+ },
782
+ "zero_optimization": {
783
+ "stage": 2,
784
+ "allgather_partitions": True,
785
+ "allgather_bucket_size": 5e8,
786
+ "overlap_comm": False,
787
+ "reduce_scatter": True,
788
+ "reduce_bucket_size": 5e8,
789
+ "contiguous_gradients": True,
790
+ "cpu_offload": True
791
+ },
792
+
793
+ "train_batch_size": "auto",
794
+ "gradient_accumulation_steps": "auto",
795
+ "optimizer": {
796
+ "type": "AdamW",
797
+ "params": {
798
+ "lr": "auto",
799
+ "betas": 'auto',
800
+ "eps": 'auto',
801
+ "weight_decay": "auto"
802
+ }
803
+ },
804
+ "scheduler": {
805
+ "type": "WarmupDecayLR",
806
+ "params": {
807
+ "warmup_min_lr": "auto",
808
+ "warmup_max_lr": "auto",
809
+ "warmup_num_steps": "auto",
810
+ "total_num_steps": "auto"
811
+ }
812
+ },
813
+ "gradient_clipping": 1.0,
814
+ "zero_optimization": {
815
+ "stage": 0
816
+ }
817
+ }
818
+ training_args = TrainingArguments(
819
+ num_train_epochs=num_train_epochs,
820
+ per_device_train_batch_size=batch_size_per_gpu,
821
+ gradient_checkpointing=True,
822
+ gradient_checkpointing_kwargs={'use_reentrant': False},
823
+ gradient_accumulation_steps=gradient_accumulation_steps,
824
+ optim='adamw_torch',
825
+ adam_beta1=0.9,
826
+ adam_beta2=0.95,
827
+ adam_epsilon=1e-7,
828
+ learning_rate=learning_rate,
829
+ weight_decay=wd,
830
+ max_grad_norm=1.0,
831
+ lr_scheduler_type='cosine',
832
+ warmup_steps=50,
833
+ logging_steps=10,
834
+ output_dir=output_dir,
835
+ save_total_limit=10,
836
+ save_only_model=True,
837
+ bf16=True,
838
+ fp16=False,
839
+ remove_unused_columns=False,
840
+ report_to='none',
841
+ deepspeed=dp_config if num_gpus==1 else None,
842
+ disable_tqdm=False,
843
+ dataloader_num_workers=4,
844
+ save_strategy='steps',
845
+ save_steps=1000,
846
+ ddp_find_unused_parameters=True,
847
+
848
+ )
849
+
850
+ out_path = Path(training_args.output_dir)
851
+ out_path.mkdir(parents=True, exist_ok=True)
852
+
853
+ # create optimizer only for trainable params
854
+ optimizer = torch.optim.AdamW(
855
+ filter(lambda p: p.requires_grad, model.parameters()),
856
+ lr=learning_rate,
857
+ weight_decay=wd,
858
+ betas=(0.9, 0.95),
859
+ eps=1e-7,
860
+ )
861
+
862
+ # Trainer Setting
863
+ trainer = Trainer(
864
+ model=model,
865
+ args=training_args,
866
+ data_collator=covost_collate_fn,
867
+ train_dataset=train_dataset,
868
+ optimizers=(optimizer, None)
869
+ )
870
+
871
+ trainer.train()
872
+
873
+
874
+ # # 1. Save LoRA Adapter
875
+ model.language_model.model.save_pretrained(output_dir)
876
+
877
+ # # 1-1. Delete Markdown file
878
+ # markdown_file = os.path.join(output_dir, "README.md")
879
+ # if os.path.exists(markdown_file):
880
+ # os.remove(markdown_file)
881
+
882
+ # 2. Save entire model
883
+ model.save_pretrained(output_dir)
cpp/gemma_v1/training_multiturn.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import sacrebleu
14
+
15
+ from datasets import load_dataset
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+ from tqdm import tqdm
18
+ from transformers import (
19
+ AutoProcessor,
20
+ AutoModel,
21
+ BatchFeature,
22
+ Trainer,
23
+ TrainingArguments,
24
+ StoppingCriteria,
25
+ StoppingCriteriaList,
26
+ )
27
+ from collections import defaultdict
28
+
29
+ import soundfile as sf
30
+ from datasets import Audio
31
+ import random
32
+ from ASRDataset import *
33
+
34
+
35
+ def count_parameters_by_module(model):
36
+ # dictionary for parameters number by modules
37
+ module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
38
+
39
+ # all params
40
+ total_params = 0
41
+ total_trainable_params = 0
42
+
43
+ # Check Embedding Token masks
44
+ embedding_masks = {}
45
+ for name, param in model.named_parameters():
46
+ if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
47
+ # check if params has embedding_grad_mask_hook
48
+ for hook_id, hook_fn in param._backward_hooks.items():
49
+ if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
50
+ # Accessing mask variables in the closure of hook functions
51
+ for cell in hook_fn.__closure__ or []:
52
+ if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
53
+ # check mask tensor
54
+ embedding_masks[name] = ~cell.cell_contents # True : Trainable
55
+
56
+ # Count params by modules
57
+ for name, param in model.named_parameters():
58
+ # extracts top module_name
59
+ module_name = name.split('.')[0]
60
+ param_count = param.numel()
61
+
62
+ module_params[module_name]["total"] += param_count
63
+ total_params += param_count
64
+
65
+ if param.requires_grad:
66
+ # Only count for real trainable params. (with masks)
67
+ if name in embedding_masks:
68
+ trainable_count = embedding_masks[name].sum().item()
69
+ module_params[module_name]["trainable"] += trainable_count
70
+ total_trainable_params += trainable_count
71
+ else:
72
+ module_params[module_name]["trainable"] += param_count
73
+ total_trainable_params += param_count
74
+
75
+ print(f"All Params: {total_params:,}")
76
+ print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
77
+ print("\nParams by Module:")
78
+
79
+ for module_name, counts in sorted(module_params.items()):
80
+ trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
81
+ total_percentage = counts["total"] / total_params * 100
82
+
83
+ print(f"- {module_name}:")
84
+ print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
85
+ print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
86
+
87
+ return module_params
88
+
89
+ def create_model(model_name_or_path, revision="main", use_flash_attention = False):
90
+ model = AutoModel.from_pretrained(
91
+ model_name_or_path,
92
+ revision=revision,
93
+ torch_dtype=torch.bfloat16,
94
+ device_map="auto",
95
+ attn_implementation="flash_attention_2" if use_flash_attention else "eager",
96
+ trust_remote_code=True,
97
+ )
98
+
99
+ # Set use_cache to False after model loaded
100
+ model.config.use_cache = False
101
+
102
+ # Freeze all parameters
103
+ for param in model.parameters():
104
+ param.requires_grad = False
105
+
106
+ model.set_lora_adapter('speech')
107
+ model.to(torch.bfloat16)
108
+
109
+ # (Optional) unfreeze audio_tower parameters
110
+ # for param in model.audio_tower.parameters():
111
+ # param.requires_grad = True
112
+
113
+ # Only unfreeze audio_projector parameters
114
+ # for param in model.audio_projector.parameters():
115
+ # param.requires_grad = True
116
+
117
+ # (Optional) unfreeze audio embed_tokens
118
+ train_embed = True
119
+ if train_embed:
120
+ embed_tokens = model.language_model.model.model.embed_tokens
121
+
122
+ embed_tokens.weight.requires_grad = False
123
+
124
+ # Added Speech token IDs (only this tokens be trainable)
125
+ trainable_token_ids = [256001, 256002]
126
+
127
+ embed_tokens.weight.requires_grad = True
128
+ mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
129
+ mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
130
+
131
+ # backward hook, with gradient masking
132
+ def embedding_grad_mask_hook(grad):
133
+ return grad.masked_fill(mask, 0)
134
+
135
+ embed_tokens.weight.register_hook(embedding_grad_mask_hook)
136
+
137
+ model.language_model.model.model.embed_tokens = embed_tokens
138
+
139
+ count_parameters_by_module(model)
140
+
141
+ return model
142
+
143
+ ANSWER_SUFFIX = "<end_of_turn>"
144
+ _IGNORE_INDEX = -100
145
+
146
+ ANSWER_SUFFIX = "<end_of_turn>"
147
+ _IGNORE_INDEX = -100
148
+
149
+ model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
150
+ use_flash_attention = False
151
+
152
+ output_dir = '../gemma_tmp13'
153
+ batch_size = 24
154
+ batch_size_per_gpu = 8
155
+ learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
156
+ wd = 0.01
157
+ num_train_epochs = 10
158
+
159
+ revision = "main" #"v1.0"
160
+
161
+ processor = AutoProcessor.from_pretrained(
162
+ model_name_or_path,
163
+ revision=revision,
164
+ trust_remote_code=True,
165
+ )
166
+
167
+ model = create_model(
168
+ model_name_or_path,
169
+ revision=revision,
170
+ use_flash_attention=use_flash_attention,
171
+ )
172
+
173
+ train_datasets = []
174
+
175
+ pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
176
+ train_datasets.append(pickup_dataset)
177
+
178
+ # custom_tw_loc = TWCostumData(processor=processor,
179
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
180
+ # train_datasets.append(custom_tw_loc) # 1500
181
+
182
+ # custom_tw_loc2 = TWCostumData(processor=processor,
183
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
184
+ # train_datasets.append(custom_tw_loc2) # 9458
185
+
186
+ # custom_yating_tw_road = TWCostumData(processor=processor,
187
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
188
+ # train_datasets.append(custom_yating_tw_road) # 35224
189
+
190
+ # custom_tw_road = TWCostumData(processor=processor,
191
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
192
+ # train_datasets.append(custom_tw_road) # 1500
193
+
194
+ # custom_tw_road2 = TWCostumData(processor=processor,
195
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
196
+ # train_datasets.append(custom_tw_road2) # 35224
197
+
198
+
199
+
200
+ print("Count Num of Datasets", len(train_datasets))
201
+ print([len(dataset) for dataset in train_datasets])
202
+
203
+ # ConcatDataset
204
+ train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
205
+ print("Count Length of Datas", len(train_dataset))
206
+
207
+
208
+
209
+ # Check GPUs
210
+ num_gpus = torch.cuda.device_count()
211
+ print(f'training on {num_gpus} GPUs')
212
+
213
+ assert (
214
+ batch_size % (num_gpus * batch_size_per_gpu) == 0
215
+ ), 'Batch size must be divisible by the number of GPUs'
216
+ gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
217
+
218
+ # hard coded training args
219
+ dp_config = {
220
+ "fp16": {
221
+ "enabled": "auto",
222
+ "loss_scale": 0,
223
+ "loss_scale_window": 1000,
224
+ "initial_scale_power": 16,
225
+ "hysteresis": 2,
226
+ "min_loss_scale": 1
227
+ },
228
+ "zero_optimization": {
229
+ "stage": 2,
230
+ "allgather_partitions": True,
231
+ "allgather_bucket_size": 5e8,
232
+ "overlap_comm": False,
233
+ "reduce_scatter": True,
234
+ "reduce_bucket_size": 5e8,
235
+ "contiguous_gradients": True,
236
+ "cpu_offload": True
237
+ },
238
+
239
+ "train_batch_size": "auto",
240
+ "gradient_accumulation_steps": "auto",
241
+ "optimizer": {
242
+ "type": "AdamW",
243
+ "params": {
244
+ "lr": "auto",
245
+ "betas": 'auto',
246
+ "eps": 'auto',
247
+ "weight_decay": "auto"
248
+ }
249
+ },
250
+ "scheduler": {
251
+ "type": "WarmupDecayLR",
252
+ "params": {
253
+ "warmup_min_lr": "auto",
254
+ "warmup_max_lr": "auto",
255
+ "warmup_num_steps": "auto",
256
+ "total_num_steps": "auto"
257
+ }
258
+ },
259
+ "gradient_clipping": 1.0,
260
+ "zero_optimization": {
261
+ "stage": 0
262
+ }
263
+ }
264
+ training_args = TrainingArguments(
265
+ num_train_epochs=num_train_epochs,
266
+ per_device_train_batch_size=batch_size_per_gpu,
267
+ gradient_checkpointing=True,
268
+ gradient_checkpointing_kwargs={'use_reentrant': False},
269
+ gradient_accumulation_steps=gradient_accumulation_steps,
270
+ optim='adamw_torch',
271
+ adam_beta1=0.9,
272
+ adam_beta2=0.95,
273
+ adam_epsilon=1e-7,
274
+ learning_rate=learning_rate,
275
+ weight_decay=wd,
276
+ max_grad_norm=1.0,
277
+ lr_scheduler_type='cosine',
278
+ warmup_steps=50,
279
+ logging_steps=10,
280
+ output_dir=output_dir,
281
+ save_total_limit=10,
282
+ save_only_model=True,
283
+ bf16=True,
284
+ fp16=False,
285
+ remove_unused_columns=False,
286
+ report_to='none',
287
+ deepspeed=None,
288
+ disable_tqdm=False,
289
+ dataloader_num_workers=16,
290
+ save_strategy='epoch',
291
+ # save_steps=2500,
292
+ ddp_find_unused_parameters=True,
293
+
294
+ )
295
+
296
+ out_path = Path(training_args.output_dir)
297
+ out_path.mkdir(parents=True, exist_ok=True)
298
+
299
+ # create optimizer only for trainable params
300
+ optimizer = torch.optim.AdamW(
301
+ filter(lambda p: p.requires_grad, model.parameters()),
302
+ lr=learning_rate,
303
+ weight_decay=wd,
304
+ betas=(0.9, 0.95),
305
+ eps=1e-7,
306
+ )
307
+
308
+ # Trainer Setting
309
+ trainer = Trainer(
310
+ model=model,
311
+ args=training_args,
312
+ data_collator=covost_collate_fn,
313
+ train_dataset=train_dataset,
314
+ optimizers=(optimizer, None)
315
+ )
316
+
317
+ trainer.train()
318
+
319
+
320
+ # # 1. Save LoRA Adapter
321
+ model.language_model.model.save_pretrained(output_dir)
322
+
323
+ # # 1-1. Delete Markdown file
324
+ # markdown_file = os.path.join(output_dir, "README.md")
325
+ # if os.path.exists(markdown_file):
326
+ # os.remove(markdown_file)
327
+
328
+ # 2. Save entire model
329
+ model.save_pretrained(output_dir)
cpp/gemma_v1/training_multiturn_textonly.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
3
+ import os
4
+ os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import sacrebleu
14
+
15
+ from datasets import load_dataset
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+ from tqdm import tqdm
18
+ from transformers import (
19
+ AutoProcessor,
20
+ AutoModel,
21
+ BatchFeature,
22
+ Trainer,
23
+ TrainingArguments,
24
+ StoppingCriteria,
25
+ StoppingCriteriaList,
26
+ )
27
+ from collections import defaultdict
28
+
29
+ import soundfile as sf
30
+ from datasets import Audio
31
+ import random
32
+ from ASRDataset import *
33
+
34
+
35
+ def count_parameters_by_module(model):
36
+ # dictionary for parameters number by modules
37
+ module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
38
+
39
+ # all params
40
+ total_params = 0
41
+ total_trainable_params = 0
42
+
43
+ # Check Embedding Token masks
44
+ embedding_masks = {}
45
+ for name, param in model.named_parameters():
46
+ if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
47
+ # check if params has embedding_grad_mask_hook
48
+ for hook_id, hook_fn in param._backward_hooks.items():
49
+ if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
50
+ # Accessing mask variables in the closure of hook functions
51
+ for cell in hook_fn.__closure__ or []:
52
+ if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
53
+ # check mask tensor
54
+ embedding_masks[name] = ~cell.cell_contents # True : Trainable
55
+
56
+ # Count params by modules
57
+ for name, param in model.named_parameters():
58
+ # extracts top module_name
59
+ module_name = name.split('.')[0]
60
+ param_count = param.numel()
61
+
62
+ module_params[module_name]["total"] += param_count
63
+ total_params += param_count
64
+
65
+ if param.requires_grad:
66
+ # Only count for real trainable params. (with masks)
67
+ if name in embedding_masks:
68
+ trainable_count = embedding_masks[name].sum().item()
69
+ module_params[module_name]["trainable"] += trainable_count
70
+ total_trainable_params += trainable_count
71
+ else:
72
+ module_params[module_name]["trainable"] += param_count
73
+ total_trainable_params += param_count
74
+
75
+ print(f"All Params: {total_params:,}")
76
+ print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
77
+ print("\nParams by Module:")
78
+
79
+ for module_name, counts in sorted(module_params.items()):
80
+ trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
81
+ total_percentage = counts["total"] / total_params * 100
82
+
83
+ print(f"- {module_name}:")
84
+ print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
85
+ print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
86
+
87
+ return module_params
88
+
89
+ def create_model(model_name_or_path, revision="main", use_flash_attention = False):
90
+ model = AutoModel.from_pretrained(
91
+ model_name_or_path,
92
+ revision=revision,
93
+ torch_dtype=torch.bfloat16,
94
+ device_map="auto",
95
+ attn_implementation="flash_attention_2" if use_flash_attention else "eager",
96
+ trust_remote_code=True,
97
+ )
98
+
99
+ # Set use_cache to False after model loaded
100
+ model.config.use_cache = False
101
+
102
+ # Freeze all parameters
103
+ for param in model.parameters():
104
+ param.requires_grad = False
105
+
106
+ model.set_lora_adapter('speech')
107
+ # model.set_lora_adapter('text')
108
+ model.to(torch.bfloat16)
109
+
110
+ # (Optional) unfreeze audio_tower parameters
111
+ # for param in model.audio_tower.parameters():
112
+ # param.requires_grad = True
113
+
114
+ # Only unfreeze audio_projector parameters
115
+ # for param in model.audio_projector.parameters():
116
+ # param.requires_grad = True
117
+
118
+ # (Optional) unfreeze audio embed_tokens
119
+ train_embed = True
120
+ if train_embed:
121
+ embed_tokens = model.language_model.model.model.embed_tokens
122
+
123
+ embed_tokens.weight.requires_grad = False
124
+
125
+ # Added Speech token IDs (only this tokens be trainable)
126
+ trainable_token_ids = [256001, 256002]
127
+
128
+ embed_tokens.weight.requires_grad = True
129
+ mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
130
+ mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
131
+
132
+ # backward hook, with gradient masking
133
+ def embedding_grad_mask_hook(grad):
134
+ return grad.masked_fill(mask, 0)
135
+
136
+ embed_tokens.weight.register_hook(embedding_grad_mask_hook)
137
+
138
+ model.language_model.model.model.embed_tokens = embed_tokens
139
+
140
+ count_parameters_by_module(model)
141
+
142
+ return model
143
+
144
+ ANSWER_SUFFIX = "<end_of_turn>"
145
+ _IGNORE_INDEX = -100
146
+
147
+ ANSWER_SUFFIX = "<end_of_turn>"
148
+ _IGNORE_INDEX = -100
149
+
150
+ model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
151
+ use_flash_attention = False
152
+
153
+ output_dir = '../gemma_tmp14_audio_and_text_speechlora'
154
+ batch_size = 16
155
+ batch_size_per_gpu = 1
156
+ learning_rate = 5.0e-5 # 1.0e-4 for fine-tuning
157
+ wd = 0.01
158
+ num_train_epochs = 10
159
+
160
+ revision = "main" #"v1.0"
161
+
162
+ processor = AutoProcessor.from_pretrained(
163
+ model_name_or_path,
164
+ revision=revision,
165
+ trust_remote_code=True,
166
+ )
167
+
168
+ model = create_model(
169
+ model_name_or_path,
170
+ revision=revision,
171
+ use_flash_attention=use_flash_attention,
172
+ )
173
+
174
+ train_datasets = []
175
+
176
+ pickup_dataset = MultiturnAudioDataset(processor=processor,text_only=True,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
177
+ train_datasets.append(pickup_dataset)
178
+
179
+ pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
180
+ train_datasets.append(pickup_dataset)
181
+
182
+ # custom_tw_loc = TWCostumData(processor=processor,
183
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
184
+ # train_datasets.append(custom_tw_loc) # 1500
185
+
186
+ # custom_tw_loc2 = TWCostumData(processor=processor,
187
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
188
+ # train_datasets.append(custom_tw_loc2) # 9458
189
+
190
+ # custom_yating_tw_road = TWCostumData(processor=processor,
191
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
192
+ # train_datasets.append(custom_yating_tw_road) # 35224
193
+
194
+ # custom_tw_road = TWCostumData(processor=processor,
195
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
196
+ # train_datasets.append(custom_tw_road) # 1500
197
+
198
+ # custom_tw_road2 = TWCostumData(processor=processor,
199
+ # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
200
+ # train_datasets.append(custom_tw_road2) # 35224
201
+
202
+
203
+
204
+ print("Count Num of Datasets", len(train_datasets))
205
+ print([len(dataset) for dataset in train_datasets])
206
+
207
+ # ConcatDataset
208
+ train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
209
+ print("Count Length of Datas", len(train_dataset))
210
+
211
+
212
+
213
+ # Check GPUs
214
+ num_gpus = torch.cuda.device_count()
215
+ print(f'training on {num_gpus} GPUs')
216
+
217
+ assert (
218
+ batch_size % (num_gpus * batch_size_per_gpu) == 0
219
+ ), 'Batch size must be divisible by the number of GPUs'
220
+ gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
221
+
222
+ # hard coded training args
223
+ dp_config = {
224
+ "fp16": {
225
+ "enabled": "auto",
226
+ "loss_scale": 0,
227
+ "loss_scale_window": 1000,
228
+ "initial_scale_power": 16,
229
+ "hysteresis": 2,
230
+ "min_loss_scale": 1
231
+ },
232
+ "zero_optimization": {
233
+ "stage": 2,
234
+ "allgather_partitions": True,
235
+ "allgather_bucket_size": 5e8,
236
+ "overlap_comm": False,
237
+ "reduce_scatter": True,
238
+ "reduce_bucket_size": 5e8,
239
+ "contiguous_gradients": True,
240
+ "cpu_offload": True
241
+ },
242
+
243
+ "train_batch_size": "auto",
244
+ "gradient_accumulation_steps": "auto",
245
+ "optimizer": {
246
+ "type": "AdamW",
247
+ "params": {
248
+ "lr": "auto",
249
+ "betas": 'auto',
250
+ "eps": 'auto',
251
+ "weight_decay": "auto"
252
+ }
253
+ },
254
+ "scheduler": {
255
+ "type": "WarmupDecayLR",
256
+ "params": {
257
+ "warmup_min_lr": "auto",
258
+ "warmup_max_lr": "auto",
259
+ "warmup_num_steps": "auto",
260
+ "total_num_steps": "auto"
261
+ }
262
+ },
263
+ "gradient_clipping": 1.0,
264
+ "zero_optimization": {
265
+ "stage": 0
266
+ }
267
+ }
268
+ training_args = TrainingArguments(
269
+ num_train_epochs=num_train_epochs,
270
+ per_device_train_batch_size=batch_size_per_gpu,
271
+ gradient_checkpointing=True,
272
+ gradient_checkpointing_kwargs={'use_reentrant': False},
273
+ gradient_accumulation_steps=gradient_accumulation_steps,
274
+ optim='adamw_torch',
275
+ adam_beta1=0.9,
276
+ adam_beta2=0.95,
277
+ adam_epsilon=1e-7,
278
+ learning_rate=learning_rate,
279
+ weight_decay=wd,
280
+ max_grad_norm=1.0,
281
+ lr_scheduler_type='cosine',
282
+ warmup_steps=50,
283
+ logging_steps=10,
284
+ output_dir=output_dir,
285
+ save_total_limit=10,
286
+ save_only_model=True,
287
+ bf16=True,
288
+ fp16=False,
289
+ remove_unused_columns=False,
290
+ report_to='none',
291
+ deepspeed=None,
292
+ disable_tqdm=False,
293
+ dataloader_num_workers=16,
294
+ save_strategy='epoch',
295
+ # save_steps=2500,
296
+ ddp_find_unused_parameters=True,
297
+
298
+ )
299
+
300
+ out_path = Path(training_args.output_dir)
301
+ out_path.mkdir(parents=True, exist_ok=True)
302
+
303
+ # create optimizer only for trainable params
304
+ optimizer = torch.optim.AdamW(
305
+ filter(lambda p: p.requires_grad, model.parameters()),
306
+ lr=learning_rate,
307
+ weight_decay=wd,
308
+ betas=(0.9, 0.95),
309
+ eps=1e-7,
310
+ )
311
+
312
+ # Trainer Setting
313
+ trainer = Trainer(
314
+ model=model,
315
+ args=training_args,
316
+ data_collator=covost_collate_fn,
317
+ train_dataset=train_dataset,
318
+ optimizers=(optimizer, None)
319
+ )
320
+
321
+ trainer.train()
322
+
323
+
324
+ # # 1. Save LoRA Adapter
325
+ model.language_model.model.save_pretrained(output_dir)
326
+
327
+ # # 1-1. Delete Markdown file
328
+ # markdown_file = os.path.join(output_dir, "README.md")
329
+ # if os.path.exists(markdown_file):
330
+ # os.remove(markdown_file)
331
+
332
+ # 2. Save entire model
333
+ model.save_pretrained(output_dir)
cpp/inference/audio_encoder_lib.cpp ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "audio_encoder_lib.h"
2
+
3
+ #include <iostream>
4
+ #include <fstream>
5
+ #include <cmath>
6
+ #include <numeric>
7
+ #include <algorithm>
8
+ #include <cstring> // For memcpy
9
+
10
+ // Include specific ONNX Runtime headers for implementation
11
+ #include <onnxruntime_cxx_api.h>
12
+
13
+ // Include specific Eigen headers for implementation
14
+ #include <Eigen/Dense>
15
+
16
+ // Include specific KissFFT headers for implementation
17
+ #include <kiss_fft.h>
18
+ #include <kiss_fftr.h>
19
+
20
+ // Define M_PI if it's not already defined
21
+ #ifndef M_PI
22
+ #define M_PI 3.14159265358979323846
23
+ #endif
24
+
25
+ // --- Global parameters for feature extraction (matching Python script) ---
26
+ // These are constants derived from the Python preprocessing script and are
27
+ // internal to the feature extraction logic.
28
+ namespace { // Anonymous namespace for internal linkage
29
+ const float PREEMPHASIS_COEFF = 0.97f;
30
+ const int N_FFT = 512; // FFT size
31
+ const int WIN_LENGTH = 400; // Window length (samples)
32
+ const int HOP_LENGTH = 160; // Hop length (samples)
33
+ const int N_MELS = 80; // Number of Mel filterbank channels
34
+ const int TARGET_SAMPLE_RATE = 16000; // Target sample rate for feature extraction
35
+ }
36
+
37
+ // --- Implementation of AudioInferenceEngine methods ---
38
+
39
+ AudioInferenceEngine::AudioInferenceEngine(const std::string& modelPath) {
40
+ // 1. Initialize ONNX Runtime Environment
41
+ env_ = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "AudioInferenceEngine");
42
+
43
+ // 2. Configure Session Options
44
+ Ort::SessionOptions session_options;
45
+ session_options.SetIntraOpNumThreads(0);
46
+ session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
47
+
48
+ // 3. Create ONNX Runtime Session
49
+ session_ = std::make_unique<Ort::Session>(*env_, modelPath.c_str(), session_options);
50
+
51
+ // 4. Initialize Allocator
52
+ allocator_ = std::make_unique<Ort::AllocatorWithDefaultOptions>();
53
+
54
+ // 5. Get Input and Output Node Names
55
+ // It's crucial to allocate these names using the allocator and store them
56
+ // as C-style strings for Ort::Session::Run.
57
+ size_t numInputNodes = session_->GetInputCount();
58
+ if (numInputNodes == 0) {
59
+ throw Ort::Exception("ONNX model has no input nodes.", ORT_FAIL);
60
+ }
61
+ input_node_names_.resize(numInputNodes);
62
+ for (size_t i = 0; i < numInputNodes; ++i) {
63
+ input_node_names_[i] = session_->GetInputNameAllocated(i, *allocator_).release(); // release() to manage lifetime
64
+ }
65
+
66
+ size_t numOutputNodes = session_->GetOutputCount();
67
+ if (numOutputNodes == 0) {
68
+ throw Ort::Exception("ONNX model has no output nodes.", ORT_FAIL);
69
+ }
70
+ output_node_names_.resize(numOutputNodes);
71
+ for (size_t i = 0; i < numOutputNodes; ++i) {
72
+ output_node_names_[i] = session_->GetOutputNameAllocated(i, *allocator_).release(); // release() to manage lifetime
73
+ }
74
+
75
+ // 6. Precompute Mel filterbank
76
+ // The Python example uses fmax=16000//2-80-230.
77
+ float mel_fmax = static_cast<float>(TARGET_SAMPLE_RATE) / 2.0f - 80.0f - 230.0f;
78
+ mel_filterbank_ = speechlibMel(TARGET_SAMPLE_RATE, N_FFT, N_MELS, 0.0f, mel_fmax);
79
+
80
+ if (mel_filterbank_.rows() == 0 || mel_filterbank_.cols() == 0) {
81
+ throw std::runtime_error("Failed to create Mel filterbank during initialization.");
82
+ }
83
+
84
+ std::cout << "AudioInferenceEngine initialized successfully with model: " << modelPath << std::endl;
85
+ }
86
+
87
+ AudioInferenceEngine::~AudioInferenceEngine() {
88
+ // Release allocated names
89
+ for (const char* name : input_node_names_) {
90
+ allocator_->Free(const_cast<void*>(reinterpret_cast<const void*>(name)));
91
+ }
92
+ for (const char* name : output_node_names_) {
93
+ allocator_->Free(const_cast<void*>(reinterpret_cast<const void*>(name)));
94
+ }
95
+ // unique_ptr automatically handles deletion of env_ and session_
96
+ }
97
+
98
+ /**
99
+ * @brief Private helper: Loads audio data from a WAV file.
100
+ */
101
+ std::vector<float> AudioInferenceEngine::loadWavToFloatArray(const std::string& filename, int& actual_sample_rate) {
102
+ std::ifstream file(filename, std::ios::binary);
103
+ if (!file.is_open()) {
104
+ std::cerr << "Error: Could not open WAV file: " << filename << std::endl;
105
+ return {};
106
+ }
107
+
108
+ WavHeader header;
109
+ file.read(reinterpret_cast<char*>(&header), sizeof(WavHeader));
110
+
111
+ if (std::string(header.riff_id, 4) != "RIFF" ||
112
+ std::string(header.wave_id, 4) != "WAVE" ||
113
+ std::string(header.fmt_id, 4) != "fmt ") {
114
+ std::cerr << "Error: Invalid WAV header (RIFF, WAVE, or fmt chunk missing/invalid)." << std::endl;
115
+ file.close();
116
+ return {};
117
+ }
118
+
119
+ if (header.audio_format != 1) { // 1 = PCM
120
+ std::cerr << "Error: Only PCM audio format (1) is supported. Found: " << header.audio_format << std::endl;
121
+ file.close();
122
+ return {};
123
+ }
124
+
125
+ if (header.bits_per_sample != 16) {
126
+ std::cerr << "Error: Only 16-bit PCM is supported. Found: " << header.bits_per_sample << " bits per sample." << std::endl;
127
+ file.close();
128
+ return {};
129
+ }
130
+
131
+ actual_sample_rate = header.sample_rate;
132
+
133
+ WavDataChunk data_chunk;
134
+ bool data_chunk_found = false;
135
+ while (!file.eof()) {
136
+ file.read(reinterpret_cast<char*>(&data_chunk.data_id), 4);
137
+ file.read(reinterpret_cast<char*>(&data_chunk.data_size), 4);
138
+
139
+ if (std::string(data_chunk.data_id, 4) == "data") {
140
+ data_chunk_found = true;
141
+ break;
142
+ } else {
143
+ file.seekg(data_chunk.data_size, std::ios::cur);
144
+ }
145
+ }
146
+
147
+ if (!data_chunk_found) {
148
+ std::cerr << "Error: 'data' chunk not found in WAV file." << std::endl;
149
+ file.close();
150
+ return {};
151
+ }
152
+
153
+ std::vector<float> audioData;
154
+ int16_t sample_buffer;
155
+ long num_samples_to_read = data_chunk.data_size / sizeof(int16_t);
156
+
157
+ for (long i = 0; i < num_samples_to_read; ++i) {
158
+ file.read(reinterpret_cast<char*>(&sample_buffer), sizeof(int16_t));
159
+ float normalized_sample = static_cast<float>(sample_buffer) / 32768.0f;
160
+
161
+ if (header.num_channels == 1) {
162
+ audioData.push_back(normalized_sample);
163
+ } else if (header.num_channels == 2) {
164
+ int16_t right_sample;
165
+ if (file.read(reinterpret_cast<char*>(&right_sample), sizeof(int16_t))) {
166
+ float normalized_right_sample = static_cast<float>(right_sample) / 32768.0f;
167
+ audioData.push_back((normalized_sample + normalized_right_sample) / 2.0f);
168
+ i++;
169
+ } else {
170
+ std::cerr << "Warning: Unexpected end of file while reading stereo data." << std::endl;
171
+ break;
172
+ }
173
+ } else {
174
+ std::cerr << "Error: Unsupported number of channels: " << header.num_channels << std::endl;
175
+ file.close();
176
+ return {};
177
+ }
178
+ }
179
+
180
+ file.close();
181
+ return audioData;
182
+ }
183
+
184
+ /**
185
+ * @brief Private helper: Generates a Hamming window.
186
+ */
187
+ std::vector<float> AudioInferenceEngine::generateHammingWindow(int window_length) {
188
+ std::vector<float> window(window_length);
189
+ for (int i = 0; i < window_length; ++i) {
190
+ window[i] = 0.54f - 0.46f * std::cos(2 * M_PI * i / static_cast<float>(window_length - 1));
191
+ }
192
+ return window;
193
+ }
194
+
195
+ /**
196
+ * @brief Private helper: Extracts spectrogram features.
197
+ */
198
+ Eigen::MatrixXf AudioInferenceEngine::extractSpectrogram(const std::vector<float>& wav, int fs) {
199
+ int n_batch = (wav.size() - WIN_LENGTH) / HOP_LENGTH + 1;
200
+ if (n_batch <= 0) {
201
+ return Eigen::MatrixXf(0, N_FFT / 2 + 1);
202
+ }
203
+
204
+ std::vector<float> fft_window = generateHammingWindow(WIN_LENGTH);
205
+
206
+ kiss_fftr_cfg fft_cfg = kiss_fftr_alloc(N_FFT, 0 /* is_inverse_fft */, nullptr, nullptr);
207
+ if (!fft_cfg) {
208
+ std::cerr << "Error: Failed to allocate KissFFT configuration." << std::endl;
209
+ return Eigen::MatrixXf(0, N_FFT / 2 + 1);
210
+ }
211
+
212
+ Eigen::MatrixXf spec_matrix(n_batch, N_FFT / 2 + 1);
213
+
214
+ std::vector<float> frame_buffer(WIN_LENGTH);
215
+ kiss_fft_scalar fft_input[N_FFT];
216
+ kiss_fft_cpx fft_output[N_FFT / 2 + 1];
217
+
218
+ for (int i = 0; i < n_batch; ++i) {
219
+ int start_idx = i * HOP_LENGTH;
220
+
221
+ for (int j = 0; j < WIN_LENGTH; ++j) {
222
+ frame_buffer[j] = wav[start_idx + j];
223
+ }
224
+
225
+ // Apply pre-emphasis and scale by 32768
226
+ if (WIN_LENGTH > 0) {
227
+ if (WIN_LENGTH > 1) {
228
+ // Corrected pre-emphasis to match Python's np.roll and then overwrite first element
229
+ // The first element of the frame is pre-emphasized against the second element.
230
+ fft_input[0] = (frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]) * 32768.0f;
231
+ for (int j = 1; j < WIN_LENGTH; ++j) {
232
+ fft_input[j] = (frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j - 1]) * 32768.0f;
233
+ }
234
+ } else { // WIN_LENGTH == 1
235
+ fft_input[0] = frame_buffer[0] * 32768.0f;
236
+ }
237
+ }
238
+ for (int j = WIN_LENGTH; j < N_FFT; ++j) {
239
+ fft_input[j] = 0.0f;
240
+ }
241
+
242
+ for (int j = 0; j < WIN_LENGTH; ++j) {
243
+ fft_input[j] *= fft_window[j];
244
+ }
245
+
246
+ kiss_fftr(fft_cfg, fft_input, fft_output);
247
+
248
+ for (int j = 0; j <= N_FFT / 2; ++j) {
249
+ spec_matrix(i, j) = std::sqrt(fft_output[j].r * fft_output[j].r + fft_output[j].i * fft_output[j].i);
250
+ }
251
+ }
252
+
253
+ kiss_fftr_free(fft_cfg);
254
+ return spec_matrix;
255
+ }
256
+
257
+ /**
258
+ * @brief Private helper: Creates a Mel filter-bank matrix.
259
+ */
260
+ Eigen::MatrixXf AudioInferenceEngine::speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax) {
261
+ int bank_width = n_fft / 2 + 1;
262
+ if (fmax == 0.0f) fmax = sample_rate / 2.0f;
263
+ if (fmin == 0.0f) fmin = 0.0f;
264
+
265
+ auto mel = [](float f) { return 1127.0f * std::log(1.0f + f / 700.0f); };
266
+ auto bin2mel = [&](int fft_bin) { return 1127.0f * std::log(1.0f + static_cast<float>(fft_bin) * sample_rate / (static_cast<float>(n_fft) * 700.0f)); };
267
+ auto f2bin = [&](float f) { return static_cast<int>((f * n_fft / sample_rate) + 0.5f); };
268
+
269
+ int klo = f2bin(fmin) + 1;
270
+ int khi = f2bin(fmax);
271
+ khi = std::max(khi, klo);
272
+
273
+ float mlo = mel(fmin);
274
+ float mhi = mel(fmax);
275
+
276
+ std::vector<float> m_centers(n_mels + 2);
277
+ float ms = (mhi - mlo) / (n_mels + 1);
278
+ for (int i = 0; i < n_mels + 2; ++i) {
279
+ m_centers[i] = mlo + i * ms;
280
+ }
281
+
282
+ Eigen::MatrixXf matrix = Eigen::MatrixXf::Zero(n_mels, bank_width);
283
+
284
+ for (int m = 0; m < n_mels; ++m) {
285
+ float left = m_centers[m];
286
+ float center = m_centers[m + 1];
287
+ float right = m_centers[m + 2];
288
+ for (int fft_bin = klo; fft_bin < bank_width; ++fft_bin) {
289
+ float mbin = bin2mel(fft_bin);
290
+ if (left < mbin && mbin < right) {
291
+ matrix(m, fft_bin) = 1.0f - std::abs(center - mbin) / ms;
292
+ }
293
+ }
294
+ }
295
+ return matrix;
296
+ }
297
+
298
+ /**
299
+ * @brief Public method: Preprocesses an audio WAV file.
300
+ */
301
+ Eigen::MatrixXf AudioInferenceEngine::preprocessAudio(const std::string& wavFilePath) {
302
+ int actual_wav_sample_rate = 0;
303
+ std::vector<float> audioWav = loadWavToFloatArray(wavFilePath, actual_wav_sample_rate);
304
+
305
+ if (audioWav.empty()) {
306
+ std::cerr << "Failed to load audio data from " << wavFilePath << "." << std::endl;
307
+ return Eigen::MatrixXf(0, N_MELS);
308
+ }
309
+
310
+ if (actual_wav_sample_rate != TARGET_SAMPLE_RATE) {
311
+ std::cerr << "Warning: WAV file sample rate (" << actual_wav_sample_rate
312
+ << " Hz) does not match the target sample rate for feature extraction ("
313
+ << TARGET_SAMPLE_RATE << " Hz)." << std::endl;
314
+ std::cerr << "This example does NOT include resampling. Features will be extracted at "
315
+ << TARGET_SAMPLE_RATE << " Hz, which might lead to incorrect results if the WAV file's sample rate is different." << std::endl;
316
+ }
317
+
318
+ Eigen::MatrixXf spec = extractSpectrogram(audioWav, TARGET_SAMPLE_RATE);
319
+ if (spec.rows() == 0) {
320
+ std::cerr << "Error: Spectrogram extraction failed." << std::endl;
321
+ return Eigen::MatrixXf(0, N_MELS);
322
+ }
323
+
324
+ Eigen::MatrixXf spec_power = spec.array().square();
325
+ Eigen::MatrixXf fbank_power = spec_power * mel_filterbank_.transpose(); // Transpose mel_filterbank_ for correct multiplication
326
+
327
+ fbank_power = fbank_power.array().max(1.0f);
328
+ Eigen::MatrixXf log_fbank = fbank_power.array().log();
329
+
330
+ return log_fbank;
331
+ }
332
+
333
+ /**
334
+ * @brief Public method: Runs inference on the loaded ONNX model.
335
+ */
336
+ std::vector<float> AudioInferenceEngine::runInference(const Eigen::MatrixXf& features) {
337
+ if (features.rows() == 0 || features.cols() == 0) {
338
+ std::cerr << "Error: Input features are empty for inference." << std::endl;
339
+ return {};
340
+ }
341
+
342
+ // Prepare Input Tensor Shape: [batch, frames, feature_size]
343
+ std::vector<int64_t> inputTensorShape = {1, features.rows(), features.cols()};
344
+
345
+ // Flatten Eigen::MatrixXf into std::vector<float> in row-major order
346
+ std::vector<float> inputTensorData(features.rows() * features.cols());
347
+ for (int r = 0; r < features.rows(); ++r) {
348
+ for (int c = 0; c < features.cols(); ++c) {
349
+ inputTensorData[r * features.cols() + c] = features(r, c);
350
+ }
351
+ }
352
+
353
+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
354
+ Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, inputTensorData.data(), inputTensorData.size(),
355
+ inputTensorShape.data(), inputTensorShape.size());
356
+
357
+ if (!inputTensor.IsTensor()) {
358
+ std::cerr << "Error: Created input tensor is not valid!" << std::endl;
359
+ return {};
360
+ }
361
+
362
+ // Run Inference
363
+ std::vector<Ort::Value> outputTensors = session_->Run(Ort::RunOptions{nullptr},
364
+ input_node_names_.data(), &inputTensor, 1,
365
+ output_node_names_.data(), output_node_names_.size());
366
+
367
+ if (outputTensors.empty() || !outputTensors[0].IsTensor()) {
368
+ std::cerr << "Error: No valid output tensors received from the model." << std::endl;
369
+ return {};
370
+ }
371
+
372
+ // Copy output data
373
+ float* outputData = outputTensors[0].GetTensorMutableData<float>();
374
+ Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
375
+ size_t outputSize = outputShapeInfo.GetElementCount();
376
+
377
+ std::vector<float> result(outputData, outputData + outputSize);
378
+ return result;
379
+ }
380
+
381
+ std::vector<Ort::Value> AudioInferenceEngine::runInference_tensor(const Ort::Value& inputTensor) {
382
+ // Run Inference
383
+ std::vector<Ort::Value> outputTensors = session_->Run(Ort::RunOptions{nullptr},
384
+ input_node_names_.data(), &inputTensor, 1,
385
+ output_node_names_.data(), output_node_names_.size());
386
+
387
+ return outputTensors;
388
+ }
cpp/inference/audio_encoder_lib.h ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef AUDIO_INFERENCE_LIBRARY_H
2
+ #define AUDIO_INFERENCE_LIBRARY_H
3
+
4
+ #include <string>
5
+ #include <vector>
6
+ #include <cstdint> // For uint32_t, int16_t
7
+ #include <memory> // For std::unique_ptr
8
+ #include <Eigen/Dense>
9
+ using namespace Eigen;
10
+ // Forward declarations for ONNX Runtime types to avoid including full headers in .h
11
+ namespace Ort {
12
+ struct Env;
13
+ struct Session;
14
+ struct MemoryInfo;
15
+ struct AllocatorWithDefaultOptions;
16
+ struct Value;
17
+ }
18
+
19
+ // Forward declaration for Eigen Matrix
20
+ namespace Eigen {
21
+ template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
22
+ class Matrix;
23
+ typedef Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor, Eigen::Dynamic, Eigen::Dynamic> MatrixXf;
24
+ }
25
+
26
+ /**
27
+ * @brief Class to handle audio preprocessing and ONNX model inference.
28
+ *
29
+ * This class encapsulates the logic for loading WAV files, extracting Mel filterbank
30
+ * features, and running inference on an ONNX model.
31
+ */
32
+ class AudioInferenceEngine {
33
+ public:
34
+ /**
35
+ * @brief Constructor for AudioInferenceEngine.
36
+ * @param modelPath The file path to the ONNX model.
37
+ * @throws Ort::Exception if ONNX Runtime initialization fails.
38
+ */
39
+ AudioInferenceEngine(const std::string& modelPath);
40
+
41
+ /**
42
+ * @brief Destructor to clean up ONNX Runtime resources.
43
+ */
44
+ ~AudioInferenceEngine();
45
+
46
+ /**
47
+ * @brief Preprocesses an audio WAV file to extract Mel filterbank features.
48
+ *
49
+ * This function loads the WAV file, converts it to a float array, and then
50
+ * applies the spectrogram and Mel filterbank extraction steps.
51
+ *
52
+ * @param wavFilePath The path to the WAV audio file.
53
+ * @return An Eigen::MatrixXf containing the extracted features (frames x N_MELS).
54
+ * Returns an empty matrix if preprocessing fails.
55
+ */
56
+ Eigen::MatrixXf preprocessAudio(const std::string& wavFilePath);
57
+
58
+ /**
59
+ * @brief Runs inference on the loaded ONNX model using the provided features.
60
+ *
61
+ * The input features should be the output of `preprocessAudio`. This function
62
+ * converts the features to an ONNX Runtime tensor and executes the model.
63
+ *
64
+ * @param features An Eigen::MatrixXf containing the preprocessed audio features.
65
+ * Expected shape: (frames, N_MELS).
66
+ * @return A std::vector<float> containing the flattened output of the ONNX model.
67
+ * Returns an empty vector if inference fails.
68
+ */
69
+ std::vector<float> runInference(const Eigen::MatrixXf& features);
70
+ std::vector<Ort::Value> runInference_tensor(const Ort::Value& inputTensor);
71
+
72
+ private:
73
+ // ONNX Runtime members
74
+ std::unique_ptr<Ort::Env> env_;
75
+ std::unique_ptr<Ort::Session> session_;
76
+ std::unique_ptr<Ort::AllocatorWithDefaultOptions> allocator_;
77
+ std::vector<const char*> input_node_names_;
78
+ std::vector<const char*> output_node_names_;
79
+
80
+ // Precomputed Mel filterbank matrix
81
+ Eigen::MatrixXf mel_filterbank_;
82
+
83
+ // Private helper functions (implemented in .cpp)
84
+ // WAV file parsing structures
85
+ #pragma pack(push, 1)
86
+ struct WavHeader {
87
+ char riff_id[4];
88
+ uint32_t file_size;
89
+ char wave_id[4];
90
+ char fmt_id[4];
91
+ uint32_t fmt_size;
92
+ uint16_t audio_format;
93
+ uint16_t num_channels;
94
+ uint32_t sample_rate;
95
+ uint32_t byte_rate;
96
+ uint16_t block_align;
97
+ uint16_t bits_per_sample;
98
+ };
99
+
100
+ struct WavDataChunk {
101
+ char data_id[4];
102
+ uint32_t data_size;
103
+ };
104
+ #pragma pack(pop)
105
+
106
+ /**
107
+ * @brief Loads audio data from a WAV file into a float vector.
108
+ * @param filename The path to the WAV audio file.
109
+ * @param actual_sample_rate Output parameter to store the sample rate read from the WAV file.
110
+ * @return A std::vector<float> containing the normalized mono audio samples.
111
+ */
112
+ std::vector<float> loadWavToFloatArray(const std::string& filename, int& actual_sample_rate);
113
+
114
+ /**
115
+ * @brief Generates a Hamming window.
116
+ * @param window_length The length of the window.
117
+ * @return A std::vector<float> containing the Hamming window coefficients.
118
+ */
119
+ std::vector<float> generateHammingWindow(int window_length);
120
+
121
+ /**
122
+ * @brief Extracts spectrogram features from waveform.
123
+ * @param wav The input waveform.
124
+ * @param fs The sampling rate.
125
+ * @return A 2D Eigen::MatrixXf representing the spectrogram.
126
+ */
127
+ Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs);
128
+
129
+ /**
130
+ * @brief Creates a Mel filter-bank matrix.
131
+ * @param sample_rate Sample rate in Hz.
132
+ * @param n_fft FFT size.
133
+ * @param n_mels Mel filter size.
134
+ * @param fmin Lowest frequency (in Hz).
135
+ * @param fmax Highest frequency (in Hz).
136
+ * @return An Eigen::MatrixXf representing the Mel transform matrix.
137
+ */
138
+ Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax);
139
+ };
140
+
141
+ #endif // AUDIO_INFERENCE_LIBRARY_H
cpp/inference/audio_encoder_lib.o ADDED
Binary file (85.1 kB). View file
 
cpp/inference/audio_inference ADDED
Binary file (91 kB). View file
 
cpp/inference/audio_inference_app ADDED
Binary file (97.7 kB). View file
 
cpp/inference/compile.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export BASE_DIR="/mnt/data-2t/jeff/codes/llm/cpp"
2
+ # g++ test.cpp $BASE_DIR/kissfft/kiss_fft.c $BASE_DIR/kissfft/kiss_fftr.c \
3
+ # -o audio_inference \
4
+ # -I $BASE_DIR/onnxruntime-linux-x64-1.22.0/include \
5
+ # -I $BASE_DIR/eigen-3.4.0 \
6
+ # -I $BASE_DIR/kissfft \
7
+ # -L $BASE_DIR/kissfft/lib -lkissfft-int16_t-openmp \
8
+ # -L $BASE_DIR/onnxruntime-linux-x64-1.22.0/lib -lonnxruntime -std=c++17 -O2 -DNDEBUG
9
+
10
+ g++ -c audio_encoder_lib.cpp \
11
+ -o audio_encoder_lib.o \
12
+ -I $BASE_DIR/onnxruntime-linux-x64-1.22.0/include \
13
+ -I $BASE_DIR/eigen-3.4.0 \
14
+ -I $BASE_DIR/kissfft \
15
+ -std=c++17 -O3 -DNDEBUG -fPIC
16
+
17
+ g++ -c $BASE_DIR/kissfft/kiss_fft.c \
18
+ -o kiss_fft.o \
19
+ -I $BASE_DIR/kissfft \
20
+ -std=c++17 -O3 -DNDEBUG -fPIC
21
+
22
+ g++ -c $BASE_DIR/kissfft/kiss_fftr.c \
23
+ -o kiss_fftr.o \
24
+ -I $BASE_DIR/kissfft \
25
+ -std=c++17 -O3 -DNDEBUG -fPIC
26
+
27
+ g++ main_text.cpp audio_encoder_lib.o kiss_fft.o kiss_fftr.o \
28
+ -o audio_inference_app \
29
+ -I $BASE_DIR/onnxruntime-linux-x64-1.22.0/include \
30
+ -I $BASE_DIR/eigen-3.4.0 \
31
+ -I $BASE_DIR/kissfft \
32
+ -L $BASE_DIR/onnxruntime-linux-x64-1.22.0/lib -lonnxruntime -std=c++17 -O3 -DNDEBUG
cpp/inference/dummy.wav ADDED
Binary file (57.5 kB). View file
 
cpp/inference/f0.txt ADDED
The diff for this file is too large to render. See raw diff
 
cpp/inference/f_inp.txt ADDED
The diff for this file is too large to render. See raw diff
 
cpp/inference/kiss_fft.o ADDED
Binary file (14.5 kB). View file
 
cpp/inference/kiss_fftr.o ADDED
Binary file (3.9 kB). View file
 
cpp/inference/main_text.cpp ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream>
2
+ #include <vector>
3
+ #include <fstream>
4
+ #include <string>
5
+ #include <cmath> // For std::sin, M_PI
6
+ #include <cstring> // For std::memcpy
7
+ #include <chrono> // For time measurement
8
+ #include <random> // For random number generation
9
+ #include <ctime> // For seeding random number generator
10
+
11
+ // Include the new library header
12
+ #include "audio_encoder_lib.h"
13
+ #include <onnxruntime_cxx_api.h>
14
+ // Define M_PI if it's not already defined
15
+ #ifndef M_PI
16
+ #define M_PI 3.14159265358979323846
17
+ #endif
18
+
19
+ // --- WAV File Header Structures (for dummy file creation) ---
20
+ #pragma pack(push, 1)
21
+ struct WavHeader {
22
+ char riff_id[4];
23
+ uint32_t file_size;
24
+ char wave_id[4];
25
+ char fmt_id[4];
26
+ uint32_t fmt_size;
27
+ uint16_t audio_format;
28
+ uint16_t num_channels;
29
+ uint32_t sample_rate;
30
+ uint32_t byte_rate;
31
+ uint16_t block_align;
32
+ uint16_t bits_per_sample;
33
+ };
34
+
35
+ struct WavDataChunk {
36
+ char data_id[4];
37
+ uint32_t data_size;
38
+ };
39
+ #pragma pack(pop)
40
+
41
+ // Function to write a dummy WAV file (moved here for example app)
42
+ void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) {
43
+ std::ofstream file(filename, std::ios::binary);
44
+ if (!file.is_open()) {
45
+ std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl;
46
+ return;
47
+ }
48
+
49
+ WavHeader header;
50
+ std::memcpy(header.riff_id, "RIFF", 4);
51
+ std::memcpy(header.wave_id, "WAVE", 4);
52
+ std::memcpy(header.fmt_id, "fmt ", 4);
53
+ header.fmt_size = 16;
54
+ header.audio_format = 1; // PCM
55
+ header.num_channels = numChannels;
56
+ header.sample_rate = sampleRate;
57
+ header.bits_per_sample = bitsPerSample;
58
+ header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8;
59
+ header.block_align = (numChannels * bitsPerSample) / 8;
60
+
61
+ WavDataChunk data_chunk;
62
+ std::memcpy(data_chunk.data_id, "data", 4);
63
+ uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds);
64
+ data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8);
65
+ header.file_size = 36 + data_chunk.data_size; // 36 is size of header before data chunk
66
+
67
+ file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader));
68
+ file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk));
69
+
70
+ // Generate a 440 Hz sine wave
71
+ for (uint32_t i = 0; i < num_samples; ++i) {
72
+ int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate)));
73
+ for (int c = 0; c < numChannels; ++c) {
74
+ file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t));
75
+ }
76
+ }
77
+
78
+ file.close();
79
+ // std::cout << "Dummy WAV file '" << filename << "' created successfully." << std::endl; // Suppress verbose creation message
80
+ }
81
+
82
+ int main(int argc, char* argv[]) {
83
+ // --- 1. Process command-line arguments ---
84
+ if (argc != 3) {
85
+ std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file_for_temp_use>" << std::endl;
86
+ std::cerr << "Example: " << argv[0] << " model.onnx temp_audio.wav" << std::endl;
87
+ return 1;
88
+ }
89
+
90
+ std::string onnxModelPath = argv[1];
91
+ std::string wavFilename = argv[2]; // This will be used as a temporary file
92
+
93
+ // --- Random number generation setup for dummy input frames ---
94
+ std::mt19937 rng(static_cast<unsigned int>(std::time(nullptr))); // Seed with current time
95
+ std::uniform_int_distribution<int> dist_frames(100, 300); // Distribution for frames (100 to 300)
96
+
97
+ // Define fixed parameters for feature extraction to calculate required duration
98
+ const int WIN_LENGTH = 400; // Window length (samples) - must match library's constant
99
+ const int HOP_LENGTH = 160; // Hop length (samples) - must match library's constant
100
+ const int TARGET_SAMPLE_RATE = 16000; // Target sample rate - must match library's constant
101
+
102
+ try {
103
+ // --- 2. Model Initialization ---
104
+ // This will load the ONNX model and precompute the Mel filterbank.
105
+ AudioInferenceEngine engine(onnxModelPath);
106
+ std::cout << "Engine initialized." << std::endl;
107
+
108
+ // --- 3. Model Inference and Time Measurement ---
109
+ std::cout << "\nRunning model inference and measuring time (100 runs with varying input sizes)..." << std::endl;
110
+ int num_runs = 100;
111
+ long long total_inference_time_us = 0; // Use microseconds for finer granularity
112
+
113
+ for (int i = 0; i < num_runs; ++i) {
114
+ // Generate a random number of frames for this run
115
+ int random_frames = dist_frames(rng);
116
+ // Calculate the number of samples needed to produce 'random_frames'
117
+ // frames = (num_samples - WIN_LENGTH) / HOP_LENGTH + 1
118
+ // num_samples = (frames - 1) * HOP_LENGTH + WIN_LENGTH
119
+ long long num_samples_for_frames = static_cast<long long>(random_frames - 1) * HOP_LENGTH + WIN_LENGTH;
120
+ double duration_seconds_for_frames = static_cast<double>(num_samples_for_frames) / TARGET_SAMPLE_RATE;
121
+
122
+ // Create a new dummy WAV file for this specific run
123
+ // This ensures the input size changes for each test.
124
+ createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, duration_seconds_for_frames);
125
+
126
+ // --- Measure the inference time ---
127
+ auto start_time = std::chrono::high_resolution_clock::now();
128
+ Eigen::MatrixXf features = engine.preprocessAudio(wavFilename);
129
+ std::vector<float> model_output = engine.runInference(features);
130
+ auto end_time = std::chrono::high_resolution_clock::now();
131
+
132
+ if (model_output.empty()) {
133
+ std::cerr << "Error: Model inference failed for run " << i + 1 << ". Exiting." << std::endl;
134
+ return 1;
135
+ }
136
+
137
+ // Calculate duration for this run in microseconds
138
+ auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
139
+ total_inference_time_us += duration.count();
140
+
141
+ // Optionally print output for the first run or specific runs
142
+ if (i == 0) {
143
+ std::cout << "First run (frames=" << features.rows() << ")"<< " take : "<< static_cast<double>(total_inference_time_us) / 1000.0 / 1000.0 <<"s output (first few elements): [";
144
+ for (size_t k = 0; k < std::min((size_t)10, model_output.size()); ++k) {
145
+ std::cout << model_output[k] << (k == std::min((size_t)10, model_output.size()) - 1 ? "" : ", ");
146
+ }
147
+ std::cout << "]" << std::endl;
148
+ }
149
+ }
150
+
151
+ double average_inference_time_ms = static_cast<double>(total_inference_time_us) / num_runs / 1000.0 / 1000.0; // Convert microseconds to milliseconds
152
+ std::cout << "\nAverage ONNX model inference time over " << num_runs << " runs (with varying input frames): "
153
+ << average_inference_time_ms << " s" << std::endl;
154
+
155
+ } catch (const Ort::Exception& e) {
156
+ std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
157
+ return 1;
158
+ } catch (const std::exception& e) {
159
+ std::cerr << "Standard Exception: " << e.what() << std::endl;
160
+ return 1;
161
+ }
162
+
163
+ std::cout << "\nProgram finished successfully." << std::endl;
164
+ return 0;
165
+ }
cpp/inference/matrix_output.txt ADDED
The diff for this file is too large to render. See raw diff
 
cpp/inference/run.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ export ONNXRUNTIME_DIR="/mnt/data-2t/jeff/codes/llm/cpp/onnxruntime-linux-x64-1.22.0"
2
+ export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH
3
+
4
+ export MODEL_PATH="/mnt/data-2t/jeff/codes/llm/cpp/onnx_files/speech_init_export/phi-4-mm-speech.onnx"
5
+ export SAMPLE_DATA="/mnt/data-2t/jeff/codes/llm/cpp/inference/dummy.wav"
6
+
7
+ ./audio_inference_app $MODEL_PATH $SAMPLE_DATA
cpp/inference/test copy 2.cpp ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream> // For standard input/output operations (e.g., std::cout, std::cerr)
2
+ #include <vector> // For dynamic arrays (e.g., std::vector<float>)
3
+ #include <fstream> // For file input/output operations (e.g., std::ifstream, std::ofstream)
4
+ #include <cstdint> // For fixed-width integer types (e.g., int16_t)
5
+ #include <cmath> // For mathematical functions (e.g., std::sin, M_PI, std::log)
6
+ #include <numeric> // For numerical operations (e.g., std::iota)
7
+ #include <algorithm> // For algorithms like std::min, std::max
8
+ #include <fstream>
9
+ // Include the ONNX Runtime C++ API header
10
+ #include <onnxruntime_cxx_api.h>
11
+
12
+ // Include Eigen for powerful matrix operations.
13
+ // You need to download Eigen and set up your include paths.
14
+ // E.g., if Eigen is in 'C:/Libraries/eigen-3.4.0', you'd compile with -I C:/Libraries/eigen-3.4.0
15
+ #include <Eigen/Dense>
16
+
17
+ // Include KissFFT for Fast Fourier Transform.
18
+ // You need to download KissFFT and set up your include paths.
19
+ // E.g., if KissFFT is in 'C:/Libraries/kissfft-1.3.0', you'd compile with -I C:/Libraries/kissfft-1.3.0
20
+ // You also need to compile kiss_fft.c and kiss_fftr.c and link them.
21
+ #include "kiss_fft.h"
22
+ #include "kiss_fftr.h" // For real-valued FFT
23
+
24
+ // Define M_PI if it's not already defined by cmath or your compiler.
25
+ #ifndef M_PI
26
+ #define M_PI 3.14159265358979323846
27
+ #endif
28
+
29
+ // --- Global parameters for feature extraction (matching Python script) ---
30
+ const float PREEMPHASIS_COEFF = 0.97f;
31
+ const int N_FFT = 512; // FFT size
32
+ const int WIN_LENGTH = 400; // Window length (samples)
33
+ const int HOP_LENGTH = 160; // Hop length (samples)
34
+ const int N_MELS = 80; // Number of Mel filterbank channels
35
+ const int TARGET_SAMPLE_RATE = 16000; // Target sample rate for feature extraction
36
+
37
+ /**
38
+ * @brief Loads raw PCM audio data from a file into a float vector.
39
+ *
40
+ * This function reads 16-bit signed integer PCM samples from the specified file,
41
+ * converts them to floating-point values, and normalizes them to the range [-1.0, 1.0].
42
+ * It assumes the PCM data is little-endian.
43
+ *
44
+ * @param filename The path to the PCM audio file.
45
+ * @return A std::vector<float> containing the normalized audio samples, or an empty
46
+ * vector if the file cannot be opened.
47
+ */
48
+ std::vector<float> loadPcmToFloatArray(const std::string& filename) {
49
+ std::ifstream file(filename, std::ios::binary);
50
+ if (!file.is_open()) {
51
+ std::cerr << "Error: Could not open PCM file: " << filename << std::endl;
52
+ return {};
53
+ }
54
+
55
+ std::vector<float> audioData;
56
+ int16_t sample;
57
+
58
+ while (file.read(reinterpret_cast<char*>(&sample), sizeof(sample))) {
59
+ audioData.push_back(static_cast<float>(sample) / 32768.0f);
60
+ }
61
+
62
+ file.close();
63
+ return audioData;
64
+ }
65
+
66
+ /**
67
+ * @brief Generates a Hamming window.
68
+ * @param window_length The length of the window.
69
+ * @return A std::vector<float> containing the Hamming window coefficients.
70
+ */
71
+ std::vector<float> generateHammingWindow(int window_length) {
72
+ std::vector<float> window(window_length);
73
+ for (int i = 0; i < window_length; ++i) {
74
+ window[i] = 0.54f - 0.46f * std::cos(2 * M_PI * i / static_cast<float>(window_length - 1));
75
+ }
76
+ return window;
77
+ }
78
+
79
+ /**
80
+ * @brief Extracts spectrogram features from waveform, matching Python's _extract_spectrogram.
81
+ *
82
+ * @param wav The input waveform (1D array of floats).
83
+ * @param fs The sampling rate of the waveform (fixed to 16000 Hz for this model).
84
+ * @return A 2D Eigen::MatrixXf representing the spectrogram (frames x (N_FFT/2 + 1)).
85
+ */
86
+ Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs) {
87
+ // Calculate number of frames
88
+ int n_batch = (wav.size() - WIN_LENGTH) / HOP_LENGTH + 1;
89
+ if (n_batch <= 0) {
90
+ std::cerr << "Warning: Input waveform too short for feature extraction. Returning empty spectrogram." << std::endl;
91
+ return Eigen::MatrixXf(0, N_FFT / 2 + 1);
92
+ }
93
+
94
+ // Generate Hamming window once
95
+ std::vector<float> fft_window = generateHammingWindow(WIN_LENGTH);
96
+ // Initialize KissFFT for real-valued input
97
+ kiss_fftr_cfg fft_cfg = kiss_fftr_alloc(N_FFT, 0 /* is_inverse_fft */, nullptr, nullptr);
98
+ if (!fft_cfg) {
99
+ std::cerr << "Error: Failed to allocate KissFFT configuration." << std::endl;
100
+ return Eigen::MatrixXf(0, N_FFT / 2 + 1);
101
+ }
102
+
103
+ // Output spectrogram matrix: rows = frames, columns = FFT bins
104
+ Eigen::MatrixXf spec_matrix(n_batch, N_FFT / 2 + 1);
105
+
106
+ std::vector<float> frame_buffer(WIN_LENGTH);
107
+ std::vector<float> prev_frame_buffer(WIN_LENGTH);
108
+ kiss_fft_scalar fft_input[N_FFT]; // KissFFT requires input buffer of size N_FFT
109
+ kiss_fft_cpx fft_output[N_FFT / 2 + 1]; // KissFFT real output size
110
+
111
+ for (int i = 0; i < n_batch; ++i) {
112
+ int start_idx = i * HOP_LENGTH;
113
+
114
+ // Extract current frame
115
+ for (int j = 0; j < WIN_LENGTH; ++j) {
116
+ frame_buffer[j] = wav[start_idx + j];
117
+ }
118
+
119
+ // Prepare previous frame for pre-emphasis (np.roll equivalent)
120
+ // y_frames_prev = np.roll(y_frames, 1, axis=1)
121
+ // y_frames_prev[:, 0] = y_frames_prev[:, 1]
122
+ prev_frame_buffer[0] = frame_buffer[0]; // Python's np.roll(..., 1) with axis=1 makes first element wrap around
123
+ // but then it's overwritten by y_frames_prev[:, 1]
124
+ if (WIN_LENGTH > 1) {
125
+ for (int j = 0; j < WIN_LENGTH - 1; ++j) {
126
+ prev_frame_buffer[j + 1] = frame_buffer[j];
127
+ }
128
+ }
129
+ // Correcting the first element as per Python code: y_frames_prev[:, 0] = y_frames_prev[:, 1]
130
+ // This means the first element of the 'previous' frame is actually the second element of the 'current' frame.
131
+ // For the first frame (i=0), prev_frame_buffer[0] should be frame_buffer[1] if WIN_LENGTH > 1.
132
+ // For subsequent frames, this logic applies to the *current* frame's first sample relative to its second.
133
+ // The original Python code effectively does:
134
+ // y_frames_prev = np.concatenate((y_frames[:, 1:2], y_frames[:, :-1]), axis=1)
135
+ // This is a bit tricky. Let's simplify and apply pre-emphasis directly to the current frame elements.
136
+ // The Python code applies pre-emphasis *within* each batch/frame.
137
+ // y_frames = (y_frames - preemphasis * y_frames_prev)
138
+ // y_frames_prev[:, 0] = y_frames_prev[:, 1] means the first element of the previous frame is taken from the second element of the *current* frame.
139
+ // This is equivalent to: frame[j] - preemphasis * (j == 0 ? frame[1] : frame[j-1])
140
+ // Let's use a temporary buffer for pre-emphasized frame.
141
+ std::vector<float> preemphasized_frame(WIN_LENGTH);
142
+ if (WIN_LENGTH > 0) {
143
+ preemphasized_frame[0] = frame_buffer[0]; // First sample is not pre-emphasized against a previous sample
144
+ if (WIN_LENGTH > 1) {
145
+ for (int j = 1; j < WIN_LENGTH; ++j) {
146
+ preemphasized_frame[j] = frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j - 1];
147
+ }
148
+ }
149
+ }
150
+ // Apply pre-emphasis and scale by 32768 (as in Python)
151
+ for (int j = 0; j < WIN_LENGTH; ++j) {
152
+ fft_input[j] = preemphasized_frame[j] * 32768.0f;
153
+ // Pad with zeros if WIN_LENGTH < N_FFT
154
+ if (j >= WIN_LENGTH) {
155
+ fft_input[j] = 0.0f;
156
+ }
157
+ }
158
+ // Zero-pad the rest of the FFT input if WIN_LENGTH < N_FFT
159
+ for (int j = WIN_LENGTH; j < N_FFT; ++j) {
160
+ fft_input[j] = 0.0f;
161
+ }
162
+ // Apply Hamming window
163
+ for (int j = 0; j < WIN_LENGTH; ++j) {
164
+ fft_input[j] *= fft_window[j];
165
+ }
166
+ // Perform real FFT
167
+ kiss_fftr(fft_cfg, fft_input, fft_output);
168
+ // Calculate magnitude spectrogram
169
+ for (int j = 0; j <= N_FFT / 2; ++j) {
170
+ spec_matrix(i, j) = std::sqrt(fft_output[j].r * fft_output[j].r + fft_output[j].i * fft_output[j].i);
171
+ }
172
+ }
173
+ kiss_fftr_free(fft_cfg); // Free KissFFT configuration
174
+ return spec_matrix;
175
+ }
176
+
177
+ /**
178
+ * @brief Creates a Mel filter-bank matrix, matching Python's speechlib_mel.
179
+ *
180
+ * @param sample_rate Sample rate in Hz.
181
+ * @param n_fft FFT size.
182
+ * @param n_mels Mel filter size.
183
+ * @param fmin Lowest frequency (in Hz).
184
+ * @param fmax Highest frequency (in Hz).
185
+ * @return An Eigen::MatrixXf representing the Mel transform matrix (n_mels x (1 + n_fft/2)).
186
+ */
187
+ Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax) {
188
+ int bank_width = n_fft / 2 + 1;
189
+ if (fmax == 0.0f) fmax = sample_rate / 2.0f; // Use 0.0f as a sentinel for None
190
+ if (fmin == 0.0f) fmin = 0.0f; // Use 0.0f as a sentinel for None
191
+
192
+ // Helper functions for Mel scale conversion
193
+ auto mel = [](float f) { return 1127.0f * std::log(1.0f + f / 700.0f); };
194
+ auto bin2mel = [&](int fft_bin) { return 1127.0f * std::log(1.0f + static_cast<float>(fft_bin) * sample_rate / (static_cast<float>(n_fft) * 700.0f)); };
195
+ auto f2bin = [&](float f) { return static_cast<int>((f * n_fft / sample_rate) + 0.5f); };
196
+
197
+ // Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax)]
198
+ int klo = f2bin(fmin) + 1;
199
+ int khi = f2bin(fmax);
200
+ khi = std::max(khi, klo);
201
+
202
+ // Spec 2: SpeechLib uses triangles in Mel space
203
+ float mlo = mel(fmin);
204
+ float mhi = mel(fmax);
205
+
206
+ // Generate Mel centers
207
+ std::vector<float> m_centers(n_mels + 2);
208
+ float ms = (mhi - mlo) / (n_mels + 1);
209
+ for (int i = 0; i < n_mels + 2; ++i) {
210
+ m_centers[i] = mlo + i * ms;
211
+ }
212
+
213
+ Eigen::MatrixXf matrix = Eigen::MatrixXf::Zero(n_mels, bank_width);
214
+
215
+ for (int m = 0; m < n_mels; ++m) {
216
+ float left = m_centers[m];
217
+ float center = m_centers[m + 1];
218
+ float right = m_centers[m + 2];
219
+ for (int fft_bin = klo; fft_bin < bank_width; ++fft_bin) { // Loop up to bank_width-1
220
+ float mbin = bin2mel(fft_bin);
221
+ if (left < mbin && mbin < right) {
222
+ matrix(m, fft_bin) = 1.0f - std::abs(center - mbin) / ms;
223
+ }
224
+ }
225
+ }
226
+ matrix.transposeInPlace();
227
+ return matrix;
228
+ }
229
+
230
+ /**
231
+ * @brief Extracts log filterbank features from waveform, matching Python's _extract_features.
232
+ *
233
+ * @param wav The input waveform (1D array of floats).
234
+ * @param fs The sampling rate of the waveform (fixed to 16000 Hz).
235
+ * @param mel_filterbank The pre-computed Mel filterbank matrix.
236
+ * @return An Eigen::MatrixXf representing the log Mel filterbank features (frames x N_MELS).
237
+ */
238
+ Eigen::MatrixXf extractFeatures(const std::vector<float>& wav, int fs, const Eigen::MatrixXf& mel_filterbank) {
239
+ // Extract spectrogram
240
+ Eigen::MatrixXf spec = extractSpectrogram(wav, fs);
241
+ if (spec.rows() == 0) {
242
+ return Eigen::MatrixXf(0, N_MELS); // Return empty matrix if spectrogram extraction failed
243
+ }
244
+
245
+ // spec_power = spec**2
246
+ Eigen::MatrixXf spec_power = spec.array().square();
247
+
248
+ // fbank_power = np.clip(spec_power.dot(_mel), 1.0, None)
249
+ // Note: Eigen's matrix multiplication is `*`, not `dot`.
250
+ // The Python `dot` for 2D arrays is matrix multiplication.
251
+ // Python: (frames, N_FFT/2+1) . (N_FFT/2+1, N_MELS) -> (frames, N_MELS)
252
+ // C++ Eigen: spec_power (rows, cols) * mel_filterbank (cols, N_MELS)
253
+ // So, mel_filterbank should be (N_FFT/2+1, N_MELS)
254
+ Eigen::MatrixXf fbank_power = spec_power * mel_filterbank;
255
+
256
+ // Apply clipping: np.clip(..., 1.0, None)
257
+ // This means any value less than 1.0 becomes 1.0.
258
+ fbank_power = fbank_power.array().max(1.0f);
259
+
260
+ // log_fbank = np.log(fbank_power).astype(np.float32)
261
+ Eigen::MatrixXf log_fbank = fbank_power.array().log();
262
+
263
+ return log_fbank;
264
+ }
265
+
266
+
267
+ int main(int argc, char* argv[]) {
268
+ // --- 1. Process command-line arguments ---
269
+ if (argc != 3) {
270
+ std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_pcm_file>" << std::endl;
271
+ std::cerr << "Example: " << argv[0] << " model.onnx audio.pcm" << std::endl;
272
+ return 1;
273
+ }
274
+
275
+ std::string onnxModelPath = argv[1];
276
+ std::string pcmFilename = argv[2];
277
+
278
+ // --- Configuration for Audio and ONNX Model ---
279
+ // These are fixed by the Python preprocessor code and model requirements.
280
+ int bitDepth = 16;
281
+ // numChannels is handled within loadPcmToFloatArray and then implicitly by feature extraction
282
+ // which squeezes to 1D and takes mean if stereo. For simplicity, we assume mono PCM input.
283
+ // If your PCM is stereo, you'd need to adjust loadPcmToFloatArray to handle channel interleaving
284
+ // and then average or select a channel before passing to extractSpectrogram.
285
+ int numChannels = 1;
286
+
287
+ // --- Create a dummy PCM file if it doesn't exist for demonstration ---
288
+ // This is helpful for initial testing without needing an actual PCM file.
289
+ std::ifstream pcmCheck(pcmFilename, std::ios::binary);
290
+ if (!pcmCheck.is_open()) {
291
+ std::cerr << "PCM file '" << pcmFilename << "' not found. Creating a dummy one for demonstration." << std::endl;
292
+ std::ofstream dummyPcmFile(pcmFilename, std::ios::binary);
293
+ if (dummyPcmFile.is_open()) {
294
+ std::cout << "Creating a dummy PCM file: " << pcmFilename << " ("
295
+ << (TARGET_SAMPLE_RATE * 2 * sizeof(int16_t)) / 1024 << " KB)" << std::endl;
296
+ for (int i = 0; i < TARGET_SAMPLE_RATE * 2; ++i) { // Generate 2 seconds of audio
297
+ int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(TARGET_SAMPLE_RATE)));
298
+ dummyPcmFile.write(reinterpret_cast<char*>(&sample), sizeof(sample));
299
+ }
300
+ dummyPcmFile.close();
301
+ } else {
302
+ std::cerr << "Error: Could not create dummy PCM file '" << pcmFilename
303
+ << "'. Please ensure the directory is writable." << std::endl;
304
+ return 1;
305
+ }
306
+ } else {
307
+ pcmCheck.close();
308
+ }
309
+
310
+
311
+ // --- 2. Load PCM audio data into a float array ---
312
+ std::vector<float> audioWav = loadPcmToFloatArray(pcmFilename);
313
+
314
+ if (audioWav.empty()) {
315
+ std::cerr << "Failed to load audio data from " << pcmFilename << ". Exiting." << std::endl;
316
+ return 1;
317
+ }
318
+
319
+ std::cout << "Successfully loaded " << audioWav.size() << " samples from " << pcmFilename << std::endl;
320
+
321
+ // --- 3. Precompute Mel filterbank (as it's constant for a given sample rate/FFT size) ---
322
+ // The Python example uses fmax=16000//2-80-230. This translates to TARGET_SAMPLE_RATE/2 - 80 - 230.
323
+ // Using 0.0f for fmin as sentinel for None.
324
+ float mel_fmax = static_cast<float>(TARGET_SAMPLE_RATE) / 2.0f - 80.0f - 230.0f;
325
+ Eigen::MatrixXf mel_filterbank = speechlibMel(TARGET_SAMPLE_RATE, N_FFT, N_MELS, 0.0f, mel_fmax);
326
+
327
+ if (mel_filterbank.rows() == 0 || mel_filterbank.cols() == 0) {
328
+ std::cerr << "Error: Failed to create Mel filterbank. Exiting." << std::endl;
329
+ return 1;
330
+ }
331
+ std::cout << "Mel filterbank created with shape: [" << mel_filterbank.rows() << ", " << mel_filterbank.cols() << "]" << std::endl;
332
+
333
+
334
+ // --- 4. Apply feature extraction (preprocessor) ---
335
+ std::cout << "Extracting features from audio..." << std::endl;
336
+ Eigen::MatrixXf features = extractFeatures(audioWav, TARGET_SAMPLE_RATE, mel_filterbank);
337
+
338
+ std::ofstream outputFile("matrix_output.txt");
339
+ // Check if the file was opened successfully
340
+ if (outputFile.is_open()) {
341
+ // Iterate through rows and columns to write elements
342
+ for (int i = 0; i < features.rows(); ++i) {
343
+ for (int j = 0; j < features.cols(); ++j) {
344
+ outputFile << features(i, j); // Write the element
345
+ if (j < features.cols() - 1) {
346
+ outputFile << ","; // Add a space separator between elements in a row
347
+ }
348
+ }
349
+ outputFile << std::endl; // Move to the next line after each row
350
+ }
351
+ outputFile.close(); // Close the file
352
+ std::cout << "Matrix successfully written to matrix_output.txt" << std::endl;
353
+ }
354
+
355
+
356
+ if (features.rows() == 0 || features.cols() == 0) {
357
+ std::cerr << "Error: Feature extraction resulted in an empty matrix. Exiting." << std::endl;
358
+ return 1;
359
+ }
360
+ std::cout << "Features extracted with shape: [" << features.rows() << ", " << features.cols() << "]" << std::endl;
361
+ std::cout << "First few feature values (first frame): [";
362
+ for (int i = 0; i < std::min((int)features.cols(), 5); ++i) {
363
+ std::cout << features(0, i) << (i == std::min((int)features.cols(), 5) - 1 ? "" : ", ");
364
+ }
365
+ std::cout << "]" << std::endl;
366
+
367
+ // --- 5. Check for ONNX model existence and provide guidance if missing ---
368
+ std::ifstream onnxModelCheck(onnxModelPath, std::ios::binary);
369
+ if (!onnxModelCheck.is_open()) {
370
+ std::cerr << "\nError: ONNX model file '" << onnxModelPath << "' not found." << std::endl;
371
+ std::cerr << "Please provide a valid ONNX model file. If you need a simple dummy one for testing, "
372
+ << "you can create it using Python (e.g., with PyTorch) like this:" << std::endl;
373
+ std::cerr << "```python" << std::endl;
374
+ std::cerr << "import torch" << std::endl;
375
+ std::cerr << "import torch.nn as nn" << std::endl;
376
+ std::cerr << "" << std::endl;
377
+ std::cerr << "class SimpleAudioModel(nn.Module):" << std::endl;
378
+ std::cerr << " def __init__(self, input_frames, feature_size, output_size):" << std::endl;
379
+ std::cerr << " super(SimpleAudioModel, self).__init__()" << std::endl;
380
+ std::cerr << " # This model expects input of shape [batch_size, frames, feature_size]" << std::endl;
381
+ std::cerr << " # Example: a simple linear layer that flattens input and processes it." << std::endl;
382
+ std::cerr << " self.flatten = nn.Flatten()" << std::endl;
383
+ std::cerr << " self.linear = nn.Linear(input_frames * feature_size, output_size)" << std::endl;
384
+ std::cerr << "" << std::endl;
385
+ std::cerr << " def forward(self, x):" << std::endl;
386
+ std::cerr << " x = self.flatten(x)" << std::endl;
387
+ std::cerr << " return self.linear(x)" << std::endl;
388
+ std::cerr << "" << std::endl;
389
+ std::cerr << "# --- IMPORTANT: Define model input and output sizes. Adjust these to match your actual model's requirements. ---" << std::endl;
390
+ std::cerr << "# The C++ preprocessor will produce features of shape [frames, 80]." << std::endl;
391
+ std::cerr << "# For a dummy model, we need to provide a fixed 'frames' value for ONNX export." << std::endl;
392
+ std::cerr << "# A typical audio segment might be 2 seconds at 16kHz, which is 32000 samples." << std::endl;
393
+ std::cerr << "# Frames = (32000 - 400) / 160 + 1 = 198.75 + 1 = 199 frames (approx)" << std::endl;
394
+ std::cerr << "# Let's use a representative number of frames, e.g., 200 for a dummy input." << std::endl;
395
+ std::cerr << "DUMMY_INPUT_FRAMES = 200 # This should be representative of your typical audio segment's frames" << std::endl;
396
+ std::cerr << "DUMMY_FEATURE_SIZE = 80 # Fixed by the Mel filterbank (N_MELS)" << std::endl;
397
+ std::cerr << "DUMMY_OUTPUT_SIZE = 10 # Example: 10 classification scores or features" << std::endl;
398
+ std::cerr << "" << std::endl;
399
+ std::cerr << "model = SimpleAudioModel(DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE, DUMMY_OUTPUT_SIZE)" << std::endl;
400
+ std::cerr << "dummy_input_tensor = torch.randn(1, DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE) # Batch size 1" << std::endl;
401
+ std::cerr << "" << std::endl;
402
+ std::cerr << "torch.onnx.export(" << std::endl;
403
+ std::cerr << " model," << std::endl;
404
+ std::cerr << " dummy_input_tensor," << std::endl;
405
+ std::cerr << " \"model.onnx\"," << std::endl;
406
+ std::cerr << " verbose=True," << std::endl;
407
+ std::cerr << " input_names=['input'], # Name of the input tensor in the ONNX graph" << std::endl;
408
+ std::cerr << " output_names=['output'], # Name of the output tensor in the ONNX graph" << std::endl;
409
+ std::cerr << " # Define dynamic axes for batch_size and frames" << std::endl;
410
+ std::cerr << " dynamic_axes={'input': {0: 'batch_size', 1: 'frames'}, 'output': {0: 'batch_size'}}" << std::endl;
411
+ std::cerr << ")" << std::endl;
412
+ std::cerr << "print(\"Dummy model.onnx created successfully. Remember to adjust DUMMY_INPUT_FRAMES in this script to match the expected number of frames from your audio segments.\")" << std::endl;
413
+ std::cerr << "```" << std::endl;
414
+ return 1;
415
+ }
416
+ onnxModelCheck.close();
417
+ std::cout << "ONNX model '" << onnxModelPath << "' found. Proceeding with inference." << std::endl;
418
+
419
+
420
+ // --- 6. ONNX Runtime Inference ---
421
+ try {
422
+ Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "AudioInference");
423
+ Ort::SessionOptions session_options;
424
+ session_options.SetIntraOpNumThreads(1);
425
+ // session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
426
+
427
+ Ort::Session session(env, onnxModelPath.c_str(), session_options);
428
+ std::cout << "Model loaded successfully from: " << onnxModelPath << std::endl;
429
+ Ort::AllocatorWithDefaultOptions allocator;
430
+
431
+ // --- Get Input Node Information ---
432
+ size_t numInputNodes = session.GetInputCount();
433
+ std::vector<const char*> inputNodeNames(numInputNodes);
434
+
435
+ std::cout << "\n--- Model Input Information ---" << std::endl;
436
+ if (numInputNodes == 0) {
437
+ std::cerr << "Error: Model has no input nodes. Exiting." << std::endl;
438
+ return 1;
439
+ }
440
+
441
+ // Assuming a single input node for simplicity
442
+ inputNodeNames[0] = "audio_embeds";
443
+ Ort::TypeInfo type_info = session.GetInputTypeInfo(0);
444
+ auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
445
+ std::vector<int64_t> actualInputShape = tensor_info.GetShape();
446
+
447
+ std::cout << " Input 0 : Name='" << inputNodeNames[0] << "', Shape=[";
448
+ for (size_t j = 0; j < actualInputShape.size(); ++j) {
449
+ // Print -1 for dynamic dimensions
450
+ if (actualInputShape[j] == -1) {
451
+ std::cout << "-1";
452
+ } else {
453
+ std::cout << actualInputShape[j];
454
+ }
455
+ std::cout << (j == actualInputShape.size() - 1 ? "" : ", ");
456
+ }
457
+ std::cout << "]" << std::endl;
458
+
459
+ // --- Prepare Input Tensor Shape ---
460
+ // The ONNX model input is [batch, frames, feature_size] = [-1, -1, 80]
461
+ // Our extracted features are [frames, 80]. We need to add a batch dimension of 1.
462
+ std::vector<int64_t> inputTensorShape = {1, features.rows(), features.cols()};
463
+ std::cout << " Preparing input tensor with shape: [" << inputTensorShape[0] << ", "
464
+ << inputTensorShape[1] << ", " << inputTensorShape[2] << "]" << std::endl;
465
+
466
+ // Flatten the Eigen::MatrixXf into a std::vector<float> for ONNX Runtime
467
+ std::vector<float> inputTensorData(features.data(), features.data() + features.size());
468
+
469
+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
470
+ Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, inputTensorData.data(), inputTensorData.size(),
471
+ inputTensorShape.data(), inputTensorShape.size());
472
+
473
+ if (!inputTensor.IsTensor()) {
474
+ std::cerr << "Error: Created input tensor is not valid! Exiting." << std::endl;
475
+ return 1;
476
+ }
477
+
478
+ // --- Get Output Node Information ---
479
+ size_t numOutputNodes = session.GetOutputCount();
480
+ std::vector<const char*> outputNodeNames(numOutputNodes);
481
+
482
+ std::cout << "\n--- Model Output Information ---" << std::endl;
483
+ for (size_t k = 0; k < numOutputNodes; ++k) {
484
+ outputNodeNames[k] = "audio_features";
485
+ Ort::TypeInfo type_info_out = session.GetOutputTypeInfo(k);
486
+ auto tensor_info_out = type_info_out.GetTensorTypeAndShapeInfo();
487
+ std::vector<int64_t> outputShape = tensor_info_out.GetShape();
488
+ std::cout << " Output " << k << " : Name='" << outputNodeNames[k] << "', Shape=[";
489
+ for (size_t l = 0; l < outputShape.size(); ++l) {
490
+ if (outputShape[l] == -1) {
491
+ std::cout << "-1";
492
+ } else {
493
+ std::cout << outputShape[l];
494
+ }
495
+ std::cout << (l == outputShape.size() - 1 ? "" : ", ");
496
+ }
497
+ std::cout << "]" << std::endl;
498
+ }
499
+
500
+ // --- Run Inference ---
501
+ std::cout << "\nRunning ONNX model inference..." << std::endl;
502
+ std::vector<Ort::Value> outputTensors = session.Run(Ort::RunOptions{nullptr},
503
+ inputNodeNames.data(), &inputTensor, 1,
504
+ outputNodeNames.data(), numOutputNodes);
505
+ std::ofstream output_file("f0.txt");
506
+ for (auto& ort_value : outputTensors) {
507
+ // Example: Assuming Ort::Value contains a float tensor
508
+ if (ort_value.IsTensor()) {
509
+ float* data = ort_value.GetTensorMutableData<float>();
510
+ Ort::TensorTypeAndShapeInfo info = ort_value.GetTensorTypeAndShapeInfo();
511
+ size_t num_elements = info.GetElementCount();
512
+
513
+ for (size_t i = 0; i < num_elements; ++i) {
514
+ output_file << data[i];
515
+ if (i < num_elements - 1) {
516
+ output_file << ","; // Space separator between elements
517
+ }
518
+ }
519
+ output_file << std::endl; // Newline after each Ort::Value's content
520
+ } else {
521
+ // Handle other Ort::Value types if necessary (e.g., sequences, maps)
522
+ output_file << "Non-tensor Ort::Value" << std::endl;
523
+ }
524
+ }
525
+
526
+ output_file.close();
527
+
528
+
529
+ // --- Process Output ---
530
+ if (outputTensors.empty()) {
531
+ std::cerr << "Error: No output tensors received from the model." << std::endl;
532
+ return 1;
533
+ }
534
+
535
+ if (outputTensors[0].IsTensor()) {
536
+ float* outputData = outputTensors[0].GetTensorMutableData<float>();
537
+ Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
538
+ std::vector<int64_t> outputShape = outputShapeInfo.GetShape();
539
+ size_t outputSize = outputShapeInfo.GetElementCount();
540
+
541
+ std::cout << "\n--- Model Inference Result (first few elements) ---" << std::endl;
542
+ for (size_t k = 0; k < std::min((size_t)10, outputSize); ++k) {
543
+ std::cout << outputData[k] << (k == std::min((size_t)10, outputSize) - 1 ? "" : ", ");
544
+ }
545
+ std::cout << std::endl;
546
+
547
+ std::cout << "Full output tensor size: " << outputSize << " elements." << std::endl;
548
+ std::cout << "Full output tensor shape: [";
549
+ for (size_t k = 0; k < outputShape.size(); ++k) {
550
+ std::cout << outputShape[k] << (k == outputShape.size() - 1 ? "" : ", ");
551
+ }
552
+ std::cout << "]" << std::endl;
553
+ } else {
554
+ std::cerr << "Error: First output tensor is not of the expected type (float tensor)." << std::endl;
555
+ }
556
+
557
+ } catch (const Ort::Exception& e) {
558
+ std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
559
+ return 1;
560
+ } catch (const std::exception& e) {
561
+ std::cerr << "Standard Exception: " << e.what() << std::endl;
562
+ return 1;
563
+ }
564
+
565
+ std::cout << "\nProgram finished successfully." << std::endl;
566
+ return 0;
567
+ }
cpp/inference/test copy.cpp ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream>
2
+ #include <vector>
3
+ #include <fstream> // For file input/output operations (e.g., std::ifstream, std::ofstream)
4
+ #include <cstdint> // For fixed-width integer types (e.g., int16_t)
5
+ #include <cmath> // For mathematical functions (e.g., std::sin, M_PI)
6
+ #include <numeric> // For numerical operations (not strictly used in this version but often useful)
7
+ #include <algorithm> // For algorithms like std::min
8
+
9
+ // Include the ONNX Runtime C++ API header
10
+ // You need to have ONNX Runtime installed and linked correctly in your build system.
11
+ // For example, using CMake, you might add:
12
+ // find_package(ONNXRuntime REQUIRED)
13
+ // target_link_libraries(your_executable PRIVATE ONNXRuntime::onnxruntime_cxx_api)
14
+ #include <onnxruntime_cxx_api.h>
15
+
16
+ // Define M_PI if it's not already defined by cmath or your compiler.
17
+ // This is common on Windows with MSVC unless _USE_MATH_DEFINES is set.
18
+ #ifndef M_PI
19
+ #define M_PI 3.14159265358979323846
20
+ #endif
21
+
22
+
23
+ std::vector<float> loadPcmToFloatArray(const std::string& filename, int bitDepth, int numChannels) {
24
+ // Open the PCM file in binary mode for reading
25
+ std::ifstream file(filename, std::ios::binary);
26
+ if (!file.is_open()) {
27
+ std::cerr << "Error: Could not open PCM file: " << filename << std::endl;
28
+ return {}; // Return empty vector on failure
29
+ }
30
+
31
+ std::vector<float> audioData; // Vector to store the normalized float audio samples
32
+
33
+ // Check if the bit depth is supported (this example only handles 16-bit)
34
+ if (bitDepth == 16) {
35
+ int16_t sample; // Buffer to read a single 16-bit sample
36
+
37
+ // Read samples until the end of the file
38
+ while (file.read(reinterpret_cast<char*>(&sample), sizeof(sample))) {
39
+ // Normalize 16-bit signed integer to float in range [-1.0, 1.0]
40
+ // The maximum positive value for int16_t is 32767.
41
+ // Dividing by 32768.0f (which is 2^15) ensures that 32767 maps to
42
+ // slightly less than 1.0, and -32768 maps to -1.0, maintaining
43
+ // the full dynamic range and avoiding overflow for -32768.
44
+ audioData.push_back(static_cast<float>(sample) / 32768.0f);
45
+ }
46
+ } else {
47
+ std::cerr << "Error: Unsupported bit depth: " << bitDepth << ". This example only supports 16-bit PCM." << std::endl;
48
+ return {}; // Return empty vector for unsupported bit depth
49
+ }
50
+
51
+ file.close(); // Close the file
52
+ return audioData; // Return the loaded audio data
53
+ }
54
+
55
+ int main() {
56
+ // --- Configuration for Audio and ONNX Model ---
57
+ std::string pcmFilename = "/mnt/data-2t/jeff/codes/llm/cpp/sample_data/pickup_breezy-common_voice_zh-TW_17376838-breezyvoice-00818.pcm"; // Name of the PCM audio file to load
58
+ int bitDepth = 16; // Bit depth of the PCM data (e.g., 16-bit)
59
+ int numChannels = 1; // Number of audio channels (e.g., 1 for mono)
60
+ int sampleRate = 16000; // Sample rate of the audio (e.g., 16000 Hz)
61
+ std::string onnxModelPath = "/mnt/data-2t/jeff/codes/llm/cpp/onnx_files/speech_init_export/phi-4-mm-speech.onnx"; // Path to your ONNX model file
62
+
63
+ // --- 2. Load PCM audio data into a float array ---
64
+ std::vector<float> audioInput = loadPcmToFloatArray(pcmFilename, bitDepth, numChannels);
65
+
66
+ if (audioInput.empty()) {
67
+ std::cerr << "Failed to load audio data from " << pcmFilename << ". Exiting." << std::endl;
68
+ return 1; // Exit if audio data loading failed
69
+ }
70
+
71
+ std::cout << "Successfully loaded " << audioInput.size() << " samples from " << pcmFilename << std::endl;
72
+
73
+ // --- 3. Check for ONNX model existence and provide guidance if missing ---
74
+ // This step is critical. You need a valid ONNX model.
75
+ std::ifstream onnxModelCheck(onnxModelPath, std::ios::binary);
76
+ if (!onnxModelCheck.is_open()) {
77
+ std::cerr << "\nError: ONNX model file '" << onnxModelPath << "' not found." << std::endl;
78
+ std::cerr << "Please provide a valid ONNX model file. If you need a simple dummy one for testing, "
79
+ << "you can create it using Python (e.g., with PyTorch) like this:" << std::endl;
80
+ std::cerr << "```python" << std::endl;
81
+ std::cerr << "import torch" << std::endl;
82
+ std::cerr << "import torch.nn as nn" << std::endl;
83
+ std::cerr << "" << std::endl;
84
+ std::cerr << "class SimpleAudioModel(nn.Module):" << std::endl;
85
+ std::cerr << " def __init__(self, input_size, output_size):" << std::endl;
86
+ std::cerr << " super(SimpleAudioModel, self).__init__()" << std::endl;
87
+ std::cerr << " # This is a very simple linear layer. Your actual model will be more complex." << std::endl;
88
+ std::cerr << " # This model expects input of shape [batch_size, input_size]" << std::endl;
89
+ std::cerr << " self.linear = nn.Linear(input_size, output_size)" << std::endl;
90
+ std::cerr << "" << std::endl;
91
+ std::cerr << " def forward(self, x):" << std::endl;
92
+ std::cerr << " # If your model expects a different input shape (e.g., [batch_size, channels, samples])," << std::endl;
93
+ std::cerr << " # you might need to reshape 'x' here before passing it to your layers (e.g., x.view(x.size(0), 1, -1))." << std::endl;
94
+ std::cerr << " return self.linear(x)" << std::endl;
95
+ std::cerr << "" << std::endl;
96
+ std::cerr << "# --- IMPORTANT: Define model input and output sizes. Adjust these to match your actual model's requirements. ---" << std::endl;
97
+ std::cerr << "# For this dummy model, we'll assume an input size matching our 2-second, 44.1kHz mono audio." << std::endl;
98
+ std::cerr << "DUMMY_INPUT_SIZE = " << (sampleRate * 2) << " # Corresponds to " << (sampleRate * 2) / static_cast<float>(sampleRate) << " seconds of audio at " << sampleRate << " Hz mono" << std::endl;
99
+ std::cerr << "DUMMY_OUTPUT_SIZE = 10 # Example: 10 classification scores or features" << std::endl;
100
+ std::cerr << "" << std::endl;
101
+ std::cerr << "model = SimpleAudioModel(DUMMY_INPUT_SIZE, DUMMY_OUTPUT_SIZE)" << std::endl;
102
+ std::cerr << "dummy_input_tensor = torch.randn(1, DUMMY_INPUT_SIZE) # Batch size 1, DUMMY_INPUT_SIZE features" << std::endl;
103
+ std::cerr << "" << std::endl;
104
+ std::cerr << "torch.onnx.export(" << std::endl;
105
+ std::cerr << " model," << std::endl;
106
+ std::cerr << " dummy_input_tensor," << std::endl;
107
+ std::cerr << " \"model.onnx\"," << std::endl;
108
+ std::cerr << " verbose=True," << std::endl;
109
+ std::cerr << " input_names=['input'], # Name of the input tensor in the ONNX graph" << std::endl;
110
+ std::cerr << " output_names=['output'], # Name of the output tensor in the ONNX graph" << std::endl;
111
+ std::cerr << " # Optional: Define dynamic axes if your batch size or sequence length can vary" << std::endl;
112
+ std::cerr << " dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}" << std::endl;
113
+ std::cerr << ")" << std::endl;
114
+ std::cerr << "print(\"Dummy model.onnx created successfully. Remember to adjust DUMMY_INPUT_SIZE in this script to match the length of your audio data or ensure your C++ code pads/truncates the audio data to the model's expected input size.\")" << std::endl;
115
+ std::cerr << "```" << std::endl;
116
+ return 1; // Exit if the ONNX model is not found
117
+ }
118
+ onnxModelCheck.close();
119
+ std::cout << "ONNX model '" << onnxModelPath << "' found. Proceeding with inference." << std::endl;
120
+
121
+
122
+ // --- 4. ONNX Runtime Inference ---
123
+ try {
124
+ // Create an ONNX Runtime environment. This is the entry point for all ONNX Runtime operations.
125
+ // ORT_LOGGING_LEVEL_WARNING suppresses verbose output unless there's a warning or error.
126
+ Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "AudioInference");
127
+
128
+ // Configure session options.
129
+ Ort::SessionOptions session_options;
130
+ session_options.SetIntraOpNumThreads(1); // Use 1 thread for operations within a single node
131
+ session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED); // Apply all available graph optimizations
132
+
133
+ // Create an ONNX Runtime session by loading the model.
134
+ Ort::Session session(env, onnxModelPath.c_str(), session_options);
135
+
136
+ // Get model input and output names and shapes.
137
+ // An allocator is needed to manage memory for allocated strings (like node names).
138
+ Ort::AllocatorWithDefaultOptions allocator;
139
+
140
+ // --- Get Input Node Information ---
141
+ size_t numInputNodes = session.GetInputCount();
142
+ std::vector<const char*> inputNodeNames(numInputNodes); // To store input node names
143
+
144
+ std::cout << "\n--- Model Input Information ---" << std::endl;
145
+ // Iterate through all input nodes (models usually have one main input)
146
+ for (size_t i = 0; i < numInputNodes; ++i) {
147
+ // Get the input node name
148
+ inputNodeNames[i] = session.GetInputNameAllocated(i, allocator).get();
149
+
150
+ // Get the type and shape information for the input tensor
151
+ Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
152
+ auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
153
+ std::vector<int64_t> actualInputShape = tensor_info.GetShape(); // Get the shape the model *expects*
154
+
155
+ std::cout << " Input " << i << " : Name='" << inputNodeNames[i] << "', Shape=[";
156
+ for (size_t j = 0; j < actualInputShape.size(); ++j) {
157
+ std::cout << actualInputShape[j] << (j == actualInputShape.size() - 1 ? "" : ", ");
158
+ }
159
+ std::cout << "]" << std::endl;
160
+
161
+ // --- Prepare Input Tensor Shape ---
162
+ // This is a CRITICAL step. The `audioInput` vector must be reshaped
163
+ // to precisely match the ONNX model's expected input tensor shape.
164
+ // The dummy Python model provided above creates an input of shape [1, DUMMY_INPUT_SIZE].
165
+ // We need to ensure `audioInput` matches `DUMMY_INPUT_SIZE` or pad/truncate it.
166
+ std::vector<int64_t> inputTensorShape; // This will be the shape of the tensor we create
167
+
168
+ if (actualInputShape.size() == 2 && actualInputShape[0] == 1) {
169
+ // Case: Model expects a 2D input with batch size 1 (e.g., [1, num_features])
170
+ int64_t expected_length = actualInputShape[1]; // The expected number of features/samples
171
+
172
+ // Check if the loaded audio data size matches the model's expected input length
173
+ if (audioInput.size() != expected_length) {
174
+ std::cout << " Warning: Loaded audio input size (" << audioInput.size()
175
+ << ") does not match model's expected input length (" << expected_length << ")." << std::endl;
176
+ std::cout << " Padding/truncating audio data to match model input size." << std::endl;
177
+ audioInput.resize(expected_length, 0.0f); // Pad with zeros or truncate the audio data
178
+ }
179
+ inputTensorShape = {1, expected_length}; // Set the tensor shape for ONNX Runtime
180
+ } else if (actualInputShape.size() == 1) {
181
+ // Case: Model expects a 1D input (e.g., [num_features])
182
+ int64_t expected_length = actualInputShape[0];
183
+
184
+ if (audioInput.size() != expected_length) {
185
+ std::cout << " Warning: Loaded audio input size (" << audioInput.size()
186
+ << ") does not match model's expected input length (" << expected_length << ")." << std::endl;
187
+ std::cout << " Padding/truncating audio data to match model input size." << std::endl;
188
+ audioInput.resize(expected_length, 0.0f); // Pad with zeros or truncate
189
+ }
190
+ inputTensorShape = {expected_length}; // Set the tensor shape for ONNX Runtime
191
+ } else {
192
+ std::cerr << "Error: Model input shape is not supported by this example ([N] or [1, N]). "
193
+ << "Please adjust the input tensor shape creation logic in C++ to match your model's specific requirements." << std::endl;
194
+ return 1; // Exit if the input shape is not handled
195
+ }
196
+
197
+ // Create an ONNX Runtime memory info object for CPU memory.
198
+ // This specifies where the tensor data is located (CPU in this case).
199
+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
200
+
201
+ // Create the input tensor from the audio data.
202
+ // `audioInput.data()` provides a pointer to the raw float data.
203
+ // `audioInput.size()` is the total number of elements.
204
+ // `inputTensorShape.data()` provides the shape array.
205
+ // `inputTensorShape.size()` is the number of dimensions.
206
+ Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, audioInput.data(), audioInput.size(),
207
+ inputTensorShape.data(), inputTensorShape.size());
208
+
209
+ // Verify that the created input tensor is valid
210
+ if (!inputTensor.IsTensor()) {
211
+ std::cerr << "Error: Created input tensor is not valid! This might indicate a shape mismatch or data issue." << std::endl;
212
+ return 1; // Exit if the tensor is invalid
213
+ }
214
+
215
+ // At this point, `inputTensor` is ready to be fed into the model.
216
+ // For simplicity, we assume there's only one input to the model.
217
+ // If your model has multiple inputs, you'd need to create multiple Ort::Value objects.
218
+
219
+ // --- Get Output Node Information ---
220
+ size_t numOutputNodes = session.GetOutputCount();
221
+ std::vector<const char*> outputNodeNames(numOutputNodes); // To store output node names
222
+
223
+ std::cout << "\n--- Model Output Information ---" << std::endl;
224
+ // Iterate through all output nodes
225
+ for (size_t k = 0; k < numOutputNodes; ++k) {
226
+ outputNodeNames[k] = session.GetOutputNameAllocated(k, allocator).get();
227
+ Ort::TypeInfo type_info_out = session.GetOutputTypeInfo(k);
228
+ auto tensor_info_out = type_info_out.GetTensorTypeAndShapeInfo();
229
+ std::vector<int64_t> outputShape = tensor_info_out.GetShape();
230
+ std::cout << " Output " << k << " : Name='" << outputNodeNames[k] << "', Shape=[";
231
+ for (size_t l = 0; l < outputShape.size(); ++l) {
232
+ std::cout << outputShape[l] << (l == outputShape.size() - 1 ? "" : ", ");
233
+ }
234
+ std::cout << "]" << std::endl;
235
+ }
236
+
237
+ // --- Run Inference ---
238
+ std::cout << "\nRunning ONNX model inference..." << std::endl;
239
+ // The `session.Run` method executes the model.
240
+ // Arguments:
241
+ // - Ort::RunOptions{nullptr}: Default run options.
242
+ // - inputNodeNames.data(): Array of C-style strings for input names.
243
+ // - &inputTensor: Pointer to the array of input tensors (here, just one).
244
+ // - 1: Number of input tensors.
245
+ // - outputNodeNames.data(): Array of C-style strings for output names.
246
+ // - numOutputNodes: Number of output tensors expected.
247
+ std::vector<Ort::Value> outputTensors = session.Run(Ort::RunOptions{nullptr},
248
+ inputNodeNames.data(), &inputTensor, 1,
249
+ outputNodeNames.data(), numOutputNodes);
250
+
251
+ // --- Process Output ---
252
+ if (outputTensors.empty()) {
253
+ std::cerr << "Error: No output tensors received from the model." << std::endl;
254
+ return 1; // Exit if no output
255
+ }
256
+
257
+ // Assuming the first output is a float tensor (common for most models)
258
+ if (outputTensors[0].IsTensor()) {
259
+ // Get a mutable pointer to the raw data of the output tensor
260
+ float* outputData = outputTensors[0].GetTensorMutableData<float>();
261
+ Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
262
+ std::vector<int64_t> outputShape = outputShapeInfo.GetShape();
263
+ size_t outputSize = outputShapeInfo.GetElementCount(); // Total number of elements in the output tensor
264
+
265
+ std::cout << "\n--- Model Inference Result (first few elements) ---" << std::endl;
266
+ // Print the first 10 elements of the output (or fewer if output is smaller)
267
+ for (size_t k = 0; k < std::min((size_t)10, outputSize); ++k) {
268
+ std::cout << outputData[k] << (k == std::min((size_t)10, outputSize) - 1 ? "" : ", ");
269
+ }
270
+ std::cout << std::endl;
271
+
272
+ std::cout << "Full output tensor size: " << outputSize << " elements." << std::endl;
273
+ std::cout << "Full output tensor shape: [";
274
+ for (size_t k = 0; k < outputShape.size(); ++k) {
275
+ std::cout << outputShape[k] << (k == outputShape.size() - 1 ? "" : ", ");
276
+ }
277
+ std::cout << "]" << std::endl;
278
+
279
+ // Here you would typically interpret the model's output based on its purpose.
280
+ // For example:
281
+ // - For classification: Find the index of the maximum value (highest probability).
282
+ // - For regression: Use the numerical output directly.
283
+ // - For feature extraction: Use the output vector as features for further processing.
284
+ } else {
285
+ std::cerr << "Error: First output tensor is not of the expected type (float tensor)." << std::endl;
286
+ }
287
+ } // End of loop for input nodes (assuming single input for simplicity in this example)
288
+
289
+ } catch (const Ort::Exception& e) {
290
+ // Catch ONNX Runtime specific exceptions
291
+ std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
292
+ return 1;
293
+ } catch (const std::exception& e) {
294
+ // Catch other standard exceptions
295
+ std::cerr << "Standard Exception: " << e.what() << std::endl;
296
+ return 1;
297
+ }
298
+
299
+ std::cout << "\nProgram finished successfully." << std::endl;
300
+ return 0;
301
+ }
cpp/inference/test.cpp ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream> // For standard input/output operations (e.g., std::cout, std::cerr)
2
+ #include <vector> // For dynamic arrays (e.g., std::vector<float>)
3
+ #include <fstream> // For file input/output operations (e.g., std::ifstream, std::ofstream)
4
+ #include <cstdint> // For fixed-width integer types (e.g., int16_t, uint32_t)
5
+ #include <cmath> // For mathematical functions (e.g., std::sin, M_PI, std::log)
6
+ #include <numeric> // For numerical operations (e.g., std::iota)
7
+ #include <algorithm> // For algorithms like std::min, std::max
8
+ #include <string> // For std::string
9
+
10
+ // Include the ONNX Runtime C++ API header
11
+ #include <onnxruntime_cxx_api.h>
12
+
13
+ // Include Eigen for powerful matrix operations.
14
+ // You need to download Eigen and set up your include paths.
15
+ // E.g., if Eigen is in 'C:/Libraries/eigen-3.4.0', you'd compile with -I C:/Libraries/eigen-3.4.0
16
+ #include <Eigen/Dense>
17
+
18
+ // Include KissFFT for Fast Fourier Transform.
19
+ // You need to download KissFFT and set up your include paths.
20
+ // E.g., if KissFFT is in 'C:/Libraries/kissfft-1.3.0', you'd compile with -I C:/Libraries/kissfft-1.3.0
21
+ // You also need to compile kiss_fft.c and kiss_fftr.c and link them.
22
+ #include <kiss_fft.h>
23
+ #include <kiss_fftr.h> // For real-valued FFT
24
+
25
+ // Define M_PI if it's not already defined by cmath or your compiler.
26
+ #ifndef M_PI
27
+ #define M_PI 3.14159265358979323846
28
+ #endif
29
+
30
+ // --- Global parameters for feature extraction (matching Python script) ---
31
+ const float PREEMPHASIS_COEFF = 0.97f;
32
+ const int N_FFT = 512; // FFT size
33
+ const int WIN_LENGTH = 400; // Window length (samples)
34
+ const int HOP_LENGTH = 160; // Hop length (samples)
35
+ const int N_MELS = 80; // Number of Mel filterbank channels
36
+ const int TARGET_SAMPLE_RATE = 16000; // Target sample rate for feature extraction
37
+
38
+ // --- WAV File Header Structures ---
39
+ // These structures are for parsing the WAV file format.
40
+ // They assume little-endian byte order, which is standard for WAV files on most systems.
41
+ #pragma pack(push, 1) // Ensure no padding for these structures
42
+
43
+ struct WavHeader {
44
+ char riff_id[4]; // Contains "RIFF"
45
+ uint32_t file_size; // Size of the overall file - 8 bytes
46
+ char wave_id[4]; // Contains "WAVE"
47
+ char fmt_id[4]; // Contains "fmt " (note the space)
48
+ uint32_t fmt_size; // Size of the fmt chunk (16 for PCM)
49
+ uint16_t audio_format; // Audio format (1 for PCM)
50
+ uint16_t num_channels; // Number of channels (1 for mono, 2 for stereo)
51
+ uint32_t sample_rate; // Sample rate (e.g., 44100 Hz)
52
+ uint32_t byte_rate; // (SampleRate * NumChannels * BitsPerSample) / 8
53
+ uint16_t block_align; // (NumChannels * BitsPerSample) / 8
54
+ uint16_t bits_per_sample;// Bits per sample (e.g., 16)
55
+ };
56
+
57
+ struct WavDataChunk {
58
+ char data_id[4]; // Contains "data"
59
+ uint32_t data_size; // Size of the data chunk
60
+ };
61
+
62
+ #pragma pack(pop) // Restore default packing alignment
63
+
64
+ /**
65
+ * @brief Loads audio data from a WAV file into a float vector.
66
+ *
67
+ * This function reads a WAV file, parses its header, extracts 16-bit signed
68
+ * integer PCM samples, converts them to floating-point values, and normalizes
69
+ * them to the range [-1.0, 1.0]. It supports mono and stereo (converting stereo to mono
70
+ * by averaging channels).
71
+ *
72
+ * @param filename The path to the WAV audio file.
73
+ * @param actual_sample_rate Output parameter to store the sample rate read from the WAV file.
74
+ * @return A std::vector<float> containing the normalized mono audio samples, or an empty
75
+ * vector if the file cannot be opened or is not a supported WAV format.
76
+ */
77
+ std::vector<float> loadWavToFloatArray(const std::string& filename, int& actual_sample_rate) {
78
+ std::ifstream file(filename, std::ios::binary);
79
+ if (!file.is_open()) {
80
+ std::cerr << "Error: Could not open WAV file: " << filename << std::endl;
81
+ return {};
82
+ }
83
+
84
+ WavHeader header;
85
+ file.read(reinterpret_cast<char*>(&header), sizeof(WavHeader));
86
+
87
+ // Basic header validation
88
+ if (std::string(header.riff_id, 4) != "RIFF" ||
89
+ std::string(header.wave_id, 4) != "WAVE" ||
90
+ std::string(header.fmt_id, 4) != "fmt ") {
91
+ std::cerr << "Error: Invalid WAV header (RIFF, WAVE, or fmt chunk missing/invalid)." << std::endl;
92
+ file.close();
93
+ return {};
94
+ }
95
+
96
+ if (header.audio_format != 1) { // 1 = PCM
97
+ std::cerr << "Error: Only PCM audio format (1) is supported. Found: " << header.audio_format << std::endl;
98
+ file.close();
99
+ return {};
100
+ }
101
+
102
+ if (header.bits_per_sample != 16) {
103
+ std::cerr << "Error: Only 16-bit PCM is supported. Found: " << header.bits_per_sample << " bits per sample." << std::endl;
104
+ file.close();
105
+ return {};
106
+ }
107
+
108
+ actual_sample_rate = header.sample_rate;
109
+ std::cout << "WAV file info: Sample Rate=" << header.sample_rate
110
+ << ", Channels=" << header.num_channels
111
+ << ", Bit Depth=" << header.bits_per_sample << std::endl;
112
+
113
+ // Find the "data" chunk
114
+ WavDataChunk data_chunk;
115
+ bool data_chunk_found = false;
116
+ while (!file.eof()) {
117
+ file.read(reinterpret_cast<char*>(&data_chunk.data_id), 4);
118
+ file.read(reinterpret_cast<char*>(&data_chunk.data_size), 4);
119
+
120
+ if (std::string(data_chunk.data_id, 4) == "data") {
121
+ data_chunk_found = true;
122
+ break;
123
+ } else {
124
+ // Skip unknown chunks
125
+ file.seekg(data_chunk.data_size, std::ios::cur);
126
+ }
127
+ }
128
+
129
+ if (!data_chunk_found) {
130
+ std::cerr << "Error: 'data' chunk not found in WAV file." << std::endl;
131
+ file.close();
132
+ return {};
133
+ }
134
+
135
+ std::vector<float> audioData;
136
+ int16_t sample_buffer;
137
+ long num_samples_to_read = data_chunk.data_size / sizeof(int16_t);
138
+
139
+ for (long i = 0; i < num_samples_to_read; ++i) {
140
+ file.read(reinterpret_cast<char*>(&sample_buffer), sizeof(int16_t));
141
+ float normalized_sample = static_cast<float>(sample_buffer) / 32768.0f;
142
+
143
+ if (header.num_channels == 1) {
144
+ audioData.push_back(normalized_sample);
145
+ } else if (header.num_channels == 2) {
146
+ // For stereo, read both left and right, then average for mono output
147
+ // Read next sample (right channel)
148
+ int16_t right_sample;
149
+ if (file.read(reinterpret_cast<char*>(&right_sample), sizeof(int16_t))) {
150
+ float normalized_right_sample = static_cast<float>(right_sample) / 32768.0f;
151
+ audioData.push_back((normalized_sample + normalized_right_sample) / 2.0f);
152
+ i++; // Increment i again as we read two samples
153
+ } else {
154
+ std::cerr << "Warning: Unexpected end of file while reading stereo data." << std::endl;
155
+ break;
156
+ }
157
+ } else {
158
+ std::cerr << "Error: Unsupported number of channels: " << header.num_channels << std::endl;
159
+ file.close();
160
+ return {};
161
+ }
162
+ }
163
+
164
+ file.close();
165
+ return audioData;
166
+ }
167
+
168
+ /**
169
+ * @brief Generates a Hamming window.
170
+ * @param window_length The length of the window.
171
+ * @return A std::vector<float> containing the Hamming window coefficients.
172
+ */
173
+ std::vector<float> generateHammingWindow(int window_length) {
174
+ std::vector<float> window(window_length);
175
+ for (int i = 0; i < window_length; ++i) {
176
+ window[i] = 0.54f - 0.46f * std::cos(2 * M_PI * i / static_cast<float>(window_length - 1));
177
+ }
178
+ return window;
179
+ }
180
+
181
+ /**
182
+ * @brief Extracts spectrogram features from waveform, matching Python's _extract_spectrogram.
183
+ *
184
+ * @param wav The input waveform (1D array of floats).
185
+ * @param fs The sampling rate of the waveform (fixed to 16000 Hz for this model).
186
+ * @return A 2D Eigen::MatrixXf representing the spectrogram (frames x (N_FFT/2 + 1)).
187
+ */
188
+ Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs) {
189
+ // Calculate number of frames
190
+ int n_batch = (wav.size() - WIN_LENGTH) / HOP_LENGTH + 1;
191
+ if (n_batch <= 0) {
192
+ std::cerr << "Warning: Input waveform too short for feature extraction. Returning empty spectrogram." << std::endl;
193
+ return Eigen::MatrixXf(0, N_FFT / 2 + 1);
194
+ }
195
+
196
+ // Generate Hamming window once
197
+ std::vector<float> fft_window = generateHammingWindow(WIN_LENGTH);
198
+
199
+ // Initialize KissFFT for real-valued input
200
+ kiss_fftr_cfg fft_cfg = kiss_fftr_alloc(N_FFT, 0 /* is_inverse_fft */, nullptr, nullptr);
201
+ if (!fft_cfg) {
202
+ std::cerr << "Error: Failed to allocate KissFFT configuration." << std::endl;
203
+ return Eigen::MatrixXf(0, N_FFT / 2 + 1);
204
+ }
205
+
206
+ // Output spectrogram matrix: rows = frames, columns = FFT bins
207
+ Eigen::MatrixXf spec_matrix(n_batch, N_FFT / 2 + 1);
208
+
209
+ std::vector<float> frame_buffer(WIN_LENGTH);
210
+ kiss_fft_scalar fft_input[N_FFT]; // KissFFT requires input buffer of size N_FFT
211
+ kiss_fft_cpx fft_output[N_FFT / 2 + 1]; // KissFFT real output size
212
+
213
+ for (int i = 0; i < n_batch; ++i) {
214
+ int start_idx = i * HOP_LENGTH;
215
+
216
+ // Extract current frame
217
+ for (int j = 0; j < WIN_LENGTH; ++j) {
218
+ frame_buffer[j] = wav[start_idx + j];
219
+ }
220
+
221
+ // Apply pre-emphasis and scale by 32768 (as in Python)
222
+ // Python: y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
223
+ // where y_frames_prev[:, 0] = y_frames_prev[:, 1]
224
+ // This means for j=0, it's frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]
225
+ // For j>0, it's frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j-1]
226
+ // Let's re-evaluate the pre-emphasis based on the Python snippet:
227
+ // y_frames_prev = np.roll(y_frames, 1, axis=1)
228
+ // y_frames_prev[:, 0] = y_frames_prev[:, 1]
229
+ // This means the first element of `y_frames_prev` for each frame is the second element of `y_frames`.
230
+ // So, for the first sample in a frame, it's `frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]`.
231
+ // For subsequent samples, it's `frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j-1]`.
232
+ // This is a common pre-emphasis filter, but the first sample handling is specific.
233
+
234
+ // Corrected pre-emphasis implementation to match the Python `np.roll` behavior:
235
+ // The Python code effectively does:
236
+ // preemphasized_sample[0] = frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1] (if WIN_LENGTH > 1)
237
+ // preemphasized_sample[j] = frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j-1] for j > 0
238
+ // If WIN_LENGTH is 1, then it's just frame_buffer[0] (no pre-emphasis)
239
+ if (WIN_LENGTH > 0) {
240
+ if (WIN_LENGTH > 1) {
241
+ fft_input[0] = (frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]) * 32768.0f;
242
+ for (int j = 1; j < WIN_LENGTH; ++j) {
243
+ fft_input[j] = (frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j - 1]) * 32768.0f;
244
+ }
245
+ } else { // WIN_LENGTH == 1
246
+ fft_input[0] = frame_buffer[0] * 32768.0f;
247
+ }
248
+ }
249
+ // Zero-pad the rest of the FFT input if WIN_LENGTH < N_FFT
250
+ for (int j = WIN_LENGTH; j < N_FFT; ++j) {
251
+ fft_input[j] = 0.0f;
252
+ }
253
+
254
+ // Apply Hamming window
255
+ for (int j = 0; j < WIN_LENGTH; ++j) {
256
+ fft_input[j] *= fft_window[j];
257
+ }
258
+
259
+ // Perform real FFT
260
+ kiss_fftr(fft_cfg, fft_input, fft_output);
261
+
262
+ // Calculate magnitude spectrogram
263
+ for (int j = 0; j <= N_FFT / 2; ++j) {
264
+ spec_matrix(i, j) = std::sqrt(fft_output[j].r * fft_output[j].r + fft_output[j].i * fft_output[j].i);
265
+ }
266
+ }
267
+
268
+ kiss_fftr_free(fft_cfg); // Free KissFFT configuration
269
+ return spec_matrix;
270
+ }
271
+
272
+ /**
273
+ * @brief Creates a Mel filter-bank matrix, matching Python's speechlib_mel.
274
+ *
275
+ * @param sample_rate Sample rate in Hz.
276
+ * @param n_fft FFT size.
277
+ * @param n_mels Mel filter size.
278
+ * @param fmin Lowest frequency (in Hz).
279
+ * @param fmax Highest frequency (in Hz).
280
+ * @return An Eigen::MatrixXf representing the Mel transform matrix (n_mels x (1 + n_fft/2)).
281
+ */
282
+ Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax) {
283
+ int bank_width = n_fft / 2 + 1;
284
+ if (fmax == 0.0f) fmax = sample_rate / 2.0f; // Use 0.0f as a sentinel for None
285
+ if (fmin == 0.0f) fmin = 0.0f; // Use 0.0f as a sentinel for None
286
+
287
+ // Helper functions for Mel scale conversion
288
+ auto mel = [](float f) { return 1127.0f * std::log(1.0f + f / 700.0f); };
289
+ auto bin2mel = [&](int fft_bin) { return 1127.0f * std::log(1.0f + static_cast<float>(fft_bin) * sample_rate / (static_cast<float>(n_fft) * 700.0f)); };
290
+ auto f2bin = [&](float f) { return static_cast<int>((f * n_fft / sample_rate) + 0.5f); };
291
+
292
+ // Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax)]
293
+ int klo = f2bin(fmin) + 1;
294
+ int khi = f2bin(fmax);
295
+ khi = std::max(khi, klo);
296
+
297
+ // Spec 2: SpeechLib uses triangles in Mel space
298
+ float mlo = mel(fmin);
299
+ float mhi = mel(fmax);
300
+
301
+ // Generate Mel centers
302
+ std::vector<float> m_centers(n_mels + 2);
303
+ float ms = (mhi - mlo) / (n_mels + 1);
304
+ for (int i = 0; i < n_mels + 2; ++i) {
305
+ m_centers[i] = mlo + i * ms;
306
+ }
307
+
308
+ Eigen::MatrixXf matrix = Eigen::MatrixXf::Zero(n_mels, bank_width);
309
+
310
+ for (int m = 0; m < n_mels; ++m) {
311
+ float left = m_centers[m];
312
+ float center = m_centers[m + 1];
313
+ float right = m_centers[m + 2];
314
+ for (int fft_bin = klo; fft_bin < bank_width; ++fft_bin) { // Loop up to bank_width-1
315
+ float mbin = bin2mel(fft_bin);
316
+ if (left < mbin && mbin < right) {
317
+ matrix(m, fft_bin) = 1.0f - std::abs(center - mbin) / ms;
318
+ }
319
+ }
320
+ }
321
+ return matrix;
322
+ }
323
+
324
+ /**
325
+ * @brief Extracts log filterbank features from waveform, matching Python's _extract_features.
326
+ *
327
+ * @param wav The input waveform (1D array of floats).
328
+ * @param fs The sampling rate of the waveform (fixed to 16000 Hz).
329
+ * @param mel_filterbank The pre-computed Mel filterbank matrix.
330
+ * @return An Eigen::MatrixXf representing the log Mel filterbank features (frames x N_MELS).
331
+ */
332
+ Eigen::MatrixXf extractFeatures(const std::vector<float>& wav, int fs, const Eigen::MatrixXf& mel_filterbank) {
333
+ // Extract spectrogram
334
+ Eigen::MatrixXf spec = extractSpectrogram(wav, fs);
335
+ if (spec.rows() == 0) {
336
+ return Eigen::MatrixXf(0, N_MELS); // Return empty matrix if spectrogram extraction failed
337
+ }
338
+
339
+ // spec_power = spec**2
340
+ Eigen::MatrixXf spec_power = spec.array().square();
341
+
342
+ // fbank_power = np.clip(spec_power.dot(_mel), 1.0, None)
343
+ // Note: Eigen's matrix multiplication is `*`, not `dot`.
344
+ // The Python `dot` for 2D arrays is matrix multiplication.
345
+ // Python: (frames, N_FFT/2+1) . (N_FFT/2+1, N_MELS) -> (frames, N_MELS)
346
+ // C++ Eigen: spec_power (rows, cols) * mel_filterbank (cols, N_MELS)
347
+ // So, mel_filterbank should be (N_FFT/2+1, N_MELS)
348
+ Eigen::MatrixXf fbank_power = spec_power * mel_filterbank.transpose(); // Transpose because Python's _mel is already transposed
349
+
350
+ // Apply clipping: np.clip(..., 1.0, None)
351
+ // This means any value less than 1.0 becomes 1.0.
352
+ fbank_power = fbank_power.array().max(1.0f);
353
+
354
+ // log_fbank = np.log(fbank_power).astype(np.float32)
355
+ Eigen::MatrixXf log_fbank = fbank_power.array().log();
356
+
357
+ return log_fbank;
358
+ }
359
+
360
+ // Function to write a dummy WAV file
361
+ void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) {
362
+ std::ofstream file(filename, std::ios::binary);
363
+ if (!file.is_open()) {
364
+ std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl;
365
+ return;
366
+ }
367
+
368
+ WavHeader header;
369
+ std::memcpy(header.riff_id, "RIFF", 4);
370
+ std::memcpy(header.wave_id, "WAVE", 4);
371
+ std::memcpy(header.fmt_id, "fmt ", 4);
372
+ header.fmt_size = 16;
373
+ header.audio_format = 1; // PCM
374
+ header.num_channels = numChannels;
375
+ header.sample_rate = sampleRate;
376
+ header.bits_per_sample = bitsPerSample;
377
+ header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8;
378
+ header.block_align = (numChannels * bitsPerSample) / 8;
379
+
380
+ WavDataChunk data_chunk;
381
+ std::memcpy(data_chunk.data_id, "data", 4);
382
+ uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds);
383
+ data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8);
384
+ header.file_size = 36 + data_chunk.data_size; // 36 is size of header before data chunk
385
+
386
+ file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader));
387
+ file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk));
388
+
389
+ // Generate a 440 Hz sine wave
390
+ for (uint32_t i = 0; i < num_samples; ++i) {
391
+ int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate)));
392
+ for (int c = 0; c < numChannels; ++c) {
393
+ file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t));
394
+ }
395
+ }
396
+
397
+ file.close();
398
+ std::cout << "Dummy WAV file '" << filename << "' created successfully." << std::endl;
399
+ }
400
+
401
+
402
+ int main(int argc, char* argv[]) {
403
+ // --- 1. Process command-line arguments ---
404
+ if (argc != 3) {
405
+ std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file>" << std::endl;
406
+ std::cerr << "Example: " << argv[0] << " model.onnx audio.wav" << std::endl;
407
+ return 1;
408
+ }
409
+
410
+ std::string onnxModelPath = argv[1];
411
+ std::string wavFilename = argv[2]; // Changed to wavFilename
412
+
413
+ // --- Configuration for Audio and ONNX Model ---
414
+ // These are fixed by the Python preprocessor code and model requirements.
415
+ // The actual sample rate will be read from the WAV file.
416
+ int actual_wav_sample_rate = 0;
417
+
418
+ // --- Create a dummy WAV file if it doesn't exist for demonstration ---
419
+ std::ifstream wavCheck(wavFilename, std::ios::binary);
420
+ if (!wavCheck.is_open()) {
421
+ std::cerr << "WAV file '" << wavFilename << "' not found. Creating a dummy one for demonstration." << std::endl;
422
+ // Create a 2-second, 16kHz, mono, 16-bit WAV file
423
+ createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, 2.0);
424
+ } else {
425
+ wavCheck.close();
426
+ }
427
+
428
+ // --- 2. Load WAV audio data into a float array ---
429
+ std::vector<float> audioWav = loadWavToFloatArray(wavFilename, actual_wav_sample_rate);
430
+
431
+ if (audioWav.empty()) {
432
+ std::cerr << "Failed to load audio data from " << wavFilename << ". Exiting." << std::endl;
433
+ return 1;
434
+ }
435
+
436
+ std::cout << "Successfully loaded " << audioWav.size() << " samples from " << wavFilename << std::endl;
437
+
438
+ // --- Validate WAV sample rate against target sample rate ---
439
+ if (actual_wav_sample_rate != TARGET_SAMPLE_RATE) {
440
+ std::cerr << "Warning: WAV file sample rate (" << actual_wav_sample_rate
441
+ << " Hz) does not match the target sample rate for feature extraction ("
442
+ << TARGET_SAMPLE_RATE << " Hz)." << std::endl;
443
+ std::cerr << "This example does NOT include resampling. Features will be extracted at "
444
+ << TARGET_SAMPLE_RATE << " Hz, which might lead to incorrect results if the WAV file's sample rate is different." << std::endl;
445
+ // In a real application, you would implement resampling here (e.g., using libsamplerate).
446
+ }
447
+
448
+
449
+ // --- 3. Precompute Mel filterbank (as it's constant for a given sample rate/FFT size) ---
450
+ // The Python example uses fmax=16000//2-80-230. This translates to TARGET_SAMPLE_RATE/2 - 80 - 230.
451
+ // Using 0.0f for fmin as sentinel for None.
452
+ float mel_fmax = static_cast<float>(TARGET_SAMPLE_RATE) / 2.0f - 80.0f - 230.0f;
453
+ Eigen::MatrixXf mel_filterbank = speechlibMel(TARGET_SAMPLE_RATE, N_FFT, N_MELS, 0.0f, mel_fmax);
454
+
455
+ if (mel_filterbank.rows() == 0 || mel_filterbank.cols() == 0) {
456
+ std::cerr << "Error: Failed to create Mel filterbank. Exiting." << std::endl;
457
+ return 1;
458
+ }
459
+ std::cout << "Mel filterbank created with shape: [" << mel_filterbank.rows() << ", " << mel_filterbank.cols() << "]" << std::endl;
460
+
461
+
462
+ // --- 4. Apply feature extraction (preprocessor) ---
463
+ std::cout << "Extracting features from audio..." << std::endl;
464
+ Eigen::MatrixXf features = extractFeatures(audioWav, TARGET_SAMPLE_RATE, mel_filterbank);
465
+
466
+ ///// check input
467
+ // std::ofstream outputFile("matrix_output.txt");
468
+ // // Check if the file was opened successfully
469
+ // if (outputFile.is_open()) {
470
+ // // Iterate through rows and columns to write elements
471
+ // for (int i = 0; i < features.rows(); ++i) {
472
+ // for (int j = 0; j < features.cols(); ++j) {
473
+ // outputFile << features(i, j); // Write the element
474
+ // if (j < features.cols() - 1) {
475
+ // outputFile << ","; // Add a space separator between elements in a row
476
+ // }
477
+ // }
478
+ // outputFile << std::endl; // Move to the next line after each row
479
+ // }
480
+ // outputFile.close(); // Close the file
481
+ // std::cout << "Matrix successfully written to matrix_output.txt" << std::endl;
482
+ // }
483
+
484
+ if (features.rows() == 0 || features.cols() == 0) {
485
+ std::cerr << "Error: Feature extraction resulted in an empty matrix. Exiting." << std::endl;
486
+ return 1;
487
+ }
488
+ std::cout << "Features extracted with shape: [" << features.rows() << ", " << features.cols() << "]" << std::endl;
489
+ std::cout << "First few feature values (first frame): [";
490
+ for (int i = 0; i < std::min((int)features.cols(), 5); ++i) {
491
+ std::cout << features(0, i) << (i == std::min((int)features.cols(), 5) - 1 ? "" : ", ");
492
+ }
493
+ std::cout << "]" << std::endl;
494
+
495
+ // --- 5. Check for ONNX model existence and provide guidance if missing ---
496
+ std::ifstream onnxModelCheck(onnxModelPath, std::ios::binary);
497
+ if (!onnxModelCheck.is_open()) {
498
+ std::cerr << "\nError: ONNX model file '" << onnxModelPath << "' not found." << std::endl;
499
+ std::cerr << "Please provide a valid ONNX model file. If you need a simple dummy one for testing, "
500
+ << "you can create it using Python (e.g., with PyTorch) like this:" << std::endl;
501
+ std::cerr << "```python" << std::endl;
502
+ std::cerr << "import torch" << std::endl;
503
+ std::cerr << "import torch.nn as nn" << std::endl;
504
+ std::cerr << "" << std::endl;
505
+ std::cerr << "class SimpleAudioModel(nn.Module):" << std::endl;
506
+ std::cerr << " def __init__(self, input_frames, feature_size, output_size):" << std::endl;
507
+ std::cerr << " super(SimpleAudioModel, self).__init__()" << std::endl;
508
+ std::cerr << " # This model expects input of shape [batch_size, frames, feature_size]" << std::endl;
509
+ std::cerr << " # Example: a simple linear layer that flattens input and processes it." << std::endl;
510
+ std::cerr << " self.flatten = nn.Flatten()" << std::endl;
511
+ std::cerr << " self.linear = nn.Linear(input_frames * feature_size, output_size)" << std::endl;
512
+ std::cerr << "" << std::endl;
513
+ std::cerr << " def forward(self, x):" << std::endl;
514
+ std::cerr << " x = self.flatten(x)" << std::endl;
515
+ std::cerr << " return self.linear(x)" << std::endl;
516
+ std::cerr << "" << std::endl;
517
+ std::cerr << "# --- IMPORTANT: Define model input and output sizes. Adjust these to match your actual model's requirements. ---" << std::endl;
518
+ std::cerr << "# The C++ preprocessor will produce features of shape [frames, 80]." << std::endl;
519
+ std::cerr << "# For a dummy model, we need to provide a fixed 'frames' value for ONNX export." << std::endl;
520
+ std::cerr << "# A typical audio segment might be 2 seconds at 16kHz, which is 32000 samples." << std::endl;
521
+ std::cerr << "# Frames = (32000 - 400) / 160 + 1 = 198.75 + 1 = 199 frames (approx)" << std::endl;
522
+ std::cerr << "# Let's use a representative number of frames, e.g., 200 for a dummy input." << std::endl;
523
+ std::cerr << "DUMMY_INPUT_FRAMES = 200 # This should be representative of your typical audio segment's frames" << std::endl;
524
+ std::cerr << "DUMMY_FEATURE_SIZE = 80 # Fixed by the Mel filterbank (N_MELS)" << std::endl;
525
+ std::cerr << "DUMMY_OUTPUT_SIZE = 10 # Example: 10 classification scores or features" << std::endl;
526
+ std::cerr << "" << std::endl;
527
+ std::cerr << "model = SimpleAudioModel(DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE, DUMMY_OUTPUT_SIZE)" << std::endl;
528
+ std::cerr << "dummy_input_tensor = torch.randn(1, DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE) # Batch size 1" << std::endl;
529
+ std::cerr << "" << std::endl;
530
+ std::cerr << "torch.onnx.export(" << std::endl;
531
+ std::cerr << " model," << std::endl;
532
+ std::cerr << " dummy_input_tensor," << std::endl;
533
+ std::cerr << " \"model.onnx\"," << std::endl;
534
+ std::cerr << " verbose=True," << std::endl;
535
+ std::cerr << " input_names=['input'], # Name of the input tensor in the ONNX graph" << std::endl;
536
+ std::cerr << " output_names=['output'], # Name of the output tensor in the ONNX graph" << std::endl;
537
+ std::cerr << " # Define dynamic axes for batch_size and frames" << std::endl;
538
+ std::cerr << " dynamic_axes={'input': {0: 'batch_size', 1: 'frames'}, 'output': {0: 'batch_size'}}" << std::endl;
539
+ std::cerr << ")" << std::endl;
540
+ std::cerr << "print(\"Dummy model.onnx created successfully. Remember to adjust DUMMY_INPUT_FRAMES in this script to match the expected number of frames from your audio segments.\")" << std::endl;
541
+ std::cerr << "```" << std::endl;
542
+ return 1;
543
+ }
544
+ onnxModelCheck.close();
545
+ std::cout << "ONNX model '" << onnxModelPath << "' found. Proceeding with inference." << std::endl;
546
+
547
+
548
+ // --- 6. ONNX Runtime Inference ---
549
+ try {
550
+ Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "AudioInference");
551
+ Ort::SessionOptions session_options;
552
+ session_options.SetIntraOpNumThreads(1);
553
+ session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
554
+
555
+ Ort::Session session(env, onnxModelPath.c_str(), session_options);
556
+ Ort::AllocatorWithDefaultOptions allocator;
557
+
558
+ // --- Get Input Node Information ---
559
+ size_t numInputNodes = session.GetInputCount();
560
+ std::vector<const char*> inputNodeNames(numInputNodes);
561
+
562
+ std::cout << "\n--- Model Input Information ---" << std::endl;
563
+ if (numInputNodes == 0) {
564
+ std::cerr << "Error: Model has no input nodes. Exiting." << std::endl;
565
+ return 1;
566
+ }
567
+
568
+ // Assuming a single input node for simplicity
569
+ inputNodeNames[0] = "audio_embeds";
570
+ Ort::TypeInfo type_info = session.GetInputTypeInfo(0);
571
+ auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
572
+ std::vector<int64_t> actualInputShape = tensor_info.GetShape();
573
+
574
+ std::cout << " Input 0 : Name='" << inputNodeNames[0] << "', Shape=[";
575
+ for (size_t j = 0; j < actualInputShape.size(); ++j) {
576
+ // Print -1 for dynamic dimensions
577
+ if (actualInputShape[j] == -1) {
578
+ std::cout << "-1";
579
+ } else {
580
+ std::cout << actualInputShape[j];
581
+ }
582
+ std::cout << (j == actualInputShape.size() - 1 ? "" : ", ");
583
+ }
584
+ std::cout << "]" << std::endl;
585
+
586
+ // --- Prepare Input Tensor Shape ---
587
+ // The ONNX model input is [batch, frames, feature_size] = [-1, -1, 80]
588
+ // Our extracted features are [frames, 80]. We need to add a batch dimension of 1.
589
+ std::vector<int64_t> inputTensorShape = {1, features.rows(), features.cols()};
590
+ std::cout << " Preparing input tensor with shape: [" << inputTensorShape[0] << ", "
591
+ << inputTensorShape[1] << ", " << inputTensorShape[2] << "]" << std::endl;
592
+
593
+ // Flatten the Eigen::MatrixXf into a std::vector<float> for ONNX Runtime
594
+ // Eigen stores in column-major order by default. ONNX Runtime expects row-major
595
+ // for flattened 2D data when reshaped to 3D [1, frames, features].
596
+ // We need to copy elements row by row to ensure correct order.
597
+ std::vector<float> inputTensorData(features.rows() * features.cols());
598
+ for (int r = 0; r < features.rows(); ++r) {
599
+ for (int c = 0; c < features.cols(); ++c) {
600
+ inputTensorData[r * features.cols() + c] = features(r, c);
601
+ }
602
+ }
603
+
604
+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
605
+ Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, inputTensorData.data(), inputTensorData.size(),
606
+ inputTensorShape.data(), inputTensorShape.size());
607
+
608
+ if (!inputTensor.IsTensor()) {
609
+ std::cerr << "Error: Created input tensor is not valid! Exiting." << std::endl;
610
+ return 1;
611
+ }
612
+
613
+ // --- Get Output Node Information ---
614
+ size_t numOutputNodes = session.GetOutputCount();
615
+ std::vector<const char*> outputNodeNames(numOutputNodes);
616
+
617
+ std::cout << "\n--- Model Output Information ---" << std::endl;
618
+ for (size_t k = 0; k < numOutputNodes; ++k) {
619
+ outputNodeNames[k] = "audio_features";
620
+ Ort::TypeInfo type_info_out = session.GetOutputTypeInfo(k);
621
+ auto tensor_info_out = type_info_out.GetTensorTypeAndShapeInfo();
622
+ std::vector<int64_t> outputShape = tensor_info_out.GetShape();
623
+ std::cout << " Output " << k << " : Name='" << outputNodeNames[k] << "', Shape=[";
624
+ for (size_t l = 0; l < outputShape.size(); ++l) {
625
+ if (outputShape[l] == -1) {
626
+ std::cout << "-1";
627
+ } else {
628
+ std::cout << outputShape[l];
629
+ }
630
+ std::cout << (l == outputShape.size() - 1 ? "" : ", ");
631
+ }
632
+ std::cout << "]" << std::endl;
633
+ }
634
+
635
+ // --- Run Inference ---
636
+ std::cout << "\nRunning ONNX model inference..." << std::endl;
637
+ std::vector<Ort::Value> outputTensors = session.Run(Ort::RunOptions{nullptr},
638
+ inputNodeNames.data(), &inputTensor, 1,
639
+ outputNodeNames.data(), numOutputNodes);
640
+
641
+
642
+ // std::ofstream output_file("f0.txt");
643
+ // for (auto& ort_value : outputTensors) {
644
+ // // Example: Assuming Ort::Value contains a float tensor
645
+ // if (ort_value.IsTensor()) {
646
+ // float* data = ort_value.GetTensorMutableData<float>();
647
+ // Ort::TensorTypeAndShapeInfo info = ort_value.GetTensorTypeAndShapeInfo();
648
+ // size_t num_elements = info.GetElementCount();
649
+
650
+ // for (size_t i = 0; i < num_elements; ++i) {
651
+ // output_file << data[i];
652
+ // if (i < num_elements - 1) {
653
+ // output_file << ","; // Space separator between elements
654
+ // }
655
+ // }
656
+ // output_file << std::endl; // Newline after each Ort::Value's content
657
+ // } else {
658
+ // // Handle other Ort::Value types if necessary (e.g., sequences, maps)
659
+ // output_file << "Non-tensor Ort::Value" << std::endl;
660
+ // }
661
+ // }
662
+ // output_file.close();
663
+
664
+ // --- Process Output ---
665
+ if (outputTensors.empty()) {
666
+ std::cerr << "Error: No output tensors received from the model." << std::endl;
667
+ return 1;
668
+ }
669
+
670
+ if (outputTensors[0].IsTensor()) {
671
+ float* outputData = outputTensors[0].GetTensorMutableData<float>();
672
+ Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
673
+ std::vector<int64_t> outputShape = outputShapeInfo.GetShape();
674
+ size_t outputSize = outputShapeInfo.GetElementCount();
675
+
676
+ std::cout << "\n--- Model Inference Result (first few elements) ---" << std::endl;
677
+ for (size_t k = 0; k < std::min((size_t)10, outputSize); ++k) {
678
+ std::cout << outputData[k] << (k == std::min((size_t)10, outputSize) - 1 ? "" : ", ");
679
+ }
680
+ std::cout << std::endl;
681
+
682
+ std::cout << "Full output tensor size: " << outputSize << " elements." << std::endl;
683
+ std::cout << "Full output tensor shape: [";
684
+ for (size_t k = 0; k < outputShape.size(); ++k) {
685
+ std::cout << outputShape[k] << (k == outputShape.size() - 1 ? "" : ", ");
686
+ }
687
+ std::cout << "]" << std::endl;
688
+ } else {
689
+ std::cerr << "Error: First output tensor is not of the expected type (float tensor)." << std::endl;
690
+ }
691
+
692
+ } catch (const Ort::Exception& e) {
693
+ std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
694
+ return 1;
695
+ } catch (const std::exception& e) {
696
+ std::cerr << "Standard Exception: " << e.what() << std::endl;
697
+ return 1;
698
+ }
699
+
700
+ std::cout << "\nProgram finished successfully." << std::endl;
701
+ return 0;
702
+ }