Spaces:
Running
Running
| """ | |
| Visualisation utils. | |
| """ | |
| import chess | |
| import chess.svg | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000) | |
| ALPHA = 1.0 | |
| NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False) | |
| def render_heatmap( | |
| board, | |
| heatmap, | |
| square=None, | |
| vmin=None, | |
| vmax=None, | |
| arrows=None, | |
| normalise="none", | |
| ): | |
| """ | |
| Render a heatmap on the board. | |
| """ | |
| if normalise == "abs": | |
| a_max = heatmap.abs().max() | |
| if a_max != 0: | |
| heatmap = heatmap / a_max | |
| vmin = -1 | |
| vmax = 1 | |
| if vmin is None: | |
| vmin = heatmap.min() | |
| if vmax is None: | |
| vmax = heatmap.max() | |
| norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) | |
| color_dict = {} | |
| for square_index in range(64): | |
| color = COLOR_MAP(norm(heatmap[square_index])) | |
| color = (*color[:3], ALPHA) | |
| color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) | |
| fig = plt.figure(figsize=(6, 0.6)) | |
| ax = plt.gca() | |
| ax.axis("off") | |
| fig.colorbar( | |
| matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP), | |
| ax=ax, | |
| orientation="horizontal", | |
| fraction=1.0, | |
| ) | |
| if square is not None: | |
| try: | |
| check = chess.parse_square(square) | |
| except ValueError: | |
| check = None | |
| else: | |
| check = None | |
| if arrows is None: | |
| arrows = [] | |
| plt.close() | |
| return ( | |
| chess.svg.board( | |
| board, | |
| check=check, | |
| fill=color_dict, | |
| size=350, | |
| arrows=arrows, | |
| ), | |
| fig, | |
| ) | |
| def render_policy_distribution( | |
| policy, | |
| legal_moves, | |
| n_bins=20, | |
| ): | |
| """ | |
| Render the policy distribution histogram. | |
| """ | |
| legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool() | |
| fig = plt.figure(figsize=(6, 6)) | |
| ax = plt.gca() | |
| _, bins = np.histogram(policy, bins=n_bins) | |
| ax.hist( | |
| policy[~legal_mask], | |
| bins=bins, | |
| alpha=0.5, | |
| density=True, | |
| label="Illegal moves", | |
| ) | |
| ax.hist( | |
| policy[legal_mask], | |
| bins=bins, | |
| alpha=0.5, | |
| density=True, | |
| label="Legal moves", | |
| ) | |
| plt.xlabel("Policy") | |
| plt.ylabel("Density") | |
| plt.legend() | |
| plt.yscale("log") | |
| return fig | |