ciyidogan commited on
Commit
126bdfd
·
verified ·
1 Parent(s): 28d6858

Update train_lora_mistral.py

Browse files
Files changed (1) hide show
  1. train_lora_mistral.py +142 -0
train_lora_mistral.py CHANGED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, zipfile, shutil, time, traceback, threading, uvicorn
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import JSONResponse
4
+ from datetime import datetime
5
+ from datasets import load_dataset
6
+ from huggingface_hub import HfApi
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
8
+ from peft import get_peft_model, LoraConfig, TaskType
9
+ import torch
10
+
11
+ # === Sabitler ===
12
+ START_NUMBER = 0
13
+ END_NUMBER = 9
14
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
15
+ TOKENIZED_DATASET_ID = "UcsTurkey/turkish-general-culture-tokenized"
16
+ ZIP_UPLOAD_REPO = "UcsTurkey/trained-zips"
17
+ HF_TOKEN = os.environ.get("HF_TOKEN")
18
+ BATCH_SIZE = 1
19
+ EPOCHS = 2
20
+ MAX_LENGTH = 2048
21
+ OUTPUT_DIR = "/data/output"
22
+ ZIP_FOLDER = "/data/zip_temp"
23
+ zip_name = f"trained_model_{START_NUMBER:03d}_{END_NUMBER:03d}.zip"
24
+ ZIP_PATH = os.path.join(ZIP_FOLDER, zip_name)
25
+
26
+ # === Health check
27
+ app = FastAPI()
28
+
29
+ @app.get("/")
30
+ def health():
31
+ return JSONResponse(content={"status": "ok"})
32
+
33
+ def run_health_server():
34
+ uvicorn.run(app, host="0.0.0.0", port=7860)
35
+
36
+ threading.Thread(target=run_health_server, daemon=True).start()
37
+
38
+ # === Log
39
+ def log(message):
40
+ timestamp = datetime.now().strftime("%H:%M:%S")
41
+ print(f"[{timestamp}] {message}")
42
+ sys.stdout.flush()
43
+
44
+ # === Eğitim Başlıyor
45
+ log("🛠️ Ortam hazırlanıyor...")
46
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+
50
+ log("🧠 Model indiriliyor...")
51
+ base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
52
+ base_model.config.pad_token_id = tokenizer.pad_token_id
53
+
54
+ log("🎯 LoRA adapter uygulanıyor...")
55
+ peft_config = LoraConfig(
56
+ task_type=TaskType.CAUSAL_LM,
57
+ r=64, lora_alpha=16, lora_dropout=0.1,
58
+ bias="none", fan_in_fan_out=False
59
+ )
60
+ model = get_peft_model(base_model, peft_config)
61
+ model.print_trainable_parameters()
62
+
63
+ log("📦 Parquet dosyaları listeleniyor...")
64
+ api = HfApi()
65
+ files = api.list_repo_files(repo_id=TOKENIZED_DATASET_ID, repo_type="dataset", token=HF_TOKEN)
66
+ selected_files = sorted([f for f in files if f.startswith("chunk_") and f.endswith(".parquet")])[START_NUMBER:END_NUMBER+1]
67
+ if not selected_files:
68
+ log("⚠️ Parquet bulunamadı. Eğitim iptal.")
69
+ exit(0)
70
+
71
+ training_args = TrainingArguments(
72
+ output_dir=OUTPUT_DIR,
73
+ per_device_train_batch_size=BATCH_SIZE,
74
+ num_train_epochs=EPOCHS,
75
+ save_strategy="epoch",
76
+ save_total_limit=2,
77
+ learning_rate=2e-4,
78
+ disable_tqdm=True,
79
+ logging_strategy="steps",
80
+ logging_steps=10,
81
+ report_to=[],
82
+ bf16=True,
83
+ fp16=False
84
+ )
85
+
86
+ for file in selected_files:
87
+ try:
88
+ log(f"\n📄 Yükleniyor: {file}")
89
+ dataset = load_dataset(
90
+ path=TOKENIZED_DATASET_ID,
91
+ data_files={"train": file},
92
+ split="train",
93
+ token=HF_TOKEN
94
+ )
95
+ log(f"🔍 {len(dataset)} örnek")
96
+ if len(dataset) == 0:
97
+ continue
98
+ trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
99
+ log("🚀 Eğitim başlıyor...")
100
+ trainer.train()
101
+ log("✅ Eğitim tamam.")
102
+ except Exception as e:
103
+ log(f"❌ Hata: {file} → {e}")
104
+ traceback.print_exc()
105
+
106
+ # === Zip
107
+ log("📦 Model zipleniyor...")
108
+ try:
109
+ tmp_dir = os.path.join(ZIP_FOLDER, "temp_save")
110
+ os.makedirs(tmp_dir, exist_ok=True)
111
+ model.save_pretrained(tmp_dir)
112
+ tokenizer.save_pretrained(tmp_dir)
113
+
114
+ with zipfile.ZipFile(ZIP_PATH, "w", zipfile.ZIP_DEFLATED) as zipf:
115
+ for root, _, files in os.walk(tmp_dir):
116
+ for file in files:
117
+ filepath = os.path.join(root, file)
118
+ arcname = os.path.relpath(filepath, tmp_dir)
119
+ zipf.write(filepath, arcname=os.path.join("output", arcname))
120
+ log(f"✅ Zip oluşturuldu: {ZIP_PATH}")
121
+ except Exception as e:
122
+ log(f"❌ Zipleme hatası: {e}")
123
+ traceback.print_exc()
124
+
125
+ # === Upload
126
+ try:
127
+ log("☁️ Hugging Face'e yükleniyor...")
128
+ api.upload_file(
129
+ path_or_fileobj=ZIP_PATH,
130
+ path_in_repo=zip_name,
131
+ repo_id=ZIP_UPLOAD_REPO,
132
+ repo_type="model",
133
+ token=HF_TOKEN
134
+ )
135
+ log("✅ Upload tamam.")
136
+ except Exception as e:
137
+ log(f"❌ Upload hatası: {e}")
138
+ traceback.print_exc()
139
+
140
+ log("⏸️ Eğitim tamamlandı. Servis bekleme modunda...")
141
+ while True:
142
+ time.sleep(60)