Spaces:
Runtime error
Runtime error
predicting-effective-arguments-in-essay
/
source
/services
/predicting_effective_arguments
/train
/seq_classification.py
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| from datasets import Dataset, load_metric | |
| from sklearn.model_selection import train_test_split | |
| from source.services.predicting_effective_arguments.train.model import TransformersSequenceClassifier | |
| class CFG: | |
| TARGET = 'discourse_effectiveness' | |
| TEXT = "discourse_text" | |
| MODEL_CHECKPOINT = "distilbert-base-uncased" | |
| MODEL_OUTPUT_DIR ='source/services/predicting_effective_arguments/model/hf_textclassification/predicting_effective_arguments_distilbert' | |
| model_name="debertav3base" | |
| learning_rate=1.5e-5 | |
| weight_decay=0.02 | |
| hidden_dropout_prob=0.007 | |
| attention_probs_dropout_prob=0.007 | |
| num_train_epochs=10 | |
| n_splits=4 | |
| batch_size=12 | |
| random_seed=42 | |
| save_steps=100 | |
| max_length=512 | |
| def seed_everything(seed: int): | |
| import random, os | |
| import numpy as np | |
| import torch | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = True | |
| def prepare_input_text(df, sep_token): | |
| df['inputs'] = df.discourse_type.str.lower() + ' ' + sep_token + ' ' + df.discourse_text.str.lower() | |
| return df | |
| if __name__ == '__main__': | |
| config = CFG() | |
| tokenizer = AutoTokenizer.from_pretrained(config.MODEL_CHECKPOINT) | |
| seqClassifer = TransformersSequenceClassifier(model_output_dir=config.MODEL_OUTPUT_DIR, tokenizer=tokenizer, model_checkpoint="distilbert-base-uncased", num_labels=3) #distilbert-base-uncased | |
| data = pd.read_csv("data/raw_data/train.csv")[:100] | |
| test_df = pd.read_csv("data/raw_data/test.csv") | |
| train_size = 0.7 | |
| valid_size = 0.2 | |
| test_size = 0.1 | |
| # First split: Separate out the training set | |
| train_df, temp_df = train_test_split(data, test_size=1 - train_size) | |
| # Second split: Separate out the validation and test sets | |
| valid_df, test_df = train_test_split(temp_df, test_size=test_size / (test_size + valid_size)) | |
| train_df = prepare_input_text(train_df, sep_token=tokenizer.sep_token) | |
| valid_df = prepare_input_text(valid_df, sep_token=tokenizer.sep_token) | |
| test_df = prepare_input_text(test_df, sep_token=tokenizer.sep_token) | |
| train_dataset = Dataset.from_pandas(train_df[['inputs', config.TARGET]]).rename_column(config.TARGET, 'label').class_encode_column("label") | |
| val_dataset = Dataset.from_pandas(valid_df[['inputs', config.TARGET]]).rename_column(config.TARGET, 'label').class_encode_column("label") | |
| test_dataset = Dataset.from_pandas(test_df[['inputs', config.TARGET]]).rename_column(config.TARGET, 'label').class_encode_column("label") | |
| train_tok_dataset = seqClassifer.tokenize_dataset(dataset=train_dataset) | |
| val_tok_dataset = seqClassifer.tokenize_dataset(dataset=val_dataset) | |
| test_tok_dataset = seqClassifer.tokenize_dataset(dataset=test_dataset) | |
| seqClassifer.train(train_dataset=train_tok_dataset, eval_dataset=val_tok_dataset, epochs=1, batch_size=16) | |
| y_pred = seqClassifer.predict_valid_data(val_tok_dataset) | |
| seqClassifer.predict_test_data(model_checkpoint=config.MODEL_OUTPUT_DIR, test_data=test_df['inputs'].tolist()) | |
| pass | |
| """ | |
| train_df[TARGET].value_counts(ascending=True).plot.barh() | |
| plt.title("Frequency of Classes") | |
| plt.show() | |
| train_df['discourse_type'].value_counts(ascending=True).plot.barh() | |
| plt.title("Frequency of discourse_type") | |
| plt.show() | |
| train_df["Words Per text"] = train_df[TEXT].str.split().apply(len) | |
| train_df.boxplot("Words Per text", by=TARGET, grid=False, showfliers=False, | |
| color="black") | |
| plt.suptitle("") | |
| plt.xlabel("") | |
| plt.show() | |
| """ | |
| pass |