Dongjin1203's picture
Initial commit for HF Spaces deployment
4739096
import torch
import wandb
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from datasets import load_dataset
import os
import shutil
import time
# ==========================================
# [FINAL SCRIPT] Running on Terminal
# ==========================================
print(">>> [System] ์Šคํฌ๋ฆฝํŠธ ์‹œ์ž‘. ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋กœ๋”ฉ ์™„๋ฃŒ.")
# 1. WandB ์ฐŒ๊บผ๊ธฐ ํด๋” ๊ฐ•์ œ ์‚ญ์ œ
if os.path.exists("wandb"):
try:
shutil.rmtree("wandb")
print(">>> [System] ๊ธฐ์กด WandB ์บ์‹œ ์‚ญ์ œ ์™„๋ฃŒ")
except:
pass
# 2. ๋ณต๊ตฌ ๋ชจ๋“œ ์ ๊ฒ€
output_dir = "outputs_final"
last_checkpoint = None
if os.path.isdir(output_dir):
last_checkpoint = get_last_checkpoint(output_dir)
if last_checkpoint:
print(f">>> [Resume] ์ด์ „ ํ•™์Šต ๊ธฐ๋ก ๋ฐœ๊ฒฌ: {last_checkpoint}")
else:
print(">>> [Start] ์ƒˆ๋กœ์šด ํ•™์Šต ์‹œ์ž‘")
# 3. WandB ์„ค์ •
try:
wandb.finish()
except:
pass
unique_id = f"run_{int(time.time())}"
wandb.init(
entity="hambur1203-project",
project="BiddinMate_Production_SFT",
name="Llama3-8B-Final-3Epochs",
id=unique_id,
resume="allow"
)
# 4. ๋ชจ๋ธ ๋กœ๋“œ (0๋ฒˆ GPU ๊ฐ•์ œ ์ง€์ •)
print(">>> [Model] Llama-3 ๋กœ๋“œ ์ค‘...")
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "beomi/Llama-3-Open-Ko-8B",
max_seq_length = max_seq_length,
dtype = None,
load_in_4bit = True,
device_map = {"": 0} # ํ•ต์‹ฌ: GPU 0๋ฒˆ ๊ณ ์ •
)
# 5. LoRA ์„ค์ •
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
# 6. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
print(">>> [Data] ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์ค‘...")
dataset = load_dataset("json", data_files="sft_train_llama.jsonl", split="train")
# 7. ํ•™์Šต ์„ค์ •
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
num_train_epochs = 3,
warmup_steps = 100,
learning_rate = 2e-4,
report_to = "wandb",
run_name = "Llama3-8B-Final-3Epochs",
logging_steps = 1,
save_strategy = "epoch",
output_dir = output_dir,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
optim = "adamw_8bit",
weight_decay = 0.01,
seed = 3407,
),
)
# 8. ์‹คํ–‰
print(">>> [Train] ํ•™์Šต ์‹œ์ž‘! (WandB๋ฅผ ํ™•์ธํ•˜์„ธ์š”)")
if last_checkpoint:
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()