Spaces:
Runtime error
Runtime error
OFA-Visual_Grounding
/
fairseq
/examples
/simultaneous_translation
/modules
/monotonic_multihead_attention.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| from torch import Tensor | |
| import torch.nn as nn | |
| from examples.simultaneous_translation.utils.p_choose_strategy import ( | |
| learnable_p_choose, | |
| waitk_p_choose | |
| ) | |
| from examples.simultaneous_translation.utils.monotonic_attention import ( | |
| expected_alignment_from_p_choose, | |
| expected_soft_attention, | |
| mass_preservation, | |
| ) | |
| from fairseq.modules import MultiheadAttention | |
| from . import register_monotonic_attention | |
| from typing import Dict, Optional | |
| class MonotonicAttention(MultiheadAttention): | |
| """ | |
| Abstract class of monotonic attentions | |
| """ | |
| k_in_proj: Dict[str, nn.Linear] | |
| q_in_proj: Dict[str, nn.Linear] | |
| def __init__(self, args): | |
| super().__init__( | |
| embed_dim=args.decoder_embed_dim, | |
| num_heads=args.decoder_attention_heads, | |
| kdim=getattr(args, "encoder_embed_dim", None), | |
| vdim=getattr(args, "encoder_embed_dim", None), | |
| dropout=args.attention_dropout, | |
| encoder_decoder_attention=True, | |
| ) | |
| self.soft_attention = False | |
| self.eps = getattr(args, "attention_eps", True) | |
| self.mass_preservation = getattr(args, "mass_preservation", True) | |
| self.noise_type = args.noise_type | |
| self.noise_mean = args.noise_mean | |
| self.noise_var = args.noise_var | |
| self.energy_bias_init = args.energy_bias_init | |
| self.energy_bias = ( | |
| nn.Parameter(self.energy_bias_init * torch.ones([1])) | |
| if args.energy_bias is True | |
| else 0 | |
| ) | |
| self.k_in_proj = {"monotonic": self.k_proj} | |
| self.q_in_proj = {"monotonic": self.q_proj} | |
| self.chunk_size = None | |
| def add_args(parser): | |
| # fmt: off | |
| parser.add_argument('--no-mass-preservation', action="store_false", | |
| dest="mass_preservation", | |
| help='Do not stay on the last token when decoding') | |
| parser.add_argument('--mass-preservation', action="store_true", | |
| dest="mass_preservation", | |
| help='Stay on the last token when decoding') | |
| parser.set_defaults(mass_preservation=True) | |
| parser.add_argument('--noise-var', type=float, default=1.0, | |
| help='Variance of discretness noise') | |
| parser.add_argument('--noise-mean', type=float, default=0.0, | |
| help='Mean of discretness noise') | |
| parser.add_argument('--noise-type', type=str, default="flat", | |
| help='Type of discretness noise') | |
| parser.add_argument('--energy-bias', action="store_true", | |
| default=False, | |
| help='Bias for energy') | |
| parser.add_argument('--energy-bias-init', type=float, default=-2.0, | |
| help='Initial value of the bias for energy') | |
| parser.add_argument('--attention-eps', type=float, default=1e-6, | |
| help='Epsilon when calculating expected attention') | |
| def energy_from_qk( | |
| self, | |
| query: Tensor, | |
| key: Tensor, | |
| energy_type: str, | |
| key_padding_mask: Optional[Tensor] = None, | |
| bias: int = 0 | |
| ): | |
| """ | |
| Compute energy from query and key | |
| q_func_value is a tuple looks like | |
| (q_proj_func, q_tensor) | |
| q_tensor size: bsz, tgt_len, emb_dim | |
| k_tensor size: bsz, src_len, emb_dim | |
| key_padding_mask size: bsz, src_len | |
| attn_mask: bsz, src_len | |
| """ | |
| length, bsz, _ = query.size() | |
| q = self.q_in_proj[energy_type].forward(query) | |
| q = ( | |
| q.contiguous() | |
| .view(length, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| q = q * self.scaling | |
| length, bsz, _ = key.size() | |
| k = self.k_in_proj[energy_type].forward(key) | |
| k = ( | |
| k.contiguous() | |
| .view(length, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| energy = torch.bmm(q, k.transpose(1, 2)) + bias | |
| if key_padding_mask is not None: | |
| energy = energy.masked_fill( | |
| key_padding_mask.unsqueeze(1).to(torch.bool), | |
| - float("inf") | |
| ) | |
| return energy | |
| def p_choose_from_qk(self, query, key, key_padding_mask, incremental_states=None): | |
| monotonic_energy = self.energy_from_qk( | |
| query, | |
| key, | |
| "monotonic", | |
| key_padding_mask=key_padding_mask, | |
| bias=self.energy_bias, | |
| ) | |
| p_choose = learnable_p_choose( | |
| monotonic_energy, | |
| self.noise_mean, | |
| self.noise_var, | |
| self.training | |
| ) | |
| return p_choose | |
| def p_choose(self, query, key, key_padding_mask, incremental_states=None): | |
| return self.p_choose_from_qk(self, query, key, key_padding_mask) | |
| def monotonic_attention_process_infer( | |
| self, | |
| query: Optional[Tensor], | |
| key: Optional[Tensor], | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
| ): | |
| """ | |
| Monotonic attention at inference time | |
| Notice that this function is designed for simuleval not sequence_generator | |
| """ | |
| assert query is not None | |
| assert key is not None | |
| if query.size(1) != 1: | |
| raise RuntimeError( | |
| "Simultaneous translation models don't support batch decoding." | |
| ) | |
| # 1. compute stepwise probability | |
| p_choose = self.p_choose( | |
| query, key, None, incremental_state | |
| ).squeeze(1) | |
| # 2. Compute the alpha | |
| src_len = key.size(0) | |
| # Maximum steps allows in this iteration | |
| max_steps = src_len - 1 if self.mass_preservation else src_len | |
| monotonic_cache = self._get_monotonic_buffer(incremental_state) | |
| # Step for each head | |
| monotonic_step = monotonic_cache.get( | |
| 'head_step', | |
| p_choose.new_zeros(1, self.num_heads).long() | |
| ) | |
| assert monotonic_step is not None | |
| finish_read = monotonic_step.eq(max_steps) | |
| p_choose_i = torch.tensor(1) | |
| while finish_read.sum().item() < self.num_heads: | |
| # p_choose: self.num_heads, src_len | |
| # only choose the p at monotonic steps | |
| # p_choose_i: 1, self.num_heads | |
| p_choose_i = ( | |
| p_choose.gather( | |
| 1, | |
| monotonic_step | |
| .clamp(0, src_len - 1), | |
| ) | |
| ) | |
| read_one_step = ( | |
| (p_choose_i < 0.5) | |
| .type_as(monotonic_step) | |
| .masked_fill(finish_read, 0) | |
| ) | |
| # 1 x bsz | |
| # sample actions on unfinished seq | |
| # 0 means stay, finish reading | |
| # 1 means leave, continue reading | |
| monotonic_step += read_one_step | |
| finish_read = monotonic_step.eq(max_steps) | (read_one_step == 0) | |
| # p_choose at last steps | |
| p_choose_i = ( | |
| p_choose.gather( | |
| 1, | |
| monotonic_step | |
| .clamp(0, src_len - 1), | |
| ) | |
| ) | |
| monotonic_cache["head_step"] = monotonic_step | |
| # Whether a head is looking for new input | |
| monotonic_cache["head_read"] = ( | |
| monotonic_step.eq(max_steps) & (p_choose_i < 0.5) | |
| ) | |
| self._set_monotonic_buffer(incremental_state, monotonic_cache) | |
| # 2. Update alpha | |
| alpha = ( | |
| p_choose | |
| .new_zeros([self.num_heads, src_len]) | |
| .scatter( | |
| 1, | |
| (monotonic_step) | |
| .view(self.num_heads, 1).clamp(0, src_len - 1), | |
| 1 | |
| ) | |
| ) | |
| if not self.mass_preservation: | |
| alpha = alpha.masked_fill( | |
| (monotonic_step == max_steps) | |
| .view(self.num_heads, 1), | |
| 0 | |
| ) | |
| # 4. Compute Beta | |
| if self.soft_attention: | |
| monotonic_step = monotonic_step.t() | |
| beta_mask = torch.arange(src_len).expand_as(alpha).gt(monotonic_step).unsqueeze(1) | |
| # If it's soft attention just do softmax on current context | |
| soft_energy = self.energy_from_qk( | |
| query, | |
| key, | |
| "soft" | |
| ) | |
| beta = torch.nn.functional.softmax( | |
| soft_energy.masked_fill(beta_mask, -float("inf")), dim=-1 | |
| ) | |
| # It could happen that a head doesn't move at all | |
| beta = beta.masked_fill(monotonic_step.eq(0).unsqueeze(1), 0) | |
| else: | |
| # If it's hard attention just select the last state | |
| beta = alpha | |
| return p_choose, alpha, beta | |
| def monotonic_attention_process_train( | |
| self, | |
| query: Optional[Tensor], | |
| key: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor] = None, | |
| ): | |
| """ | |
| Calculating monotonic attention process for training | |
| Including: | |
| stepwise probability: p_choose | |
| expected hard alignment: alpha | |
| expected soft attention: beta | |
| """ | |
| assert query is not None | |
| assert key is not None | |
| # 1. compute stepwise probability | |
| p_choose = self.p_choose_from_qk(query, key, key_padding_mask) | |
| # 2. compute expected_alignment | |
| alpha = expected_alignment_from_p_choose( | |
| p_choose, | |
| key_padding_mask, | |
| eps=self.eps, | |
| ) | |
| if self.mass_preservation: | |
| alpha = mass_preservation( | |
| alpha, key_padding_mask | |
| ) | |
| # 3. compute expected soft attention (soft aligned model only) | |
| if self.soft_attention: | |
| soft_energy = self.energy_from_qk( | |
| query, | |
| key, | |
| "soft", | |
| key_padding_mask=None, | |
| ) | |
| beta = expected_soft_attention( | |
| alpha, | |
| soft_energy, | |
| padding_mask=key_padding_mask, | |
| chunk_size=self.chunk_size, | |
| eps=self.eps, | |
| ) | |
| else: | |
| beta = alpha | |
| soft_energy = alpha | |
| return p_choose, alpha, beta, soft_energy | |
| def forward( | |
| self, | |
| query: Optional[Tensor], | |
| key: Optional[Tensor], | |
| value: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor] = None, | |
| attn_mask: Optional[Tensor] = None, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False, | |
| ): | |
| """ | |
| query: tgt_len, bsz, embed_dim | |
| key: src_len, bsz, embed_dim | |
| value: src_len, bsz, embed_dim | |
| """ | |
| assert attn_mask is None | |
| assert query is not None | |
| assert key is not None | |
| assert value is not None | |
| tgt_len, bsz, embed_dim = query.size() | |
| src_len = value.size(0) | |
| if key_padding_mask is not None: | |
| assert not key_padding_mask[:, 0].any(), ( | |
| "Only right padding is supported." | |
| ) | |
| key_padding_mask = ( | |
| key_padding_mask | |
| .unsqueeze(1) | |
| .expand([bsz, self.num_heads, src_len]) | |
| .contiguous() | |
| .view(-1, src_len) | |
| ) | |
| if incremental_state is not None: | |
| # Inference | |
| ( | |
| p_choose, alpha, beta | |
| ) = self.monotonic_attention_process_infer( | |
| query, key, incremental_state | |
| ) | |
| soft_energy = beta | |
| else: | |
| # Train | |
| ( | |
| p_choose, alpha, beta, soft_energy | |
| ) = self.monotonic_attention_process_train( | |
| query, key, key_padding_mask | |
| ) | |
| v = self.v_proj(value) | |
| length, bsz, _ = v.size() | |
| v = ( | |
| v.contiguous() | |
| .view(length, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| attn = torch.bmm(beta.type_as(v), v) | |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | |
| attn = self.out_proj(attn) | |
| p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) | |
| alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) | |
| beta = beta.view(bsz, self.num_heads, tgt_len, src_len) | |
| return attn, { | |
| "p_choose": p_choose, | |
| "alpha": alpha, | |
| "beta": beta, | |
| } | |
| def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): | |
| maybe_incremental_state = self.get_incremental_state( | |
| incremental_state, | |
| 'monotonic', | |
| ) | |
| if maybe_incremental_state is None: | |
| typed_empty_dict: Dict[str, Optional[Tensor]] = {} | |
| return typed_empty_dict | |
| else: | |
| return maybe_incremental_state | |
| def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]): | |
| self.set_incremental_state( | |
| incremental_state, | |
| 'monotonic', | |
| buffer, | |
| ) | |
| class MonotonicInfiniteLookbackAttention( | |
| MonotonicAttention | |
| ): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.soft_attention = True | |
| self.init_soft_attention() | |
| def init_soft_attention(self): | |
| self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True) | |
| self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | |
| self.k_in_proj["soft"] = self.k_proj_soft | |
| self.q_in_proj["soft"] = self.q_proj_soft | |
| if self.qkv_same_dim: | |
| # Empirically observed the convergence to be much better with | |
| # the scaled initialization | |
| nn.init.xavier_uniform_( | |
| self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2) | |
| ) | |
| nn.init.xavier_uniform_( | |
| self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2) | |
| ) | |
| else: | |
| nn.init.xavier_uniform_(self.k_in_proj["soft"].weight) | |
| nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) | |
| class WaitKAttention( | |
| MonotonicInfiniteLookbackAttention | |
| ): | |
| """ | |
| STACL: Simultaneous Translation with Implicit Anticipation and | |
| Controllable Latency using Prefix-to-Prefix Framework | |
| https://www.aclweb.org/anthology/P19-1289/ | |
| """ | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.q_in_proj["soft"] = self.q_in_proj["monotonic"] | |
| self.k_in_proj["soft"] = self.k_in_proj["monotonic"] | |
| self.waitk_lagging = args.waitk_lagging | |
| assert self.waitk_lagging > 0, ( | |
| f"Lagging has to been larger than 0, get {self.waitk_lagging}." | |
| ) | |
| def add_args(parser): | |
| super( | |
| MonotonicInfiniteLookbackAttention, | |
| MonotonicInfiniteLookbackAttention | |
| ).add_args(parser) | |
| parser.add_argument( | |
| "--waitk-lagging", type=int, required=True, help="Wait K lagging" | |
| ) | |
| def p_choose_from_qk( | |
| self, | |
| query: Optional[Tensor], | |
| key: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor] = None, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| ): | |
| assert query is not None | |
| assert key is not None | |
| p_choose = waitk_p_choose( | |
| tgt_len=query.size(0), | |
| src_len=key.size(0), | |
| bsz=query.size(1) * self.num_heads, | |
| waitk_lagging=self.waitk_lagging, | |
| key_padding_mask=key_padding_mask, | |
| incremental_state=incremental_state, | |
| ) | |
| return p_choose.to(query) | |
| class ChunkwiseAttention( | |
| MonotonicInfiniteLookbackAttention | |
| ): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.chunk_size = args.mocha_chunk_size | |
| assert self.chunk_size > 1 | |
| def add_args(parser): | |
| super( | |
| MonotonicInfiniteLookbackAttention | |
| ).add_args(parser) | |
| parser.add_argument( | |
| "--mocha-chunk-size", type=int, | |
| required=True, help="Mocha chunk size" | |
| ) | |