Spaces:
Running
Running
| """ | |
| Gradio interface for plotting policy. | |
| """ | |
| import chess | |
| import gradio as gr | |
| import uuid | |
| from lczerolens.encodings import encode_move | |
| from src import constants, global_variables, visualisation | |
| def compute_features_fn( | |
| features, | |
| model_output, | |
| file_id, | |
| root_fen, | |
| traj_fen, | |
| feature_index | |
| ): | |
| model_output, _, sae_output = global_variables.generator.generate( | |
| root_fen=root_fen, | |
| traj_fen=traj_fen | |
| ) | |
| features = sae_output["f"] | |
| first_output = render_feature_index( | |
| features, | |
| model_output, | |
| file_id, | |
| feature_index, | |
| traj_fen, | |
| ) | |
| game_info = f"WDL: {model_output.get('wdl')}" | |
| return *first_output, game_info | |
| def render_feature_index( | |
| features, | |
| model_output, | |
| file_id, | |
| feature_index, | |
| traj_fen, | |
| ): | |
| if file_id is None: | |
| file_id = str(uuid.uuid4()) | |
| board = chess.Board(traj_fen) | |
| pixel_features = features[:,feature_index] | |
| if board.turn: | |
| heatmap = pixel_features.view(64) | |
| else: | |
| heatmap = pixel_features.view(8,8).flip(0).view(64) | |
| best_legal_logit = None | |
| best_legal_move = None | |
| for move in board.legal_moves: | |
| move_index = encode_move(move, (board.turn, not board.turn)) | |
| logit = model_output["policy"][1,move_index].item() | |
| if best_legal_logit is None: | |
| best_legal_logit = logit | |
| else: | |
| best_legal_move = move | |
| svg_board, fig = visualisation.render_heatmap( | |
| board, | |
| heatmap, | |
| arrows=[(best_legal_move.from_square, best_legal_move.to_square)], | |
| ) | |
| with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f: | |
| f.write(svg_board) | |
| return ( | |
| features, | |
| model_output, | |
| file_id, | |
| f"{constants.FIGURES_FOLER}/{file_id}.svg", | |
| fig | |
| ) | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| root_fen = gr.Textbox( | |
| label="Root FEN", | |
| lines=1, | |
| max_lines=1, | |
| value=chess.STARTING_FEN, | |
| ) | |
| traj_fen = gr.Textbox( | |
| label="Trajectory FEN", | |
| lines=1, | |
| max_lines=1, | |
| value="rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1", | |
| ) | |
| compute_features = gr.Button("Compute features") | |
| with gr.Group(): | |
| with gr.Row(): | |
| feature_index = gr.Slider( | |
| label="Feature index", | |
| minimum=0, | |
| maximum=constants.N_FEATURES, | |
| step=1, | |
| value=0, | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="") | |
| with gr.Row(): | |
| colorbar = gr.Plot(label="Colorbar") | |
| with gr.Column(): | |
| board_image = gr.Image(label="Board") | |
| features = gr.State(None) | |
| model_output = gr.State(None) | |
| file_id = gr.State(None) | |
| compute_features.click( | |
| compute_features_fn, | |
| inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index], | |
| outputs=[features, model_output, file_id, board_image, colorbar, game_info], | |
| ) |