Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +17 -29
train_model.py
CHANGED
|
@@ -12,10 +12,10 @@ from transformers import (
|
|
| 12 |
DataCollatorForLanguageModeling,
|
| 13 |
DataCollatorWithPadding,
|
| 14 |
)
|
| 15 |
-
from datasets import load_dataset
|
| 16 |
import torch
|
| 17 |
import os
|
| 18 |
-
from huggingface_hub import login, HfApi
|
| 19 |
import logging
|
| 20 |
|
| 21 |
from torch.optim import AdamW # Import PyTorch's AdamW
|
|
@@ -34,10 +34,9 @@ def setup_logging(log_file_path):
|
|
| 34 |
f_handler.setLevel(logging.INFO)
|
| 35 |
|
| 36 |
# Create formatters and add to handlers
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
f_handler.setFormatter(f_format)
|
| 41 |
|
| 42 |
# Add handlers to the logger
|
| 43 |
logger.addHandler(c_handler)
|
|
@@ -66,30 +65,18 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
|
|
| 66 |
"""
|
| 67 |
logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
|
| 68 |
try:
|
| 69 |
-
if
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
dataset, config = dataset_name.split('/', 1)
|
| 73 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
| 74 |
-
else:
|
| 75 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
| 76 |
-
logging.info("Dataset loaded successfully for generation task.")
|
| 77 |
-
def tokenize_function(examples):
|
| 78 |
-
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
| 79 |
-
elif task == "classification":
|
| 80 |
-
if '/' in dataset_name:
|
| 81 |
-
dataset, config = dataset_name.split('/', 1)
|
| 82 |
-
dataset = load_dataset(dataset, config, split='train')
|
| 83 |
-
else:
|
| 84 |
-
dataset = load_dataset(dataset_name, split='train')
|
| 85 |
-
logging.info("Dataset loaded successfully for classification task.")
|
| 86 |
-
# Assuming the dataset has 'text' and 'label' columns
|
| 87 |
-
def tokenize_function(examples):
|
| 88 |
-
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
| 89 |
else:
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
|
| 94 |
logging.info("Dataset tokenization complete.")
|
| 95 |
return tokenized_datasets
|
|
@@ -186,7 +173,7 @@ def main():
|
|
| 186 |
logging.info("Setting pad_token to eos_token.")
|
| 187 |
tokenizer.pad_token = tokenizer.eos_token
|
| 188 |
logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
|
| 189 |
-
#
|
| 190 |
model = initialize_model(
|
| 191 |
task=args.task,
|
| 192 |
model_name=args.model_name,
|
|
@@ -315,3 +302,4 @@ def main():
|
|
| 315 |
|
| 316 |
if __name__ == "__main__":
|
| 317 |
main()
|
|
|
|
|
|
| 12 |
DataCollatorForLanguageModeling,
|
| 13 |
DataCollatorWithPadding,
|
| 14 |
)
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
import torch
|
| 17 |
import os
|
| 18 |
+
from huggingface_hub import login, HfApi
|
| 19 |
import logging
|
| 20 |
|
| 21 |
from torch.optim import AdamW # Import PyTorch's AdamW
|
|
|
|
| 34 |
f_handler.setLevel(logging.INFO)
|
| 35 |
|
| 36 |
# Create formatters and add to handlers
|
| 37 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 38 |
+
c_handler.setFormatter(formatter)
|
| 39 |
+
f_handler.setFormatter(formatter)
|
|
|
|
| 40 |
|
| 41 |
# Add handlers to the logger
|
| 42 |
logger.addHandler(c_handler)
|
|
|
|
| 65 |
"""
|
| 66 |
logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
|
| 67 |
try:
|
| 68 |
+
if '/' in dataset_name:
|
| 69 |
+
dataset, config = dataset_name.split('/', 1)
|
| 70 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
else:
|
| 72 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
| 73 |
|
| 74 |
+
logging.info("Dataset loaded successfully.")
|
| 75 |
+
|
| 76 |
+
def tokenize_function(examples):
|
| 77 |
+
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
| 78 |
+
|
| 79 |
+
# Tokenize the dataset
|
| 80 |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
|
| 81 |
logging.info("Dataset tokenization complete.")
|
| 82 |
return tokenized_datasets
|
|
|
|
| 173 |
logging.info("Setting pad_token to eos_token.")
|
| 174 |
tokenizer.pad_token = tokenizer.eos_token
|
| 175 |
logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
|
| 176 |
+
# Initialize model after setting pad_token
|
| 177 |
model = initialize_model(
|
| 178 |
task=args.task,
|
| 179 |
model_name=args.model_name,
|
|
|
|
| 302 |
|
| 303 |
if __name__ == "__main__":
|
| 304 |
main()
|
| 305 |
+
|