panikos commited on
Commit
52e3179
·
verified ·
1 Parent(s): 93912b6

Upload production_training_llama_qlora.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. production_training_llama_qlora.py +108 -0
production_training_llama_qlora.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "transformers>=4.40.0", "datasets>=2.18.0", "accelerate>=0.28.0", "bitsandbytes>=0.41.0"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+ from transformers import BitsAndBytesConfig, AutoModelForCausalLM
9
+ import torch
10
+ import trackio
11
+
12
+ print("=" * 80)
13
+ print("PRODUCTION: Biomedical Llama Fine-Tuning with QLoRA (Full Dataset)")
14
+ print("=" * 80)
15
+
16
+ print("\n[1/5] Loading dataset...")
17
+ dataset = load_dataset("panikos/biomedical-llama-training")
18
+
19
+ train_dataset = dataset["train"]
20
+ eval_dataset = dataset["validation"]
21
+
22
+ print(f" Train: {len(train_dataset)} examples")
23
+ print(f" Eval: {len(eval_dataset)} examples")
24
+
25
+ print("\n[2/5] Configuring 4-bit quantization...")
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_use_double_quant=True,
29
+ bnb_4bit_quant_type="nf4",
30
+ bnb_4bit_compute_dtype=torch.bfloat16
31
+ )
32
+ print(" Quantization: 4-bit NF4")
33
+ print(" Compute dtype: bfloat16")
34
+ print(" Double quantization: enabled")
35
+
36
+ print("\n[3/5] Configuring LoRA...")
37
+ lora_config = LoraConfig(
38
+ r=16,
39
+ lora_alpha=32,
40
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
41
+ lora_dropout=0.05,
42
+ bias="none",
43
+ task_type="CAUSAL_LM"
44
+ )
45
+ print(" LoRA rank: 16, alpha: 32")
46
+
47
+ print("\n[4/5] Loading quantized model...")
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ "meta-llama/Llama-3.1-8B-Instruct",
50
+ quantization_config=bnb_config,
51
+ device_map="auto"
52
+ )
53
+
54
+ print("\n[5/5] Initializing trainer...")
55
+ trainer = SFTTrainer(
56
+ model=model,
57
+ train_dataset=train_dataset,
58
+ eval_dataset=eval_dataset,
59
+ peft_config=lora_config,
60
+ args=SFTConfig(
61
+ output_dir="llama-biomedical-production-qlora",
62
+ num_train_epochs=3,
63
+ per_device_train_batch_size=2,
64
+ gradient_accumulation_steps=4,
65
+ learning_rate=2e-4,
66
+ lr_scheduler_type="cosine",
67
+ warmup_ratio=0.1,
68
+ logging_steps=50,
69
+ eval_strategy="steps",
70
+ eval_steps=200,
71
+ save_strategy="epoch",
72
+ save_total_limit=2,
73
+ push_to_hub=True,
74
+ hub_model_id="panikos/llama-biomedical-production-qlora",
75
+ hub_private_repo=True,
76
+ bf16=True,
77
+ gradient_checkpointing=True,
78
+ report_to="trackio",
79
+ project="biomedical-llama-training",
80
+ run_name="production-full-dataset-qlora-v1"
81
+ )
82
+ )
83
+
84
+ print("\n[6/6] Starting training...")
85
+ print(" Model: meta-llama/Llama-3.1-8B-Instruct")
86
+ print(" Method: QLoRA (4-bit) with LoRA adapters")
87
+ print(" Epochs: 3")
88
+ print(" Training examples: 17,008")
89
+ print(" Validation examples: 896")
90
+ print(" Batch size: 2 x 4 = 8 (effective)")
91
+ print(" Estimated steps: ~6,378 (2,126 per epoch)")
92
+ print(" Gradient checkpointing: ENABLED")
93
+ print(" Memory: ~5-6GB (optimized with QLoRA)")
94
+ print()
95
+
96
+ trainer.train()
97
+
98
+ print("\n" + "=" * 80)
99
+ print("Pushing model to Hub...")
100
+ print("=" * 80)
101
+ trainer.push_to_hub()
102
+
103
+ print("\n" + "=" * 80)
104
+ print("PRODUCTION TRAINING COMPLETE!")
105
+ print("=" * 80)
106
+ print("\nModel: https://huggingface.co/panikos/llama-biomedical-production-qlora")
107
+ print("Dashboard: https://panikos-trackio.hf.space/")
108
+ print("\nYour biomedical Llama model is ready!")