| """Script to create MLE model for retriever component of CRB-CRS. | |
| For ReDial, use the following command: | |
| python -m script.crb_crs.create_mle \ | |
| --corpus_file data/redial/GT_corpus_tokens.txt \ | |
| --output_file data/models/crb_crs/mle_model.pkl | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| from src.model.crb_crs.retriever.mle_model import NGramMLE | |
| def parse_args() -> argparse.Namespace: | |
| """Parses command line arguments.""" | |
| parser = argparse.ArgumentParser( | |
| description="Create MLE model for retriever component of CRB-CRS." | |
| ) | |
| parser.add_argument( | |
| "--corpus_file", | |
| type=str, | |
| required=True, | |
| help="Path to the corpus file.", | |
| ) | |
| parser.add_argument( | |
| "--output_file", | |
| type=str, | |
| required=True, | |
| help="Path to save the created MLE model.", | |
| ) | |
| parser.add_argument( | |
| "--n", | |
| type=int, | |
| default=2, | |
| help="Maximum n-gram order. Defaults to 2.", | |
| ) | |
| return parser.parse_args() | |
| def main(args: argparse.Namespace) -> None: | |
| """Creates MLE model for retriever component of CRB-CRS. | |
| Args: | |
| args: Command line arguments. | |
| """ | |
| model = NGramMLE(args.n, args.corpus_file) | |
| model.create_ngrams() | |
| os.makedirs(os.path.dirname(args.output_file), exist_ok=True) | |
| model.save(args.output_file) | |
| logging.info(f"MLE model saved at {args.output_file}.") | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| main(parse_args()) | |