Spaces:
Running
Running
| """ | |
| Gradio interface for plotting policy. | |
| """ | |
| import chess | |
| import gradio as gr | |
| import uuid | |
| import torch | |
| from lczerolens.encodings import encode_move | |
| from src import constants, global_variables, visualisation | |
| def render_feature_index( | |
| file_id, | |
| feature_index | |
| ): | |
| if file_id is None: | |
| file_id = str(uuid.uuid4()) | |
| opt_features = global_variables.f_ds["opt_features"] | |
| f_acts = opt_features[:, feature_index] | |
| indices = f_acts.topk(16).indices | |
| board_images = [] | |
| colorbars = [] | |
| for topi, idx in enumerate(indices): | |
| s = global_variables.f_ds[idx.item()] | |
| pixel_index = global_variables.f_ds["pixel_index"][idx] | |
| features = [] | |
| for i in range(64): | |
| current_index = idx + i - pixel_index | |
| features.append(opt_features[current_index.item(), feature_index]) | |
| features = torch.stack(features) | |
| fen = s["opt_fen"] | |
| current_depth = s["current_depth"] | |
| uci_move = s["moves_opt"][current_depth + 6] | |
| move = chess.Move.from_uci(uci_move) | |
| board = chess.Board(fen) | |
| if board.turn: | |
| heatmap = features.view(64) | |
| else: | |
| heatmap = features.view(8, 8).flip(0).view(64) | |
| svg_board, fig = visualisation.render_heatmap( | |
| board, | |
| heatmap, | |
| arrows=[(move.from_square, move.to_square)], | |
| ) | |
| with open(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg", "w") as f: | |
| f.write(svg_board) | |
| board_images.append(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg") | |
| colorbars.append(fig) | |
| return file_id, *board_images, *colorbars | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| feature_index = gr.Slider( | |
| label="Feature index", | |
| minimum=0, | |
| maximum=constants.DICTIONARY_SIZE-1, | |
| step=1, | |
| value=0, | |
| ) | |
| board_images = [] | |
| colorbars = [] | |
| for i in range(4): | |
| with gr.Row(): | |
| for j in range(4): | |
| with gr.Column(): | |
| with gr.Group(): | |
| idx = 4*i + j | |
| with gr.Row(): | |
| board_images.append(gr.Image(label=f"Board {idx}")) | |
| with gr.Row(): | |
| colorbars.append(gr.Plot(label=f"Colorbar {idx}")) | |
| file_id = gr.State(None) | |
| feature_index.change( | |
| render_feature_index, | |
| inputs=[file_id, feature_index], | |
| outputs=[file_id, *board_images, *colorbars], | |
| ) |