| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| class FullModelDatasetTorch(Dataset): | |
| """Train dataset.""" | |
| def __init__(self, df: pd.DataFrame, nontext_features: list[str] = ["aov"]): | |
| """ | |
| Args: | |
| df (pd.DataFrame): train dataframe | |
| nontext_features (list[str]): features to use in training except for text embeddings | |
| """ | |
| self.df = df | |
| self.nontext_features = nontext_features | |
| df[nontext_features + ["ctr"]] = df[nontext_features + ["ctr"]].astype(np.float32) | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| if torch.is_tensor(idx): | |
| idx = idx.tolist() | |
| text = self.df.loc[idx, "text_clean"] | |
| ctr = self.df.loc[idx, "ctr"] | |
| nontext_features = {feature: self.df.loc[idx, feature] for feature in self.nontext_features} | |
| return {"text": text, "ctr": ctr} | nontext_features | |
| # tokenizer = BertTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity") | |
| # train_dataset = AdDataset(df=dataset.train, tokenizer=tokenizer) | |