Spaces:
Running
Running
| """Script to generate features for a given board state. | |
| """ | |
| from typing import Optional | |
| from lczerolens import ModelWrapper | |
| from lczerolens.xai import ActivationLens | |
| from lczerolens.encodings import InputEncoding | |
| import chess | |
| import einops | |
| import torch | |
| from .sae import SparseAutoEncoder | |
| class OutputGenerator: | |
| def __init__(self, sae: SparseAutoEncoder, wrapper: ModelWrapper, module_exp: Optional[str] = None): | |
| self.sae = sae | |
| self.wrapper = wrapper | |
| self.lens = ActivationLens(module_exp=module_exp) | |
| def generate( | |
| self, | |
| root_fen: Optional[str] = None, | |
| traj_fen: Optional[str] = None, | |
| root_board: Optional[chess.Board] = None, | |
| traj_board: Optional[chess.Board] = None, | |
| ): | |
| if root_board is not None and traj_board is not None: | |
| input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE | |
| elif root_fen is not None and traj_fen is not None: | |
| root_board = chess.Board(root_fen) | |
| traj_board = chess.Board(traj_fen) | |
| input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED | |
| else: | |
| raise ValueError | |
| iter_boards = iter([([root_board, traj_board],)]) | |
| result_iter = self.lens.analyse_batched_boards( | |
| iter_boards, | |
| self.wrapper, | |
| return_output=True, | |
| wrapper_kwargs={ | |
| "input_encoding": input_encoding, | |
| } | |
| ) | |
| act_dict, (model_output,) = next(result_iter) | |
| if len(act_dict) == 0: | |
| raise ValueError("No module matced the given expression.") | |
| elif len(act_dict) > 1: | |
| raise ValueError("Multiple modules matched the given expression.") | |
| acts = next(iter(act_dict.values())) | |
| root_acts = einops.rearrange(acts[0], "c h w -> (h w) c") | |
| traj_acts = einops.rearrange(acts[1], "c h w -> (h w) c") | |
| pixel_acts = torch.cat([root_acts, traj_acts], dim=1) | |
| sae_output = self.sae(pixel_acts, output_features=True) | |
| return model_output, pixel_acts, sae_output | |