File size: 3,108 Bytes
4739096 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
import wandb
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from datasets import load_dataset
import os
import shutil
import time
# ==========================================
# [FINAL SCRIPT] Running on Terminal
# ==========================================
print(">>> [System] ์คํฌ๋ฆฝํธ ์์. ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ก๋ฉ ์๋ฃ.")
# 1. WandB ์ฐ๊บผ๊ธฐ ํด๋ ๊ฐ์ ์ญ์
if os.path.exists("wandb"):
try:
shutil.rmtree("wandb")
print(">>> [System] ๊ธฐ์กด WandB ์บ์ ์ญ์ ์๋ฃ")
except:
pass
# 2. ๋ณต๊ตฌ ๋ชจ๋ ์ ๊ฒ
output_dir = "outputs_final"
last_checkpoint = None
if os.path.isdir(output_dir):
last_checkpoint = get_last_checkpoint(output_dir)
if last_checkpoint:
print(f">>> [Resume] ์ด์ ํ์ต ๊ธฐ๋ก ๋ฐ๊ฒฌ: {last_checkpoint}")
else:
print(">>> [Start] ์๋ก์ด ํ์ต ์์")
# 3. WandB ์ค์
try:
wandb.finish()
except:
pass
unique_id = f"run_{int(time.time())}"
wandb.init(
entity="hambur1203-project",
project="BiddinMate_Production_SFT",
name="Llama3-8B-Final-3Epochs",
id=unique_id,
resume="allow"
)
# 4. ๋ชจ๋ธ ๋ก๋ (0๋ฒ GPU ๊ฐ์ ์ง์ )
print(">>> [Model] Llama-3 ๋ก๋ ์ค...")
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "beomi/Llama-3-Open-Ko-8B",
max_seq_length = max_seq_length,
dtype = None,
load_in_4bit = True,
device_map = {"": 0} # ํต์ฌ: GPU 0๋ฒ ๊ณ ์
)
# 5. LoRA ์ค์
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
# 6. ๋ฐ์ดํฐ์
๋ก๋
print(">>> [Data] ๋ฐ์ดํฐ์
๋ก๋ ์ค...")
dataset = load_dataset("json", data_files="sft_train_llama.jsonl", split="train")
# 7. ํ์ต ์ค์
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
num_train_epochs = 3,
warmup_steps = 100,
learning_rate = 2e-4,
report_to = "wandb",
run_name = "Llama3-8B-Final-3Epochs",
logging_steps = 1,
save_strategy = "epoch",
output_dir = output_dir,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
optim = "adamw_8bit",
weight_decay = 0.01,
seed = 3407,
),
)
# 8. ์คํ
print(">>> [Train] ํ์ต ์์! (WandB๋ฅผ ํ์ธํ์ธ์)")
if last_checkpoint:
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
|