File size: 17,083 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import torch
import logging
from tqdm import tqdm
import functools
# Standard library imports
import os
import sys
import time
import traceback
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional
import torch.amp
# Third-party imports
import datasets
import psutil
import torch.multiprocessing as mp
import transformers
from peft import LoraConfig
from tqdm import tqdm
from transformers import HfArgumentParser, TrainingArguments, set_seed
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM

# Local imports
from lpm_kernel.L2.utils import (
    create_and_prepare_model,
    formatting_prompts_func,
    create_chat_data,
    release_ollama_models_early,
)
from lpm_kernel.configs.logging import LOGGING_CONFIG
import logging.config
from lpm_kernel.configs.logging import get_train_process_logger
from lpm_kernel.L2.memory_manager import get_memory_manager

logger = get_train_process_logger()


# Configure how tqdm displays in logs
class LogTqdm(tqdm):
    def __init__(self, *args, **kwargs):
        kwargs.setdefault("mininterval", 1.0)
        kwargs.setdefault("ascii", True)
        super().__init__(*args, **kwargs)

# Replace the default tqdm
sys.modules["tqdm"].tqdm = LogTqdm

# Debug callback for logging training progress
class DebugCallback(transformers.TrainerCallback):
    def __init__(self):
        self.total_time = 0
        self.last_time = time.time()
        
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 10 == 0:
            current_time = time.time()
            step_time = current_time - self.last_time
            self.total_time += step_time
            self.last_time = current_time
            
            # Log step time and training progress
            logger.info(f"Step {state.global_step}: {step_time:.2f}s - Total training time: {self.total_time:.2f}s")
            
    def on_epoch_end(self, args, state, control, **kwargs):
        logger.info(f"Epoch {state.epoch} completed")


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )
    chat_template_format: Optional[str] = field(
        default="none",
        metadata={
            "help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template."
        },
    )
    lora_alpha: Optional[int] = field(default=16)
    lora_dropout: Optional[float] = field(default=0.1)
    lora_r: Optional[int] = field(default=64)
    lora_target_modules: Optional[str] = field(
        default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
        metadata={
            "help": "comma separated list of target modules to apply LoRA layers to"
        },
    )
    use_nested_quant: Optional[bool] = field(
        default=False,
        metadata={"help": "Activate nested quantization for 4bit base models"},
    )
    bnb_4bit_compute_dtype: Optional[str] = field(
        default="float16",
        metadata={"help": "Compute dtype for 4bit base models"},
    )
    bnb_4bit_quant_storage_dtype: Optional[str] = field(
        default="float32",
        metadata={"help": "Quantization storage dtype for 4bit base models"},
    )
    bnb_4bit_quant_type: Optional[str] = field(
        default="nf4",
        metadata={"help": "Quantization type fp4 or nf4"},
    )
    use_flash_attn: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables Flash attention for training."},
    )
    use_peft_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables PEFT LoRA for training."},
    )
    use_8bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 8bit."},
    )
    use_4bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 4bit."},
    )
    use_reentrant: Optional[bool] = field(
        default=False,
        metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
    )
    use_unsloth: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables UnSloth for training."},
    )
    use_cuda: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables CUDA GPU acceleration for training and inference when available."},
    )


@dataclass
class DataTrainingArguments:
    dataset_name: Optional[str] = field(
        default="timdettmers/openassistant-guanaco",
        metadata={"help": "The preference dataset to use."},
    )
    append_concat_token: Optional[bool] = field(
        default=False,
        metadata={
            "help": "If True, appends `eos_token_id` at the end of each sample being packed."
        },
    )
    add_special_tokens: Optional[bool] = field(
        default=False,
        metadata={
            "help": "If True, tokenizers adds special tokens to each sample being packed."
        },
    )
    splits: Optional[str] = field(
        default="train,test",
        metadata={"help": "Comma separate list of the splits to use from the dataset."},
    )
    is_sequential: Optional[bool] = field(
        default=False,
        metadata={"help": "If True, the dataset is sequential."},
    )
    is_cot: Optional[bool] = field(
        default=False,
        metadata={"help": "If True, the dataset is COT dataset."},
    )
    user_name: Optional[str] = field(
        default="User",
        metadata={"help": "The name of the user."},
    )


