| import torch | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from .original import TransformerModel, LMHead | |
| ''' | |
| Code for HuggingFace Hub Compatability | |
| ''' | |
| class HF_LMModel(PreTrainedModel): | |
| """ Transformer with language model head only """ | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.transformer = TransformerModel(config, vocab=config.n_vocab, n_ctx=config.n_ctx) | |
| self.lm_head = LMHead(self.transformer, config, trunc_and_reshape=False) | |
| self.return_probs = config.return_probs | |
| self.return_acts = config.return_acts | |
| if self.return_probs or self.return_acts: | |
| pos_emb_mask = torch.zeros(1, 1, config.n_vocab) | |
| pos_emb_mask[:, :, -config.n_ctx:] = -1e12 | |
| self.register_buffer('pos_emb_mask', pos_emb_mask) | |
| def forward(self, x, sequence_mask=None): | |
| h = self.transformer(x, sequence_mask) | |
| lm_logits = self.lm_head(h) | |
| if self.return_probs: | |
| lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1) | |
| elif self.return_acts: | |
| lm_logits = lm_logits + self.pos_emb_mask | |
| return { "logits": lm_logits } |