Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import itertools as it | |
| from typing import Any, Dict, List | |
| import torch | |
| from fairseq.data.dictionary import Dictionary | |
| from fairseq.models.fairseq_model import FairseqModel | |
| class BaseDecoder: | |
| def __init__(self, tgt_dict: Dictionary) -> None: | |
| self.tgt_dict = tgt_dict | |
| self.vocab_size = len(tgt_dict) | |
| self.blank = ( | |
| tgt_dict.index("<ctc_blank>") | |
| if "<ctc_blank>" in tgt_dict.indices | |
| else tgt_dict.bos() | |
| ) | |
| if "<sep>" in tgt_dict.indices: | |
| self.silence = tgt_dict.index("<sep>") | |
| elif "|" in tgt_dict.indices: | |
| self.silence = tgt_dict.index("|") | |
| else: | |
| self.silence = tgt_dict.eos() | |
| def generate( | |
| self, models: List[FairseqModel], sample: Dict[str, Any], **unused | |
| ) -> List[List[Dict[str, torch.LongTensor]]]: | |
| encoder_input = { | |
| k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" | |
| } | |
| emissions = self.get_emissions(models, encoder_input) | |
| return self.decode(emissions) | |
| def get_emissions( | |
| self, | |
| models: List[FairseqModel], | |
| encoder_input: Dict[str, Any], | |
| ) -> torch.FloatTensor: | |
| model = models[0] | |
| encoder_out = model(**encoder_input) | |
| if hasattr(model, "get_logits"): | |
| emissions = model.get_logits(encoder_out) | |
| else: | |
| emissions = model.get_normalized_probs(encoder_out, log_probs=True) | |
| return emissions.transpose(0, 1).float().cpu().contiguous() | |
| def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: | |
| idxs = (g[0] for g in it.groupby(idxs)) | |
| idxs = filter(lambda x: x != self.blank, idxs) | |
| return torch.LongTensor(list(idxs)) | |
| def decode( | |
| self, | |
| emissions: torch.FloatTensor, | |
| ) -> List[List[Dict[str, torch.LongTensor]]]: | |
| raise NotImplementedError | |