Spaces:
Sleeping
Sleeping
| import torch | |
| from dateutil.parser import parse as parse_date | |
| from sklearn.model_selection import train_test_split | |
| from transformers import ( | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| TrainingArguments, | |
| Trainer | |
| ) | |
| from torch.utils.data import Dataset | |
| class GroundingDataset(Dataset): | |
| def __init__(self, data, tokenizer, max_length=512): | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| encoding = self.tokenizer( | |
| item["question"], | |
| text_pair=item["answer"] + " [SEP] " + item["context"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors="pt" | |
| ) | |
| return { | |
| "input_ids": encoding["input_ids"].squeeze(), | |
| "attention_mask": encoding["attention_mask"].squeeze(), | |
| "labels": torch.tensor(item["label"]) | |
| } | |
| class GroundingTrainer: | |
| def __init__(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| "distilbert-base-uncased", num_labels=2 | |
| ) | |
| def train(self, dataset): | |
| train_data, val_data = train_test_split(dataset, test_size=0.2) | |
| trainer = Trainer( | |
| model=self.model, | |
| args=TrainingArguments( | |
| output_dir="./results", | |
| num_train_epochs=3, | |
| per_device_train_batch_size=8, | |
| evaluation_strategy="epoch", | |
| logging_dir="./logs" | |
| ), | |
| train_dataset=GroundingDataset(train_data, self.tokenizer), | |
| eval_dataset=GroundingDataset(val_data, self.tokenizer) | |
| ) | |
| trainer.train() | |
| self.model.save_pretrained("./grounding_detector") | |
| self.tokenizer.save_pretrained("./grounding_detector") | |