Spaces:
Sleeping
Sleeping
| from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import logging | |
| from tqdm import tqdm | |
| import functools | |
| # Standard library imports | |
| import os | |
| import sys | |
| import time | |
| import traceback | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timedelta | |
| from typing import Optional | |
| import torch.amp | |
| # Third-party imports | |
| import datasets | |
| import psutil | |
| import torch.multiprocessing as mp | |
| import transformers | |
| from peft import LoraConfig | |
| from tqdm import tqdm | |
| from transformers import HfArgumentParser, TrainingArguments, set_seed | |
| from torch.utils.data import DataLoader, RandomSampler, SequentialSampler | |
| from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM | |
| # Local imports | |
| from lpm_kernel.L2.utils import ( | |
| create_and_prepare_model, | |
| formatting_prompts_func, | |
| create_chat_data, | |
| release_ollama_models_early, | |
| ) | |
| from lpm_kernel.configs.logging import LOGGING_CONFIG | |
| import logging.config | |
| from lpm_kernel.configs.logging import get_train_process_logger | |
| from lpm_kernel.L2.memory_manager import get_memory_manager | |
| logger = get_train_process_logger() | |
| # Configure how tqdm displays in logs | |
| class LogTqdm(tqdm): | |
| def __init__(self, *args, **kwargs): | |
| kwargs.setdefault("mininterval", 1.0) | |
| kwargs.setdefault("ascii", True) | |
| super().__init__(*args, **kwargs) | |
| # Replace the default tqdm | |
| sys.modules["tqdm"].tqdm = LogTqdm | |
| # Debug callback for logging training progress | |
| class DebugCallback(transformers.TrainerCallback): | |
| def __init__(self): | |
| self.total_time = 0 | |
| self.last_time = time.time() | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if state.global_step % 10 == 0: | |
| current_time = time.time() | |
| step_time = current_time - self.last_time | |
| self.total_time += step_time | |
| self.last_time = current_time | |
| # Log step time and training progress | |
| logger.info(f"Step {state.global_step}: {step_time:.2f}s - Total training time: {self.total_time:.2f}s") | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| logger.info(f"Epoch {state.epoch} completed") | |
| class ModelArguments: | |
| """ | |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. | |
| """ | |
| model_name_or_path: str = field( | |
| metadata={ | |
| "help": "Path to pretrained model or model identifier from huggingface.co/models" | |
| } | |
| ) | |
| chat_template_format: Optional[str] = field( | |
| default="none", | |
| metadata={ | |
| "help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template." | |
| }, | |
| ) | |
| lora_alpha: Optional[int] = field(default=16) | |
| lora_dropout: Optional[float] = field(default=0.1) | |
| lora_r: Optional[int] = field(default=64) | |
| lora_target_modules: Optional[str] = field( | |
| default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj", | |
| metadata={ | |
| "help": "comma separated list of target modules to apply LoRA layers to" | |
| }, | |
| ) | |
| use_nested_quant: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "Activate nested quantization for 4bit base models"}, | |
| ) | |
| bnb_4bit_compute_dtype: Optional[str] = field( | |
| default="float16", | |
| metadata={"help": "Compute dtype for 4bit base models"}, | |
| ) | |
| bnb_4bit_quant_storage_dtype: Optional[str] = field( | |
| default="float32", | |
| metadata={"help": "Quantization storage dtype for 4bit base models"}, | |
| ) | |
| bnb_4bit_quant_type: Optional[str] = field( | |
| default="nf4", | |
| metadata={"help": "Quantization type fp4 or nf4"}, | |
| ) | |
| use_flash_attn: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "Enables Flash attention for training."}, | |
| ) | |
| use_peft_lora: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "Enables PEFT LoRA for training."}, | |
| ) | |
| use_8bit_quantization: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "Enables loading model in 8bit."}, | |
| ) | |
| use_4bit_quantization: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "Enables loading model in 4bit."}, | |
| ) | |
| use_reentrant: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "Gradient Checkpointing param. Refer the related docs"}, | |
| ) | |
| use_unsloth: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "Enables UnSloth for training."}, | |
| ) | |
| use_cuda: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "Enables CUDA GPU acceleration for training and inference when available."}, | |
| ) | |
| class DataTrainingArguments: | |
| dataset_name: Optional[str] = field( | |
| default="timdettmers/openassistant-guanaco", | |
| metadata={"help": "The preference dataset to use."}, | |
| ) | |
| append_concat_token: Optional[bool] = field( | |
| default=False, | |
| metadata={ | |
| "help": "If True, appends `eos_token_id` at the end of each sample being packed." | |
| }, | |
| ) | |
| add_special_tokens: Optional[bool] = field( | |
| default=False, | |
| metadata={ | |
| "help": "If True, tokenizers adds special tokens to each sample being packed." | |
| }, | |
| ) | |
| splits: Optional[str] = field( | |
| default="train,test", | |
| metadata={"help": "Comma separate list of the splits to use from the dataset."}, | |
| ) | |
| is_sequential: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "If True, the dataset is sequential."}, | |
| ) | |
| is_cot: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "If True, the dataset is COT dataset."}, | |
| ) | |
| user_name: Optional[str] = field( | |
| default="User", | |
| metadata={"help": "The name of the user."}, | |
| ) | |
| def main(model_args, data_args, training_args): | |
| logger.info(f"Python version--------------------: {sys.version}") | |
| # Configure logging | |
| logging.config.dictConfig(LOGGING_CONFIG) | |
| logger.info("Begin training...") | |
| # Ensure logs are flushed immediately | |
| for handler in logging.getLogger().handlers: | |
| handler.flush() | |
| # Get memory manager for optimization | |
| memory_manager = get_memory_manager() | |
| memory_manager.cleanup_memory(force=True) | |
| # Release Ollama models if they exist to free up VRAM | |
| if torch.cuda.is_available() and model_args.use_cuda: | |
| release_ollama_models_early() | |
| logger.info("Initializing training with memory optimizations") | |
| set_seed(training_args.seed) | |
| # Apply PyTorch memory optimizations to training arguments | |
| logger.info("Applying memory optimizations to training configuration") | |
| training_args = memory_manager.optimize_training_args(training_args) | |
| # --- Accelerate optimizer state offloading logic --- | |
| # Enable optimizer state offload to CPU if VRAM is low and not using DeepSpeed | |
| vram_total = memory_manager.get_memory_info().get("vram_total_gb", 0) | |
| use_accelerate_offload = False | |
| if torch.cuda.is_available() and model_args.use_cuda and vram_total > 0 and vram_total < 16: | |
| # Only set if not already using DeepSpeed | |
| if not hasattr(training_args, "deepspeed") or training_args.deepspeed is None: | |
| logger.info("Enabling Hugging Face Accelerate optimizer state offload to CPU for low VRAM GPUs") | |
| accelerate_config = { | |
| "compute_environment": "LOCAL_MACHINE", | |
| "deepspeed_config": None, | |
| "distributed_type": "NO", | |
| "downcast_bf16": False, | |
| "fsdp_config": {}, | |
| "main_training_function": "main", | |
| "mixed_precision": "no", | |
| "num_machines": 1, | |
| "num_processes": 1, | |
| "use_cpu": False, | |
| "zero3_init_flag": False, | |
| "offload_optimizer_device": "cpu", | |
| "offload_param_device": "none" | |
| } | |
| training_args.accelerate_config = accelerate_config | |
| use_accelerate_offload = True | |
| # Model loading with device_map="auto" for automatic offloading | |
| logger.info(f"Loading model with automatic memory management from {model_args.model_name_or_path}") | |
| # Create model arguments dict with automatic offloading | |
| model_kwargs = { | |
| # Don't use "auto" device_map initially to avoid meta tensor issues | |
| "device_map": None, | |
| "trust_remote_code": True | |
| } | |
| # Configure quantization if requested | |
| if model_args.use_4bit_quantization: | |
| from transformers import BitsAndBytesConfig | |
| compute_dtype = getattr(torch, model_args.bnb_4bit_compute_dtype) | |
| quant_storage_dtype = getattr(torch, model_args.bnb_4bit_quant_storage_dtype) | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=model_args.use_4bit_quantization, | |
| bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| bnb_4bit_use_double_quant=model_args.use_nested_quant, | |
| bnb_4bit_quant_storage=quant_storage_dtype, | |
| ) | |
| # For 4-bit models, we can use device_map="auto" | |
| model_kwargs["device_map"] = "auto" | |
| logger.info("Using 4-bit quantization for memory efficiency") | |
| elif model_args.use_8bit_quantization: | |
| from transformers import BitsAndBytesConfig | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_8bit=model_args.use_8bit_quantization | |
| ) | |
| # For 8-bit models, we can use device_map="auto" | |
| model_kwargs["device_map"] = "auto" | |
| logger.info("Using 8-bit quantization for memory efficiency") | |
| # Flash attention for memory efficiency when supported | |
| if model_args.use_flash_attn and torch.cuda.is_available() and model_args.use_cuda: | |
| model_kwargs["attn_implementation"] = "flash_attention_2" | |
| logger.info("Using Flash Attention 2 for memory efficiency") | |
| # Load model with built-in memory management features | |
| model, peft_config, tokenizer = create_and_prepare_model( | |
| model_args, data_args, training_args, model_kwargs=model_kwargs | |
| ) | |
| # If model has meta tensors, handle them properly | |
| if hasattr(model, "is_meta") and model.is_meta: | |
| logger.info("Model has meta tensors, using to_empty() to properly initialize") | |
| device = "cuda" if torch.cuda.is_available() and model_args.use_cuda else "cpu" | |
| model = model.to_empty(device=device) | |
| # Apply gradient checkpointing for memory efficiency | |
| if training_args.gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): | |
| logger.info("Enabling gradient checkpointing for memory efficiency") | |
| model.gradient_checkpointing_enable() | |
| model.config.use_cache = False | |
| # Allow only one full forward/backward pass at a time (if needed for memory) | |
| if torch.cuda.is_available() and memory_manager.get_memory_info().get("vram_total_gb", 0) < 8: | |
| torch.cuda.set_per_process_memory_fraction(0.9) | |
| logger.info("Setting memory fraction limit to avoid OOM errors") | |
| # datasets | |
| train_dataset = create_chat_data( | |
| data_args, | |
| tokenizer, | |
| ) | |
| response_template = "\n<|im_start|>assistant\n" | |
| collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) | |
| training_args.dataset_kwargs = { | |
| "append_concat_token": data_args.append_concat_token, | |
| "add_special_tokens": data_args.add_special_tokens, | |
| } | |
| # Use DeepSpeed to handle meta tensors if available | |
| try: | |
| # Only configure DeepSpeed if meta tensors are present and DeepSpeed is available | |
| if hasattr(model, "is_meta") and model.is_meta: | |
| logger.info("Model has meta tensors, checking DeepSpeed availability") | |
| # First verify DeepSpeed is properly installed and importable | |
| try: | |
| import deepspeed | |
| logger.info("DeepSpeed is available, configuring for meta tensor handling") | |
| # Configure with appropriate settings for meta tensors | |
| training_args.deepspeed = { | |
| "zero_stage": 3, | |
| "offload_optimizer": { | |
| "device": "cpu" | |
| }, | |
| "offload_param": { | |
| "device": "cpu" | |
| }, | |
| "zero3_init_flag": True, | |
| "zero_force_ds_cpu_optimizer": False | |
| } | |
| logger.info("DeepSpeed configured for meta tensor handling") | |
| except ImportError: | |
| logger.warning("DeepSpeed is not available, meta tensors will be handled differently") | |
| # If DeepSpeed isn't available, use alternative approach to handle meta tensors | |
| if torch.cuda.is_available() and model_args.use_cuda: | |
| logger.info("Initializing meta tensors on GPU") | |
| # Use device_map instead of DeepSpeed for meta tensor initialization | |
| from accelerate import init_empty_weights | |
| with init_empty_weights(): | |
| model.to_empty(device="cuda") | |
| else: | |
| logger.info("Initializing meta tensors on CPU") | |
| model.to_empty(device="cpu") | |
| except Exception as e: | |
| logger.warning(f"Could not configure meta tensor handling: {e}") | |
| logger.warning(traceback.format_exc()) | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| peft_config=peft_config, | |
| formatting_func=formatting_prompts_func, | |
| data_collator=collator, | |
| ) | |
| # Print model details | |
| trainer.accelerator.print(f"{trainer.model}") | |
| if hasattr(trainer.model, "print_trainable_parameters"): | |
| trainer.model.print_trainable_parameters() | |
| # Memory usage tracking callback | |
| class MemoryMonitorCallback(transformers.TrainerCallback): | |
| def __init__(self): | |
| self.memory_manager = get_memory_manager() | |
| def on_step_end(self, args, state, control, **kwargs): | |
| # Check memory every 5 steps | |
| if state.global_step % 5 == 0 and torch.cuda.is_available(): | |
| info = self.memory_manager.get_memory_info() | |
| vram_usage_pct = info.get("vram_used_gb", 0) / info.get("vram_total_gb", 1) * 100 | |
| if vram_usage_pct > 90: | |
| logger.info(f"VRAM usage high ({vram_usage_pct:.1f}%), cleaning cache") | |
| self.memory_manager.cleanup_memory() | |
| def on_save(self, args, state, control, **kwargs): | |
| # Free up memory before saving | |
| self.memory_manager.cleanup_memory(force=True) | |
| # Add memory monitoring | |
| trainer.add_callback(MemoryMonitorCallback()) | |
| # Add existing debug callback | |
| trainer.add_callback(DebugCallback()) | |
| # Resume from checkpoint if specified | |
| checkpoint = None | |
| if training_args.resume_from_checkpoint is not None: | |
| checkpoint = training_args.resume_from_checkpoint | |
| # Training with automatic memory management | |
| try: | |
| logger.info("Starting training with memory-optimized configuration") | |
| trainer.train(resume_from_checkpoint=checkpoint) | |
| except Exception as e: | |
| logger.error(f"Error during training: {str(e)}") | |
| logger.error(f"Error type: {type(e)}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| raise | |
| # Save the model | |
| if trainer.is_fsdp_enabled: | |
| trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") | |
| # Clean up before saving | |
| memory_manager.cleanup_memory(force=True) | |
| trainer.save_model() | |
| logger.info("Training completed successfully") | |
| # Create a patch to handle autocast compatibility | |
| def get_autocast(): | |
| if hasattr(torch.cpu, "amp") and hasattr(torch.cpu.amp, "autocast"): | |
| # Old version | |
| return torch.cpu.amp.autocast | |
| else: | |
| # New version | |
| return lambda **kwargs: torch.amp.autocast("cpu", **kwargs) | |
| # Replace the original torch.cpu.amp.autocast with our compatible function | |
| torch.cpu.amp.autocast = get_autocast() | |
| if __name__ == "__main__": | |
| parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig)) | |
| if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | |
| # If we pass only one argument to the script and it's the path to a json file, | |
| # let's parse it to get our arguments. | |
| model_args, data_args, training_args = parser.parse_json_file( | |
| json_file=os.path.abspath(sys.argv[1]) | |
| ) | |
| else: | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| main(model_args, data_args, training_args) | |