| | import torch |
| | import joblib |
| | import numpy as np |
| | import mdtraj as md |
| | import matplotlib.pyplot as plt |
| | import pyemma.coordinates as coor |
| |
|
| | from .utils import compute_dihedral |
| | from matplotlib.colors import LinearSegmentedColormap |
| |
|
| | class Plot: |
| | def __init__(self, args, mds): |
| | self.device = args.device |
| | self.save_dir = args.save_dir |
| | self.molecule = args.molecule |
| | self.start_state = args.start_state |
| | self.num_samples = args.num_samples |
| | self.start_position = mds.start_position |
| | self.target_position = mds.target_position |
| | self.energy_function = mds.energy_function |
| |
|
| | def __call__(self): |
| | positions, potentials = [], [] |
| | for i in range(self.num_samples): |
| | position = np.load(f"{self.save_dir}/positions/{i}.npy").astype(np.float32) |
| | potential = self.energy_function(position)[1] |
| | positions.append(torch.from_numpy(position).to(self.device)) |
| | potentials.append(potential) |
| | self.paths(positions) |
| |
|
| | def paths(self, positions): |
| | zorder = 32 |
| | circle_size = 500 |
| | saddle_size = 2400 |
| | custom_colors_1 = ["#05009E", "#6B67EE", "#50B2D7", "#B0ADF1"] |
| | custom_colors_2 = ["#05009E", "#6B67EE", "#50B2D7", "#F7EFFF"] |
| | custom_cmap_1 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_1) |
| | custom_cmap_2 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_2) |
| | |
| | if self.molecule == "aldp": |
| | angle_1 = [6, 8, 14, 16] |
| | angle_2 = [1, 6, 8, 14] |
| | plt.clf() |
| | plt.close() |
| | fig = plt.figure(figsize=(7, 7)) |
| | ax = fig.add_subplot(111) |
| | plt.xlim([-np.pi, np.pi]) |
| | plt.ylim([-np.pi, np.pi]) |
| | with open("./data/aldp/landscape.dat") as f: |
| | lines = f.readlines() |
| | dims = [90, 90] |
| | locations = torch.zeros((int(dims[0]), int(dims[1]), 2)) |
| | data = torch.zeros((int(dims[0]), int(dims[1]))) |
| | i = 0 |
| | for line in lines[1:]: |
| | splits = line[0:-1].split(" ") |
| | vals = [y for y in splits if y != ""] |
| | x = float(vals[0]) |
| | y = float(vals[1]) |
| | val = float(vals[-1]) |
| | locations[i // 90, i % 90, :] = torch.tensor([x, y]) |
| | data[i // 90, i % 90] = val |
| | i = i + 1 |
| | xs = np.arange(-np.pi, np.pi + 0.1, 0.1) |
| | ys = np.arange(-np.pi, np.pi + 0.1, 0.1) |
| | x, y = np.meshgrid(xs, ys) |
| | inp = torch.tensor(np.array([x, y])).view(2, -1).T |
| | loc = locations.view(-1, 2) |
| | distances = torch.cdist(inp, loc.double(), p=2) |
| | index = distances.argmin(dim=1) |
| | a = torch.div(index, locations.shape[0], rounding_mode="trunc") |
| | b = index % locations.shape[0] |
| | z = data[a, b] |
| | z = z.view(y.shape[0], y.shape[1]) |
| | plt.contourf(xs, ys, z, levels=100, zorder=0, cmap=custom_cmap_2) |
| | |
| | |
| | cm = custom_cmap_2 |
| | |
| | """ax.set_prop_cycle( |
| | color=[cm(1.0 * i / len(positions)) for i in range(len(positions))] |
| | )""" |
| | for position in positions: |
| | psi = compute_dihedral(position[:, angle_1, :]).detach().cpu().numpy() |
| | phi = compute_dihedral(position[:, angle_2, :]).detach().cpu().numpy() |
| | |
| | ax.plot( |
| | phi, |
| | psi, |
| | marker="o", linestyle="None", markersize=2, alpha=1.0, |
| | markerfacecolor="white", |
| | markeredgecolor="none", |
| | markeredgewidth=0, |
| | ) |
| | |
| | |
| | end_phi, end_psi = phi[-1], psi[-1] |
| | ax.scatter( |
| | [end_phi], [end_psi], |
| | s=70, c="#D577FF", edgecolors="w", linewidths=0.8, |
| | zorder=zorder + 1, marker="o" |
| | ) |
| | |
| | start_psi = ( |
| | compute_dihedral(self.start_position[:, angle_1, :]) |
| | .detach() |
| | .cpu() |
| | .numpy() |
| | ) |
| | start_phi = ( |
| | compute_dihedral(self.start_position[:, angle_2, :]) |
| | .detach() |
| | .cpu() |
| | .numpy() |
| | ) |
| | target_psi = ( |
| | compute_dihedral(self.target_position[:, angle_1, :]) |
| | .detach() |
| | .cpu() |
| | .numpy() |
| | ) |
| | target_phi = ( |
| | compute_dihedral(self.target_position[:, angle_2, :]) |
| | .detach() |
| | .cpu() |
| | .numpy() |
| | ) |
| | phis_saddle = [-0.035, -0.017] |
| | psis_saddle = [1.605, -0.535] |
| | """ax.scatter( |
| | phis_saddle, |
| | psis_saddle, |
| | edgecolors="black", |
| | c="w", |
| | zorder=zorder, |
| | s=saddle_size, |
| | marker="*", |
| | )""" |
| | ax.scatter( |
| | start_phi, |
| | start_psi, |
| | edgecolors="w", |
| | c="#9793F8", |
| | zorder=zorder, |
| | s=circle_size, |
| | marker="*", |
| | ) |
| | ax.scatter( |
| | target_phi, |
| | target_psi, |
| | edgecolors="w", |
| | c="#9793F8", |
| | zorder=zorder, |
| | s=circle_size, |
| | marker="*", |
| | ) |
| | plt.xlabel("\u03A6", fontsize=35, fontweight="medium") |
| | plt.ylabel("\u03A8", fontsize=35, fontweight="medium") |
| | else: |
| | fig = plt.figure(figsize=(7, 7)) |
| | ax = fig.add_subplot(111) |
| | cm = plt.get_cmap("gist_rainbow") |
| | """ax.set_prop_cycle( |
| | color=[cm(1.0 * i / len(positions)) for i in range(len(positions))] |
| | )""" |
| | pmf = np.load(f"./data/{self.molecule}/pmf.npy") |
| | xs = np.load(f"./data/{self.molecule}/xs.npy") |
| | ys = np.load(f"./data/{self.molecule}/ys.npy") |
| | plt.pcolormesh(xs, ys, pmf.T, cmap=custom_cmap_1) |
| | tica_model = joblib.load(f"./data/{self.molecule}/tica_model.pkl") |
| | feat = coor.featurizer(f"./data/{self.molecule}/{self.start_state}.pdb") |
| | feat.add_backbone_torsions(cossin=True) |
| | for position in positions: |
| | traj = md.Trajectory( |
| | position.cpu().numpy(), |
| | md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, |
| | ) |
| | feature = feat.transform(traj) |
| | tica = tica_model.transform(feature) |
| | ax.plot( |
| | tica[:, 0], |
| | tica[:, 1], |
| | marker="o", |
| | linestyle="None", |
| | markersize=2, |
| | alpha=1.0, |
| | markerfacecolor="white", |
| | markeredgecolor="none", |
| | markeredgewidth=0, |
| | ) |
| | end_x, end_y = tica[-1, 0], tica[-1, 1] |
| | ax.scatter( |
| | [end_x], [end_y], |
| | s=70, c="#D577FF", edgecolors="w", linewidths=0.8, |
| | zorder=zorder + 1, marker="o" |
| | ) |
| | |
| | start_position = md.Trajectory( |
| | self.start_position.cpu().numpy(), |
| | md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, |
| | ) |
| | feature = feat.transform(start_position) |
| | start_tica = tica_model.transform(feature) |
| | ax.scatter( |
| | start_tica[:, 0], |
| | start_tica[:, 1], |
| | edgecolors="w", |
| | c="#9793F8", |
| | zorder=zorder, |
| | s=circle_size, |
| | marker="*", |
| | ) |
| | target_position = md.Trajectory( |
| | self.target_position.cpu().numpy(), |
| | md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, |
| | ) |
| | feature = feat.transform(target_position) |
| | target_tica = tica_model.transform(feature) |
| | ax.scatter( |
| | target_tica[:, 0], |
| | target_tica[:, 1], |
| | edgecolors="w", |
| | c="#9793F8", |
| | zorder=zorder, |
| | s=circle_size, |
| | marker="*", |
| | ) |
| | plt.xlabel("TIC 1", fontsize=35, fontweight="medium") |
| | plt.ylabel("TIC 2", fontsize=35, fontweight="medium") |
| | plt.xlim(xs.min(), xs.max()) |
| | plt.ylim(ys.min(), ys.max()) |
| | plt.tick_params( |
| | left=False, |
| | right=False, |
| | labelleft=False, |
| | labelbottom=False, |
| | bottom=False, |
| | ) |
| | plt.tight_layout() |
| | plt.savefig(f"{self.save_dir}/paths.png", dpi=300, bbox_inches="tight") |
| | plt.show() |
| | plt.close() |
| | return fig |
| |
|