| | from collections import defaultdict |
| | from typing import Dict |
| | import datasets |
| | from datasets import Dataset |
| | from sentence_transformers import ( |
| | SentenceTransformer, |
| | SentenceTransformerTrainer, |
| | losses, |
| | evaluation, |
| | SentenceTransformerTrainingArguments |
| | ) |
| | from sentence_transformers.models import Transformer, Pooling, Dense, Normalize |
| |
|
| | def to_triplets(dataset): |
| | premises = defaultdict(dict) |
| | for sample in dataset: |
| | premises[sample["premise"]][sample["label"]] = sample["hypothesis"] |
| | queries = [] |
| | positives = [] |
| | negatives = [] |
| | for premise, sentences in premises.items(): |
| | if 0 in sentences and 2 in sentences: |
| | queries.append(premise) |
| | positives.append(sentences[0]) |
| | negatives.append(sentences[2]) |
| | return Dataset.from_dict({ |
| | "anchor": queries, |
| | "positive": positives, |
| | "negative": negatives, |
| | }) |
| |
|
| | snli_ds = datasets.load_dataset("snli") |
| | snli_ds = datasets.DatasetDict({ |
| | "train": to_triplets(snli_ds["train"]), |
| | "validation": to_triplets(snli_ds["validation"]), |
| | "test": to_triplets(snli_ds["test"]), |
| | }) |
| | multi_nli_ds = datasets.load_dataset("multi_nli") |
| | multi_nli_ds = datasets.DatasetDict({ |
| | "train": to_triplets(multi_nli_ds["train"]), |
| | "validation_matched": to_triplets(multi_nli_ds["validation_matched"]), |
| | }) |
| |
|
| | all_nli_ds = datasets.DatasetDict({ |
| | "train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]), |
| | "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]), |
| | "test": snli_ds["test"] |
| | }) |
| |
|
| | stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation") |
| | stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test") |
| |
|
| | training_args = SentenceTransformerTrainingArguments( |
| | output_dir="checkpoints", |
| | num_train_epochs=1, |
| | seed=42, |
| | per_device_train_batch_size=256, |
| | per_device_eval_batch_size=256, |
| | learning_rate=2e-5, |
| | warmup_ratio=0.1, |
| | bf16=True, |
| | logging_steps=100, |
| | eval_strategy="steps", |
| | eval_steps=100, |
| | save_steps=100, |
| | save_total_limit=2, |
| | metric_for_best_model="sts-dev_spearman_cosine", |
| | greater_is_better=True, |
| | ) |
| |
|
| | transformer = Transformer("prajjwal1/bert-tiny", max_seq_length=384) |
| | pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean") |
| | dense = Dense(128, 256) |
| | normalize = Normalize() |
| | model = SentenceTransformer(modules=[transformer, pooling, dense, normalize]) |
| | |
| | for param in model.parameters(): |
| | param.data = param.data.contiguous() |
| |
|
| | loss = losses.MultipleNegativesRankingLoss(model) |
| | |
| |
|
| | dev_evaluator = evaluation.EmbeddingSimilarityEvaluator( |
| | stsb_dev["sentence1"], |
| | stsb_dev["sentence2"], |
| | [score / 5 for score in stsb_dev["score"]], |
| | main_similarity=evaluation.SimilarityFunction.COSINE, |
| | name="sts-dev", |
| | ) |
| |
|
| | trainer = SentenceTransformerTrainer( |
| | model=model, |
| | evaluator=dev_evaluator, |
| | args=training_args, |
| | train_dataset=all_nli_ds["train"], |
| | eval_dataset=all_nli_ds["validation"], |
| | loss=loss, |
| | ) |
| | trainer.train() |
| |
|
| | test_evaluator = evaluation.EmbeddingSimilarityEvaluator( |
| | stsb_test["sentence1"], |
| | stsb_test["sentence2"], |
| | [score / 5 for score in stsb_test["score"]], |
| | main_similarity=evaluation.SimilarityFunction.COSINE, |
| | name="sts-test", |
| | ) |
| | results = test_evaluator(model) |
| |
|
| | breakpoint() |
| | model.push_to_hub("sentence-transformers-testing/all-nli-bert-tiny-dense", private=True) |