Spaces:
Sleeping
Sleeping
| # train_model.py (Training Script) | |
| import argparse | |
| from transformers import ( | |
| GPT2Config, | |
| GPT2LMHeadModel, | |
| BertConfig, | |
| BertForSequenceClassification, | |
| Trainer, | |
| TrainingArguments, | |
| AutoTokenizer, | |
| DataCollatorForLanguageModeling, | |
| DataCollatorWithPadding, | |
| ) | |
| from datasets import load_dataset | |
| import torch | |
| import os | |
| from huggingface_hub import login, HfApi | |
| import logging | |
| from torch.optim import AdamW | |
| def setup_logging(log_file_path): | |
| """ | |
| Sets up logging to both console and a file. | |
| """ | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| # Create handlers | |
| c_handler = logging.StreamHandler() | |
| f_handler = logging.FileHandler(log_file_path) | |
| c_handler.setLevel(logging.INFO) | |
| f_handler.setLevel(logging.INFO) | |
| # Create formatters and add to handlers | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| c_handler.setFormatter(formatter) | |
| f_handler.setFormatter(formatter) | |
| # Add handlers to the logger | |
| logger.addHandler(c_handler) | |
| logger.addHandler(f_handler) | |
| def parse_arguments(): | |
| """ | |
| Parses command-line arguments. | |
| """ | |
| parser = argparse.ArgumentParser(description="Train a custom LLM.") | |
| parser.add_argument("--task", type=str, required=True, choices=["generation", "classification"], | |
| help="Task type: 'generation' or 'classification'") | |
| parser.add_argument("--model_name", type=str, required=True, help="Name of the model") | |
| parser.add_argument("--dataset_name", type=str, required=True, help="Name of the Hugging Face dataset (e.g., 'wikitext/wikitext-2-raw-v1')") | |
| parser.add_argument("--num_layers", type=int, default=12, help="Number of hidden layers") | |
| parser.add_argument("--attention_heads", type=int, default=1, help="Number of attention heads") | |
| parser.add_argument("--hidden_size", type=int, default=64, help="Hidden size of the model") | |
| parser.add_argument("--vocab_size", type=int, default=30000, help="Vocabulary size") | |
| parser.add_argument("--sequence_length", type=int, default=512, help="Maximum sequence length") | |
| args = parser.parse_args() | |
| return args | |
| def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length): | |
| """ | |
| Loads and tokenizes the dataset based on the task. | |
| """ | |
| logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...") | |
| try: | |
| dataset = load_dataset(dataset_name, split='train') | |
| logging.info("Dataset loaded successfully.") | |
| # Log some examples to check dataset structure | |
| logging.info(f"Example data from the dataset: {dataset[:5]}") | |
| def clean_text(text): | |
| # Ensure each text is a string | |
| if isinstance(text, list): | |
| return " ".join([str(t) for t in text]) | |
| return str(text) | |
| def tokenize_function(examples): | |
| try: | |
| # Clean text to ensure correct format | |
| examples['text'] = [clean_text(text) for text in examples['text']] | |
| # Log the type and structure of text to debug | |
| logging.info(f"Type of examples['text']: {type(examples['text'])}") | |
| logging.info(f"First example type: {type(examples['text'][0])}") | |
| # Tokenize with truncation and padding | |
| tokens = tokenizer( | |
| examples['text'], | |
| truncation=True, | |
| max_length=sequence_length, | |
| padding=False, # Defer padding to data collator | |
| return_tensors=None # Let the data collator handle tensor creation | |
| ) | |
| # Log the tokens for debugging | |
| logging.info(f"Tokenized example: {tokens}") | |
| return tokens | |
| except Exception as e: | |
| logging.error(f"Error during tokenization: {e}") | |
| logging.error(f"Problematic example: {examples}") | |
| raise e | |
| # Tokenize the dataset using the modified tokenize_function | |
| tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True) | |
| logging.info("Dataset tokenization complete.") | |
| return tokenized_datasets | |
| except Exception as e: | |
| logging.error(f"Error loading or tokenizing dataset: {str(e)}") | |
| raise e | |
| def initialize_model(task, model_name, vocab_size, sequence_length, hidden_size, num_layers, attention_heads): | |
| """ | |
| Initializes the model configuration and model based on the task. | |
| """ | |
| logging.info(f"Initializing model for task '{task}'...") | |
| try: | |
| if task == "generation": | |
| config = GPT2Config( | |
| vocab_size=vocab_size, | |
| n_positions=sequence_length, | |
| n_ctx=sequence_length, | |
| n_embd=hidden_size, | |
| num_hidden_layers=num_layers, | |
| num_attention_heads=attention_heads, | |
| intermediate_size=4 * hidden_size, | |
| hidden_act='gelu', | |
| use_cache=True, | |
| ) | |
| model = GPT2LMHeadModel(config) | |
| logging.info("GPT2LMHeadModel initialized successfully.") | |
| elif task == "classification": | |
| config = BertConfig( | |
| vocab_size=vocab_size, | |
| max_position_embeddings=sequence_length, | |
| hidden_size=hidden_size, | |
| num_hidden_layers=num_layers, | |
| num_attention_heads=attention_heads, | |
| intermediate_size=4 * hidden_size, | |
| hidden_act='gelu', | |
| num_labels=2 # Adjust based on your classification task | |
| ) | |
| model = BertForSequenceClassification(config) | |
| logging.info("BertForSequenceClassification initialized successfully.") | |
| else: | |
| raise ValueError("Unsupported task type") | |
| return model | |
| except Exception as e: | |
| logging.error(f"Error initializing model: {str(e)}") | |
| raise e | |
| def get_optimizer(model, learning_rate): | |
| """ | |
| Returns the AdamW optimizer from PyTorch. | |
| """ | |
| return AdamW(model.parameters(), lr=learning_rate) | |
| def main(): | |
| # Parse arguments | |
| args = parse_arguments() | |
| # Setup logging | |
| log_file = "training.log" | |
| setup_logging(log_file) | |
| logging.info("Training script started.") | |
| # Initialize Hugging Face API | |
| api = HfApi() | |
| # Retrieve the Hugging Face API token from environment variables | |
| hf_token = os.getenv("HF_API_TOKEN") | |
| if not hf_token: | |
| logging.error("HF_API_TOKEN environment variable not set.") | |
| raise ValueError("HF_API_TOKEN environment variable not set.") | |
| # Perform login using the API token | |
| try: | |
| login(token=hf_token) | |
| logging.info("Successfully logged in to Hugging Face Hub.") | |
| except Exception as e: | |
| logging.error(f"Failed to log in to Hugging Face Hub: {str(e)}") | |
| raise e | |
| # Initialize tokenizer | |
| try: | |
| logging.info("Initializing tokenizer...") | |
| if args.task == "generation": | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| elif args.task == "classification": | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| else: | |
| raise ValueError("Unsupported task type") | |
| logging.info("Tokenizer initialized successfully.") | |
| # Set pad_token to eos_token if not already set | |
| if tokenizer.pad_token is None: | |
| logging.info("Setting pad_token to eos_token.") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Initialize model | |
| model = initialize_model( | |
| task=args.task, | |
| model_name=args.model_name, | |
| vocab_size=args.vocab_size, | |
| sequence_length=args.sequence_length, | |
| hidden_size=args.hidden_size, | |
| num_layers=args.num_layers, | |
| attention_heads=args.attention_heads | |
| ) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| except Exception as e: | |
| logging.error(f"Error initializing tokenizer or model: {str(e)}") | |
| raise e | |
| # Load and prepare dataset | |
| try: | |
| tokenized_datasets = load_and_prepare_dataset( | |
| task=args.task, | |
| dataset_name=args.dataset_name, | |
| tokenizer=tokenizer, | |
| sequence_length=args.sequence_length | |
| ) | |
| except Exception as e: | |
| logging.error("Failed to load and prepare dataset.") | |
| raise e | |
| # Define data collator | |
| if args.task == "generation": | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| elif args.task == "classification": | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest') # Handle padding dynamically during batching | |
| else: | |
| logging.error("Unsupported task type for data collator.") | |
| raise ValueError("Unsupported task type for data collator.") | |
| # Define training arguments | |
| training_args = TrainingArguments( | |
| output_dir=f"./models/{args.model_name}", | |
| num_train_epochs=3, | |
| per_device_train_batch_size=8 if args.task == "generation" else 16, | |
| save_steps=5000, | |
| save_total_limit=2, | |
| logging_steps=500, | |
| learning_rate=5e-4 if args.task == "generation" else 5e-5, | |
| remove_unused_columns=False, | |
| push_to_hub=False | |
| ) | |
| # Initialize Trainer with the data collator | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_datasets, | |
| data_collator=data_collator, | |
| optimizers=(get_optimizer(model, training_args.learning_rate), None) | |
| ) | |
| # Start training | |
| logging.info("Starting training...") | |
| try: | |
| trainer.train() | |
| logging.info("Training completed successfully.") | |
| except Exception as e: | |
| logging.error(f"Error during training: {str(e)}") | |
| raise e | |
| # Save the final model and tokenizer | |
| try: | |
| trainer.save_model(training_args.output_dir) | |
| tokenizer.save_pretrained(training_args.output_dir) | |
| logging.info(f"Model and tokenizer saved to '{training_args.output_dir}'.") | |
| except Exception as e: | |
| logging.error(f"Error saving model or tokenizer: {str(e)}") | |
| raise e | |
| # Push the model to Hugging Face Hub | |
| model_repo = f"{api.whoami(token=hf_token)['name']}/{args.model_name}" | |
| try: | |
| logging.info(f"Pushing model to Hugging Face Hub at '{model_repo}'...") | |
| api.create_repo(repo_id=model_repo, private=False, token=hf_token) | |
| logging.info(f"Repository '{model_repo}' created successfully.") | |
| except Exception as e: | |
| logging.warning(f"Repository might already exist: {str(e)}") | |
| try: | |
| model.push_to_hub(model_repo, use_auth_token=hf_token) | |
| tokenizer.push_to_hub(model_repo, use_auth_token=hf_token) | |
| logging.info(f"Model and tokenizer pushed to Hugging Face Hub at '{model_repo}'.") | |
| except Exception as e: | |
| logging.error(f"Error pushing model to Hugging Face Hub: {str(e)}") | |
| raise e | |
| logging.info("Training script finished successfully.") | |
| if __name__ == "__main__": | |
| main() | |