Spaces:
Paused
Paused
| import torch | |
| from .decode_strategy import DecodeStrategy | |
| class BeamSearch(DecodeStrategy): | |
| """Generation with beam search. | |
| """ | |
| def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length, | |
| return_attention, max_length): | |
| super(BeamSearch, self).__init__( | |
| pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length) | |
| self.beam_size = beam_size | |
| self.n_best = n_best | |
| # result caching | |
| self.hypotheses = [[] for _ in range(batch_size)] | |
| # beam state | |
| self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool) | |
| self._batch_offset = torch.arange(batch_size, dtype=torch.long) | |
| self.select_indices = None | |
| self.done = False | |
| def initialize(self, memory_bank, device=None): | |
| """Repeat src objects `beam_size` times. | |
| """ | |
| def fn_map_state(state, dim): | |
| return torch.repeat_interleave(state, self.beam_size, dim=dim) | |
| memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0) | |
| if device is None: | |
| device = memory_bank.device | |
| self.memory_length = memory_bank.size(1) | |
| super().initialize(memory_bank, device) | |
| self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device) | |
| self._beam_offset = torch.arange( | |
| 0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device) | |
| self.topk_log_probs = torch.tensor( | |
| [0.0] + [float("-inf")] * (self.beam_size - 1), device=device | |
| ).repeat(self.batch_size) | |
| # buffers for the topk scores and 'backpointer' | |
| self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device) | |
| self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device) | |
| self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device) | |
| return fn_map_state, memory_bank | |
| def current_predictions(self): | |
| return self.alive_seq[:, -1] | |
| def current_backptr(self): | |
| # for testing | |
| return self.select_indices.view(self.batch_size, self.beam_size) | |
| def batch_offset(self): | |
| return self._batch_offset | |
| def _pick(self, log_probs): | |
| """Return token decision for a step. | |
| Args: | |
| log_probs (FloatTensor): (B, vocab_size) | |
| Returns: | |
| topk_scores (FloatTensor): (B, beam_size) | |
| topk_ids (LongTensor): (B, beam_size) | |
| """ | |
| vocab_size = log_probs.size(-1) | |
| # Flatten probs into a list of probabilities. | |
| curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size) | |
| topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1) | |
| return topk_scores, topk_ids | |
| def advance(self, log_probs, attn): | |
| """ | |
| Args: | |
| log_probs: (B * beam_size, vocab_size) | |
| """ | |
| vocab_size = log_probs.size(-1) | |
| # (non-finished) batch_size | |
| _B = log_probs.shape[0] // self.beam_size | |
| step = len(self) # alive_seq | |
| self.ensure_min_length(log_probs) | |
| # Multiply probs by the beam probability | |
| log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) | |
| curr_length = step + 1 | |
| curr_scores = log_probs / curr_length # avg log_prob | |
| self.topk_scores, self.topk_ids = self._pick(curr_scores) | |
| # topk_scores/topk_ids: (batch_size, beam_size) | |
| # Recover log probs | |
| torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs) | |
| # Resolve beam origin and map to batch index flat representation. | |
| self._batch_index = self.topk_ids // vocab_size | |
| self._batch_index += self._beam_offset[:_B].unsqueeze(1) | |
| self.select_indices = self._batch_index.view(_B * self.beam_size) | |
| self.topk_ids.fmod_(vocab_size) # resolve true word ids | |
| # Append last prediction. | |
| self.alive_seq = torch.cat( | |
| [self.alive_seq.index_select(0, self.select_indices), | |
| self.topk_ids.view(_B * self.beam_size, 1)], -1) | |
| if self.return_attention: | |
| current_attn = attn.index_select(1, self.select_indices) | |
| if step == 1: | |
| self.alive_attn = current_attn | |
| else: | |
| self.alive_attn = self.alive_attn.index_select( | |
| 1, self.select_indices) | |
| self.alive_attn = torch.cat([self.alive_attn, current_attn], 0) | |
| self.is_finished = self.topk_ids.eq(self.eos) | |
| self.ensure_max_length() | |
| def update_finished(self): | |
| _B_old = self.topk_log_probs.shape[0] | |
| step = self.alive_seq.shape[-1] # len(self) | |
| self.topk_log_probs.masked_fill_(self.is_finished, -1e10) | |
| self.is_finished = self.is_finished.to('cpu') | |
| self.top_beam_finished |= self.is_finished[:, 0].eq(1) | |
| predictions = self.alive_seq.view(_B_old, self.beam_size, step) | |
| attention = ( | |
| self.alive_attn.view( | |
| step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) | |
| if self.alive_attn is not None else None) | |
| non_finished_batch = [] | |
| for i in range(self.is_finished.size(0)): | |
| b = self._batch_offset[i] | |
| finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1) | |
| # Store finished hypothesis for this batch. | |
| for j in finished_hyp: # Beam level: finished beam j in batch i | |
| self.hypotheses[b].append(( | |
| self.topk_scores[i, j], | |
| predictions[i, j, 1:], # Ignore start token | |
| attention[:, i, j, :self.memory_length] | |
| if attention is not None else None)) | |
| # End condition is the top beam finished and we can return | |
| # n_best hypotheses. | |
| finish_flag = self.top_beam_finished[i] != 0 | |
| if finish_flag and len(self.hypotheses[b]) >= self.n_best: | |
| best_hyp = sorted( | |
| self.hypotheses[b], key=lambda x: x[0], reverse=True) | |
| for n, (score, pred, attn) in enumerate(best_hyp): | |
| if n >= self.n_best: | |
| break | |
| self.scores[b].append(score.item()) | |
| self.predictions[b].append(pred) | |
| self.attention[b].append( | |
| attn if attn is not None else []) | |
| else: | |
| non_finished_batch.append(i) | |
| non_finished = torch.tensor(non_finished_batch) | |
| if len(non_finished) == 0: | |
| self.done = True | |
| return | |
| _B_new = non_finished.shape[0] | |
| # Remove finished batches for the next step | |
| self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished) | |
| self._batch_offset = self._batch_offset.index_select(0, non_finished) | |
| non_finished = non_finished.to(self.topk_ids.device) | |
| self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished) | |
| self._batch_index = self._batch_index.index_select(0, non_finished) | |
| self.select_indices = self._batch_index.view(_B_new * self.beam_size) | |
| self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1)) | |
| self.topk_scores = self.topk_scores.index_select(0, non_finished) | |
| self.topk_ids = self.topk_ids.index_select(0, non_finished) | |
| if self.alive_attn is not None: | |
| inp_seq_len = self.alive_attn.size(-1) | |
| self.alive_attn = attention.index_select(1, non_finished) \ | |
| .view(step - 1, _B_new * self.beam_size, inp_seq_len) | |