Support torch_dtype and CLS pooling (#6)
Browse files- support cls pooling and torch_dtype (dfe03e244d4edb62b846a3ed83842cc6c89e9129)
- configuration_xlm_roberta.py +9 -1
- modeling_xlm_roberta.py +25 -3
configuration_xlm_roberta.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from transformers import PretrainedConfig
|
|
|
|
| 2 |
|
| 3 |
class XLMRobertaFlashConfig(PretrainedConfig):
|
| 4 |
def __init__(
|
|
@@ -22,6 +23,8 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 22 |
use_cache=True,
|
| 23 |
classifier_dropout=None,
|
| 24 |
use_flash_attn=True,
|
|
|
|
|
|
|
| 25 |
**kwargs,
|
| 26 |
):
|
| 27 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
@@ -42,4 +45,9 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
| 42 |
self.use_cache = use_cache
|
| 43 |
self.classifier_dropout = classifier_dropout
|
| 44 |
self.use_flash_attn = use_flash_attn
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
+
import torch
|
| 3 |
|
| 4 |
class XLMRobertaFlashConfig(PretrainedConfig):
|
| 5 |
def __init__(
|
|
|
|
| 23 |
use_cache=True,
|
| 24 |
classifier_dropout=None,
|
| 25 |
use_flash_attn=True,
|
| 26 |
+
torch_dtype=None,
|
| 27 |
+
emb_pooler=None,
|
| 28 |
**kwargs,
|
| 29 |
):
|
| 30 |
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
|
| 45 |
self.use_cache = use_cache
|
| 46 |
self.classifier_dropout = classifier_dropout
|
| 47 |
self.use_flash_attn = use_flash_attn
|
| 48 |
+
self.emb_pooler = emb_pooler
|
| 49 |
+
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
| 50 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
| 51 |
+
else:
|
| 52 |
+
self.torch_dtype = torch_dtype
|
| 53 |
+
|
modeling_xlm_roberta.py
CHANGED
|
@@ -395,6 +395,17 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
| 395 |
if isinstance(module, XLMRobertaEncoder):
|
| 396 |
module.gradient_checkpointing = value
|
| 397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
|
| 399 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
| 400 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
|
@@ -545,9 +556,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 545 |
elif output_value is None:
|
| 546 |
raise NotImplementedError
|
| 547 |
else:
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
if normalize_embeddings:
|
| 553 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
@@ -580,6 +596,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 580 |
)
|
| 581 |
|
| 582 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
def forward(
|
| 584 |
self,
|
| 585 |
input_ids,
|
|
|
|
| 395 |
if isinstance(module, XLMRobertaEncoder):
|
| 396 |
module.gradient_checkpointing = value
|
| 397 |
|
| 398 |
+
@classmethod
|
| 399 |
+
def from_pretrained(
|
| 400 |
+
cls,
|
| 401 |
+
*args,
|
| 402 |
+
**kwargs,
|
| 403 |
+
):
|
| 404 |
+
if not 'torch_dtype' in kwargs:
|
| 405 |
+
kwargs['torch_dtype'] = 'auto'
|
| 406 |
+
return super().from_pretrained(*args, **kwargs)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
|
| 410 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
| 411 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
|
|
|
| 556 |
elif output_value is None:
|
| 557 |
raise NotImplementedError
|
| 558 |
else:
|
| 559 |
+
if self.config.emb_pooler == 'cls':
|
| 560 |
+
embeddings = self.cls_pooling(
|
| 561 |
+
token_embs, encoded_input['attention_mask']
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
embeddings = self.mean_pooling(
|
| 565 |
+
token_embs, encoded_input['attention_mask']
|
| 566 |
+
)
|
| 567 |
|
| 568 |
if normalize_embeddings:
|
| 569 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
|
|
| 596 |
)
|
| 597 |
|
| 598 |
|
| 599 |
+
def cls_pooling(
|
| 600 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 601 |
+
):
|
| 602 |
+
return token_embeddings[:,0]
|
| 603 |
+
|
| 604 |
+
|
| 605 |
def forward(
|
| 606 |
self,
|
| 607 |
input_ids,
|