import datasets datasets.config.DOWNLOADED_DATASETS_PATH = "/mnt/jeff/huggingface/data" import os os.environ['HF_HOME'] = '/mnt/jeff/huggingface' import argparse import json import os from pathlib import Path import numpy as np import torch import sacrebleu from datasets import load_dataset from torch.utils.data import Dataset, ConcatDataset from tqdm import tqdm from transformers import ( AutoProcessor, AutoModel, BatchFeature, Trainer, TrainingArguments, StoppingCriteria, StoppingCriteriaList, ) from collections import defaultdict import soundfile as sf from datasets import Audio import random from ASRDataset import * def count_parameters_by_module(model): # dictionary for parameters number by modules module_params = defaultdict(lambda: {"total": 0, "trainable": 0}) # all params total_params = 0 total_trainable_params = 0 # Check Embedding Token masks embedding_masks = {} for name, param in model.named_parameters(): if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks: # check if params has embedding_grad_mask_hook for hook_id, hook_fn in param._backward_hooks.items(): if hook_fn.__code__.co_name == 'embedding_grad_mask_hook': # Accessing mask variables in the closure of hook functions for cell in hook_fn.__closure__ or []: if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool: # check mask tensor embedding_masks[name] = ~cell.cell_contents # True : Trainable # Count params by modules for name, param in model.named_parameters(): # extracts top module_name module_name = name.split('.')[0] param_count = param.numel() module_params[module_name]["total"] += param_count total_params += param_count if param.requires_grad: # Only count for real trainable params. (with masks) if name in embedding_masks: trainable_count = embedding_masks[name].sum().item() module_params[module_name]["trainable"] += trainable_count total_trainable_params += trainable_count else: module_params[module_name]["trainable"] += param_count total_trainable_params += param_count print(f"All Params: {total_params:,}") print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)") print("\nParams by Module:") for module_name, counts in sorted(module_params.items()): trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0 total_percentage = counts["total"] / total_params * 100 print(f"- {module_name}:") print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)") print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)") return module_params def create_model(model_name_or_path, revision="main", use_flash_attention = False): model = AutoModel.from_pretrained( model_name_or_path, revision=revision, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2" if use_flash_attention else "eager", trust_remote_code=True, ) # Set use_cache to False after model loaded model.config.use_cache = False # Freeze all parameters for param in model.parameters(): param.requires_grad = False model.set_lora_adapter('speech') # model.set_lora_adapter('text') model.to(torch.bfloat16) # (Optional) unfreeze audio_tower parameters # for param in model.audio_tower.parameters(): # param.requires_grad = True # Only unfreeze audio_projector parameters # for param in model.audio_projector.parameters(): # param.requires_grad = True # (Optional) unfreeze audio embed_tokens train_embed = True if train_embed: embed_tokens = model.language_model.model.model.embed_tokens embed_tokens.weight.requires_grad = False # Added Speech token IDs (only this tokens be trainable) trainable_token_ids = [256001, 256002] embed_tokens.weight.requires_grad = True mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool) mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze) # backward hook, with gradient masking def embedding_grad_mask_hook(grad): return grad.masked_fill(mask, 0) embed_tokens.weight.register_hook(embedding_grad_mask_hook) model.language_model.model.model.embed_tokens = embed_tokens count_parameters_by_module(model) return model ANSWER_SUFFIX = "" _IGNORE_INDEX = -100 ANSWER_SUFFIX = "" _IGNORE_INDEX = -100 model_name_or_path = '/mnt/jeff/gemma-3-4b-it-omni' use_flash_attention = False output_dir = '../gemma_tmp14_audio_and_text_speechlora' batch_size = 16 batch_size_per_gpu = 1 learning_rate = 5.0e-5 # 1.0e-4 for fine-tuning wd = 0.01 num_train_epochs = 10 revision = "main" #"v1.0" processor = AutoProcessor.from_pretrained( model_name_or_path, revision=revision, trust_remote_code=True, ) model = create_model( model_name_or_path, revision=revision, use_flash_attention=use_flash_attention, ) train_datasets = [] pickup_dataset = MultiturnAudioDataset(processor=processor,text_only=True,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json') train_datasets.append(pickup_dataset) pickup_dataset = MultiturnAudioDataset(processor=processor,json_path='/mnt/jeff/InCar/data/multiturn_data/pickup_processed.json') train_datasets.append(pickup_dataset) # custom_tw_loc = TWCostumData(processor=processor, # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250509-common_voice_16_1-TW.csv') # train_datasets.append(custom_tw_loc) # 1500 # custom_tw_loc2 = TWCostumData(processor=processor, # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_location-srdc_tts-20250529-common_voice_16_1-TW.csv') # train_datasets.append(custom_tw_loc2) # 9458 # custom_yating_tw_road = TWCostumData(processor=processor, # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250430-yating-1-2s-breezyvoice.csv') # train_datasets.append(custom_yating_tw_road) # 35224 # custom_tw_road = TWCostumData(processor=processor, # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250509-common_voice_16_1-TW.csv') # train_datasets.append(custom_tw_road) # 1500 # custom_tw_road2 = TWCostumData(processor=processor, # csv_path='/mnt/jeff/InCar/data/tw_data/taiwan_road-srdc_tts-20250529-common_voice_16_1-TW.csv') # train_datasets.append(custom_tw_road2) # 35224 print("Count Num of Datasets", len(train_datasets)) print([len(dataset) for dataset in train_datasets]) # ConcatDataset train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0] print("Count Length of Datas", len(train_dataset)) # Check GPUs num_gpus = torch.cuda.device_count() print(f'training on {num_gpus} GPUs') assert ( batch_size % (num_gpus * batch_size_per_gpu) == 0 ), 'Batch size must be divisible by the number of GPUs' gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu) # hard coded training args dp_config = { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "zero_optimization": { "stage": 2, "allgather_partitions": True, "allgather_bucket_size": 5e8, "overlap_comm": False, "reduce_scatter": True, "reduce_bucket_size": 5e8, "contiguous_gradients": True, "cpu_offload": True }, "train_batch_size": "auto", "gradient_accumulation_steps": "auto", "optimizer": { "type": "AdamW", "params": { "lr": "auto", "betas": 'auto', "eps": 'auto', "weight_decay": "auto" } }, "scheduler": { "type": "WarmupDecayLR", "params": { "warmup_min_lr": "auto", "warmup_max_lr": "auto", "warmup_num_steps": "auto", "total_num_steps": "auto" } }, "gradient_clipping": 1.0, "zero_optimization": { "stage": 0 } } training_args = TrainingArguments( num_train_epochs=num_train_epochs, per_device_train_batch_size=batch_size_per_gpu, gradient_checkpointing=True, gradient_checkpointing_kwargs={'use_reentrant': False}, gradient_accumulation_steps=gradient_accumulation_steps, optim='adamw_torch', adam_beta1=0.9, adam_beta2=0.95, adam_epsilon=1e-7, learning_rate=learning_rate, weight_decay=wd, max_grad_norm=1.0, lr_scheduler_type='cosine', warmup_steps=50, logging_steps=10, output_dir=output_dir, save_total_limit=10, save_only_model=True, bf16=True, fp16=False, remove_unused_columns=False, report_to='none', deepspeed=None, disable_tqdm=False, dataloader_num_workers=16, save_strategy='epoch', # save_steps=2500, ddp_find_unused_parameters=True, ) out_path = Path(training_args.output_dir) out_path.mkdir(parents=True, exist_ok=True) # create optimizer only for trainable params optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=wd, betas=(0.9, 0.95), eps=1e-7, ) # Trainer Setting trainer = Trainer( model=model, args=training_args, data_collator=covost_collate_fn, train_dataset=train_dataset, optimizers=(optimizer, None) ) trainer.train() # # 1. Save LoRA Adapter model.language_model.model.save_pretrained(output_dir) # # 1-1. Delete Markdown file # markdown_file = os.path.join(output_dir, "README.md") # if os.path.exists(markdown_file): # os.remove(markdown_file) # 2. Save entire model model.save_pretrained(output_dir)