Update train_lora_mistral.py
Browse files- train_lora_mistral.py +18 -7
train_lora_mistral.py
CHANGED
|
@@ -4,7 +4,7 @@ from fastapi.responses import JSONResponse
|
|
| 4 |
from datetime import datetime
|
| 5 |
from datasets import load_dataset
|
| 6 |
from huggingface_hub import HfApi
|
| 7 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
| 8 |
from peft import get_peft_model, LoraConfig, TaskType
|
| 9 |
import torch
|
| 10 |
|
|
@@ -36,7 +36,6 @@ def run_health_server():
|
|
| 36 |
threading.Thread(target=run_health_server, daemon=True).start()
|
| 37 |
|
| 38 |
# === Log
|
| 39 |
-
|
| 40 |
def log(message):
|
| 41 |
timestamp = datetime.now().strftime("%H:%M:%S")
|
| 42 |
print(f"[{timestamp}] {message}")
|
|
@@ -55,8 +54,11 @@ base_model.config.pad_token_id = tokenizer.pad_token_id
|
|
| 55 |
log("🎯 LoRA adapter uygulanıyor...")
|
| 56 |
peft_config = LoraConfig(
|
| 57 |
task_type=TaskType.CAUSAL_LM,
|
| 58 |
-
r=64,
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
model = get_peft_model(base_model, peft_config)
|
| 62 |
model.print_trainable_parameters()
|
|
@@ -65,6 +67,7 @@ log("📦 Parquet dosyaları listeleniyor...")
|
|
| 65 |
api = HfApi()
|
| 66 |
files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
|
| 67 |
selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
|
|
|
|
| 68 |
if not selected_files:
|
| 69 |
log("⚠️ Parquet bulunamadı. Eğitim iptal.")
|
| 70 |
exit(0)
|
|
@@ -84,6 +87,8 @@ training_args = TrainingArguments(
|
|
| 84 |
fp16=False
|
| 85 |
)
|
| 86 |
|
|
|
|
|
|
|
| 87 |
for file in selected_files:
|
| 88 |
try:
|
| 89 |
log(f"\n📄 Yükleniyor: {file}")
|
|
@@ -97,12 +102,18 @@ for file in selected_files:
|
|
| 97 |
if len(dataset) == 0:
|
| 98 |
continue
|
| 99 |
|
| 100 |
-
#
|
|
|
|
| 101 |
first_row = dataset[0]
|
| 102 |
decoded_prompt = tokenizer.decode(first_row["input_ids"], skip_special_tokens=True)
|
| 103 |
-
log(f"📌 Örnek prompt: {decoded_prompt}")
|
| 104 |
|
| 105 |
-
trainer = Trainer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
log("🚀 Eğitim başlıyor...")
|
| 107 |
trainer.train()
|
| 108 |
log("✅ Eğitim tamam.")
|
|
|
|
| 4 |
from datetime import datetime
|
| 5 |
from datasets import load_dataset
|
| 6 |
from huggingface_hub import HfApi
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
| 8 |
from peft import get_peft_model, LoraConfig, TaskType
|
| 9 |
import torch
|
| 10 |
|
|
|
|
| 36 |
threading.Thread(target=run_health_server, daemon=True).start()
|
| 37 |
|
| 38 |
# === Log
|
|
|
|
| 39 |
def log(message):
|
| 40 |
timestamp = datetime.now().strftime("%H:%M:%S")
|
| 41 |
print(f"[{timestamp}] {message}")
|
|
|
|
| 54 |
log("🎯 LoRA adapter uygulanıyor...")
|
| 55 |
peft_config = LoraConfig(
|
| 56 |
task_type=TaskType.CAUSAL_LM,
|
| 57 |
+
r=64,
|
| 58 |
+
lora_alpha=16,
|
| 59 |
+
lora_dropout=0.1,
|
| 60 |
+
bias="none",
|
| 61 |
+
fan_in_fan_out=False
|
| 62 |
)
|
| 63 |
model = get_peft_model(base_model, peft_config)
|
| 64 |
model.print_trainable_parameters()
|
|
|
|
| 67 |
api = HfApi()
|
| 68 |
files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
|
| 69 |
selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
|
| 70 |
+
|
| 71 |
if not selected_files:
|
| 72 |
log("⚠️ Parquet bulunamadı. Eğitim iptal.")
|
| 73 |
exit(0)
|
|
|
|
| 87 |
fp16=False
|
| 88 |
)
|
| 89 |
|
| 90 |
+
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 91 |
+
|
| 92 |
for file in selected_files:
|
| 93 |
try:
|
| 94 |
log(f"\n📄 Yükleniyor: {file}")
|
|
|
|
| 102 |
if len(dataset) == 0:
|
| 103 |
continue
|
| 104 |
|
| 105 |
+
# prompt tanımı: tokenize edilmiş dataset içinde input_ids zaten var
|
| 106 |
+
# sadece örnek bir tanesini loglayalım
|
| 107 |
first_row = dataset[0]
|
| 108 |
decoded_prompt = tokenizer.decode(first_row["input_ids"], skip_special_tokens=True)
|
| 109 |
+
log(f"📌 Örnek prompt: {decoded_prompt[:200]}...")
|
| 110 |
|
| 111 |
+
trainer = Trainer(
|
| 112 |
+
model=model,
|
| 113 |
+
args=training_args,
|
| 114 |
+
train_dataset=dataset,
|
| 115 |
+
data_collator=collator
|
| 116 |
+
)
|
| 117 |
log("🚀 Eğitim başlıyor...")
|
| 118 |
trainer.train()
|
| 119 |
log("✅ Eğitim tamam.")
|