davidtran999 commited on
Commit
78adb6c
·
verified ·
1 Parent(s): 8b0c5c0

Upload backend/venv/lib/python3.10/site-packages/sentence_transformers/training_args.py with huggingface_hub

Browse files
backend/venv/lib/python3.10/site-packages/sentence_transformers/training_args.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass, field
5
+
6
+ from transformers import TrainingArguments as TransformersTrainingArguments
7
+ from transformers.training_args import ParallelMode
8
+ from transformers.utils import ExplicitEnum
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class BatchSamplers(ExplicitEnum):
14
+ """
15
+ Stores the acceptable string identifiers for batch samplers.
16
+
17
+ The batch sampler is responsible for determining how samples are grouped into batches during training.
18
+ Valid options are:
19
+
20
+ - ``BatchSamplers.BATCH_SAMPLER``: **[default]** Uses :class:`~sentence_transformers.sampler.DefaultBatchSampler`, the default
21
+ PyTorch batch sampler.
22
+ - ``BatchSamplers.NO_DUPLICATES``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler`,
23
+ ensuring no duplicate samples in a batch. Recommended for losses that use in-batch negatives, such as:
24
+
25
+ - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss`
26
+ - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss`
27
+ - :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss`
28
+ - :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss`
29
+ - :class:`~sentence_transformers.losses.MegaBatchMarginLoss`
30
+ - :class:`~sentence_transformers.losses.GISTEmbedLoss`
31
+ - :class:`~sentence_transformers.losses.CachedGISTEmbedLoss`
32
+ - ``BatchSamplers.GROUP_BY_LABEL``: Uses :class:`~sentence_transformers.sampler.GroupByLabelBatchSampler`,
33
+ ensuring that each batch has 2+ samples from the same label. Recommended for losses that require multiple
34
+ samples from the same label, such as:
35
+
36
+ - :class:`~sentence_transformers.losses.BatchAllTripletLoss`
37
+ - :class:`~sentence_transformers.losses.BatchHardSoftMarginTripletLoss`
38
+ - :class:`~sentence_transformers.losses.BatchHardTripletLoss`
39
+ - :class:`~sentence_transformers.losses.BatchSemiHardTripletLoss`
40
+
41
+ If you want to use a custom batch sampler, you can create a new Trainer class that inherits from
42
+ :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` and overrides the
43
+ :meth:`~sentence_transformers.trainer.SentenceTransformerTrainer.get_batch_sampler` method. The
44
+ method must return a class instance that supports ``__iter__`` and ``__len__`` methods. The former
45
+ should yield a list of indices for each batch, and the latter should return the number of batches.
46
+
47
+ Usage:
48
+ ::
49
+
50
+ from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
51
+ from sentence_transformers.training_args import BatchSamplers
52
+ from sentence_transformers.losses import MultipleNegativesRankingLoss
53
+ from datasets import Dataset
54
+
55
+ model = SentenceTransformer("microsoft/mpnet-base")
56
+ train_dataset = Dataset.from_dict({
57
+ "anchor": ["It's nice weather outside today.", "He drove to work."],
58
+ "positive": ["It's so sunny.", "He took the car to the office."],
59
+ })
60
+ loss = MultipleNegativesRankingLoss(model)
61
+ args = SentenceTransformerTrainingArguments(
62
+ output_dir="checkpoints",
63
+ batch_sampler=BatchSamplers.NO_DUPLICATES,
64
+ )
65
+ trainer = SentenceTransformerTrainer(
66
+ model=model,
67
+ args=args,
68
+ train_dataset=train_dataset,
69
+ loss=loss,
70
+ )
71
+ trainer.train()
72
+ """
73
+
74
+ BATCH_SAMPLER = "batch_sampler"
75
+ NO_DUPLICATES = "no_duplicates"
76
+ GROUP_BY_LABEL = "group_by_label"
77
+
78
+
79
+ class MultiDatasetBatchSamplers(ExplicitEnum):
80
+ """
81
+ Stores the acceptable string identifiers for multi-dataset batch samplers.
82
+
83
+ The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple
84
+ datasets during training. Valid options are:
85
+
86
+ - ``MultiDatasetBatchSamplers.ROUND_ROBIN``: Uses :class:`~sentence_transformers.sampler.RoundRobinBatchSampler`,
87
+ which uses round-robin sampling from each dataset until one is exhausted.
88
+ With this strategy, it's likely that not all samples from each dataset are used, but each dataset is sampled
89
+ from equally.
90
+ - ``MultiDatasetBatchSamplers.PROPORTIONAL``: **[default]** Uses :class:`~sentence_transformers.sampler.ProportionalBatchSampler`,
91
+ which samples from each dataset in proportion to its size.
92
+ With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.
93
+
94
+ Usage:
95
+ ::
96
+
97
+ from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
98
+ from sentence_transformers.training_args import MultiDatasetBatchSamplers
99
+ from sentence_transformers.losses import CoSENTLoss
100
+ from datasets import Dataset, DatasetDict
101
+
102
+ model = SentenceTransformer("microsoft/mpnet-base")
103
+ train_general = Dataset.from_dict({
104
+ "sentence_A": ["It's nice weather outside today.", "He drove to work."],
105
+ "sentence_B": ["It's so sunny.", "He took the car to the bank."],
106
+ "score": [0.9, 0.4],
107
+ })
108
+ train_medical = Dataset.from_dict({
109
+ "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
110
+ "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
111
+ "score": [0.8, 0.6, 0.7],
112
+ })
113
+ train_legal = Dataset.from_dict({
114
+ "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
115
+ "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
116
+ "score": [0.7, 0.8],
117
+ })
118
+ train_dataset = DatasetDict({
119
+ "general": train_general,
120
+ "medical": train_medical,
121
+ "legal": train_legal,
122
+ })
123
+
124
+ loss = CoSENTLoss(model)
125
+ args = SentenceTransformerTrainingArguments(
126
+ output_dir="checkpoints",
127
+ multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
128
+ )
129
+ trainer = SentenceTransformerTrainer(
130
+ model=model,
131
+ args=args,
132
+ train_dataset=train_dataset,
133
+ loss=loss,
134
+ )
135
+ trainer.train()
136
+ """
137
+
138
+ ROUND_ROBIN = "round_robin" # Round-robin sampling from each dataset
139
+ PROPORTIONAL = "proportional" # Sample from each dataset in proportion to its size [default]
140
+
141
+
142
+ @dataclass
143
+ class SentenceTransformerTrainingArguments(TransformersTrainingArguments):
144
+ """
145
+ SentenceTransformerTrainingArguments extends :class:`~transformers.TrainingArguments` with additional arguments
146
+ specific to Sentence Transformers. See :class:`~transformers.TrainingArguments` for the complete list of
147
+ available arguments.
148
+
149
+ Args:
150
+ output_dir (`str`):
151
+ The output directory where the model checkpoints will be written.
152
+ batch_sampler (Union[:class:`~sentence_transformers.training_args.BatchSamplers`, `str`], *optional*):
153
+ The batch sampler to use. See :class:`~sentence_transformers.training_args.BatchSamplers` for valid options.
154
+ Defaults to ``BatchSamplers.BATCH_SAMPLER``.
155
+ multi_dataset_batch_sampler (Union[:class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`, `str`], *optional*):
156
+ The multi-dataset batch sampler to use. See :class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`
157
+ for valid options. Defaults to ``MultiDatasetBatchSamplers.PROPORTIONAL``.
158
+ """
159
+
160
+ batch_sampler: BatchSamplers | str = field(
161
+ default=BatchSamplers.BATCH_SAMPLER, metadata={"help": "The batch sampler to use."}
162
+ )
163
+ multi_dataset_batch_sampler: MultiDatasetBatchSamplers | str = field(
164
+ default=MultiDatasetBatchSamplers.PROPORTIONAL, metadata={"help": "The multi-dataset batch sampler to use."}
165
+ )
166
+
167
+ def __post_init__(self):
168
+ super().__post_init__()
169
+
170
+ self.batch_sampler = BatchSamplers(self.batch_sampler)
171
+ self.multi_dataset_batch_sampler = MultiDatasetBatchSamplers(self.multi_dataset_batch_sampler)
172
+
173
+ # The `compute_loss` method in `SentenceTransformerTrainer` is overridden to only compute the prediction loss,
174
+ # so we set `prediction_loss_only` to `True` here to avoid
175
+ self.prediction_loss_only = True
176
+
177
+ # Disable broadcasting of buffers to avoid `RuntimeError: one of the variables needed for gradient computation
178
+ # has been modified by an inplace operation.` when training with DDP & a BertModel-based model.
179
+ self.ddp_broadcast_buffers = False
180
+
181
+ if self.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
182
+ # If output_dir is "unused", then this instance is created to compare training arguments vs the defaults,
183
+ # so we don't have to warn.
184
+ if self.output_dir != "unused":
185
+ logger.warning(
186
+ "Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. "
187
+ "See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information."
188
+ )
189
+
190
+ elif self.parallel_mode == ParallelMode.DISTRIBUTED and not self.dataloader_drop_last:
191
+ # If output_dir is "unused", then this instance is created to compare training arguments vs the defaults,
192
+ # so we don't have to warn.
193
+ if self.output_dir != "unused":
194
+ logger.warning(
195
+ "When using DistributedDataParallel (DDP), it is recommended to set `dataloader_drop_last=True` to avoid hanging issues with an uneven last batch. "
196
+ "Setting `dataloader_drop_last=True`."
197
+ )
198
+ self.dataloader_drop_last = True