|
|
import os |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
TrainingArguments, |
|
|
Trainer |
|
|
) |
|
|
from .config import Config |
|
|
from .dataset import DataProcessor |
|
|
from .metrics import compute_metrics |
|
|
from .visualization import plot_training_history |
|
|
|
|
|
def main(): |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
print(f"Using device: MPS (Mac Silicon Acceleration)") |
|
|
elif torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
print(f"Using device: CUDA") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
print(f"Using device: CPU") |
|
|
|
|
|
|
|
|
print(f"Loading tokenizer from {Config.BASE_MODEL}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL) |
|
|
|
|
|
|
|
|
print("Preparing datasets...") |
|
|
processor = DataProcessor(tokenizer) |
|
|
|
|
|
|
|
|
num_proc = max(1, os.cpu_count() - 1) |
|
|
|
|
|
|
|
|
dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR, num_proc=num_proc) |
|
|
|
|
|
train_dataset = dataset['train'] |
|
|
eval_dataset = dataset['test'] |
|
|
|
|
|
print(f"Training on {len(train_dataset)} samples, Validating on {len(eval_dataset)} samples.") |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
Config.BASE_MODEL, |
|
|
num_labels=Config.NUM_LABELS, |
|
|
id2label=Config.ID2LABEL, |
|
|
label2id=Config.LABEL2ID |
|
|
) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=Config.RESULTS_DIR, |
|
|
num_train_epochs=Config.NUM_EPOCHS, |
|
|
per_device_train_batch_size=Config.BATCH_SIZE, |
|
|
per_device_eval_batch_size=Config.BATCH_SIZE, |
|
|
learning_rate=Config.LEARNING_RATE, |
|
|
warmup_ratio=Config.WARMUP_RATIO, |
|
|
weight_decay=Config.WEIGHT_DECAY, |
|
|
logging_dir=os.path.join(Config.RESULTS_DIR, 'logs'), |
|
|
logging_steps=Config.LOGGING_STEPS, |
|
|
eval_strategy="steps", |
|
|
eval_steps=Config.EVAL_STEPS, |
|
|
save_steps=Config.SAVE_STEPS, |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="f1", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
tokenizer=tokenizer, |
|
|
compute_metrics=compute_metrics, |
|
|
) |
|
|
|
|
|
|
|
|
print("Starting training...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
print(f"Saving model to {Config.CHECKPOINT_DIR}...") |
|
|
trainer.save_model(Config.CHECKPOINT_DIR) |
|
|
tokenizer.save_pretrained(Config.CHECKPOINT_DIR) |
|
|
|
|
|
|
|
|
print("Generating training plots...") |
|
|
plot_save_path = os.path.join(Config.RESULTS_DIR, 'training_curves.png') |
|
|
plot_training_history(trainer.state.log_history, save_path=plot_save_path) |
|
|
|
|
|
print("Training completed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|