|
|
import torch |
|
|
import wandb |
|
|
from unsloth import FastLanguageModel, is_bfloat16_supported |
|
|
from trl import SFTTrainer |
|
|
from transformers import TrainingArguments |
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
from datasets import load_dataset |
|
|
import os |
|
|
import shutil |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(">>> [System] ์คํฌ๋ฆฝํธ ์์. ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ก๋ฉ ์๋ฃ.") |
|
|
|
|
|
|
|
|
if os.path.exists("wandb"): |
|
|
try: |
|
|
shutil.rmtree("wandb") |
|
|
print(">>> [System] ๊ธฐ์กด WandB ์บ์ ์ญ์ ์๋ฃ") |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
output_dir = "outputs_final" |
|
|
last_checkpoint = None |
|
|
|
|
|
if os.path.isdir(output_dir): |
|
|
last_checkpoint = get_last_checkpoint(output_dir) |
|
|
if last_checkpoint: |
|
|
print(f">>> [Resume] ์ด์ ํ์ต ๊ธฐ๋ก ๋ฐ๊ฒฌ: {last_checkpoint}") |
|
|
else: |
|
|
print(">>> [Start] ์๋ก์ด ํ์ต ์์") |
|
|
|
|
|
|
|
|
try: |
|
|
wandb.finish() |
|
|
except: |
|
|
pass |
|
|
|
|
|
unique_id = f"run_{int(time.time())}" |
|
|
|
|
|
wandb.init( |
|
|
entity="hambur1203-project", |
|
|
project="BiddinMate_Production_SFT", |
|
|
name="Llama3-8B-Final-3Epochs", |
|
|
id=unique_id, |
|
|
resume="allow" |
|
|
) |
|
|
|
|
|
|
|
|
print(">>> [Model] Llama-3 ๋ก๋ ์ค...") |
|
|
max_seq_length = 2048 |
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name = "beomi/Llama-3-Open-Ko-8B", |
|
|
max_seq_length = max_seq_length, |
|
|
dtype = None, |
|
|
load_in_4bit = True, |
|
|
device_map = {"": 0} |
|
|
) |
|
|
|
|
|
|
|
|
model = FastLanguageModel.get_peft_model( |
|
|
model, |
|
|
r = 16, |
|
|
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", |
|
|
"gate_proj", "up_proj", "down_proj",], |
|
|
lora_alpha = 16, |
|
|
lora_dropout = 0, |
|
|
bias = "none", |
|
|
use_gradient_checkpointing = "unsloth", |
|
|
random_state = 3407, |
|
|
use_rslora = False, |
|
|
loftq_config = None, |
|
|
) |
|
|
|
|
|
|
|
|
print(">>> [Data] ๋ฐ์ดํฐ์
๋ก๋ ์ค...") |
|
|
dataset = load_dataset("json", data_files="sft_train_llama.jsonl", split="train") |
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model = model, |
|
|
tokenizer = tokenizer, |
|
|
train_dataset = dataset, |
|
|
dataset_text_field = "text", |
|
|
max_seq_length = max_seq_length, |
|
|
dataset_num_proc = 2, |
|
|
packing = False, |
|
|
args = TrainingArguments( |
|
|
per_device_train_batch_size = 2, |
|
|
gradient_accumulation_steps = 4, |
|
|
num_train_epochs = 3, |
|
|
warmup_steps = 100, |
|
|
learning_rate = 2e-4, |
|
|
report_to = "wandb", |
|
|
run_name = "Llama3-8B-Final-3Epochs", |
|
|
logging_steps = 1, |
|
|
save_strategy = "epoch", |
|
|
output_dir = output_dir, |
|
|
fp16 = not is_bfloat16_supported(), |
|
|
bf16 = is_bfloat16_supported(), |
|
|
optim = "adamw_8bit", |
|
|
weight_decay = 0.01, |
|
|
seed = 3407, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
print(">>> [Train] ํ์ต ์์! (WandB๋ฅผ ํ์ธํ์ธ์)") |
|
|
if last_checkpoint: |
|
|
trainer.train(resume_from_checkpoint=True) |
|
|
else: |
|
|
trainer.train() |
|
|
|