ctr-ll4 / src /regression /datasets /DecoderDatasetTorch.py
sanjin7's picture
Upload src/ with huggingface_hub
cea4a4b
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
class DecoderDatasetTorch(Dataset):
"""Train dataset."""
def __init__(self, df: pd.DataFrame, embedding_column: str = "my_full_mean_embedding"):
"""
Args:
df (pd.DataFrame): dataframe with ads
embedding_column (str, optional): Column whose values to output in __get_item__. Defaults to 'full_mean_embedding'.
"""
self.df = df
self.embedding_column = embedding_column
df[[embedding_column, "ctr"]] = df[[embedding_column, "ctr"]].applymap(lambda x: np.float32(x))
# df["ctr"] = df["ctr"].astype(np.float32)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
embedding = self.df.loc[idx, self.embedding_column]
ctr = self.df.loc[idx, "ctr"]
return {"embedding": embedding, "ctr": ctr}
# tokenizer = BertTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
# train_dataset = AdDataset(df=dataset.train, tokenizer=tokenizer)