go / training_multiturn_textonly.py
jva96160's picture
Upload 25 files
a16e4aa verified
import datasets
datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data"
import os
os.environ['HF_HOME'] = '/mnt/jeff/huggingface'
import argparse
import json
import os
from pathlib import Path
import numpy as np
import torch
import sacrebleu
from datasets import load_dataset
from torch.utils.data import Dataset, ConcatDataset
from tqdm import tqdm
from transformers import (
AutoProcessor,
AutoModel,
BatchFeature,
Trainer,
TrainingArguments,
StoppingCriteria,
StoppingCriteriaList,
)
from collections import defaultdict
import soundfile as sf
from datasets import Audio
import random
from ASRDataset import *
def count_parameters_by_module(model):
# dictionary for parameters number by modules
module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
# all params
total_params = 0
total_trainable_params = 0
# Check Embedding Token masks
embedding_masks = {}
for name, param in model.named_parameters():
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
# check if params has embedding_grad_mask_hook
for hook_id, hook_fn in param._backward_hooks.items():
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
# Accessing mask variables in the closure of hook functions
for cell in hook_fn.__closure__ or []:
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
# check mask tensor
embedding_masks[name] = ~cell.cell_contents # True : Trainable
# Count params by modules
for name, param in model.named_parameters():
# extracts top module_name
module_name = name.split('.')[0]
param_count = param.numel()
module_params[module_name]["total"] += param_count
total_params += param_count
if param.requires_grad:
# Only count for real trainable params. (with masks)
if name in embedding_masks:
trainable_count = embedding_masks[name].sum().item()
module_params[module_name]["trainable"] += trainable_count
total_trainable_params += trainable_count
else:
module_params[module_name]["trainable"] += param_count
total_trainable_params += param_count
print(f"All Params: {total_params:,}")
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
print("\nParams by Module:")
for module_name, counts in sorted(module_params.items()):
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
total_percentage = counts["total"] / total_params * 100
print(f"- {module_name}:")
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
return module_params
def create_model(model_name_or_path, revision="main", use_flash_attention = False):
model = AutoModel.from_pretrained(
model_name_or_path,
revision=revision,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
trust_remote_code=True,
)
# Set use_cache to False after model loaded
model.config.use_cache = False
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
model.set_lora_adapter('speech')
# model.set_lora_adapter('text')
model.to(torch.bfloat16)
# (Optional) unfreeze audio_tower parameters
# for param in model.audio_tower.parameters():
# param.requires_grad = True
# Only unfreeze audio_projector parameters
# for param in model.audio_projector.parameters():
# param.requires_grad = True
# (Optional) unfreeze audio embed_tokens
train_embed = True
if train_embed:
embed_tokens = model.language_model.model.model.embed_tokens
embed_tokens.weight.requires_grad = False
# Added Speech token IDs (only this tokens be trainable)
trainable_token_ids = [256001, 256002]
embed_tokens.weight.requires_grad = True
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
# backward hook, with gradient masking
def embedding_grad_mask_hook(grad):
return grad.masked_fill(mask, 0)
embed_tokens.weight.register_hook(embedding_grad_mask_hook)
model.language_model.model.model.embed_tokens = embed_tokens
count_parameters_by_module(model)
return model
ANSWER_SUFFIX = "<end_of_turn>"
_IGNORE_INDEX = -100
ANSWER_SUFFIX = "<end_of_turn>"
_IGNORE_INDEX = -100
model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni'
use_flash_attention = False
output_dir = '../gemma_tmp14_audio_and_text_speechlora'
batch_size = 16
batch_size_per_gpu = 1
learning_rate = 5.0e-5 # 1.0e-4 for fine-tuning
wd = 0.01
num_train_epochs = 10
revision = "main" #"v1.0"
processor = AutoProcessor.from_pretrained(
model_name_or_path,
revision=revision,
trust_remote_code=True,
)
model = create_model(
model_name_or_path,
revision=revision,
use_flash_attention=use_flash_attention,
)
train_datasets = []
pickup_dataset = MultiturnAudioDataset(processor=processor,text_only=True,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
train_datasets.append(pickup_dataset)
pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json')
train_datasets.append(pickup_dataset)
# custom_tw_loc = TWCostumData(processor=processor,
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv')
# train_datasets.append(custom_tw_loc) # 1500
# custom_tw_loc2 = TWCostumData(processor=processor,
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv')
# train_datasets.append(custom_tw_loc2) # 9458
# custom_yating_tw_road = TWCostumData(processor=processor,
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv')
# train_datasets.append(custom_yating_tw_road) # 35224
# custom_tw_road = TWCostumData(processor=processor,
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv')
# train_datasets.append(custom_tw_road) # 1500
# custom_tw_road2 = TWCostumData(processor=processor,
# csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv')
# train_datasets.append(custom_tw_road2) # 35224
print("Count Num of Datasets", len(train_datasets))
print([len(dataset) for dataset in train_datasets])
# ConcatDataset
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
print("Count Length of Datas", len(train_dataset))
# Check GPUs
num_gpus = torch.cuda.device_count()
print(f'training on {num_gpus} GPUs')
assert (
batch_size % (num_gpus * batch_size_per_gpu) == 0
), 'Batch size must be divisible by the number of GPUs'
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
# hard coded training args
dp_config = {
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": False,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"cpu_offload": True
},
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": 'auto',
"eps": 'auto',
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 0
}
}
training_args = TrainingArguments(
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size_per_gpu,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False},
gradient_accumulation_steps=gradient_accumulation_steps,
optim='adamw_torch',
adam_beta1=0.9,
adam_beta2=0.95,
adam_epsilon=1e-7,
learning_rate=learning_rate,
weight_decay=wd,
max_grad_norm=1.0,
lr_scheduler_type='cosine',
warmup_steps=50,
logging_steps=10,
output_dir=output_dir,
save_total_limit=10,
save_only_model=True,
bf16=True,
fp16=False,
remove_unused_columns=False,
report_to='none',
deepspeed=None,
disable_tqdm=False,
dataloader_num_workers=16,
save_strategy='epoch',
# save_steps=2500,
ddp_find_unused_parameters=True,
)
out_path = Path(training_args.output_dir)
out_path.mkdir(parents=True, exist_ok=True)
# create optimizer only for trainable params
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
weight_decay=wd,
betas=(0.9, 0.95),
eps=1e-7,
)
# Trainer Setting
trainer = Trainer(
model=model,
args=training_args,
data_collator=covost_collate_fn,
train_dataset=train_dataset,
optimizers=(optimizer, None)
)
trainer.train()
# # 1. Save LoRA Adapter
model.language_model.model.save_pretrained(output_dir)
# # 1-1. Delete Markdown file
# markdown_file = os.path.join(output_dir, "README.md")
# if os.path.exists(markdown_file):
# os.remove(markdown_file)
# 2. Save entire model
model.save_pretrained(output_dir)