Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| import torch | |
| from transformers.file_utils import ModelOutput | |
| from typing import Optional, Tuple | |
| from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput | |
| class BaseModelOutputWithPastAndCrossAttentionsSkim(ModelOutput): | |
| last_hidden_state: torch.FloatTensor = None | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| cross_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| attention_mask: Optional[torch.FloatTensor] = None | |
| skim_mask: Optional[torch.FloatTensor] = None | |
| class BaseModelOutputWithPoolingAndCrossAttentionsSkim(ModelOutput): | |
| last_hidden_state: torch.FloatTensor = None | |
| pooler_output: torch.FloatTensor = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| cross_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| attention_mask: Optional[torch.FloatTensor] = None | |
| skim_mask: Optional[torch.FloatTensor] = None | |
| class SequenceClassifierOutputSkim(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| attention_mask: Optional[torch.FloatTensor] = None | |
| skim_mask: Optional[torch.FloatTensor] = None | |
| skim_loss: Optional[torch.FloatTensor] = None | |
| classification_loss: Optional[torch.FloatTensor] = None | |
| tokens_remained: Optional[torch.FloatTensor] = None | |
| layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None | |
| class QuestionAnsweringModelOutputSkim(QuestionAnsweringModelOutput): | |
| attention_mask: Optional[torch.FloatTensor] = None | |
| skim_mask: Optional[torch.FloatTensor] = None | |
| skim_loss: Optional[torch.FloatTensor] = None | |
| classification_loss: Optional[torch.FloatTensor] = None | |
| tokens_remained: Optional[torch.FloatTensor] = None | |
| layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None | |
| class MaskedLMOutputSkim(MaskedLMOutput): | |
| attention_mask: Optional[torch.FloatTensor] = None | |
| skim_mask: Optional[torch.FloatTensor] = None | |
| skim_loss: Optional[torch.FloatTensor] = None | |
| classification_loss: Optional[torch.FloatTensor] = None | |
| tokens_remained: Optional[torch.FloatTensor] = None | |
| layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None | |
| def masked_softmax(vec, mask, dim=1, eps=1e-6): | |
| mask = mask[:,None,None,:] | |
| exps = torch.exp(vec) | |
| masked_exps = exps * mask.float() + eps | |
| masked_sums = masked_exps.sum(dim, keepdim=True) | |
| return (masked_exps/masked_sums) | |
| def convert_softmax_mask_to_digit(skim_mask): | |
| # skim_mask [batch, from, to, seq_len] | |
| return (skim_mask == 0).to(dtype=torch.int64).unsqueeze(1).unsqueeze(1) | |
| def trunc_with_mask_batched(input, mask, dim): | |
| """ | |
| trunc a batched input at dim | |
| e.g. hidden_states ([batch, seq_len, hidden_size]) | |
| attention_mask ([batch, layer, head, seq_len]) | |
| mask: [batch, seq_len] | |
| """ | |
| assert input.shape[dim]==mask.shape[1] | |
| if dim != 1: | |
| input = input.transpose(1, dim) | |
| transpose_shape = list(input.shape) | |
| transpose_shape[1] = -1 | |
| trunc_input = input[mask].view(transpose_shape) | |
| if dim != 1: | |
| trunc_input = trunc_input.transpose(1, dim) | |
| return trunc_input |