| """Script to download external data for the project at build time.""" | |
| import argparse | |
| import logging | |
| import os | |
| import tarfile | |
| import wget | |
| def download_and_extract_models(models_url: str) -> None: | |
| """Downloads the models folder from the server and extracts it. | |
| Args: | |
| models_url: URL to download the models from. | |
| """ | |
| logging.debug("Downloading models folder.") | |
| models_targz = "models.tar.gz" | |
| models_folder = "data/models/" | |
| try: | |
| logging.debug(f"Downloading models from {models_url}.") | |
| wget.download(models_url, models_targz) | |
| logging.debug("Extracting models folder.") | |
| with tarfile.open(models_targz, "r:gz") as tar: | |
| tar.extractall(models_folder) | |
| os.remove(models_targz) | |
| logging.debug("Models folder downloaded and extracted.") | |
| except Exception as e: | |
| logging.error(f"Error downloading models folder: {e}") | |
| def download_and_extract_item_embeddings(item_embeddings_url: str) -> None: | |
| """Downloads the item embeddings folder from the server and extracts it. | |
| Args: | |
| item_embeddings_url: URL to download the item embeddings from. | |
| """ | |
| logging.debug("Downloading item embeddings folder.") | |
| item_embeddings_tarbz = "item_embeddings.tar.bz2" | |
| item_embeddings_folder = "data/" | |
| try: | |
| logging.debug( | |
| f"Downloading item embeddings from {item_embeddings_url}." | |
| ) | |
| wget.download(item_embeddings_url, item_embeddings_tarbz) | |
| logging.debug("Extracting item embeddings folder.") | |
| with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar: | |
| tar.extractall(item_embeddings_folder) | |
| os.remove(item_embeddings_tarbz) | |
| logging.debug("Item embeddings folder downloaded and extracted.") | |
| except Exception as e: | |
| logging.error(f"Error downloading item embeddings folder: {e}") | |
| def parse_args() -> argparse.Namespace: | |
| """Parses command line arguments.""" | |
| parser = argparse.ArgumentParser( | |
| description="Download external data for the project." | |
| ) | |
| parser.add_argument( | |
| "models", type=str, help="URL to download the models folder." | |
| ) | |
| parser.add_argument( | |
| "embeddings", | |
| type=str, | |
| help="URL to download the item embeddings folder", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.DEBUG) | |
| args = parse_args() | |
| if not os.path.exists("data/models"): | |
| logging.info("Downloading models...") | |
| download_and_extract_models(args.models) | |
| if not os.path.exists("data/embed_items"): | |
| logging.info("Downloading item embeddings...") | |
| download_and_extract_item_embeddings(args.embeddings) | |