Spaces:
Running
Running
| """Manage global variables for the app. | |
| """ | |
| from huggingface_hub import HfApi | |
| import gradio as gr | |
| from lczerolens import ModelWrapper | |
| import torch | |
| from datasets import load_dataset, Dataset | |
| from src import constants | |
| from src.helpers import SparseAutoEncoder, OutputGenerator | |
| hf_api: HfApi | |
| wrapper: ModelWrapper | |
| sae: SparseAutoEncoder | |
| generator: OutputGenerator | |
| f_ds: Dataset | |
| def setup(): | |
| global hf_api | |
| global wrapper | |
| global sae | |
| global generator | |
| global f_ds | |
| hf_api = HfApi(token=constants.HF_TOKEN) | |
| hf_api.snapshot_download( | |
| local_dir=f"{constants.ASSETS_FOLDER}/models", | |
| repo_id="lczero-planning/models", | |
| repo_type="model", | |
| ) | |
| hf_api.snapshot_download( | |
| local_dir=f"{constants.ASSETS_FOLDER}/saes", | |
| repo_id="lczero-planning/saes", | |
| repo_type="model", | |
| ) | |
| wrapper = ModelWrapper.from_onnx_path(f"{constants.ASSETS_FOLDER}/models/{constants.MODEL_NAME}").to(constants.DEVICE) | |
| sae_dict = torch.load( | |
| f"{constants.ASSETS_FOLDER}/saes/{constants.SAE_CONFIG}/model.pt", | |
| map_location=constants.DEVICE, | |
| ) | |
| sae = SparseAutoEncoder( | |
| constants.ACTIVATION_DIM, | |
| constants.DICTIONARY_SIZE, | |
| pre_bias=constants.PRE_BIAS, | |
| init_normalise_dict=constants.INIT_NORMALISE_DICT, | |
| ) | |
| sae.load_state_dict( | |
| sae_dict | |
| ) | |
| generator = OutputGenerator( | |
| sae=sae, | |
| wrapper=wrapper, | |
| module_exp=rf".*block{constants.LAYER}/conv2/relu" | |
| ) | |
| f_ds = load_dataset( | |
| constants.FEATURE_DATASET, | |
| constants.SAE_CONFIG, | |
| split="test" | |
| ).with_format("torch") | |
| if gr.NO_RELOAD: | |
| setup() | |