Upload 71 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +9 -0
- cpp/ASRDataset.py +794 -0
- cpp/__pycache__/ASRDataset.cpython-310.pyc +0 -0
- cpp/__pycache__/speech_conformer_encoder.cpython-310.pyc +0 -0
- cpp/convert_onnx.ipynb +767 -0
- cpp/convert_tensorRT.ipynb +0 -0
- cpp/gemma_v1/ASRDataset.py +793 -0
- cpp/gemma_v1/__pycache__/ASRDataset.cpython-312.pyc +0 -0
- cpp/gemma_v1/added_tokens.json +3 -0
- cpp/gemma_v1/chat_template.json +3 -0
- cpp/gemma_v1/config.json +118 -0
- cpp/gemma_v1/configuration_gemma3omni.py +206 -0
- cpp/gemma_v1/eval.py +635 -0
- cpp/gemma_v1/eval_multiturn.ipynb +0 -0
- cpp/gemma_v1/eval_multiturn.py +211 -0
- cpp/gemma_v1/merge_lora.ipynb +119 -0
- cpp/gemma_v1/model-00001-of-00003.safetensors +3 -0
- cpp/gemma_v1/model-00002-of-00003.safetensors +3 -0
- cpp/gemma_v1/model-00003-of-00003.safetensors +3 -0
- cpp/gemma_v1/model.safetensors.index.json +0 -0
- cpp/gemma_v1/modeling_gemma3omni.py +668 -0
- cpp/gemma_v1/preprocessing_gemma3omni.py +444 -0
- cpp/gemma_v1/preprocessor_config.json +41 -0
- cpp/gemma_v1/processor_config.json +7 -0
- cpp/gemma_v1/special_tokens_map.json +36 -0
- cpp/gemma_v1/speech_conformer_encoder.py +0 -0
- cpp/gemma_v1/speech_conformer_encoder_old.py +0 -0
- cpp/gemma_v1/tokenizer.json +3 -0
- cpp/gemma_v1/tokenizer.model +3 -0
- cpp/gemma_v1/tokenizer_config.json +0 -0
- cpp/gemma_v1/training.py +883 -0
- cpp/gemma_v1/training_multiturn.py +329 -0
- cpp/gemma_v1/training_multiturn_textonly.py +333 -0
- cpp/inference/audio_encoder_lib.cpp +388 -0
- cpp/inference/audio_encoder_lib.h +141 -0
- cpp/inference/audio_encoder_lib.o +0 -0
- cpp/inference/audio_inference +0 -0
- cpp/inference/audio_inference_app +0 -0
- cpp/inference/compile.sh +32 -0
- cpp/inference/dummy.wav +0 -0
- cpp/inference/f0.txt +0 -0
- cpp/inference/f_inp.txt +0 -0
- cpp/inference/kiss_fft.o +0 -0
- cpp/inference/kiss_fftr.o +0 -0
- cpp/inference/main_text.cpp +165 -0
- cpp/inference/matrix_output.txt +0 -0
- cpp/inference/run.sh +7 -0
- cpp/inference/test copy 2.cpp +567 -0
- cpp/inference/test copy.cpp +301 -0
- cpp/inference/test.cpp +702 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
cpp/gemma_v1/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17376838-breezyvoice-00818.pcm filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17376838-breezyvoice-00818.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382475-breezyvoice-01452.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382570-breezyvoice-01041.pcm filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382570-breezyvoice-01041.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382594-breezyvoice-00389.pcm filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
cpp/sample_data_old/pickup_breezy-common_voice_zh-TW_17382594-breezyvoice-00389.wav filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
cpp/sample_data/pickup_breezy-common_voice_zh-TW_17382570-breezyvoice-01041.wav filter=lfs diff=lfs merge=lfs -text
|
cpp/ASRDataset.py
ADDED
|
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import sacrebleu
|
| 13 |
+
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from transformers import (
|
| 18 |
+
BatchFeature,
|
| 19 |
+
)
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
from datasets import Audio
|
| 23 |
+
import random
|
| 24 |
+
from copy import deepcopy
|
| 25 |
+
import torchaudio
|
| 26 |
+
|
| 27 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 28 |
+
_IGNORE_INDEX = -100
|
| 29 |
+
class BaseAudioDataset(Dataset):
|
| 30 |
+
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 31 |
+
self.processor = processor
|
| 32 |
+
self.training = "train" in split or 'other' in split
|
| 33 |
+
self.debug = debug
|
| 34 |
+
self.sampling_rate = sampling_rate
|
| 35 |
+
self.name = ""
|
| 36 |
+
|
| 37 |
+
def set_dataset_name(self, name):
|
| 38 |
+
self.name = name
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
|
| 42 |
+
original_size = len(data)
|
| 43 |
+
|
| 44 |
+
data = data.cast_column(audio_field, Audio(decode=False))
|
| 45 |
+
|
| 46 |
+
def identify_corrupted_files(example):
|
| 47 |
+
try:
|
| 48 |
+
sf.read(example[audio_field]["path"])
|
| 49 |
+
|
| 50 |
+
for field in text_fields:
|
| 51 |
+
if field in example and example[field].replace('"', '') == "":
|
| 52 |
+
return False
|
| 53 |
+
return True
|
| 54 |
+
except Exception:
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
data = data.filter(identify_corrupted_files, num_proc=16)
|
| 58 |
+
validated_size = len(data)
|
| 59 |
+
|
| 60 |
+
# Audio Decoding
|
| 61 |
+
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
|
| 62 |
+
|
| 63 |
+
if debug:
|
| 64 |
+
print(f"Dataset: {dataset_name}")
|
| 65 |
+
print(f"Original data nums: {original_size}")
|
| 66 |
+
print(f"After filtering data nums: {validated_size}")
|
| 67 |
+
print(f"Filtering ratio: {validated_size/original_size:.2%}")
|
| 68 |
+
|
| 69 |
+
return data
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
|
| 73 |
+
original_size = len(data)
|
| 74 |
+
|
| 75 |
+
def filter_audio_by_length(example):
|
| 76 |
+
try:
|
| 77 |
+
audio = example[audio_field]['array']
|
| 78 |
+
channel = 1
|
| 79 |
+
if hasattr(audio, 'ndim') and audio.ndim > 1:
|
| 80 |
+
channel = audio.ndim
|
| 81 |
+
audio = audio.squeeze()
|
| 82 |
+
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
|
| 83 |
+
return min_sec <= audio_length <= max_sec
|
| 84 |
+
except Exception as e:
|
| 85 |
+
if debug:
|
| 86 |
+
print(f"Error : {str(e)[:100]}... - sample excluded")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
data = data.filter(filter_audio_by_length, num_proc=16)
|
| 90 |
+
filtered_size = len(data)
|
| 91 |
+
|
| 92 |
+
if debug:
|
| 93 |
+
print(f"Before Length Filtering data nums: {original_size}")
|
| 94 |
+
print(f"After Length Filtering data nums: {filtered_size}")
|
| 95 |
+
print(f"Filtering ratio: {filtered_size/original_size:.2%}")
|
| 96 |
+
|
| 97 |
+
return data
|
| 98 |
+
|
| 99 |
+
def prepare_model_inputs(self, audio_array, instruction, answer_text):
|
| 100 |
+
user_message = {
|
| 101 |
+
'role': 'user',
|
| 102 |
+
'content': '<start_of_audio>' + instruction,
|
| 103 |
+
}
|
| 104 |
+
prompt = self.processor.tokenizer.apply_chat_template(
|
| 105 |
+
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
inputs = self.processor(
|
| 109 |
+
text=prompt,
|
| 110 |
+
audio=[audio_array],
|
| 111 |
+
add_special_tokens=False,
|
| 112 |
+
return_tensors='pt'
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
answer = f"{answer_text}{ANSWER_SUFFIX}"
|
| 116 |
+
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
|
| 117 |
+
|
| 118 |
+
if self.debug:
|
| 119 |
+
self.debug = False
|
| 120 |
+
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
|
| 121 |
+
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
|
| 122 |
+
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
|
| 123 |
+
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
|
| 124 |
+
|
| 125 |
+
if self.training:
|
| 126 |
+
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
|
| 127 |
+
labels = torch.full_like(input_ids, _IGNORE_INDEX)
|
| 128 |
+
labels[:, -answer_ids.shape[1]:] = answer_ids
|
| 129 |
+
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
|
| 130 |
+
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
|
| 131 |
+
else:
|
| 132 |
+
input_ids = inputs.input_ids
|
| 133 |
+
labels = answer_ids
|
| 134 |
+
token_type_ids = inputs.token_type_ids
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
'input_ids': input_ids,
|
| 138 |
+
'labels': labels,
|
| 139 |
+
'token_type_ids': token_type_ids,
|
| 140 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 141 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 142 |
+
'input_modes': inputs.input_modes,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
# Libri Speech Dataset Class
|
| 146 |
+
class LibriSpeechDataset(BaseAudioDataset):
|
| 147 |
+
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
|
| 148 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 149 |
+
|
| 150 |
+
self.set_dataset_name(f"LibriSpeech_{subset}")
|
| 151 |
+
# only ASR
|
| 152 |
+
self.ast = False
|
| 153 |
+
self.lang = "en"
|
| 154 |
+
|
| 155 |
+
# load dataset
|
| 156 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
|
| 157 |
+
subset,
|
| 158 |
+
split=split,
|
| 159 |
+
trust_remote_code=True,
|
| 160 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# (Optional) Audio length Filtering
|
| 164 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 165 |
+
|
| 166 |
+
# Instruction Setting
|
| 167 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 168 |
+
|
| 169 |
+
def __len__(self):
|
| 170 |
+
return len(self.data)
|
| 171 |
+
|
| 172 |
+
def __getitem__(self, idx):
|
| 173 |
+
data = self.data[idx]
|
| 174 |
+
|
| 175 |
+
# Libri Speech is only for ASR
|
| 176 |
+
answer_text = data["text"].replace('"', '')
|
| 177 |
+
|
| 178 |
+
return self.prepare_model_inputs(
|
| 179 |
+
data["audio"]["array"],
|
| 180 |
+
self.instruction,
|
| 181 |
+
answer_text
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# common_voice_16_1 dataset
|
| 185 |
+
class CommonVoiceDataset(BaseAudioDataset):
|
| 186 |
+
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
|
| 187 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 188 |
+
|
| 189 |
+
self.set_dataset_name(f"CommonVoice_{source_lang}")
|
| 190 |
+
# only ASR
|
| 191 |
+
self.ast = False
|
| 192 |
+
self.lang=source_lang
|
| 193 |
+
|
| 194 |
+
# load dataset
|
| 195 |
+
if source_lang=="zh-TW":
|
| 196 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
|
| 197 |
+
else:
|
| 198 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
|
| 199 |
+
self.data = load_dataset(data_path,
|
| 200 |
+
source_lang,
|
| 201 |
+
split=split,
|
| 202 |
+
trust_remote_code=True,
|
| 203 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 204 |
+
)
|
| 205 |
+
def prepare_dataset(batch):
|
| 206 |
+
"""Function to preprocess the dataset with the .map method"""
|
| 207 |
+
transcription = batch["sentence"]
|
| 208 |
+
|
| 209 |
+
if transcription.startswith('"') and transcription.endswith('"'):
|
| 210 |
+
# we can remove trailing quotation marks as they do not affect the transcription
|
| 211 |
+
transcription = transcription[1:-1]
|
| 212 |
+
|
| 213 |
+
if transcription[-1] not in [".", "?", "!"]:
|
| 214 |
+
# append a full-stop to sentences that do not end in punctuation
|
| 215 |
+
transcription = transcription + "."
|
| 216 |
+
|
| 217 |
+
batch["sentence"] = transcription
|
| 218 |
+
|
| 219 |
+
return batch
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
import opencc
|
| 223 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 224 |
+
def To_zhTW(batch):
|
| 225 |
+
|
| 226 |
+
transcription = converter.convert(batch["sentence"])
|
| 227 |
+
batch["sentence"] = transcription
|
| 228 |
+
|
| 229 |
+
return batch
|
| 230 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 231 |
+
if source_lang=='zh-CN':
|
| 232 |
+
self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# (Optional) Audio length Filtering
|
| 236 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 237 |
+
|
| 238 |
+
if source_lang == "zh-TW" and split=='train':
|
| 239 |
+
import torchaudio
|
| 240 |
+
from torchaudio import transforms
|
| 241 |
+
import copy
|
| 242 |
+
import pickle
|
| 243 |
+
import os
|
| 244 |
+
def subsample(batch):
|
| 245 |
+
batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
|
| 246 |
+
batch['audio']['sampling_rate']=16000
|
| 247 |
+
return batch
|
| 248 |
+
def TW_data_augment_fast(batch):
|
| 249 |
+
speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
|
| 250 |
+
new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
|
| 251 |
+
batch['audio']['array'] = new_array_fast
|
| 252 |
+
return batch
|
| 253 |
+
def TW_data_augment_slow(batch):
|
| 254 |
+
speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
|
| 255 |
+
new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
|
| 256 |
+
batch['audio']['array'] = new_array_slow
|
| 257 |
+
return batch
|
| 258 |
+
# data = self.data.map(subsample, num_proc=1, desc="subsample")
|
| 259 |
+
fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
|
| 260 |
+
if not os.path.exists(fast_path):
|
| 261 |
+
data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
|
| 262 |
+
with open(fast_path,'wb') as f:
|
| 263 |
+
pickle.dump(data_fast,f)
|
| 264 |
+
else:
|
| 265 |
+
with open(fast_path,'rb') as f:
|
| 266 |
+
data_fast=pickle.load(f)
|
| 267 |
+
|
| 268 |
+
slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
|
| 269 |
+
if not os.path.exists(slow_path):
|
| 270 |
+
data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
|
| 271 |
+
with open(slow_path,'wb') as f:
|
| 272 |
+
pickle.dump(data_slow,f)
|
| 273 |
+
else:
|
| 274 |
+
with open(slow_path,'rb') as f:
|
| 275 |
+
data_slow=pickle.load(f)
|
| 276 |
+
self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
|
| 277 |
+
|
| 278 |
+
# Instruction Setting
|
| 279 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 280 |
+
|
| 281 |
+
def __len__(self):
|
| 282 |
+
return len(self.data)
|
| 283 |
+
|
| 284 |
+
def __getitem__(self, idx):
|
| 285 |
+
data = self.data[idx]
|
| 286 |
+
|
| 287 |
+
answer_text = data["sentence"]
|
| 288 |
+
return self.prepare_model_inputs(
|
| 289 |
+
data["audio"]["array"],
|
| 290 |
+
self.instruction,
|
| 291 |
+
answer_text
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Fleurs Dataset Class
|
| 296 |
+
class FleursDataset(BaseAudioDataset):
|
| 297 |
+
def __init__(self, processor, split, source_lang, target_lang=None,
|
| 298 |
+
mode="asr", sampling_rate=16000, debug=False):
|
| 299 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 300 |
+
|
| 301 |
+
self.set_dataset_name("Fleurs")
|
| 302 |
+
# Mode Setting (ASR or AST)
|
| 303 |
+
if mode not in ["asr", "ast"]:
|
| 304 |
+
raise ValueError("mode must be 'asr' or 'ast'.")
|
| 305 |
+
|
| 306 |
+
self.mode = mode
|
| 307 |
+
self.ast = (mode == "ast")
|
| 308 |
+
self.source_lang = source_lang
|
| 309 |
+
|
| 310 |
+
# Language name mapping (expand if needed)
|
| 311 |
+
self.lang_names = {
|
| 312 |
+
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
# load dataset - source language dataset
|
| 316 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 317 |
+
source_lang,
|
| 318 |
+
split=split,
|
| 319 |
+
trust_remote_code=True,
|
| 320 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 321 |
+
)
|
| 322 |
+
import opencc
|
| 323 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 324 |
+
def prepare_dataset(batch):
|
| 325 |
+
transcription = converter.convert(batch["transcription"])
|
| 326 |
+
batch["transcription"] = transcription
|
| 327 |
+
|
| 328 |
+
return batch
|
| 329 |
+
if (source_lang=="cmn_hans_cn"):
|
| 330 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 331 |
+
|
| 332 |
+
# (Optional) Audio length Filtering
|
| 333 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 334 |
+
self.target_lang_name = ""
|
| 335 |
+
# When AST mode, load target language dataset.
|
| 336 |
+
if self.ast:
|
| 337 |
+
if target_lang is None:
|
| 338 |
+
raise ValueError("AST mode requires target_lang.")
|
| 339 |
+
|
| 340 |
+
self.target_lang = target_lang
|
| 341 |
+
self.lang = f"{source_lang}_{target_lang}"
|
| 342 |
+
|
| 343 |
+
# load dataset - target language dataset (for translation)
|
| 344 |
+
target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 345 |
+
target_lang,
|
| 346 |
+
split=split,
|
| 347 |
+
trust_remote_code=True,
|
| 348 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 349 |
+
)
|
| 350 |
+
if target_lang=="cmn_hans_cn":
|
| 351 |
+
target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
|
| 352 |
+
source_dict = {item['id']: item for item in self.data}
|
| 353 |
+
target_dict = {item['id']: item for item in target_data}
|
| 354 |
+
|
| 355 |
+
# only Common ID, add translation fields
|
| 356 |
+
common_ids = set(source_dict.keys()) & set(target_dict.keys())
|
| 357 |
+
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
|
| 358 |
+
self.data = [
|
| 359 |
+
{**source_dict[id], 'translation': target_dict[id]['transcription']}
|
| 360 |
+
for id in common_ids
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
# Instruction Setting - use target language name
|
| 364 |
+
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 365 |
+
self.instruction = random.choice(INSTRUCTION["ast"])
|
| 366 |
+
else:
|
| 367 |
+
# ASR mode
|
| 368 |
+
self.lang = source_lang
|
| 369 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 370 |
+
|
| 371 |
+
if self.debug:
|
| 372 |
+
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
| 373 |
+
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
|
| 374 |
+
if self.ast:
|
| 375 |
+
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
|
| 376 |
+
print(f"dataset size: {len(self.data)}")
|
| 377 |
+
|
| 378 |
+
def __len__(self):
|
| 379 |
+
return len(self.data)
|
| 380 |
+
|
| 381 |
+
def __getitem__(self, idx):
|
| 382 |
+
data = self.data[idx]
|
| 383 |
+
audio_array = data["audio"]["array"]
|
| 384 |
+
|
| 385 |
+
if self.ast:
|
| 386 |
+
answer_text = data["translation"]
|
| 387 |
+
else:
|
| 388 |
+
answer_text = data["transcription"]
|
| 389 |
+
|
| 390 |
+
return self.prepare_model_inputs(
|
| 391 |
+
audio_array,
|
| 392 |
+
self.instruction.format(self.target_lang_name),
|
| 393 |
+
answer_text
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
class TWCostumData(BaseAudioDataset):
|
| 397 |
+
|
| 398 |
+
def __init__(self, processor, split="train", sampling_rate=16000,csv_path="", debug=False):
|
| 399 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 400 |
+
import pandas as pd
|
| 401 |
+
from datasets import Dataset, Audio
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
df = pd.read_csv(csv_path).fillna('')
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
self.set_dataset_name(f"TWCostumData")
|
| 408 |
+
self.data = Dataset.from_dict(
|
| 409 |
+
{
|
| 410 |
+
"audio": [audio for audio in df['audio']],
|
| 411 |
+
"sentence": [text for text in df['text']]
|
| 412 |
+
}
|
| 413 |
+
).cast_column("audio", Audio(sampling_rate=16000))
|
| 414 |
+
|
| 415 |
+
# Instruction Setting
|
| 416 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 417 |
+
|
| 418 |
+
def __len__(self):
|
| 419 |
+
return len(self.data)
|
| 420 |
+
|
| 421 |
+
def __getitem__(self, idx):
|
| 422 |
+
data = self.data[idx]
|
| 423 |
+
|
| 424 |
+
answer_text = data["sentence"]
|
| 425 |
+
return self.prepare_model_inputs(
|
| 426 |
+
data["audio"]["array"],
|
| 427 |
+
self.instruction,
|
| 428 |
+
answer_text
|
| 429 |
+
)
|
| 430 |
+
def covost_collate_fn(batch):
|
| 431 |
+
input_ids_list = []
|
| 432 |
+
labels_list = []
|
| 433 |
+
token_type_ids_list = []
|
| 434 |
+
input_audio_embeds_list = []
|
| 435 |
+
audio_embed_sizes_list = []
|
| 436 |
+
audio_attention_mask_list = []
|
| 437 |
+
input_modes_list = []
|
| 438 |
+
audio_paths = []
|
| 439 |
+
for inputs in batch:
|
| 440 |
+
if 'audio_path' in inputs:
|
| 441 |
+
audio_paths.append(inputs['audio_path'])
|
| 442 |
+
input_ids_list.append(inputs['input_ids'][0])
|
| 443 |
+
labels_list.append(inputs['labels'][0])
|
| 444 |
+
token_type_ids_list.append(inputs['token_type_ids'][0])
|
| 445 |
+
if inputs['input_modes']==2:
|
| 446 |
+
input_audio_embeds_list.append(inputs['input_audio_embeds'])
|
| 447 |
+
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
|
| 448 |
+
audio_attention_mask_list.append(
|
| 449 |
+
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
|
| 450 |
+
)
|
| 451 |
+
# else:
|
| 452 |
+
# input_audio_embeds_list.append(None)
|
| 453 |
+
# audio_embed_sizes_list.append(None)
|
| 454 |
+
# audio_attention_mask_list.append(None)
|
| 455 |
+
input_modes_list.append(inputs['input_modes'])
|
| 456 |
+
# try:
|
| 457 |
+
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
|
| 458 |
+
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
|
| 459 |
+
labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
|
| 460 |
+
audio_attention_mask = (
|
| 461 |
+
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
|
| 462 |
+
if len(audio_attention_mask_list) > 1
|
| 463 |
+
else None
|
| 464 |
+
)
|
| 465 |
+
# except Exception as e:
|
| 466 |
+
# print(e)
|
| 467 |
+
# print(input_ids_list)
|
| 468 |
+
# print(labels_list)
|
| 469 |
+
# raise
|
| 470 |
+
attention_mask = (input_ids != 0).long()
|
| 471 |
+
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) if len(input_audio_embeds_list)>0 else None
|
| 472 |
+
audio_embed_sizes = torch.cat(audio_embed_sizes_list) if len(audio_embed_sizes_list)>0 else None
|
| 473 |
+
input_modes = torch.cat(input_modes_list)
|
| 474 |
+
if len(audio_paths)>0:
|
| 475 |
+
return BatchFeature(
|
| 476 |
+
{
|
| 477 |
+
"audio_path": audio_paths,
|
| 478 |
+
'input_ids': input_ids,
|
| 479 |
+
'labels': labels,
|
| 480 |
+
'token_type_ids': token_type_ids,
|
| 481 |
+
'attention_mask': attention_mask,
|
| 482 |
+
'input_audio_embeds': input_audio_embeds,
|
| 483 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 484 |
+
'audio_attention_mask': audio_attention_mask,
|
| 485 |
+
'input_modes': input_modes,
|
| 486 |
+
}
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
return BatchFeature(
|
| 490 |
+
{
|
| 491 |
+
'input_ids': input_ids,
|
| 492 |
+
'labels': labels,
|
| 493 |
+
'token_type_ids': token_type_ids,
|
| 494 |
+
'attention_mask': attention_mask,
|
| 495 |
+
'input_audio_embeds': input_audio_embeds,
|
| 496 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 497 |
+
'audio_attention_mask': audio_attention_mask,
|
| 498 |
+
'input_modes': input_modes,
|
| 499 |
+
}
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
def pad_sequence(sequences, padding_side='left', padding_value=0):
|
| 503 |
+
"""
|
| 504 |
+
Pad a list of sequences to the same length.
|
| 505 |
+
sequences: list of tensors in [seq_len, *] shape
|
| 506 |
+
"""
|
| 507 |
+
assert padding_side in ['right', 'left']
|
| 508 |
+
max_size = sequences[0].size()
|
| 509 |
+
trailing_dims = max_size[1:]
|
| 510 |
+
max_len = max(len(seq) for seq in sequences)
|
| 511 |
+
batch_size = len(sequences)
|
| 512 |
+
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
|
| 513 |
+
for i, seq in enumerate(sequences):
|
| 514 |
+
length = seq.size(0)
|
| 515 |
+
if padding_side == 'right':
|
| 516 |
+
output.data[i, :length] = seq
|
| 517 |
+
else:
|
| 518 |
+
output.data[i, -length:] = seq
|
| 519 |
+
return output
|
| 520 |
+
|
| 521 |
+
def cat_with_pad(tensors, dim, padding_value=0):
|
| 522 |
+
"""
|
| 523 |
+
cat along dim, while pad to max for all other dims
|
| 524 |
+
"""
|
| 525 |
+
ndim = tensors[0].dim()
|
| 526 |
+
assert all(
|
| 527 |
+
t.dim() == ndim for t in tensors[1:]
|
| 528 |
+
), 'All tensors must have the same number of dimensions'
|
| 529 |
+
|
| 530 |
+
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
|
| 531 |
+
out_size[dim] = sum(t.shape[dim] for t in tensors)
|
| 532 |
+
output = tensors[0].new_full(out_size, padding_value)
|
| 533 |
+
|
| 534 |
+
index = 0
|
| 535 |
+
for t in tensors:
|
| 536 |
+
# Create a slice list where every dimension except dim is full slice
|
| 537 |
+
slices = [slice(0, t.shape[d]) for d in range(ndim)]
|
| 538 |
+
# Update only the concat dimension slice
|
| 539 |
+
slices[dim] = slice(index, index + t.shape[dim])
|
| 540 |
+
|
| 541 |
+
output[slices] = t
|
| 542 |
+
index += t.shape[dim]
|
| 543 |
+
|
| 544 |
+
return output
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class MultiturnAudioDataset(BaseAudioDataset):
|
| 549 |
+
def __init__(self, processor, split="train", sampling_rate=16000,json_path="",text_only=False, debug=False):
|
| 550 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 551 |
+
from llamafactory.data.template import Llama2Template,parse_template
|
| 552 |
+
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
| 553 |
+
from llamafactory.data.mm_plugin import get_mm_plugin
|
| 554 |
+
import json
|
| 555 |
+
self.train=False
|
| 556 |
+
self.text_only=text_only
|
| 557 |
+
with open(json_path) as f:
|
| 558 |
+
js_data = json.load(f)
|
| 559 |
+
if split=='train':
|
| 560 |
+
self.train=True
|
| 561 |
+
js_data = js_data[:int(len(js_data)*0.8)]
|
| 562 |
+
else:
|
| 563 |
+
js_data = js_data[-int(len(js_data)*0.2):]
|
| 564 |
+
for conv in js_data:
|
| 565 |
+
for mess in conv['conversations']:
|
| 566 |
+
if 'audio_path' in mess:
|
| 567 |
+
mess['audio_path'] = mess['audio_path'].replace('/home/jeff/codes/llm/InCar/srdc_generate_tts/','/mnt/jeff/InCar/data/multiturn_data/')
|
| 568 |
+
default_system = ""#"""You are a helpful assistant that determines how to solve problems based on user needs and converts user speech into text.\n"""
|
| 569 |
+
self.template=Llama2Template(
|
| 570 |
+
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
| 571 |
+
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
| 572 |
+
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
| 573 |
+
format_function=FunctionFormatter(slots=["{{content}}", {"eos_token"}], tool_format="default"),
|
| 574 |
+
format_tools = ToolFormatter(tool_format="default"),
|
| 575 |
+
format_observation=StringFormatter(
|
| 576 |
+
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
| 577 |
+
),
|
| 578 |
+
default_system=default_system,
|
| 579 |
+
thought_words=("<think>", "</think>"),
|
| 580 |
+
efficient_eos=False,
|
| 581 |
+
replace_eos=False,
|
| 582 |
+
replace_jinja_template=False,
|
| 583 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 584 |
+
stop_words=["<end_of_turn>"],
|
| 585 |
+
mm_plugin=get_mm_plugin(name="base"),
|
| 586 |
+
enable_thinking=False
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
self.set_dataset_name(f"MultiturnCostumData")
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
self.data = []
|
| 593 |
+
self.text_only_data = []
|
| 594 |
+
for conv in js_data:
|
| 595 |
+
tools = conv['tools'] if 'tools' in conv else ""
|
| 596 |
+
system = conv['system'] if 'system' in conv else default_system
|
| 597 |
+
tmp = {
|
| 598 |
+
'tools':tools,
|
| 599 |
+
'system':system,
|
| 600 |
+
'messages':[],
|
| 601 |
+
}
|
| 602 |
+
for i,mess in enumerate(conv['conversations']):
|
| 603 |
+
tmp['messages'].append(mess)
|
| 604 |
+
if mess['from']=='human':
|
| 605 |
+
tmp['messages'].append(conv['conversations'][i+1])
|
| 606 |
+
d = deepcopy(tmp)
|
| 607 |
+
d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0]
|
| 608 |
+
self.data.append(d)
|
| 609 |
+
if self.text_only:
|
| 610 |
+
self.text_only_data.append(deepcopy(tmp))
|
| 611 |
+
tmp['messages'].pop()
|
| 612 |
+
elif mess['from']=='observation':
|
| 613 |
+
tmp['messages'].append(conv['conversations'][i+1])
|
| 614 |
+
d = deepcopy(tmp)
|
| 615 |
+
self.text_only_data.append(d)
|
| 616 |
+
tmp['messages'].pop()
|
| 617 |
+
if text_only:
|
| 618 |
+
self.data=self.text_only_data
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def prepare_multiturn_model_inputs(self, audio_array, messages, system="", tools=""):
|
| 622 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 623 |
+
prompt = ""
|
| 624 |
+
answer_text = ""
|
| 625 |
+
user_transcribe = ""
|
| 626 |
+
audio_paths = []
|
| 627 |
+
for i, message in enumerate(messages):
|
| 628 |
+
elements = []
|
| 629 |
+
|
| 630 |
+
system_text = ""
|
| 631 |
+
if i == 0:
|
| 632 |
+
elements += self.template.format_prefix.apply()
|
| 633 |
+
if system or tools:
|
| 634 |
+
tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
|
| 635 |
+
system_text = self.template.format_system.apply(content=(system + tool_text))[0]
|
| 636 |
+
|
| 637 |
+
if message["from"] == "human":
|
| 638 |
+
if i==len(messages)-2 and not self.text_only:
|
| 639 |
+
user_transcribe = message["value"]
|
| 640 |
+
elements += self.template.format_user.apply(content=system_text+'<start_of_audio>')
|
| 641 |
+
else:
|
| 642 |
+
elements += self.template.format_user.apply(content=system_text + message["value"])
|
| 643 |
+
audio_paths.append(message['audio_path'])
|
| 644 |
+
elif message["from"] == "gpt":
|
| 645 |
+
elements += self.template.format_assistant.apply(content=message["value"])
|
| 646 |
+
elif message["from"] == "observation":
|
| 647 |
+
elements += self.template.format_observation.apply(content=message["value"])
|
| 648 |
+
elif message["from"] == "function_call":
|
| 649 |
+
elements += self.template.format_function.apply(content=message["value"])
|
| 650 |
+
else:
|
| 651 |
+
raise NotImplementedError("Unexpected role: {}".format(message["from"]))
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
for elem in elements:
|
| 655 |
+
ele_str = ""
|
| 656 |
+
if isinstance(elem, str):
|
| 657 |
+
ele_str=elem
|
| 658 |
+
elif isinstance(elem, set):
|
| 659 |
+
if "bos_token" in elem and self.processor.tokenizer.bos_token_id is not None:
|
| 660 |
+
ele_str = self.processor.tokenizer.bos_token
|
| 661 |
+
elif "eos_token" in elem and self.processor.tokenizer.eos_token_id is not None:
|
| 662 |
+
ele_str = self.processor.tokenizer.eos_token
|
| 663 |
+
if i == len(messages)-1:
|
| 664 |
+
answer_text+=ele_str
|
| 665 |
+
else:
|
| 666 |
+
prompt+=ele_str
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
if type(audio_array)!=type(None):
|
| 670 |
+
inputs = self.processor(
|
| 671 |
+
text=prompt,
|
| 672 |
+
audio=[audio_array],
|
| 673 |
+
add_special_tokens=False,
|
| 674 |
+
return_tensors='pt'
|
| 675 |
+
)
|
| 676 |
+
answer = "\nUser transcribe is : {};\nGPT output is : {}{}".format(user_transcribe,answer_text,ANSWER_SUFFIX)
|
| 677 |
+
else:
|
| 678 |
+
inputs = self.processor(
|
| 679 |
+
text=prompt,
|
| 680 |
+
audio=None,
|
| 681 |
+
add_special_tokens=False,
|
| 682 |
+
return_tensors='pt'
|
| 683 |
+
)
|
| 684 |
+
answer = f"{answer_text}{ANSWER_SUFFIX}"
|
| 685 |
+
# print('user_transcribe',user_transcribe)
|
| 686 |
+
# print('answer_text', answer)
|
| 687 |
+
# print('prompt',prompt)
|
| 688 |
+
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
|
| 689 |
+
|
| 690 |
+
if self.debug:
|
| 691 |
+
self.debug = False
|
| 692 |
+
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
|
| 693 |
+
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
|
| 694 |
+
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
|
| 695 |
+
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
|
| 696 |
+
|
| 697 |
+
if self.training:
|
| 698 |
+
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
|
| 699 |
+
labels = torch.full_like(input_ids, _IGNORE_INDEX)
|
| 700 |
+
labels[:, -answer_ids.shape[1]:] = answer_ids
|
| 701 |
+
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
|
| 702 |
+
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
|
| 703 |
+
else:
|
| 704 |
+
input_ids = inputs.input_ids
|
| 705 |
+
labels = answer_ids
|
| 706 |
+
token_type_ids = inputs.token_type_ids
|
| 707 |
+
if type(audio_array)!=type(None):
|
| 708 |
+
if not self.train:
|
| 709 |
+
return {
|
| 710 |
+
"audio_path": audio_paths,
|
| 711 |
+
'input_ids': input_ids,
|
| 712 |
+
'labels': labels,
|
| 713 |
+
'token_type_ids': token_type_ids,
|
| 714 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 715 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 716 |
+
'input_modes': inputs.input_modes,
|
| 717 |
+
}
|
| 718 |
+
else:
|
| 719 |
+
return {
|
| 720 |
+
'input_ids': input_ids,
|
| 721 |
+
'labels': labels,
|
| 722 |
+
'token_type_ids': token_type_ids,
|
| 723 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 724 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 725 |
+
'input_modes': inputs.input_modes,
|
| 726 |
+
}
|
| 727 |
+
else:
|
| 728 |
+
return {
|
| 729 |
+
'input_ids': input_ids,
|
| 730 |
+
'labels': labels,
|
| 731 |
+
'token_type_ids': token_type_ids,
|
| 732 |
+
'input_audio_embeds': None,
|
| 733 |
+
'audio_embed_sizes': None,
|
| 734 |
+
'input_modes': inputs.input_modes,
|
| 735 |
+
}
|
| 736 |
+
def __len__(self):
|
| 737 |
+
return len(self.data)
|
| 738 |
+
|
| 739 |
+
def __getitem__(self, idx):
|
| 740 |
+
data = self.data[idx]
|
| 741 |
+
return self.prepare_multiturn_model_inputs(
|
| 742 |
+
audio_array=data["audio_array"] if "audio_array" in data else None,
|
| 743 |
+
messages=data['messages'],
|
| 744 |
+
system=data["system"],
|
| 745 |
+
tools=data["tools"]
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 751 |
+
|
| 752 |
+
INSTRUCTION = {
|
| 753 |
+
"ast": [
|
| 754 |
+
"Translate the audio to {0}.",
|
| 755 |
+
"Translate the audio clip into {0}.",
|
| 756 |
+
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
|
| 757 |
+
"Translate the provided audio file into {0}.",
|
| 758 |
+
"Convert the audio speech to {0} text.",
|
| 759 |
+
"Write an {0} translation of the audio file.",
|
| 760 |
+
"Translate spoken words from the audio into {0}.",
|
| 761 |
+
"Create an {0} version of the audio content.",
|
| 762 |
+
"Produce an accurate {0} translation of the audio.",
|
| 763 |
+
"Extract speech from the audio and translate it to {0}.",
|
| 764 |
+
"Turn the audio into readable {0} text.",
|
| 765 |
+
"Write all spoken content from the audio in {0}.",
|
| 766 |
+
"Generate an {0} translation of the speech in the file.",
|
| 767 |
+
"Convert the recording into {0} text.",
|
| 768 |
+
"Accurately translate the audio recording to {0}.",
|
| 769 |
+
"Write down dialogue from the given audio in {0}.",
|
| 770 |
+
"Translate all speech in this audio file to {0}.",
|
| 771 |
+
"Create an accurate {0} version of the speech.",
|
| 772 |
+
"Perform a complete {0} translation of the audio."
|
| 773 |
+
],
|
| 774 |
+
"asr": [
|
| 775 |
+
"Transcribe the audio clip into text.",
|
| 776 |
+
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
|
| 777 |
+
"Transcribe the provided audio file into text.",
|
| 778 |
+
"Convert the audio speech to text.",
|
| 779 |
+
"Write a transcript of the audio file.",
|
| 780 |
+
"Transcribe spoken words from the audio.",
|
| 781 |
+
"Create a text version of the audio content.",
|
| 782 |
+
"Produce a verbatim transcript of the audio.",
|
| 783 |
+
"Extract and transcribe speech from the audio.",
|
| 784 |
+
"Turn the audio into readable text.",
|
| 785 |
+
"Write all spoken words from the audio.",
|
| 786 |
+
"Generate a transcript of the speech in the file.",
|
| 787 |
+
"Convert the recording into a text transcript.",
|
| 788 |
+
"Accurately transcribe the audio recording.",
|
| 789 |
+
"Write down dialogue from the given audio.",
|
| 790 |
+
"Transcribe all speech in this audio file.",
|
| 791 |
+
"Create an accurate text version of the speech.",
|
| 792 |
+
"Perform a complete transcription of the audio."
|
| 793 |
+
],
|
| 794 |
+
}
|
cpp/__pycache__/ASRDataset.cpython-310.pyc
ADDED
|
Binary file (22.5 kB). View file
|
|
|
cpp/__pycache__/speech_conformer_encoder.cpython-310.pyc
ADDED
|
Binary file (79.5 kB). View file
|
|
|
cpp/convert_onnx.ipynb
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"/home/foxconnhy/miniconda3/envs/llamafactory/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 14 |
+
"/home/jeff/.cache/huggingface/modules/transformers_modules/gemma_v1/speech_conformer_encoder.py:2798: FutureWarning: Please specify CheckpointImpl.NO_REENTRANT as CheckpointImpl.REENTRANT will soon be removed as the default and eventually deprecated.\n",
|
| 15 |
+
" lambda i: encoder_checkpoint_wrapper(\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stdout",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
"######################## speech lora #############\n",
|
| 23 |
+
"######################## text lora #############\n"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"name": "stderr",
|
| 28 |
+
"output_type": "stream",
|
| 29 |
+
"text": [
|
| 30 |
+
"Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00, 1.80it/s]\n",
|
| 31 |
+
"Some weights of Gemma3OmniForConditionalGeneration were not initialized from the model checkpoint at /mnt/data-2t/jeff/codes/llm/cpp/gemma_v1 and are newly initialized: ['language_model.model.base_model.model.layers.0.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.0.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.1.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.10.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.11.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.12.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.13.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.14.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.15.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.16.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.17.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.18.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.19.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.2.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.20.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.21.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.22.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.23.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.24.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.25.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.26.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.27.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.28.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.29.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.3.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.30.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.31.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.32.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.33.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.4.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.5.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.6.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.7.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.8.self_attn.v_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.down_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.down_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.gate_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.gate_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.mlp.up_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.mlp.up_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.k_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.k_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.o_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.o_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.q_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.q_proj.lora_B.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.v_proj.lora_A.text.weight', 'language_model.model.base_model.model.layers.9.self_attn.v_proj.lora_B.text.weight']\n",
|
| 32 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
| 33 |
+
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
|
| 34 |
+
]
|
| 35 |
+
}
|
| 36 |
+
],
|
| 37 |
+
"source": [
|
| 38 |
+
"from io import BytesIO\n",
|
| 39 |
+
"import torch\n",
|
| 40 |
+
"import numpy as np\n",
|
| 41 |
+
"from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"# converter = opencc.OpenCC('s2tw.json')\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"model_id = \"/mnt/data-2t/jeff/codes/llm/cpp/gemma_v1\"\n",
|
| 46 |
+
"revision = \"main\" #\"v1.0\"\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"model = AutoModel.from_pretrained(\n",
|
| 49 |
+
" model_id, device_map=\"cpu\", revision = revision, trust_remote_code=True\n",
|
| 50 |
+
").eval()\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"processor = AutoProcessor.from_pretrained(\n",
|
| 53 |
+
" model_id, revision = revision, trust_remote_code=True\n",
|
| 54 |
+
")"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 2,
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [
|
| 62 |
+
{
|
| 63 |
+
"data": {
|
| 64 |
+
"text/plain": [
|
| 65 |
+
"Sequential(\n",
|
| 66 |
+
" (0): Linear(in_features=1024, out_features=2560, bias=True)\n",
|
| 67 |
+
" (1): GELU(approximate='none')\n",
|
| 68 |
+
" (2): Linear(in_features=2560, out_features=2560, bias=True)\n",
|
| 69 |
+
")"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
"execution_count": 2,
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"output_type": "execute_result"
|
| 75 |
+
}
|
| 76 |
+
],
|
| 77 |
+
"source": [
|
| 78 |
+
"model.audio_tower\n",
|
| 79 |
+
"model.audio_projector"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": 179,
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"from ASRDataset import *\n",
|
| 89 |
+
"pickup_dataset = MultiturnAudioDataset(split='train',processor=processor,json_path='/mnt/data-2t/jeff/codes/llm/cpp/sample_data/pickup_processed.json')"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": 180,
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [
|
| 97 |
+
{
|
| 98 |
+
"name": "stdout",
|
| 99 |
+
"output_type": "stream",
|
| 100 |
+
"text": [
|
| 101 |
+
"torch.Size([1, 256, 80])\n",
|
| 102 |
+
"torch.Size([1, 217, 80])\n",
|
| 103 |
+
"torch.Size([1, 77, 80])\n",
|
| 104 |
+
"torch.Size([1, 580, 80])\n"
|
| 105 |
+
]
|
| 106 |
+
}
|
| 107 |
+
],
|
| 108 |
+
"source": [
|
| 109 |
+
"for i in range(len(pickup_dataset)):\n",
|
| 110 |
+
" inp = pickup_dataset.__getitem__(i)\n",
|
| 111 |
+
" print(inp['input_audio_embeds'].shape)"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": 5,
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [
|
| 119 |
+
{
|
| 120 |
+
"data": {
|
| 121 |
+
"text/plain": [
|
| 122 |
+
"torch.Size([1, 100, 2560])"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
"execution_count": 5,
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"output_type": "execute_result"
|
| 128 |
+
}
|
| 129 |
+
],
|
| 130 |
+
"source": [
|
| 131 |
+
"inp = pickup_dataset.__getitem__(3)\n",
|
| 132 |
+
"fea,mask = model.audio_tower(inp['input_audio_embeds'],torch.ones(inp['input_audio_embeds'].shape[:2]))\n",
|
| 133 |
+
"model.audio_projector(fea).shape"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": null,
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [],
|
| 141 |
+
"source": [
|
| 142 |
+
"import torch \n",
|
| 143 |
+
"import torch.nn as nn\n",
|
| 144 |
+
"from speech_conformer_encoder import ConformerEncoder\n",
|
| 145 |
+
"class Gemma3AudioEncoder(nn.Module):\n",
|
| 146 |
+
" def __init__(self,):\n",
|
| 147 |
+
" super().__init__()\n",
|
| 148 |
+
" audio_config = model.config.audio_config.to_diff_dict()\n",
|
| 149 |
+
" for item in ['transformers_version', 'model_type', 'torch_dtype']:\n",
|
| 150 |
+
" if item in audio_config:\n",
|
| 151 |
+
" audio_config.pop(item)\n",
|
| 152 |
+
" # self.audio_tower = model.audio_tower\n",
|
| 153 |
+
" # self.audio_projector = model.audio_projector\n",
|
| 154 |
+
" self.audio_tower = ConformerEncoder(**audio_config)#model.audio_tower\n",
|
| 155 |
+
" self.audio_projector = nn.Sequential(\n",
|
| 156 |
+
" nn.Linear(in_features=1024, out_features=2560, bias=True),\n",
|
| 157 |
+
" nn.GELU(approximate='none'),\n",
|
| 158 |
+
" nn.Linear(in_features=2560, out_features=2560, bias=True))#model.audio_projector\n",
|
| 159 |
+
" def forward(self,x,mask):\n",
|
| 160 |
+
" # mask = torch.ones(x.shape[:2])\n",
|
| 161 |
+
" x,_ = self.audio_tower(x,mask)\n",
|
| 162 |
+
" x = self.audio_projector(x)\n",
|
| 163 |
+
" return x\n",
|
| 164 |
+
"audio_encoder = Gemma3AudioEncoder()\n",
|
| 165 |
+
"import copy\n",
|
| 166 |
+
"audio_encoder.audio_tower.encoder_embedding=copy.deepcopy(model.audio_tower.encoder_embedding)\n",
|
| 167 |
+
"audio_encoder.audio_projector.load_state_dict(model.audio_projector.state_dict())\n",
|
| 168 |
+
"audio_encoder.audio_tower.load_state_dict(model.audio_tower.state_dict())"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "code",
|
| 173 |
+
"execution_count": null,
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"outputs": [],
|
| 176 |
+
"source": [
|
| 177 |
+
"import numpy as np\n",
|
| 178 |
+
"import onnx\n",
|
| 179 |
+
"import onnxruntime as ort\n",
|
| 180 |
+
"import onnxscript\n",
|
| 181 |
+
"import os\n",
|
| 182 |
+
"import requests\n",
|
| 183 |
+
"import shutil\n",
|
| 184 |
+
"import soundfile\n",
|
| 185 |
+
"import subprocess\n",
|
| 186 |
+
"import sys\n",
|
| 187 |
+
"import torch\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"from onnx import helper, numpy_helper, TensorProto\n",
|
| 190 |
+
"from onnxruntime_genai.models.builder import create_model\n",
|
| 191 |
+
"from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper\n",
|
| 192 |
+
"from onnxscript import ir\n",
|
| 193 |
+
"from torch.export import Dim, export\n",
|
| 194 |
+
"def build_speech(outputdir='./onnx_files'):\n",
|
| 195 |
+
" # TorchScript export\n",
|
| 196 |
+
" dummy_inputs = (\n",
|
| 197 |
+
" torch.randn((1,97,80)),\n",
|
| 198 |
+
" torch.ones((1,97))\n",
|
| 199 |
+
" #inputs[\"input_audio_embeds\"], # audio_embeds: torch.FloatTensor\n",
|
| 200 |
+
" #inputs[\"audio_attention_mask\"], # audio_attention_mask: torch.BoolTensor\n",
|
| 201 |
+
" # inputs[\"audio_embed_sizes\"], # audio_sizes: torch.LongTensor\n",
|
| 202 |
+
" # inputs[\"input_mode\"], # audio_projection_mode: int\n",
|
| 203 |
+
" )\n",
|
| 204 |
+
" filename = \"phi-4-mm-speech.onnx\"\n",
|
| 205 |
+
"\n",
|
| 206 |
+
" temp_folder_1 = os.path.join(outputdir, \"speech_init_export\")\n",
|
| 207 |
+
" os.makedirs(temp_folder_1, exist_ok=True)\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" fpath_1 = os.path.join(temp_folder_1, filename)\n",
|
| 210 |
+
" torch._dynamo.config.capture_scalar_outputs = True\n",
|
| 211 |
+
" onnx_program = torch.onnx.export(audio_encoder, dummy_inputs, fpath_1,\n",
|
| 212 |
+
" input_names=[\"audio_embeds\", \"audio_attention_mask\"], \n",
|
| 213 |
+
" output_names=[\"audio_features\"],\n",
|
| 214 |
+
" opset_version=20,\n",
|
| 215 |
+
" dynamic_axes={\n",
|
| 216 |
+
" \"audio_embeds\": {0:'B',1: \"L\"},\n",
|
| 217 |
+
" \"audio_attention_mask\": {0:'B',1: \"L\"},\n",
|
| 218 |
+
" },\n",
|
| 219 |
+
" )\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"build_speech()"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"cell_type": "code",
|
| 226 |
+
"execution_count": 44,
|
| 227 |
+
"metadata": {},
|
| 228 |
+
"outputs": [],
|
| 229 |
+
"source": [
|
| 230 |
+
"import onnxruntime as ort\n",
|
| 231 |
+
"import numpy as np\n",
|
| 232 |
+
"ort_sess = ort.InferenceSession(\"/mnt/data-2t/jeff/codes/llm/cpp/onnx_files/speech_init_export/phi-4-mm-speech.onnx\")"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "code",
|
| 237 |
+
"execution_count": 45,
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"outputs": [
|
| 240 |
+
{
|
| 241 |
+
"data": {
|
| 242 |
+
"text/plain": [
|
| 243 |
+
"(2, 111, 2560)"
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
"execution_count": 45,
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"output_type": "execute_result"
|
| 249 |
+
}
|
| 250 |
+
],
|
| 251 |
+
"source": [
|
| 252 |
+
"import warnings\n",
|
| 253 |
+
"warnings.filterwarnings('ignore')\n",
|
| 254 |
+
"from tqdm import tqdm\n",
|
| 255 |
+
"import torch\n",
|
| 256 |
+
"import numpy as np\n",
|
| 257 |
+
"a=[]\n",
|
| 258 |
+
"# for i in tqdm(range(10000)):\n",
|
| 259 |
+
"# try:\n",
|
| 260 |
+
"ort_sess.run(None, {\"audio_embeds\": np.array(torch.randn(1,97,80),dtype=np.float32),\n",
|
| 261 |
+
" # \"audio_attention_mask\":np.ones((1,97),dtype=np.float32)\n",
|
| 262 |
+
" }\n",
|
| 263 |
+
" )\n",
|
| 264 |
+
" # print(i)\n",
|
| 265 |
+
" # a.append(i)\n",
|
| 266 |
+
" # except:\n",
|
| 267 |
+
" # pass\n",
|
| 268 |
+
"ort_sess.run(None, {\"audio_embeds\": np.array(torch.randn(2,888,80),dtype=np.float32),\n",
|
| 269 |
+
" # \"audio_attention_mask\":np.ones((2,97),dtype=np.float32)\n",
|
| 270 |
+
" }\n",
|
| 271 |
+
" )[0].shape"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "markdown",
|
| 276 |
+
"metadata": {},
|
| 277 |
+
"source": [
|
| 278 |
+
"# Python inference time check"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"cell_type": "code",
|
| 283 |
+
"execution_count": null,
|
| 284 |
+
"metadata": {},
|
| 285 |
+
"outputs": [],
|
| 286 |
+
"source": [
|
| 287 |
+
"import time\n",
|
| 288 |
+
"total = 0\n",
|
| 289 |
+
"_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n",
|
| 290 |
+
"for i in range(100):\n",
|
| 291 |
+
" now = time.time()\n",
|
| 292 |
+
" inp = np.random.randn(np.random.randint(16240, 48240)).reshape(1,-1)#np.array(torch.randn(1,np.random.randint(100,300),80),dtype=np.float32)\n",
|
| 293 |
+
" inp = _extract_features(inp,16000).reshape(1,-1,80)\n",
|
| 294 |
+
" now = time.time()\n",
|
| 295 |
+
" # inp = np.array(torch.randn(1,150,80),dtype=np.float32)\n",
|
| 296 |
+
" ort_sess.run(None, {\"audio_embeds\": inp,\n",
|
| 297 |
+
" # \"audio_attention_mask\":np.ones((1,97),dtype=np.float32)\n",
|
| 298 |
+
" })\n",
|
| 299 |
+
" total += time.time()-now\n",
|
| 300 |
+
" \n",
|
| 301 |
+
"total,total/100"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"cell_type": "code",
|
| 306 |
+
"execution_count": 238,
|
| 307 |
+
"metadata": {},
|
| 308 |
+
"outputs": [
|
| 309 |
+
{
|
| 310 |
+
"data": {
|
| 311 |
+
"text/plain": [
|
| 312 |
+
"24240"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
"execution_count": 238,
|
| 316 |
+
"metadata": {},
|
| 317 |
+
"output_type": "execute_result"
|
| 318 |
+
}
|
| 319 |
+
],
|
| 320 |
+
"source": [
|
| 321 |
+
"149*160+400"
|
| 322 |
+
]
|
| 323 |
+
},
|
| 324 |
+
{
|
| 325 |
+
"cell_type": "code",
|
| 326 |
+
"execution_count": 233,
|
| 327 |
+
"metadata": {},
|
| 328 |
+
"outputs": [
|
| 329 |
+
{
|
| 330 |
+
"data": {
|
| 331 |
+
"text/plain": [
|
| 332 |
+
"(1, 40917)"
|
| 333 |
+
]
|
| 334 |
+
},
|
| 335 |
+
"execution_count": 233,
|
| 336 |
+
"metadata": {},
|
| 337 |
+
"output_type": "execute_result"
|
| 338 |
+
}
|
| 339 |
+
],
|
| 340 |
+
"source": [
|
| 341 |
+
"np.random.randn(np.random.randint(16240, 48240)).reshape(1,-1).shape"
|
| 342 |
+
]
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"cell_type": "code",
|
| 346 |
+
"execution_count": null,
|
| 347 |
+
"metadata": {},
|
| 348 |
+
"outputs": [
|
| 349 |
+
{
|
| 350 |
+
"data": {
|
| 351 |
+
"text/plain": [
|
| 352 |
+
"(10.608245611190796, 0.10608245611190796)"
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
"execution_count": 218,
|
| 356 |
+
"metadata": {},
|
| 357 |
+
"output_type": "execute_result"
|
| 358 |
+
}
|
| 359 |
+
],
|
| 360 |
+
"source": [
|
| 361 |
+
"import time\n",
|
| 362 |
+
"total = 0\n",
|
| 363 |
+
"for i in range(100):\n",
|
| 364 |
+
" tmp = torch.randn(1,np.random.randint(100,300),80)\n",
|
| 365 |
+
" mask = torch.ones(tmp.shape[:2])\n",
|
| 366 |
+
" now = time.time()\n",
|
| 367 |
+
" audio_encoder(tmp,mask)\n",
|
| 368 |
+
" total += time.time()-now\n",
|
| 369 |
+
"total,total/100"
|
| 370 |
+
]
|
| 371 |
+
},
|
| 372 |
+
{
|
| 373 |
+
"cell_type": "markdown",
|
| 374 |
+
"metadata": {},
|
| 375 |
+
"source": [
|
| 376 |
+
"# C++ ERROR check"
|
| 377 |
+
]
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"cell_type": "code",
|
| 381 |
+
"execution_count": 167,
|
| 382 |
+
"metadata": {},
|
| 383 |
+
"outputs": [
|
| 384 |
+
{
|
| 385 |
+
"data": {
|
| 386 |
+
"text/plain": [
|
| 387 |
+
"tensor([[[ 0.3246, 0.0295, 0.1076, ..., -0.1125, -0.0894, -0.3800],\n",
|
| 388 |
+
" [ 0.3267, -0.2442, 0.2653, ..., 0.7783, -0.6049, -1.0858],\n",
|
| 389 |
+
" [ 0.1797, 0.0438, 0.9673, ..., 0.5126, -0.5657, -0.7050],\n",
|
| 390 |
+
" ...,\n",
|
| 391 |
+
" [ 0.0261, -0.0324, 0.0230, ..., -0.1303, 0.0343, 0.1486],\n",
|
| 392 |
+
" [ 0.1655, -0.3327, 0.4232, ..., 0.0513, 0.4222, -0.3645],\n",
|
| 393 |
+
" [ 0.1147, -0.1201, 0.4198, ..., 0.6170, 0.0838, -0.1409]]],\n",
|
| 394 |
+
" grad_fn=<ViewBackward0>)"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
"execution_count": 167,
|
| 398 |
+
"metadata": {},
|
| 399 |
+
"output_type": "execute_result"
|
| 400 |
+
}
|
| 401 |
+
],
|
| 402 |
+
"source": [
|
| 403 |
+
"inp = pickup_dataset.__getitem__(3)\n",
|
| 404 |
+
"fea,mask = model.audio_tower(inp['input_audio_embeds'],torch.ones(inp['input_audio_embeds'].shape[:2]))\n",
|
| 405 |
+
"model.audio_projector(fea)"
|
| 406 |
+
]
|
| 407 |
+
},
|
| 408 |
+
{
|
| 409 |
+
"cell_type": "code",
|
| 410 |
+
"execution_count": 201,
|
| 411 |
+
"metadata": {},
|
| 412 |
+
"outputs": [
|
| 413 |
+
{
|
| 414 |
+
"data": {
|
| 415 |
+
"text/plain": [
|
| 416 |
+
"array([[[ 0.13004433, -0.06643961, 0.01333247, ..., -0.05643693,\n",
|
| 417 |
+
" -0.23922557, 0.569423 ],\n",
|
| 418 |
+
" [-0.75552 , -0.05047493, -0.82725084, ..., 0.32261163,\n",
|
| 419 |
+
" -0.14968234, -0.7078437 ],\n",
|
| 420 |
+
" [-0.6673857 , 0.33906737, -0.6191502 , ..., 0.04259709,\n",
|
| 421 |
+
" -0.01194861, 0.27635992],\n",
|
| 422 |
+
" ...,\n",
|
| 423 |
+
" [ 0.02916821, -0.03163592, 0.02736526, ..., -0.12979224,\n",
|
| 424 |
+
" 0.03317374, 0.15346158],\n",
|
| 425 |
+
" [-0.8559882 , -0.5196625 , 0.2549707 , ..., 0.28192428,\n",
|
| 426 |
+
" 1.4099622 , -0.15940394],\n",
|
| 427 |
+
" [-0.20253824, -0.30478072, -0.6786582 , ..., 0.08860758,\n",
|
| 428 |
+
" -0.12145798, 0.525889 ]]], dtype=float32)"
|
| 429 |
+
]
|
| 430 |
+
},
|
| 431 |
+
"execution_count": 201,
|
| 432 |
+
"metadata": {},
|
| 433 |
+
"output_type": "execute_result"
|
| 434 |
+
}
|
| 435 |
+
],
|
| 436 |
+
"source": [
|
| 437 |
+
"inp = pickup_dataset.__getitem__(0)\n",
|
| 438 |
+
"res = ort_sess.run(None, {\"audio_embeds\": np.array(inp['input_audio_embeds'],dtype=np.float32),\n",
|
| 439 |
+
" # \"audio_attention_mask\":np.ones((2,97),dtype=np.float32)\n",
|
| 440 |
+
" }\n",
|
| 441 |
+
" )[0]\n",
|
| 442 |
+
"res"
|
| 443 |
+
]
|
| 444 |
+
},
|
| 445 |
+
{
|
| 446 |
+
"cell_type": "code",
|
| 447 |
+
"execution_count": 208,
|
| 448 |
+
"metadata": {},
|
| 449 |
+
"outputs": [
|
| 450 |
+
{
|
| 451 |
+
"data": {
|
| 452 |
+
"text/plain": [
|
| 453 |
+
"array([[[ 0.130969 , -0.0697925, 0.0150866, ..., -0.0559536,\n",
|
| 454 |
+
" -0.239062 , 0.567436 ],\n",
|
| 455 |
+
" [-0.753288 , -0.0582227, -0.825365 , ..., 0.320587 ,\n",
|
| 456 |
+
" -0.153626 , -0.709664 ],\n",
|
| 457 |
+
" [-0.656874 , 0.342632 , -0.607641 , ..., 0.0383743,\n",
|
| 458 |
+
" -0.0218912, 0.269968 ],\n",
|
| 459 |
+
" ...,\n",
|
| 460 |
+
" [ 0.0291714, -0.0316175, 0.027369 , ..., -0.129825 ,\n",
|
| 461 |
+
" 0.033166 , 0.153453 ],\n",
|
| 462 |
+
" [-0.854555 , -0.530883 , 0.258313 , ..., 0.279057 ,\n",
|
| 463 |
+
" 1.40658 , -0.159066 ],\n",
|
| 464 |
+
" [-0.197598 , -0.306157 , -0.67907 , ..., 0.0915015,\n",
|
| 465 |
+
" -0.124402 , 0.52159 ]]])"
|
| 466 |
+
]
|
| 467 |
+
},
|
| 468 |
+
"execution_count": 208,
|
| 469 |
+
"metadata": {},
|
| 470 |
+
"output_type": "execute_result"
|
| 471 |
+
}
|
| 472 |
+
],
|
| 473 |
+
"source": [
|
| 474 |
+
"f = open('/mnt/data-2t/jeff/codes/llm/cpp/inference/f0.txt')\n",
|
| 475 |
+
"content = f.readlines()\n",
|
| 476 |
+
"f.close()\n",
|
| 477 |
+
"audio_fea_cpp = np.array([float(i) for i in content[0].split(',')]).reshape(1,-1,2560)\n",
|
| 478 |
+
"audio_fea_cpp"
|
| 479 |
+
]
|
| 480 |
+
},
|
| 481 |
+
{
|
| 482 |
+
"cell_type": "code",
|
| 483 |
+
"execution_count": 202,
|
| 484 |
+
"metadata": {},
|
| 485 |
+
"outputs": [
|
| 486 |
+
{
|
| 487 |
+
"data": {
|
| 488 |
+
"text/plain": [
|
| 489 |
+
"(array([[[0.917797, 1.33496 , 1.9894 , ..., 6.60723 , 6.95787 ,\n",
|
| 490 |
+
" 7.20139 ],\n",
|
| 491 |
+
" [0. , 0. , 0. , ..., 5.99914 , 6.11214 ,\n",
|
| 492 |
+
" 6.40908 ],\n",
|
| 493 |
+
" [0. , 0. , 0. , ..., 5.1184 , 5.36291 ,\n",
|
| 494 |
+
" 5.14623 ],\n",
|
| 495 |
+
" ...,\n",
|
| 496 |
+
" [0. , 0. , 0. , ..., 6.25256 , 6.29312 ,\n",
|
| 497 |
+
" 7.05511 ],\n",
|
| 498 |
+
" [0. , 0. , 0. , ..., 6.49829 , 6.7198 ,\n",
|
| 499 |
+
" 7.08144 ],\n",
|
| 500 |
+
" [0. , 0. , 1.08376 , ..., 5.43068 , 5.97577 ,\n",
|
| 501 |
+
" 6.35748 ]]]),\n",
|
| 502 |
+
" tensor([[[0.8826, 1.3054, 1.9652, ..., 6.6069, 6.9578, 7.2011],\n",
|
| 503 |
+
" [0.0000, 0.0000, 0.0000, ..., 5.9991, 6.1121, 6.4091],\n",
|
| 504 |
+
" [0.0000, 0.0000, 0.0000, ..., 5.1147, 5.3624, 5.1428],\n",
|
| 505 |
+
" ...,\n",
|
| 506 |
+
" [0.0000, 0.0000, 0.0000, ..., 6.2526, 6.2931, 7.0548],\n",
|
| 507 |
+
" [0.0000, 0.0000, 0.0000, ..., 6.4981, 6.7198, 7.0807],\n",
|
| 508 |
+
" [0.0000, 0.0000, 1.1479, ..., 5.4311, 5.9743, 6.3568]]]))"
|
| 509 |
+
]
|
| 510 |
+
},
|
| 511 |
+
"execution_count": 202,
|
| 512 |
+
"metadata": {},
|
| 513 |
+
"output_type": "execute_result"
|
| 514 |
+
}
|
| 515 |
+
],
|
| 516 |
+
"source": [
|
| 517 |
+
"f = open('/mnt/data-2t/jeff/codes/llm/cpp/inference/matrix_output.txt')\n",
|
| 518 |
+
"txtlines = f.readlines()\n",
|
| 519 |
+
"f.close()\n",
|
| 520 |
+
"inp_emb_cpp = np.array([float(i) for l in txtlines for i in l.split(',')]).reshape(1,-1,80)\n",
|
| 521 |
+
"inp_emb_cpp,pickup_dataset.__getitem__(0)['input_audio_embeds']"
|
| 522 |
+
]
|
| 523 |
+
},
|
| 524 |
+
{
|
| 525 |
+
"cell_type": "markdown",
|
| 526 |
+
"metadata": {},
|
| 527 |
+
"source": [
|
| 528 |
+
"# Python preprocessor"
|
| 529 |
+
]
|
| 530 |
+
},
|
| 531 |
+
{
|
| 532 |
+
"cell_type": "code",
|
| 533 |
+
"execution_count": null,
|
| 534 |
+
"metadata": {},
|
| 535 |
+
"outputs": [
|
| 536 |
+
{
|
| 537 |
+
"data": {
|
| 538 |
+
"text/plain": [
|
| 539 |
+
"(353, 80)"
|
| 540 |
+
]
|
| 541 |
+
},
|
| 542 |
+
"execution_count": 66,
|
| 543 |
+
"metadata": {},
|
| 544 |
+
"output_type": "execute_result"
|
| 545 |
+
}
|
| 546 |
+
],
|
| 547 |
+
"source": [
|
| 548 |
+
"# modify the code : \n",
|
| 549 |
+
"# 1. input model and input pcm from args. \n",
|
| 550 |
+
"# 2. add model input preprocessor by following python code. The wav input of _extract_features which is an audio array\n",
|
| 551 |
+
"# 3. the onnx model input is [batch,frames,feature size] = [-1,-1,80]\n",
|
| 552 |
+
"\n",
|
| 553 |
+
"def _extract_spectrogram(wav, fs):\n",
|
| 554 |
+
" \"\"\"Extract spectrogram features from waveform.\n",
|
| 555 |
+
" Args:\n",
|
| 556 |
+
" wav (1D array): waveform of the input\n",
|
| 557 |
+
" fs (int): sampling rate of the waveform, 16000.\n",
|
| 558 |
+
" Output:\n",
|
| 559 |
+
" log_fbank (2D array): a TxD matrix of log Mel filterbank features.\n",
|
| 560 |
+
" D=80, and T is the number of frames.\n",
|
| 561 |
+
" \"\"\"\n",
|
| 562 |
+
" if wav.ndim > 1:\n",
|
| 563 |
+
" wav = np.squeeze(wav)\n",
|
| 564 |
+
"\n",
|
| 565 |
+
" # by default, we extract the mean if stereo\n",
|
| 566 |
+
" if len(wav.shape) == 2:\n",
|
| 567 |
+
" wav = wav.mean(1)\n",
|
| 568 |
+
"\n",
|
| 569 |
+
" preemphasis = 0.97\n",
|
| 570 |
+
" n_fft = 512\n",
|
| 571 |
+
" win_length = 400\n",
|
| 572 |
+
" hop_length = 160\n",
|
| 573 |
+
" fft_window = np.hamming(400)\n",
|
| 574 |
+
"\n",
|
| 575 |
+
" # Spec 1: SpeechLib cut remaining sample insufficient for a hop\n",
|
| 576 |
+
" n_batch = (wav.shape[0] - win_length) // hop_length + 1\n",
|
| 577 |
+
" # Here we don't use stride_tricks since the input array may not satisfy\n",
|
| 578 |
+
" # memory layout requirement and we need writeable output\n",
|
| 579 |
+
" # Here we only use list of views before copy to desination\n",
|
| 580 |
+
" # so it is more efficient than broadcasting\n",
|
| 581 |
+
" y_frames = np.array(\n",
|
| 582 |
+
" [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],\n",
|
| 583 |
+
" dtype=np.float32,\n",
|
| 584 |
+
" )\n",
|
| 585 |
+
"\n",
|
| 586 |
+
" # Spec 2: SpeechLib applies preemphasis within each batch\n",
|
| 587 |
+
" y_frames_prev = np.roll(y_frames, 1, axis=1)\n",
|
| 588 |
+
" y_frames_prev[:, 0] = y_frames_prev[:, 1]\n",
|
| 589 |
+
" y_frames = (y_frames - preemphasis * y_frames_prev) * 32768\n",
|
| 590 |
+
"\n",
|
| 591 |
+
" S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)\n",
|
| 592 |
+
" spec = np.abs(S).astype(np.float32)\n",
|
| 593 |
+
" return spec\n",
|
| 594 |
+
"def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):\n",
|
| 595 |
+
" \"\"\"Create a Mel filter-bank the same as SpeechLib FbankFC.\n",
|
| 596 |
+
"\n",
|
| 597 |
+
" Args:\n",
|
| 598 |
+
" sample_rate (int): Sample rate in Hz. number > 0 [scalar]\n",
|
| 599 |
+
" n_fft (int): FFT size. int > 0 [scalar]\n",
|
| 600 |
+
" n_mel (int): Mel filter size. int > 0 [scalar]\n",
|
| 601 |
+
" fmin (float): lowest frequency (in Hz). If None use 0.0.\n",
|
| 602 |
+
" float >= 0 [scalar]\n",
|
| 603 |
+
" fmax: highest frequency (in Hz). If None use sample_rate / 2.\n",
|
| 604 |
+
" float >= 0 [scalar]\n",
|
| 605 |
+
"\n",
|
| 606 |
+
" Returns\n",
|
| 607 |
+
" out (numpy.ndarray): Mel transform matrix\n",
|
| 608 |
+
" [shape=(n_mels, 1 + n_fft/2)]\n",
|
| 609 |
+
" \"\"\"\n",
|
| 610 |
+
"\n",
|
| 611 |
+
" bank_width = int(n_fft // 2 + 1)\n",
|
| 612 |
+
" if fmax is None:\n",
|
| 613 |
+
" fmax = sample_rate / 2\n",
|
| 614 |
+
" if fmin is None:\n",
|
| 615 |
+
" fmin = 0\n",
|
| 616 |
+
" assert fmin >= 0, \"fmin cannot be negtive\"\n",
|
| 617 |
+
" assert fmin < fmax <= sample_rate / 2, \"fmax must be between (fmin, samplerate / 2]\"\n",
|
| 618 |
+
"\n",
|
| 619 |
+
" def mel(f):\n",
|
| 620 |
+
" return 1127.0 * np.log(1.0 + f / 700.0)\n",
|
| 621 |
+
"\n",
|
| 622 |
+
" def bin2mel(fft_bin):\n",
|
| 623 |
+
" return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))\n",
|
| 624 |
+
"\n",
|
| 625 |
+
" def f2bin(f):\n",
|
| 626 |
+
" return int((f * n_fft / sample_rate) + 0.5)\n",
|
| 627 |
+
"\n",
|
| 628 |
+
" # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]\n",
|
| 629 |
+
" klo = f2bin(fmin) + 1\n",
|
| 630 |
+
" khi = f2bin(fmax)\n",
|
| 631 |
+
"\n",
|
| 632 |
+
" khi = max(khi, klo)\n",
|
| 633 |
+
"\n",
|
| 634 |
+
" # Spec 2: SpeechLib uses trianges in Mel space\n",
|
| 635 |
+
" mlo = mel(fmin)\n",
|
| 636 |
+
" mhi = mel(fmax)\n",
|
| 637 |
+
" m_centers = np.linspace(mlo, mhi, n_mels + 2)\n",
|
| 638 |
+
" ms = (mhi - mlo) / (n_mels + 1)\n",
|
| 639 |
+
"\n",
|
| 640 |
+
" matrix = np.zeros((n_mels, bank_width), dtype=np.float32)\n",
|
| 641 |
+
" for m in range(0, n_mels):\n",
|
| 642 |
+
" left = m_centers[m]\n",
|
| 643 |
+
" center = m_centers[m + 1]\n",
|
| 644 |
+
" right = m_centers[m + 2]\n",
|
| 645 |
+
" for fft_bin in range(klo, khi):\n",
|
| 646 |
+
" mbin = bin2mel(fft_bin)\n",
|
| 647 |
+
" if left < mbin < right:\n",
|
| 648 |
+
" matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms\n",
|
| 649 |
+
"\n",
|
| 650 |
+
" return matrix\n",
|
| 651 |
+
"\n",
|
| 652 |
+
"def _extract_features(wav, fs):\n",
|
| 653 |
+
" \"\"\"Extract log filterbank features from waveform.\n",
|
| 654 |
+
" Args:\n",
|
| 655 |
+
" wav (1D array): waveform of the input\n",
|
| 656 |
+
" fs (int): sampling rate of the waveform, 16000 or 8000.\n",
|
| 657 |
+
" If fs=8000, the waveform will be resampled to 16000Hz.\n",
|
| 658 |
+
" Output:\n",
|
| 659 |
+
" log_fbank (2D array): a TxD matrix of log Mel filterbank features.\n",
|
| 660 |
+
" D=80, and T is the number of frames.\n",
|
| 661 |
+
" \"\"\"\n",
|
| 662 |
+
" spec = _extract_spectrogram(wav, fs)\n",
|
| 663 |
+
" spec_power = spec**2\n",
|
| 664 |
+
"\n",
|
| 665 |
+
" fbank_power = np.clip(spec_power.dot(_mel), 1.0, None)\n",
|
| 666 |
+
" log_fbank = np.log(fbank_power).astype(np.float32)\n",
|
| 667 |
+
"\n",
|
| 668 |
+
" return log_fbank\n",
|
| 669 |
+
"\n",
|
| 670 |
+
"## example \n",
|
| 671 |
+
"## input shape of arr is [1, 56832], output shape will be (353,80)\n",
|
| 672 |
+
"_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n",
|
| 673 |
+
"output = _extract_features(arr,16000)"
|
| 674 |
+
]
|
| 675 |
+
},
|
| 676 |
+
{
|
| 677 |
+
"cell_type": "code",
|
| 678 |
+
"execution_count": 227,
|
| 679 |
+
"metadata": {},
|
| 680 |
+
"outputs": [
|
| 681 |
+
{
|
| 682 |
+
"data": {
|
| 683 |
+
"text/plain": [
|
| 684 |
+
"(256, 80)"
|
| 685 |
+
]
|
| 686 |
+
},
|
| 687 |
+
"execution_count": 227,
|
| 688 |
+
"metadata": {},
|
| 689 |
+
"output_type": "execute_result"
|
| 690 |
+
}
|
| 691 |
+
],
|
| 692 |
+
"source": [
|
| 693 |
+
"_mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=16000//2-80-230).T\n",
|
| 694 |
+
"output = _extract_features(arr,16000)\n",
|
| 695 |
+
"output.shape"
|
| 696 |
+
]
|
| 697 |
+
},
|
| 698 |
+
{
|
| 699 |
+
"cell_type": "code",
|
| 700 |
+
"execution_count": null,
|
| 701 |
+
"metadata": {},
|
| 702 |
+
"outputs": [
|
| 703 |
+
{
|
| 704 |
+
"data": {
|
| 705 |
+
"text/plain": [
|
| 706 |
+
"256"
|
| 707 |
+
]
|
| 708 |
+
},
|
| 709 |
+
"execution_count": 228,
|
| 710 |
+
"metadata": {},
|
| 711 |
+
"output_type": "execute_result"
|
| 712 |
+
}
|
| 713 |
+
],
|
| 714 |
+
"source": [
|
| 715 |
+
"(41239-400)//160+1 100~300"
|
| 716 |
+
]
|
| 717 |
+
},
|
| 718 |
+
{
|
| 719 |
+
"cell_type": "code",
|
| 720 |
+
"execution_count": 229,
|
| 721 |
+
"metadata": {},
|
| 722 |
+
"outputs": [
|
| 723 |
+
{
|
| 724 |
+
"data": {
|
| 725 |
+
"text/plain": [
|
| 726 |
+
"(16240, 48240)"
|
| 727 |
+
]
|
| 728 |
+
},
|
| 729 |
+
"execution_count": 229,
|
| 730 |
+
"metadata": {},
|
| 731 |
+
"output_type": "execute_result"
|
| 732 |
+
}
|
| 733 |
+
],
|
| 734 |
+
"source": [
|
| 735 |
+
"99*160+400,299*160+400"
|
| 736 |
+
]
|
| 737 |
+
},
|
| 738 |
+
{
|
| 739 |
+
"cell_type": "code",
|
| 740 |
+
"execution_count": null,
|
| 741 |
+
"metadata": {},
|
| 742 |
+
"outputs": [],
|
| 743 |
+
"source": []
|
| 744 |
+
}
|
| 745 |
+
],
|
| 746 |
+
"metadata": {
|
| 747 |
+
"kernelspec": {
|
| 748 |
+
"display_name": "llamafactory",
|
| 749 |
+
"language": "python",
|
| 750 |
+
"name": "python3"
|
| 751 |
+
},
|
| 752 |
+
"language_info": {
|
| 753 |
+
"codemirror_mode": {
|
| 754 |
+
"name": "ipython",
|
| 755 |
+
"version": 3
|
| 756 |
+
},
|
| 757 |
+
"file_extension": ".py",
|
| 758 |
+
"mimetype": "text/x-python",
|
| 759 |
+
"name": "python",
|
| 760 |
+
"nbconvert_exporter": "python",
|
| 761 |
+
"pygments_lexer": "ipython3",
|
| 762 |
+
"version": "3.10.16"
|
| 763 |
+
}
|
| 764 |
+
},
|
| 765 |
+
"nbformat": 4,
|
| 766 |
+
"nbformat_minor": 2
|
| 767 |
+
}
|
cpp/convert_tensorRT.ipynb
ADDED
|
File without changes
|
cpp/gemma_v1/ASRDataset.py
ADDED
|
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import sacrebleu
|
| 13 |
+
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from transformers import (
|
| 18 |
+
BatchFeature,
|
| 19 |
+
)
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
from datasets import Audio
|
| 23 |
+
import random
|
| 24 |
+
from copy import deepcopy
|
| 25 |
+
import torchaudio
|
| 26 |
+
|
| 27 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 28 |
+
_IGNORE_INDEX = -100
|
| 29 |
+
class BaseAudioDataset(Dataset):
|
| 30 |
+
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 31 |
+
self.processor = processor
|
| 32 |
+
self.training = "train" in split or 'other' in split
|
| 33 |
+
self.debug = debug
|
| 34 |
+
self.sampling_rate = sampling_rate
|
| 35 |
+
self.name = ""
|
| 36 |
+
|
| 37 |
+
def set_dataset_name(self, name):
|
| 38 |
+
self.name = name
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
|
| 42 |
+
original_size = len(data)
|
| 43 |
+
|
| 44 |
+
data = data.cast_column(audio_field, Audio(decode=False))
|
| 45 |
+
|
| 46 |
+
def identify_corrupted_files(example):
|
| 47 |
+
try:
|
| 48 |
+
sf.read(example[audio_field]["path"])
|
| 49 |
+
|
| 50 |
+
for field in text_fields:
|
| 51 |
+
if field in example and example[field].replace('"', '') == "":
|
| 52 |
+
return False
|
| 53 |
+
return True
|
| 54 |
+
except Exception:
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
data = data.filter(identify_corrupted_files, num_proc=16)
|
| 58 |
+
validated_size = len(data)
|
| 59 |
+
|
| 60 |
+
# Audio Decoding
|
| 61 |
+
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
|
| 62 |
+
|
| 63 |
+
if debug:
|
| 64 |
+
print(f"Dataset: {dataset_name}")
|
| 65 |
+
print(f"Original data nums: {original_size}")
|
| 66 |
+
print(f"After filtering data nums: {validated_size}")
|
| 67 |
+
print(f"Filtering ratio: {validated_size/original_size:.2%}")
|
| 68 |
+
|
| 69 |
+
return data
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
|
| 73 |
+
original_size = len(data)
|
| 74 |
+
|
| 75 |
+
def filter_audio_by_length(example):
|
| 76 |
+
try:
|
| 77 |
+
audio = example[audio_field]['array']
|
| 78 |
+
channel = 1
|
| 79 |
+
if hasattr(audio, 'ndim') and audio.ndim > 1:
|
| 80 |
+
channel = audio.ndim
|
| 81 |
+
audio = audio.squeeze()
|
| 82 |
+
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
|
| 83 |
+
return min_sec <= audio_length <= max_sec
|
| 84 |
+
except Exception as e:
|
| 85 |
+
if debug:
|
| 86 |
+
print(f"Error : {str(e)[:100]}... - sample excluded")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
data = data.filter(filter_audio_by_length, num_proc=16)
|
| 90 |
+
filtered_size = len(data)
|
| 91 |
+
|
| 92 |
+
if debug:
|
| 93 |
+
print(f"Before Length Filtering data nums: {original_size}")
|
| 94 |
+
print(f"After Length Filtering data nums: {filtered_size}")
|
| 95 |
+
print(f"Filtering ratio: {filtered_size/original_size:.2%}")
|
| 96 |
+
|
| 97 |
+
return data
|
| 98 |
+
|
| 99 |
+
def prepare_model_inputs(self, audio_array, instruction, answer_text):
|
| 100 |
+
user_message = {
|
| 101 |
+
'role': 'user',
|
| 102 |
+
'content': '<start_of_audio>' + instruction,
|
| 103 |
+
}
|
| 104 |
+
prompt = self.processor.tokenizer.apply_chat_template(
|
| 105 |
+
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
inputs = self.processor(
|
| 109 |
+
text=prompt,
|
| 110 |
+
audio=[audio_array],
|
| 111 |
+
add_special_tokens=False,
|
| 112 |
+
return_tensors='pt'
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
answer = f"{answer_text}{ANSWER_SUFFIX}"
|
| 116 |
+
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
|
| 117 |
+
|
| 118 |
+
if self.debug:
|
| 119 |
+
self.debug = False
|
| 120 |
+
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
|
| 121 |
+
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
|
| 122 |
+
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
|
| 123 |
+
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
|
| 124 |
+
|
| 125 |
+
if self.training:
|
| 126 |
+
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
|
| 127 |
+
labels = torch.full_like(input_ids, _IGNORE_INDEX)
|
| 128 |
+
labels[:, -answer_ids.shape[1]:] = answer_ids
|
| 129 |
+
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
|
| 130 |
+
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
|
| 131 |
+
else:
|
| 132 |
+
input_ids = inputs.input_ids
|
| 133 |
+
labels = answer_ids
|
| 134 |
+
token_type_ids = inputs.token_type_ids
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
'input_ids': input_ids,
|
| 138 |
+
'labels': labels,
|
| 139 |
+
'token_type_ids': token_type_ids,
|
| 140 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 141 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 142 |
+
'input_modes': inputs.input_modes,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
# Libri Speech Dataset Class
|
| 146 |
+
class LibriSpeechDataset(BaseAudioDataset):
|
| 147 |
+
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
|
| 148 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 149 |
+
|
| 150 |
+
self.set_dataset_name(f"LibriSpeech_{subset}")
|
| 151 |
+
# only ASR
|
| 152 |
+
self.ast = False
|
| 153 |
+
self.lang = "en"
|
| 154 |
+
|
| 155 |
+
# load dataset
|
| 156 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
|
| 157 |
+
subset,
|
| 158 |
+
split=split,
|
| 159 |
+
trust_remote_code=True,
|
| 160 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# (Optional) Audio length Filtering
|
| 164 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 165 |
+
|
| 166 |
+
# Instruction Setting
|
| 167 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 168 |
+
|
| 169 |
+
def __len__(self):
|
| 170 |
+
return len(self.data)
|
| 171 |
+
|
| 172 |
+
def __getitem__(self, idx):
|
| 173 |
+
data = self.data[idx]
|
| 174 |
+
|
| 175 |
+
# Libri Speech is only for ASR
|
| 176 |
+
answer_text = data["text"].replace('"', '')
|
| 177 |
+
|
| 178 |
+
return self.prepare_model_inputs(
|
| 179 |
+
data["audio"]["array"],
|
| 180 |
+
self.instruction,
|
| 181 |
+
answer_text
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# common_voice_16_1 dataset
|
| 185 |
+
class CommonVoiceDataset(BaseAudioDataset):
|
| 186 |
+
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
|
| 187 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 188 |
+
|
| 189 |
+
self.set_dataset_name(f"CommonVoice_{source_lang}")
|
| 190 |
+
# only ASR
|
| 191 |
+
self.ast = False
|
| 192 |
+
self.lang=source_lang
|
| 193 |
+
|
| 194 |
+
# load dataset
|
| 195 |
+
if source_lang=="zh-TW":
|
| 196 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
|
| 197 |
+
else:
|
| 198 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
|
| 199 |
+
self.data = load_dataset(data_path,
|
| 200 |
+
source_lang,
|
| 201 |
+
split=split,
|
| 202 |
+
trust_remote_code=True,
|
| 203 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 204 |
+
)
|
| 205 |
+
def prepare_dataset(batch):
|
| 206 |
+
"""Function to preprocess the dataset with the .map method"""
|
| 207 |
+
transcription = batch["sentence"]
|
| 208 |
+
|
| 209 |
+
if transcription.startswith('"') and transcription.endswith('"'):
|
| 210 |
+
# we can remove trailing quotation marks as they do not affect the transcription
|
| 211 |
+
transcription = transcription[1:-1]
|
| 212 |
+
|
| 213 |
+
if transcription[-1] not in [".", "?", "!"]:
|
| 214 |
+
# append a full-stop to sentences that do not end in punctuation
|
| 215 |
+
transcription = transcription + "."
|
| 216 |
+
|
| 217 |
+
batch["sentence"] = transcription
|
| 218 |
+
|
| 219 |
+
return batch
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
import opencc
|
| 223 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 224 |
+
def To_zhTW(batch):
|
| 225 |
+
|
| 226 |
+
transcription = converter.convert(batch["sentence"])
|
| 227 |
+
batch["sentence"] = transcription
|
| 228 |
+
|
| 229 |
+
return batch
|
| 230 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 231 |
+
if source_lang=='zh-CN':
|
| 232 |
+
self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# (Optional) Audio length Filtering
|
| 236 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 237 |
+
|
| 238 |
+
if source_lang == "zh-TW" and split=='train':
|
| 239 |
+
import torchaudio
|
| 240 |
+
from torchaudio import transforms
|
| 241 |
+
import copy
|
| 242 |
+
import pickle
|
| 243 |
+
import os
|
| 244 |
+
def subsample(batch):
|
| 245 |
+
batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
|
| 246 |
+
batch['audio']['sampling_rate']=16000
|
| 247 |
+
return batch
|
| 248 |
+
def TW_data_augment_fast(batch):
|
| 249 |
+
speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
|
| 250 |
+
new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
|
| 251 |
+
batch['audio']['array'] = new_array_fast
|
| 252 |
+
return batch
|
| 253 |
+
def TW_data_augment_slow(batch):
|
| 254 |
+
speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
|
| 255 |
+
new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
|
| 256 |
+
batch['audio']['array'] = new_array_slow
|
| 257 |
+
return batch
|
| 258 |
+
# data = self.data.map(subsample, num_proc=1, desc="subsample")
|
| 259 |
+
fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
|
| 260 |
+
if not os.path.exists(fast_path):
|
| 261 |
+
data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
|
| 262 |
+
with open(fast_path,'wb') as f:
|
| 263 |
+
pickle.dump(data_fast,f)
|
| 264 |
+
else:
|
| 265 |
+
with open(fast_path,'rb') as f:
|
| 266 |
+
data_fast=pickle.load(f)
|
| 267 |
+
|
| 268 |
+
slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
|
| 269 |
+
if not os.path.exists(slow_path):
|
| 270 |
+
data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
|
| 271 |
+
with open(slow_path,'wb') as f:
|
| 272 |
+
pickle.dump(data_slow,f)
|
| 273 |
+
else:
|
| 274 |
+
with open(slow_path,'rb') as f:
|
| 275 |
+
data_slow=pickle.load(f)
|
| 276 |
+
self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
|
| 277 |
+
|
| 278 |
+
# Instruction Setting
|
| 279 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 280 |
+
|
| 281 |
+
def __len__(self):
|
| 282 |
+
return len(self.data)
|
| 283 |
+
|
| 284 |
+
def __getitem__(self, idx):
|
| 285 |
+
data = self.data[idx]
|
| 286 |
+
|
| 287 |
+
answer_text = data["sentence"]
|
| 288 |
+
return self.prepare_model_inputs(
|
| 289 |
+
data["audio"]["array"],
|
| 290 |
+
self.instruction,
|
| 291 |
+
answer_text
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Fleurs Dataset Class
|
| 296 |
+
class FleursDataset(BaseAudioDataset):
|
| 297 |
+
def __init__(self, processor, split, source_lang, target_lang=None,
|
| 298 |
+
mode="asr", sampling_rate=16000, debug=False):
|
| 299 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 300 |
+
|
| 301 |
+
self.set_dataset_name("Fleurs")
|
| 302 |
+
# Mode Setting (ASR or AST)
|
| 303 |
+
if mode not in ["asr", "ast"]:
|
| 304 |
+
raise ValueError("mode must be 'asr' or 'ast'.")
|
| 305 |
+
|
| 306 |
+
self.mode = mode
|
| 307 |
+
self.ast = (mode == "ast")
|
| 308 |
+
self.source_lang = source_lang
|
| 309 |
+
|
| 310 |
+
# Language name mapping (expand if needed)
|
| 311 |
+
self.lang_names = {
|
| 312 |
+
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
# load dataset - source language dataset
|
| 316 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 317 |
+
source_lang,
|
| 318 |
+
split=split,
|
| 319 |
+
trust_remote_code=True,
|
| 320 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 321 |
+
)
|
| 322 |
+
import opencc
|
| 323 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 324 |
+
def prepare_dataset(batch):
|
| 325 |
+
transcription = converter.convert(batch["transcription"])
|
| 326 |
+
batch["transcription"] = transcription
|
| 327 |
+
|
| 328 |
+
return batch
|
| 329 |
+
if (source_lang=="cmn_hans_cn"):
|
| 330 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 331 |
+
|
| 332 |
+
# (Optional) Audio length Filtering
|
| 333 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 334 |
+
self.target_lang_name = ""
|
| 335 |
+
# When AST mode, load target language dataset.
|
| 336 |
+
if self.ast:
|
| 337 |
+
if target_lang is None:
|
| 338 |
+
raise ValueError("AST mode requires target_lang.")
|
| 339 |
+
|
| 340 |
+
self.target_lang = target_lang
|
| 341 |
+
self.lang = f"{source_lang}_{target_lang}"
|
| 342 |
+
|
| 343 |
+
# load dataset - target language dataset (for translation)
|
| 344 |
+
target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 345 |
+
target_lang,
|
| 346 |
+
split=split,
|
| 347 |
+
trust_remote_code=True,
|
| 348 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 349 |
+
)
|
| 350 |
+
if target_lang=="cmn_hans_cn":
|
| 351 |
+
target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
|
| 352 |
+
source_dict = {item['id']: item for item in self.data}
|
| 353 |
+
target_dict = {item['id']: item for item in target_data}
|
| 354 |
+
|
| 355 |
+
# only Common ID, add translation fields
|
| 356 |
+
common_ids = set(source_dict.keys()) & set(target_dict.keys())
|
| 357 |
+
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
|
| 358 |
+
self.data = [
|
| 359 |
+
{**source_dict[id], 'translation': target_dict[id]['transcription']}
|
| 360 |
+
for id in common_ids
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
# Instruction Setting - use target language name
|
| 364 |
+
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 365 |
+
self.instruction = random.choice(INSTRUCTION["ast"])
|
| 366 |
+
else:
|
| 367 |
+
# ASR mode
|
| 368 |
+
self.lang = source_lang
|
| 369 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 370 |
+
|
| 371 |
+
if self.debug:
|
| 372 |
+
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
| 373 |
+
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
|
| 374 |
+
if self.ast:
|
| 375 |
+
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
|
| 376 |
+
print(f"dataset size: {len(self.data)}")
|
| 377 |
+
|
| 378 |
+
def __len__(self):
|
| 379 |
+
return len(self.data)
|
| 380 |
+
|
| 381 |
+
def __getitem__(self, idx):
|
| 382 |
+
data = self.data[idx]
|
| 383 |
+
audio_array = data["audio"]["array"]
|
| 384 |
+
|
| 385 |
+
if self.ast:
|
| 386 |
+
answer_text = data["translation"]
|
| 387 |
+
else:
|
| 388 |
+
answer_text = data["transcription"]
|
| 389 |
+
|
| 390 |
+
return self.prepare_model_inputs(
|
| 391 |
+
audio_array,
|
| 392 |
+
self.instruction.format(self.target_lang_name),
|
| 393 |
+
answer_text
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
class TWCostumData(BaseAudioDataset):
|
| 397 |
+
|
| 398 |
+
def __init__(self, processor, split="train", sampling_rate=16000,csv_path="", debug=False):
|
| 399 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 400 |
+
import pandas as pd
|
| 401 |
+
from datasets import Dataset, Audio
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
df = pd.read_csv(csv_path).fillna('')
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
self.set_dataset_name(f"TWCostumData")
|
| 408 |
+
self.data = Dataset.from_dict(
|
| 409 |
+
{
|
| 410 |
+
"audio": [audio for audio in df['audio']],
|
| 411 |
+
"sentence": [text for text in df['text']]
|
| 412 |
+
}
|
| 413 |
+
).cast_column("audio", Audio(sampling_rate=16000))
|
| 414 |
+
|
| 415 |
+
# Instruction Setting
|
| 416 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 417 |
+
|
| 418 |
+
def __len__(self):
|
| 419 |
+
return len(self.data)
|
| 420 |
+
|
| 421 |
+
def __getitem__(self, idx):
|
| 422 |
+
data = self.data[idx]
|
| 423 |
+
|
| 424 |
+
answer_text = data["sentence"]
|
| 425 |
+
return self.prepare_model_inputs(
|
| 426 |
+
data["audio"]["array"],
|
| 427 |
+
self.instruction,
|
| 428 |
+
answer_text
|
| 429 |
+
)
|
| 430 |
+
def covost_collate_fn(batch):
|
| 431 |
+
input_ids_list = []
|
| 432 |
+
labels_list = []
|
| 433 |
+
token_type_ids_list = []
|
| 434 |
+
input_audio_embeds_list = []
|
| 435 |
+
audio_embed_sizes_list = []
|
| 436 |
+
audio_attention_mask_list = []
|
| 437 |
+
input_modes_list = []
|
| 438 |
+
audio_paths = []
|
| 439 |
+
for inputs in batch:
|
| 440 |
+
if 'audio_path' in inputs:
|
| 441 |
+
audio_paths.append(inputs['audio_path'])
|
| 442 |
+
input_ids_list.append(inputs['input_ids'][0])
|
| 443 |
+
labels_list.append(inputs['labels'][0])
|
| 444 |
+
token_type_ids_list.append(inputs['token_type_ids'][0])
|
| 445 |
+
if inputs['input_modes']==2:
|
| 446 |
+
input_audio_embeds_list.append(inputs['input_audio_embeds'])
|
| 447 |
+
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
|
| 448 |
+
audio_attention_mask_list.append(
|
| 449 |
+
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
|
| 450 |
+
)
|
| 451 |
+
# else:
|
| 452 |
+
# input_audio_embeds_list.append(None)
|
| 453 |
+
# audio_embed_sizes_list.append(None)
|
| 454 |
+
# audio_attention_mask_list.append(None)
|
| 455 |
+
input_modes_list.append(inputs['input_modes'])
|
| 456 |
+
# try:
|
| 457 |
+
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
|
| 458 |
+
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
|
| 459 |
+
labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
|
| 460 |
+
audio_attention_mask = (
|
| 461 |
+
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
|
| 462 |
+
if len(audio_attention_mask_list) > 1
|
| 463 |
+
else None
|
| 464 |
+
)
|
| 465 |
+
# except Exception as e:
|
| 466 |
+
# print(e)
|
| 467 |
+
# print(input_ids_list)
|
| 468 |
+
# print(labels_list)
|
| 469 |
+
# raise
|
| 470 |
+
attention_mask = (input_ids != 0).long()
|
| 471 |
+
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) if len(input_audio_embeds_list)>0 else None
|
| 472 |
+
audio_embed_sizes = torch.cat(audio_embed_sizes_list) if len(audio_embed_sizes_list)>0 else None
|
| 473 |
+
input_modes = torch.cat(input_modes_list)
|
| 474 |
+
if len(audio_paths)>0:
|
| 475 |
+
return BatchFeature(
|
| 476 |
+
{
|
| 477 |
+
"audio_path": audio_paths,
|
| 478 |
+
'input_ids': input_ids,
|
| 479 |
+
'labels': labels,
|
| 480 |
+
'token_type_ids': token_type_ids,
|
| 481 |
+
'attention_mask': attention_mask,
|
| 482 |
+
'input_audio_embeds': input_audio_embeds,
|
| 483 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 484 |
+
'audio_attention_mask': audio_attention_mask,
|
| 485 |
+
'input_modes': input_modes,
|
| 486 |
+
}
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
return BatchFeature(
|
| 490 |
+
{
|
| 491 |
+
'input_ids': input_ids,
|
| 492 |
+
'labels': labels,
|
| 493 |
+
'token_type_ids': token_type_ids,
|
| 494 |
+
'attention_mask': attention_mask,
|
| 495 |
+
'input_audio_embeds': input_audio_embeds,
|
| 496 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 497 |
+
'audio_attention_mask': audio_attention_mask,
|
| 498 |
+
'input_modes': input_modes,
|
| 499 |
+
}
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
def pad_sequence(sequences, padding_side='left', padding_value=0):
|
| 503 |
+
"""
|
| 504 |
+
Pad a list of sequences to the same length.
|
| 505 |
+
sequences: list of tensors in [seq_len, *] shape
|
| 506 |
+
"""
|
| 507 |
+
assert padding_side in ['right', 'left']
|
| 508 |
+
max_size = sequences[0].size()
|
| 509 |
+
trailing_dims = max_size[1:]
|
| 510 |
+
max_len = max(len(seq) for seq in sequences)
|
| 511 |
+
batch_size = len(sequences)
|
| 512 |
+
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
|
| 513 |
+
for i, seq in enumerate(sequences):
|
| 514 |
+
length = seq.size(0)
|
| 515 |
+
if padding_side == 'right':
|
| 516 |
+
output.data[i, :length] = seq
|
| 517 |
+
else:
|
| 518 |
+
output.data[i, -length:] = seq
|
| 519 |
+
return output
|
| 520 |
+
|
| 521 |
+
def cat_with_pad(tensors, dim, padding_value=0):
|
| 522 |
+
"""
|
| 523 |
+
cat along dim, while pad to max for all other dims
|
| 524 |
+
"""
|
| 525 |
+
ndim = tensors[0].dim()
|
| 526 |
+
assert all(
|
| 527 |
+
t.dim() == ndim for t in tensors[1:]
|
| 528 |
+
), 'All tensors must have the same number of dimensions'
|
| 529 |
+
|
| 530 |
+
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
|
| 531 |
+
out_size[dim] = sum(t.shape[dim] for t in tensors)
|
| 532 |
+
output = tensors[0].new_full(out_size, padding_value)
|
| 533 |
+
|
| 534 |
+
index = 0
|
| 535 |
+
for t in tensors:
|
| 536 |
+
# Create a slice list where every dimension except dim is full slice
|
| 537 |
+
slices = [slice(0, t.shape[d]) for d in range(ndim)]
|
| 538 |
+
# Update only the concat dimension slice
|
| 539 |
+
slices[dim] = slice(index, index + t.shape[dim])
|
| 540 |
+
|
| 541 |
+
output[slices] = t
|
| 542 |
+
index += t.shape[dim]
|
| 543 |
+
|
| 544 |
+
return output
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class MultiturnAudioDataset(BaseAudioDataset):
|
| 549 |
+
def __init__(self, processor, split="train", sampling_rate=16000,json_path="",text_only=False, debug=False):
|
| 550 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 551 |
+
from llamafactory.data.template import Llama2Template,parse_template
|
| 552 |
+
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
| 553 |
+
from llamafactory.data.mm_plugin import get_mm_plugin
|
| 554 |
+
import json
|
| 555 |
+
self.train=False
|
| 556 |
+
self.text_only=text_only
|
| 557 |
+
with open(json_path) as f:
|
| 558 |
+
js_data = json.load(f)
|
| 559 |
+
if split=='train':
|
| 560 |
+
self.train=True
|
| 561 |
+
js_data = js_data[:int(len(js_data)*0.8)]
|
| 562 |
+
else:
|
| 563 |
+
js_data = js_data[-int(len(js_data)*0.2):]
|
| 564 |
+
for conv in js_data:
|
| 565 |
+
for mess in conv['conversations']:
|
| 566 |
+
if 'audio_path' in mess:
|
| 567 |
+
mess['audio_path'] = mess['audio_path'].replace('/home/jeff/codes/llm/InCar/srdc_generate_tts/','/mnt/jeff/InCar/data/multiturn_data/')
|
| 568 |
+
default_system = ""#"""You are a helpful assistant that determines how to solve problems based on user needs and converts user speech into text.\n"""
|
| 569 |
+
self.template=Llama2Template(
|
| 570 |
+
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
| 571 |
+
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
|
| 572 |
+
format_system=StringFormatter(slots=["{{content}}\n\n"]),
|
| 573 |
+
format_function=FunctionFormatter(slots=["{{content}}", {"eos_token"}], tool_format="default"),
|
| 574 |
+
format_tools = ToolFormatter(tool_format="default"),
|
| 575 |
+
format_observation=StringFormatter(
|
| 576 |
+
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
| 577 |
+
),
|
| 578 |
+
default_system=default_system,
|
| 579 |
+
thought_words=("<think>", "</think>"),
|
| 580 |
+
efficient_eos=False,
|
| 581 |
+
replace_eos=False,
|
| 582 |
+
replace_jinja_template=False,
|
| 583 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
| 584 |
+
stop_words=["<end_of_turn>"],
|
| 585 |
+
mm_plugin=get_mm_plugin(name="base"),
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
self.set_dataset_name(f"MultiturnCostumData")
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
self.data = []
|
| 592 |
+
self.text_only_data = []
|
| 593 |
+
for conv in js_data:
|
| 594 |
+
tools = conv['tools'] if 'tools' in conv else ""
|
| 595 |
+
system = conv['system'] if 'system' in conv else default_system
|
| 596 |
+
tmp = {
|
| 597 |
+
'tools':tools,
|
| 598 |
+
'system':system,
|
| 599 |
+
'messages':[],
|
| 600 |
+
}
|
| 601 |
+
for i,mess in enumerate(conv['conversations']):
|
| 602 |
+
tmp['messages'].append(mess)
|
| 603 |
+
if mess['from']=='human':
|
| 604 |
+
tmp['messages'].append(conv['conversations'][i+1])
|
| 605 |
+
d = deepcopy(tmp)
|
| 606 |
+
d['audio_array'] = torchaudio.load(mess['audio_path'])[0][0]
|
| 607 |
+
self.data.append(d)
|
| 608 |
+
if self.text_only:
|
| 609 |
+
self.text_only_data.append(deepcopy(tmp))
|
| 610 |
+
tmp['messages'].pop()
|
| 611 |
+
elif mess['from']=='observation':
|
| 612 |
+
tmp['messages'].append(conv['conversations'][i+1])
|
| 613 |
+
d = deepcopy(tmp)
|
| 614 |
+
self.text_only_data.append(d)
|
| 615 |
+
tmp['messages'].pop()
|
| 616 |
+
if text_only:
|
| 617 |
+
self.data=self.text_only_data
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def prepare_multiturn_model_inputs(self, audio_array, messages, system="", tools=""):
|
| 621 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 622 |
+
prompt = ""
|
| 623 |
+
answer_text = ""
|
| 624 |
+
user_transcribe = ""
|
| 625 |
+
audio_paths = []
|
| 626 |
+
for i, message in enumerate(messages):
|
| 627 |
+
elements = []
|
| 628 |
+
|
| 629 |
+
system_text = ""
|
| 630 |
+
if i == 0:
|
| 631 |
+
elements += self.template.format_prefix.apply()
|
| 632 |
+
if system or tools:
|
| 633 |
+
tool_text = self.template.format_tools.apply(content=tools)[0] if tools else ""
|
| 634 |
+
system_text = self.template.format_system.apply(content=(system + tool_text))[0]
|
| 635 |
+
|
| 636 |
+
if message["from"] == "human":
|
| 637 |
+
if i==len(messages)-2 and not self.text_only:
|
| 638 |
+
user_transcribe = message["value"]
|
| 639 |
+
elements += self.template.format_user.apply(content=system_text+'<start_of_audio>')
|
| 640 |
+
else:
|
| 641 |
+
elements += self.template.format_user.apply(content=system_text + message["value"])
|
| 642 |
+
audio_paths.append(message['audio_path'])
|
| 643 |
+
elif message["from"] == "gpt":
|
| 644 |
+
elements += self.template.format_assistant.apply(content=message["value"])
|
| 645 |
+
elif message["from"] == "observation":
|
| 646 |
+
elements += self.template.format_observation.apply(content=message["value"])
|
| 647 |
+
elif message["from"] == "function_call":
|
| 648 |
+
elements += self.template.format_function.apply(content=message["value"])
|
| 649 |
+
else:
|
| 650 |
+
raise NotImplementedError("Unexpected role: {}".format(message["from"]))
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
for elem in elements:
|
| 654 |
+
ele_str = ""
|
| 655 |
+
if isinstance(elem, str):
|
| 656 |
+
ele_str=elem
|
| 657 |
+
elif isinstance(elem, set):
|
| 658 |
+
if "bos_token" in elem and self.processor.tokenizer.bos_token_id is not None:
|
| 659 |
+
ele_str = self.processor.tokenizer.bos_token
|
| 660 |
+
elif "eos_token" in elem and self.processor.tokenizer.eos_token_id is not None:
|
| 661 |
+
ele_str = self.processor.tokenizer.eos_token
|
| 662 |
+
if i == len(messages)-1:
|
| 663 |
+
answer_text+=ele_str
|
| 664 |
+
else:
|
| 665 |
+
prompt+=ele_str
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
if type(audio_array)!=type(None):
|
| 669 |
+
inputs = self.processor(
|
| 670 |
+
text=prompt,
|
| 671 |
+
audio=[audio_array],
|
| 672 |
+
add_special_tokens=False,
|
| 673 |
+
return_tensors='pt'
|
| 674 |
+
)
|
| 675 |
+
answer = "\nUser transcribe is : {};\nGPT output is : {}{}".format(user_transcribe,answer_text,ANSWER_SUFFIX)
|
| 676 |
+
else:
|
| 677 |
+
inputs = self.processor(
|
| 678 |
+
text=prompt,
|
| 679 |
+
audio=None,
|
| 680 |
+
add_special_tokens=False,
|
| 681 |
+
return_tensors='pt'
|
| 682 |
+
)
|
| 683 |
+
answer = f"{answer_text}{ANSWER_SUFFIX}"
|
| 684 |
+
# print('user_transcribe',user_transcribe)
|
| 685 |
+
# print('answer_text', answer)
|
| 686 |
+
# print('prompt',prompt)
|
| 687 |
+
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
|
| 688 |
+
|
| 689 |
+
if self.debug:
|
| 690 |
+
self.debug = False
|
| 691 |
+
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
|
| 692 |
+
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
|
| 693 |
+
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
|
| 694 |
+
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
|
| 695 |
+
|
| 696 |
+
if self.training:
|
| 697 |
+
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
|
| 698 |
+
labels = torch.full_like(input_ids, _IGNORE_INDEX)
|
| 699 |
+
labels[:, -answer_ids.shape[1]:] = answer_ids
|
| 700 |
+
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
|
| 701 |
+
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
|
| 702 |
+
else:
|
| 703 |
+
input_ids = inputs.input_ids
|
| 704 |
+
labels = answer_ids
|
| 705 |
+
token_type_ids = inputs.token_type_ids
|
| 706 |
+
if type(audio_array)!=type(None):
|
| 707 |
+
if not self.train:
|
| 708 |
+
return {
|
| 709 |
+
"audio_path": audio_paths,
|
| 710 |
+
'input_ids': input_ids,
|
| 711 |
+
'labels': labels,
|
| 712 |
+
'token_type_ids': token_type_ids,
|
| 713 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 714 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 715 |
+
'input_modes': inputs.input_modes,
|
| 716 |
+
}
|
| 717 |
+
else:
|
| 718 |
+
return {
|
| 719 |
+
'input_ids': input_ids,
|
| 720 |
+
'labels': labels,
|
| 721 |
+
'token_type_ids': token_type_ids,
|
| 722 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 723 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 724 |
+
'input_modes': inputs.input_modes,
|
| 725 |
+
}
|
| 726 |
+
else:
|
| 727 |
+
return {
|
| 728 |
+
'input_ids': input_ids,
|
| 729 |
+
'labels': labels,
|
| 730 |
+
'token_type_ids': token_type_ids,
|
| 731 |
+
'input_audio_embeds': None,
|
| 732 |
+
'audio_embed_sizes': None,
|
| 733 |
+
'input_modes': inputs.input_modes,
|
| 734 |
+
}
|
| 735 |
+
def __len__(self):
|
| 736 |
+
return len(self.data)
|
| 737 |
+
|
| 738 |
+
def __getitem__(self, idx):
|
| 739 |
+
data = self.data[idx]
|
| 740 |
+
return self.prepare_multiturn_model_inputs(
|
| 741 |
+
audio_array=data["audio_array"] if "audio_array" in data else None,
|
| 742 |
+
messages=data['messages'],
|
| 743 |
+
system=data["system"],
|
| 744 |
+
tools=data["tools"]
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 750 |
+
|
| 751 |
+
INSTRUCTION = {
|
| 752 |
+
"ast": [
|
| 753 |
+
"Translate the audio to {0}.",
|
| 754 |
+
"Translate the audio clip into {0}.",
|
| 755 |
+
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
|
| 756 |
+
"Translate the provided audio file into {0}.",
|
| 757 |
+
"Convert the audio speech to {0} text.",
|
| 758 |
+
"Write an {0} translation of the audio file.",
|
| 759 |
+
"Translate spoken words from the audio into {0}.",
|
| 760 |
+
"Create an {0} version of the audio content.",
|
| 761 |
+
"Produce an accurate {0} translation of the audio.",
|
| 762 |
+
"Extract speech from the audio and translate it to {0}.",
|
| 763 |
+
"Turn the audio into readable {0} text.",
|
| 764 |
+
"Write all spoken content from the audio in {0}.",
|
| 765 |
+
"Generate an {0} translation of the speech in the file.",
|
| 766 |
+
"Convert the recording into {0} text.",
|
| 767 |
+
"Accurately translate the audio recording to {0}.",
|
| 768 |
+
"Write down dialogue from the given audio in {0}.",
|
| 769 |
+
"Translate all speech in this audio file to {0}.",
|
| 770 |
+
"Create an accurate {0} version of the speech.",
|
| 771 |
+
"Perform a complete {0} translation of the audio."
|
| 772 |
+
],
|
| 773 |
+
"asr": [
|
| 774 |
+
"Transcribe the audio clip into text.",
|
| 775 |
+
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
|
| 776 |
+
"Transcribe the provided audio file into text.",
|
| 777 |
+
"Convert the audio speech to text.",
|
| 778 |
+
"Write a transcript of the audio file.",
|
| 779 |
+
"Transcribe spoken words from the audio.",
|
| 780 |
+
"Create a text version of the audio content.",
|
| 781 |
+
"Produce a verbatim transcript of the audio.",
|
| 782 |
+
"Extract and transcribe speech from the audio.",
|
| 783 |
+
"Turn the audio into readable text.",
|
| 784 |
+
"Write all spoken words from the audio.",
|
| 785 |
+
"Generate a transcript of the speech in the file.",
|
| 786 |
+
"Convert the recording into a text transcript.",
|
| 787 |
+
"Accurately transcribe the audio recording.",
|
| 788 |
+
"Write down dialogue from the given audio.",
|
| 789 |
+
"Transcribe all speech in this audio file.",
|
| 790 |
+
"Create an accurate text version of the speech.",
|
| 791 |
+
"Perform a complete transcription of the audio."
|
| 792 |
+
],
|
| 793 |
+
}
|
cpp/gemma_v1/__pycache__/ASRDataset.cpython-312.pyc
ADDED
|
Binary file (36.9 kB). View file
|
|
|
cpp/gemma_v1/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image_soft_token>": 262144
|
| 3 |
+
}
|
cpp/gemma_v1/chat_template.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'audio' -%}\n {{ '<start_of_audio>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n"
|
| 3 |
+
}
|
cpp/gemma_v1/config.json
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Gemma3OmniForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"audio_config": {
|
| 6 |
+
"activation": "swish",
|
| 7 |
+
"activation_checkpointing": {
|
| 8 |
+
"interval": 1,
|
| 9 |
+
"module": "transformer",
|
| 10 |
+
"offload": false
|
| 11 |
+
},
|
| 12 |
+
"attention_dim": 1024,
|
| 13 |
+
"attention_heads": 16,
|
| 14 |
+
"batch_norm": false,
|
| 15 |
+
"bias_in_glu": true,
|
| 16 |
+
"causal": true,
|
| 17 |
+
"chunk_size": -1,
|
| 18 |
+
"cnn_layer_norm": true,
|
| 19 |
+
"conv_activation": "swish",
|
| 20 |
+
"conv_glu_type": "swish",
|
| 21 |
+
"depthwise_multiplier": 1,
|
| 22 |
+
"depthwise_seperable_out_channel": 1024,
|
| 23 |
+
"dropout_rate": 0.0,
|
| 24 |
+
"encoder_embedding_config": {
|
| 25 |
+
"input_size": 80
|
| 26 |
+
},
|
| 27 |
+
"ext_pw_kernel_size": 1,
|
| 28 |
+
"ext_pw_out_channel": 1024,
|
| 29 |
+
"input_layer": "nemo_conv",
|
| 30 |
+
"input_size": 80,
|
| 31 |
+
"kernel_size": 3,
|
| 32 |
+
"left_chunk": 18,
|
| 33 |
+
"linear_units": 1536,
|
| 34 |
+
"nemo_conv_settings": {
|
| 35 |
+
"conv_channels": 1024
|
| 36 |
+
},
|
| 37 |
+
"num_blocks": 24,
|
| 38 |
+
"relative_attention_bias_args": {
|
| 39 |
+
"t5_bias_max_distance": 500,
|
| 40 |
+
"type": "t5"
|
| 41 |
+
},
|
| 42 |
+
"time_reduction": 8
|
| 43 |
+
},
|
| 44 |
+
"audio_token_index": 262143,
|
| 45 |
+
"auto_map": {
|
| 46 |
+
"AutoConfig": "configuration_gemma3omni.Gemma3OmniConfig",
|
| 47 |
+
"AutoModel": "modeling_gemma3omni.Gemma3OmniForConditionalGeneration"
|
| 48 |
+
},
|
| 49 |
+
"boa_token_index": 256001,
|
| 50 |
+
"boi_token_index": 255999,
|
| 51 |
+
"eoa_token_index": 256002,
|
| 52 |
+
"eoi_token_index": 256000,
|
| 53 |
+
"eos_token_id": [
|
| 54 |
+
1,
|
| 55 |
+
106
|
| 56 |
+
],
|
| 57 |
+
"image_token_index": 262144,
|
| 58 |
+
"initializer_range": 0.02,
|
| 59 |
+
"mm_tokens_per_image": 256,
|
| 60 |
+
"model_type": "gemma3omni",
|
| 61 |
+
"speech_lora": {
|
| 62 |
+
"dp": 0.01,
|
| 63 |
+
"layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
|
| 64 |
+
"lora_alpha": 320,
|
| 65 |
+
"r": 320,
|
| 66 |
+
"use_rslora": true
|
| 67 |
+
},
|
| 68 |
+
"text_lora": {
|
| 69 |
+
"dp": 0.01,
|
| 70 |
+
"layer": "((layers.*self_attn\\.(q|k|v|o)_proj)|(layers.*mlp\\.(gate|up|down)_proj))",
|
| 71 |
+
"lora_alpha": 16,
|
| 72 |
+
"r": 8,
|
| 73 |
+
"use_rslora": true
|
| 74 |
+
},
|
| 75 |
+
"text_config": {
|
| 76 |
+
"attention_bias": false,
|
| 77 |
+
"attention_dropout": 0.0,
|
| 78 |
+
"attn_logit_softcapping": null,
|
| 79 |
+
"cache_implementation": "hybrid",
|
| 80 |
+
"final_logit_softcapping": null,
|
| 81 |
+
"head_dim": 256,
|
| 82 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
| 83 |
+
"hidden_size": 2560,
|
| 84 |
+
"initializer_range": 0.02,
|
| 85 |
+
"intermediate_size": 10240,
|
| 86 |
+
"max_position_embeddings": 131072,
|
| 87 |
+
"model_type": "gemma3_text",
|
| 88 |
+
"num_attention_heads": 8,
|
| 89 |
+
"num_hidden_layers": 34,
|
| 90 |
+
"num_key_value_heads": 4,
|
| 91 |
+
"query_pre_attn_scalar": 256,
|
| 92 |
+
"rms_norm_eps": 1e-06,
|
| 93 |
+
"rope_local_base_freq": 10000.0,
|
| 94 |
+
"rope_scaling": {
|
| 95 |
+
"factor": 8.0,
|
| 96 |
+
"rope_type": "linear"
|
| 97 |
+
},
|
| 98 |
+
"rope_theta": 1000000.0,
|
| 99 |
+
"sliding_window": 1024,
|
| 100 |
+
"sliding_window_pattern": 6,
|
| 101 |
+
"torch_dtype": "float",
|
| 102 |
+
"use_cache": true,
|
| 103 |
+
"vocab_size": 262208
|
| 104 |
+
},
|
| 105 |
+
"torch_dtype": "float",
|
| 106 |
+
"transformers_version": "4.51.3",
|
| 107 |
+
"use_cache": false,
|
| 108 |
+
"vision_config": {
|
| 109 |
+
"hidden_size": 1152,
|
| 110 |
+
"image_size": 896,
|
| 111 |
+
"intermediate_size": 4304,
|
| 112 |
+
"model_type": "siglip_vision_model",
|
| 113 |
+
"num_attention_heads": 16,
|
| 114 |
+
"num_hidden_layers": 27,
|
| 115 |
+
"patch_size": 14,
|
| 116 |
+
"vision_use_head": false
|
| 117 |
+
}
|
| 118 |
+
}
|
cpp/gemma_v1/configuration_gemma3omni.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from transformers import AutoConfig, Gemma3TextConfig
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 6 |
+
from transformers.utils import logging
|
| 7 |
+
from transformers.models.siglip import SiglipVisionConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
class AudioConfig(PretrainedConfig):
|
| 13 |
+
model_type = "gemma3_audio"
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
input_size=80,
|
| 18 |
+
attention_dim=1024,
|
| 19 |
+
attention_heads=16,
|
| 20 |
+
num_blocks=24,
|
| 21 |
+
linear_units=1536,
|
| 22 |
+
dropout_rate=0.0,
|
| 23 |
+
kernel_size=3,
|
| 24 |
+
ext_pw_kernel_size=1,
|
| 25 |
+
ext_pw_out_channel=1024,
|
| 26 |
+
depthwise_seperable_out_channel=1024,
|
| 27 |
+
depthwise_multiplier=1,
|
| 28 |
+
activation="swish",
|
| 29 |
+
conv_activation="swish",
|
| 30 |
+
conv_glu_type="swish",
|
| 31 |
+
bias_in_glu=True,
|
| 32 |
+
causal=True,
|
| 33 |
+
batch_norm=False,
|
| 34 |
+
cnn_layer_norm=True,
|
| 35 |
+
time_reduction=8,
|
| 36 |
+
input_layer="nemo_conv",
|
| 37 |
+
nemo_conv_settings=None,
|
| 38 |
+
chunk_size=-1,
|
| 39 |
+
left_chunk=18,
|
| 40 |
+
relative_attention_bias_args=None,
|
| 41 |
+
activation_checkpointing=None,
|
| 42 |
+
encoder_embedding_config=None,
|
| 43 |
+
**kwargs
|
| 44 |
+
):
|
| 45 |
+
super().__init__(**kwargs)
|
| 46 |
+
|
| 47 |
+
self.input_size = input_size
|
| 48 |
+
self.attention_dim = attention_dim
|
| 49 |
+
self.attention_heads = attention_heads
|
| 50 |
+
self.num_blocks = num_blocks
|
| 51 |
+
self.linear_units = linear_units
|
| 52 |
+
self.dropout_rate = dropout_rate
|
| 53 |
+
self.kernel_size = kernel_size
|
| 54 |
+
self.ext_pw_kernel_size = ext_pw_kernel_size
|
| 55 |
+
self.ext_pw_out_channel = ext_pw_out_channel
|
| 56 |
+
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
| 57 |
+
self.depthwise_multiplier = depthwise_multiplier
|
| 58 |
+
self.activation = activation
|
| 59 |
+
self.conv_activation = conv_activation
|
| 60 |
+
self.conv_glu_type = conv_glu_type
|
| 61 |
+
self.bias_in_glu = bias_in_glu
|
| 62 |
+
self.causal = causal
|
| 63 |
+
self.batch_norm = batch_norm
|
| 64 |
+
self.cnn_layer_norm = cnn_layer_norm
|
| 65 |
+
self.time_reduction = time_reduction
|
| 66 |
+
self.input_layer = input_layer
|
| 67 |
+
|
| 68 |
+
if nemo_conv_settings is None:
|
| 69 |
+
self.nemo_conv_settings = {"conv_channels": 1024}
|
| 70 |
+
else:
|
| 71 |
+
self.nemo_conv_settings = nemo_conv_settings
|
| 72 |
+
|
| 73 |
+
self.chunk_size = chunk_size
|
| 74 |
+
self.left_chunk = left_chunk
|
| 75 |
+
|
| 76 |
+
if relative_attention_bias_args is None:
|
| 77 |
+
self.relative_attention_bias_args = {"type": "t5", "t5_bias_max_distance": 500}
|
| 78 |
+
else:
|
| 79 |
+
self.relative_attention_bias_args = relative_attention_bias_args
|
| 80 |
+
|
| 81 |
+
if activation_checkpointing is None:
|
| 82 |
+
self.activation_checkpointing = {"interval": 1, "module": "transformer", "offload": False}
|
| 83 |
+
else:
|
| 84 |
+
self.activation_checkpointing = activation_checkpointing
|
| 85 |
+
|
| 86 |
+
if encoder_embedding_config is None:
|
| 87 |
+
self.encoder_embedding_config = {"input_size": input_size}
|
| 88 |
+
else:
|
| 89 |
+
self.encoder_embedding_config = encoder_embedding_config
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Gemma3OmniConfig(PretrainedConfig):
|
| 93 |
+
r"""
|
| 94 |
+
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
|
| 95 |
+
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 96 |
+
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
|
| 97 |
+
|
| 98 |
+
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
| 99 |
+
|
| 100 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 101 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
|
| 105 |
+
The config object of the text backbone.
|
| 106 |
+
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
| 107 |
+
Custom vision config or dict.
|
| 108 |
+
audio_config (`Union[AutoConfig, dict]`, *optional*):
|
| 109 |
+
Custom audio config or dict.
|
| 110 |
+
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
| 111 |
+
The number of tokens per image embedding.
|
| 112 |
+
boi_token_index (`int`, *optional*, defaults to 255999):
|
| 113 |
+
The begin-of-image token index to wrap the image prompt.
|
| 114 |
+
eoi_token_index (`int`, *optional*, defaults to 256000):
|
| 115 |
+
The end-of-image token index to wrap the image prompt.
|
| 116 |
+
image_token_index (`int`, *optional*, defaults to 262144):
|
| 117 |
+
The image token index to encode the image prompt.
|
| 118 |
+
audio_token_index (`int`, *optional*, defaults to 262145):
|
| 119 |
+
The audio token index to encode the audio prompt.
|
| 120 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 121 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
Example:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
| 128 |
+
|
| 129 |
+
>>> # Initializing a Siglip-like vision config
|
| 130 |
+
>>> vision_config = SiglipVisionConfig()
|
| 131 |
+
|
| 132 |
+
>>> # Initializing a Siglip-like vision config
|
| 133 |
+
>>> audio_config = AudioConfig()
|
| 134 |
+
|
| 135 |
+
>>> # Initializing a Gemma3 Text config
|
| 136 |
+
>>> text_config = Gemma3TextConfig()
|
| 137 |
+
|
| 138 |
+
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
| 139 |
+
>>> configuration = Gemma3Config(vision_config, text_config)
|
| 140 |
+
|
| 141 |
+
>>> # Initializing a model from the gemma-3-4b style configuration
|
| 142 |
+
>>> model = Gemma3TextConfig(configuration)
|
| 143 |
+
|
| 144 |
+
>>> # Accessing the model configuration
|
| 145 |
+
>>> configuration = model.config
|
| 146 |
+
```"""
|
| 147 |
+
|
| 148 |
+
model_type = "gemma3omni"
|
| 149 |
+
sub_configs = {
|
| 150 |
+
"text_config": Gemma3TextConfig,
|
| 151 |
+
"vision_config": SiglipVisionConfig,
|
| 152 |
+
"audio_config": AudioConfig,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
text_config: Optional[Gemma3TextConfig] = None,
|
| 158 |
+
vision_config: Optional[SiglipVisionConfig] = None,
|
| 159 |
+
audio_config: Optional[AudioConfig] = None,
|
| 160 |
+
mm_tokens_per_image: int = 256,
|
| 161 |
+
boi_token_index: int = 255_999,
|
| 162 |
+
eoi_token_index: int = 256_000,
|
| 163 |
+
boa_token_index: int = 256_001,
|
| 164 |
+
eoa_token_index: int = 256_002,
|
| 165 |
+
image_token_index: int = 262_144,
|
| 166 |
+
audio_token_index: int = 262_143,
|
| 167 |
+
initializer_range: float = 0.02,
|
| 168 |
+
**kwargs,
|
| 169 |
+
):
|
| 170 |
+
if text_config is None:
|
| 171 |
+
text_config = Gemma3TextConfig()
|
| 172 |
+
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
|
| 173 |
+
elif isinstance(text_config, dict):
|
| 174 |
+
text_config = Gemma3TextConfig(**text_config)
|
| 175 |
+
|
| 176 |
+
if isinstance(vision_config, dict):
|
| 177 |
+
vision_config = SiglipVisionConfig(**vision_config)
|
| 178 |
+
else:
|
| 179 |
+
vision_config = SiglipVisionConfig()
|
| 180 |
+
logger.info(
|
| 181 |
+
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
|
| 182 |
+
"to text tasks."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if isinstance(audio_config, dict):
|
| 186 |
+
audio_config = AudioConfig(**audio_config)
|
| 187 |
+
else:
|
| 188 |
+
audio_config = AudioConfig()
|
| 189 |
+
logger.info(
|
| 190 |
+
"audio_config is None or incompatible with Gemma3AudioConfig intialization. Gemma3 will be limited "
|
| 191 |
+
"to text tasks."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self.text_config = text_config
|
| 195 |
+
self.vision_config = vision_config
|
| 196 |
+
self.audio_config = audio_config
|
| 197 |
+
self.mm_tokens_per_image = mm_tokens_per_image
|
| 198 |
+
self.boi_token_index = boi_token_index
|
| 199 |
+
self.eoi_token_index = eoi_token_index
|
| 200 |
+
self.boa_token_index = boa_token_index
|
| 201 |
+
self.eoa_token_index = eoa_token_index
|
| 202 |
+
self.image_token_index = image_token_index
|
| 203 |
+
self.audio_token_index = audio_token_index
|
| 204 |
+
self.initializer_range = initializer_range
|
| 205 |
+
|
| 206 |
+
super().__init__(**kwargs)
|
cpp/gemma_v1/eval.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from urllib.request import urlopen
|
| 3 |
+
import soundfile
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset, Audio
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModel, AutoProcessor, BatchFeature
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from whisper_normalizer.english import EnglishTextNormalizer
|
| 14 |
+
from whisper_normalizer.basic import BasicTextNormalizer
|
| 15 |
+
import sacrebleu
|
| 16 |
+
from jiwer import cer, wer
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
import soundfile as sf
|
| 19 |
+
import re
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import opencc
|
| 22 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 23 |
+
normalizer = {
|
| 24 |
+
"en_us" : EnglishTextNormalizer(),
|
| 25 |
+
"other" : BasicTextNormalizer()
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
model_id = "/mnt/jeff/gemma_test"
|
| 29 |
+
revision = "main" #"v1.0"
|
| 30 |
+
|
| 31 |
+
model = AutoModel.from_pretrained(
|
| 32 |
+
model_id, device_map="cuda", revision = revision, trust_remote_code=True
|
| 33 |
+
).eval()
|
| 34 |
+
|
| 35 |
+
processor = AutoProcessor.from_pretrained(
|
| 36 |
+
model_id, revision = revision, trust_remote_code=True
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 40 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
INSTRUCTION = {
|
| 44 |
+
"ast": "Translate the audio to {0}.",
|
| 45 |
+
"asr": "Transcribe the audio clip into text.",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
class BaseAudioDataset(Dataset):
|
| 49 |
+
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 50 |
+
self.processor = processor
|
| 51 |
+
self.training = "train" in split
|
| 52 |
+
self.debug = debug
|
| 53 |
+
self.sampling_rate = sampling_rate
|
| 54 |
+
self.name = ""
|
| 55 |
+
|
| 56 |
+
def set_dataset_name(self, name):
|
| 57 |
+
self.name = name
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
|
| 61 |
+
original_size = len(data)
|
| 62 |
+
|
| 63 |
+
data = data.cast_column(audio_field, Audio(decode=False))
|
| 64 |
+
|
| 65 |
+
def identify_corrupted_files(example):
|
| 66 |
+
try:
|
| 67 |
+
sf.read(example[audio_field]["path"])
|
| 68 |
+
|
| 69 |
+
for field in text_fields:
|
| 70 |
+
if example[field].replace('"', '') == "":
|
| 71 |
+
return False
|
| 72 |
+
return True
|
| 73 |
+
except Exception:
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
data = data.filter(identify_corrupted_files, num_proc=16)
|
| 77 |
+
validated_size = len(data)
|
| 78 |
+
|
| 79 |
+
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
return data
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
|
| 86 |
+
original_size = len(data)
|
| 87 |
+
|
| 88 |
+
def filter_audio_by_length(example):
|
| 89 |
+
try:
|
| 90 |
+
audio = example[audio_field]['array']
|
| 91 |
+
channel = 1
|
| 92 |
+
if hasattr(audio, 'ndim') and audio.ndim > 1:
|
| 93 |
+
channel = audio.ndim
|
| 94 |
+
audio = audio.squeeze()
|
| 95 |
+
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
|
| 96 |
+
return min_sec <= audio_length <= max_sec
|
| 97 |
+
except Exception as e:
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
data = data.filter(filter_audio_by_length, num_proc=16)
|
| 101 |
+
filtered_size = len(data)
|
| 102 |
+
|
| 103 |
+
return data
|
| 104 |
+
|
| 105 |
+
def prepare_model_inputs(self, audio_array, instruction, answer_text):
|
| 106 |
+
user_message = {
|
| 107 |
+
'role': 'user',
|
| 108 |
+
'content': '<start_of_audio>' + instruction,
|
| 109 |
+
}
|
| 110 |
+
prompt = self.processor.tokenizer.apply_chat_template(
|
| 111 |
+
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
inputs = self.processor(
|
| 115 |
+
text=prompt,
|
| 116 |
+
audio=[audio_array],
|
| 117 |
+
add_special_tokens=False,
|
| 118 |
+
return_tensors='pt'
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
input_ids = inputs.input_ids
|
| 122 |
+
token_type_ids = inputs.token_type_ids
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
'input_ids': input_ids,
|
| 126 |
+
'token_type_ids': token_type_ids,
|
| 127 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 128 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 129 |
+
'input_modes': inputs.input_modes,
|
| 130 |
+
'answer': answer_text,
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Libri Speech Dataset Class
|
| 135 |
+
class LibriSpeechDataset(BaseAudioDataset):
|
| 136 |
+
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
|
| 137 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 138 |
+
|
| 139 |
+
self.set_dataset_name(f"LibriSpeech_{subset}")
|
| 140 |
+
# only ASR
|
| 141 |
+
self.ast = False
|
| 142 |
+
self.lang = "en"
|
| 143 |
+
|
| 144 |
+
# load dataset
|
| 145 |
+
self.data = load_dataset("openslr/librispeech_asr",
|
| 146 |
+
subset,
|
| 147 |
+
split=split,
|
| 148 |
+
trust_remote_code=True,
|
| 149 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# (Optional) Audio length Filtering
|
| 153 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 154 |
+
|
| 155 |
+
# Instruction Setting
|
| 156 |
+
self.instruction = INSTRUCTION["asr"]
|
| 157 |
+
|
| 158 |
+
def __len__(self):
|
| 159 |
+
return len(self.data)
|
| 160 |
+
|
| 161 |
+
def __getitem__(self, idx):
|
| 162 |
+
data = self.data[idx]
|
| 163 |
+
|
| 164 |
+
# Libri Speech is only for ASR
|
| 165 |
+
answer_text = data["text"].replace('"', '')
|
| 166 |
+
|
| 167 |
+
return self.prepare_model_inputs(
|
| 168 |
+
data["audio"]["array"],
|
| 169 |
+
INSTRUCTION["asr"],
|
| 170 |
+
answer_text
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# common_voice_16_1 dataset
|
| 174 |
+
class CommonVoiceDataset(BaseAudioDataset):
|
| 175 |
+
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
|
| 176 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 177 |
+
|
| 178 |
+
self.set_dataset_name(f"CommonVoice_{source_lang}")
|
| 179 |
+
# only ASR
|
| 180 |
+
self.ast = False
|
| 181 |
+
self.lang=source_lang
|
| 182 |
+
|
| 183 |
+
# load dataset
|
| 184 |
+
self.data = load_dataset("mozilla-foundation/common_voice_16_1",
|
| 185 |
+
source_lang,
|
| 186 |
+
split=split,
|
| 187 |
+
trust_remote_code=True,
|
| 188 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 189 |
+
)
|
| 190 |
+
def prepare_dataset(batch):
|
| 191 |
+
"""Function to preprocess the dataset with the .map method"""
|
| 192 |
+
transcription = batch["sentence"]
|
| 193 |
+
|
| 194 |
+
if transcription.startswith('"') and transcription.endswith('"'):
|
| 195 |
+
# we can remove trailing quotation marks as they do not affect the transcription
|
| 196 |
+
transcription = transcription[1:-1]
|
| 197 |
+
|
| 198 |
+
if transcription[-1] not in [".", "?", "!"]:
|
| 199 |
+
# append a full-stop to sentences that do not end in punctuation
|
| 200 |
+
transcription = transcription + "."
|
| 201 |
+
|
| 202 |
+
batch["sentence"] = transcription
|
| 203 |
+
|
| 204 |
+
return batch
|
| 205 |
+
self.data=self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 206 |
+
|
| 207 |
+
# (Optional) Audio length Filtering
|
| 208 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 209 |
+
|
| 210 |
+
# Instruction Setting
|
| 211 |
+
self.instruction = INSTRUCTION["asr"]
|
| 212 |
+
|
| 213 |
+
def __len__(self):
|
| 214 |
+
return len(self.data)
|
| 215 |
+
|
| 216 |
+
def __getitem__(self, idx):
|
| 217 |
+
data = self.data[idx]
|
| 218 |
+
|
| 219 |
+
answer_text = data["sentence"]
|
| 220 |
+
return self.prepare_model_inputs(
|
| 221 |
+
data["audio"]["array"],
|
| 222 |
+
INSTRUCTION["asr"],
|
| 223 |
+
answer_text
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Fleurs Dataset Class
|
| 228 |
+
class FleursDataset(BaseAudioDataset):
|
| 229 |
+
def __init__(self, processor, split, source_lang, target_lang=None,
|
| 230 |
+
mode="asr", sampling_rate=16000, debug=False):
|
| 231 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 232 |
+
|
| 233 |
+
self.set_dataset_name("Fleurs")
|
| 234 |
+
# Mode Setting (ASR or AST)
|
| 235 |
+
if mode not in ["asr", "ast"]:
|
| 236 |
+
raise ValueError("mode must be 'asr' or 'ast'.")
|
| 237 |
+
|
| 238 |
+
self.mode = mode
|
| 239 |
+
self.ast = (mode == "ast")
|
| 240 |
+
self.source_lang = source_lang
|
| 241 |
+
|
| 242 |
+
# Language name mapping (expand if needed)
|
| 243 |
+
self.lang_names = {
|
| 244 |
+
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# load dataset - source language dataset
|
| 248 |
+
self.data = load_dataset("google/fleurs",
|
| 249 |
+
source_lang,
|
| 250 |
+
split=split,
|
| 251 |
+
trust_remote_code=True,
|
| 252 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 253 |
+
)
|
| 254 |
+
def prepare_dataset(batch):
|
| 255 |
+
import opencc
|
| 256 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 257 |
+
if self.ast:
|
| 258 |
+
translation = converter.convert(batch["translation"])
|
| 259 |
+
batch["translation"] = translation
|
| 260 |
+
else:
|
| 261 |
+
transcription = converter.convert(batch["transcription"])
|
| 262 |
+
batch["transcription"] = transcription
|
| 263 |
+
|
| 264 |
+
return batch
|
| 265 |
+
if (source_lang=="cmn_hans_cn" and not self.ast) or (self.ast and target_lang=="cmn_hans_cn"):
|
| 266 |
+
self.data=self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 267 |
+
|
| 268 |
+
# (Optional) Audio length Filtering
|
| 269 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 270 |
+
self.target_lang_name = ""
|
| 271 |
+
# When AST mode, load target language dataset.
|
| 272 |
+
if self.ast:
|
| 273 |
+
if target_lang is None:
|
| 274 |
+
raise ValueError("AST mode requires target_lang.")
|
| 275 |
+
|
| 276 |
+
self.target_lang = target_lang
|
| 277 |
+
self.lang = f"{source_lang}_{target_lang}"
|
| 278 |
+
|
| 279 |
+
# load dataset - target language dataset (for translation)
|
| 280 |
+
target_data = load_dataset("google/fleurs",
|
| 281 |
+
target_lang,
|
| 282 |
+
split=split,
|
| 283 |
+
trust_remote_code=True,
|
| 284 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
source_dict = {item['id']: item for item in self.data}
|
| 288 |
+
target_dict = {item['id']: item for item in target_data}
|
| 289 |
+
|
| 290 |
+
# only Common ID, add translation fields
|
| 291 |
+
common_ids = set(source_dict.keys()) & set(target_dict.keys())
|
| 292 |
+
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
|
| 293 |
+
self.data = [
|
| 294 |
+
{**source_dict[id], 'translation': target_dict[id]['transcription']}
|
| 295 |
+
for id in common_ids
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
+
# Instruction Setting - use target language name
|
| 299 |
+
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 300 |
+
self.instruction = INSTRUCTION["ast"]
|
| 301 |
+
else:
|
| 302 |
+
# ASR mode
|
| 303 |
+
self.lang = source_lang
|
| 304 |
+
self.instruction = INSTRUCTION["asr"]
|
| 305 |
+
|
| 306 |
+
if self.debug:
|
| 307 |
+
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
| 308 |
+
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
|
| 309 |
+
if self.ast:
|
| 310 |
+
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
|
| 311 |
+
print(f"dataset size: {len(self.data)}")
|
| 312 |
+
|
| 313 |
+
def __len__(self):
|
| 314 |
+
return len(self.data)
|
| 315 |
+
|
| 316 |
+
def __getitem__(self, idx):
|
| 317 |
+
data = self.data[idx]
|
| 318 |
+
audio_array = data["audio"]["array"]
|
| 319 |
+
|
| 320 |
+
if self.ast:
|
| 321 |
+
answer_text = data["translation"]
|
| 322 |
+
else:
|
| 323 |
+
answer_text = data["transcription"]
|
| 324 |
+
|
| 325 |
+
return self.prepare_model_inputs(
|
| 326 |
+
audio_array,
|
| 327 |
+
self.instruction.format(self.target_lang_name),
|
| 328 |
+
answer_text
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
def pad_sequence(sequences, padding_side='left', padding_value=0):
|
| 332 |
+
"""
|
| 333 |
+
Pad a list of sequences to the same length.
|
| 334 |
+
sequences: list of tensors in [seq_len, *] shape
|
| 335 |
+
"""
|
| 336 |
+
assert padding_side in ['right', 'left']
|
| 337 |
+
max_size = sequences[0].size()
|
| 338 |
+
trailing_dims = max_size[1:]
|
| 339 |
+
max_len = max(len(seq) for seq in sequences)
|
| 340 |
+
batch_size = len(sequences)
|
| 341 |
+
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
|
| 342 |
+
for i, seq in enumerate(sequences):
|
| 343 |
+
length = seq.size(0)
|
| 344 |
+
if padding_side == 'right':
|
| 345 |
+
output.data[i, :length] = seq
|
| 346 |
+
else:
|
| 347 |
+
output.data[i, -length:] = seq
|
| 348 |
+
return output
|
| 349 |
+
|
| 350 |
+
def cat_with_pad(tensors, dim, padding_value=0):
|
| 351 |
+
"""
|
| 352 |
+
cat along dim, while pad to max for all other dims
|
| 353 |
+
"""
|
| 354 |
+
ndim = tensors[0].dim()
|
| 355 |
+
assert all(
|
| 356 |
+
t.dim() == ndim for t in tensors[1:]
|
| 357 |
+
), 'All tensors must have the same number of dimensions'
|
| 358 |
+
|
| 359 |
+
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
|
| 360 |
+
out_size[dim] = sum(t.shape[dim] for t in tensors)
|
| 361 |
+
output = tensors[0].new_full(out_size, padding_value)
|
| 362 |
+
|
| 363 |
+
index = 0
|
| 364 |
+
for t in tensors:
|
| 365 |
+
# Create a slice list where every dimension except dim is full slice
|
| 366 |
+
slices = [slice(0, t.shape[d]) for d in range(ndim)]
|
| 367 |
+
# Update only the concat dimension slice
|
| 368 |
+
slices[dim] = slice(index, index + t.shape[dim])
|
| 369 |
+
|
| 370 |
+
output[slices] = t
|
| 371 |
+
index += t.shape[dim]
|
| 372 |
+
|
| 373 |
+
return output
|
| 374 |
+
|
| 375 |
+
def covost_collate_fn(batch):
|
| 376 |
+
input_ids_list = []
|
| 377 |
+
input_audio_embeds_list = []
|
| 378 |
+
audio_embed_sizes_list = []
|
| 379 |
+
audio_attention_mask_list = []
|
| 380 |
+
input_modes_list = []
|
| 381 |
+
answer_list = []
|
| 382 |
+
for inputs in batch:
|
| 383 |
+
input_ids_list.append(inputs['input_ids'][0])
|
| 384 |
+
input_audio_embeds_list.append(inputs['input_audio_embeds'])
|
| 385 |
+
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
|
| 386 |
+
audio_attention_mask_list.append(
|
| 387 |
+
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
|
| 388 |
+
)
|
| 389 |
+
input_modes_list.append(inputs['input_modes'])
|
| 390 |
+
answer_list.append(inputs['answer'])
|
| 391 |
+
|
| 392 |
+
try:
|
| 393 |
+
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
|
| 394 |
+
audio_attention_mask = (
|
| 395 |
+
pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
|
| 396 |
+
if len(audio_attention_mask_list) > 1
|
| 397 |
+
else None
|
| 398 |
+
)
|
| 399 |
+
except Exception as e:
|
| 400 |
+
print(e)
|
| 401 |
+
print(input_ids_list)
|
| 402 |
+
print(audio_attention_mask)
|
| 403 |
+
raise
|
| 404 |
+
attention_mask = (input_ids != 0).long()
|
| 405 |
+
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
|
| 406 |
+
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
|
| 407 |
+
input_modes = torch.cat(input_modes_list)
|
| 408 |
+
|
| 409 |
+
return BatchFeature(
|
| 410 |
+
{
|
| 411 |
+
'input_ids': input_ids,
|
| 412 |
+
'attention_mask': attention_mask,
|
| 413 |
+
'input_audio_embeds': input_audio_embeds,
|
| 414 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 415 |
+
'audio_attention_mask': audio_attention_mask,
|
| 416 |
+
'input_modes': input_modes,
|
| 417 |
+
'answer': answer_list,
|
| 418 |
+
}
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
|
| 422 |
+
filename = f"{task}_{dataset_name}_{source_lang}"
|
| 423 |
+
if target_lang:
|
| 424 |
+
filename += f"_to_{target_lang}"
|
| 425 |
+
if sample_idx is not None:
|
| 426 |
+
filename += f"_sample_{sample_idx}"
|
| 427 |
+
|
| 428 |
+
filepath = os.path.join(results_dir, f"{filename}.json")
|
| 429 |
+
|
| 430 |
+
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 431 |
+
|
| 432 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 433 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 434 |
+
|
| 435 |
+
return filepath
|
| 436 |
+
|
| 437 |
+
def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 4, is_asr=True):
|
| 438 |
+
task_type = "asr" if is_asr else "translation"
|
| 439 |
+
eval_lang = source_lang if is_asr else target_lang
|
| 440 |
+
if eval_lang in normalizer:
|
| 441 |
+
eval_normalizer = normalizer[eval_lang]
|
| 442 |
+
else:
|
| 443 |
+
eval_normalizer = normalizer['other']
|
| 444 |
+
sample_results = []
|
| 445 |
+
|
| 446 |
+
if num_samples > 0 and num_samples < len(dataset):
|
| 447 |
+
indices = np.random.choice(len(dataset), num_samples, replace=False)
|
| 448 |
+
dataset = dataset.select(indices)
|
| 449 |
+
|
| 450 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn)
|
| 451 |
+
|
| 452 |
+
evaluated_samples = {}
|
| 453 |
+
|
| 454 |
+
for batch_idx, batch in enumerate(tqdm(dataloader)):
|
| 455 |
+
batch_references = batch.pop("answer")
|
| 456 |
+
|
| 457 |
+
if torch.cuda.is_available():
|
| 458 |
+
try:
|
| 459 |
+
batch = {k: v.to("cuda") for k, v in batch.items()}
|
| 460 |
+
except:
|
| 461 |
+
print('error')
|
| 462 |
+
break
|
| 463 |
+
|
| 464 |
+
with torch.inference_mode():
|
| 465 |
+
generate_ids = model.generate(**batch,
|
| 466 |
+
max_new_tokens=256,
|
| 467 |
+
#temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
input_lengths = batch['input_ids'].shape[1]
|
| 471 |
+
generate_ids = generate_ids[:, input_lengths:]
|
| 472 |
+
|
| 473 |
+
batch_predictions = processor.batch_decode(
|
| 474 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
|
| 478 |
+
idx = batch_idx * batch_size + i
|
| 479 |
+
sample_result = {
|
| 480 |
+
"id": idx,
|
| 481 |
+
"reference": reference,
|
| 482 |
+
"prediction": converter.convert(prediction)
|
| 483 |
+
}
|
| 484 |
+
sample_results.append(sample_result)
|
| 485 |
+
|
| 486 |
+
if (batch_idx + 1) % 10 == 0:
|
| 487 |
+
temp_results = []
|
| 488 |
+
|
| 489 |
+
for item in sample_results:
|
| 490 |
+
sample_id = item["id"]
|
| 491 |
+
|
| 492 |
+
if sample_id in evaluated_samples:
|
| 493 |
+
temp_item = item.copy()
|
| 494 |
+
temp_item.update(evaluated_samples[sample_id])
|
| 495 |
+
temp_results.append(temp_item)
|
| 496 |
+
else:
|
| 497 |
+
temp_item = item.copy()
|
| 498 |
+
try:
|
| 499 |
+
ref = eval_normalizer(item["reference"])
|
| 500 |
+
pred = eval_normalizer(item["prediction"])
|
| 501 |
+
|
| 502 |
+
# BLEU, WER/CER
|
| 503 |
+
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
|
| 504 |
+
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
|
| 505 |
+
utt_wer = round(wer(ref, pred) * 100, 2)
|
| 506 |
+
|
| 507 |
+
metrics = {
|
| 508 |
+
"bleu": utt_bleu,
|
| 509 |
+
"cer": min(100,utt_cer),
|
| 510 |
+
"wer": utt_wer
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
evaluated_samples[sample_id] = metrics
|
| 514 |
+
temp_item.update(metrics)
|
| 515 |
+
except Exception as e:
|
| 516 |
+
print(f"Error evaluating sample {sample_id}: {e}")
|
| 517 |
+
metrics = {
|
| 518 |
+
"bleu": 0,
|
| 519 |
+
"cer": 100,
|
| 520 |
+
"wer": 100,
|
| 521 |
+
"error": str(e)
|
| 522 |
+
}
|
| 523 |
+
evaluated_samples[sample_id] = metrics
|
| 524 |
+
temp_item.update(metrics)
|
| 525 |
+
|
| 526 |
+
temp_results.append(temp_item)
|
| 527 |
+
|
| 528 |
+
partial_results = {
|
| 529 |
+
"task": task_type,
|
| 530 |
+
"source_lang": source_lang,
|
| 531 |
+
"target_lang": target_lang,
|
| 532 |
+
"num_samples": len(temp_results),
|
| 533 |
+
"sample_results": temp_results
|
| 534 |
+
}
|
| 535 |
+
save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
|
| 536 |
+
|
| 537 |
+
for item in sample_results:
|
| 538 |
+
ref = eval_normalizer(item["reference"])
|
| 539 |
+
pred = eval_normalizer(item["prediction"])
|
| 540 |
+
|
| 541 |
+
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
|
| 542 |
+
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
|
| 543 |
+
utt_wer = round(wer(ref, pred) * 100, 2)
|
| 544 |
+
|
| 545 |
+
item.update({
|
| 546 |
+
"bleu": utt_bleu,
|
| 547 |
+
"cer": min(100,utt_cer),
|
| 548 |
+
"wer": utt_wer
|
| 549 |
+
})
|
| 550 |
+
|
| 551 |
+
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
|
| 552 |
+
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
|
| 553 |
+
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
|
| 554 |
+
|
| 555 |
+
results = {
|
| 556 |
+
"dataset": dataset.name,
|
| 557 |
+
"task": task_type,
|
| 558 |
+
"source_lang": source_lang,
|
| 559 |
+
"target_lang": target_lang,
|
| 560 |
+
"num_samples": len(sample_results),
|
| 561 |
+
"metrics": {
|
| 562 |
+
"bleu": avg_bleu,
|
| 563 |
+
"cer": avg_cer,
|
| 564 |
+
"wer": avg_wer
|
| 565 |
+
},
|
| 566 |
+
"sample_results": sample_results
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
save_results(results, dataset.name, task_type, source_lang, target_lang)
|
| 570 |
+
return results
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
if __name__ == "__main__":
|
| 574 |
+
|
| 575 |
+
source_languages = [
|
| 576 |
+
("en_us", "English"),
|
| 577 |
+
]
|
| 578 |
+
|
| 579 |
+
target_languages = [
|
| 580 |
+
("zh-TW", "zh-TW"),
|
| 581 |
+
]
|
| 582 |
+
|
| 583 |
+
num_samples = -1
|
| 584 |
+
batch_size = 32
|
| 585 |
+
|
| 586 |
+
for source_lang, target_lang in zip(source_languages, target_languages):
|
| 587 |
+
print(f"\n===== {source_lang[0]} ASR =====")
|
| 588 |
+
|
| 589 |
+
split = "test"
|
| 590 |
+
|
| 591 |
+
datasets = []
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
commonvoice_speech_tw = CommonVoiceDataset(
|
| 596 |
+
processor=processor,
|
| 597 |
+
source_lang="zh-TW",
|
| 598 |
+
split=split
|
| 599 |
+
)
|
| 600 |
+
datasets.append(commonvoice_speech_tw)
|
| 601 |
+
fleurs = FleursDataset(
|
| 602 |
+
processor=processor,
|
| 603 |
+
split=split,
|
| 604 |
+
source_lang="en_us", # English
|
| 605 |
+
mode="asr"
|
| 606 |
+
)
|
| 607 |
+
datasets.append(fleurs)
|
| 608 |
+
|
| 609 |
+
# Libri Speech Clean ASR mode (English -> English text)
|
| 610 |
+
# libri_speech_clean = LibriSpeechDataset(
|
| 611 |
+
# processor=processor,
|
| 612 |
+
# subset="clean",
|
| 613 |
+
# split=split
|
| 614 |
+
# )
|
| 615 |
+
# datasets.append(libri_speech_clean)
|
| 616 |
+
|
| 617 |
+
# # Libri Speech Other ASR mode (English -> English text)
|
| 618 |
+
# libri_speech_other = LibriSpeechDataset(
|
| 619 |
+
# processor=processor,
|
| 620 |
+
# subset="other",
|
| 621 |
+
# split=split
|
| 622 |
+
# )
|
| 623 |
+
# datasets.append(libri_speech_other)
|
| 624 |
+
|
| 625 |
+
# Fleurs ASR mode (English -> English text)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
for dataset in datasets:
|
| 629 |
+
# ASR
|
| 630 |
+
asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
|
| 631 |
+
|
| 632 |
+
print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR===")
|
| 633 |
+
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
|
| 634 |
+
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
|
| 635 |
+
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
|
cpp/gemma_v1/eval_multiturn.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cpp/gemma_v1/eval_multiturn.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from urllib.request import urlopen
|
| 3 |
+
import soundfile
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset, Audio
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import AutoModel, AutoProcessor, BatchFeature
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from whisper_normalizer.english import EnglishTextNormalizer
|
| 14 |
+
from whisper_normalizer.basic import BasicTextNormalizer
|
| 15 |
+
import sacrebleu
|
| 16 |
+
from jiwer import cer, wer
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
import soundfile as sf
|
| 19 |
+
import re
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import opencc
|
| 22 |
+
from ASRDataset import *
|
| 23 |
+
|
| 24 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 25 |
+
normalizer = {
|
| 26 |
+
"en_us" : EnglishTextNormalizer(),
|
| 27 |
+
"other" : BasicTextNormalizer()
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
model_id = "/mnt/jeff/gemma_test"
|
| 31 |
+
revision = "main" #"v1.0"
|
| 32 |
+
|
| 33 |
+
model = AutoModel.from_pretrained(
|
| 34 |
+
model_id, device_map="cuda", revision = revision, trust_remote_code=True
|
| 35 |
+
).eval()
|
| 36 |
+
|
| 37 |
+
processor = AutoProcessor.from_pretrained(
|
| 38 |
+
model_id, revision = revision, trust_remote_code=True
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 42 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
INSTRUCTION = {
|
| 46 |
+
"ast": "Translate the audio to {0}.",
|
| 47 |
+
"asr": "Transcribe the audio clip into text.",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
|
| 53 |
+
filename = f"{task}_{dataset_name}_{source_lang}"
|
| 54 |
+
if target_lang:
|
| 55 |
+
filename += f"_to_{target_lang}"
|
| 56 |
+
if sample_idx is not None:
|
| 57 |
+
filename += f"_sample_{sample_idx}"
|
| 58 |
+
|
| 59 |
+
filepath = os.path.join(results_dir, f"{filename}.json")
|
| 60 |
+
|
| 61 |
+
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 62 |
+
|
| 63 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 64 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 65 |
+
|
| 66 |
+
return filepath
|
| 67 |
+
|
| 68 |
+
def evaluate_task(dataset):
|
| 69 |
+
sample_results = []
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=covost_collate_fn)
|
| 73 |
+
|
| 74 |
+
evaluated_samples = {}
|
| 75 |
+
|
| 76 |
+
for batch_idx, batch in enumerate(tqdm(dataloader)):
|
| 77 |
+
|
| 78 |
+
if torch.cuda.is_available():
|
| 79 |
+
try:
|
| 80 |
+
batch = {k: v.to("cuda") for k, v in batch.items()}
|
| 81 |
+
except:
|
| 82 |
+
print('error')
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
with torch.inference_mode():
|
| 86 |
+
generate_ids = model.generate(**batch,
|
| 87 |
+
max_new_tokens=256,
|
| 88 |
+
#temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
input_lengths = batch['input_ids'].shape[1]
|
| 92 |
+
generate_ids = generate_ids[:, input_lengths:]
|
| 93 |
+
|
| 94 |
+
batch_predictions = processor.batch_decode(
|
| 95 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 96 |
+
)
|
| 97 |
+
input_lengths = batch['input_ids'].shape[1]
|
| 98 |
+
label_ids = generate_ids[:, input_lengths:]
|
| 99 |
+
batch_references = processor.batch_decode(
|
| 100 |
+
label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
|
| 104 |
+
idx = batch_idx + i
|
| 105 |
+
sample_result = {
|
| 106 |
+
"id": idx,
|
| 107 |
+
"reference": reference,
|
| 108 |
+
"prediction": converter.convert(prediction)
|
| 109 |
+
}
|
| 110 |
+
sample_results.append(sample_result)
|
| 111 |
+
|
| 112 |
+
if (batch_idx + 1) % 10 == 0:
|
| 113 |
+
temp_results = []
|
| 114 |
+
|
| 115 |
+
for item in sample_results:
|
| 116 |
+
sample_id = item["id"]
|
| 117 |
+
|
| 118 |
+
if sample_id in evaluated_samples:
|
| 119 |
+
temp_item = item.copy()
|
| 120 |
+
temp_item.update(evaluated_samples[sample_id])
|
| 121 |
+
temp_results.append(temp_item)
|
| 122 |
+
else:
|
| 123 |
+
temp_item = item.copy()
|
| 124 |
+
try:
|
| 125 |
+
ref = eval_normalizer(item["reference"])
|
| 126 |
+
pred = eval_normalizer(item["prediction"])
|
| 127 |
+
|
| 128 |
+
# BLEU, WER/CER
|
| 129 |
+
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
|
| 130 |
+
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
|
| 131 |
+
utt_wer = round(wer(ref, pred) * 100, 2)
|
| 132 |
+
|
| 133 |
+
metrics = {
|
| 134 |
+
"bleu": utt_bleu,
|
| 135 |
+
"cer": min(100,utt_cer),
|
| 136 |
+
"wer": utt_wer
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
evaluated_samples[sample_id] = metrics
|
| 140 |
+
temp_item.update(metrics)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"Error evaluating sample {sample_id}: {e}")
|
| 143 |
+
metrics = {
|
| 144 |
+
"bleu": 0,
|
| 145 |
+
"cer": 100,
|
| 146 |
+
"wer": 100,
|
| 147 |
+
"error": str(e)
|
| 148 |
+
}
|
| 149 |
+
evaluated_samples[sample_id] = metrics
|
| 150 |
+
temp_item.update(metrics)
|
| 151 |
+
|
| 152 |
+
temp_results.append(temp_item)
|
| 153 |
+
|
| 154 |
+
partial_results = {
|
| 155 |
+
"task": task_type,
|
| 156 |
+
"source_lang": source_lang,
|
| 157 |
+
"target_lang": target_lang,
|
| 158 |
+
"num_samples": len(temp_results),
|
| 159 |
+
"sample_results": temp_results
|
| 160 |
+
}
|
| 161 |
+
save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
|
| 162 |
+
|
| 163 |
+
for item in sample_results:
|
| 164 |
+
ref = eval_normalizer(item["reference"])
|
| 165 |
+
pred = eval_normalizer(item["prediction"])
|
| 166 |
+
|
| 167 |
+
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
|
| 168 |
+
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
|
| 169 |
+
utt_wer = round(wer(ref, pred) * 100, 2)
|
| 170 |
+
|
| 171 |
+
item.update({
|
| 172 |
+
"bleu": utt_bleu,
|
| 173 |
+
"cer": min(100,utt_cer),
|
| 174 |
+
"wer": utt_wer
|
| 175 |
+
})
|
| 176 |
+
|
| 177 |
+
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
|
| 178 |
+
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
|
| 179 |
+
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
|
| 180 |
+
|
| 181 |
+
results = {
|
| 182 |
+
"dataset": dataset.name,
|
| 183 |
+
"task": task_type,
|
| 184 |
+
"source_lang": source_lang,
|
| 185 |
+
"target_lang": target_lang,
|
| 186 |
+
"num_samples": len(sample_results),
|
| 187 |
+
"metrics": {
|
| 188 |
+
"bleu": avg_bleu,
|
| 189 |
+
"cer": avg_cer,
|
| 190 |
+
"wer": avg_wer
|
| 191 |
+
},
|
| 192 |
+
"sample_results": sample_results
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
save_results(results, dataset.name, task_type, source_lang, target_lang)
|
| 196 |
+
return results
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
|
| 201 |
+
datasets = []
|
| 202 |
+
pickup_dataset = MultiturnAudioDataset(split='eval',processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
|
| 203 |
+
datasets.append(pickup_dataset)
|
| 204 |
+
for dataset in datasets:
|
| 205 |
+
# ASR
|
| 206 |
+
asr_results = evaluate_task(dataset)
|
| 207 |
+
|
| 208 |
+
print(f"\n=== {asr_results.get('dataset', 'Dataset')}")
|
| 209 |
+
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
|
| 210 |
+
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
|
| 211 |
+
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
|
cpp/gemma_v1/merge_lora.ipynb
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from safetensors import safe_open\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"lora = {}\n",
|
| 12 |
+
"with safe_open(\"/data2/bjh/diffusion-pipe/cosmos_test/20250327_02-37-25/epoch5/adapter_model.safetensors\", framework=\"pt\", device='cpu') as f:\n",
|
| 13 |
+
" for k in f.keys():\n",
|
| 14 |
+
" lora[k] = f.get_tensor(k)"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 2,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"tensors = {}\n",
|
| 24 |
+
"with safe_open(\"/data2/bjh/ComfyUI/models/diffusion_models/Cosmos-1_0-Diffusion-14B-Text2World.safetensors\", framework=\"pt\", device='cpu') as f:\n",
|
| 25 |
+
" for k in f.keys():\n",
|
| 26 |
+
" tensors[k] = f.get_tensor(k)"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": 3,
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [
|
| 34 |
+
{
|
| 35 |
+
"data": {
|
| 36 |
+
"text/plain": [
|
| 37 |
+
"1152"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
"execution_count": 3,
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"output_type": "execute_result"
|
| 43 |
+
}
|
| 44 |
+
],
|
| 45 |
+
"source": [
|
| 46 |
+
"len(lora)"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": 4,
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"name_lis = []\n",
|
| 56 |
+
"for k in lora:\n",
|
| 57 |
+
" a = k.split('.')[1:][:-2]\n",
|
| 58 |
+
" name = '.'.join(a)\n",
|
| 59 |
+
" name_lis.append(name)\n",
|
| 60 |
+
"name_lis=set(name_lis)"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"execution_count": 5,
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"source": [
|
| 69 |
+
"import torch\n",
|
| 70 |
+
"new_dic = {}\n",
|
| 71 |
+
"for k in tensors:\n",
|
| 72 |
+
" name='.'.join(k.split('.')[1:][:-1])\n",
|
| 73 |
+
" if name in name_lis:\n",
|
| 74 |
+
" a,b = lora['diffusion_model.'+name+'.lora_A.weight'],lora['diffusion_model.'+name+'.lora_B.weight']\n",
|
| 75 |
+
" new_dic[k]=tensors[k]+torch.matmul(b,a)\n",
|
| 76 |
+
" else:\n",
|
| 77 |
+
" new_dic[k]=tensors[k]"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": 6,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"from safetensors.torch import save_file\n",
|
| 87 |
+
"save_file(new_dic,'test.safetensors')"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": []
|
| 96 |
+
}
|
| 97 |
+
],
|
| 98 |
+
"metadata": {
|
| 99 |
+
"kernelspec": {
|
| 100 |
+
"display_name": "dp",
|
| 101 |
+
"language": "python",
|
| 102 |
+
"name": "python3"
|
| 103 |
+
},
|
| 104 |
+
"language_info": {
|
| 105 |
+
"codemirror_mode": {
|
| 106 |
+
"name": "ipython",
|
| 107 |
+
"version": 3
|
| 108 |
+
},
|
| 109 |
+
"file_extension": ".py",
|
| 110 |
+
"mimetype": "text/x-python",
|
| 111 |
+
"name": "python",
|
| 112 |
+
"nbconvert_exporter": "python",
|
| 113 |
+
"pygments_lexer": "ipython3",
|
| 114 |
+
"version": "3.12.9"
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
"nbformat": 4,
|
| 118 |
+
"nbformat_minor": 2
|
| 119 |
+
}
|
cpp/gemma_v1/model-00001-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ddd3e8916f7ad6ad92651ac288227995c1d34628f0f888eb2dc5b9acb4dc0121
|
| 3 |
+
size 4976361384
|
cpp/gemma_v1/model-00002-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c343a455ca768923cb3b9ab77cbb91c9cd2526a1bee5740cf9cf86bfa85a0a7b
|
| 3 |
+
size 4984907872
|
cpp/gemma_v1/model-00003-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ad04449d015f4efbda75d6cc41e06296b4da996cd84053fa6f9791fe16d55d03
|
| 3 |
+
size 732141104
|
cpp/gemma_v1/model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cpp/gemma_v1/modeling_gemma3omni.py
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from collections.abc import Callable
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from transformers.activations import ACT2FN
|
| 10 |
+
from transformers.cache_utils import Cache, HybridCache, StaticCache
|
| 11 |
+
from transformers.generation import GenerationMixin
|
| 12 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
| 14 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 15 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 16 |
+
from transformers.processing_utils import Unpack
|
| 17 |
+
from transformers.utils import (
|
| 18 |
+
add_start_docstrings,
|
| 19 |
+
add_start_docstrings_to_model_forward,
|
| 20 |
+
is_torchdynamo_compiling,
|
| 21 |
+
logging,
|
| 22 |
+
replace_return_docstrings,
|
| 23 |
+
)
|
| 24 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 25 |
+
from transformers import AutoModel, AutoModelForCausalLM
|
| 26 |
+
|
| 27 |
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast, Gemma3PreTrainedModel, Gemma3MultiModalProjector
|
| 28 |
+
|
| 29 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 30 |
+
|
| 31 |
+
from .configuration_gemma3omni import Gemma3OmniConfig
|
| 32 |
+
from .speech_conformer_encoder import ConformerEncoder
|
| 33 |
+
from enum import Enum
|
| 34 |
+
class InputMode(Enum):
|
| 35 |
+
LANGUAGE = 0
|
| 36 |
+
VISION = 1
|
| 37 |
+
SPEECH = 2
|
| 38 |
+
VISION_SPEECH = 3
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
_CONFIG_FOR_DOC = "Gemma3OmniConfig"
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class Gemma3OmniCausalLMOutputWithPast(Gemma3CausalLMOutputWithPast):
|
| 44 |
+
"""
|
| 45 |
+
Multimodal version of `Gemma3CausalLMOutputWithPast`.
|
| 46 |
+
Adds audio-specific hidden states.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
audio_hidden_states (`torch.FloatTensor`, *optional*):
|
| 50 |
+
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
|
| 51 |
+
Audio hidden states produced by the audio encoder.
|
| 52 |
+
"""
|
| 53 |
+
audio_hidden_states: Optional[torch.FloatTensor] = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
GEMMA3_START_DOCSTRING = r"""
|
| 57 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 58 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 59 |
+
etc.)
|
| 60 |
+
|
| 61 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 62 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 63 |
+
and behavior.
|
| 64 |
+
|
| 65 |
+
Parameters:
|
| 66 |
+
config ([`Gemma3Config`]):
|
| 67 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 68 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 69 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
GEMMA3_INPUTS_DOCSTRING = r"""
|
| 75 |
+
Args:
|
| 76 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 77 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 78 |
+
it.
|
| 79 |
+
|
| 80 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 81 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 82 |
+
|
| 83 |
+
[What are input IDs?](../glossary#input-ids)
|
| 84 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 85 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 86 |
+
|
| 87 |
+
- 1 for tokens that are **not masked**,
|
| 88 |
+
- 0 for tokens that are **masked**.
|
| 89 |
+
|
| 90 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 91 |
+
|
| 92 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 93 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 94 |
+
|
| 95 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 96 |
+
`past_key_values`).
|
| 97 |
+
|
| 98 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 99 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 100 |
+
information on the default strategy.
|
| 101 |
+
|
| 102 |
+
- 1 indicates the head is **not masked**,
|
| 103 |
+
- 0 indicates the head is **masked**.
|
| 104 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 105 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 106 |
+
config.n_positions - 1]`.
|
| 107 |
+
|
| 108 |
+
[What are position IDs?](../glossary#position-ids)
|
| 109 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 110 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 111 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 112 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 113 |
+
|
| 114 |
+
Two formats are allowed:
|
| 115 |
+
- a [`~cache_utils.Cache`] instance, see our
|
| 116 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
| 117 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 118 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 119 |
+
cache format.
|
| 120 |
+
|
| 121 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 122 |
+
legacy cache format will be returned.
|
| 123 |
+
|
| 124 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 125 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 126 |
+
of shape `(batch_size, sequence_length)`.
|
| 127 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 128 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 129 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 130 |
+
model's internal embedding lookup matrix.
|
| 131 |
+
use_cache (`bool`, *optional*):
|
| 132 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 133 |
+
`past_key_values`).
|
| 134 |
+
output_attentions (`bool`, *optional*):
|
| 135 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 136 |
+
tensors for more detail.
|
| 137 |
+
output_hidden_states (`bool`, *optional*):
|
| 138 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 139 |
+
more detail.
|
| 140 |
+
return_dict (`bool`, *optional*):
|
| 141 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 142 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 143 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 144 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 145 |
+
the complete sequence length.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
@add_start_docstrings(
|
| 149 |
+
"The bare Gemma3 Model outputting raw hidden-states without any specific head on top.",
|
| 150 |
+
GEMMA3_START_DOCSTRING,
|
| 151 |
+
)
|
| 152 |
+
class Gemma3OmniPreTrainedModel(Gemma3PreTrainedModel):
|
| 153 |
+
config_class = Gemma3OmniConfig
|
| 154 |
+
|
| 155 |
+
@add_start_docstrings(
|
| 156 |
+
"""The GEMMA3 model which consists of a vision backbone and a language model.""",
|
| 157 |
+
GEMMA3_START_DOCSTRING,
|
| 158 |
+
)
|
| 159 |
+
class Gemma3OmniForConditionalGeneration(Gemma3OmniPreTrainedModel, GenerationMixin):
|
| 160 |
+
def __init__(self, config: Gemma3OmniConfig):
|
| 161 |
+
super().__init__(config)
|
| 162 |
+
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
| 163 |
+
audio_config = config.audio_config.to_diff_dict()
|
| 164 |
+
for item in ['transformers_version', 'model_type', 'torch_dtype']:
|
| 165 |
+
if item in audio_config:
|
| 166 |
+
audio_config.pop(item)
|
| 167 |
+
self.audio_tower = ConformerEncoder(**audio_config)
|
| 168 |
+
self.audio_tower.post_init({})
|
| 169 |
+
self.audio_tower = self.audio_tower.to(dtype=self.dtype)
|
| 170 |
+
self.audio_projector = nn.Sequential(
|
| 171 |
+
nn.Linear(in_features=config.audio_config.attention_dim, out_features=config.text_config.hidden_size, bias=True),
|
| 172 |
+
nn.GELU(approximate='none'),
|
| 173 |
+
nn.Linear(in_features=config.text_config.hidden_size, out_features=config.text_config.hidden_size, bias=True)
|
| 174 |
+
).to(dtype=self.dtype)
|
| 175 |
+
|
| 176 |
+
self.multi_modal_projector = Gemma3MultiModalProjector(config)
|
| 177 |
+
self.vocab_size = config.text_config.vocab_size
|
| 178 |
+
|
| 179 |
+
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
|
| 180 |
+
|
| 181 |
+
if language_model._tied_weights_keys is not None:
|
| 182 |
+
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
| 183 |
+
self.language_model = language_model
|
| 184 |
+
|
| 185 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
| 186 |
+
self.init_lora()
|
| 187 |
+
self.post_init()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def init_lora(self):
|
| 191 |
+
from peft import LoraConfig, get_peft_model
|
| 192 |
+
import warnings
|
| 193 |
+
print('######################## speech lora #############')
|
| 194 |
+
speech_lora_config = LoraConfig(
|
| 195 |
+
r=self.config.speech_lora['r'],
|
| 196 |
+
lora_alpha=self.config.speech_lora['lora_alpha'],
|
| 197 |
+
target_modules=self.config.speech_lora['layer'],
|
| 198 |
+
use_rslora=self.config.speech_lora['use_rslora'],
|
| 199 |
+
lora_dropout=self.config.speech_lora['dp'],
|
| 200 |
+
task_type="CAUSAL_LM",
|
| 201 |
+
)
|
| 202 |
+
self.language_model.model = get_peft_model(self.language_model.model, speech_lora_config, adapter_name="speech")
|
| 203 |
+
print('######################## text lora #############')
|
| 204 |
+
text_lora_config = LoraConfig(
|
| 205 |
+
r=self.config.text_lora['r'],
|
| 206 |
+
lora_alpha=self.config.text_lora['lora_alpha'],
|
| 207 |
+
target_modules=self.config.text_lora['layer'],
|
| 208 |
+
use_rslora=self.config.text_lora['use_rslora'],
|
| 209 |
+
lora_dropout=self.config.text_lora['dp'],
|
| 210 |
+
task_type="CAUSAL_LM",
|
| 211 |
+
)
|
| 212 |
+
self.language_model.model.base_model.active_adapter.append("text")
|
| 213 |
+
self.language_model.model.add_adapter("text", text_lora_config)
|
| 214 |
+
|
| 215 |
+
def set_lora_adapter(self, adapter_name) -> None:
|
| 216 |
+
from peft.tuners.lora.layer import LoraLayer
|
| 217 |
+
for module in self.modules():
|
| 218 |
+
if isinstance(module, LoraLayer):
|
| 219 |
+
if module.merged:
|
| 220 |
+
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
|
| 221 |
+
module.unmerge()
|
| 222 |
+
module.set_adapter(adapter_name)
|
| 223 |
+
module._disable_adapters = False
|
| 224 |
+
|
| 225 |
+
def unset_lora_adapter(self) -> None:
|
| 226 |
+
# Ref: peft/tuners/tuners_utils.py - enable_adapters()
|
| 227 |
+
# Ref: peft/tuners/lora/layer.py
|
| 228 |
+
from peft.tuners.lora.layer import LoraLayer
|
| 229 |
+
for module in self.modules():
|
| 230 |
+
if isinstance(module, LoraLayer):
|
| 231 |
+
# disable grads on all adapter layers
|
| 232 |
+
# TODO weijian: may use enable_adapters() instead
|
| 233 |
+
for layer_name in module.adapter_layer_names:
|
| 234 |
+
layer = getattr(module, layer_name)
|
| 235 |
+
layer.requires_grad_(False)
|
| 236 |
+
module._disable_adapters = True
|
| 237 |
+
|
| 238 |
+
def get_input_embeddings(self):
|
| 239 |
+
return self.language_model.get_input_embeddings()
|
| 240 |
+
|
| 241 |
+
def set_input_embeddings(self, value):
|
| 242 |
+
self.language_model.set_input_embeddings(value)
|
| 243 |
+
|
| 244 |
+
def get_output_embeddings(self):
|
| 245 |
+
return self.language_model.get_output_embeddings()
|
| 246 |
+
|
| 247 |
+
def set_output_embeddings(self, new_embeddings):
|
| 248 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 249 |
+
|
| 250 |
+
def set_decoder(self, decoder):
|
| 251 |
+
self.language_model.set_decoder(decoder)
|
| 252 |
+
|
| 253 |
+
def get_decoder(self):
|
| 254 |
+
return self.language_model.get_decoder()
|
| 255 |
+
|
| 256 |
+
def _update_causal_mask(
|
| 257 |
+
self,
|
| 258 |
+
attention_mask,
|
| 259 |
+
token_type_ids,
|
| 260 |
+
past_key_values,
|
| 261 |
+
cache_position,
|
| 262 |
+
input_tensor,
|
| 263 |
+
is_training: bool = False,
|
| 264 |
+
):
|
| 265 |
+
if self.config.text_config._attn_implementation == "flash_attention_2":
|
| 266 |
+
return attention_mask
|
| 267 |
+
|
| 268 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 269 |
+
# In this case we assume that the mask comes already in inverted
|
| 270 |
+
# form and requires no inversion or slicing.
|
| 271 |
+
return attention_mask
|
| 272 |
+
|
| 273 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 274 |
+
min_dtype = torch.finfo(self.dtype).min
|
| 275 |
+
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
| 276 |
+
if using_static_cache:
|
| 277 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 278 |
+
elif isinstance(past_key_values, HybridCache):
|
| 279 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 280 |
+
else:
|
| 281 |
+
target_length = (
|
| 282 |
+
attention_mask.shape[-1]
|
| 283 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 284 |
+
else cache_position[0] + sequence_length + 1
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 288 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 289 |
+
return attention_mask
|
| 290 |
+
|
| 291 |
+
causal_mask = torch.full(
|
| 292 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
| 296 |
+
if sequence_length != 1:
|
| 297 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 298 |
+
|
| 299 |
+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
| 300 |
+
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
| 301 |
+
|
| 302 |
+
# Apply bidirectional mask on images if token type ids are provided
|
| 303 |
+
if token_type_ids is not None and sequence_length != 1:
|
| 304 |
+
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
| 305 |
+
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
| 306 |
+
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
| 307 |
+
causal_mask = causal_mask.clone()
|
| 308 |
+
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
| 309 |
+
token_type_mask, 0.0
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if attention_mask is not None:
|
| 313 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 314 |
+
mask_length = attention_mask.shape[-1]
|
| 315 |
+
|
| 316 |
+
# Then apply padding mask (will mask pad tokens)
|
| 317 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
| 318 |
+
padding_mask = padding_mask == 0
|
| 319 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 320 |
+
padding_mask, min_dtype
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
return causal_mask
|
| 324 |
+
|
| 325 |
+
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 326 |
+
"""
|
| 327 |
+
Projects the last hidden state from the vision model into language model space.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
| 331 |
+
The tensors corresponding to the input images.
|
| 332 |
+
Returns:
|
| 333 |
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
| 334 |
+
"""
|
| 335 |
+
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
| 336 |
+
image_features = self.multi_modal_projector(vision_outputs)
|
| 337 |
+
return image_features
|
| 338 |
+
|
| 339 |
+
def get_audio_features(self, input_audio_embeds: torch.FloatTensor, audio_attention_mask: torch.FloatTensor, audio_embed_sizes: torch.FloatTensor):
|
| 340 |
+
"""
|
| 341 |
+
Projects the last hidden state from the audio model into language model space.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
audio_inputs (`torch.FloatTensor]` of shape `(batch_size, sequence_length, feature_dim)`)
|
| 345 |
+
The tensors corresponding to the input audio features.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
audio_features (`torch.Tensor`): Audio feature tensor of shape `(batch_size, audio_length, embed_dim)`).
|
| 349 |
+
"""
|
| 350 |
+
audio_features, masks = self.audio_tower(input_audio_embeds, audio_attention_mask)
|
| 351 |
+
audio_outputs = self.audio_projector(audio_features)
|
| 352 |
+
return audio_outputs
|
| 353 |
+
|
| 354 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
| 355 |
+
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
| 356 |
+
@replace_return_docstrings(output_type=Gemma3OmniCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 357 |
+
def forward(
|
| 358 |
+
self,
|
| 359 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 360 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 361 |
+
input_audio_embeds: torch.FloatTensor = None,
|
| 362 |
+
audio_embed_sizes: torch.FloatTensor = None,
|
| 363 |
+
audio_attention_mask: torch.FloatTensor = None,
|
| 364 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 365 |
+
input_modes: torch.LongTensor = None,
|
| 366 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 367 |
+
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
| 368 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 369 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 370 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 371 |
+
labels: Optional[torch.LongTensor] = None,
|
| 372 |
+
use_cache: Optional[bool] = None,
|
| 373 |
+
output_attentions: Optional[bool] = None,
|
| 374 |
+
output_hidden_states: Optional[bool] = None,
|
| 375 |
+
return_dict: Optional[bool] = None,
|
| 376 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 377 |
+
**lm_kwargs,
|
| 378 |
+
) -> Union[Tuple, Gemma3OmniCausalLMOutputWithPast]:
|
| 379 |
+
r"""
|
| 380 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 381 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 382 |
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 383 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
| 384 |
+
|
| 385 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 386 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 387 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 388 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 389 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 390 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
|
| 394 |
+
Example:
|
| 395 |
+
|
| 396 |
+
```python
|
| 397 |
+
>>> from PIL import Image
|
| 398 |
+
>>> import requests
|
| 399 |
+
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 400 |
+
|
| 401 |
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
| 402 |
+
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
| 403 |
+
|
| 404 |
+
>>> messages = [
|
| 405 |
+
... {
|
| 406 |
+
... "role": "system",
|
| 407 |
+
... "content": [
|
| 408 |
+
... {"type": "text", "text": "You are a helpful assistant."}
|
| 409 |
+
... ]
|
| 410 |
+
... },
|
| 411 |
+
... {
|
| 412 |
+
... "role": "user", "content": [
|
| 413 |
+
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
| 414 |
+
... {"type": "text", "text": "Where is the cat standing?"},
|
| 415 |
+
... ]
|
| 416 |
+
... },
|
| 417 |
+
... ]
|
| 418 |
+
|
| 419 |
+
>>> inputs = processor.apply_chat_template(
|
| 420 |
+
... messages,
|
| 421 |
+
... tokenizer=True,
|
| 422 |
+
... return_dict=True,
|
| 423 |
+
... return_tensors="pt",
|
| 424 |
+
... add_generation_prompt=True
|
| 425 |
+
... )
|
| 426 |
+
>>> # Generate
|
| 427 |
+
>>> generate_ids = model.generate(**inputs)
|
| 428 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 429 |
+
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
| 430 |
+
```
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 434 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 435 |
+
|
| 436 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 437 |
+
output_hidden_states = (
|
| 438 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 439 |
+
)
|
| 440 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 441 |
+
|
| 442 |
+
if isinstance(input_modes, torch.Tensor):
|
| 443 |
+
# len(input_mode) == num_beams in beam search, and all elements of input_mode should have the same value
|
| 444 |
+
input_modes = input_modes.unique()
|
| 445 |
+
if len(input_modes) != 1:
|
| 446 |
+
raise ValueError("Elements of input_modes should have the same value")
|
| 447 |
+
|
| 448 |
+
input_mode = InputMode(input_modes.item())
|
| 449 |
+
|
| 450 |
+
if input_mode in [InputMode.VISION_SPEECH, InputMode.VISION]:
|
| 451 |
+
self.unset_lora_adapter()
|
| 452 |
+
#self.set_lora_adapter('vision')
|
| 453 |
+
#audio_projection_mode = 'vision'
|
| 454 |
+
elif input_mode == InputMode.SPEECH:
|
| 455 |
+
self.unset_lora_adapter()
|
| 456 |
+
self.set_lora_adapter('speech')
|
| 457 |
+
#audio_projection_mode = 'speech'
|
| 458 |
+
elif input_mode == InputMode.LANGUAGE:
|
| 459 |
+
self.unset_lora_adapter()
|
| 460 |
+
self.set_lora_adapter('text')
|
| 461 |
+
|
| 462 |
+
#audio_projection_mode = 'speech'
|
| 463 |
+
else:
|
| 464 |
+
raise ValueError(f"Invalid input_mode: {input_mode}")
|
| 465 |
+
|
| 466 |
+
is_training = token_type_ids is not None and labels is not None
|
| 467 |
+
|
| 468 |
+
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
| 469 |
+
if input_ids is not None and self.config.image_token_index >= self.vocab_size or self.config.audio_token_index >= self.vocab_size:
|
| 470 |
+
special_image_mask = input_ids == self.config.image_token_index
|
| 471 |
+
special_audio_mask = input_ids == self.config.audio_token_index
|
| 472 |
+
llm_input_ids = input_ids.clone()
|
| 473 |
+
llm_input_ids[special_image_mask] = 0
|
| 474 |
+
llm_input_ids[special_audio_mask] = 0
|
| 475 |
+
else:
|
| 476 |
+
llm_input_ids = input_ids
|
| 477 |
+
|
| 478 |
+
if inputs_embeds is None:
|
| 479 |
+
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
| 480 |
+
inputs_embeds = inputs_embeds.to(dtype=self.dtype)
|
| 481 |
+
if cache_position is None:
|
| 482 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 483 |
+
cache_position = torch.arange(
|
| 484 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if position_ids is None:
|
| 488 |
+
position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
|
| 489 |
+
|
| 490 |
+
# Merge text and images
|
| 491 |
+
if pixel_values is not None:
|
| 492 |
+
image_features = self.get_image_features(pixel_values)
|
| 493 |
+
|
| 494 |
+
if input_ids is None:
|
| 495 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 496 |
+
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
| 500 |
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 501 |
+
|
| 502 |
+
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 503 |
+
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
| 504 |
+
raise ValueError(
|
| 505 |
+
f"Number of images does not match number of special image tokens in the input text. "
|
| 506 |
+
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
| 507 |
+
"tokens from image embeddings."
|
| 508 |
+
)
|
| 509 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 510 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
| 511 |
+
|
| 512 |
+
# Merge text and audios
|
| 513 |
+
if input_audio_embeds is not None:
|
| 514 |
+
input_audio_embeds=input_audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 515 |
+
if audio_attention_mask is not None:
|
| 516 |
+
audio_attention_mask=audio_attention_mask.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 517 |
+
audio_features = self.get_audio_features(input_audio_embeds, audio_attention_mask, audio_embed_sizes)
|
| 518 |
+
if input_ids is None:
|
| 519 |
+
special_audio_mask = inputs_embeds == self.get_input_embeddings()(
|
| 520 |
+
torch.tensor(self.config.audio_token_index, dtype=torch.long, device=inputs_embeds.device)
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
special_audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1)
|
| 524 |
+
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 525 |
+
masked_audio_features = []
|
| 526 |
+
for i, size in enumerate(audio_embed_sizes):
|
| 527 |
+
masked_audio_features.append(audio_features[i, :size, :])
|
| 528 |
+
masked_audio_features = torch.cat(masked_audio_features, dim=0)
|
| 529 |
+
|
| 530 |
+
if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != masked_audio_features.numel():
|
| 531 |
+
audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
|
| 532 |
+
masked_audio_size = audio_embed_sizes#.sum()[0]
|
| 533 |
+
raise ValueError(
|
| 534 |
+
f"Number of audio does not match number of special audio tokens in the input text. "
|
| 535 |
+
f"Got {audio_tokens_in_text} audio tokens in the text but {masked_audio_size} "
|
| 536 |
+
"tokens from audio embeddings. "
|
| 537 |
+
f"{masked_audio_features.numel()} \n"
|
| 538 |
+
f"{inputs_embeds[special_audio_mask].numel()} \n"
|
| 539 |
+
f"{audio_features} \n"
|
| 540 |
+
f"{inputs_embeds[special_audio_mask]} \n"
|
| 541 |
+
f"{special_audio_mask} \n"
|
| 542 |
+
)
|
| 543 |
+
masked_audio_features = masked_audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 544 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, masked_audio_features)
|
| 545 |
+
# mask out pad-token-ids in labels for BC
|
| 546 |
+
if labels is not None and self.pad_token_id in labels:
|
| 547 |
+
logger.warning_once(
|
| 548 |
+
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
| 549 |
+
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
| 550 |
+
)
|
| 551 |
+
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
|
| 552 |
+
|
| 553 |
+
causal_mask = self._update_causal_mask(
|
| 554 |
+
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
| 555 |
+
)
|
| 556 |
+
outputs = self.language_model(
|
| 557 |
+
attention_mask=causal_mask,
|
| 558 |
+
position_ids=position_ids,
|
| 559 |
+
past_key_values=past_key_values,
|
| 560 |
+
inputs_embeds=inputs_embeds,
|
| 561 |
+
use_cache=use_cache,
|
| 562 |
+
output_attentions=output_attentions,
|
| 563 |
+
output_hidden_states=output_hidden_states,
|
| 564 |
+
return_dict=return_dict,
|
| 565 |
+
cache_position=cache_position,
|
| 566 |
+
logits_to_keep=logits_to_keep,
|
| 567 |
+
**lm_kwargs,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
logits = outputs.logits
|
| 571 |
+
loss = None
|
| 572 |
+
# print('#############################')
|
| 573 |
+
# print(logits)
|
| 574 |
+
if labels is not None:
|
| 575 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
| 576 |
+
logits = logits.float()
|
| 577 |
+
shift_logits = logits[..., :-1, :]
|
| 578 |
+
shift_labels = labels[..., 1:]
|
| 579 |
+
if attention_mask is not None:
|
| 580 |
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
| 581 |
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
| 582 |
+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
| 583 |
+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
| 584 |
+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
| 585 |
+
else:
|
| 586 |
+
shift_logits = shift_logits.contiguous()
|
| 587 |
+
shift_labels = shift_labels.contiguous()
|
| 588 |
+
# Flatten the tokens
|
| 589 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 590 |
+
|
| 591 |
+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
| 592 |
+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
| 593 |
+
loss = loss_fct(flat_logits, flat_labels)
|
| 594 |
+
# print('flat logits',flat_logits)
|
| 595 |
+
# print(flat_labels)
|
| 596 |
+
# print(loss)
|
| 597 |
+
if not return_dict:
|
| 598 |
+
output = (logits,) + outputs[1:]
|
| 599 |
+
return (loss,) + output if loss is not None else output
|
| 600 |
+
|
| 601 |
+
return Gemma3OmniCausalLMOutputWithPast(
|
| 602 |
+
loss=loss,
|
| 603 |
+
logits=logits,
|
| 604 |
+
past_key_values=outputs.past_key_values,
|
| 605 |
+
hidden_states=outputs.hidden_states,
|
| 606 |
+
attentions=outputs.attentions,
|
| 607 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
| 608 |
+
audio_hidden_states=audio_features if input_audio_embeds is not None else None,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
def prepare_inputs_for_generation(
|
| 612 |
+
self,
|
| 613 |
+
input_ids,
|
| 614 |
+
past_key_values=None,
|
| 615 |
+
input_modes=None,
|
| 616 |
+
inputs_embeds=None,
|
| 617 |
+
cache_position=None,
|
| 618 |
+
position_ids=None,
|
| 619 |
+
pixel_values=None,
|
| 620 |
+
input_audio_embeds=None,
|
| 621 |
+
audio_embed_sizes=None,
|
| 622 |
+
audio_attention_mask=None,
|
| 623 |
+
attention_mask=None,
|
| 624 |
+
token_type_ids=None,
|
| 625 |
+
use_cache=True,
|
| 626 |
+
logits_to_keep=None,
|
| 627 |
+
labels=None,
|
| 628 |
+
**kwargs,
|
| 629 |
+
):
|
| 630 |
+
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
| 631 |
+
model_inputs = self.language_model.prepare_inputs_for_generation(
|
| 632 |
+
input_ids,
|
| 633 |
+
past_key_values=past_key_values,
|
| 634 |
+
input_modes=input_modes,
|
| 635 |
+
inputs_embeds=inputs_embeds,
|
| 636 |
+
attention_mask=attention_mask,
|
| 637 |
+
position_ids=position_ids,
|
| 638 |
+
cache_position=cache_position,
|
| 639 |
+
use_cache=use_cache,
|
| 640 |
+
logits_to_keep=logits_to_keep,
|
| 641 |
+
token_type_ids=token_type_ids,
|
| 642 |
+
**kwargs,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# position_ids in Gemma3 are 1-indexed
|
| 646 |
+
if model_inputs.get("position_ids") is not None:
|
| 647 |
+
model_inputs["position_ids"] += 1
|
| 648 |
+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 649 |
+
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
| 650 |
+
if cache_position[0] == 0:
|
| 651 |
+
model_inputs["pixel_values"] = pixel_values
|
| 652 |
+
model_inputs["input_audio_embeds"] = input_audio_embeds
|
| 653 |
+
model_inputs["audio_embed_sizes"] = audio_embed_sizes
|
| 654 |
+
model_inputs["audio_attention_mask"] = audio_attention_mask
|
| 655 |
+
model_inputs["input_modes"] = input_modes
|
| 656 |
+
is_training = token_type_ids is not None and labels is not None
|
| 657 |
+
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
| 658 |
+
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
| 659 |
+
causal_mask = self._update_causal_mask(
|
| 660 |
+
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
| 661 |
+
)
|
| 662 |
+
model_inputs["attention_mask"] = causal_mask
|
| 663 |
+
|
| 664 |
+
return model_inputs
|
| 665 |
+
|
| 666 |
+
def tie_weights(self):
|
| 667 |
+
return self.language_model.tie_weights()
|
| 668 |
+
|
cpp/gemma_v1/preprocessing_gemma3omni.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Optional, Union, Tuple
|
| 3 |
+
from math import ceil
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import scipy
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
from transformers import AutoFeatureExtractor
|
| 13 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 14 |
+
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
| 15 |
+
from transformers.image_utils import ImageInput, make_nested_list_of_images
|
| 16 |
+
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
|
| 17 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 18 |
+
from transformers.utils import to_py_obj, TensorType
|
| 19 |
+
from transformers.audio_utils import AudioInput
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Gemma3ImagesKwargs(ImagesKwargs):
|
| 23 |
+
do_pan_and_scan: Optional[bool]
|
| 24 |
+
pan_and_scan_min_crop_size: Optional[int]
|
| 25 |
+
pan_and_scan_max_num_crops: Optional[int]
|
| 26 |
+
pan_and_scan_min_ratio_to_activate: Optional[float]
|
| 27 |
+
do_convert_rgb: Optional[bool]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
| 31 |
+
images_kwargs: Gemma3ImagesKwargs
|
| 32 |
+
_defaults = {
|
| 33 |
+
"text_kwargs": {
|
| 34 |
+
"padding": False,
|
| 35 |
+
},
|
| 36 |
+
"images_kwargs": {
|
| 37 |
+
"do_pan_and_scan": False,
|
| 38 |
+
"pan_and_scan_min_crop_size": 256,
|
| 39 |
+
"pan_and_scan_max_num_crops": 4,
|
| 40 |
+
"pan_and_scan_min_ratio_to_activate": 1.2,
|
| 41 |
+
},
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
|
| 45 |
+
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
|
| 49 |
+
n_fft (int): FFT size. int > 0 [scalar]
|
| 50 |
+
n_mel (int): Mel filter size. int > 0 [scalar]
|
| 51 |
+
fmin (float): lowest frequency (in Hz). If None use 0.0.
|
| 52 |
+
float >= 0 [scalar]
|
| 53 |
+
fmax: highest frequency (in Hz). If None use sample_rate / 2.
|
| 54 |
+
float >= 0 [scalar]
|
| 55 |
+
|
| 56 |
+
Returns
|
| 57 |
+
out (numpy.ndarray): Mel transform matrix
|
| 58 |
+
[shape=(n_mels, 1 + n_fft/2)]
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
bank_width = int(n_fft // 2 + 1)
|
| 62 |
+
if fmax is None:
|
| 63 |
+
fmax = sample_rate / 2
|
| 64 |
+
if fmin is None:
|
| 65 |
+
fmin = 0
|
| 66 |
+
assert fmin >= 0, "fmin cannot be negtive"
|
| 67 |
+
assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
|
| 68 |
+
|
| 69 |
+
def mel(f):
|
| 70 |
+
return 1127.0 * np.log(1.0 + f / 700.0)
|
| 71 |
+
|
| 72 |
+
def bin2mel(fft_bin):
|
| 73 |
+
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
|
| 74 |
+
|
| 75 |
+
def f2bin(f):
|
| 76 |
+
return int((f * n_fft / sample_rate) + 0.5)
|
| 77 |
+
|
| 78 |
+
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
|
| 79 |
+
klo = f2bin(fmin) + 1
|
| 80 |
+
khi = f2bin(fmax)
|
| 81 |
+
|
| 82 |
+
khi = max(khi, klo)
|
| 83 |
+
|
| 84 |
+
# Spec 2: SpeechLib uses trianges in Mel space
|
| 85 |
+
mlo = mel(fmin)
|
| 86 |
+
mhi = mel(fmax)
|
| 87 |
+
m_centers = np.linspace(mlo, mhi, n_mels + 2)
|
| 88 |
+
ms = (mhi - mlo) / (n_mels + 1)
|
| 89 |
+
|
| 90 |
+
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
|
| 91 |
+
for m in range(0, n_mels):
|
| 92 |
+
left = m_centers[m]
|
| 93 |
+
center = m_centers[m + 1]
|
| 94 |
+
right = m_centers[m + 2]
|
| 95 |
+
for fft_bin in range(klo, khi):
|
| 96 |
+
mbin = bin2mel(fft_bin)
|
| 97 |
+
if left < mbin < right:
|
| 98 |
+
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
|
| 99 |
+
|
| 100 |
+
return matrix
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
| 104 |
+
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
|
| 105 |
+
|
| 106 |
+
def __init__(self, audio_compression_rate=8,
|
| 107 |
+
audio_downsample_rate=1,
|
| 108 |
+
audio_feat_stride=1,
|
| 109 |
+
feature_size = 80,
|
| 110 |
+
sampling_rate = 16000,
|
| 111 |
+
padding_value = 0.0,
|
| 112 |
+
**kwargs):
|
| 113 |
+
|
| 114 |
+
super().__init__(feature_size=feature_size,
|
| 115 |
+
sampling_rate=sampling_rate,
|
| 116 |
+
padding_value=padding_value, **kwargs)
|
| 117 |
+
|
| 118 |
+
self.compression_rate = audio_compression_rate
|
| 119 |
+
self.qformer_compression_rate = audio_downsample_rate
|
| 120 |
+
self.feat_stride = audio_feat_stride
|
| 121 |
+
|
| 122 |
+
self._eightk_method = "fillzero"
|
| 123 |
+
self._mel = speechlib_mel(self.sampling_rate, 512, self.feature_size, fmin=None, fmax=self.sampling_rate//2-self.feature_size-230).T
|
| 124 |
+
|
| 125 |
+
self._hamming400 = np.hamming(400) # for 16k audio
|
| 126 |
+
self._hamming200 = np.hamming(200) # for 8k audio
|
| 127 |
+
|
| 128 |
+
def duration_to_frames(self, duration):
|
| 129 |
+
"""duration in s, estimated frames"""
|
| 130 |
+
frame_rate = 10
|
| 131 |
+
|
| 132 |
+
num_frames = duration * 1000 // frame_rate
|
| 133 |
+
return num_frames
|
| 134 |
+
|
| 135 |
+
def __call__(
|
| 136 |
+
self,
|
| 137 |
+
audios: List[AudioInput],
|
| 138 |
+
sampling_rate = 16000,
|
| 139 |
+
return_attention_mask=True,
|
| 140 |
+
padding="max_length",
|
| 141 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 142 |
+
):
|
| 143 |
+
# Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
|
| 144 |
+
returned_input_audio_embeds = []
|
| 145 |
+
returned_audio_embed_sizes = []
|
| 146 |
+
audio_frames_list = []
|
| 147 |
+
|
| 148 |
+
for audio_data in audios:
|
| 149 |
+
audio_embeds = self._extract_features(audio_data, sampling_rate)
|
| 150 |
+
audio_frames = len(audio_embeds) * self.feat_stride
|
| 151 |
+
audio_embed_size = self._compute_audio_embed_size(audio_frames)
|
| 152 |
+
|
| 153 |
+
returned_input_audio_embeds.append(torch.tensor(audio_embeds))
|
| 154 |
+
returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
|
| 155 |
+
audio_frames_list.append(audio_frames)
|
| 156 |
+
|
| 157 |
+
returned_input_audio_embeds = pad_sequence(
|
| 158 |
+
returned_input_audio_embeds, batch_first=True
|
| 159 |
+
)
|
| 160 |
+
returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
|
| 161 |
+
audio_frames = torch.tensor(audio_frames_list)
|
| 162 |
+
returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
|
| 163 |
+
|
| 164 |
+
data = {
|
| 165 |
+
"input_audio_embeds": returned_input_audio_embeds,
|
| 166 |
+
"audio_embed_sizes": returned_audio_embed_sizes,
|
| 167 |
+
}
|
| 168 |
+
if returned_audio_attention_mask is not None and return_attention_mask:
|
| 169 |
+
data["audio_attention_mask"] = returned_audio_attention_mask
|
| 170 |
+
|
| 171 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 172 |
+
|
| 173 |
+
def _extract_spectrogram(self, wav, fs):
|
| 174 |
+
"""Extract spectrogram features from waveform.
|
| 175 |
+
Args:
|
| 176 |
+
wav (1D array): waveform of the input
|
| 177 |
+
fs (int): sampling rate of the waveform, 16000 or 8000.
|
| 178 |
+
If fs=8000, the waveform will be resampled to 16000Hz.
|
| 179 |
+
Output:
|
| 180 |
+
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
|
| 181 |
+
D=80, and T is the number of frames.
|
| 182 |
+
"""
|
| 183 |
+
if wav.ndim > 1:
|
| 184 |
+
wav = np.squeeze(wav)
|
| 185 |
+
|
| 186 |
+
# by default, we extract the mean if stereo
|
| 187 |
+
if len(wav.shape) == 2:
|
| 188 |
+
wav = wav.mean(1)
|
| 189 |
+
|
| 190 |
+
# Resample to 16000 or 8000 if needed
|
| 191 |
+
if fs > 16000:
|
| 192 |
+
wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
|
| 193 |
+
fs = 16000
|
| 194 |
+
elif 8000 < fs < 16000:
|
| 195 |
+
wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
|
| 196 |
+
fs = 8000
|
| 197 |
+
elif fs < 8000:
|
| 198 |
+
raise RuntimeError(f"Unsupported sample rate {fs}")
|
| 199 |
+
|
| 200 |
+
if fs == 8000:
|
| 201 |
+
if self._eightk_method == "resample":
|
| 202 |
+
# Input audio is 8 kHz. Convert to 16 kHz before feature
|
| 203 |
+
# extraction
|
| 204 |
+
wav = scipy.signal.resample_poly(wav, 2, 1)
|
| 205 |
+
fs = 16000
|
| 206 |
+
# Do nothing here for fillzero method
|
| 207 |
+
elif fs != 16000:
|
| 208 |
+
# Input audio is not a supported sample rate.
|
| 209 |
+
raise RuntimeError(f"Input data using an unsupported sample rate: {fs}")
|
| 210 |
+
|
| 211 |
+
preemphasis = 0.97
|
| 212 |
+
|
| 213 |
+
if fs == 8000:
|
| 214 |
+
n_fft = 256
|
| 215 |
+
win_length = 200
|
| 216 |
+
hop_length = 80
|
| 217 |
+
fft_window = self._hamming200
|
| 218 |
+
elif fs == 16000:
|
| 219 |
+
n_fft = 512
|
| 220 |
+
win_length = 400
|
| 221 |
+
hop_length = 160
|
| 222 |
+
fft_window = self._hamming400
|
| 223 |
+
|
| 224 |
+
# Spec 1: SpeechLib cut remaining sample insufficient for a hop
|
| 225 |
+
n_batch = (wav.shape[0] - win_length) // hop_length + 1
|
| 226 |
+
# Here we don't use stride_tricks since the input array may not satisfy
|
| 227 |
+
# memory layout requirement and we need writeable output
|
| 228 |
+
# Here we only use list of views before copy to desination
|
| 229 |
+
# so it is more efficient than broadcasting
|
| 230 |
+
y_frames = np.array(
|
| 231 |
+
[wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)],
|
| 232 |
+
dtype=np.float32,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Spec 2: SpeechLib applies preemphasis within each batch
|
| 236 |
+
y_frames_prev = np.roll(y_frames, 1, axis=1)
|
| 237 |
+
y_frames_prev[:, 0] = y_frames_prev[:, 1]
|
| 238 |
+
y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
|
| 239 |
+
|
| 240 |
+
S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64)
|
| 241 |
+
|
| 242 |
+
if fs == 8000:
|
| 243 |
+
# Need to pad the output to look like 16 kHz data but with zeros in
|
| 244 |
+
# the 4 to 8 kHz bins.
|
| 245 |
+
frames, bins = S.shape
|
| 246 |
+
padarray = np.zeros((frames, bins))
|
| 247 |
+
S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero
|
| 248 |
+
|
| 249 |
+
spec = np.abs(S).astype(np.float32)
|
| 250 |
+
return spec
|
| 251 |
+
|
| 252 |
+
def _extract_features(self, wav, fs):
|
| 253 |
+
"""Extract log filterbank features from waveform.
|
| 254 |
+
Args:
|
| 255 |
+
wav (1D array): waveform of the input
|
| 256 |
+
fs (int): sampling rate of the waveform, 16000 or 8000.
|
| 257 |
+
If fs=8000, the waveform will be resampled to 16000Hz.
|
| 258 |
+
Output:
|
| 259 |
+
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
|
| 260 |
+
D=80, and T is the number of frames.
|
| 261 |
+
"""
|
| 262 |
+
spec = self._extract_spectrogram(wav, fs)
|
| 263 |
+
spec_power = spec**2
|
| 264 |
+
|
| 265 |
+
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
|
| 266 |
+
log_fbank = np.log(fbank_power).astype(np.float32)
|
| 267 |
+
|
| 268 |
+
return log_fbank
|
| 269 |
+
|
| 270 |
+
def _compute_audio_embed_size(self, audio_frames):
|
| 271 |
+
integer = audio_frames // self.compression_rate
|
| 272 |
+
remainder = audio_frames % self.compression_rate
|
| 273 |
+
|
| 274 |
+
result = integer if remainder == 0 else integer + 1
|
| 275 |
+
|
| 276 |
+
integer = result // self.qformer_compression_rate
|
| 277 |
+
remainder = result % self.qformer_compression_rate
|
| 278 |
+
result = integer if remainder == 0 else integer + 1 # qformer compression
|
| 279 |
+
|
| 280 |
+
return result
|
| 281 |
+
|
| 282 |
+
class Gemma3OmniProcessor(ProcessorMixin):
|
| 283 |
+
attributes = ["image_processor", "feature_extractor", "tokenizer"]
|
| 284 |
+
valid_kwargs = ["chat_template", "image_seq_length"]
|
| 285 |
+
image_processor_class = "AutoImageProcessor"
|
| 286 |
+
feature_extractor_class = "Gemma3AudioFeatureExtractor"
|
| 287 |
+
tokenizer_class = "AutoTokenizer"
|
| 288 |
+
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
image_processor,
|
| 292 |
+
feature_extractor,
|
| 293 |
+
tokenizer,
|
| 294 |
+
chat_template=None,
|
| 295 |
+
image_seq_length: int = 256,
|
| 296 |
+
**kwargs,
|
| 297 |
+
):
|
| 298 |
+
self.image_seq_length = image_seq_length
|
| 299 |
+
self.image_token_id = tokenizer.image_token_id
|
| 300 |
+
self.boi_token = tokenizer.boi_token
|
| 301 |
+
self.image_token = tokenizer.image_token
|
| 302 |
+
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
|
| 303 |
+
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
| 304 |
+
|
| 305 |
+
self.audio_token_id = tokenizer.audio_token_id
|
| 306 |
+
self.boa_token = tokenizer.boa_token
|
| 307 |
+
self.eoa_token = tokenizer.eoa_token
|
| 308 |
+
self.audio_token = tokenizer.audio_token
|
| 309 |
+
|
| 310 |
+
super().__init__(
|
| 311 |
+
image_processor=image_processor,
|
| 312 |
+
feature_extractor=feature_extractor,
|
| 313 |
+
tokenizer=tokenizer,
|
| 314 |
+
chat_template=chat_template,
|
| 315 |
+
**kwargs,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def __call__(
|
| 319 |
+
self,
|
| 320 |
+
images: ImageInput = None,
|
| 321 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 322 |
+
videos=None,
|
| 323 |
+
audio: List[AudioInput] = None,
|
| 324 |
+
**kwargs: Unpack[Gemma3ProcessorKwargs],
|
| 325 |
+
) -> BatchFeature:
|
| 326 |
+
if text is None and images is None:
|
| 327 |
+
raise ValueError("Provide at least one of `text` or `images`.")
|
| 328 |
+
|
| 329 |
+
output_kwargs = self._merge_kwargs(
|
| 330 |
+
Gemma3ProcessorKwargs,
|
| 331 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 332 |
+
**kwargs,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if isinstance(text, str):
|
| 336 |
+
text = [text]
|
| 337 |
+
elif not isinstance(text, list) and not isinstance(text[0], str):
|
| 338 |
+
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
| 339 |
+
|
| 340 |
+
image_inputs = {}
|
| 341 |
+
if images is not None:
|
| 342 |
+
batched_images = make_nested_list_of_images(images)
|
| 343 |
+
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
|
| 344 |
+
|
| 345 |
+
# Create empty text to be replaced with placeholders
|
| 346 |
+
if not text:
|
| 347 |
+
text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
|
| 348 |
+
|
| 349 |
+
if len(batched_images) != len(text):
|
| 350 |
+
raise ValueError(
|
| 351 |
+
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Replace image tokens by the full expanded sequence
|
| 355 |
+
num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
| 356 |
+
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
|
| 357 |
+
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
| 358 |
+
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
| 359 |
+
|
| 360 |
+
if len(images) != len(image_indexes):
|
| 361 |
+
raise ValueError(
|
| 362 |
+
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Insert additional image tokens for Pan-and-Scan crops
|
| 366 |
+
for num, idx in reversed(list(zip(num_crops, image_indexes))):
|
| 367 |
+
if num:
|
| 368 |
+
formatted_image_text = (
|
| 369 |
+
f"Here is the original image {self.boi_token} and here are some crops to help you see better "
|
| 370 |
+
+ " ".join([self.boi_token] * num)
|
| 371 |
+
)
|
| 372 |
+
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
|
| 373 |
+
text[batch_idx] = prompt
|
| 374 |
+
|
| 375 |
+
# Expand placeholder image tokens to the full image token sequence
|
| 376 |
+
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
| 377 |
+
|
| 378 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 379 |
+
|
| 380 |
+
audio_inputs = {}
|
| 381 |
+
if audio is not None:
|
| 382 |
+
full_audio_sequences = []
|
| 383 |
+
audio_inputs = self.feature_extractor(audio)
|
| 384 |
+
def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
|
| 385 |
+
parts = prompt.split(boa_token)
|
| 386 |
+
result = ""
|
| 387 |
+
for i in range(len(parts) - 1):
|
| 388 |
+
result += parts[i]
|
| 389 |
+
if i < len(audio_sequences):
|
| 390 |
+
result += audio_sequences[i]
|
| 391 |
+
else:
|
| 392 |
+
result += boa_token
|
| 393 |
+
result += parts[-1]
|
| 394 |
+
return result
|
| 395 |
+
for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
|
| 396 |
+
audio_tokens_expanded = "".join([self.audio_token] * embed_size)
|
| 397 |
+
full_audio_sequence = f"\n\n{self.boa_token}{audio_tokens_expanded}{self.eoa_token}\n\n"
|
| 398 |
+
full_audio_sequences.append(full_audio_sequence)
|
| 399 |
+
|
| 400 |
+
text = [replace_tokens_sequentially(prompt, self.boa_token, [audio_sequences]) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
|
| 401 |
+
#text = [prompt.replace(self.boa_token, audio_sequences) for (prompt, audio_sequences) in zip(text, full_audio_sequences)]
|
| 402 |
+
|
| 403 |
+
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
| 404 |
+
|
| 405 |
+
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
| 406 |
+
array_ids = text_inputs["input_ids"]
|
| 407 |
+
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
| 408 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 409 |
+
mm_token_type_ids[array_ids == self.audio_token_id] = 2
|
| 410 |
+
|
| 411 |
+
has_vision_ids = np.any(mm_token_type_ids == 1, axis=1)
|
| 412 |
+
has_audio_ids = np.any(mm_token_type_ids == 2, axis=1)
|
| 413 |
+
|
| 414 |
+
input_modes = (has_audio_ids << 1) | has_vision_ids
|
| 415 |
+
|
| 416 |
+
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
|
| 417 |
+
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
| 418 |
+
text_inputs["input_modes"] = input_modes.tolist()
|
| 419 |
+
|
| 420 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
|
| 421 |
+
|
| 422 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
| 423 |
+
def batch_decode(self, *args, **kwargs):
|
| 424 |
+
"""
|
| 425 |
+
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 426 |
+
refer to the docstring of this method for more information.
|
| 427 |
+
"""
|
| 428 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 429 |
+
|
| 430 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
|
| 431 |
+
def decode(self, *args, **kwargs):
|
| 432 |
+
"""
|
| 433 |
+
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 434 |
+
the docstring of this method for more information.
|
| 435 |
+
"""
|
| 436 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 437 |
+
|
| 438 |
+
@property
|
| 439 |
+
def model_input_names(self):
|
| 440 |
+
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
|
| 441 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 442 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 443 |
+
|
| 444 |
+
AutoFeatureExtractor.register("Gemma3AudioFeatureExtractor", Gemma3AudioFeatureExtractor)
|
cpp/gemma_v1/preprocessor_config.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"audio_compression_rate": 8,
|
| 3 |
+
"audio_downsample_rate": 1,
|
| 4 |
+
"audio_feat_stride": 1,
|
| 5 |
+
"compression_rate": 8,
|
| 6 |
+
"do_convert_rgb": null,
|
| 7 |
+
"do_normalize": true,
|
| 8 |
+
"do_pan_and_scan": null,
|
| 9 |
+
"do_rescale": true,
|
| 10 |
+
"do_resize": true,
|
| 11 |
+
"feat_stride": 1,
|
| 12 |
+
"feature_extractor_type": "Gemma3AudioFeatureExtractor",
|
| 13 |
+
"feature_size": 80,
|
| 14 |
+
"image_mean": [
|
| 15 |
+
0.5,
|
| 16 |
+
0.5,
|
| 17 |
+
0.5
|
| 18 |
+
],
|
| 19 |
+
"image_processor_type": "Gemma3ImageProcessor",
|
| 20 |
+
"image_seq_length": 256,
|
| 21 |
+
"image_std": [
|
| 22 |
+
0.5,
|
| 23 |
+
0.5,
|
| 24 |
+
0.5
|
| 25 |
+
],
|
| 26 |
+
"padding_side": "right",
|
| 27 |
+
"padding_value": 0.0,
|
| 28 |
+
"pan_and_scan_max_num_crops": null,
|
| 29 |
+
"pan_and_scan_min_crop_size": null,
|
| 30 |
+
"pan_and_scan_min_ratio_to_activate": null,
|
| 31 |
+
"processor_class": "Gemma3OmniProcessor",
|
| 32 |
+
"qformer_compression_rate": 1,
|
| 33 |
+
"resample": 2,
|
| 34 |
+
"rescale_factor": 0.00392156862745098,
|
| 35 |
+
"return_attention_mask": true,
|
| 36 |
+
"sampling_rate": 16000,
|
| 37 |
+
"size": {
|
| 38 |
+
"height": 896,
|
| 39 |
+
"width": 896
|
| 40 |
+
}
|
| 41 |
+
}
|
cpp/gemma_v1/processor_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "preprocessing_gemma3omni.Gemma3OmniProcessor"
|
| 4 |
+
},
|
| 5 |
+
"image_seq_length": 256,
|
| 6 |
+
"processor_class": "Gemma3Processor"
|
| 7 |
+
}
|
cpp/gemma_v1/special_tokens_map.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"audio_token": "<audio_soft_token>",
|
| 3 |
+
"boa_token": "<start_of_audio>",
|
| 4 |
+
"boi_token": "<start_of_image>",
|
| 5 |
+
"bos_token": {
|
| 6 |
+
"content": "<bos>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"eoa_token": "<end_of_audio>",
|
| 13 |
+
"eoi_token": "<end_of_image>",
|
| 14 |
+
"eos_token": {
|
| 15 |
+
"content": "<eos>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"image_token": "<image_soft_token>",
|
| 22 |
+
"pad_token": {
|
| 23 |
+
"content": "<pad>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false
|
| 28 |
+
},
|
| 29 |
+
"unk_token": {
|
| 30 |
+
"content": "<unk>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false
|
| 35 |
+
}
|
| 36 |
+
}
|
cpp/gemma_v1/speech_conformer_encoder.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cpp/gemma_v1/speech_conformer_encoder_old.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cpp/gemma_v1/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52941f2ba60fdcc48edb940f4252f6d874d0c369323dab293168015122e556be
|
| 3 |
+
size 33384559
|
cpp/gemma_v1/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
| 3 |
+
size 4689074
|
cpp/gemma_v1/tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cpp/gemma_v1/training.py
ADDED
|
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import sacrebleu
|
| 14 |
+
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoProcessor,
|
| 20 |
+
AutoModel,
|
| 21 |
+
BatchFeature,
|
| 22 |
+
Trainer,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
StoppingCriteria,
|
| 25 |
+
StoppingCriteriaList,
|
| 26 |
+
)
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
|
| 29 |
+
import soundfile as sf
|
| 30 |
+
from datasets import Audio
|
| 31 |
+
import random
|
| 32 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 33 |
+
_IGNORE_INDEX = -100
|
| 34 |
+
class BaseAudioDataset(Dataset):
|
| 35 |
+
def __init__(self, processor, split, sampling_rate=16000, debug=False):
|
| 36 |
+
self.processor = processor
|
| 37 |
+
self.training = "train" in split or 'other' in split
|
| 38 |
+
self.debug = debug
|
| 39 |
+
self.sampling_rate = sampling_rate
|
| 40 |
+
self.name = ""
|
| 41 |
+
|
| 42 |
+
def set_dataset_name(self, name):
|
| 43 |
+
self.name = name
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
|
| 47 |
+
original_size = len(data)
|
| 48 |
+
|
| 49 |
+
data = data.cast_column(audio_field, Audio(decode=False))
|
| 50 |
+
|
| 51 |
+
def identify_corrupted_files(example):
|
| 52 |
+
try:
|
| 53 |
+
sf.read(example[audio_field]["path"])
|
| 54 |
+
|
| 55 |
+
for field in text_fields:
|
| 56 |
+
if field in example and example[field].replace('"', '') == "":
|
| 57 |
+
return False
|
| 58 |
+
return True
|
| 59 |
+
except Exception:
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
data = data.filter(identify_corrupted_files, num_proc=16)
|
| 63 |
+
validated_size = len(data)
|
| 64 |
+
|
| 65 |
+
# Audio Decoding
|
| 66 |
+
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
|
| 67 |
+
|
| 68 |
+
if debug:
|
| 69 |
+
print(f"Dataset: {dataset_name}")
|
| 70 |
+
print(f"Original data nums: {original_size}")
|
| 71 |
+
print(f"After filtering data nums: {validated_size}")
|
| 72 |
+
print(f"Filtering ratio: {validated_size/original_size:.2%}")
|
| 73 |
+
|
| 74 |
+
return data
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
|
| 78 |
+
original_size = len(data)
|
| 79 |
+
|
| 80 |
+
def filter_audio_by_length(example):
|
| 81 |
+
try:
|
| 82 |
+
audio = example[audio_field]['array']
|
| 83 |
+
channel = 1
|
| 84 |
+
if hasattr(audio, 'ndim') and audio.ndim > 1:
|
| 85 |
+
channel = audio.ndim
|
| 86 |
+
audio = audio.squeeze()
|
| 87 |
+
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
|
| 88 |
+
return min_sec <= audio_length <= max_sec
|
| 89 |
+
except Exception as e:
|
| 90 |
+
if debug:
|
| 91 |
+
print(f"Error : {str(e)[:100]}... - sample excluded")
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
data = data.filter(filter_audio_by_length, num_proc=16)
|
| 95 |
+
filtered_size = len(data)
|
| 96 |
+
|
| 97 |
+
if debug:
|
| 98 |
+
print(f"Before Length Filtering data nums: {original_size}")
|
| 99 |
+
print(f"After Length Filtering data nums: {filtered_size}")
|
| 100 |
+
print(f"Filtering ratio: {filtered_size/original_size:.2%}")
|
| 101 |
+
|
| 102 |
+
return data
|
| 103 |
+
|
| 104 |
+
def prepare_model_inputs(self, audio_array, instruction, answer_text):
|
| 105 |
+
user_message = {
|
| 106 |
+
'role': 'user',
|
| 107 |
+
'content': '<start_of_audio>' + instruction,
|
| 108 |
+
}
|
| 109 |
+
prompt = self.processor.tokenizer.apply_chat_template(
|
| 110 |
+
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
inputs = self.processor(
|
| 114 |
+
text=prompt,
|
| 115 |
+
audio=[audio_array],
|
| 116 |
+
add_special_tokens=False,
|
| 117 |
+
return_tensors='pt'
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
answer = f"{answer_text}{ANSWER_SUFFIX}"
|
| 121 |
+
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
|
| 122 |
+
|
| 123 |
+
if self.debug:
|
| 124 |
+
self.debug = False
|
| 125 |
+
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
|
| 126 |
+
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
|
| 127 |
+
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
|
| 128 |
+
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
|
| 129 |
+
|
| 130 |
+
if self.training:
|
| 131 |
+
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
|
| 132 |
+
labels = torch.full_like(input_ids, _IGNORE_INDEX)
|
| 133 |
+
labels[:, -answer_ids.shape[1]:] = answer_ids
|
| 134 |
+
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
|
| 135 |
+
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
|
| 136 |
+
else:
|
| 137 |
+
input_ids = inputs.input_ids
|
| 138 |
+
labels = answer_ids
|
| 139 |
+
token_type_ids = inputs.token_type_ids
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
'input_ids': input_ids,
|
| 143 |
+
'labels': labels,
|
| 144 |
+
'token_type_ids': token_type_ids,
|
| 145 |
+
'input_audio_embeds': inputs.input_audio_embeds,
|
| 146 |
+
'audio_embed_sizes': inputs.audio_embed_sizes,
|
| 147 |
+
'input_modes': inputs.input_modes,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# Libri Speech Dataset Class
|
| 152 |
+
class LibriSpeechDataset(BaseAudioDataset):
|
| 153 |
+
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
|
| 154 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 155 |
+
|
| 156 |
+
self.set_dataset_name(f"LibriSpeech_{subset}")
|
| 157 |
+
# only ASR
|
| 158 |
+
self.ast = False
|
| 159 |
+
self.lang = "en"
|
| 160 |
+
|
| 161 |
+
# load dataset
|
| 162 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/librispeech_asr",
|
| 163 |
+
subset,
|
| 164 |
+
split=split,
|
| 165 |
+
trust_remote_code=True,
|
| 166 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# (Optional) Audio length Filtering
|
| 170 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 171 |
+
|
| 172 |
+
# Instruction Setting
|
| 173 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 174 |
+
|
| 175 |
+
def __len__(self):
|
| 176 |
+
return len(self.data)
|
| 177 |
+
|
| 178 |
+
def __getitem__(self, idx):
|
| 179 |
+
data = self.data[idx]
|
| 180 |
+
|
| 181 |
+
# Libri Speech is only for ASR
|
| 182 |
+
answer_text = data["text"].replace('"', '')
|
| 183 |
+
|
| 184 |
+
return self.prepare_model_inputs(
|
| 185 |
+
data["audio"]["array"],
|
| 186 |
+
self.instruction,
|
| 187 |
+
answer_text
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# common_voice_16_1 dataset
|
| 191 |
+
class CommonVoiceDataset(BaseAudioDataset):
|
| 192 |
+
def __init__(self, processor, split, source_lang, sampling_rate=16000, debug=False):
|
| 193 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 194 |
+
|
| 195 |
+
self.set_dataset_name(f"CommonVoice_{source_lang}")
|
| 196 |
+
# only ASR
|
| 197 |
+
self.ast = False
|
| 198 |
+
self.lang=source_lang
|
| 199 |
+
|
| 200 |
+
# load dataset
|
| 201 |
+
if source_lang=="zh-TW":
|
| 202 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_16_1"
|
| 203 |
+
else:
|
| 204 |
+
data_path = "/mnt/jeff/InCar/data/common_voice_17_0"
|
| 205 |
+
self.data = load_dataset(data_path,
|
| 206 |
+
source_lang,
|
| 207 |
+
split=split,
|
| 208 |
+
trust_remote_code=True,
|
| 209 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 210 |
+
)
|
| 211 |
+
def prepare_dataset(batch):
|
| 212 |
+
"""Function to preprocess the dataset with the .map method"""
|
| 213 |
+
transcription = batch["sentence"]
|
| 214 |
+
|
| 215 |
+
if transcription.startswith('"') and transcription.endswith('"'):
|
| 216 |
+
# we can remove trailing quotation marks as they do not affect the transcription
|
| 217 |
+
transcription = transcription[1:-1]
|
| 218 |
+
|
| 219 |
+
if transcription[-1] not in [".", "?", "!"]:
|
| 220 |
+
# append a full-stop to sentences that do not end in punctuation
|
| 221 |
+
transcription = transcription + "."
|
| 222 |
+
|
| 223 |
+
batch["sentence"] = transcription
|
| 224 |
+
|
| 225 |
+
return batch
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
import opencc
|
| 229 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 230 |
+
def To_zhTW(batch):
|
| 231 |
+
|
| 232 |
+
transcription = converter.convert(batch["sentence"])
|
| 233 |
+
batch["sentence"] = transcription
|
| 234 |
+
|
| 235 |
+
return batch
|
| 236 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 237 |
+
if source_lang=='zh-CN':
|
| 238 |
+
self.data = self.data.map(To_zhTW, desc="preprocess dataset To_zhTW")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# (Optional) Audio length Filtering
|
| 242 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 243 |
+
|
| 244 |
+
if source_lang == "zh-TW" and split=='train':
|
| 245 |
+
import torchaudio
|
| 246 |
+
from torchaudio import transforms
|
| 247 |
+
import copy
|
| 248 |
+
import pickle
|
| 249 |
+
import os
|
| 250 |
+
def subsample(batch):
|
| 251 |
+
batch['audio']['array']=torchaudio.functional.resample(torch.FloatTensor(batch['audio']['array']), orig_freq=batch['audio']['sampling_rate'], new_freq=16000)
|
| 252 |
+
batch['audio']['sampling_rate']=16000
|
| 253 |
+
return batch
|
| 254 |
+
def TW_data_augment_fast(batch):
|
| 255 |
+
speed_perturb_fast = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [1.1])
|
| 256 |
+
new_array_fast = speed_perturb_fast(torch.FloatTensor(batch['audio']['array']))[0]
|
| 257 |
+
batch['audio']['array'] = new_array_fast
|
| 258 |
+
return batch
|
| 259 |
+
def TW_data_augment_slow(batch):
|
| 260 |
+
speed_perturb_slow = transforms.SpeedPerturbation(batch['audio']['sampling_rate'], [0.9])
|
| 261 |
+
new_array_slow = speed_perturb_slow(torch.FloatTensor(batch['audio']['array']))[0]
|
| 262 |
+
batch['audio']['array'] = new_array_slow
|
| 263 |
+
return batch
|
| 264 |
+
# data = self.data.map(subsample, num_proc=1, desc="subsample")
|
| 265 |
+
fast_path = '/mnt/jeff/InCar/data/tw_fast.pkl'
|
| 266 |
+
if not os.path.exists(fast_path):
|
| 267 |
+
data_fast = self.data.map(TW_data_augment_fast, num_proc=1, desc="augment fast")
|
| 268 |
+
with open(fast_path,'wb') as f:
|
| 269 |
+
pickle.dump(data_fast,f)
|
| 270 |
+
else:
|
| 271 |
+
with open(fast_path,'rb') as f:
|
| 272 |
+
data_fast=pickle.load(f)
|
| 273 |
+
|
| 274 |
+
slow_path = '/mnt/jeff/InCar/data/data_slow.pkl'
|
| 275 |
+
if not os.path.exists(slow_path):
|
| 276 |
+
data_slow = self.data.map(TW_data_augment_slow, num_proc=1, desc="augment slow")
|
| 277 |
+
with open(slow_path,'wb') as f:
|
| 278 |
+
pickle.dump(data_slow,f)
|
| 279 |
+
else:
|
| 280 |
+
with open(slow_path,'rb') as f:
|
| 281 |
+
data_slow=pickle.load(f)
|
| 282 |
+
self.data = [d for d in self.data]+[d for d in data_fast]+[d for d in data_slow]
|
| 283 |
+
|
| 284 |
+
# Instruction Setting
|
| 285 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 286 |
+
|
| 287 |
+
def __len__(self):
|
| 288 |
+
return len(self.data)
|
| 289 |
+
|
| 290 |
+
def __getitem__(self, idx):
|
| 291 |
+
data = self.data[idx]
|
| 292 |
+
|
| 293 |
+
answer_text = data["sentence"]
|
| 294 |
+
return self.prepare_model_inputs(
|
| 295 |
+
data["audio"]["array"],
|
| 296 |
+
self.instruction,
|
| 297 |
+
answer_text
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# Fleurs Dataset Class
|
| 302 |
+
class FleursDataset(BaseAudioDataset):
|
| 303 |
+
def __init__(self, processor, split, source_lang, target_lang=None,
|
| 304 |
+
mode="asr", sampling_rate=16000, debug=False):
|
| 305 |
+
super().__init__(processor, split, sampling_rate, debug)
|
| 306 |
+
|
| 307 |
+
self.set_dataset_name("Fleurs")
|
| 308 |
+
# Mode Setting (ASR or AST)
|
| 309 |
+
if mode not in ["asr", "ast"]:
|
| 310 |
+
raise ValueError("mode must be 'asr' or 'ast'.")
|
| 311 |
+
|
| 312 |
+
self.mode = mode
|
| 313 |
+
self.ast = (mode == "ast")
|
| 314 |
+
self.source_lang = source_lang
|
| 315 |
+
|
| 316 |
+
# Language name mapping (expand if needed)
|
| 317 |
+
self.lang_names = {
|
| 318 |
+
'en_us': 'English', 'cmn_hans': 'Mandarin Chinese'
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
# load dataset - source language dataset
|
| 322 |
+
self.data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 323 |
+
source_lang,
|
| 324 |
+
split=split,
|
| 325 |
+
trust_remote_code=True,
|
| 326 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 327 |
+
)
|
| 328 |
+
import opencc
|
| 329 |
+
converter = opencc.OpenCC('s2tw.json')
|
| 330 |
+
def prepare_dataset(batch):
|
| 331 |
+
transcription = converter.convert(batch["transcription"])
|
| 332 |
+
batch["transcription"] = transcription
|
| 333 |
+
|
| 334 |
+
return batch
|
| 335 |
+
if (source_lang=="cmn_hans_cn"):
|
| 336 |
+
self.data = self.data.map(prepare_dataset, desc="preprocess dataset")
|
| 337 |
+
|
| 338 |
+
# (Optional) Audio length Filtering
|
| 339 |
+
self.data = self.filter_by_audio_length(self.data, "audio")
|
| 340 |
+
self.target_lang_name = ""
|
| 341 |
+
# When AST mode, load target language dataset.
|
| 342 |
+
if self.ast:
|
| 343 |
+
if target_lang is None:
|
| 344 |
+
raise ValueError("AST mode requires target_lang.")
|
| 345 |
+
|
| 346 |
+
self.target_lang = target_lang
|
| 347 |
+
self.lang = f"{source_lang}_{target_lang}"
|
| 348 |
+
|
| 349 |
+
# load dataset - target language dataset (for translation)
|
| 350 |
+
target_data = load_dataset("/mnt/jeff/InCar/data/fleurs",
|
| 351 |
+
target_lang,
|
| 352 |
+
split=split,
|
| 353 |
+
trust_remote_code=True,
|
| 354 |
+
cache_dir=Path("/mnt/jeff/InCar/data")
|
| 355 |
+
)
|
| 356 |
+
if target_lang=="cmn_hans_cn":
|
| 357 |
+
target_data=target_data.map(prepare_dataset, desc="preprocess dataset")
|
| 358 |
+
source_dict = {item['id']: item for item in self.data}
|
| 359 |
+
target_dict = {item['id']: item for item in target_data}
|
| 360 |
+
|
| 361 |
+
# only Common ID, add translation fields
|
| 362 |
+
common_ids = set(source_dict.keys()) & set(target_dict.keys())
|
| 363 |
+
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
|
| 364 |
+
self.data = [
|
| 365 |
+
{**source_dict[id], 'translation': target_dict[id]['transcription']}
|
| 366 |
+
for id in common_ids
|
| 367 |
+
]
|
| 368 |
+
|
| 369 |
+
# Instruction Setting - use target language name
|
| 370 |
+
self.target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
|
| 371 |
+
self.instruction = random.choice(INSTRUCTION["ast"])
|
| 372 |
+
else:
|
| 373 |
+
# ASR mode
|
| 374 |
+
self.lang = source_lang
|
| 375 |
+
self.instruction = random.choice(INSTRUCTION["asr"])
|
| 376 |
+
|
| 377 |
+
if self.debug:
|
| 378 |
+
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
|
| 379 |
+
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
|
| 380 |
+
if self.ast:
|
| 381 |
+
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
|
| 382 |
+
print(f"dataset size: {len(self.data)}")
|
| 383 |
+
|
| 384 |
+
def __len__(self):
|
| 385 |
+
return len(self.data)
|
| 386 |
+
|
| 387 |
+
def __getitem__(self, idx):
|
| 388 |
+
data = self.data[idx]
|
| 389 |
+
audio_array = data["audio"]["array"]
|
| 390 |
+
|
| 391 |
+
if self.ast:
|
| 392 |
+
answer_text = data["translation"]
|
| 393 |
+
else:
|
| 394 |
+
answer_text = data["transcription"]
|
| 395 |
+
|
| 396 |
+
return self.prepare_model_inputs(
|
| 397 |
+
audio_array,
|
| 398 |
+
self.instruction.format(self.target_lang_name),
|
| 399 |
+
answer_text
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
def covost_collate_fn(batch):
|
| 403 |
+
input_ids_list = []
|
| 404 |
+
labels_list = []
|
| 405 |
+
token_type_ids_list = []
|
| 406 |
+
input_audio_embeds_list = []
|
| 407 |
+
audio_embed_sizes_list = []
|
| 408 |
+
audio_attention_mask_list = []
|
| 409 |
+
input_modes_list = []
|
| 410 |
+
for inputs in batch:
|
| 411 |
+
input_ids_list.append(inputs['input_ids'][0])
|
| 412 |
+
labels_list.append(inputs['labels'][0])
|
| 413 |
+
token_type_ids_list.append(inputs['token_type_ids'][0])
|
| 414 |
+
input_audio_embeds_list.append(inputs['input_audio_embeds'])
|
| 415 |
+
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
|
| 416 |
+
audio_attention_mask_list.append(
|
| 417 |
+
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
|
| 418 |
+
)
|
| 419 |
+
input_modes_list.append(inputs['input_modes'])
|
| 420 |
+
|
| 421 |
+
try:
|
| 422 |
+
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
|
| 423 |
+
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
|
| 424 |
+
labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
|
| 425 |
+
audio_attention_mask = (
|
| 426 |
+
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
|
| 427 |
+
if len(audio_attention_mask_list) > 1
|
| 428 |
+
else None
|
| 429 |
+
)
|
| 430 |
+
except Exception as e:
|
| 431 |
+
print(e)
|
| 432 |
+
print(input_ids_list)
|
| 433 |
+
print(labels_list)
|
| 434 |
+
raise
|
| 435 |
+
attention_mask = (input_ids != 0).long()
|
| 436 |
+
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
|
| 437 |
+
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
|
| 438 |
+
input_modes = torch.cat(input_modes_list)
|
| 439 |
+
|
| 440 |
+
return BatchFeature(
|
| 441 |
+
{
|
| 442 |
+
'input_ids': input_ids,
|
| 443 |
+
'labels': labels,
|
| 444 |
+
'token_type_ids': token_type_ids,
|
| 445 |
+
'attention_mask': attention_mask,
|
| 446 |
+
'input_audio_embeds': input_audio_embeds,
|
| 447 |
+
'audio_embed_sizes': audio_embed_sizes,
|
| 448 |
+
'audio_attention_mask': audio_attention_mask,
|
| 449 |
+
'input_modes': input_modes,
|
| 450 |
+
}
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
def pad_sequence(sequences, padding_side='left', padding_value=0):
|
| 454 |
+
"""
|
| 455 |
+
Pad a list of sequences to the same length.
|
| 456 |
+
sequences: list of tensors in [seq_len, *] shape
|
| 457 |
+
"""
|
| 458 |
+
assert padding_side in ['right', 'left']
|
| 459 |
+
max_size = sequences[0].size()
|
| 460 |
+
trailing_dims = max_size[1:]
|
| 461 |
+
max_len = max(len(seq) for seq in sequences)
|
| 462 |
+
batch_size = len(sequences)
|
| 463 |
+
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
|
| 464 |
+
for i, seq in enumerate(sequences):
|
| 465 |
+
length = seq.size(0)
|
| 466 |
+
if padding_side == 'right':
|
| 467 |
+
output.data[i, :length] = seq
|
| 468 |
+
else:
|
| 469 |
+
output.data[i, -length:] = seq
|
| 470 |
+
return output
|
| 471 |
+
|
| 472 |
+
def cat_with_pad(tensors, dim, padding_value=0):
|
| 473 |
+
"""
|
| 474 |
+
cat along dim, while pad to max for all other dims
|
| 475 |
+
"""
|
| 476 |
+
ndim = tensors[0].dim()
|
| 477 |
+
assert all(
|
| 478 |
+
t.dim() == ndim for t in tensors[1:]
|
| 479 |
+
), 'All tensors must have the same number of dimensions'
|
| 480 |
+
|
| 481 |
+
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
|
| 482 |
+
out_size[dim] = sum(t.shape[dim] for t in tensors)
|
| 483 |
+
output = tensors[0].new_full(out_size, padding_value)
|
| 484 |
+
|
| 485 |
+
index = 0
|
| 486 |
+
for t in tensors:
|
| 487 |
+
# Create a slice list where every dimension except dim is full slice
|
| 488 |
+
slices = [slice(0, t.shape[d]) for d in range(ndim)]
|
| 489 |
+
# Update only the concat dimension slice
|
| 490 |
+
slices[dim] = slice(index, index + t.shape[dim])
|
| 491 |
+
|
| 492 |
+
output[slices] = t
|
| 493 |
+
index += t.shape[dim]
|
| 494 |
+
|
| 495 |
+
return output
|
| 496 |
+
|
| 497 |
+
def count_parameters_by_module(model):
|
| 498 |
+
# dictionary for parameters number by modules
|
| 499 |
+
module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
|
| 500 |
+
|
| 501 |
+
# all params
|
| 502 |
+
total_params = 0
|
| 503 |
+
total_trainable_params = 0
|
| 504 |
+
|
| 505 |
+
# Check Embedding Token masks
|
| 506 |
+
embedding_masks = {}
|
| 507 |
+
for name, param in model.named_parameters():
|
| 508 |
+
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
|
| 509 |
+
# check if params has embedding_grad_mask_hook
|
| 510 |
+
for hook_id, hook_fn in param._backward_hooks.items():
|
| 511 |
+
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
|
| 512 |
+
# Accessing mask variables in the closure of hook functions
|
| 513 |
+
for cell in hook_fn.__closure__ or []:
|
| 514 |
+
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
|
| 515 |
+
# check mask tensor
|
| 516 |
+
embedding_masks[name] = ~cell.cell_contents # True : Trainable
|
| 517 |
+
|
| 518 |
+
# Count params by modules
|
| 519 |
+
for name, param in model.named_parameters():
|
| 520 |
+
# extracts top module_name
|
| 521 |
+
module_name = name.split('.')[0]
|
| 522 |
+
param_count = param.numel()
|
| 523 |
+
|
| 524 |
+
module_params[module_name]["total"] += param_count
|
| 525 |
+
total_params += param_count
|
| 526 |
+
|
| 527 |
+
if param.requires_grad:
|
| 528 |
+
# Only count for real trainable params. (with masks)
|
| 529 |
+
if name in embedding_masks:
|
| 530 |
+
trainable_count = embedding_masks[name].sum().item()
|
| 531 |
+
module_params[module_name]["trainable"] += trainable_count
|
| 532 |
+
total_trainable_params += trainable_count
|
| 533 |
+
else:
|
| 534 |
+
module_params[module_name]["trainable"] += param_count
|
| 535 |
+
total_trainable_params += param_count
|
| 536 |
+
|
| 537 |
+
print(f"All Params: {total_params:,}")
|
| 538 |
+
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
|
| 539 |
+
print("\nParams by Module:")
|
| 540 |
+
|
| 541 |
+
for module_name, counts in sorted(module_params.items()):
|
| 542 |
+
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
|
| 543 |
+
total_percentage = counts["total"] / total_params * 100
|
| 544 |
+
|
| 545 |
+
print(f"- {module_name}:")
|
| 546 |
+
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
|
| 547 |
+
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
|
| 548 |
+
|
| 549 |
+
return module_params
|
| 550 |
+
|
| 551 |
+
def create_model(model_name_or_path, revision="main", use_flash_attention = False):
|
| 552 |
+
model = AutoModel.from_pretrained(
|
| 553 |
+
model_name_or_path,
|
| 554 |
+
revision=revision,
|
| 555 |
+
torch_dtype=torch.bfloat16,
|
| 556 |
+
device_map="auto",
|
| 557 |
+
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
|
| 558 |
+
trust_remote_code=True,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# Set use_cache to False after model loaded
|
| 562 |
+
model.config.use_cache = False
|
| 563 |
+
|
| 564 |
+
# Freeze all parameters
|
| 565 |
+
for param in model.parameters():
|
| 566 |
+
param.requires_grad = False
|
| 567 |
+
|
| 568 |
+
model.set_lora_adapter('speech')
|
| 569 |
+
model.to(torch.bfloat16)
|
| 570 |
+
|
| 571 |
+
# (Optional) unfreeze audio_tower parameters
|
| 572 |
+
# for param in model.audio_tower.parameters():
|
| 573 |
+
# param.requires_grad = True
|
| 574 |
+
|
| 575 |
+
# Only unfreeze audio_projector parameters
|
| 576 |
+
for param in model.audio_projector.parameters():
|
| 577 |
+
param.requires_grad = True
|
| 578 |
+
|
| 579 |
+
# (Optional) unfreeze audio embed_tokens
|
| 580 |
+
train_embed = True
|
| 581 |
+
if train_embed:
|
| 582 |
+
embed_tokens = model.language_model.model.model.embed_tokens
|
| 583 |
+
|
| 584 |
+
embed_tokens.weight.requires_grad = False
|
| 585 |
+
|
| 586 |
+
# Added Speech token IDs (only this tokens be trainable)
|
| 587 |
+
trainable_token_ids = [256001, 256002]
|
| 588 |
+
|
| 589 |
+
embed_tokens.weight.requires_grad = True
|
| 590 |
+
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
|
| 591 |
+
mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
|
| 592 |
+
|
| 593 |
+
# backward hook, with gradient masking
|
| 594 |
+
def embedding_grad_mask_hook(grad):
|
| 595 |
+
return grad.masked_fill(mask, 0)
|
| 596 |
+
|
| 597 |
+
embed_tokens.weight.register_hook(embedding_grad_mask_hook)
|
| 598 |
+
|
| 599 |
+
model.language_model.model.model.embed_tokens = embed_tokens
|
| 600 |
+
|
| 601 |
+
count_parameters_by_module(model)
|
| 602 |
+
|
| 603 |
+
return model
|
| 604 |
+
|
| 605 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 606 |
+
|
| 607 |
+
INSTRUCTION = {
|
| 608 |
+
"ast": [
|
| 609 |
+
"Translate the audio to {0}.",
|
| 610 |
+
"Translate the audio clip into {0}.",
|
| 611 |
+
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
|
| 612 |
+
"Translate the provided audio file into {0}.",
|
| 613 |
+
"Convert the audio speech to {0} text.",
|
| 614 |
+
"Write an {0} translation of the audio file.",
|
| 615 |
+
"Translate spoken words from the audio into {0}.",
|
| 616 |
+
"Create an {0} version of the audio content.",
|
| 617 |
+
"Produce an accurate {0} translation of the audio.",
|
| 618 |
+
"Extract speech from the audio and translate it to {0}.",
|
| 619 |
+
"Turn the audio into readable {0} text.",
|
| 620 |
+
"Write all spoken content from the audio in {0}.",
|
| 621 |
+
"Generate an {0} translation of the speech in the file.",
|
| 622 |
+
"Convert the recording into {0} text.",
|
| 623 |
+
"Accurately translate the audio recording to {0}.",
|
| 624 |
+
"Write down dialogue from the given audio in {0}.",
|
| 625 |
+
"Translate all speech in this audio file to {0}.",
|
| 626 |
+
"Create an accurate {0} version of the speech.",
|
| 627 |
+
"Perform a complete {0} translation of the audio."
|
| 628 |
+
],
|
| 629 |
+
"asr": [
|
| 630 |
+
"Transcribe the audio clip into text.",
|
| 631 |
+
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
|
| 632 |
+
"Transcribe the provided audio file into text.",
|
| 633 |
+
"Convert the audio speech to text.",
|
| 634 |
+
"Write a transcript of the audio file.",
|
| 635 |
+
"Transcribe spoken words from the audio.",
|
| 636 |
+
"Create a text version of the audio content.",
|
| 637 |
+
"Produce a verbatim transcript of the audio.",
|
| 638 |
+
"Extract and transcribe speech from the audio.",
|
| 639 |
+
"Turn the audio into readable text.",
|
| 640 |
+
"Write all spoken words from the audio.",
|
| 641 |
+
"Generate a transcript of the speech in the file.",
|
| 642 |
+
"Convert the recording into a text transcript.",
|
| 643 |
+
"Accurately transcribe the audio recording.",
|
| 644 |
+
"Write down dialogue from the given audio.",
|
| 645 |
+
"Transcribe all speech in this audio file.",
|
| 646 |
+
"Create an accurate text version of the speech.",
|
| 647 |
+
"Perform a complete transcription of the audio."
|
| 648 |
+
],
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 652 |
+
_IGNORE_INDEX = -100
|
| 653 |
+
|
| 654 |
+
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
|
| 655 |
+
use_flash_attention = True
|
| 656 |
+
|
| 657 |
+
output_dir = '../gemma_tmp7'
|
| 658 |
+
batch_size = 128
|
| 659 |
+
batch_size_per_gpu = 16
|
| 660 |
+
learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
|
| 661 |
+
wd = 0.01
|
| 662 |
+
num_train_epochs = 15
|
| 663 |
+
|
| 664 |
+
revision = "main" #"v1.0"
|
| 665 |
+
|
| 666 |
+
processor = AutoProcessor.from_pretrained(
|
| 667 |
+
model_name_or_path,
|
| 668 |
+
revision=revision,
|
| 669 |
+
trust_remote_code=True,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
model = create_model(
|
| 673 |
+
model_name_or_path,
|
| 674 |
+
revision=revision,
|
| 675 |
+
use_flash_attention=use_flash_attention,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
train_datasets = []
|
| 679 |
+
|
| 680 |
+
# common voice asr
|
| 681 |
+
commonvoice_speech_tw2 = CommonVoiceDataset(
|
| 682 |
+
processor=processor,
|
| 683 |
+
source_lang="zh-TW",
|
| 684 |
+
split="other[:70%]"
|
| 685 |
+
)
|
| 686 |
+
train_datasets.append(commonvoice_speech_tw2)
|
| 687 |
+
|
| 688 |
+
commonvoice_speech_cn = CommonVoiceDataset(
|
| 689 |
+
processor=processor,
|
| 690 |
+
source_lang="zh-CN",
|
| 691 |
+
split="train[:50%]"
|
| 692 |
+
)
|
| 693 |
+
train_datasets.append(commonvoice_speech_cn)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
commonvoice_speech_tw = CommonVoiceDataset(
|
| 697 |
+
processor=processor,
|
| 698 |
+
source_lang="zh-TW",
|
| 699 |
+
split="train"
|
| 700 |
+
)
|
| 701 |
+
train_datasets.append(commonvoice_speech_tw)
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
# Libri Speech Clean ASR mode (English -> English text)
|
| 707 |
+
libri_speech_clean = LibriSpeechDataset(
|
| 708 |
+
processor=processor,
|
| 709 |
+
subset="clean",
|
| 710 |
+
split="train.360[:50%]"
|
| 711 |
+
)
|
| 712 |
+
train_datasets.append(libri_speech_clean)
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# Fleurs ASR mode (English -> English text)
|
| 716 |
+
en_asr_fleurs = FleursDataset(
|
| 717 |
+
processor=processor,
|
| 718 |
+
split="train",
|
| 719 |
+
source_lang="en_us", # English
|
| 720 |
+
mode="asr"
|
| 721 |
+
)
|
| 722 |
+
train_datasets.append(en_asr_fleurs)
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
# en_ch_ast_fleurs = FleursDataset(
|
| 726 |
+
# processor=processor,
|
| 727 |
+
# split="train",
|
| 728 |
+
# source_lang="en_us",
|
| 729 |
+
# target_lang="cmn_hans_cn",
|
| 730 |
+
# mode="ast"
|
| 731 |
+
# )
|
| 732 |
+
# train_datasets.append(en_ch_ast_fleurs)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
ch_asr_fleurs = FleursDataset(
|
| 737 |
+
processor=processor,
|
| 738 |
+
split="train",
|
| 739 |
+
source_lang="cmn_hans_cn",
|
| 740 |
+
mode="asr"
|
| 741 |
+
)
|
| 742 |
+
train_datasets.append(ch_asr_fleurs)
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
# ch_en_ast_fleurs = FleursDataset(
|
| 746 |
+
# processor=processor,
|
| 747 |
+
# split="train",
|
| 748 |
+
# source_lang="cmn_hans_cn",
|
| 749 |
+
# target_lang="en_us",
|
| 750 |
+
# mode="ast"
|
| 751 |
+
# )
|
| 752 |
+
# train_datasets.append(ch_en_ast_fleurs)
|
| 753 |
+
|
| 754 |
+
print("Count Num of Datasets", len(train_datasets))
|
| 755 |
+
print([len(dataset) for dataset in train_datasets])
|
| 756 |
+
|
| 757 |
+
# ConcatDataset
|
| 758 |
+
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
| 759 |
+
print("Count Length of Datas", len(train_dataset))
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
# Check GPUs
|
| 764 |
+
num_gpus = torch.cuda.device_count()
|
| 765 |
+
print(f'training on {num_gpus} GPUs')
|
| 766 |
+
|
| 767 |
+
assert (
|
| 768 |
+
batch_size % (num_gpus * batch_size_per_gpu) == 0
|
| 769 |
+
), 'Batch size must be divisible by the number of GPUs'
|
| 770 |
+
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
|
| 771 |
+
|
| 772 |
+
# hard coded training args
|
| 773 |
+
dp_config = {
|
| 774 |
+
"fp16": {
|
| 775 |
+
"enabled": "auto",
|
| 776 |
+
"loss_scale": 0,
|
| 777 |
+
"loss_scale_window": 1000,
|
| 778 |
+
"initial_scale_power": 16,
|
| 779 |
+
"hysteresis": 2,
|
| 780 |
+
"min_loss_scale": 1
|
| 781 |
+
},
|
| 782 |
+
"zero_optimization": {
|
| 783 |
+
"stage": 2,
|
| 784 |
+
"allgather_partitions": True,
|
| 785 |
+
"allgather_bucket_size": 5e8,
|
| 786 |
+
"overlap_comm": False,
|
| 787 |
+
"reduce_scatter": True,
|
| 788 |
+
"reduce_bucket_size": 5e8,
|
| 789 |
+
"contiguous_gradients": True,
|
| 790 |
+
"cpu_offload": True
|
| 791 |
+
},
|
| 792 |
+
|
| 793 |
+
"train_batch_size": "auto",
|
| 794 |
+
"gradient_accumulation_steps": "auto",
|
| 795 |
+
"optimizer": {
|
| 796 |
+
"type": "AdamW",
|
| 797 |
+
"params": {
|
| 798 |
+
"lr": "auto",
|
| 799 |
+
"betas": 'auto',
|
| 800 |
+
"eps": 'auto',
|
| 801 |
+
"weight_decay": "auto"
|
| 802 |
+
}
|
| 803 |
+
},
|
| 804 |
+
"scheduler": {
|
| 805 |
+
"type": "WarmupDecayLR",
|
| 806 |
+
"params": {
|
| 807 |
+
"warmup_min_lr": "auto",
|
| 808 |
+
"warmup_max_lr": "auto",
|
| 809 |
+
"warmup_num_steps": "auto",
|
| 810 |
+
"total_num_steps": "auto"
|
| 811 |
+
}
|
| 812 |
+
},
|
| 813 |
+
"gradient_clipping": 1.0,
|
| 814 |
+
"zero_optimization": {
|
| 815 |
+
"stage": 0
|
| 816 |
+
}
|
| 817 |
+
}
|
| 818 |
+
training_args = TrainingArguments(
|
| 819 |
+
num_train_epochs=num_train_epochs,
|
| 820 |
+
per_device_train_batch_size=batch_size_per_gpu,
|
| 821 |
+
gradient_checkpointing=True,
|
| 822 |
+
gradient_checkpointing_kwargs={'use_reentrant': False},
|
| 823 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 824 |
+
optim='adamw_torch',
|
| 825 |
+
adam_beta1=0.9,
|
| 826 |
+
adam_beta2=0.95,
|
| 827 |
+
adam_epsilon=1e-7,
|
| 828 |
+
learning_rate=learning_rate,
|
| 829 |
+
weight_decay=wd,
|
| 830 |
+
max_grad_norm=1.0,
|
| 831 |
+
lr_scheduler_type='cosine',
|
| 832 |
+
warmup_steps=50,
|
| 833 |
+
logging_steps=10,
|
| 834 |
+
output_dir=output_dir,
|
| 835 |
+
save_total_limit=10,
|
| 836 |
+
save_only_model=True,
|
| 837 |
+
bf16=True,
|
| 838 |
+
fp16=False,
|
| 839 |
+
remove_unused_columns=False,
|
| 840 |
+
report_to='none',
|
| 841 |
+
deepspeed=dp_config if num_gpus==1 else None,
|
| 842 |
+
disable_tqdm=False,
|
| 843 |
+
dataloader_num_workers=4,
|
| 844 |
+
save_strategy='steps',
|
| 845 |
+
save_steps=1000,
|
| 846 |
+
ddp_find_unused_parameters=True,
|
| 847 |
+
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
out_path = Path(training_args.output_dir)
|
| 851 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 852 |
+
|
| 853 |
+
# create optimizer only for trainable params
|
| 854 |
+
optimizer = torch.optim.AdamW(
|
| 855 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 856 |
+
lr=learning_rate,
|
| 857 |
+
weight_decay=wd,
|
| 858 |
+
betas=(0.9, 0.95),
|
| 859 |
+
eps=1e-7,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
# Trainer Setting
|
| 863 |
+
trainer = Trainer(
|
| 864 |
+
model=model,
|
| 865 |
+
args=training_args,
|
| 866 |
+
data_collator=covost_collate_fn,
|
| 867 |
+
train_dataset=train_dataset,
|
| 868 |
+
optimizers=(optimizer, None)
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
trainer.train()
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
# # 1. Save LoRA Adapter
|
| 875 |
+
model.language_model.model.save_pretrained(output_dir)
|
| 876 |
+
|
| 877 |
+
# # 1-1. Delete Markdown file
|
| 878 |
+
# markdown_file = os.path.join(output_dir, "README.md")
|
| 879 |
+
# if os.path.exists(markdown_file):
|
| 880 |
+
# os.remove(markdown_file)
|
| 881 |
+
|
| 882 |
+
# 2. Save entire model
|
| 883 |
+
model.save_pretrained(output_dir)
|
cpp/gemma_v1/training_multiturn.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import sacrebleu
|
| 14 |
+
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoProcessor,
|
| 20 |
+
AutoModel,
|
| 21 |
+
BatchFeature,
|
| 22 |
+
Trainer,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
StoppingCriteria,
|
| 25 |
+
StoppingCriteriaList,
|
| 26 |
+
)
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
|
| 29 |
+
import soundfile as sf
|
| 30 |
+
from datasets import Audio
|
| 31 |
+
import random
|
| 32 |
+
from ASRDataset import *
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def count_parameters_by_module(model):
|
| 36 |
+
# dictionary for parameters number by modules
|
| 37 |
+
module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
|
| 38 |
+
|
| 39 |
+
# all params
|
| 40 |
+
total_params = 0
|
| 41 |
+
total_trainable_params = 0
|
| 42 |
+
|
| 43 |
+
# Check Embedding Token masks
|
| 44 |
+
embedding_masks = {}
|
| 45 |
+
for name, param in model.named_parameters():
|
| 46 |
+
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
|
| 47 |
+
# check if params has embedding_grad_mask_hook
|
| 48 |
+
for hook_id, hook_fn in param._backward_hooks.items():
|
| 49 |
+
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
|
| 50 |
+
# Accessing mask variables in the closure of hook functions
|
| 51 |
+
for cell in hook_fn.__closure__ or []:
|
| 52 |
+
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
|
| 53 |
+
# check mask tensor
|
| 54 |
+
embedding_masks[name] = ~cell.cell_contents # True : Trainable
|
| 55 |
+
|
| 56 |
+
# Count params by modules
|
| 57 |
+
for name, param in model.named_parameters():
|
| 58 |
+
# extracts top module_name
|
| 59 |
+
module_name = name.split('.')[0]
|
| 60 |
+
param_count = param.numel()
|
| 61 |
+
|
| 62 |
+
module_params[module_name]["total"] += param_count
|
| 63 |
+
total_params += param_count
|
| 64 |
+
|
| 65 |
+
if param.requires_grad:
|
| 66 |
+
# Only count for real trainable params. (with masks)
|
| 67 |
+
if name in embedding_masks:
|
| 68 |
+
trainable_count = embedding_masks[name].sum().item()
|
| 69 |
+
module_params[module_name]["trainable"] += trainable_count
|
| 70 |
+
total_trainable_params += trainable_count
|
| 71 |
+
else:
|
| 72 |
+
module_params[module_name]["trainable"] += param_count
|
| 73 |
+
total_trainable_params += param_count
|
| 74 |
+
|
| 75 |
+
print(f"All Params: {total_params:,}")
|
| 76 |
+
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
|
| 77 |
+
print("\nParams by Module:")
|
| 78 |
+
|
| 79 |
+
for module_name, counts in sorted(module_params.items()):
|
| 80 |
+
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
|
| 81 |
+
total_percentage = counts["total"] / total_params * 100
|
| 82 |
+
|
| 83 |
+
print(f"- {module_name}:")
|
| 84 |
+
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
|
| 85 |
+
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
|
| 86 |
+
|
| 87 |
+
return module_params
|
| 88 |
+
|
| 89 |
+
def create_model(model_name_or_path, revision="main", use_flash_attention = False):
|
| 90 |
+
model = AutoModel.from_pretrained(
|
| 91 |
+
model_name_or_path,
|
| 92 |
+
revision=revision,
|
| 93 |
+
torch_dtype=torch.bfloat16,
|
| 94 |
+
device_map="auto",
|
| 95 |
+
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
|
| 96 |
+
trust_remote_code=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Set use_cache to False after model loaded
|
| 100 |
+
model.config.use_cache = False
|
| 101 |
+
|
| 102 |
+
# Freeze all parameters
|
| 103 |
+
for param in model.parameters():
|
| 104 |
+
param.requires_grad = False
|
| 105 |
+
|
| 106 |
+
model.set_lora_adapter('speech')
|
| 107 |
+
model.to(torch.bfloat16)
|
| 108 |
+
|
| 109 |
+
# (Optional) unfreeze audio_tower parameters
|
| 110 |
+
# for param in model.audio_tower.parameters():
|
| 111 |
+
# param.requires_grad = True
|
| 112 |
+
|
| 113 |
+
# Only unfreeze audio_projector parameters
|
| 114 |
+
# for param in model.audio_projector.parameters():
|
| 115 |
+
# param.requires_grad = True
|
| 116 |
+
|
| 117 |
+
# (Optional) unfreeze audio embed_tokens
|
| 118 |
+
train_embed = True
|
| 119 |
+
if train_embed:
|
| 120 |
+
embed_tokens = model.language_model.model.model.embed_tokens
|
| 121 |
+
|
| 122 |
+
embed_tokens.weight.requires_grad = False
|
| 123 |
+
|
| 124 |
+
# Added Speech token IDs (only this tokens be trainable)
|
| 125 |
+
trainable_token_ids = [256001, 256002]
|
| 126 |
+
|
| 127 |
+
embed_tokens.weight.requires_grad = True
|
| 128 |
+
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
|
| 129 |
+
mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
|
| 130 |
+
|
| 131 |
+
# backward hook, with gradient masking
|
| 132 |
+
def embedding_grad_mask_hook(grad):
|
| 133 |
+
return grad.masked_fill(mask, 0)
|
| 134 |
+
|
| 135 |
+
embed_tokens.weight.register_hook(embedding_grad_mask_hook)
|
| 136 |
+
|
| 137 |
+
model.language_model.model.model.embed_tokens = embed_tokens
|
| 138 |
+
|
| 139 |
+
count_parameters_by_module(model)
|
| 140 |
+
|
| 141 |
+
return model
|
| 142 |
+
|
| 143 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 144 |
+
_IGNORE_INDEX = -100
|
| 145 |
+
|
| 146 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 147 |
+
_IGNORE_INDEX = -100
|
| 148 |
+
|
| 149 |
+
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
|
| 150 |
+
use_flash_attention = False
|
| 151 |
+
|
| 152 |
+
output_dir = '../gemma_tmp13'
|
| 153 |
+
batch_size = 24
|
| 154 |
+
batch_size_per_gpu = 8
|
| 155 |
+
learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
|
| 156 |
+
wd = 0.01
|
| 157 |
+
num_train_epochs = 10
|
| 158 |
+
|
| 159 |
+
revision = "main" #"v1.0"
|
| 160 |
+
|
| 161 |
+
processor = AutoProcessor.from_pretrained(
|
| 162 |
+
model_name_or_path,
|
| 163 |
+
revision=revision,
|
| 164 |
+
trust_remote_code=True,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
model = create_model(
|
| 168 |
+
model_name_or_path,
|
| 169 |
+
revision=revision,
|
| 170 |
+
use_flash_attention=use_flash_attention,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
train_datasets = []
|
| 174 |
+
|
| 175 |
+
pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
|
| 176 |
+
train_datasets.append(pickup_dataset)
|
| 177 |
+
|
| 178 |
+
# custom_tw_loc = TWCostumData(processor=processor,
|
| 179 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
|
| 180 |
+
# train_datasets.append(custom_tw_loc) # 1500
|
| 181 |
+
|
| 182 |
+
# custom_tw_loc2 = TWCostumData(processor=processor,
|
| 183 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
|
| 184 |
+
# train_datasets.append(custom_tw_loc2) # 9458
|
| 185 |
+
|
| 186 |
+
# custom_yating_tw_road = TWCostumData(processor=processor,
|
| 187 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
|
| 188 |
+
# train_datasets.append(custom_yating_tw_road) # 35224
|
| 189 |
+
|
| 190 |
+
# custom_tw_road = TWCostumData(processor=processor,
|
| 191 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
|
| 192 |
+
# train_datasets.append(custom_tw_road) # 1500
|
| 193 |
+
|
| 194 |
+
# custom_tw_road2 = TWCostumData(processor=processor,
|
| 195 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
|
| 196 |
+
# train_datasets.append(custom_tw_road2) # 35224
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
print("Count Num of Datasets", len(train_datasets))
|
| 201 |
+
print([len(dataset) for dataset in train_datasets])
|
| 202 |
+
|
| 203 |
+
# ConcatDataset
|
| 204 |
+
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
| 205 |
+
print("Count Length of Datas", len(train_dataset))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Check GPUs
|
| 210 |
+
num_gpus = torch.cuda.device_count()
|
| 211 |
+
print(f'training on {num_gpus} GPUs')
|
| 212 |
+
|
| 213 |
+
assert (
|
| 214 |
+
batch_size % (num_gpus * batch_size_per_gpu) == 0
|
| 215 |
+
), 'Batch size must be divisible by the number of GPUs'
|
| 216 |
+
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
|
| 217 |
+
|
| 218 |
+
# hard coded training args
|
| 219 |
+
dp_config = {
|
| 220 |
+
"fp16": {
|
| 221 |
+
"enabled": "auto",
|
| 222 |
+
"loss_scale": 0,
|
| 223 |
+
"loss_scale_window": 1000,
|
| 224 |
+
"initial_scale_power": 16,
|
| 225 |
+
"hysteresis": 2,
|
| 226 |
+
"min_loss_scale": 1
|
| 227 |
+
},
|
| 228 |
+
"zero_optimization": {
|
| 229 |
+
"stage": 2,
|
| 230 |
+
"allgather_partitions": True,
|
| 231 |
+
"allgather_bucket_size": 5e8,
|
| 232 |
+
"overlap_comm": False,
|
| 233 |
+
"reduce_scatter": True,
|
| 234 |
+
"reduce_bucket_size": 5e8,
|
| 235 |
+
"contiguous_gradients": True,
|
| 236 |
+
"cpu_offload": True
|
| 237 |
+
},
|
| 238 |
+
|
| 239 |
+
"train_batch_size": "auto",
|
| 240 |
+
"gradient_accumulation_steps": "auto",
|
| 241 |
+
"optimizer": {
|
| 242 |
+
"type": "AdamW",
|
| 243 |
+
"params": {
|
| 244 |
+
"lr": "auto",
|
| 245 |
+
"betas": 'auto',
|
| 246 |
+
"eps": 'auto',
|
| 247 |
+
"weight_decay": "auto"
|
| 248 |
+
}
|
| 249 |
+
},
|
| 250 |
+
"scheduler": {
|
| 251 |
+
"type": "WarmupDecayLR",
|
| 252 |
+
"params": {
|
| 253 |
+
"warmup_min_lr": "auto",
|
| 254 |
+
"warmup_max_lr": "auto",
|
| 255 |
+
"warmup_num_steps": "auto",
|
| 256 |
+
"total_num_steps": "auto"
|
| 257 |
+
}
|
| 258 |
+
},
|
| 259 |
+
"gradient_clipping": 1.0,
|
| 260 |
+
"zero_optimization": {
|
| 261 |
+
"stage": 0
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
training_args = TrainingArguments(
|
| 265 |
+
num_train_epochs=num_train_epochs,
|
| 266 |
+
per_device_train_batch_size=batch_size_per_gpu,
|
| 267 |
+
gradient_checkpointing=True,
|
| 268 |
+
gradient_checkpointing_kwargs={'use_reentrant': False},
|
| 269 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 270 |
+
optim='adamw_torch',
|
| 271 |
+
adam_beta1=0.9,
|
| 272 |
+
adam_beta2=0.95,
|
| 273 |
+
adam_epsilon=1e-7,
|
| 274 |
+
learning_rate=learning_rate,
|
| 275 |
+
weight_decay=wd,
|
| 276 |
+
max_grad_norm=1.0,
|
| 277 |
+
lr_scheduler_type='cosine',
|
| 278 |
+
warmup_steps=50,
|
| 279 |
+
logging_steps=10,
|
| 280 |
+
output_dir=output_dir,
|
| 281 |
+
save_total_limit=10,
|
| 282 |
+
save_only_model=True,
|
| 283 |
+
bf16=True,
|
| 284 |
+
fp16=False,
|
| 285 |
+
remove_unused_columns=False,
|
| 286 |
+
report_to='none',
|
| 287 |
+
deepspeed=None,
|
| 288 |
+
disable_tqdm=False,
|
| 289 |
+
dataloader_num_workers=16,
|
| 290 |
+
save_strategy='epoch',
|
| 291 |
+
# save_steps=2500,
|
| 292 |
+
ddp_find_unused_parameters=True,
|
| 293 |
+
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
out_path = Path(training_args.output_dir)
|
| 297 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 298 |
+
|
| 299 |
+
# create optimizer only for trainable params
|
| 300 |
+
optimizer = torch.optim.AdamW(
|
| 301 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 302 |
+
lr=learning_rate,
|
| 303 |
+
weight_decay=wd,
|
| 304 |
+
betas=(0.9, 0.95),
|
| 305 |
+
eps=1e-7,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Trainer Setting
|
| 309 |
+
trainer = Trainer(
|
| 310 |
+
model=model,
|
| 311 |
+
args=training_args,
|
| 312 |
+
data_collator=covost_collate_fn,
|
| 313 |
+
train_dataset=train_dataset,
|
| 314 |
+
optimizers=(optimizer, None)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
trainer.train()
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# # 1. Save LoRA Adapter
|
| 321 |
+
model.language_model.model.save_pretrained(output_dir)
|
| 322 |
+
|
| 323 |
+
# # 1-1. Delete Markdown file
|
| 324 |
+
# markdown_file = os.path.join(output_dir, "README.md")
|
| 325 |
+
# if os.path.exists(markdown_file):
|
| 326 |
+
# os.remove(markdown_file)
|
| 327 |
+
|
| 328 |
+
# 2. Save entire model
|
| 329 |
+
model.save_pretrained(output_dir)
|
cpp/gemma_v1/training_multiturn_textonly.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
|
| 3 |
+
import os
|
| 4 |
+
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import sacrebleu
|
| 14 |
+
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import (
|
| 19 |
+
AutoProcessor,
|
| 20 |
+
AutoModel,
|
| 21 |
+
BatchFeature,
|
| 22 |
+
Trainer,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
StoppingCriteria,
|
| 25 |
+
StoppingCriteriaList,
|
| 26 |
+
)
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
|
| 29 |
+
import soundfile as sf
|
| 30 |
+
from datasets import Audio
|
| 31 |
+
import random
|
| 32 |
+
from ASRDataset import *
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def count_parameters_by_module(model):
|
| 36 |
+
# dictionary for parameters number by modules
|
| 37 |
+
module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
|
| 38 |
+
|
| 39 |
+
# all params
|
| 40 |
+
total_params = 0
|
| 41 |
+
total_trainable_params = 0
|
| 42 |
+
|
| 43 |
+
# Check Embedding Token masks
|
| 44 |
+
embedding_masks = {}
|
| 45 |
+
for name, param in model.named_parameters():
|
| 46 |
+
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
|
| 47 |
+
# check if params has embedding_grad_mask_hook
|
| 48 |
+
for hook_id, hook_fn in param._backward_hooks.items():
|
| 49 |
+
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
|
| 50 |
+
# Accessing mask variables in the closure of hook functions
|
| 51 |
+
for cell in hook_fn.__closure__ or []:
|
| 52 |
+
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
|
| 53 |
+
# check mask tensor
|
| 54 |
+
embedding_masks[name] = ~cell.cell_contents # True : Trainable
|
| 55 |
+
|
| 56 |
+
# Count params by modules
|
| 57 |
+
for name, param in model.named_parameters():
|
| 58 |
+
# extracts top module_name
|
| 59 |
+
module_name = name.split('.')[0]
|
| 60 |
+
param_count = param.numel()
|
| 61 |
+
|
| 62 |
+
module_params[module_name]["total"] += param_count
|
| 63 |
+
total_params += param_count
|
| 64 |
+
|
| 65 |
+
if param.requires_grad:
|
| 66 |
+
# Only count for real trainable params. (with masks)
|
| 67 |
+
if name in embedding_masks:
|
| 68 |
+
trainable_count = embedding_masks[name].sum().item()
|
| 69 |
+
module_params[module_name]["trainable"] += trainable_count
|
| 70 |
+
total_trainable_params += trainable_count
|
| 71 |
+
else:
|
| 72 |
+
module_params[module_name]["trainable"] += param_count
|
| 73 |
+
total_trainable_params += param_count
|
| 74 |
+
|
| 75 |
+
print(f"All Params: {total_params:,}")
|
| 76 |
+
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
|
| 77 |
+
print("\nParams by Module:")
|
| 78 |
+
|
| 79 |
+
for module_name, counts in sorted(module_params.items()):
|
| 80 |
+
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
|
| 81 |
+
total_percentage = counts["total"] / total_params * 100
|
| 82 |
+
|
| 83 |
+
print(f"- {module_name}:")
|
| 84 |
+
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
|
| 85 |
+
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
|
| 86 |
+
|
| 87 |
+
return module_params
|
| 88 |
+
|
| 89 |
+
def create_model(model_name_or_path, revision="main", use_flash_attention = False):
|
| 90 |
+
model = AutoModel.from_pretrained(
|
| 91 |
+
model_name_or_path,
|
| 92 |
+
revision=revision,
|
| 93 |
+
torch_dtype=torch.bfloat16,
|
| 94 |
+
device_map="auto",
|
| 95 |
+
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
|
| 96 |
+
trust_remote_code=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Set use_cache to False after model loaded
|
| 100 |
+
model.config.use_cache = False
|
| 101 |
+
|
| 102 |
+
# Freeze all parameters
|
| 103 |
+
for param in model.parameters():
|
| 104 |
+
param.requires_grad = False
|
| 105 |
+
|
| 106 |
+
model.set_lora_adapter('speech')
|
| 107 |
+
# model.set_lora_adapter('text')
|
| 108 |
+
model.to(torch.bfloat16)
|
| 109 |
+
|
| 110 |
+
# (Optional) unfreeze audio_tower parameters
|
| 111 |
+
# for param in model.audio_tower.parameters():
|
| 112 |
+
# param.requires_grad = True
|
| 113 |
+
|
| 114 |
+
# Only unfreeze audio_projector parameters
|
| 115 |
+
# for param in model.audio_projector.parameters():
|
| 116 |
+
# param.requires_grad = True
|
| 117 |
+
|
| 118 |
+
# (Optional) unfreeze audio embed_tokens
|
| 119 |
+
train_embed = True
|
| 120 |
+
if train_embed:
|
| 121 |
+
embed_tokens = model.language_model.model.model.embed_tokens
|
| 122 |
+
|
| 123 |
+
embed_tokens.weight.requires_grad = False
|
| 124 |
+
|
| 125 |
+
# Added Speech token IDs (only this tokens be trainable)
|
| 126 |
+
trainable_token_ids = [256001, 256002]
|
| 127 |
+
|
| 128 |
+
embed_tokens.weight.requires_grad = True
|
| 129 |
+
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
|
| 130 |
+
mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
|
| 131 |
+
|
| 132 |
+
# backward hook, with gradient masking
|
| 133 |
+
def embedding_grad_mask_hook(grad):
|
| 134 |
+
return grad.masked_fill(mask, 0)
|
| 135 |
+
|
| 136 |
+
embed_tokens.weight.register_hook(embedding_grad_mask_hook)
|
| 137 |
+
|
| 138 |
+
model.language_model.model.model.embed_tokens = embed_tokens
|
| 139 |
+
|
| 140 |
+
count_parameters_by_module(model)
|
| 141 |
+
|
| 142 |
+
return model
|
| 143 |
+
|
| 144 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 145 |
+
_IGNORE_INDEX = -100
|
| 146 |
+
|
| 147 |
+
ANSWER_SUFFIX = "<end_of_turn>"
|
| 148 |
+
_IGNORE_INDEX = -100
|
| 149 |
+
|
| 150 |
+
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
|
| 151 |
+
use_flash_attention = False
|
| 152 |
+
|
| 153 |
+
output_dir = '../gemma_tmp14_audio_and_text_speechlora'
|
| 154 |
+
batch_size = 16
|
| 155 |
+
batch_size_per_gpu = 1
|
| 156 |
+
learning_rate = 5.0e-5 # 1.0e-4 for fine-tuning
|
| 157 |
+
wd = 0.01
|
| 158 |
+
num_train_epochs = 10
|
| 159 |
+
|
| 160 |
+
revision = "main" #"v1.0"
|
| 161 |
+
|
| 162 |
+
processor = AutoProcessor.from_pretrained(
|
| 163 |
+
model_name_or_path,
|
| 164 |
+
revision=revision,
|
| 165 |
+
trust_remote_code=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
model = create_model(
|
| 169 |
+
model_name_or_path,
|
| 170 |
+
revision=revision,
|
| 171 |
+
use_flash_attention=use_flash_attention,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
train_datasets = []
|
| 175 |
+
|
| 176 |
+
pickup_dataset = MultiturnAudioDataset(processor=processor,text_only=True,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
|
| 177 |
+
train_datasets.append(pickup_dataset)
|
| 178 |
+
|
| 179 |
+
pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
|
| 180 |
+
train_datasets.append(pickup_dataset)
|
| 181 |
+
|
| 182 |
+
# custom_tw_loc = TWCostumData(processor=processor,
|
| 183 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
|
| 184 |
+
# train_datasets.append(custom_tw_loc) # 1500
|
| 185 |
+
|
| 186 |
+
# custom_tw_loc2 = TWCostumData(processor=processor,
|
| 187 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
|
| 188 |
+
# train_datasets.append(custom_tw_loc2) # 9458
|
| 189 |
+
|
| 190 |
+
# custom_yating_tw_road = TWCostumData(processor=processor,
|
| 191 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
|
| 192 |
+
# train_datasets.append(custom_yating_tw_road) # 35224
|
| 193 |
+
|
| 194 |
+
# custom_tw_road = TWCostumData(processor=processor,
|
| 195 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
|
| 196 |
+
# train_datasets.append(custom_tw_road) # 1500
|
| 197 |
+
|
| 198 |
+
# custom_tw_road2 = TWCostumData(processor=processor,
|
| 199 |
+
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
|
| 200 |
+
# train_datasets.append(custom_tw_road2) # 35224
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
print("Count Num of Datasets", len(train_datasets))
|
| 205 |
+
print([len(dataset) for dataset in train_datasets])
|
| 206 |
+
|
| 207 |
+
# ConcatDataset
|
| 208 |
+
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
|
| 209 |
+
print("Count Length of Datas", len(train_dataset))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# Check GPUs
|
| 214 |
+
num_gpus = torch.cuda.device_count()
|
| 215 |
+
print(f'training on {num_gpus} GPUs')
|
| 216 |
+
|
| 217 |
+
assert (
|
| 218 |
+
batch_size % (num_gpus * batch_size_per_gpu) == 0
|
| 219 |
+
), 'Batch size must be divisible by the number of GPUs'
|
| 220 |
+
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
|
| 221 |
+
|
| 222 |
+
# hard coded training args
|
| 223 |
+
dp_config = {
|
| 224 |
+
"fp16": {
|
| 225 |
+
"enabled": "auto",
|
| 226 |
+
"loss_scale": 0,
|
| 227 |
+
"loss_scale_window": 1000,
|
| 228 |
+
"initial_scale_power": 16,
|
| 229 |
+
"hysteresis": 2,
|
| 230 |
+
"min_loss_scale": 1
|
| 231 |
+
},
|
| 232 |
+
"zero_optimization": {
|
| 233 |
+
"stage": 2,
|
| 234 |
+
"allgather_partitions": True,
|
| 235 |
+
"allgather_bucket_size": 5e8,
|
| 236 |
+
"overlap_comm": False,
|
| 237 |
+
"reduce_scatter": True,
|
| 238 |
+
"reduce_bucket_size": 5e8,
|
| 239 |
+
"contiguous_gradients": True,
|
| 240 |
+
"cpu_offload": True
|
| 241 |
+
},
|
| 242 |
+
|
| 243 |
+
"train_batch_size": "auto",
|
| 244 |
+
"gradient_accumulation_steps": "auto",
|
| 245 |
+
"optimizer": {
|
| 246 |
+
"type": "AdamW",
|
| 247 |
+
"params": {
|
| 248 |
+
"lr": "auto",
|
| 249 |
+
"betas": 'auto',
|
| 250 |
+
"eps": 'auto',
|
| 251 |
+
"weight_decay": "auto"
|
| 252 |
+
}
|
| 253 |
+
},
|
| 254 |
+
"scheduler": {
|
| 255 |
+
"type": "WarmupDecayLR",
|
| 256 |
+
"params": {
|
| 257 |
+
"warmup_min_lr": "auto",
|
| 258 |
+
"warmup_max_lr": "auto",
|
| 259 |
+
"warmup_num_steps": "auto",
|
| 260 |
+
"total_num_steps": "auto"
|
| 261 |
+
}
|
| 262 |
+
},
|
| 263 |
+
"gradient_clipping": 1.0,
|
| 264 |
+
"zero_optimization": {
|
| 265 |
+
"stage": 0
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
training_args = TrainingArguments(
|
| 269 |
+
num_train_epochs=num_train_epochs,
|
| 270 |
+
per_device_train_batch_size=batch_size_per_gpu,
|
| 271 |
+
gradient_checkpointing=True,
|
| 272 |
+
gradient_checkpointing_kwargs={'use_reentrant': False},
|
| 273 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 274 |
+
optim='adamw_torch',
|
| 275 |
+
adam_beta1=0.9,
|
| 276 |
+
adam_beta2=0.95,
|
| 277 |
+
adam_epsilon=1e-7,
|
| 278 |
+
learning_rate=learning_rate,
|
| 279 |
+
weight_decay=wd,
|
| 280 |
+
max_grad_norm=1.0,
|
| 281 |
+
lr_scheduler_type='cosine',
|
| 282 |
+
warmup_steps=50,
|
| 283 |
+
logging_steps=10,
|
| 284 |
+
output_dir=output_dir,
|
| 285 |
+
save_total_limit=10,
|
| 286 |
+
save_only_model=True,
|
| 287 |
+
bf16=True,
|
| 288 |
+
fp16=False,
|
| 289 |
+
remove_unused_columns=False,
|
| 290 |
+
report_to='none',
|
| 291 |
+
deepspeed=None,
|
| 292 |
+
disable_tqdm=False,
|
| 293 |
+
dataloader_num_workers=16,
|
| 294 |
+
save_strategy='epoch',
|
| 295 |
+
# save_steps=2500,
|
| 296 |
+
ddp_find_unused_parameters=True,
|
| 297 |
+
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
out_path = Path(training_args.output_dir)
|
| 301 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 302 |
+
|
| 303 |
+
# create optimizer only for trainable params
|
| 304 |
+
optimizer = torch.optim.AdamW(
|
| 305 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 306 |
+
lr=learning_rate,
|
| 307 |
+
weight_decay=wd,
|
| 308 |
+
betas=(0.9, 0.95),
|
| 309 |
+
eps=1e-7,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Trainer Setting
|
| 313 |
+
trainer = Trainer(
|
| 314 |
+
model=model,
|
| 315 |
+
args=training_args,
|
| 316 |
+
data_collator=covost_collate_fn,
|
| 317 |
+
train_dataset=train_dataset,
|
| 318 |
+
optimizers=(optimizer, None)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
trainer.train()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# # 1. Save LoRA Adapter
|
| 325 |
+
model.language_model.model.save_pretrained(output_dir)
|
| 326 |
+
|
| 327 |
+
# # 1-1. Delete Markdown file
|
| 328 |
+
# markdown_file = os.path.join(output_dir, "README.md")
|
| 329 |
+
# if os.path.exists(markdown_file):
|
| 330 |
+
# os.remove(markdown_file)
|
| 331 |
+
|
| 332 |
+
# 2. Save entire model
|
| 333 |
+
model.save_pretrained(output_dir)
|
cpp/inference/audio_encoder_lib.cpp
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "audio_encoder_lib.h"
|
| 2 |
+
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include <fstream>
|
| 5 |
+
#include <cmath>
|
| 6 |
+
#include <numeric>
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cstring> // For memcpy
|
| 9 |
+
|
| 10 |
+
// Include specific ONNX Runtime headers for implementation
|
| 11 |
+
#include <onnxruntime_cxx_api.h>
|
| 12 |
+
|
| 13 |
+
// Include specific Eigen headers for implementation
|
| 14 |
+
#include <Eigen/Dense>
|
| 15 |
+
|
| 16 |
+
// Include specific KissFFT headers for implementation
|
| 17 |
+
#include <kiss_fft.h>
|
| 18 |
+
#include <kiss_fftr.h>
|
| 19 |
+
|
| 20 |
+
// Define M_PI if it's not already defined
|
| 21 |
+
#ifndef M_PI
|
| 22 |
+
#define M_PI 3.14159265358979323846
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
// --- Global parameters for feature extraction (matching Python script) ---
|
| 26 |
+
// These are constants derived from the Python preprocessing script and are
|
| 27 |
+
// internal to the feature extraction logic.
|
| 28 |
+
namespace { // Anonymous namespace for internal linkage
|
| 29 |
+
const float PREEMPHASIS_COEFF = 0.97f;
|
| 30 |
+
const int N_FFT = 512; // FFT size
|
| 31 |
+
const int WIN_LENGTH = 400; // Window length (samples)
|
| 32 |
+
const int HOP_LENGTH = 160; // Hop length (samples)
|
| 33 |
+
const int N_MELS = 80; // Number of Mel filterbank channels
|
| 34 |
+
const int TARGET_SAMPLE_RATE = 16000; // Target sample rate for feature extraction
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// --- Implementation of AudioInferenceEngine methods ---
|
| 38 |
+
|
| 39 |
+
AudioInferenceEngine::AudioInferenceEngine(const std::string& modelPath) {
|
| 40 |
+
// 1. Initialize ONNX Runtime Environment
|
| 41 |
+
env_ = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "AudioInferenceEngine");
|
| 42 |
+
|
| 43 |
+
// 2. Configure Session Options
|
| 44 |
+
Ort::SessionOptions session_options;
|
| 45 |
+
session_options.SetIntraOpNumThreads(0);
|
| 46 |
+
session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
|
| 47 |
+
|
| 48 |
+
// 3. Create ONNX Runtime Session
|
| 49 |
+
session_ = std::make_unique<Ort::Session>(*env_, modelPath.c_str(), session_options);
|
| 50 |
+
|
| 51 |
+
// 4. Initialize Allocator
|
| 52 |
+
allocator_ = std::make_unique<Ort::AllocatorWithDefaultOptions>();
|
| 53 |
+
|
| 54 |
+
// 5. Get Input and Output Node Names
|
| 55 |
+
// It's crucial to allocate these names using the allocator and store them
|
| 56 |
+
// as C-style strings for Ort::Session::Run.
|
| 57 |
+
size_t numInputNodes = session_->GetInputCount();
|
| 58 |
+
if (numInputNodes == 0) {
|
| 59 |
+
throw Ort::Exception("ONNX model has no input nodes.", ORT_FAIL);
|
| 60 |
+
}
|
| 61 |
+
input_node_names_.resize(numInputNodes);
|
| 62 |
+
for (size_t i = 0; i < numInputNodes; ++i) {
|
| 63 |
+
input_node_names_[i] = session_->GetInputNameAllocated(i, *allocator_).release(); // release() to manage lifetime
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
size_t numOutputNodes = session_->GetOutputCount();
|
| 67 |
+
if (numOutputNodes == 0) {
|
| 68 |
+
throw Ort::Exception("ONNX model has no output nodes.", ORT_FAIL);
|
| 69 |
+
}
|
| 70 |
+
output_node_names_.resize(numOutputNodes);
|
| 71 |
+
for (size_t i = 0; i < numOutputNodes; ++i) {
|
| 72 |
+
output_node_names_[i] = session_->GetOutputNameAllocated(i, *allocator_).release(); // release() to manage lifetime
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// 6. Precompute Mel filterbank
|
| 76 |
+
// The Python example uses fmax=16000//2-80-230.
|
| 77 |
+
float mel_fmax = static_cast<float>(TARGET_SAMPLE_RATE) / 2.0f - 80.0f - 230.0f;
|
| 78 |
+
mel_filterbank_ = speechlibMel(TARGET_SAMPLE_RATE, N_FFT, N_MELS, 0.0f, mel_fmax);
|
| 79 |
+
|
| 80 |
+
if (mel_filterbank_.rows() == 0 || mel_filterbank_.cols() == 0) {
|
| 81 |
+
throw std::runtime_error("Failed to create Mel filterbank during initialization.");
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
std::cout << "AudioInferenceEngine initialized successfully with model: " << modelPath << std::endl;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
AudioInferenceEngine::~AudioInferenceEngine() {
|
| 88 |
+
// Release allocated names
|
| 89 |
+
for (const char* name : input_node_names_) {
|
| 90 |
+
allocator_->Free(const_cast<void*>(reinterpret_cast<const void*>(name)));
|
| 91 |
+
}
|
| 92 |
+
for (const char* name : output_node_names_) {
|
| 93 |
+
allocator_->Free(const_cast<void*>(reinterpret_cast<const void*>(name)));
|
| 94 |
+
}
|
| 95 |
+
// unique_ptr automatically handles deletion of env_ and session_
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
/**
|
| 99 |
+
* @brief Private helper: Loads audio data from a WAV file.
|
| 100 |
+
*/
|
| 101 |
+
std::vector<float> AudioInferenceEngine::loadWavToFloatArray(const std::string& filename, int& actual_sample_rate) {
|
| 102 |
+
std::ifstream file(filename, std::ios::binary);
|
| 103 |
+
if (!file.is_open()) {
|
| 104 |
+
std::cerr << "Error: Could not open WAV file: " << filename << std::endl;
|
| 105 |
+
return {};
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
WavHeader header;
|
| 109 |
+
file.read(reinterpret_cast<char*>(&header), sizeof(WavHeader));
|
| 110 |
+
|
| 111 |
+
if (std::string(header.riff_id, 4) != "RIFF" ||
|
| 112 |
+
std::string(header.wave_id, 4) != "WAVE" ||
|
| 113 |
+
std::string(header.fmt_id, 4) != "fmt ") {
|
| 114 |
+
std::cerr << "Error: Invalid WAV header (RIFF, WAVE, or fmt chunk missing/invalid)." << std::endl;
|
| 115 |
+
file.close();
|
| 116 |
+
return {};
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if (header.audio_format != 1) { // 1 = PCM
|
| 120 |
+
std::cerr << "Error: Only PCM audio format (1) is supported. Found: " << header.audio_format << std::endl;
|
| 121 |
+
file.close();
|
| 122 |
+
return {};
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
if (header.bits_per_sample != 16) {
|
| 126 |
+
std::cerr << "Error: Only 16-bit PCM is supported. Found: " << header.bits_per_sample << " bits per sample." << std::endl;
|
| 127 |
+
file.close();
|
| 128 |
+
return {};
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
actual_sample_rate = header.sample_rate;
|
| 132 |
+
|
| 133 |
+
WavDataChunk data_chunk;
|
| 134 |
+
bool data_chunk_found = false;
|
| 135 |
+
while (!file.eof()) {
|
| 136 |
+
file.read(reinterpret_cast<char*>(&data_chunk.data_id), 4);
|
| 137 |
+
file.read(reinterpret_cast<char*>(&data_chunk.data_size), 4);
|
| 138 |
+
|
| 139 |
+
if (std::string(data_chunk.data_id, 4) == "data") {
|
| 140 |
+
data_chunk_found = true;
|
| 141 |
+
break;
|
| 142 |
+
} else {
|
| 143 |
+
file.seekg(data_chunk.data_size, std::ios::cur);
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
if (!data_chunk_found) {
|
| 148 |
+
std::cerr << "Error: 'data' chunk not found in WAV file." << std::endl;
|
| 149 |
+
file.close();
|
| 150 |
+
return {};
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
std::vector<float> audioData;
|
| 154 |
+
int16_t sample_buffer;
|
| 155 |
+
long num_samples_to_read = data_chunk.data_size / sizeof(int16_t);
|
| 156 |
+
|
| 157 |
+
for (long i = 0; i < num_samples_to_read; ++i) {
|
| 158 |
+
file.read(reinterpret_cast<char*>(&sample_buffer), sizeof(int16_t));
|
| 159 |
+
float normalized_sample = static_cast<float>(sample_buffer) / 32768.0f;
|
| 160 |
+
|
| 161 |
+
if (header.num_channels == 1) {
|
| 162 |
+
audioData.push_back(normalized_sample);
|
| 163 |
+
} else if (header.num_channels == 2) {
|
| 164 |
+
int16_t right_sample;
|
| 165 |
+
if (file.read(reinterpret_cast<char*>(&right_sample), sizeof(int16_t))) {
|
| 166 |
+
float normalized_right_sample = static_cast<float>(right_sample) / 32768.0f;
|
| 167 |
+
audioData.push_back((normalized_sample + normalized_right_sample) / 2.0f);
|
| 168 |
+
i++;
|
| 169 |
+
} else {
|
| 170 |
+
std::cerr << "Warning: Unexpected end of file while reading stereo data." << std::endl;
|
| 171 |
+
break;
|
| 172 |
+
}
|
| 173 |
+
} else {
|
| 174 |
+
std::cerr << "Error: Unsupported number of channels: " << header.num_channels << std::endl;
|
| 175 |
+
file.close();
|
| 176 |
+
return {};
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
file.close();
|
| 181 |
+
return audioData;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
/**
|
| 185 |
+
* @brief Private helper: Generates a Hamming window.
|
| 186 |
+
*/
|
| 187 |
+
std::vector<float> AudioInferenceEngine::generateHammingWindow(int window_length) {
|
| 188 |
+
std::vector<float> window(window_length);
|
| 189 |
+
for (int i = 0; i < window_length; ++i) {
|
| 190 |
+
window[i] = 0.54f - 0.46f * std::cos(2 * M_PI * i / static_cast<float>(window_length - 1));
|
| 191 |
+
}
|
| 192 |
+
return window;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
/**
|
| 196 |
+
* @brief Private helper: Extracts spectrogram features.
|
| 197 |
+
*/
|
| 198 |
+
Eigen::MatrixXf AudioInferenceEngine::extractSpectrogram(const std::vector<float>& wav, int fs) {
|
| 199 |
+
int n_batch = (wav.size() - WIN_LENGTH) / HOP_LENGTH + 1;
|
| 200 |
+
if (n_batch <= 0) {
|
| 201 |
+
return Eigen::MatrixXf(0, N_FFT / 2 + 1);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
std::vector<float> fft_window = generateHammingWindow(WIN_LENGTH);
|
| 205 |
+
|
| 206 |
+
kiss_fftr_cfg fft_cfg = kiss_fftr_alloc(N_FFT, 0 /* is_inverse_fft */, nullptr, nullptr);
|
| 207 |
+
if (!fft_cfg) {
|
| 208 |
+
std::cerr << "Error: Failed to allocate KissFFT configuration." << std::endl;
|
| 209 |
+
return Eigen::MatrixXf(0, N_FFT / 2 + 1);
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
Eigen::MatrixXf spec_matrix(n_batch, N_FFT / 2 + 1);
|
| 213 |
+
|
| 214 |
+
std::vector<float> frame_buffer(WIN_LENGTH);
|
| 215 |
+
kiss_fft_scalar fft_input[N_FFT];
|
| 216 |
+
kiss_fft_cpx fft_output[N_FFT / 2 + 1];
|
| 217 |
+
|
| 218 |
+
for (int i = 0; i < n_batch; ++i) {
|
| 219 |
+
int start_idx = i * HOP_LENGTH;
|
| 220 |
+
|
| 221 |
+
for (int j = 0; j < WIN_LENGTH; ++j) {
|
| 222 |
+
frame_buffer[j] = wav[start_idx + j];
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
// Apply pre-emphasis and scale by 32768
|
| 226 |
+
if (WIN_LENGTH > 0) {
|
| 227 |
+
if (WIN_LENGTH > 1) {
|
| 228 |
+
// Corrected pre-emphasis to match Python's np.roll and then overwrite first element
|
| 229 |
+
// The first element of the frame is pre-emphasized against the second element.
|
| 230 |
+
fft_input[0] = (frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]) * 32768.0f;
|
| 231 |
+
for (int j = 1; j < WIN_LENGTH; ++j) {
|
| 232 |
+
fft_input[j] = (frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j - 1]) * 32768.0f;
|
| 233 |
+
}
|
| 234 |
+
} else { // WIN_LENGTH == 1
|
| 235 |
+
fft_input[0] = frame_buffer[0] * 32768.0f;
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
for (int j = WIN_LENGTH; j < N_FFT; ++j) {
|
| 239 |
+
fft_input[j] = 0.0f;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
for (int j = 0; j < WIN_LENGTH; ++j) {
|
| 243 |
+
fft_input[j] *= fft_window[j];
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
kiss_fftr(fft_cfg, fft_input, fft_output);
|
| 247 |
+
|
| 248 |
+
for (int j = 0; j <= N_FFT / 2; ++j) {
|
| 249 |
+
spec_matrix(i, j) = std::sqrt(fft_output[j].r * fft_output[j].r + fft_output[j].i * fft_output[j].i);
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
kiss_fftr_free(fft_cfg);
|
| 254 |
+
return spec_matrix;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
/**
|
| 258 |
+
* @brief Private helper: Creates a Mel filter-bank matrix.
|
| 259 |
+
*/
|
| 260 |
+
Eigen::MatrixXf AudioInferenceEngine::speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax) {
|
| 261 |
+
int bank_width = n_fft / 2 + 1;
|
| 262 |
+
if (fmax == 0.0f) fmax = sample_rate / 2.0f;
|
| 263 |
+
if (fmin == 0.0f) fmin = 0.0f;
|
| 264 |
+
|
| 265 |
+
auto mel = [](float f) { return 1127.0f * std::log(1.0f + f / 700.0f); };
|
| 266 |
+
auto bin2mel = [&](int fft_bin) { return 1127.0f * std::log(1.0f + static_cast<float>(fft_bin) * sample_rate / (static_cast<float>(n_fft) * 700.0f)); };
|
| 267 |
+
auto f2bin = [&](float f) { return static_cast<int>((f * n_fft / sample_rate) + 0.5f); };
|
| 268 |
+
|
| 269 |
+
int klo = f2bin(fmin) + 1;
|
| 270 |
+
int khi = f2bin(fmax);
|
| 271 |
+
khi = std::max(khi, klo);
|
| 272 |
+
|
| 273 |
+
float mlo = mel(fmin);
|
| 274 |
+
float mhi = mel(fmax);
|
| 275 |
+
|
| 276 |
+
std::vector<float> m_centers(n_mels + 2);
|
| 277 |
+
float ms = (mhi - mlo) / (n_mels + 1);
|
| 278 |
+
for (int i = 0; i < n_mels + 2; ++i) {
|
| 279 |
+
m_centers[i] = mlo + i * ms;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
Eigen::MatrixXf matrix = Eigen::MatrixXf::Zero(n_mels, bank_width);
|
| 283 |
+
|
| 284 |
+
for (int m = 0; m < n_mels; ++m) {
|
| 285 |
+
float left = m_centers[m];
|
| 286 |
+
float center = m_centers[m + 1];
|
| 287 |
+
float right = m_centers[m + 2];
|
| 288 |
+
for (int fft_bin = klo; fft_bin < bank_width; ++fft_bin) {
|
| 289 |
+
float mbin = bin2mel(fft_bin);
|
| 290 |
+
if (left < mbin && mbin < right) {
|
| 291 |
+
matrix(m, fft_bin) = 1.0f - std::abs(center - mbin) / ms;
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
return matrix;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
/**
|
| 299 |
+
* @brief Public method: Preprocesses an audio WAV file.
|
| 300 |
+
*/
|
| 301 |
+
Eigen::MatrixXf AudioInferenceEngine::preprocessAudio(const std::string& wavFilePath) {
|
| 302 |
+
int actual_wav_sample_rate = 0;
|
| 303 |
+
std::vector<float> audioWav = loadWavToFloatArray(wavFilePath, actual_wav_sample_rate);
|
| 304 |
+
|
| 305 |
+
if (audioWav.empty()) {
|
| 306 |
+
std::cerr << "Failed to load audio data from " << wavFilePath << "." << std::endl;
|
| 307 |
+
return Eigen::MatrixXf(0, N_MELS);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
if (actual_wav_sample_rate != TARGET_SAMPLE_RATE) {
|
| 311 |
+
std::cerr << "Warning: WAV file sample rate (" << actual_wav_sample_rate
|
| 312 |
+
<< " Hz) does not match the target sample rate for feature extraction ("
|
| 313 |
+
<< TARGET_SAMPLE_RATE << " Hz)." << std::endl;
|
| 314 |
+
std::cerr << "This example does NOT include resampling. Features will be extracted at "
|
| 315 |
+
<< TARGET_SAMPLE_RATE << " Hz, which might lead to incorrect results if the WAV file's sample rate is different." << std::endl;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
Eigen::MatrixXf spec = extractSpectrogram(audioWav, TARGET_SAMPLE_RATE);
|
| 319 |
+
if (spec.rows() == 0) {
|
| 320 |
+
std::cerr << "Error: Spectrogram extraction failed." << std::endl;
|
| 321 |
+
return Eigen::MatrixXf(0, N_MELS);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
Eigen::MatrixXf spec_power = spec.array().square();
|
| 325 |
+
Eigen::MatrixXf fbank_power = spec_power * mel_filterbank_.transpose(); // Transpose mel_filterbank_ for correct multiplication
|
| 326 |
+
|
| 327 |
+
fbank_power = fbank_power.array().max(1.0f);
|
| 328 |
+
Eigen::MatrixXf log_fbank = fbank_power.array().log();
|
| 329 |
+
|
| 330 |
+
return log_fbank;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
/**
|
| 334 |
+
* @brief Public method: Runs inference on the loaded ONNX model.
|
| 335 |
+
*/
|
| 336 |
+
std::vector<float> AudioInferenceEngine::runInference(const Eigen::MatrixXf& features) {
|
| 337 |
+
if (features.rows() == 0 || features.cols() == 0) {
|
| 338 |
+
std::cerr << "Error: Input features are empty for inference." << std::endl;
|
| 339 |
+
return {};
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// Prepare Input Tensor Shape: [batch, frames, feature_size]
|
| 343 |
+
std::vector<int64_t> inputTensorShape = {1, features.rows(), features.cols()};
|
| 344 |
+
|
| 345 |
+
// Flatten Eigen::MatrixXf into std::vector<float> in row-major order
|
| 346 |
+
std::vector<float> inputTensorData(features.rows() * features.cols());
|
| 347 |
+
for (int r = 0; r < features.rows(); ++r) {
|
| 348 |
+
for (int c = 0; c < features.cols(); ++c) {
|
| 349 |
+
inputTensorData[r * features.cols() + c] = features(r, c);
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
| 354 |
+
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, inputTensorData.data(), inputTensorData.size(),
|
| 355 |
+
inputTensorShape.data(), inputTensorShape.size());
|
| 356 |
+
|
| 357 |
+
if (!inputTensor.IsTensor()) {
|
| 358 |
+
std::cerr << "Error: Created input tensor is not valid!" << std::endl;
|
| 359 |
+
return {};
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
// Run Inference
|
| 363 |
+
std::vector<Ort::Value> outputTensors = session_->Run(Ort::RunOptions{nullptr},
|
| 364 |
+
input_node_names_.data(), &inputTensor, 1,
|
| 365 |
+
output_node_names_.data(), output_node_names_.size());
|
| 366 |
+
|
| 367 |
+
if (outputTensors.empty() || !outputTensors[0].IsTensor()) {
|
| 368 |
+
std::cerr << "Error: No valid output tensors received from the model." << std::endl;
|
| 369 |
+
return {};
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
// Copy output data
|
| 373 |
+
float* outputData = outputTensors[0].GetTensorMutableData<float>();
|
| 374 |
+
Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
|
| 375 |
+
size_t outputSize = outputShapeInfo.GetElementCount();
|
| 376 |
+
|
| 377 |
+
std::vector<float> result(outputData, outputData + outputSize);
|
| 378 |
+
return result;
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
std::vector<Ort::Value> AudioInferenceEngine::runInference_tensor(const Ort::Value& inputTensor) {
|
| 382 |
+
// Run Inference
|
| 383 |
+
std::vector<Ort::Value> outputTensors = session_->Run(Ort::RunOptions{nullptr},
|
| 384 |
+
input_node_names_.data(), &inputTensor, 1,
|
| 385 |
+
output_node_names_.data(), output_node_names_.size());
|
| 386 |
+
|
| 387 |
+
return outputTensors;
|
| 388 |
+
}
|
cpp/inference/audio_encoder_lib.h
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef AUDIO_INFERENCE_LIBRARY_H
|
| 2 |
+
#define AUDIO_INFERENCE_LIBRARY_H
|
| 3 |
+
|
| 4 |
+
#include <string>
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <cstdint> // For uint32_t, int16_t
|
| 7 |
+
#include <memory> // For std::unique_ptr
|
| 8 |
+
#include <Eigen/Dense>
|
| 9 |
+
using namespace Eigen;
|
| 10 |
+
// Forward declarations for ONNX Runtime types to avoid including full headers in .h
|
| 11 |
+
namespace Ort {
|
| 12 |
+
struct Env;
|
| 13 |
+
struct Session;
|
| 14 |
+
struct MemoryInfo;
|
| 15 |
+
struct AllocatorWithDefaultOptions;
|
| 16 |
+
struct Value;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
// Forward declaration for Eigen Matrix
|
| 20 |
+
namespace Eigen {
|
| 21 |
+
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
|
| 22 |
+
class Matrix;
|
| 23 |
+
typedef Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor, Eigen::Dynamic, Eigen::Dynamic> MatrixXf;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
/**
|
| 27 |
+
* @brief Class to handle audio preprocessing and ONNX model inference.
|
| 28 |
+
*
|
| 29 |
+
* This class encapsulates the logic for loading WAV files, extracting Mel filterbank
|
| 30 |
+
* features, and running inference on an ONNX model.
|
| 31 |
+
*/
|
| 32 |
+
class AudioInferenceEngine {
|
| 33 |
+
public:
|
| 34 |
+
/**
|
| 35 |
+
* @brief Constructor for AudioInferenceEngine.
|
| 36 |
+
* @param modelPath The file path to the ONNX model.
|
| 37 |
+
* @throws Ort::Exception if ONNX Runtime initialization fails.
|
| 38 |
+
*/
|
| 39 |
+
AudioInferenceEngine(const std::string& modelPath);
|
| 40 |
+
|
| 41 |
+
/**
|
| 42 |
+
* @brief Destructor to clean up ONNX Runtime resources.
|
| 43 |
+
*/
|
| 44 |
+
~AudioInferenceEngine();
|
| 45 |
+
|
| 46 |
+
/**
|
| 47 |
+
* @brief Preprocesses an audio WAV file to extract Mel filterbank features.
|
| 48 |
+
*
|
| 49 |
+
* This function loads the WAV file, converts it to a float array, and then
|
| 50 |
+
* applies the spectrogram and Mel filterbank extraction steps.
|
| 51 |
+
*
|
| 52 |
+
* @param wavFilePath The path to the WAV audio file.
|
| 53 |
+
* @return An Eigen::MatrixXf containing the extracted features (frames x N_MELS).
|
| 54 |
+
* Returns an empty matrix if preprocessing fails.
|
| 55 |
+
*/
|
| 56 |
+
Eigen::MatrixXf preprocessAudio(const std::string& wavFilePath);
|
| 57 |
+
|
| 58 |
+
/**
|
| 59 |
+
* @brief Runs inference on the loaded ONNX model using the provided features.
|
| 60 |
+
*
|
| 61 |
+
* The input features should be the output of `preprocessAudio`. This function
|
| 62 |
+
* converts the features to an ONNX Runtime tensor and executes the model.
|
| 63 |
+
*
|
| 64 |
+
* @param features An Eigen::MatrixXf containing the preprocessed audio features.
|
| 65 |
+
* Expected shape: (frames, N_MELS).
|
| 66 |
+
* @return A std::vector<float> containing the flattened output of the ONNX model.
|
| 67 |
+
* Returns an empty vector if inference fails.
|
| 68 |
+
*/
|
| 69 |
+
std::vector<float> runInference(const Eigen::MatrixXf& features);
|
| 70 |
+
std::vector<Ort::Value> runInference_tensor(const Ort::Value& inputTensor);
|
| 71 |
+
|
| 72 |
+
private:
|
| 73 |
+
// ONNX Runtime members
|
| 74 |
+
std::unique_ptr<Ort::Env> env_;
|
| 75 |
+
std::unique_ptr<Ort::Session> session_;
|
| 76 |
+
std::unique_ptr<Ort::AllocatorWithDefaultOptions> allocator_;
|
| 77 |
+
std::vector<const char*> input_node_names_;
|
| 78 |
+
std::vector<const char*> output_node_names_;
|
| 79 |
+
|
| 80 |
+
// Precomputed Mel filterbank matrix
|
| 81 |
+
Eigen::MatrixXf mel_filterbank_;
|
| 82 |
+
|
| 83 |
+
// Private helper functions (implemented in .cpp)
|
| 84 |
+
// WAV file parsing structures
|
| 85 |
+
#pragma pack(push, 1)
|
| 86 |
+
struct WavHeader {
|
| 87 |
+
char riff_id[4];
|
| 88 |
+
uint32_t file_size;
|
| 89 |
+
char wave_id[4];
|
| 90 |
+
char fmt_id[4];
|
| 91 |
+
uint32_t fmt_size;
|
| 92 |
+
uint16_t audio_format;
|
| 93 |
+
uint16_t num_channels;
|
| 94 |
+
uint32_t sample_rate;
|
| 95 |
+
uint32_t byte_rate;
|
| 96 |
+
uint16_t block_align;
|
| 97 |
+
uint16_t bits_per_sample;
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
struct WavDataChunk {
|
| 101 |
+
char data_id[4];
|
| 102 |
+
uint32_t data_size;
|
| 103 |
+
};
|
| 104 |
+
#pragma pack(pop)
|
| 105 |
+
|
| 106 |
+
/**
|
| 107 |
+
* @brief Loads audio data from a WAV file into a float vector.
|
| 108 |
+
* @param filename The path to the WAV audio file.
|
| 109 |
+
* @param actual_sample_rate Output parameter to store the sample rate read from the WAV file.
|
| 110 |
+
* @return A std::vector<float> containing the normalized mono audio samples.
|
| 111 |
+
*/
|
| 112 |
+
std::vector<float> loadWavToFloatArray(const std::string& filename, int& actual_sample_rate);
|
| 113 |
+
|
| 114 |
+
/**
|
| 115 |
+
* @brief Generates a Hamming window.
|
| 116 |
+
* @param window_length The length of the window.
|
| 117 |
+
* @return A std::vector<float> containing the Hamming window coefficients.
|
| 118 |
+
*/
|
| 119 |
+
std::vector<float> generateHammingWindow(int window_length);
|
| 120 |
+
|
| 121 |
+
/**
|
| 122 |
+
* @brief Extracts spectrogram features from waveform.
|
| 123 |
+
* @param wav The input waveform.
|
| 124 |
+
* @param fs The sampling rate.
|
| 125 |
+
* @return A 2D Eigen::MatrixXf representing the spectrogram.
|
| 126 |
+
*/
|
| 127 |
+
Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs);
|
| 128 |
+
|
| 129 |
+
/**
|
| 130 |
+
* @brief Creates a Mel filter-bank matrix.
|
| 131 |
+
* @param sample_rate Sample rate in Hz.
|
| 132 |
+
* @param n_fft FFT size.
|
| 133 |
+
* @param n_mels Mel filter size.
|
| 134 |
+
* @param fmin Lowest frequency (in Hz).
|
| 135 |
+
* @param fmax Highest frequency (in Hz).
|
| 136 |
+
* @return An Eigen::MatrixXf representing the Mel transform matrix.
|
| 137 |
+
*/
|
| 138 |
+
Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax);
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
+
#endif // AUDIO_INFERENCE_LIBRARY_H
|
cpp/inference/audio_encoder_lib.o
ADDED
|
Binary file (85.1 kB). View file
|
|
|
cpp/inference/audio_inference
ADDED
|
Binary file (91 kB). View file
|
|
|
cpp/inference/audio_inference_app
ADDED
|
Binary file (97.7 kB). View file
|
|
|
cpp/inference/compile.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export BASE_DIR="/mnt/data-2t/jeff/codes/llm/cpp"
|
| 2 |
+
# g++ test.cpp $BASE_DIR/kissfft/kiss_fft.c $BASE_DIR/kissfft/kiss_fftr.c \
|
| 3 |
+
# -o audio_inference \
|
| 4 |
+
# -I $BASE_DIR/onnxruntime-linux-x64-1.22.0/include \
|
| 5 |
+
# -I $BASE_DIR/eigen-3.4.0 \
|
| 6 |
+
# -I $BASE_DIR/kissfft \
|
| 7 |
+
# -L $BASE_DIR/kissfft/lib -lkissfft-int16_t-openmp \
|
| 8 |
+
# -L $BASE_DIR/onnxruntime-linux-x64-1.22.0/lib -lonnxruntime -std=c++17 -O2 -DNDEBUG
|
| 9 |
+
|
| 10 |
+
g++ -c audio_encoder_lib.cpp \
|
| 11 |
+
-o audio_encoder_lib.o \
|
| 12 |
+
-I $BASE_DIR/onnxruntime-linux-x64-1.22.0/include \
|
| 13 |
+
-I $BASE_DIR/eigen-3.4.0 \
|
| 14 |
+
-I $BASE_DIR/kissfft \
|
| 15 |
+
-std=c++17 -O3 -DNDEBUG -fPIC
|
| 16 |
+
|
| 17 |
+
g++ -c $BASE_DIR/kissfft/kiss_fft.c \
|
| 18 |
+
-o kiss_fft.o \
|
| 19 |
+
-I $BASE_DIR/kissfft \
|
| 20 |
+
-std=c++17 -O3 -DNDEBUG -fPIC
|
| 21 |
+
|
| 22 |
+
g++ -c $BASE_DIR/kissfft/kiss_fftr.c \
|
| 23 |
+
-o kiss_fftr.o \
|
| 24 |
+
-I $BASE_DIR/kissfft \
|
| 25 |
+
-std=c++17 -O3 -DNDEBUG -fPIC
|
| 26 |
+
|
| 27 |
+
g++ main_text.cpp audio_encoder_lib.o kiss_fft.o kiss_fftr.o \
|
| 28 |
+
-o audio_inference_app \
|
| 29 |
+
-I $BASE_DIR/onnxruntime-linux-x64-1.22.0/include \
|
| 30 |
+
-I $BASE_DIR/eigen-3.4.0 \
|
| 31 |
+
-I $BASE_DIR/kissfft \
|
| 32 |
+
-L $BASE_DIR/onnxruntime-linux-x64-1.22.0/lib -lonnxruntime -std=c++17 -O3 -DNDEBUG
|
cpp/inference/dummy.wav
ADDED
|
Binary file (57.5 kB). View file
|
|
|
cpp/inference/f0.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cpp/inference/f_inp.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cpp/inference/kiss_fft.o
ADDED
|
Binary file (14.5 kB). View file
|
|
|
cpp/inference/kiss_fftr.o
ADDED
|
Binary file (3.9 kB). View file
|
|
|
cpp/inference/main_text.cpp
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <iostream>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#include <fstream>
|
| 4 |
+
#include <string>
|
| 5 |
+
#include <cmath> // For std::sin, M_PI
|
| 6 |
+
#include <cstring> // For std::memcpy
|
| 7 |
+
#include <chrono> // For time measurement
|
| 8 |
+
#include <random> // For random number generation
|
| 9 |
+
#include <ctime> // For seeding random number generator
|
| 10 |
+
|
| 11 |
+
// Include the new library header
|
| 12 |
+
#include "audio_encoder_lib.h"
|
| 13 |
+
#include <onnxruntime_cxx_api.h>
|
| 14 |
+
// Define M_PI if it's not already defined
|
| 15 |
+
#ifndef M_PI
|
| 16 |
+
#define M_PI 3.14159265358979323846
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
// --- WAV File Header Structures (for dummy file creation) ---
|
| 20 |
+
#pragma pack(push, 1)
|
| 21 |
+
struct WavHeader {
|
| 22 |
+
char riff_id[4];
|
| 23 |
+
uint32_t file_size;
|
| 24 |
+
char wave_id[4];
|
| 25 |
+
char fmt_id[4];
|
| 26 |
+
uint32_t fmt_size;
|
| 27 |
+
uint16_t audio_format;
|
| 28 |
+
uint16_t num_channels;
|
| 29 |
+
uint32_t sample_rate;
|
| 30 |
+
uint32_t byte_rate;
|
| 31 |
+
uint16_t block_align;
|
| 32 |
+
uint16_t bits_per_sample;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
struct WavDataChunk {
|
| 36 |
+
char data_id[4];
|
| 37 |
+
uint32_t data_size;
|
| 38 |
+
};
|
| 39 |
+
#pragma pack(pop)
|
| 40 |
+
|
| 41 |
+
// Function to write a dummy WAV file (moved here for example app)
|
| 42 |
+
void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) {
|
| 43 |
+
std::ofstream file(filename, std::ios::binary);
|
| 44 |
+
if (!file.is_open()) {
|
| 45 |
+
std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl;
|
| 46 |
+
return;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
WavHeader header;
|
| 50 |
+
std::memcpy(header.riff_id, "RIFF", 4);
|
| 51 |
+
std::memcpy(header.wave_id, "WAVE", 4);
|
| 52 |
+
std::memcpy(header.fmt_id, "fmt ", 4);
|
| 53 |
+
header.fmt_size = 16;
|
| 54 |
+
header.audio_format = 1; // PCM
|
| 55 |
+
header.num_channels = numChannels;
|
| 56 |
+
header.sample_rate = sampleRate;
|
| 57 |
+
header.bits_per_sample = bitsPerSample;
|
| 58 |
+
header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8;
|
| 59 |
+
header.block_align = (numChannels * bitsPerSample) / 8;
|
| 60 |
+
|
| 61 |
+
WavDataChunk data_chunk;
|
| 62 |
+
std::memcpy(data_chunk.data_id, "data", 4);
|
| 63 |
+
uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds);
|
| 64 |
+
data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8);
|
| 65 |
+
header.file_size = 36 + data_chunk.data_size; // 36 is size of header before data chunk
|
| 66 |
+
|
| 67 |
+
file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader));
|
| 68 |
+
file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk));
|
| 69 |
+
|
| 70 |
+
// Generate a 440 Hz sine wave
|
| 71 |
+
for (uint32_t i = 0; i < num_samples; ++i) {
|
| 72 |
+
int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate)));
|
| 73 |
+
for (int c = 0; c < numChannels; ++c) {
|
| 74 |
+
file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t));
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
file.close();
|
| 79 |
+
// std::cout << "Dummy WAV file '" << filename << "' created successfully." << std::endl; // Suppress verbose creation message
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
int main(int argc, char* argv[]) {
|
| 83 |
+
// --- 1. Process command-line arguments ---
|
| 84 |
+
if (argc != 3) {
|
| 85 |
+
std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file_for_temp_use>" << std::endl;
|
| 86 |
+
std::cerr << "Example: " << argv[0] << " model.onnx temp_audio.wav" << std::endl;
|
| 87 |
+
return 1;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
std::string onnxModelPath = argv[1];
|
| 91 |
+
std::string wavFilename = argv[2]; // This will be used as a temporary file
|
| 92 |
+
|
| 93 |
+
// --- Random number generation setup for dummy input frames ---
|
| 94 |
+
std::mt19937 rng(static_cast<unsigned int>(std::time(nullptr))); // Seed with current time
|
| 95 |
+
std::uniform_int_distribution<int> dist_frames(100, 300); // Distribution for frames (100 to 300)
|
| 96 |
+
|
| 97 |
+
// Define fixed parameters for feature extraction to calculate required duration
|
| 98 |
+
const int WIN_LENGTH = 400; // Window length (samples) - must match library's constant
|
| 99 |
+
const int HOP_LENGTH = 160; // Hop length (samples) - must match library's constant
|
| 100 |
+
const int TARGET_SAMPLE_RATE = 16000; // Target sample rate - must match library's constant
|
| 101 |
+
|
| 102 |
+
try {
|
| 103 |
+
// --- 2. Model Initialization ---
|
| 104 |
+
// This will load the ONNX model and precompute the Mel filterbank.
|
| 105 |
+
AudioInferenceEngine engine(onnxModelPath);
|
| 106 |
+
std::cout << "Engine initialized." << std::endl;
|
| 107 |
+
|
| 108 |
+
// --- 3. Model Inference and Time Measurement ---
|
| 109 |
+
std::cout << "\nRunning model inference and measuring time (100 runs with varying input sizes)..." << std::endl;
|
| 110 |
+
int num_runs = 100;
|
| 111 |
+
long long total_inference_time_us = 0; // Use microseconds for finer granularity
|
| 112 |
+
|
| 113 |
+
for (int i = 0; i < num_runs; ++i) {
|
| 114 |
+
// Generate a random number of frames for this run
|
| 115 |
+
int random_frames = dist_frames(rng);
|
| 116 |
+
// Calculate the number of samples needed to produce 'random_frames'
|
| 117 |
+
// frames = (num_samples - WIN_LENGTH) / HOP_LENGTH + 1
|
| 118 |
+
// num_samples = (frames - 1) * HOP_LENGTH + WIN_LENGTH
|
| 119 |
+
long long num_samples_for_frames = static_cast<long long>(random_frames - 1) * HOP_LENGTH + WIN_LENGTH;
|
| 120 |
+
double duration_seconds_for_frames = static_cast<double>(num_samples_for_frames) / TARGET_SAMPLE_RATE;
|
| 121 |
+
|
| 122 |
+
// Create a new dummy WAV file for this specific run
|
| 123 |
+
// This ensures the input size changes for each test.
|
| 124 |
+
createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, duration_seconds_for_frames);
|
| 125 |
+
|
| 126 |
+
// --- Measure the inference time ---
|
| 127 |
+
auto start_time = std::chrono::high_resolution_clock::now();
|
| 128 |
+
Eigen::MatrixXf features = engine.preprocessAudio(wavFilename);
|
| 129 |
+
std::vector<float> model_output = engine.runInference(features);
|
| 130 |
+
auto end_time = std::chrono::high_resolution_clock::now();
|
| 131 |
+
|
| 132 |
+
if (model_output.empty()) {
|
| 133 |
+
std::cerr << "Error: Model inference failed for run " << i + 1 << ". Exiting." << std::endl;
|
| 134 |
+
return 1;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
// Calculate duration for this run in microseconds
|
| 138 |
+
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
|
| 139 |
+
total_inference_time_us += duration.count();
|
| 140 |
+
|
| 141 |
+
// Optionally print output for the first run or specific runs
|
| 142 |
+
if (i == 0) {
|
| 143 |
+
std::cout << "First run (frames=" << features.rows() << ")"<< " take : "<< static_cast<double>(total_inference_time_us) / 1000.0 / 1000.0 <<"s output (first few elements): [";
|
| 144 |
+
for (size_t k = 0; k < std::min((size_t)10, model_output.size()); ++k) {
|
| 145 |
+
std::cout << model_output[k] << (k == std::min((size_t)10, model_output.size()) - 1 ? "" : ", ");
|
| 146 |
+
}
|
| 147 |
+
std::cout << "]" << std::endl;
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
double average_inference_time_ms = static_cast<double>(total_inference_time_us) / num_runs / 1000.0 / 1000.0; // Convert microseconds to milliseconds
|
| 152 |
+
std::cout << "\nAverage ONNX model inference time over " << num_runs << " runs (with varying input frames): "
|
| 153 |
+
<< average_inference_time_ms << " s" << std::endl;
|
| 154 |
+
|
| 155 |
+
} catch (const Ort::Exception& e) {
|
| 156 |
+
std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
|
| 157 |
+
return 1;
|
| 158 |
+
} catch (const std::exception& e) {
|
| 159 |
+
std::cerr << "Standard Exception: " << e.what() << std::endl;
|
| 160 |
+
return 1;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
std::cout << "\nProgram finished successfully." << std::endl;
|
| 164 |
+
return 0;
|
| 165 |
+
}
|
cpp/inference/matrix_output.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
cpp/inference/run.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export ONNXRUNTIME_DIR="/mnt/data-2t/jeff/codes/llm/cpp/onnxruntime-linux-x64-1.22.0"
|
| 2 |
+
export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH
|
| 3 |
+
|
| 4 |
+
export MODEL_PATH="/mnt/data-2t/jeff/codes/llm/cpp/onnx_files/speech_init_export/phi-4-mm-speech.onnx"
|
| 5 |
+
export SAMPLE_DATA="/mnt/data-2t/jeff/codes/llm/cpp/inference/dummy.wav"
|
| 6 |
+
|
| 7 |
+
./audio_inference_app $MODEL_PATH $SAMPLE_DATA
|
cpp/inference/test copy 2.cpp
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <iostream> // For standard input/output operations (e.g., std::cout, std::cerr)
|
| 2 |
+
#include <vector> // For dynamic arrays (e.g., std::vector<float>)
|
| 3 |
+
#include <fstream> // For file input/output operations (e.g., std::ifstream, std::ofstream)
|
| 4 |
+
#include <cstdint> // For fixed-width integer types (e.g., int16_t)
|
| 5 |
+
#include <cmath> // For mathematical functions (e.g., std::sin, M_PI, std::log)
|
| 6 |
+
#include <numeric> // For numerical operations (e.g., std::iota)
|
| 7 |
+
#include <algorithm> // For algorithms like std::min, std::max
|
| 8 |
+
#include <fstream>
|
| 9 |
+
// Include the ONNX Runtime C++ API header
|
| 10 |
+
#include <onnxruntime_cxx_api.h>
|
| 11 |
+
|
| 12 |
+
// Include Eigen for powerful matrix operations.
|
| 13 |
+
// You need to download Eigen and set up your include paths.
|
| 14 |
+
// E.g., if Eigen is in 'C:/Libraries/eigen-3.4.0', you'd compile with -I C:/Libraries/eigen-3.4.0
|
| 15 |
+
#include <Eigen/Dense>
|
| 16 |
+
|
| 17 |
+
// Include KissFFT for Fast Fourier Transform.
|
| 18 |
+
// You need to download KissFFT and set up your include paths.
|
| 19 |
+
// E.g., if KissFFT is in 'C:/Libraries/kissfft-1.3.0', you'd compile with -I C:/Libraries/kissfft-1.3.0
|
| 20 |
+
// You also need to compile kiss_fft.c and kiss_fftr.c and link them.
|
| 21 |
+
#include "kiss_fft.h"
|
| 22 |
+
#include "kiss_fftr.h" // For real-valued FFT
|
| 23 |
+
|
| 24 |
+
// Define M_PI if it's not already defined by cmath or your compiler.
|
| 25 |
+
#ifndef M_PI
|
| 26 |
+
#define M_PI 3.14159265358979323846
|
| 27 |
+
#endif
|
| 28 |
+
|
| 29 |
+
// --- Global parameters for feature extraction (matching Python script) ---
|
| 30 |
+
const float PREEMPHASIS_COEFF = 0.97f;
|
| 31 |
+
const int N_FFT = 512; // FFT size
|
| 32 |
+
const int WIN_LENGTH = 400; // Window length (samples)
|
| 33 |
+
const int HOP_LENGTH = 160; // Hop length (samples)
|
| 34 |
+
const int N_MELS = 80; // Number of Mel filterbank channels
|
| 35 |
+
const int TARGET_SAMPLE_RATE = 16000; // Target sample rate for feature extraction
|
| 36 |
+
|
| 37 |
+
/**
|
| 38 |
+
* @brief Loads raw PCM audio data from a file into a float vector.
|
| 39 |
+
*
|
| 40 |
+
* This function reads 16-bit signed integer PCM samples from the specified file,
|
| 41 |
+
* converts them to floating-point values, and normalizes them to the range [-1.0, 1.0].
|
| 42 |
+
* It assumes the PCM data is little-endian.
|
| 43 |
+
*
|
| 44 |
+
* @param filename The path to the PCM audio file.
|
| 45 |
+
* @return A std::vector<float> containing the normalized audio samples, or an empty
|
| 46 |
+
* vector if the file cannot be opened.
|
| 47 |
+
*/
|
| 48 |
+
std::vector<float> loadPcmToFloatArray(const std::string& filename) {
|
| 49 |
+
std::ifstream file(filename, std::ios::binary);
|
| 50 |
+
if (!file.is_open()) {
|
| 51 |
+
std::cerr << "Error: Could not open PCM file: " << filename << std::endl;
|
| 52 |
+
return {};
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
std::vector<float> audioData;
|
| 56 |
+
int16_t sample;
|
| 57 |
+
|
| 58 |
+
while (file.read(reinterpret_cast<char*>(&sample), sizeof(sample))) {
|
| 59 |
+
audioData.push_back(static_cast<float>(sample) / 32768.0f);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
file.close();
|
| 63 |
+
return audioData;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
/**
|
| 67 |
+
* @brief Generates a Hamming window.
|
| 68 |
+
* @param window_length The length of the window.
|
| 69 |
+
* @return A std::vector<float> containing the Hamming window coefficients.
|
| 70 |
+
*/
|
| 71 |
+
std::vector<float> generateHammingWindow(int window_length) {
|
| 72 |
+
std::vector<float> window(window_length);
|
| 73 |
+
for (int i = 0; i < window_length; ++i) {
|
| 74 |
+
window[i] = 0.54f - 0.46f * std::cos(2 * M_PI * i / static_cast<float>(window_length - 1));
|
| 75 |
+
}
|
| 76 |
+
return window;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
/**
|
| 80 |
+
* @brief Extracts spectrogram features from waveform, matching Python's _extract_spectrogram.
|
| 81 |
+
*
|
| 82 |
+
* @param wav The input waveform (1D array of floats).
|
| 83 |
+
* @param fs The sampling rate of the waveform (fixed to 16000 Hz for this model).
|
| 84 |
+
* @return A 2D Eigen::MatrixXf representing the spectrogram (frames x (N_FFT/2 + 1)).
|
| 85 |
+
*/
|
| 86 |
+
Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs) {
|
| 87 |
+
// Calculate number of frames
|
| 88 |
+
int n_batch = (wav.size() - WIN_LENGTH) / HOP_LENGTH + 1;
|
| 89 |
+
if (n_batch <= 0) {
|
| 90 |
+
std::cerr << "Warning: Input waveform too short for feature extraction. Returning empty spectrogram." << std::endl;
|
| 91 |
+
return Eigen::MatrixXf(0, N_FFT / 2 + 1);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
// Generate Hamming window once
|
| 95 |
+
std::vector<float> fft_window = generateHammingWindow(WIN_LENGTH);
|
| 96 |
+
// Initialize KissFFT for real-valued input
|
| 97 |
+
kiss_fftr_cfg fft_cfg = kiss_fftr_alloc(N_FFT, 0 /* is_inverse_fft */, nullptr, nullptr);
|
| 98 |
+
if (!fft_cfg) {
|
| 99 |
+
std::cerr << "Error: Failed to allocate KissFFT configuration." << std::endl;
|
| 100 |
+
return Eigen::MatrixXf(0, N_FFT / 2 + 1);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// Output spectrogram matrix: rows = frames, columns = FFT bins
|
| 104 |
+
Eigen::MatrixXf spec_matrix(n_batch, N_FFT / 2 + 1);
|
| 105 |
+
|
| 106 |
+
std::vector<float> frame_buffer(WIN_LENGTH);
|
| 107 |
+
std::vector<float> prev_frame_buffer(WIN_LENGTH);
|
| 108 |
+
kiss_fft_scalar fft_input[N_FFT]; // KissFFT requires input buffer of size N_FFT
|
| 109 |
+
kiss_fft_cpx fft_output[N_FFT / 2 + 1]; // KissFFT real output size
|
| 110 |
+
|
| 111 |
+
for (int i = 0; i < n_batch; ++i) {
|
| 112 |
+
int start_idx = i * HOP_LENGTH;
|
| 113 |
+
|
| 114 |
+
// Extract current frame
|
| 115 |
+
for (int j = 0; j < WIN_LENGTH; ++j) {
|
| 116 |
+
frame_buffer[j] = wav[start_idx + j];
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// Prepare previous frame for pre-emphasis (np.roll equivalent)
|
| 120 |
+
// y_frames_prev = np.roll(y_frames, 1, axis=1)
|
| 121 |
+
// y_frames_prev[:, 0] = y_frames_prev[:, 1]
|
| 122 |
+
prev_frame_buffer[0] = frame_buffer[0]; // Python's np.roll(..., 1) with axis=1 makes first element wrap around
|
| 123 |
+
// but then it's overwritten by y_frames_prev[:, 1]
|
| 124 |
+
if (WIN_LENGTH > 1) {
|
| 125 |
+
for (int j = 0; j < WIN_LENGTH - 1; ++j) {
|
| 126 |
+
prev_frame_buffer[j + 1] = frame_buffer[j];
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
// Correcting the first element as per Python code: y_frames_prev[:, 0] = y_frames_prev[:, 1]
|
| 130 |
+
// This means the first element of the 'previous' frame is actually the second element of the 'current' frame.
|
| 131 |
+
// For the first frame (i=0), prev_frame_buffer[0] should be frame_buffer[1] if WIN_LENGTH > 1.
|
| 132 |
+
// For subsequent frames, this logic applies to the *current* frame's first sample relative to its second.
|
| 133 |
+
// The original Python code effectively does:
|
| 134 |
+
// y_frames_prev = np.concatenate((y_frames[:, 1:2], y_frames[:, :-1]), axis=1)
|
| 135 |
+
// This is a bit tricky. Let's simplify and apply pre-emphasis directly to the current frame elements.
|
| 136 |
+
// The Python code applies pre-emphasis *within* each batch/frame.
|
| 137 |
+
// y_frames = (y_frames - preemphasis * y_frames_prev)
|
| 138 |
+
// y_frames_prev[:, 0] = y_frames_prev[:, 1] means the first element of the previous frame is taken from the second element of the *current* frame.
|
| 139 |
+
// This is equivalent to: frame[j] - preemphasis * (j == 0 ? frame[1] : frame[j-1])
|
| 140 |
+
// Let's use a temporary buffer for pre-emphasized frame.
|
| 141 |
+
std::vector<float> preemphasized_frame(WIN_LENGTH);
|
| 142 |
+
if (WIN_LENGTH > 0) {
|
| 143 |
+
preemphasized_frame[0] = frame_buffer[0]; // First sample is not pre-emphasized against a previous sample
|
| 144 |
+
if (WIN_LENGTH > 1) {
|
| 145 |
+
for (int j = 1; j < WIN_LENGTH; ++j) {
|
| 146 |
+
preemphasized_frame[j] = frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j - 1];
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
// Apply pre-emphasis and scale by 32768 (as in Python)
|
| 151 |
+
for (int j = 0; j < WIN_LENGTH; ++j) {
|
| 152 |
+
fft_input[j] = preemphasized_frame[j] * 32768.0f;
|
| 153 |
+
// Pad with zeros if WIN_LENGTH < N_FFT
|
| 154 |
+
if (j >= WIN_LENGTH) {
|
| 155 |
+
fft_input[j] = 0.0f;
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
// Zero-pad the rest of the FFT input if WIN_LENGTH < N_FFT
|
| 159 |
+
for (int j = WIN_LENGTH; j < N_FFT; ++j) {
|
| 160 |
+
fft_input[j] = 0.0f;
|
| 161 |
+
}
|
| 162 |
+
// Apply Hamming window
|
| 163 |
+
for (int j = 0; j < WIN_LENGTH; ++j) {
|
| 164 |
+
fft_input[j] *= fft_window[j];
|
| 165 |
+
}
|
| 166 |
+
// Perform real FFT
|
| 167 |
+
kiss_fftr(fft_cfg, fft_input, fft_output);
|
| 168 |
+
// Calculate magnitude spectrogram
|
| 169 |
+
for (int j = 0; j <= N_FFT / 2; ++j) {
|
| 170 |
+
spec_matrix(i, j) = std::sqrt(fft_output[j].r * fft_output[j].r + fft_output[j].i * fft_output[j].i);
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
kiss_fftr_free(fft_cfg); // Free KissFFT configuration
|
| 174 |
+
return spec_matrix;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
/**
|
| 178 |
+
* @brief Creates a Mel filter-bank matrix, matching Python's speechlib_mel.
|
| 179 |
+
*
|
| 180 |
+
* @param sample_rate Sample rate in Hz.
|
| 181 |
+
* @param n_fft FFT size.
|
| 182 |
+
* @param n_mels Mel filter size.
|
| 183 |
+
* @param fmin Lowest frequency (in Hz).
|
| 184 |
+
* @param fmax Highest frequency (in Hz).
|
| 185 |
+
* @return An Eigen::MatrixXf representing the Mel transform matrix (n_mels x (1 + n_fft/2)).
|
| 186 |
+
*/
|
| 187 |
+
Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax) {
|
| 188 |
+
int bank_width = n_fft / 2 + 1;
|
| 189 |
+
if (fmax == 0.0f) fmax = sample_rate / 2.0f; // Use 0.0f as a sentinel for None
|
| 190 |
+
if (fmin == 0.0f) fmin = 0.0f; // Use 0.0f as a sentinel for None
|
| 191 |
+
|
| 192 |
+
// Helper functions for Mel scale conversion
|
| 193 |
+
auto mel = [](float f) { return 1127.0f * std::log(1.0f + f / 700.0f); };
|
| 194 |
+
auto bin2mel = [&](int fft_bin) { return 1127.0f * std::log(1.0f + static_cast<float>(fft_bin) * sample_rate / (static_cast<float>(n_fft) * 700.0f)); };
|
| 195 |
+
auto f2bin = [&](float f) { return static_cast<int>((f * n_fft / sample_rate) + 0.5f); };
|
| 196 |
+
|
| 197 |
+
// Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax)]
|
| 198 |
+
int klo = f2bin(fmin) + 1;
|
| 199 |
+
int khi = f2bin(fmax);
|
| 200 |
+
khi = std::max(khi, klo);
|
| 201 |
+
|
| 202 |
+
// Spec 2: SpeechLib uses triangles in Mel space
|
| 203 |
+
float mlo = mel(fmin);
|
| 204 |
+
float mhi = mel(fmax);
|
| 205 |
+
|
| 206 |
+
// Generate Mel centers
|
| 207 |
+
std::vector<float> m_centers(n_mels + 2);
|
| 208 |
+
float ms = (mhi - mlo) / (n_mels + 1);
|
| 209 |
+
for (int i = 0; i < n_mels + 2; ++i) {
|
| 210 |
+
m_centers[i] = mlo + i * ms;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
Eigen::MatrixXf matrix = Eigen::MatrixXf::Zero(n_mels, bank_width);
|
| 214 |
+
|
| 215 |
+
for (int m = 0; m < n_mels; ++m) {
|
| 216 |
+
float left = m_centers[m];
|
| 217 |
+
float center = m_centers[m + 1];
|
| 218 |
+
float right = m_centers[m + 2];
|
| 219 |
+
for (int fft_bin = klo; fft_bin < bank_width; ++fft_bin) { // Loop up to bank_width-1
|
| 220 |
+
float mbin = bin2mel(fft_bin);
|
| 221 |
+
if (left < mbin && mbin < right) {
|
| 222 |
+
matrix(m, fft_bin) = 1.0f - std::abs(center - mbin) / ms;
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
matrix.transposeInPlace();
|
| 227 |
+
return matrix;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
/**
|
| 231 |
+
* @brief Extracts log filterbank features from waveform, matching Python's _extract_features.
|
| 232 |
+
*
|
| 233 |
+
* @param wav The input waveform (1D array of floats).
|
| 234 |
+
* @param fs The sampling rate of the waveform (fixed to 16000 Hz).
|
| 235 |
+
* @param mel_filterbank The pre-computed Mel filterbank matrix.
|
| 236 |
+
* @return An Eigen::MatrixXf representing the log Mel filterbank features (frames x N_MELS).
|
| 237 |
+
*/
|
| 238 |
+
Eigen::MatrixXf extractFeatures(const std::vector<float>& wav, int fs, const Eigen::MatrixXf& mel_filterbank) {
|
| 239 |
+
// Extract spectrogram
|
| 240 |
+
Eigen::MatrixXf spec = extractSpectrogram(wav, fs);
|
| 241 |
+
if (spec.rows() == 0) {
|
| 242 |
+
return Eigen::MatrixXf(0, N_MELS); // Return empty matrix if spectrogram extraction failed
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// spec_power = spec**2
|
| 246 |
+
Eigen::MatrixXf spec_power = spec.array().square();
|
| 247 |
+
|
| 248 |
+
// fbank_power = np.clip(spec_power.dot(_mel), 1.0, None)
|
| 249 |
+
// Note: Eigen's matrix multiplication is `*`, not `dot`.
|
| 250 |
+
// The Python `dot` for 2D arrays is matrix multiplication.
|
| 251 |
+
// Python: (frames, N_FFT/2+1) . (N_FFT/2+1, N_MELS) -> (frames, N_MELS)
|
| 252 |
+
// C++ Eigen: spec_power (rows, cols) * mel_filterbank (cols, N_MELS)
|
| 253 |
+
// So, mel_filterbank should be (N_FFT/2+1, N_MELS)
|
| 254 |
+
Eigen::MatrixXf fbank_power = spec_power * mel_filterbank;
|
| 255 |
+
|
| 256 |
+
// Apply clipping: np.clip(..., 1.0, None)
|
| 257 |
+
// This means any value less than 1.0 becomes 1.0.
|
| 258 |
+
fbank_power = fbank_power.array().max(1.0f);
|
| 259 |
+
|
| 260 |
+
// log_fbank = np.log(fbank_power).astype(np.float32)
|
| 261 |
+
Eigen::MatrixXf log_fbank = fbank_power.array().log();
|
| 262 |
+
|
| 263 |
+
return log_fbank;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
int main(int argc, char* argv[]) {
|
| 268 |
+
// --- 1. Process command-line arguments ---
|
| 269 |
+
if (argc != 3) {
|
| 270 |
+
std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_pcm_file>" << std::endl;
|
| 271 |
+
std::cerr << "Example: " << argv[0] << " model.onnx audio.pcm" << std::endl;
|
| 272 |
+
return 1;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
std::string onnxModelPath = argv[1];
|
| 276 |
+
std::string pcmFilename = argv[2];
|
| 277 |
+
|
| 278 |
+
// --- Configuration for Audio and ONNX Model ---
|
| 279 |
+
// These are fixed by the Python preprocessor code and model requirements.
|
| 280 |
+
int bitDepth = 16;
|
| 281 |
+
// numChannels is handled within loadPcmToFloatArray and then implicitly by feature extraction
|
| 282 |
+
// which squeezes to 1D and takes mean if stereo. For simplicity, we assume mono PCM input.
|
| 283 |
+
// If your PCM is stereo, you'd need to adjust loadPcmToFloatArray to handle channel interleaving
|
| 284 |
+
// and then average or select a channel before passing to extractSpectrogram.
|
| 285 |
+
int numChannels = 1;
|
| 286 |
+
|
| 287 |
+
// --- Create a dummy PCM file if it doesn't exist for demonstration ---
|
| 288 |
+
// This is helpful for initial testing without needing an actual PCM file.
|
| 289 |
+
std::ifstream pcmCheck(pcmFilename, std::ios::binary);
|
| 290 |
+
if (!pcmCheck.is_open()) {
|
| 291 |
+
std::cerr << "PCM file '" << pcmFilename << "' not found. Creating a dummy one for demonstration." << std::endl;
|
| 292 |
+
std::ofstream dummyPcmFile(pcmFilename, std::ios::binary);
|
| 293 |
+
if (dummyPcmFile.is_open()) {
|
| 294 |
+
std::cout << "Creating a dummy PCM file: " << pcmFilename << " ("
|
| 295 |
+
<< (TARGET_SAMPLE_RATE * 2 * sizeof(int16_t)) / 1024 << " KB)" << std::endl;
|
| 296 |
+
for (int i = 0; i < TARGET_SAMPLE_RATE * 2; ++i) { // Generate 2 seconds of audio
|
| 297 |
+
int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(TARGET_SAMPLE_RATE)));
|
| 298 |
+
dummyPcmFile.write(reinterpret_cast<char*>(&sample), sizeof(sample));
|
| 299 |
+
}
|
| 300 |
+
dummyPcmFile.close();
|
| 301 |
+
} else {
|
| 302 |
+
std::cerr << "Error: Could not create dummy PCM file '" << pcmFilename
|
| 303 |
+
<< "'. Please ensure the directory is writable." << std::endl;
|
| 304 |
+
return 1;
|
| 305 |
+
}
|
| 306 |
+
} else {
|
| 307 |
+
pcmCheck.close();
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
// --- 2. Load PCM audio data into a float array ---
|
| 312 |
+
std::vector<float> audioWav = loadPcmToFloatArray(pcmFilename);
|
| 313 |
+
|
| 314 |
+
if (audioWav.empty()) {
|
| 315 |
+
std::cerr << "Failed to load audio data from " << pcmFilename << ". Exiting." << std::endl;
|
| 316 |
+
return 1;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
std::cout << "Successfully loaded " << audioWav.size() << " samples from " << pcmFilename << std::endl;
|
| 320 |
+
|
| 321 |
+
// --- 3. Precompute Mel filterbank (as it's constant for a given sample rate/FFT size) ---
|
| 322 |
+
// The Python example uses fmax=16000//2-80-230. This translates to TARGET_SAMPLE_RATE/2 - 80 - 230.
|
| 323 |
+
// Using 0.0f for fmin as sentinel for None.
|
| 324 |
+
float mel_fmax = static_cast<float>(TARGET_SAMPLE_RATE) / 2.0f - 80.0f - 230.0f;
|
| 325 |
+
Eigen::MatrixXf mel_filterbank = speechlibMel(TARGET_SAMPLE_RATE, N_FFT, N_MELS, 0.0f, mel_fmax);
|
| 326 |
+
|
| 327 |
+
if (mel_filterbank.rows() == 0 || mel_filterbank.cols() == 0) {
|
| 328 |
+
std::cerr << "Error: Failed to create Mel filterbank. Exiting." << std::endl;
|
| 329 |
+
return 1;
|
| 330 |
+
}
|
| 331 |
+
std::cout << "Mel filterbank created with shape: [" << mel_filterbank.rows() << ", " << mel_filterbank.cols() << "]" << std::endl;
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
// --- 4. Apply feature extraction (preprocessor) ---
|
| 335 |
+
std::cout << "Extracting features from audio..." << std::endl;
|
| 336 |
+
Eigen::MatrixXf features = extractFeatures(audioWav, TARGET_SAMPLE_RATE, mel_filterbank);
|
| 337 |
+
|
| 338 |
+
std::ofstream outputFile("matrix_output.txt");
|
| 339 |
+
// Check if the file was opened successfully
|
| 340 |
+
if (outputFile.is_open()) {
|
| 341 |
+
// Iterate through rows and columns to write elements
|
| 342 |
+
for (int i = 0; i < features.rows(); ++i) {
|
| 343 |
+
for (int j = 0; j < features.cols(); ++j) {
|
| 344 |
+
outputFile << features(i, j); // Write the element
|
| 345 |
+
if (j < features.cols() - 1) {
|
| 346 |
+
outputFile << ","; // Add a space separator between elements in a row
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
outputFile << std::endl; // Move to the next line after each row
|
| 350 |
+
}
|
| 351 |
+
outputFile.close(); // Close the file
|
| 352 |
+
std::cout << "Matrix successfully written to matrix_output.txt" << std::endl;
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
if (features.rows() == 0 || features.cols() == 0) {
|
| 357 |
+
std::cerr << "Error: Feature extraction resulted in an empty matrix. Exiting." << std::endl;
|
| 358 |
+
return 1;
|
| 359 |
+
}
|
| 360 |
+
std::cout << "Features extracted with shape: [" << features.rows() << ", " << features.cols() << "]" << std::endl;
|
| 361 |
+
std::cout << "First few feature values (first frame): [";
|
| 362 |
+
for (int i = 0; i < std::min((int)features.cols(), 5); ++i) {
|
| 363 |
+
std::cout << features(0, i) << (i == std::min((int)features.cols(), 5) - 1 ? "" : ", ");
|
| 364 |
+
}
|
| 365 |
+
std::cout << "]" << std::endl;
|
| 366 |
+
|
| 367 |
+
// --- 5. Check for ONNX model existence and provide guidance if missing ---
|
| 368 |
+
std::ifstream onnxModelCheck(onnxModelPath, std::ios::binary);
|
| 369 |
+
if (!onnxModelCheck.is_open()) {
|
| 370 |
+
std::cerr << "\nError: ONNX model file '" << onnxModelPath << "' not found." << std::endl;
|
| 371 |
+
std::cerr << "Please provide a valid ONNX model file. If you need a simple dummy one for testing, "
|
| 372 |
+
<< "you can create it using Python (e.g., with PyTorch) like this:" << std::endl;
|
| 373 |
+
std::cerr << "```python" << std::endl;
|
| 374 |
+
std::cerr << "import torch" << std::endl;
|
| 375 |
+
std::cerr << "import torch.nn as nn" << std::endl;
|
| 376 |
+
std::cerr << "" << std::endl;
|
| 377 |
+
std::cerr << "class SimpleAudioModel(nn.Module):" << std::endl;
|
| 378 |
+
std::cerr << " def __init__(self, input_frames, feature_size, output_size):" << std::endl;
|
| 379 |
+
std::cerr << " super(SimpleAudioModel, self).__init__()" << std::endl;
|
| 380 |
+
std::cerr << " # This model expects input of shape [batch_size, frames, feature_size]" << std::endl;
|
| 381 |
+
std::cerr << " # Example: a simple linear layer that flattens input and processes it." << std::endl;
|
| 382 |
+
std::cerr << " self.flatten = nn.Flatten()" << std::endl;
|
| 383 |
+
std::cerr << " self.linear = nn.Linear(input_frames * feature_size, output_size)" << std::endl;
|
| 384 |
+
std::cerr << "" << std::endl;
|
| 385 |
+
std::cerr << " def forward(self, x):" << std::endl;
|
| 386 |
+
std::cerr << " x = self.flatten(x)" << std::endl;
|
| 387 |
+
std::cerr << " return self.linear(x)" << std::endl;
|
| 388 |
+
std::cerr << "" << std::endl;
|
| 389 |
+
std::cerr << "# --- IMPORTANT: Define model input and output sizes. Adjust these to match your actual model's requirements. ---" << std::endl;
|
| 390 |
+
std::cerr << "# The C++ preprocessor will produce features of shape [frames, 80]." << std::endl;
|
| 391 |
+
std::cerr << "# For a dummy model, we need to provide a fixed 'frames' value for ONNX export." << std::endl;
|
| 392 |
+
std::cerr << "# A typical audio segment might be 2 seconds at 16kHz, which is 32000 samples." << std::endl;
|
| 393 |
+
std::cerr << "# Frames = (32000 - 400) / 160 + 1 = 198.75 + 1 = 199 frames (approx)" << std::endl;
|
| 394 |
+
std::cerr << "# Let's use a representative number of frames, e.g., 200 for a dummy input." << std::endl;
|
| 395 |
+
std::cerr << "DUMMY_INPUT_FRAMES = 200 # This should be representative of your typical audio segment's frames" << std::endl;
|
| 396 |
+
std::cerr << "DUMMY_FEATURE_SIZE = 80 # Fixed by the Mel filterbank (N_MELS)" << std::endl;
|
| 397 |
+
std::cerr << "DUMMY_OUTPUT_SIZE = 10 # Example: 10 classification scores or features" << std::endl;
|
| 398 |
+
std::cerr << "" << std::endl;
|
| 399 |
+
std::cerr << "model = SimpleAudioModel(DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE, DUMMY_OUTPUT_SIZE)" << std::endl;
|
| 400 |
+
std::cerr << "dummy_input_tensor = torch.randn(1, DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE) # Batch size 1" << std::endl;
|
| 401 |
+
std::cerr << "" << std::endl;
|
| 402 |
+
std::cerr << "torch.onnx.export(" << std::endl;
|
| 403 |
+
std::cerr << " model," << std::endl;
|
| 404 |
+
std::cerr << " dummy_input_tensor," << std::endl;
|
| 405 |
+
std::cerr << " \"model.onnx\"," << std::endl;
|
| 406 |
+
std::cerr << " verbose=True," << std::endl;
|
| 407 |
+
std::cerr << " input_names=['input'], # Name of the input tensor in the ONNX graph" << std::endl;
|
| 408 |
+
std::cerr << " output_names=['output'], # Name of the output tensor in the ONNX graph" << std::endl;
|
| 409 |
+
std::cerr << " # Define dynamic axes for batch_size and frames" << std::endl;
|
| 410 |
+
std::cerr << " dynamic_axes={'input': {0: 'batch_size', 1: 'frames'}, 'output': {0: 'batch_size'}}" << std::endl;
|
| 411 |
+
std::cerr << ")" << std::endl;
|
| 412 |
+
std::cerr << "print(\"Dummy model.onnx created successfully. Remember to adjust DUMMY_INPUT_FRAMES in this script to match the expected number of frames from your audio segments.\")" << std::endl;
|
| 413 |
+
std::cerr << "```" << std::endl;
|
| 414 |
+
return 1;
|
| 415 |
+
}
|
| 416 |
+
onnxModelCheck.close();
|
| 417 |
+
std::cout << "ONNX model '" << onnxModelPath << "' found. Proceeding with inference." << std::endl;
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
// --- 6. ONNX Runtime Inference ---
|
| 421 |
+
try {
|
| 422 |
+
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "AudioInference");
|
| 423 |
+
Ort::SessionOptions session_options;
|
| 424 |
+
session_options.SetIntraOpNumThreads(1);
|
| 425 |
+
// session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
|
| 426 |
+
|
| 427 |
+
Ort::Session session(env, onnxModelPath.c_str(), session_options);
|
| 428 |
+
std::cout << "Model loaded successfully from: " << onnxModelPath << std::endl;
|
| 429 |
+
Ort::AllocatorWithDefaultOptions allocator;
|
| 430 |
+
|
| 431 |
+
// --- Get Input Node Information ---
|
| 432 |
+
size_t numInputNodes = session.GetInputCount();
|
| 433 |
+
std::vector<const char*> inputNodeNames(numInputNodes);
|
| 434 |
+
|
| 435 |
+
std::cout << "\n--- Model Input Information ---" << std::endl;
|
| 436 |
+
if (numInputNodes == 0) {
|
| 437 |
+
std::cerr << "Error: Model has no input nodes. Exiting." << std::endl;
|
| 438 |
+
return 1;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
// Assuming a single input node for simplicity
|
| 442 |
+
inputNodeNames[0] = "audio_embeds";
|
| 443 |
+
Ort::TypeInfo type_info = session.GetInputTypeInfo(0);
|
| 444 |
+
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
| 445 |
+
std::vector<int64_t> actualInputShape = tensor_info.GetShape();
|
| 446 |
+
|
| 447 |
+
std::cout << " Input 0 : Name='" << inputNodeNames[0] << "', Shape=[";
|
| 448 |
+
for (size_t j = 0; j < actualInputShape.size(); ++j) {
|
| 449 |
+
// Print -1 for dynamic dimensions
|
| 450 |
+
if (actualInputShape[j] == -1) {
|
| 451 |
+
std::cout << "-1";
|
| 452 |
+
} else {
|
| 453 |
+
std::cout << actualInputShape[j];
|
| 454 |
+
}
|
| 455 |
+
std::cout << (j == actualInputShape.size() - 1 ? "" : ", ");
|
| 456 |
+
}
|
| 457 |
+
std::cout << "]" << std::endl;
|
| 458 |
+
|
| 459 |
+
// --- Prepare Input Tensor Shape ---
|
| 460 |
+
// The ONNX model input is [batch, frames, feature_size] = [-1, -1, 80]
|
| 461 |
+
// Our extracted features are [frames, 80]. We need to add a batch dimension of 1.
|
| 462 |
+
std::vector<int64_t> inputTensorShape = {1, features.rows(), features.cols()};
|
| 463 |
+
std::cout << " Preparing input tensor with shape: [" << inputTensorShape[0] << ", "
|
| 464 |
+
<< inputTensorShape[1] << ", " << inputTensorShape[2] << "]" << std::endl;
|
| 465 |
+
|
| 466 |
+
// Flatten the Eigen::MatrixXf into a std::vector<float> for ONNX Runtime
|
| 467 |
+
std::vector<float> inputTensorData(features.data(), features.data() + features.size());
|
| 468 |
+
|
| 469 |
+
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
| 470 |
+
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, inputTensorData.data(), inputTensorData.size(),
|
| 471 |
+
inputTensorShape.data(), inputTensorShape.size());
|
| 472 |
+
|
| 473 |
+
if (!inputTensor.IsTensor()) {
|
| 474 |
+
std::cerr << "Error: Created input tensor is not valid! Exiting." << std::endl;
|
| 475 |
+
return 1;
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
// --- Get Output Node Information ---
|
| 479 |
+
size_t numOutputNodes = session.GetOutputCount();
|
| 480 |
+
std::vector<const char*> outputNodeNames(numOutputNodes);
|
| 481 |
+
|
| 482 |
+
std::cout << "\n--- Model Output Information ---" << std::endl;
|
| 483 |
+
for (size_t k = 0; k < numOutputNodes; ++k) {
|
| 484 |
+
outputNodeNames[k] = "audio_features";
|
| 485 |
+
Ort::TypeInfo type_info_out = session.GetOutputTypeInfo(k);
|
| 486 |
+
auto tensor_info_out = type_info_out.GetTensorTypeAndShapeInfo();
|
| 487 |
+
std::vector<int64_t> outputShape = tensor_info_out.GetShape();
|
| 488 |
+
std::cout << " Output " << k << " : Name='" << outputNodeNames[k] << "', Shape=[";
|
| 489 |
+
for (size_t l = 0; l < outputShape.size(); ++l) {
|
| 490 |
+
if (outputShape[l] == -1) {
|
| 491 |
+
std::cout << "-1";
|
| 492 |
+
} else {
|
| 493 |
+
std::cout << outputShape[l];
|
| 494 |
+
}
|
| 495 |
+
std::cout << (l == outputShape.size() - 1 ? "" : ", ");
|
| 496 |
+
}
|
| 497 |
+
std::cout << "]" << std::endl;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
// --- Run Inference ---
|
| 501 |
+
std::cout << "\nRunning ONNX model inference..." << std::endl;
|
| 502 |
+
std::vector<Ort::Value> outputTensors = session.Run(Ort::RunOptions{nullptr},
|
| 503 |
+
inputNodeNames.data(), &inputTensor, 1,
|
| 504 |
+
outputNodeNames.data(), numOutputNodes);
|
| 505 |
+
std::ofstream output_file("f0.txt");
|
| 506 |
+
for (auto& ort_value : outputTensors) {
|
| 507 |
+
// Example: Assuming Ort::Value contains a float tensor
|
| 508 |
+
if (ort_value.IsTensor()) {
|
| 509 |
+
float* data = ort_value.GetTensorMutableData<float>();
|
| 510 |
+
Ort::TensorTypeAndShapeInfo info = ort_value.GetTensorTypeAndShapeInfo();
|
| 511 |
+
size_t num_elements = info.GetElementCount();
|
| 512 |
+
|
| 513 |
+
for (size_t i = 0; i < num_elements; ++i) {
|
| 514 |
+
output_file << data[i];
|
| 515 |
+
if (i < num_elements - 1) {
|
| 516 |
+
output_file << ","; // Space separator between elements
|
| 517 |
+
}
|
| 518 |
+
}
|
| 519 |
+
output_file << std::endl; // Newline after each Ort::Value's content
|
| 520 |
+
} else {
|
| 521 |
+
// Handle other Ort::Value types if necessary (e.g., sequences, maps)
|
| 522 |
+
output_file << "Non-tensor Ort::Value" << std::endl;
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
output_file.close();
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
// --- Process Output ---
|
| 530 |
+
if (outputTensors.empty()) {
|
| 531 |
+
std::cerr << "Error: No output tensors received from the model." << std::endl;
|
| 532 |
+
return 1;
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
if (outputTensors[0].IsTensor()) {
|
| 536 |
+
float* outputData = outputTensors[0].GetTensorMutableData<float>();
|
| 537 |
+
Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
|
| 538 |
+
std::vector<int64_t> outputShape = outputShapeInfo.GetShape();
|
| 539 |
+
size_t outputSize = outputShapeInfo.GetElementCount();
|
| 540 |
+
|
| 541 |
+
std::cout << "\n--- Model Inference Result (first few elements) ---" << std::endl;
|
| 542 |
+
for (size_t k = 0; k < std::min((size_t)10, outputSize); ++k) {
|
| 543 |
+
std::cout << outputData[k] << (k == std::min((size_t)10, outputSize) - 1 ? "" : ", ");
|
| 544 |
+
}
|
| 545 |
+
std::cout << std::endl;
|
| 546 |
+
|
| 547 |
+
std::cout << "Full output tensor size: " << outputSize << " elements." << std::endl;
|
| 548 |
+
std::cout << "Full output tensor shape: [";
|
| 549 |
+
for (size_t k = 0; k < outputShape.size(); ++k) {
|
| 550 |
+
std::cout << outputShape[k] << (k == outputShape.size() - 1 ? "" : ", ");
|
| 551 |
+
}
|
| 552 |
+
std::cout << "]" << std::endl;
|
| 553 |
+
} else {
|
| 554 |
+
std::cerr << "Error: First output tensor is not of the expected type (float tensor)." << std::endl;
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
} catch (const Ort::Exception& e) {
|
| 558 |
+
std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
|
| 559 |
+
return 1;
|
| 560 |
+
} catch (const std::exception& e) {
|
| 561 |
+
std::cerr << "Standard Exception: " << e.what() << std::endl;
|
| 562 |
+
return 1;
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
std::cout << "\nProgram finished successfully." << std::endl;
|
| 566 |
+
return 0;
|
| 567 |
+
}
|
cpp/inference/test copy.cpp
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <iostream>
|
| 2 |
+
#include <vector>
|
| 3 |
+
#include <fstream> // For file input/output operations (e.g., std::ifstream, std::ofstream)
|
| 4 |
+
#include <cstdint> // For fixed-width integer types (e.g., int16_t)
|
| 5 |
+
#include <cmath> // For mathematical functions (e.g., std::sin, M_PI)
|
| 6 |
+
#include <numeric> // For numerical operations (not strictly used in this version but often useful)
|
| 7 |
+
#include <algorithm> // For algorithms like std::min
|
| 8 |
+
|
| 9 |
+
// Include the ONNX Runtime C++ API header
|
| 10 |
+
// You need to have ONNX Runtime installed and linked correctly in your build system.
|
| 11 |
+
// For example, using CMake, you might add:
|
| 12 |
+
// find_package(ONNXRuntime REQUIRED)
|
| 13 |
+
// target_link_libraries(your_executable PRIVATE ONNXRuntime::onnxruntime_cxx_api)
|
| 14 |
+
#include <onnxruntime_cxx_api.h>
|
| 15 |
+
|
| 16 |
+
// Define M_PI if it's not already defined by cmath or your compiler.
|
| 17 |
+
// This is common on Windows with MSVC unless _USE_MATH_DEFINES is set.
|
| 18 |
+
#ifndef M_PI
|
| 19 |
+
#define M_PI 3.14159265358979323846
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
std::vector<float> loadPcmToFloatArray(const std::string& filename, int bitDepth, int numChannels) {
|
| 24 |
+
// Open the PCM file in binary mode for reading
|
| 25 |
+
std::ifstream file(filename, std::ios::binary);
|
| 26 |
+
if (!file.is_open()) {
|
| 27 |
+
std::cerr << "Error: Could not open PCM file: " << filename << std::endl;
|
| 28 |
+
return {}; // Return empty vector on failure
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
std::vector<float> audioData; // Vector to store the normalized float audio samples
|
| 32 |
+
|
| 33 |
+
// Check if the bit depth is supported (this example only handles 16-bit)
|
| 34 |
+
if (bitDepth == 16) {
|
| 35 |
+
int16_t sample; // Buffer to read a single 16-bit sample
|
| 36 |
+
|
| 37 |
+
// Read samples until the end of the file
|
| 38 |
+
while (file.read(reinterpret_cast<char*>(&sample), sizeof(sample))) {
|
| 39 |
+
// Normalize 16-bit signed integer to float in range [-1.0, 1.0]
|
| 40 |
+
// The maximum positive value for int16_t is 32767.
|
| 41 |
+
// Dividing by 32768.0f (which is 2^15) ensures that 32767 maps to
|
| 42 |
+
// slightly less than 1.0, and -32768 maps to -1.0, maintaining
|
| 43 |
+
// the full dynamic range and avoiding overflow for -32768.
|
| 44 |
+
audioData.push_back(static_cast<float>(sample) / 32768.0f);
|
| 45 |
+
}
|
| 46 |
+
} else {
|
| 47 |
+
std::cerr << "Error: Unsupported bit depth: " << bitDepth << ". This example only supports 16-bit PCM." << std::endl;
|
| 48 |
+
return {}; // Return empty vector for unsupported bit depth
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
file.close(); // Close the file
|
| 52 |
+
return audioData; // Return the loaded audio data
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
int main() {
|
| 56 |
+
// --- Configuration for Audio and ONNX Model ---
|
| 57 |
+
std::string pcmFilename = "/mnt/data-2t/jeff/codes/llm/cpp/sample_data/pickup_breezy-common_voice_zh-TW_17376838-breezyvoice-00818.pcm"; // Name of the PCM audio file to load
|
| 58 |
+
int bitDepth = 16; // Bit depth of the PCM data (e.g., 16-bit)
|
| 59 |
+
int numChannels = 1; // Number of audio channels (e.g., 1 for mono)
|
| 60 |
+
int sampleRate = 16000; // Sample rate of the audio (e.g., 16000 Hz)
|
| 61 |
+
std::string onnxModelPath = "/mnt/data-2t/jeff/codes/llm/cpp/onnx_files/speech_init_export/phi-4-mm-speech.onnx"; // Path to your ONNX model file
|
| 62 |
+
|
| 63 |
+
// --- 2. Load PCM audio data into a float array ---
|
| 64 |
+
std::vector<float> audioInput = loadPcmToFloatArray(pcmFilename, bitDepth, numChannels);
|
| 65 |
+
|
| 66 |
+
if (audioInput.empty()) {
|
| 67 |
+
std::cerr << "Failed to load audio data from " << pcmFilename << ". Exiting." << std::endl;
|
| 68 |
+
return 1; // Exit if audio data loading failed
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
std::cout << "Successfully loaded " << audioInput.size() << " samples from " << pcmFilename << std::endl;
|
| 72 |
+
|
| 73 |
+
// --- 3. Check for ONNX model existence and provide guidance if missing ---
|
| 74 |
+
// This step is critical. You need a valid ONNX model.
|
| 75 |
+
std::ifstream onnxModelCheck(onnxModelPath, std::ios::binary);
|
| 76 |
+
if (!onnxModelCheck.is_open()) {
|
| 77 |
+
std::cerr << "\nError: ONNX model file '" << onnxModelPath << "' not found." << std::endl;
|
| 78 |
+
std::cerr << "Please provide a valid ONNX model file. If you need a simple dummy one for testing, "
|
| 79 |
+
<< "you can create it using Python (e.g., with PyTorch) like this:" << std::endl;
|
| 80 |
+
std::cerr << "```python" << std::endl;
|
| 81 |
+
std::cerr << "import torch" << std::endl;
|
| 82 |
+
std::cerr << "import torch.nn as nn" << std::endl;
|
| 83 |
+
std::cerr << "" << std::endl;
|
| 84 |
+
std::cerr << "class SimpleAudioModel(nn.Module):" << std::endl;
|
| 85 |
+
std::cerr << " def __init__(self, input_size, output_size):" << std::endl;
|
| 86 |
+
std::cerr << " super(SimpleAudioModel, self).__init__()" << std::endl;
|
| 87 |
+
std::cerr << " # This is a very simple linear layer. Your actual model will be more complex." << std::endl;
|
| 88 |
+
std::cerr << " # This model expects input of shape [batch_size, input_size]" << std::endl;
|
| 89 |
+
std::cerr << " self.linear = nn.Linear(input_size, output_size)" << std::endl;
|
| 90 |
+
std::cerr << "" << std::endl;
|
| 91 |
+
std::cerr << " def forward(self, x):" << std::endl;
|
| 92 |
+
std::cerr << " # If your model expects a different input shape (e.g., [batch_size, channels, samples])," << std::endl;
|
| 93 |
+
std::cerr << " # you might need to reshape 'x' here before passing it to your layers (e.g., x.view(x.size(0), 1, -1))." << std::endl;
|
| 94 |
+
std::cerr << " return self.linear(x)" << std::endl;
|
| 95 |
+
std::cerr << "" << std::endl;
|
| 96 |
+
std::cerr << "# --- IMPORTANT: Define model input and output sizes. Adjust these to match your actual model's requirements. ---" << std::endl;
|
| 97 |
+
std::cerr << "# For this dummy model, we'll assume an input size matching our 2-second, 44.1kHz mono audio." << std::endl;
|
| 98 |
+
std::cerr << "DUMMY_INPUT_SIZE = " << (sampleRate * 2) << " # Corresponds to " << (sampleRate * 2) / static_cast<float>(sampleRate) << " seconds of audio at " << sampleRate << " Hz mono" << std::endl;
|
| 99 |
+
std::cerr << "DUMMY_OUTPUT_SIZE = 10 # Example: 10 classification scores or features" << std::endl;
|
| 100 |
+
std::cerr << "" << std::endl;
|
| 101 |
+
std::cerr << "model = SimpleAudioModel(DUMMY_INPUT_SIZE, DUMMY_OUTPUT_SIZE)" << std::endl;
|
| 102 |
+
std::cerr << "dummy_input_tensor = torch.randn(1, DUMMY_INPUT_SIZE) # Batch size 1, DUMMY_INPUT_SIZE features" << std::endl;
|
| 103 |
+
std::cerr << "" << std::endl;
|
| 104 |
+
std::cerr << "torch.onnx.export(" << std::endl;
|
| 105 |
+
std::cerr << " model," << std::endl;
|
| 106 |
+
std::cerr << " dummy_input_tensor," << std::endl;
|
| 107 |
+
std::cerr << " \"model.onnx\"," << std::endl;
|
| 108 |
+
std::cerr << " verbose=True," << std::endl;
|
| 109 |
+
std::cerr << " input_names=['input'], # Name of the input tensor in the ONNX graph" << std::endl;
|
| 110 |
+
std::cerr << " output_names=['output'], # Name of the output tensor in the ONNX graph" << std::endl;
|
| 111 |
+
std::cerr << " # Optional: Define dynamic axes if your batch size or sequence length can vary" << std::endl;
|
| 112 |
+
std::cerr << " dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}" << std::endl;
|
| 113 |
+
std::cerr << ")" << std::endl;
|
| 114 |
+
std::cerr << "print(\"Dummy model.onnx created successfully. Remember to adjust DUMMY_INPUT_SIZE in this script to match the length of your audio data or ensure your C++ code pads/truncates the audio data to the model's expected input size.\")" << std::endl;
|
| 115 |
+
std::cerr << "```" << std::endl;
|
| 116 |
+
return 1; // Exit if the ONNX model is not found
|
| 117 |
+
}
|
| 118 |
+
onnxModelCheck.close();
|
| 119 |
+
std::cout << "ONNX model '" << onnxModelPath << "' found. Proceeding with inference." << std::endl;
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
// --- 4. ONNX Runtime Inference ---
|
| 123 |
+
try {
|
| 124 |
+
// Create an ONNX Runtime environment. This is the entry point for all ONNX Runtime operations.
|
| 125 |
+
// ORT_LOGGING_LEVEL_WARNING suppresses verbose output unless there's a warning or error.
|
| 126 |
+
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "AudioInference");
|
| 127 |
+
|
| 128 |
+
// Configure session options.
|
| 129 |
+
Ort::SessionOptions session_options;
|
| 130 |
+
session_options.SetIntraOpNumThreads(1); // Use 1 thread for operations within a single node
|
| 131 |
+
session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED); // Apply all available graph optimizations
|
| 132 |
+
|
| 133 |
+
// Create an ONNX Runtime session by loading the model.
|
| 134 |
+
Ort::Session session(env, onnxModelPath.c_str(), session_options);
|
| 135 |
+
|
| 136 |
+
// Get model input and output names and shapes.
|
| 137 |
+
// An allocator is needed to manage memory for allocated strings (like node names).
|
| 138 |
+
Ort::AllocatorWithDefaultOptions allocator;
|
| 139 |
+
|
| 140 |
+
// --- Get Input Node Information ---
|
| 141 |
+
size_t numInputNodes = session.GetInputCount();
|
| 142 |
+
std::vector<const char*> inputNodeNames(numInputNodes); // To store input node names
|
| 143 |
+
|
| 144 |
+
std::cout << "\n--- Model Input Information ---" << std::endl;
|
| 145 |
+
// Iterate through all input nodes (models usually have one main input)
|
| 146 |
+
for (size_t i = 0; i < numInputNodes; ++i) {
|
| 147 |
+
// Get the input node name
|
| 148 |
+
inputNodeNames[i] = session.GetInputNameAllocated(i, allocator).get();
|
| 149 |
+
|
| 150 |
+
// Get the type and shape information for the input tensor
|
| 151 |
+
Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
|
| 152 |
+
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
| 153 |
+
std::vector<int64_t> actualInputShape = tensor_info.GetShape(); // Get the shape the model *expects*
|
| 154 |
+
|
| 155 |
+
std::cout << " Input " << i << " : Name='" << inputNodeNames[i] << "', Shape=[";
|
| 156 |
+
for (size_t j = 0; j < actualInputShape.size(); ++j) {
|
| 157 |
+
std::cout << actualInputShape[j] << (j == actualInputShape.size() - 1 ? "" : ", ");
|
| 158 |
+
}
|
| 159 |
+
std::cout << "]" << std::endl;
|
| 160 |
+
|
| 161 |
+
// --- Prepare Input Tensor Shape ---
|
| 162 |
+
// This is a CRITICAL step. The `audioInput` vector must be reshaped
|
| 163 |
+
// to precisely match the ONNX model's expected input tensor shape.
|
| 164 |
+
// The dummy Python model provided above creates an input of shape [1, DUMMY_INPUT_SIZE].
|
| 165 |
+
// We need to ensure `audioInput` matches `DUMMY_INPUT_SIZE` or pad/truncate it.
|
| 166 |
+
std::vector<int64_t> inputTensorShape; // This will be the shape of the tensor we create
|
| 167 |
+
|
| 168 |
+
if (actualInputShape.size() == 2 && actualInputShape[0] == 1) {
|
| 169 |
+
// Case: Model expects a 2D input with batch size 1 (e.g., [1, num_features])
|
| 170 |
+
int64_t expected_length = actualInputShape[1]; // The expected number of features/samples
|
| 171 |
+
|
| 172 |
+
// Check if the loaded audio data size matches the model's expected input length
|
| 173 |
+
if (audioInput.size() != expected_length) {
|
| 174 |
+
std::cout << " Warning: Loaded audio input size (" << audioInput.size()
|
| 175 |
+
<< ") does not match model's expected input length (" << expected_length << ")." << std::endl;
|
| 176 |
+
std::cout << " Padding/truncating audio data to match model input size." << std::endl;
|
| 177 |
+
audioInput.resize(expected_length, 0.0f); // Pad with zeros or truncate the audio data
|
| 178 |
+
}
|
| 179 |
+
inputTensorShape = {1, expected_length}; // Set the tensor shape for ONNX Runtime
|
| 180 |
+
} else if (actualInputShape.size() == 1) {
|
| 181 |
+
// Case: Model expects a 1D input (e.g., [num_features])
|
| 182 |
+
int64_t expected_length = actualInputShape[0];
|
| 183 |
+
|
| 184 |
+
if (audioInput.size() != expected_length) {
|
| 185 |
+
std::cout << " Warning: Loaded audio input size (" << audioInput.size()
|
| 186 |
+
<< ") does not match model's expected input length (" << expected_length << ")." << std::endl;
|
| 187 |
+
std::cout << " Padding/truncating audio data to match model input size." << std::endl;
|
| 188 |
+
audioInput.resize(expected_length, 0.0f); // Pad with zeros or truncate
|
| 189 |
+
}
|
| 190 |
+
inputTensorShape = {expected_length}; // Set the tensor shape for ONNX Runtime
|
| 191 |
+
} else {
|
| 192 |
+
std::cerr << "Error: Model input shape is not supported by this example ([N] or [1, N]). "
|
| 193 |
+
<< "Please adjust the input tensor shape creation logic in C++ to match your model's specific requirements." << std::endl;
|
| 194 |
+
return 1; // Exit if the input shape is not handled
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// Create an ONNX Runtime memory info object for CPU memory.
|
| 198 |
+
// This specifies where the tensor data is located (CPU in this case).
|
| 199 |
+
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
| 200 |
+
|
| 201 |
+
// Create the input tensor from the audio data.
|
| 202 |
+
// `audioInput.data()` provides a pointer to the raw float data.
|
| 203 |
+
// `audioInput.size()` is the total number of elements.
|
| 204 |
+
// `inputTensorShape.data()` provides the shape array.
|
| 205 |
+
// `inputTensorShape.size()` is the number of dimensions.
|
| 206 |
+
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, audioInput.data(), audioInput.size(),
|
| 207 |
+
inputTensorShape.data(), inputTensorShape.size());
|
| 208 |
+
|
| 209 |
+
// Verify that the created input tensor is valid
|
| 210 |
+
if (!inputTensor.IsTensor()) {
|
| 211 |
+
std::cerr << "Error: Created input tensor is not valid! This might indicate a shape mismatch or data issue." << std::endl;
|
| 212 |
+
return 1; // Exit if the tensor is invalid
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// At this point, `inputTensor` is ready to be fed into the model.
|
| 216 |
+
// For simplicity, we assume there's only one input to the model.
|
| 217 |
+
// If your model has multiple inputs, you'd need to create multiple Ort::Value objects.
|
| 218 |
+
|
| 219 |
+
// --- Get Output Node Information ---
|
| 220 |
+
size_t numOutputNodes = session.GetOutputCount();
|
| 221 |
+
std::vector<const char*> outputNodeNames(numOutputNodes); // To store output node names
|
| 222 |
+
|
| 223 |
+
std::cout << "\n--- Model Output Information ---" << std::endl;
|
| 224 |
+
// Iterate through all output nodes
|
| 225 |
+
for (size_t k = 0; k < numOutputNodes; ++k) {
|
| 226 |
+
outputNodeNames[k] = session.GetOutputNameAllocated(k, allocator).get();
|
| 227 |
+
Ort::TypeInfo type_info_out = session.GetOutputTypeInfo(k);
|
| 228 |
+
auto tensor_info_out = type_info_out.GetTensorTypeAndShapeInfo();
|
| 229 |
+
std::vector<int64_t> outputShape = tensor_info_out.GetShape();
|
| 230 |
+
std::cout << " Output " << k << " : Name='" << outputNodeNames[k] << "', Shape=[";
|
| 231 |
+
for (size_t l = 0; l < outputShape.size(); ++l) {
|
| 232 |
+
std::cout << outputShape[l] << (l == outputShape.size() - 1 ? "" : ", ");
|
| 233 |
+
}
|
| 234 |
+
std::cout << "]" << std::endl;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
// --- Run Inference ---
|
| 238 |
+
std::cout << "\nRunning ONNX model inference..." << std::endl;
|
| 239 |
+
// The `session.Run` method executes the model.
|
| 240 |
+
// Arguments:
|
| 241 |
+
// - Ort::RunOptions{nullptr}: Default run options.
|
| 242 |
+
// - inputNodeNames.data(): Array of C-style strings for input names.
|
| 243 |
+
// - &inputTensor: Pointer to the array of input tensors (here, just one).
|
| 244 |
+
// - 1: Number of input tensors.
|
| 245 |
+
// - outputNodeNames.data(): Array of C-style strings for output names.
|
| 246 |
+
// - numOutputNodes: Number of output tensors expected.
|
| 247 |
+
std::vector<Ort::Value> outputTensors = session.Run(Ort::RunOptions{nullptr},
|
| 248 |
+
inputNodeNames.data(), &inputTensor, 1,
|
| 249 |
+
outputNodeNames.data(), numOutputNodes);
|
| 250 |
+
|
| 251 |
+
// --- Process Output ---
|
| 252 |
+
if (outputTensors.empty()) {
|
| 253 |
+
std::cerr << "Error: No output tensors received from the model." << std::endl;
|
| 254 |
+
return 1; // Exit if no output
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// Assuming the first output is a float tensor (common for most models)
|
| 258 |
+
if (outputTensors[0].IsTensor()) {
|
| 259 |
+
// Get a mutable pointer to the raw data of the output tensor
|
| 260 |
+
float* outputData = outputTensors[0].GetTensorMutableData<float>();
|
| 261 |
+
Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
|
| 262 |
+
std::vector<int64_t> outputShape = outputShapeInfo.GetShape();
|
| 263 |
+
size_t outputSize = outputShapeInfo.GetElementCount(); // Total number of elements in the output tensor
|
| 264 |
+
|
| 265 |
+
std::cout << "\n--- Model Inference Result (first few elements) ---" << std::endl;
|
| 266 |
+
// Print the first 10 elements of the output (or fewer if output is smaller)
|
| 267 |
+
for (size_t k = 0; k < std::min((size_t)10, outputSize); ++k) {
|
| 268 |
+
std::cout << outputData[k] << (k == std::min((size_t)10, outputSize) - 1 ? "" : ", ");
|
| 269 |
+
}
|
| 270 |
+
std::cout << std::endl;
|
| 271 |
+
|
| 272 |
+
std::cout << "Full output tensor size: " << outputSize << " elements." << std::endl;
|
| 273 |
+
std::cout << "Full output tensor shape: [";
|
| 274 |
+
for (size_t k = 0; k < outputShape.size(); ++k) {
|
| 275 |
+
std::cout << outputShape[k] << (k == outputShape.size() - 1 ? "" : ", ");
|
| 276 |
+
}
|
| 277 |
+
std::cout << "]" << std::endl;
|
| 278 |
+
|
| 279 |
+
// Here you would typically interpret the model's output based on its purpose.
|
| 280 |
+
// For example:
|
| 281 |
+
// - For classification: Find the index of the maximum value (highest probability).
|
| 282 |
+
// - For regression: Use the numerical output directly.
|
| 283 |
+
// - For feature extraction: Use the output vector as features for further processing.
|
| 284 |
+
} else {
|
| 285 |
+
std::cerr << "Error: First output tensor is not of the expected type (float tensor)." << std::endl;
|
| 286 |
+
}
|
| 287 |
+
} // End of loop for input nodes (assuming single input for simplicity in this example)
|
| 288 |
+
|
| 289 |
+
} catch (const Ort::Exception& e) {
|
| 290 |
+
// Catch ONNX Runtime specific exceptions
|
| 291 |
+
std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
|
| 292 |
+
return 1;
|
| 293 |
+
} catch (const std::exception& e) {
|
| 294 |
+
// Catch other standard exceptions
|
| 295 |
+
std::cerr << "Standard Exception: " << e.what() << std::endl;
|
| 296 |
+
return 1;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
std::cout << "\nProgram finished successfully." << std::endl;
|
| 300 |
+
return 0;
|
| 301 |
+
}
|
cpp/inference/test.cpp
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <iostream> // For standard input/output operations (e.g., std::cout, std::cerr)
|
| 2 |
+
#include <vector> // For dynamic arrays (e.g., std::vector<float>)
|
| 3 |
+
#include <fstream> // For file input/output operations (e.g., std::ifstream, std::ofstream)
|
| 4 |
+
#include <cstdint> // For fixed-width integer types (e.g., int16_t, uint32_t)
|
| 5 |
+
#include <cmath> // For mathematical functions (e.g., std::sin, M_PI, std::log)
|
| 6 |
+
#include <numeric> // For numerical operations (e.g., std::iota)
|
| 7 |
+
#include <algorithm> // For algorithms like std::min, std::max
|
| 8 |
+
#include <string> // For std::string
|
| 9 |
+
|
| 10 |
+
// Include the ONNX Runtime C++ API header
|
| 11 |
+
#include <onnxruntime_cxx_api.h>
|
| 12 |
+
|
| 13 |
+
// Include Eigen for powerful matrix operations.
|
| 14 |
+
// You need to download Eigen and set up your include paths.
|
| 15 |
+
// E.g., if Eigen is in 'C:/Libraries/eigen-3.4.0', you'd compile with -I C:/Libraries/eigen-3.4.0
|
| 16 |
+
#include <Eigen/Dense>
|
| 17 |
+
|
| 18 |
+
// Include KissFFT for Fast Fourier Transform.
|
| 19 |
+
// You need to download KissFFT and set up your include paths.
|
| 20 |
+
// E.g., if KissFFT is in 'C:/Libraries/kissfft-1.3.0', you'd compile with -I C:/Libraries/kissfft-1.3.0
|
| 21 |
+
// You also need to compile kiss_fft.c and kiss_fftr.c and link them.
|
| 22 |
+
#include <kiss_fft.h>
|
| 23 |
+
#include <kiss_fftr.h> // For real-valued FFT
|
| 24 |
+
|
| 25 |
+
// Define M_PI if it's not already defined by cmath or your compiler.
|
| 26 |
+
#ifndef M_PI
|
| 27 |
+
#define M_PI 3.14159265358979323846
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
// --- Global parameters for feature extraction (matching Python script) ---
|
| 31 |
+
const float PREEMPHASIS_COEFF = 0.97f;
|
| 32 |
+
const int N_FFT = 512; // FFT size
|
| 33 |
+
const int WIN_LENGTH = 400; // Window length (samples)
|
| 34 |
+
const int HOP_LENGTH = 160; // Hop length (samples)
|
| 35 |
+
const int N_MELS = 80; // Number of Mel filterbank channels
|
| 36 |
+
const int TARGET_SAMPLE_RATE = 16000; // Target sample rate for feature extraction
|
| 37 |
+
|
| 38 |
+
// --- WAV File Header Structures ---
|
| 39 |
+
// These structures are for parsing the WAV file format.
|
| 40 |
+
// They assume little-endian byte order, which is standard for WAV files on most systems.
|
| 41 |
+
#pragma pack(push, 1) // Ensure no padding for these structures
|
| 42 |
+
|
| 43 |
+
struct WavHeader {
|
| 44 |
+
char riff_id[4]; // Contains "RIFF"
|
| 45 |
+
uint32_t file_size; // Size of the overall file - 8 bytes
|
| 46 |
+
char wave_id[4]; // Contains "WAVE"
|
| 47 |
+
char fmt_id[4]; // Contains "fmt " (note the space)
|
| 48 |
+
uint32_t fmt_size; // Size of the fmt chunk (16 for PCM)
|
| 49 |
+
uint16_t audio_format; // Audio format (1 for PCM)
|
| 50 |
+
uint16_t num_channels; // Number of channels (1 for mono, 2 for stereo)
|
| 51 |
+
uint32_t sample_rate; // Sample rate (e.g., 44100 Hz)
|
| 52 |
+
uint32_t byte_rate; // (SampleRate * NumChannels * BitsPerSample) / 8
|
| 53 |
+
uint16_t block_align; // (NumChannels * BitsPerSample) / 8
|
| 54 |
+
uint16_t bits_per_sample;// Bits per sample (e.g., 16)
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
struct WavDataChunk {
|
| 58 |
+
char data_id[4]; // Contains "data"
|
| 59 |
+
uint32_t data_size; // Size of the data chunk
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
#pragma pack(pop) // Restore default packing alignment
|
| 63 |
+
|
| 64 |
+
/**
|
| 65 |
+
* @brief Loads audio data from a WAV file into a float vector.
|
| 66 |
+
*
|
| 67 |
+
* This function reads a WAV file, parses its header, extracts 16-bit signed
|
| 68 |
+
* integer PCM samples, converts them to floating-point values, and normalizes
|
| 69 |
+
* them to the range [-1.0, 1.0]. It supports mono and stereo (converting stereo to mono
|
| 70 |
+
* by averaging channels).
|
| 71 |
+
*
|
| 72 |
+
* @param filename The path to the WAV audio file.
|
| 73 |
+
* @param actual_sample_rate Output parameter to store the sample rate read from the WAV file.
|
| 74 |
+
* @return A std::vector<float> containing the normalized mono audio samples, or an empty
|
| 75 |
+
* vector if the file cannot be opened or is not a supported WAV format.
|
| 76 |
+
*/
|
| 77 |
+
std::vector<float> loadWavToFloatArray(const std::string& filename, int& actual_sample_rate) {
|
| 78 |
+
std::ifstream file(filename, std::ios::binary);
|
| 79 |
+
if (!file.is_open()) {
|
| 80 |
+
std::cerr << "Error: Could not open WAV file: " << filename << std::endl;
|
| 81 |
+
return {};
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
WavHeader header;
|
| 85 |
+
file.read(reinterpret_cast<char*>(&header), sizeof(WavHeader));
|
| 86 |
+
|
| 87 |
+
// Basic header validation
|
| 88 |
+
if (std::string(header.riff_id, 4) != "RIFF" ||
|
| 89 |
+
std::string(header.wave_id, 4) != "WAVE" ||
|
| 90 |
+
std::string(header.fmt_id, 4) != "fmt ") {
|
| 91 |
+
std::cerr << "Error: Invalid WAV header (RIFF, WAVE, or fmt chunk missing/invalid)." << std::endl;
|
| 92 |
+
file.close();
|
| 93 |
+
return {};
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
if (header.audio_format != 1) { // 1 = PCM
|
| 97 |
+
std::cerr << "Error: Only PCM audio format (1) is supported. Found: " << header.audio_format << std::endl;
|
| 98 |
+
file.close();
|
| 99 |
+
return {};
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
if (header.bits_per_sample != 16) {
|
| 103 |
+
std::cerr << "Error: Only 16-bit PCM is supported. Found: " << header.bits_per_sample << " bits per sample." << std::endl;
|
| 104 |
+
file.close();
|
| 105 |
+
return {};
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
actual_sample_rate = header.sample_rate;
|
| 109 |
+
std::cout << "WAV file info: Sample Rate=" << header.sample_rate
|
| 110 |
+
<< ", Channels=" << header.num_channels
|
| 111 |
+
<< ", Bit Depth=" << header.bits_per_sample << std::endl;
|
| 112 |
+
|
| 113 |
+
// Find the "data" chunk
|
| 114 |
+
WavDataChunk data_chunk;
|
| 115 |
+
bool data_chunk_found = false;
|
| 116 |
+
while (!file.eof()) {
|
| 117 |
+
file.read(reinterpret_cast<char*>(&data_chunk.data_id), 4);
|
| 118 |
+
file.read(reinterpret_cast<char*>(&data_chunk.data_size), 4);
|
| 119 |
+
|
| 120 |
+
if (std::string(data_chunk.data_id, 4) == "data") {
|
| 121 |
+
data_chunk_found = true;
|
| 122 |
+
break;
|
| 123 |
+
} else {
|
| 124 |
+
// Skip unknown chunks
|
| 125 |
+
file.seekg(data_chunk.data_size, std::ios::cur);
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
if (!data_chunk_found) {
|
| 130 |
+
std::cerr << "Error: 'data' chunk not found in WAV file." << std::endl;
|
| 131 |
+
file.close();
|
| 132 |
+
return {};
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
std::vector<float> audioData;
|
| 136 |
+
int16_t sample_buffer;
|
| 137 |
+
long num_samples_to_read = data_chunk.data_size / sizeof(int16_t);
|
| 138 |
+
|
| 139 |
+
for (long i = 0; i < num_samples_to_read; ++i) {
|
| 140 |
+
file.read(reinterpret_cast<char*>(&sample_buffer), sizeof(int16_t));
|
| 141 |
+
float normalized_sample = static_cast<float>(sample_buffer) / 32768.0f;
|
| 142 |
+
|
| 143 |
+
if (header.num_channels == 1) {
|
| 144 |
+
audioData.push_back(normalized_sample);
|
| 145 |
+
} else if (header.num_channels == 2) {
|
| 146 |
+
// For stereo, read both left and right, then average for mono output
|
| 147 |
+
// Read next sample (right channel)
|
| 148 |
+
int16_t right_sample;
|
| 149 |
+
if (file.read(reinterpret_cast<char*>(&right_sample), sizeof(int16_t))) {
|
| 150 |
+
float normalized_right_sample = static_cast<float>(right_sample) / 32768.0f;
|
| 151 |
+
audioData.push_back((normalized_sample + normalized_right_sample) / 2.0f);
|
| 152 |
+
i++; // Increment i again as we read two samples
|
| 153 |
+
} else {
|
| 154 |
+
std::cerr << "Warning: Unexpected end of file while reading stereo data." << std::endl;
|
| 155 |
+
break;
|
| 156 |
+
}
|
| 157 |
+
} else {
|
| 158 |
+
std::cerr << "Error: Unsupported number of channels: " << header.num_channels << std::endl;
|
| 159 |
+
file.close();
|
| 160 |
+
return {};
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
file.close();
|
| 165 |
+
return audioData;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
/**
|
| 169 |
+
* @brief Generates a Hamming window.
|
| 170 |
+
* @param window_length The length of the window.
|
| 171 |
+
* @return A std::vector<float> containing the Hamming window coefficients.
|
| 172 |
+
*/
|
| 173 |
+
std::vector<float> generateHammingWindow(int window_length) {
|
| 174 |
+
std::vector<float> window(window_length);
|
| 175 |
+
for (int i = 0; i < window_length; ++i) {
|
| 176 |
+
window[i] = 0.54f - 0.46f * std::cos(2 * M_PI * i / static_cast<float>(window_length - 1));
|
| 177 |
+
}
|
| 178 |
+
return window;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
/**
|
| 182 |
+
* @brief Extracts spectrogram features from waveform, matching Python's _extract_spectrogram.
|
| 183 |
+
*
|
| 184 |
+
* @param wav The input waveform (1D array of floats).
|
| 185 |
+
* @param fs The sampling rate of the waveform (fixed to 16000 Hz for this model).
|
| 186 |
+
* @return A 2D Eigen::MatrixXf representing the spectrogram (frames x (N_FFT/2 + 1)).
|
| 187 |
+
*/
|
| 188 |
+
Eigen::MatrixXf extractSpectrogram(const std::vector<float>& wav, int fs) {
|
| 189 |
+
// Calculate number of frames
|
| 190 |
+
int n_batch = (wav.size() - WIN_LENGTH) / HOP_LENGTH + 1;
|
| 191 |
+
if (n_batch <= 0) {
|
| 192 |
+
std::cerr << "Warning: Input waveform too short for feature extraction. Returning empty spectrogram." << std::endl;
|
| 193 |
+
return Eigen::MatrixXf(0, N_FFT / 2 + 1);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
// Generate Hamming window once
|
| 197 |
+
std::vector<float> fft_window = generateHammingWindow(WIN_LENGTH);
|
| 198 |
+
|
| 199 |
+
// Initialize KissFFT for real-valued input
|
| 200 |
+
kiss_fftr_cfg fft_cfg = kiss_fftr_alloc(N_FFT, 0 /* is_inverse_fft */, nullptr, nullptr);
|
| 201 |
+
if (!fft_cfg) {
|
| 202 |
+
std::cerr << "Error: Failed to allocate KissFFT configuration." << std::endl;
|
| 203 |
+
return Eigen::MatrixXf(0, N_FFT / 2 + 1);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
// Output spectrogram matrix: rows = frames, columns = FFT bins
|
| 207 |
+
Eigen::MatrixXf spec_matrix(n_batch, N_FFT / 2 + 1);
|
| 208 |
+
|
| 209 |
+
std::vector<float> frame_buffer(WIN_LENGTH);
|
| 210 |
+
kiss_fft_scalar fft_input[N_FFT]; // KissFFT requires input buffer of size N_FFT
|
| 211 |
+
kiss_fft_cpx fft_output[N_FFT / 2 + 1]; // KissFFT real output size
|
| 212 |
+
|
| 213 |
+
for (int i = 0; i < n_batch; ++i) {
|
| 214 |
+
int start_idx = i * HOP_LENGTH;
|
| 215 |
+
|
| 216 |
+
// Extract current frame
|
| 217 |
+
for (int j = 0; j < WIN_LENGTH; ++j) {
|
| 218 |
+
frame_buffer[j] = wav[start_idx + j];
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// Apply pre-emphasis and scale by 32768 (as in Python)
|
| 222 |
+
// Python: y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
|
| 223 |
+
// where y_frames_prev[:, 0] = y_frames_prev[:, 1]
|
| 224 |
+
// This means for j=0, it's frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]
|
| 225 |
+
// For j>0, it's frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j-1]
|
| 226 |
+
// Let's re-evaluate the pre-emphasis based on the Python snippet:
|
| 227 |
+
// y_frames_prev = np.roll(y_frames, 1, axis=1)
|
| 228 |
+
// y_frames_prev[:, 0] = y_frames_prev[:, 1]
|
| 229 |
+
// This means the first element of `y_frames_prev` for each frame is the second element of `y_frames`.
|
| 230 |
+
// So, for the first sample in a frame, it's `frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]`.
|
| 231 |
+
// For subsequent samples, it's `frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j-1]`.
|
| 232 |
+
// This is a common pre-emphasis filter, but the first sample handling is specific.
|
| 233 |
+
|
| 234 |
+
// Corrected pre-emphasis implementation to match the Python `np.roll` behavior:
|
| 235 |
+
// The Python code effectively does:
|
| 236 |
+
// preemphasized_sample[0] = frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1] (if WIN_LENGTH > 1)
|
| 237 |
+
// preemphasized_sample[j] = frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j-1] for j > 0
|
| 238 |
+
// If WIN_LENGTH is 1, then it's just frame_buffer[0] (no pre-emphasis)
|
| 239 |
+
if (WIN_LENGTH > 0) {
|
| 240 |
+
if (WIN_LENGTH > 1) {
|
| 241 |
+
fft_input[0] = (frame_buffer[0] - PREEMPHASIS_COEFF * frame_buffer[1]) * 32768.0f;
|
| 242 |
+
for (int j = 1; j < WIN_LENGTH; ++j) {
|
| 243 |
+
fft_input[j] = (frame_buffer[j] - PREEMPHASIS_COEFF * frame_buffer[j - 1]) * 32768.0f;
|
| 244 |
+
}
|
| 245 |
+
} else { // WIN_LENGTH == 1
|
| 246 |
+
fft_input[0] = frame_buffer[0] * 32768.0f;
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
// Zero-pad the rest of the FFT input if WIN_LENGTH < N_FFT
|
| 250 |
+
for (int j = WIN_LENGTH; j < N_FFT; ++j) {
|
| 251 |
+
fft_input[j] = 0.0f;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// Apply Hamming window
|
| 255 |
+
for (int j = 0; j < WIN_LENGTH; ++j) {
|
| 256 |
+
fft_input[j] *= fft_window[j];
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Perform real FFT
|
| 260 |
+
kiss_fftr(fft_cfg, fft_input, fft_output);
|
| 261 |
+
|
| 262 |
+
// Calculate magnitude spectrogram
|
| 263 |
+
for (int j = 0; j <= N_FFT / 2; ++j) {
|
| 264 |
+
spec_matrix(i, j) = std::sqrt(fft_output[j].r * fft_output[j].r + fft_output[j].i * fft_output[j].i);
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
kiss_fftr_free(fft_cfg); // Free KissFFT configuration
|
| 269 |
+
return spec_matrix;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
/**
|
| 273 |
+
* @brief Creates a Mel filter-bank matrix, matching Python's speechlib_mel.
|
| 274 |
+
*
|
| 275 |
+
* @param sample_rate Sample rate in Hz.
|
| 276 |
+
* @param n_fft FFT size.
|
| 277 |
+
* @param n_mels Mel filter size.
|
| 278 |
+
* @param fmin Lowest frequency (in Hz).
|
| 279 |
+
* @param fmax Highest frequency (in Hz).
|
| 280 |
+
* @return An Eigen::MatrixXf representing the Mel transform matrix (n_mels x (1 + n_fft/2)).
|
| 281 |
+
*/
|
| 282 |
+
Eigen::MatrixXf speechlibMel(int sample_rate, int n_fft, int n_mels, float fmin, float fmax) {
|
| 283 |
+
int bank_width = n_fft / 2 + 1;
|
| 284 |
+
if (fmax == 0.0f) fmax = sample_rate / 2.0f; // Use 0.0f as a sentinel for None
|
| 285 |
+
if (fmin == 0.0f) fmin = 0.0f; // Use 0.0f as a sentinel for None
|
| 286 |
+
|
| 287 |
+
// Helper functions for Mel scale conversion
|
| 288 |
+
auto mel = [](float f) { return 1127.0f * std::log(1.0f + f / 700.0f); };
|
| 289 |
+
auto bin2mel = [&](int fft_bin) { return 1127.0f * std::log(1.0f + static_cast<float>(fft_bin) * sample_rate / (static_cast<float>(n_fft) * 700.0f)); };
|
| 290 |
+
auto f2bin = [&](float f) { return static_cast<int>((f * n_fft / sample_rate) + 0.5f); };
|
| 291 |
+
|
| 292 |
+
// Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax)]
|
| 293 |
+
int klo = f2bin(fmin) + 1;
|
| 294 |
+
int khi = f2bin(fmax);
|
| 295 |
+
khi = std::max(khi, klo);
|
| 296 |
+
|
| 297 |
+
// Spec 2: SpeechLib uses triangles in Mel space
|
| 298 |
+
float mlo = mel(fmin);
|
| 299 |
+
float mhi = mel(fmax);
|
| 300 |
+
|
| 301 |
+
// Generate Mel centers
|
| 302 |
+
std::vector<float> m_centers(n_mels + 2);
|
| 303 |
+
float ms = (mhi - mlo) / (n_mels + 1);
|
| 304 |
+
for (int i = 0; i < n_mels + 2; ++i) {
|
| 305 |
+
m_centers[i] = mlo + i * ms;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
Eigen::MatrixXf matrix = Eigen::MatrixXf::Zero(n_mels, bank_width);
|
| 309 |
+
|
| 310 |
+
for (int m = 0; m < n_mels; ++m) {
|
| 311 |
+
float left = m_centers[m];
|
| 312 |
+
float center = m_centers[m + 1];
|
| 313 |
+
float right = m_centers[m + 2];
|
| 314 |
+
for (int fft_bin = klo; fft_bin < bank_width; ++fft_bin) { // Loop up to bank_width-1
|
| 315 |
+
float mbin = bin2mel(fft_bin);
|
| 316 |
+
if (left < mbin && mbin < right) {
|
| 317 |
+
matrix(m, fft_bin) = 1.0f - std::abs(center - mbin) / ms;
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
return matrix;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
/**
|
| 325 |
+
* @brief Extracts log filterbank features from waveform, matching Python's _extract_features.
|
| 326 |
+
*
|
| 327 |
+
* @param wav The input waveform (1D array of floats).
|
| 328 |
+
* @param fs The sampling rate of the waveform (fixed to 16000 Hz).
|
| 329 |
+
* @param mel_filterbank The pre-computed Mel filterbank matrix.
|
| 330 |
+
* @return An Eigen::MatrixXf representing the log Mel filterbank features (frames x N_MELS).
|
| 331 |
+
*/
|
| 332 |
+
Eigen::MatrixXf extractFeatures(const std::vector<float>& wav, int fs, const Eigen::MatrixXf& mel_filterbank) {
|
| 333 |
+
// Extract spectrogram
|
| 334 |
+
Eigen::MatrixXf spec = extractSpectrogram(wav, fs);
|
| 335 |
+
if (spec.rows() == 0) {
|
| 336 |
+
return Eigen::MatrixXf(0, N_MELS); // Return empty matrix if spectrogram extraction failed
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
// spec_power = spec**2
|
| 340 |
+
Eigen::MatrixXf spec_power = spec.array().square();
|
| 341 |
+
|
| 342 |
+
// fbank_power = np.clip(spec_power.dot(_mel), 1.0, None)
|
| 343 |
+
// Note: Eigen's matrix multiplication is `*`, not `dot`.
|
| 344 |
+
// The Python `dot` for 2D arrays is matrix multiplication.
|
| 345 |
+
// Python: (frames, N_FFT/2+1) . (N_FFT/2+1, N_MELS) -> (frames, N_MELS)
|
| 346 |
+
// C++ Eigen: spec_power (rows, cols) * mel_filterbank (cols, N_MELS)
|
| 347 |
+
// So, mel_filterbank should be (N_FFT/2+1, N_MELS)
|
| 348 |
+
Eigen::MatrixXf fbank_power = spec_power * mel_filterbank.transpose(); // Transpose because Python's _mel is already transposed
|
| 349 |
+
|
| 350 |
+
// Apply clipping: np.clip(..., 1.0, None)
|
| 351 |
+
// This means any value less than 1.0 becomes 1.0.
|
| 352 |
+
fbank_power = fbank_power.array().max(1.0f);
|
| 353 |
+
|
| 354 |
+
// log_fbank = np.log(fbank_power).astype(np.float32)
|
| 355 |
+
Eigen::MatrixXf log_fbank = fbank_power.array().log();
|
| 356 |
+
|
| 357 |
+
return log_fbank;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// Function to write a dummy WAV file
|
| 361 |
+
void createDummyWavFile(const std::string& filename, int sampleRate, int numChannels, int bitsPerSample, double durationSeconds) {
|
| 362 |
+
std::ofstream file(filename, std::ios::binary);
|
| 363 |
+
if (!file.is_open()) {
|
| 364 |
+
std::cerr << "Error: Could not create dummy WAV file: " << filename << std::endl;
|
| 365 |
+
return;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
WavHeader header;
|
| 369 |
+
std::memcpy(header.riff_id, "RIFF", 4);
|
| 370 |
+
std::memcpy(header.wave_id, "WAVE", 4);
|
| 371 |
+
std::memcpy(header.fmt_id, "fmt ", 4);
|
| 372 |
+
header.fmt_size = 16;
|
| 373 |
+
header.audio_format = 1; // PCM
|
| 374 |
+
header.num_channels = numChannels;
|
| 375 |
+
header.sample_rate = sampleRate;
|
| 376 |
+
header.bits_per_sample = bitsPerSample;
|
| 377 |
+
header.byte_rate = (sampleRate * numChannels * bitsPerSample) / 8;
|
| 378 |
+
header.block_align = (numChannels * bitsPerSample) / 8;
|
| 379 |
+
|
| 380 |
+
WavDataChunk data_chunk;
|
| 381 |
+
std::memcpy(data_chunk.data_id, "data", 4);
|
| 382 |
+
uint32_t num_samples = static_cast<uint32_t>(sampleRate * durationSeconds);
|
| 383 |
+
data_chunk.data_size = num_samples * numChannels * (bitsPerSample / 8);
|
| 384 |
+
header.file_size = 36 + data_chunk.data_size; // 36 is size of header before data chunk
|
| 385 |
+
|
| 386 |
+
file.write(reinterpret_cast<const char*>(&header), sizeof(WavHeader));
|
| 387 |
+
file.write(reinterpret_cast<const char*>(&data_chunk), sizeof(WavDataChunk));
|
| 388 |
+
|
| 389 |
+
// Generate a 440 Hz sine wave
|
| 390 |
+
for (uint32_t i = 0; i < num_samples; ++i) {
|
| 391 |
+
int16_t sample = static_cast<int16_t>(30000 * std::sin(2 * M_PI * 440 * i / static_cast<double>(sampleRate)));
|
| 392 |
+
for (int c = 0; c < numChannels; ++c) {
|
| 393 |
+
file.write(reinterpret_cast<const char*>(&sample), sizeof(int16_t));
|
| 394 |
+
}
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
file.close();
|
| 398 |
+
std::cout << "Dummy WAV file '" << filename << "' created successfully." << std::endl;
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
int main(int argc, char* argv[]) {
|
| 403 |
+
// --- 1. Process command-line arguments ---
|
| 404 |
+
if (argc != 3) {
|
| 405 |
+
std::cerr << "Usage: " << argv[0] << " <path_to_onnx_model> <path_to_wav_file>" << std::endl;
|
| 406 |
+
std::cerr << "Example: " << argv[0] << " model.onnx audio.wav" << std::endl;
|
| 407 |
+
return 1;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
std::string onnxModelPath = argv[1];
|
| 411 |
+
std::string wavFilename = argv[2]; // Changed to wavFilename
|
| 412 |
+
|
| 413 |
+
// --- Configuration for Audio and ONNX Model ---
|
| 414 |
+
// These are fixed by the Python preprocessor code and model requirements.
|
| 415 |
+
// The actual sample rate will be read from the WAV file.
|
| 416 |
+
int actual_wav_sample_rate = 0;
|
| 417 |
+
|
| 418 |
+
// --- Create a dummy WAV file if it doesn't exist for demonstration ---
|
| 419 |
+
std::ifstream wavCheck(wavFilename, std::ios::binary);
|
| 420 |
+
if (!wavCheck.is_open()) {
|
| 421 |
+
std::cerr << "WAV file '" << wavFilename << "' not found. Creating a dummy one for demonstration." << std::endl;
|
| 422 |
+
// Create a 2-second, 16kHz, mono, 16-bit WAV file
|
| 423 |
+
createDummyWavFile(wavFilename, TARGET_SAMPLE_RATE, 1, 16, 2.0);
|
| 424 |
+
} else {
|
| 425 |
+
wavCheck.close();
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
// --- 2. Load WAV audio data into a float array ---
|
| 429 |
+
std::vector<float> audioWav = loadWavToFloatArray(wavFilename, actual_wav_sample_rate);
|
| 430 |
+
|
| 431 |
+
if (audioWav.empty()) {
|
| 432 |
+
std::cerr << "Failed to load audio data from " << wavFilename << ". Exiting." << std::endl;
|
| 433 |
+
return 1;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
std::cout << "Successfully loaded " << audioWav.size() << " samples from " << wavFilename << std::endl;
|
| 437 |
+
|
| 438 |
+
// --- Validate WAV sample rate against target sample rate ---
|
| 439 |
+
if (actual_wav_sample_rate != TARGET_SAMPLE_RATE) {
|
| 440 |
+
std::cerr << "Warning: WAV file sample rate (" << actual_wav_sample_rate
|
| 441 |
+
<< " Hz) does not match the target sample rate for feature extraction ("
|
| 442 |
+
<< TARGET_SAMPLE_RATE << " Hz)." << std::endl;
|
| 443 |
+
std::cerr << "This example does NOT include resampling. Features will be extracted at "
|
| 444 |
+
<< TARGET_SAMPLE_RATE << " Hz, which might lead to incorrect results if the WAV file's sample rate is different." << std::endl;
|
| 445 |
+
// In a real application, you would implement resampling here (e.g., using libsamplerate).
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
// --- 3. Precompute Mel filterbank (as it's constant for a given sample rate/FFT size) ---
|
| 450 |
+
// The Python example uses fmax=16000//2-80-230. This translates to TARGET_SAMPLE_RATE/2 - 80 - 230.
|
| 451 |
+
// Using 0.0f for fmin as sentinel for None.
|
| 452 |
+
float mel_fmax = static_cast<float>(TARGET_SAMPLE_RATE) / 2.0f - 80.0f - 230.0f;
|
| 453 |
+
Eigen::MatrixXf mel_filterbank = speechlibMel(TARGET_SAMPLE_RATE, N_FFT, N_MELS, 0.0f, mel_fmax);
|
| 454 |
+
|
| 455 |
+
if (mel_filterbank.rows() == 0 || mel_filterbank.cols() == 0) {
|
| 456 |
+
std::cerr << "Error: Failed to create Mel filterbank. Exiting." << std::endl;
|
| 457 |
+
return 1;
|
| 458 |
+
}
|
| 459 |
+
std::cout << "Mel filterbank created with shape: [" << mel_filterbank.rows() << ", " << mel_filterbank.cols() << "]" << std::endl;
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
// --- 4. Apply feature extraction (preprocessor) ---
|
| 463 |
+
std::cout << "Extracting features from audio..." << std::endl;
|
| 464 |
+
Eigen::MatrixXf features = extractFeatures(audioWav, TARGET_SAMPLE_RATE, mel_filterbank);
|
| 465 |
+
|
| 466 |
+
///// check input
|
| 467 |
+
// std::ofstream outputFile("matrix_output.txt");
|
| 468 |
+
// // Check if the file was opened successfully
|
| 469 |
+
// if (outputFile.is_open()) {
|
| 470 |
+
// // Iterate through rows and columns to write elements
|
| 471 |
+
// for (int i = 0; i < features.rows(); ++i) {
|
| 472 |
+
// for (int j = 0; j < features.cols(); ++j) {
|
| 473 |
+
// outputFile << features(i, j); // Write the element
|
| 474 |
+
// if (j < features.cols() - 1) {
|
| 475 |
+
// outputFile << ","; // Add a space separator between elements in a row
|
| 476 |
+
// }
|
| 477 |
+
// }
|
| 478 |
+
// outputFile << std::endl; // Move to the next line after each row
|
| 479 |
+
// }
|
| 480 |
+
// outputFile.close(); // Close the file
|
| 481 |
+
// std::cout << "Matrix successfully written to matrix_output.txt" << std::endl;
|
| 482 |
+
// }
|
| 483 |
+
|
| 484 |
+
if (features.rows() == 0 || features.cols() == 0) {
|
| 485 |
+
std::cerr << "Error: Feature extraction resulted in an empty matrix. Exiting." << std::endl;
|
| 486 |
+
return 1;
|
| 487 |
+
}
|
| 488 |
+
std::cout << "Features extracted with shape: [" << features.rows() << ", " << features.cols() << "]" << std::endl;
|
| 489 |
+
std::cout << "First few feature values (first frame): [";
|
| 490 |
+
for (int i = 0; i < std::min((int)features.cols(), 5); ++i) {
|
| 491 |
+
std::cout << features(0, i) << (i == std::min((int)features.cols(), 5) - 1 ? "" : ", ");
|
| 492 |
+
}
|
| 493 |
+
std::cout << "]" << std::endl;
|
| 494 |
+
|
| 495 |
+
// --- 5. Check for ONNX model existence and provide guidance if missing ---
|
| 496 |
+
std::ifstream onnxModelCheck(onnxModelPath, std::ios::binary);
|
| 497 |
+
if (!onnxModelCheck.is_open()) {
|
| 498 |
+
std::cerr << "\nError: ONNX model file '" << onnxModelPath << "' not found." << std::endl;
|
| 499 |
+
std::cerr << "Please provide a valid ONNX model file. If you need a simple dummy one for testing, "
|
| 500 |
+
<< "you can create it using Python (e.g., with PyTorch) like this:" << std::endl;
|
| 501 |
+
std::cerr << "```python" << std::endl;
|
| 502 |
+
std::cerr << "import torch" << std::endl;
|
| 503 |
+
std::cerr << "import torch.nn as nn" << std::endl;
|
| 504 |
+
std::cerr << "" << std::endl;
|
| 505 |
+
std::cerr << "class SimpleAudioModel(nn.Module):" << std::endl;
|
| 506 |
+
std::cerr << " def __init__(self, input_frames, feature_size, output_size):" << std::endl;
|
| 507 |
+
std::cerr << " super(SimpleAudioModel, self).__init__()" << std::endl;
|
| 508 |
+
std::cerr << " # This model expects input of shape [batch_size, frames, feature_size]" << std::endl;
|
| 509 |
+
std::cerr << " # Example: a simple linear layer that flattens input and processes it." << std::endl;
|
| 510 |
+
std::cerr << " self.flatten = nn.Flatten()" << std::endl;
|
| 511 |
+
std::cerr << " self.linear = nn.Linear(input_frames * feature_size, output_size)" << std::endl;
|
| 512 |
+
std::cerr << "" << std::endl;
|
| 513 |
+
std::cerr << " def forward(self, x):" << std::endl;
|
| 514 |
+
std::cerr << " x = self.flatten(x)" << std::endl;
|
| 515 |
+
std::cerr << " return self.linear(x)" << std::endl;
|
| 516 |
+
std::cerr << "" << std::endl;
|
| 517 |
+
std::cerr << "# --- IMPORTANT: Define model input and output sizes. Adjust these to match your actual model's requirements. ---" << std::endl;
|
| 518 |
+
std::cerr << "# The C++ preprocessor will produce features of shape [frames, 80]." << std::endl;
|
| 519 |
+
std::cerr << "# For a dummy model, we need to provide a fixed 'frames' value for ONNX export." << std::endl;
|
| 520 |
+
std::cerr << "# A typical audio segment might be 2 seconds at 16kHz, which is 32000 samples." << std::endl;
|
| 521 |
+
std::cerr << "# Frames = (32000 - 400) / 160 + 1 = 198.75 + 1 = 199 frames (approx)" << std::endl;
|
| 522 |
+
std::cerr << "# Let's use a representative number of frames, e.g., 200 for a dummy input." << std::endl;
|
| 523 |
+
std::cerr << "DUMMY_INPUT_FRAMES = 200 # This should be representative of your typical audio segment's frames" << std::endl;
|
| 524 |
+
std::cerr << "DUMMY_FEATURE_SIZE = 80 # Fixed by the Mel filterbank (N_MELS)" << std::endl;
|
| 525 |
+
std::cerr << "DUMMY_OUTPUT_SIZE = 10 # Example: 10 classification scores or features" << std::endl;
|
| 526 |
+
std::cerr << "" << std::endl;
|
| 527 |
+
std::cerr << "model = SimpleAudioModel(DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE, DUMMY_OUTPUT_SIZE)" << std::endl;
|
| 528 |
+
std::cerr << "dummy_input_tensor = torch.randn(1, DUMMY_INPUT_FRAMES, DUMMY_FEATURE_SIZE) # Batch size 1" << std::endl;
|
| 529 |
+
std::cerr << "" << std::endl;
|
| 530 |
+
std::cerr << "torch.onnx.export(" << std::endl;
|
| 531 |
+
std::cerr << " model," << std::endl;
|
| 532 |
+
std::cerr << " dummy_input_tensor," << std::endl;
|
| 533 |
+
std::cerr << " \"model.onnx\"," << std::endl;
|
| 534 |
+
std::cerr << " verbose=True," << std::endl;
|
| 535 |
+
std::cerr << " input_names=['input'], # Name of the input tensor in the ONNX graph" << std::endl;
|
| 536 |
+
std::cerr << " output_names=['output'], # Name of the output tensor in the ONNX graph" << std::endl;
|
| 537 |
+
std::cerr << " # Define dynamic axes for batch_size and frames" << std::endl;
|
| 538 |
+
std::cerr << " dynamic_axes={'input': {0: 'batch_size', 1: 'frames'}, 'output': {0: 'batch_size'}}" << std::endl;
|
| 539 |
+
std::cerr << ")" << std::endl;
|
| 540 |
+
std::cerr << "print(\"Dummy model.onnx created successfully. Remember to adjust DUMMY_INPUT_FRAMES in this script to match the expected number of frames from your audio segments.\")" << std::endl;
|
| 541 |
+
std::cerr << "```" << std::endl;
|
| 542 |
+
return 1;
|
| 543 |
+
}
|
| 544 |
+
onnxModelCheck.close();
|
| 545 |
+
std::cout << "ONNX model '" << onnxModelPath << "' found. Proceeding with inference." << std::endl;
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
// --- 6. ONNX Runtime Inference ---
|
| 549 |
+
try {
|
| 550 |
+
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "AudioInference");
|
| 551 |
+
Ort::SessionOptions session_options;
|
| 552 |
+
session_options.SetIntraOpNumThreads(1);
|
| 553 |
+
session_options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
|
| 554 |
+
|
| 555 |
+
Ort::Session session(env, onnxModelPath.c_str(), session_options);
|
| 556 |
+
Ort::AllocatorWithDefaultOptions allocator;
|
| 557 |
+
|
| 558 |
+
// --- Get Input Node Information ---
|
| 559 |
+
size_t numInputNodes = session.GetInputCount();
|
| 560 |
+
std::vector<const char*> inputNodeNames(numInputNodes);
|
| 561 |
+
|
| 562 |
+
std::cout << "\n--- Model Input Information ---" << std::endl;
|
| 563 |
+
if (numInputNodes == 0) {
|
| 564 |
+
std::cerr << "Error: Model has no input nodes. Exiting." << std::endl;
|
| 565 |
+
return 1;
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
// Assuming a single input node for simplicity
|
| 569 |
+
inputNodeNames[0] = "audio_embeds";
|
| 570 |
+
Ort::TypeInfo type_info = session.GetInputTypeInfo(0);
|
| 571 |
+
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
| 572 |
+
std::vector<int64_t> actualInputShape = tensor_info.GetShape();
|
| 573 |
+
|
| 574 |
+
std::cout << " Input 0 : Name='" << inputNodeNames[0] << "', Shape=[";
|
| 575 |
+
for (size_t j = 0; j < actualInputShape.size(); ++j) {
|
| 576 |
+
// Print -1 for dynamic dimensions
|
| 577 |
+
if (actualInputShape[j] == -1) {
|
| 578 |
+
std::cout << "-1";
|
| 579 |
+
} else {
|
| 580 |
+
std::cout << actualInputShape[j];
|
| 581 |
+
}
|
| 582 |
+
std::cout << (j == actualInputShape.size() - 1 ? "" : ", ");
|
| 583 |
+
}
|
| 584 |
+
std::cout << "]" << std::endl;
|
| 585 |
+
|
| 586 |
+
// --- Prepare Input Tensor Shape ---
|
| 587 |
+
// The ONNX model input is [batch, frames, feature_size] = [-1, -1, 80]
|
| 588 |
+
// Our extracted features are [frames, 80]. We need to add a batch dimension of 1.
|
| 589 |
+
std::vector<int64_t> inputTensorShape = {1, features.rows(), features.cols()};
|
| 590 |
+
std::cout << " Preparing input tensor with shape: [" << inputTensorShape[0] << ", "
|
| 591 |
+
<< inputTensorShape[1] << ", " << inputTensorShape[2] << "]" << std::endl;
|
| 592 |
+
|
| 593 |
+
// Flatten the Eigen::MatrixXf into a std::vector<float> for ONNX Runtime
|
| 594 |
+
// Eigen stores in column-major order by default. ONNX Runtime expects row-major
|
| 595 |
+
// for flattened 2D data when reshaped to 3D [1, frames, features].
|
| 596 |
+
// We need to copy elements row by row to ensure correct order.
|
| 597 |
+
std::vector<float> inputTensorData(features.rows() * features.cols());
|
| 598 |
+
for (int r = 0; r < features.rows(); ++r) {
|
| 599 |
+
for (int c = 0; c < features.cols(); ++c) {
|
| 600 |
+
inputTensorData[r * features.cols() + c] = features(r, c);
|
| 601 |
+
}
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
| 605 |
+
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memory_info, inputTensorData.data(), inputTensorData.size(),
|
| 606 |
+
inputTensorShape.data(), inputTensorShape.size());
|
| 607 |
+
|
| 608 |
+
if (!inputTensor.IsTensor()) {
|
| 609 |
+
std::cerr << "Error: Created input tensor is not valid! Exiting." << std::endl;
|
| 610 |
+
return 1;
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
// --- Get Output Node Information ---
|
| 614 |
+
size_t numOutputNodes = session.GetOutputCount();
|
| 615 |
+
std::vector<const char*> outputNodeNames(numOutputNodes);
|
| 616 |
+
|
| 617 |
+
std::cout << "\n--- Model Output Information ---" << std::endl;
|
| 618 |
+
for (size_t k = 0; k < numOutputNodes; ++k) {
|
| 619 |
+
outputNodeNames[k] = "audio_features";
|
| 620 |
+
Ort::TypeInfo type_info_out = session.GetOutputTypeInfo(k);
|
| 621 |
+
auto tensor_info_out = type_info_out.GetTensorTypeAndShapeInfo();
|
| 622 |
+
std::vector<int64_t> outputShape = tensor_info_out.GetShape();
|
| 623 |
+
std::cout << " Output " << k << " : Name='" << outputNodeNames[k] << "', Shape=[";
|
| 624 |
+
for (size_t l = 0; l < outputShape.size(); ++l) {
|
| 625 |
+
if (outputShape[l] == -1) {
|
| 626 |
+
std::cout << "-1";
|
| 627 |
+
} else {
|
| 628 |
+
std::cout << outputShape[l];
|
| 629 |
+
}
|
| 630 |
+
std::cout << (l == outputShape.size() - 1 ? "" : ", ");
|
| 631 |
+
}
|
| 632 |
+
std::cout << "]" << std::endl;
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
// --- Run Inference ---
|
| 636 |
+
std::cout << "\nRunning ONNX model inference..." << std::endl;
|
| 637 |
+
std::vector<Ort::Value> outputTensors = session.Run(Ort::RunOptions{nullptr},
|
| 638 |
+
inputNodeNames.data(), &inputTensor, 1,
|
| 639 |
+
outputNodeNames.data(), numOutputNodes);
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
// std::ofstream output_file("f0.txt");
|
| 643 |
+
// for (auto& ort_value : outputTensors) {
|
| 644 |
+
// // Example: Assuming Ort::Value contains a float tensor
|
| 645 |
+
// if (ort_value.IsTensor()) {
|
| 646 |
+
// float* data = ort_value.GetTensorMutableData<float>();
|
| 647 |
+
// Ort::TensorTypeAndShapeInfo info = ort_value.GetTensorTypeAndShapeInfo();
|
| 648 |
+
// size_t num_elements = info.GetElementCount();
|
| 649 |
+
|
| 650 |
+
// for (size_t i = 0; i < num_elements; ++i) {
|
| 651 |
+
// output_file << data[i];
|
| 652 |
+
// if (i < num_elements - 1) {
|
| 653 |
+
// output_file << ","; // Space separator between elements
|
| 654 |
+
// }
|
| 655 |
+
// }
|
| 656 |
+
// output_file << std::endl; // Newline after each Ort::Value's content
|
| 657 |
+
// } else {
|
| 658 |
+
// // Handle other Ort::Value types if necessary (e.g., sequences, maps)
|
| 659 |
+
// output_file << "Non-tensor Ort::Value" << std::endl;
|
| 660 |
+
// }
|
| 661 |
+
// }
|
| 662 |
+
// output_file.close();
|
| 663 |
+
|
| 664 |
+
// --- Process Output ---
|
| 665 |
+
if (outputTensors.empty()) {
|
| 666 |
+
std::cerr << "Error: No output tensors received from the model." << std::endl;
|
| 667 |
+
return 1;
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
if (outputTensors[0].IsTensor()) {
|
| 671 |
+
float* outputData = outputTensors[0].GetTensorMutableData<float>();
|
| 672 |
+
Ort::TensorTypeAndShapeInfo outputShapeInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
|
| 673 |
+
std::vector<int64_t> outputShape = outputShapeInfo.GetShape();
|
| 674 |
+
size_t outputSize = outputShapeInfo.GetElementCount();
|
| 675 |
+
|
| 676 |
+
std::cout << "\n--- Model Inference Result (first few elements) ---" << std::endl;
|
| 677 |
+
for (size_t k = 0; k < std::min((size_t)10, outputSize); ++k) {
|
| 678 |
+
std::cout << outputData[k] << (k == std::min((size_t)10, outputSize) - 1 ? "" : ", ");
|
| 679 |
+
}
|
| 680 |
+
std::cout << std::endl;
|
| 681 |
+
|
| 682 |
+
std::cout << "Full output tensor size: " << outputSize << " elements." << std::endl;
|
| 683 |
+
std::cout << "Full output tensor shape: [";
|
| 684 |
+
for (size_t k = 0; k < outputShape.size(); ++k) {
|
| 685 |
+
std::cout << outputShape[k] << (k == outputShape.size() - 1 ? "" : ", ");
|
| 686 |
+
}
|
| 687 |
+
std::cout << "]" << std::endl;
|
| 688 |
+
} else {
|
| 689 |
+
std::cerr << "Error: First output tensor is not of the expected type (float tensor)." << std::endl;
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
} catch (const Ort::Exception& e) {
|
| 693 |
+
std::cerr << "ONNX Runtime Exception: " << e.what() << std::endl;
|
| 694 |
+
return 1;
|
| 695 |
+
} catch (const std::exception& e) {
|
| 696 |
+
std::cerr << "Standard Exception: " << e.what() << std::endl;
|
| 697 |
+
return 1;
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
std::cout << "\nProgram finished successfully." << std::endl;
|
| 701 |
+
return 0;
|
| 702 |
+
}
|