def main(model_args, data_args, training_args):
    logger.info(f"Python version--------------------: {sys.version}")

    # Configure logging
    logging.config.dictConfig(LOGGING_CONFIG)

    logger.info("Begin training...")

    # Ensure logs are flushed immediately
    for handler in logging.getLogger().handlers:
        handler.flush()

    # Get memory manager for optimization
    memory_manager = get_memory_manager()
    memory_manager.cleanup_memory(force=True)
    
    # Release Ollama models if they exist to free up VRAM
    if torch.cuda.is_available() and model_args.use_cuda:
        release_ollama_models_early()
    
    logger.info("Initializing training with memory optimizations")
    set_seed(training_args.seed)
    
    # Apply PyTorch memory optimizations to training arguments
    logger.info("Applying memory optimizations to training configuration")
    training_args = memory_manager.optimize_training_args(training_args)

    # --- Accelerate optimizer state offloading logic ---
    # Enable optimizer state offload to CPU if VRAM is low and not using DeepSpeed
    vram_total = memory_manager.get_memory_info().get("vram_total_gb", 0)
    use_accelerate_offload = False
    if torch.cuda.is_available() and model_args.use_cuda and vram_total > 0 and vram_total < 16:
        # Only set if not already using DeepSpeed
        if not hasattr(training_args, "deepspeed") or training_args.deepspeed is None:
            logger.info("Enabling Hugging Face Accelerate optimizer state offload to CPU for low VRAM GPUs")
            accelerate_config = {
                "compute_environment": "LOCAL_MACHINE",
                "deepspeed_config": None,
                "distributed_type": "NO",
                "downcast_bf16": False,
                "fsdp_config": {},
                "main_training_function": "main",
                "mixed_precision": "no",
                "num_machines": 1,
                "num_processes": 1,
                "use_cpu": False,
                "zero3_init_flag": False,
                "offload_optimizer_device": "cpu",
                "offload_param_device": "none"
            }
            training_args.accelerate_config = accelerate_config
            use_accelerate_offload = True

    # Model loading with device_map="auto" for automatic offloading
    logger.info(f"Loading model with automatic memory management from {model_args.model_name_or_path}")
    
    # Create model arguments dict with automatic offloading
    model_kwargs = {
        # Don't use "auto" device_map initially to avoid meta tensor issues
        "device_map": None,
        "trust_remote_code": True
    }
    
    # Configure quantization if requested
    if model_args.use_4bit_quantization:
        from transformers import BitsAndBytesConfig
        compute_dtype = getattr(torch, model_args.bnb_4bit_compute_dtype)
        quant_storage_dtype = getattr(torch, model_args.bnb_4bit_quant_storage_dtype)
        
        model_kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=model_args.use_4bit_quantization,
            bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=model_args.use_nested_quant,
            bnb_4bit_quant_storage=quant_storage_dtype,
        )
        # For 4-bit models, we can use device_map="auto"
        model_kwargs["device_map"] = "auto"
        logger.info("Using 4-bit quantization for memory efficiency")
    elif model_args.use_8bit_quantization:
        from transformers import BitsAndBytesConfig
        model_kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_8bit=model_args.use_8bit_quantization
        )
        # For 8-bit models, we can use device_map="auto"
        model_kwargs["device_map"] = "auto"
        logger.info("Using 8-bit quantization for memory efficiency")
    
    # Flash attention for memory efficiency when supported
    if model_args.use_flash_attn and torch.cuda.is_available() and model_args.use_cuda:
        model_kwargs["attn_implementation"] = "flash_attention_2"
        logger.info("Using Flash Attention 2 for memory efficiency")
    
    # Load model with built-in memory management features
    model, peft_config, tokenizer = create_and_prepare_model(
        model_args, data_args, training_args, model_kwargs=model_kwargs
    )
    
    # If model has meta tensors, handle them properly
    if hasattr(model, "is_meta") and model.is_meta:
        logger.info("Model has meta tensors, using to_empty() to properly initialize")
        device = "cuda" if torch.cuda.is_available() and model_args.use_cuda else "cpu"
        model = model.to_empty(device=device)
    
    # Apply gradient checkpointing for memory efficiency
    if training_args.gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
        logger.info("Enabling gradient checkpointing for memory efficiency")
        model.gradient_checkpointing_enable()
        model.config.use_cache = False
    
    # Allow only one full forward/backward pass at a time (if needed for memory)
    if torch.cuda.is_available() and memory_manager.get_memory_info().get("vram_total_gb", 0) < 8:
        torch.cuda.set_per_process_memory_fraction(0.9)
        logger.info("Setting memory fraction limit to avoid OOM errors")

    # datasets
    train_dataset = create_chat_data(
        data_args,
        tokenizer,
    )
    
    response_template = "\n<|im_start|>assistant\n"
    
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
    
    training_args.dataset_kwargs = {
        "append_concat_token": data_args.append_concat_token,
        "add_special_tokens": data_args.add_special_tokens,
    }

    # Use DeepSpeed to handle meta tensors if available
    try:
        # Only configure DeepSpeed if meta tensors are present and DeepSpeed is available
        if hasattr(model, "is_meta") and model.is_meta:
            logger.info("Model has meta tensors, checking DeepSpeed availability")
            # First verify DeepSpeed is properly installed and importable
            try:
                import deepspeed
                logger.info("DeepSpeed is available, configuring for meta tensor handling")
                
                # Configure with appropriate settings for meta tensors
                training_args.deepspeed = {
                    "zero_stage": 3,
                    "offload_optimizer": {
                        "device": "cpu"
                    },
                    "offload_param": {
                        "device": "cpu"
                    },
                    "zero3_init_flag": True,
                    "zero_force_ds_cpu_optimizer": False
                }
                logger.info("DeepSpeed configured for meta tensor handling")
            except ImportError:
                logger.warning("DeepSpeed is not available, meta tensors will be handled differently")
                # If DeepSpeed isn't available, use alternative approach to handle meta tensors
                if torch.cuda.is_available() and model_args.use_cuda:
                    logger.info("Initializing meta tensors on GPU")
                    # Use device_map instead of DeepSpeed for meta tensor initialization
                    from accelerate import init_empty_weights
                    with init_empty_weights():
                        model.to_empty(device="cuda")
                else:
                    logger.info("Initializing meta tensors on CPU")
                    model.to_empty(device="cpu")
    except Exception as e:
        logger.warning(f"Could not configure meta tensor handling: {e}")
        logger.warning(traceback.format_exc())

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        peft_config=peft_config,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
    )
    
    # Print model details
    trainer.accelerator.print(f"{trainer.model}")
    
    if hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()
    
    # Memory usage tracking callback
    class MemoryMonitorCallback(transformers.TrainerCallback):
        def __init__(self):
            self.memory_manager = get_memory_manager()
        
        def on_step_end(self, args, state, control, **kwargs):
            # Check memory every 5 steps
            if state.global_step % 5 == 0 and torch.cuda.is_available():
                info = self.memory_manager.get_memory_info()
                vram_usage_pct = info.get("vram_used_gb", 0) / info.get("vram_total_gb", 1) * 100
                
                if vram_usage_pct > 90:
                    logger.info(f"VRAM usage high ({vram_usage_pct:.1f}%), cleaning cache")
                    self.memory_manager.cleanup_memory()
        
        def on_save(self, args, state, control, **kwargs):
            # Free up memory before saving
            self.memory_manager.cleanup_memory(force=True)

    # Add memory monitoring
    trainer.add_callback(MemoryMonitorCallback())
    
    # Add existing debug callback
    trainer.add_callback(DebugCallback())

    # Resume from checkpoint if specified
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint

    # Training with automatic memory management
    try:
        logger.info("Starting training with memory-optimized configuration")
        trainer.train(resume_from_checkpoint=checkpoint)
    except Exception as e:
        logger.error(f"Error during training: {str(e)}")
        logger.error(f"Error type: {type(e)}")
        logger.error(f"Traceback: {traceback.format_exc()}")
        raise

    # Save the model
    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
    
    # Clean up before saving
    memory_manager.cleanup_memory(force=True)
    
    trainer.save_model()
    logger.info("Training completed successfully")


# Create a patch to handle autocast compatibility
def get_autocast():
    if hasattr(torch.cpu, "amp") and hasattr(torch.cpu.amp, "autocast"):
        # Old version
        return torch.cpu.amp.autocast
    else:
        # New version
        return lambda **kwargs: torch.amp.autocast("cpu", **kwargs)


# Replace the original torch.cpu.amp.autocast with our compatible function
torch.cpu.amp.autocast = get_autocast()


if __name__ == "__main__":
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1])
        )
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    main(model_args, data_args, training_args)