ctr-ll4 / src /regression /datasets /FullModelDatasetTorch.py
sanjin7's picture
Upload src/ with huggingface_hub
cea4a4b
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
class FullModelDatasetTorch(Dataset):
"""Train dataset."""
def __init__(self, df: pd.DataFrame, nontext_features: list[str] = ["aov"]):
"""
Args:
df (pd.DataFrame): train dataframe
nontext_features (list[str]): features to use in training except for text embeddings
"""
self.df = df
self.nontext_features = nontext_features
df[nontext_features + ["ctr"]] = df[nontext_features + ["ctr"]].astype(np.float32)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
text = self.df.loc[idx, "text_clean"]
ctr = self.df.loc[idx, "ctr"]
nontext_features = {feature: self.df.loc[idx, feature] for feature in self.nontext_features}
return {"text": text, "ctr": ctr} | nontext_features
# tokenizer = BertTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
# train_dataset = AdDataset(df=dataset.train, tokenizer=tokenizer)