import torch import gc from datasets import load_dataset from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig #from transformers.trainer_utils import get_parameter_names, ALL_LAYERNORM_LAYERS from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, logger, is_accelerate_available, is_datasets_available from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from trl import SFTConfig, SFTTrainer from huggingface_hub import notebook_login from loader import jsonl_to_dataset_hf_parallel import PIL import torch.distributed as dist from torch.utils.data import Dataset from torch import nn import logging logger = logging.getLogger(__name__) class LazyVisualDataset(Dataset): def __init__(self, data, processor): self.data = data self.processor = processor def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] # Load image on-demand if isinstance(sample['image'], str): image = PIL.Image.open(sample['image']).convert("RGB") else: image = sample['image'].convert("RGB") return { "images": [image], "messages": [ {"role": "system", "content": [{"type": "text", "text": system_message}]}, {"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": sample["query"] if "query" in sample else sample['question']}]}, {"role": "assistant", "content": [{"type": "text", "text": sample["label"][0] if "label" in sample else sample['gt_answer']}]}, ], } class CustomLFM2VLTrainer(SFTTrainer): """Custom trainer with different learning rates for vision and language components.""" def create_optimizer(self): """ Setup the optimizer with different learning rates for different model components. - multi_modal_projector: 1/10 of base learning rate - vision_tower: 1/10 of base learning rate - language_model: base learning rate """ if self.optimizer is None: opt_model = self.model # Get parameters that should have weight decay (exclude bias and layer norm) decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] # Set up learning rate mapper for different components base_lr = self.args.learning_rate lr_mapper = { "multi_modal_projector": base_lr / 10.0, # 1/10 of base learning rate "vision_tower": base_lr / 10.0, # 1/10 of base learning rate # language_model will use the base learning rate (not specified in mapper) } print(f"Base learning rate: {base_lr}") print(f"Multi-modal projector learning rate: {lr_mapper['multi_modal_projector']}") print(f"Vision tower learning rate: {lr_mapper['vision_tower']}") print(f"Language model learning rate: {base_lr}") # Find parameters that belong to special learning rate modules special_lr_parameters = [] for name, _ in opt_model.named_parameters(): for module_keyword in lr_mapper: if module_keyword in name: special_lr_parameters.append(name) break # Create parameter groups optimizer_grouped_parameters = [] # Regular parameters (language model) with weight decay optimizer_grouped_parameters.append({ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)], "weight_decay": self.args.weight_decay, "lr": base_lr, }) # Regular parameters (language model) without weight decay optimizer_grouped_parameters.append({ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)], "weight_decay": 0.0, "lr": base_lr, }) # Add special learning rate parameter groups for module_keyword, lr in lr_mapper.items(): module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name] # Parameters with weight decay optimizer_grouped_parameters.append({ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)], "weight_decay": self.args.weight_decay, "lr": lr, }) # Parameters without weight decay optimizer_grouped_parameters.append({ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)], "weight_decay": 0.0, "lr": lr, }) # Print parameter group info for i, group in enumerate(optimizer_grouped_parameters): param_count = sum(p.numel() for p in group["params"]) print(f"Parameter group {i}: {param_count} parameters, " f"lr={group['lr']}, weight_decay={group['weight_decay']}") # Get optimizer class and kwargs optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(self.args) # Create optimizer self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) # Handle 8bit optimizer if needed if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") return self.optimizer def format_data(sample, system_message=None): """Formats the dataset into a structured format for the model.""" if not isinstance(sample['image'], PIL.Image.Image): sample['image'] = PIL.Image.open(sample['image']) if system_message is None: return { "images": [sample['image'].convert("RGB")], "messages": [ {"role": "user", "content": [ {"type": "image", "image": sample["image"].convert("RGB")}, {"type": "text", "text": sample["query"] if "query" in sample else sample['question']}]}, {"role": "assistant", "content": [{"type": "text", "text": sample["label"][0] if "label" in sample else sample['gt_answer']}]}, ], } return { "images": [sample['image'].convert("RGB")], "messages": [ {"role": "system", "content": [{"type": "text", "text": system_message}]}, {"role": "user", "content": [ {"type": "image", "image": sample["image"].convert("RGB")}, {"type": "text", "text": sample["query"] if "query" in sample else sample['question']}]}, {"role": "assistant", "content": [{"type": "text", "text": sample["label"][0] if "label" in sample else sample['gt_answer']}]}, ], } # Step 2: Load and format the dataset dataset_id = "HuggingFaceM4/ChartQA" train_dataset = jsonl_to_dataset_hf_parallel( jsonl_file="/lustre/scratch/client/movian/research/users/chitb/mmawe/VLM_data/FINETUNE_JSONL/QUALCOMM/finetuning/categories/output_easy_hard/hard_ins.jsonl", image_root="/lustre/scratch/client/movian/research/users/chitb/mmawe/VLM_data/VAI_VLM_IMAGES" ) eval_dataset = jsonl_to_dataset_hf_parallel( jsonl_file="/lustre/scratch/client/movian/research/users/chitb/mmawe/VLM_data/FINETUNE_JSONL/QUALCOMM/pretraining/iaocr_short.jsonl", image_root="/lustre/scratch/client/movian/research/users/chitb/mmawe/VLM_data/VAI_VLM_IMAGES" ) model_id = "LiquidAI/LFM2-VL-450M" # Load original model for initial inference system_message = """You are a Vision Language Model specialized in interpreting visual data from images and follow closely to the instructions.""" processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) train_dataset = LazyVisualDataset(train_dataset, processor) eval_dataset = LazyVisualDataset(eval_dataset, processor) # Define a helper function for inference def generate_answer(model, processor, sample_messages): """Generates an answer from a model given a list of messages.""" inputs = processor.apply_chat_template( sample_messages, tokenize=True, return_tensors='pt', return_dict=True, add_generation_prompt=True ) inputs = {k: v.to(model.device) for k, v in inputs.items()} generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=256) generated_ids = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)] output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return output_text # --- FINE-TUNING PROCESS --- print("--- Starting Fine-Tuning Process ---") # Step 3: Load the model for training model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=torch.bfloat16, trust_remote_code=True ).cuda() # Step 5: Configure and run the training output_model_name = "lfm2vl_custom_no_lora_hard_ins" training_args = SFTConfig( output_dir=output_model_name, num_train_epochs=1, deepspeed="./zero1.json", per_device_train_batch_size=2, per_device_eval_batch_size=4, gradient_accumulation_steps=1, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, max_length=8192, optim="adamw_torch", learning_rate=5e-6, # Base learning rate for language model save_steps=500, logging_steps=1, save_total_limit=6, save_strategy="steps", bf16=True, warmup_ratio=0.01, report_to='none' ) # Use the custom trainer instead of SFTTrainer trainer = CustomLFM2VLTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processor, ) trainer.train() trainer.save_model(output_model_name) print("\n--- Fine-tuning complete and model saved. ---\n") # --- POST-TUNING INFERENCE --- print("--- Starting Inference with Fine-Tuned Model ---") # Clear memory del trainer, model gc.collect() torch.cuda.empty_cache() if dist.is_initialized(): if dist.get_rank() == 0: # Original model inference original_model = AutoModelForImageTextToText.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16 ).cuda() processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) eval_sample = eval_dataset[5] prompt_messages = eval_sample['messages'][:2] if len(eval_sample['messages']) == 3 else eval_sample['messages'][:1] prompt_messages[1]['content'][0]['image'].save("./eval_sample.jpg") print(prompt_messages) original_model_output = generate_answer(original_model, processor, prompt_messages) del original_model, processor gc.collect() torch.cuda.empty_cache() print("\n--- Original Model Inference Complete ---\n") if dist.is_initialized(): if dist.get_rank() == 0: eval_sample = eval_dataset[5] prompt_messages = eval_sample['messages'][:2] if len(eval_sample['messages']) == 3 else eval_sample['messages'][:1] processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained( output_model_name, torch_dtype=torch.bfloat16, trust_remote_code=True ).cuda() finetuned_model_output = generate_answer(model, processor, prompt_messages) print("\n--- Fine-Tuned Model Inference Complete ---\n") # --- FINAL COMPARISON --- print("="*50) print(" MODEL PERFORMANCE COMPARISON") print("="*50) print(f"Query: {eval_sample['messages'][1]['content'][1]['text']}\n") print(f"Ground Truth: {eval_sample['messages'][2]['content'][0]['text']}") print(f"----------------------Original Model: {original_model_output}") print(f"----------------------Fine-Tuned Model: {finetuned_model_output}") print("="*50)