Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| import shutil | |
| from tempfile import NamedTemporaryFile | |
| from typing import Optional, Tuple | |
| import pandas as pd | |
| import torchaudio | |
| from examples.speech_to_text.data_utils import ( | |
| create_zip, | |
| extract_fbank_features, | |
| filter_manifest_df, | |
| gen_config_yaml, | |
| gen_vocab, | |
| get_zip_manifest, | |
| load_df_from_tsv, | |
| save_df_to_tsv, | |
| ) | |
| from torch import Tensor | |
| from torch.utils.data import Dataset | |
| from torchaudio.datasets.utils import download_url, extract_archive | |
| from tqdm import tqdm | |
| log = logging.getLogger(__name__) | |
| MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"] | |
| class CoVoST(Dataset): | |
| """Create a Dataset for CoVoST (https://github.com/facebookresearch/covost). | |
| Args: | |
| root (str): root path to the dataset and generated manifests/features | |
| source_language (str): source (audio) language | |
| target_language (str, optional): target (text) language, | |
| None for no translation (default: None) | |
| version (int, optional): CoVoST version. (default: 2) | |
| download (bool, optional): Whether to download the dataset if it is not | |
| found at root path. (default: ``False``). | |
| """ | |
| COVOST_URL_TEMPLATE = ( | |
| "https://dl.fbaipublicfiles.com/covost/" | |
| "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz" | |
| ) | |
| VERSIONS = {2} | |
| SPLITS = ["train", "dev", "test"] | |
| XX_EN_LANGUAGES = { | |
| 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"], | |
| 2: [ | |
| "fr", | |
| "de", | |
| "es", | |
| "ca", | |
| "it", | |
| "ru", | |
| "zh-CN", | |
| "pt", | |
| "fa", | |
| "et", | |
| "mn", | |
| "nl", | |
| "tr", | |
| "ar", | |
| "sv-SE", | |
| "lv", | |
| "sl", | |
| "ta", | |
| "ja", | |
| "id", | |
| "cy", | |
| ], | |
| } | |
| EN_XX_LANGUAGES = { | |
| 1: [], | |
| 2: [ | |
| "de", | |
| "tr", | |
| "fa", | |
| "sv-SE", | |
| "mn", | |
| "zh-CN", | |
| "cy", | |
| "ca", | |
| "sl", | |
| "et", | |
| "id", | |
| "ar", | |
| "ta", | |
| "lv", | |
| "ja", | |
| ], | |
| } | |
| def __init__( | |
| self, | |
| root: str, | |
| split: str, | |
| source_language: str, | |
| target_language: Optional[str] = None, | |
| version: int = 2, | |
| ) -> None: | |
| assert version in self.VERSIONS and split in self.SPLITS | |
| assert source_language is not None | |
| self.no_translation = target_language is None | |
| if not self.no_translation: | |
| assert "en" in {source_language, target_language} | |
| if source_language == "en": | |
| assert target_language in self.EN_XX_LANGUAGES[version] | |
| else: | |
| assert source_language in self.XX_EN_LANGUAGES[version] | |
| else: | |
| # Hack here so that we can get "split" column from CoVoST TSV. | |
| # Note that we use CoVoST train split for ASR which is an extension | |
| # to Common Voice train split. | |
| target_language = "de" if source_language == "en" else "en" | |
| self.root: Path = Path(root) | |
| cv_tsv_path = self.root / "validated.tsv" | |
| assert cv_tsv_path.is_file() | |
| covost_url = self.COVOST_URL_TEMPLATE.format( | |
| src_lang=source_language, tgt_lang=target_language | |
| ) | |
| covost_archive = self.root / Path(covost_url).name | |
| if not covost_archive.is_file(): | |
| download_url(covost_url, self.root.as_posix(), hash_value=None) | |
| extract_archive(covost_archive.as_posix()) | |
| cv_tsv = load_df_from_tsv(cv_tsv_path) | |
| covost_tsv = load_df_from_tsv( | |
| self.root / Path(covost_url).name.replace(".tar.gz", "") | |
| ) | |
| df = pd.merge( | |
| left=cv_tsv[["path", "sentence", "client_id"]], | |
| right=covost_tsv[["path", "translation", "split"]], | |
| how="inner", | |
| on="path", | |
| ) | |
| if split == "train": | |
| df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")] | |
| else: | |
| df = df[df["split"] == split] | |
| data = df.to_dict(orient="index").items() | |
| data = [v for k, v in sorted(data, key=lambda x: x[0])] | |
| self.data = [] | |
| for e in data: | |
| try: | |
| path = self.root / "clips" / e["path"] | |
| _ = torchaudio.info(path.as_posix()) | |
| self.data.append(e) | |
| except RuntimeError: | |
| pass | |
| def __getitem__( | |
| self, n: int | |
| ) -> Tuple[Tensor, int, str, str, Optional[str], str, str]: | |
| """Load the n-th sample from the dataset. | |
| Args: | |
| n (int): The index of the sample to be loaded | |
| Returns: | |
| tuple: ``(waveform, sample_rate, sentence, translation, speaker_id, | |
| sample_id)`` | |
| """ | |
| data = self.data[n] | |
| path = self.root / "clips" / data["path"] | |
| waveform, sample_rate = torchaudio.load(path) | |
| sentence = data["sentence"] | |
| translation = None if self.no_translation else data["translation"] | |
| speaker_id = data["client_id"] | |
| _id = data["path"].replace(".mp3", "") | |
| return waveform, sample_rate, sentence, translation, speaker_id, _id | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def process(args): | |
| root = Path(args.data_root).absolute() / args.src_lang | |
| if not root.is_dir(): | |
| raise NotADirectoryError(f"{root} does not exist") | |
| # Extract features | |
| feature_root = root / "fbank80" | |
| feature_root.mkdir(exist_ok=True) | |
| for split in CoVoST.SPLITS: | |
| print(f"Fetching split {split}...") | |
| dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) | |
| print("Extracting log mel filter bank features...") | |
| for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): | |
| extract_fbank_features( | |
| waveform, sample_rate, feature_root / f"{utt_id}.npy" | |
| ) | |
| # Pack features into ZIP | |
| zip_path = root / "fbank80.zip" | |
| print("ZIPing features...") | |
| create_zip(feature_root, zip_path) | |
| print("Fetching ZIP manifest...") | |
| audio_paths, audio_lengths = get_zip_manifest(zip_path) | |
| # Generate TSV manifest | |
| print("Generating manifest...") | |
| train_text = [] | |
| task = f"asr_{args.src_lang}" | |
| if args.tgt_lang is not None: | |
| task = f"st_{args.src_lang}_{args.tgt_lang}" | |
| for split in CoVoST.SPLITS: | |
| manifest = {c: [] for c in MANIFEST_COLUMNS} | |
| dataset = CoVoST(root, split, args.src_lang, args.tgt_lang) | |
| for _, _, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset): | |
| manifest["id"].append(utt_id) | |
| manifest["audio"].append(audio_paths[utt_id]) | |
| manifest["n_frames"].append(audio_lengths[utt_id]) | |
| manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt) | |
| manifest["speaker"].append(speaker_id) | |
| is_train_split = split.startswith("train") | |
| if is_train_split: | |
| train_text.extend(manifest["tgt_text"]) | |
| df = pd.DataFrame.from_dict(manifest) | |
| df = filter_manifest_df(df, is_train_split=is_train_split) | |
| save_df_to_tsv(df, root / f"{split}_{task}.tsv") | |
| # Generate vocab | |
| vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) | |
| spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}" | |
| with NamedTemporaryFile(mode="w") as f: | |
| for t in train_text: | |
| f.write(t + "\n") | |
| gen_vocab( | |
| Path(f.name), | |
| root / spm_filename_prefix, | |
| args.vocab_type, | |
| args.vocab_size | |
| ) | |
| # Generate config YAML | |
| gen_config_yaml( | |
| root, | |
| spm_filename=spm_filename_prefix + ".model", | |
| yaml_filename=f"config_{task}.yaml", | |
| specaugment_policy="lb", | |
| ) | |
| # Clean up | |
| shutil.rmtree(feature_root) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--data-root", "-d", required=True, type=str, | |
| help="data root with sub-folders for each language <root>/<src_lang>" | |
| ) | |
| parser.add_argument( | |
| "--vocab-type", | |
| default="unigram", | |
| required=True, | |
| type=str, | |
| choices=["bpe", "unigram", "char"], | |
| ), | |
| parser.add_argument("--vocab-size", default=1000, type=int) | |
| parser.add_argument("--src-lang", "-s", required=True, type=str) | |
| parser.add_argument("--tgt-lang", "-t", type=str) | |
| args = parser.parse_args() | |
| process(args) | |
| if __name__ == "__main__": | |
| main() | |