| import sys | |
| import os | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import xarray as xr | |
| from huggingface_hub import hf_hub_download | |
| from torchvision.transforms.functional import resize | |
| sys.path.append(os.path.abspath("poseidon_demo/external/poseidon")) | |
| from external.poseidon.scOT.model import ScOT, ScOTConfig | |
| def load_model(): | |
| """ | |
| Initializes and loads a POSEIDON model with fixed configuration. | |
| Returns: | |
| model (ScOT): An instance of the POSEIDON model in evaluation mode. | |
| """ | |
| config = ScOTConfig( | |
| num_channels=4, | |
| skip_connections=[True, True, True, True] | |
| ) | |
| model = ScOT(config) | |
| model.eval() | |
| return model | |
| def run_inference_by_domain(model, domain): | |
| """ | |
| Runs the model on a synthetic input based on the chosen domain. | |
| Args: | |
| model (ScOT): The POSEIDON model. | |
| domain (str): Domain to simulate input for. One of: 'Fluid Dynamics', 'Finance', 'Quantum', 'Biology / Medicine'. | |
| Returns: | |
| np.ndarray: The predicted model output. | |
| """ | |
| if domain == "Fluid Dynamics": | |
| x = torch.linspace(-1, 1, 224) | |
| y = torch.linspace(-1, 1, 224) | |
| X, Y = torch.meshgrid(x, y, indexing="ij") | |
| blob = torch.exp(-(X**2 + Y**2) * 10) | |
| input_tensor = blob.expand(4, 224, 224).unsqueeze(0) | |
| elif domain == "Finance": | |
| base = torch.linspace(0, 1, 224).reshape(1, -1).repeat(224, 1) | |
| noise = torch.randn(4, 224, 224) * 0.05 | |
| input_tensor = (base + noise).unsqueeze(0) | |
| elif domain == "Quantum": | |
| x = torch.linspace(0, 4 * torch.pi, 224) | |
| y = torch.linspace(0, 4 * torch.pi, 224) | |
| X, Y = torch.meshgrid(x, y, indexing="ij") | |
| sin_grid = torch.sin(X) * torch.sin(Y) | |
| input_tensor = sin_grid.expand(4, 224, 224).unsqueeze(0) | |
| elif domain == "Biology / Medicine": | |
| x = torch.linspace(-1, 1, 224) | |
| y = torch.linspace(-1, 1, 224) | |
| X, Y = torch.meshgrid(x, y, indexing="ij") | |
| base_blob = torch.exp(-(X**2 + Y**2) * 5) | |
| blob = torch.randn(4, 224, 224) * 0.2 + base_blob | |
| input_tensor = blob.unsqueeze(0) | |
| else: | |
| input_tensor = torch.randn(1, 4, 224, 224) | |
| time_tensor = torch.tensor([0.0]) | |
| with torch.no_grad(): | |
| output = model(pixel_values=input_tensor, time=time_tensor).output | |
| return output.squeeze().numpy() | |
| def run_inference_on_dataset(model, dataset_name): | |
| """ | |
| Downloads and runs inference on a real scientific dataset using POSEIDON. | |
| Args: | |
| model (ScOT): The POSEIDON model. | |
| dataset_name (str): Identifier for the dataset. | |
| Returns: | |
| tuple: (input_array, output_array) as numpy arrays. | |
| """ | |
| dataset_mapping = { | |
| "fluids.incompressible.Sines": { | |
| "repo_id": "camlab-ethz/NS-Sines", | |
| "filename": "velocity_0.nc", | |
| "variable": "velocity" | |
| }, | |
| "fluids.compressible.Riemann": { | |
| "repo_id": "camlab-ethz/CE-RP", | |
| "filename": "data_0.nc", | |
| "variable": "data" | |
| }, | |
| "reaction_diffusion.AllenCahn": { | |
| "repo_id": "camlab-ethz/ACE", | |
| "filename": "solution_0.nc", | |
| "variable": "solution" | |
| } | |
| } | |
| entry = dataset_mapping.get(dataset_name) | |
| if entry is None: | |
| raise ValueError(f"Unknown dataset name: {dataset_name}") | |
| file_path = hf_hub_download( | |
| repo_id=entry["repo_id"], | |
| filename=entry["filename"], | |
| repo_type="dataset" | |
| ) | |
| ds = xr.open_dataset(file_path) | |
| var = ds[entry["variable"]] | |
| print(f"Loaded shape: {var.shape}, dims: {var.dims}") | |
| if "sample" in var.dims: | |
| sample = var.isel(sample=0, time=0).values.astype(np.float32) | |
| else: | |
| sample = var.isel(time=0).values.astype(np.float32) | |
| if sample.ndim > 3: | |
| sample = np.squeeze(sample) | |
| while sample.ndim < 3: | |
| sample = np.expand_dims(sample, 0) | |
| tensor = torch.tensor(sample) | |
| if tensor.shape[-1] != 224 or tensor.shape[-2] != 224: | |
| tensor = resize(tensor, size=[224, 224]) | |
| if tensor.shape[0] < 4: | |
| pad = 4 - tensor.shape[0] | |
| extra = torch.zeros((pad, 224, 224)) | |
| tensor = torch.cat([tensor, extra], dim=0) | |
| elif tensor.shape[0] > 4: | |
| tensor = tensor[:4] | |
| input_tensor = tensor.unsqueeze(0) | |
| time_tensor = torch.tensor([0.0]) | |
| with torch.no_grad(): | |
| output = model(pixel_values=input_tensor, time=time_tensor).output | |
| return tensor.squeeze().numpy(), output.squeeze().numpy() | |
| def plot_output(output_array, cmap="inferno", contrast=2.0): | |
| """ | |
| Plots the output array from the model using a heatmap. | |
| Args: | |
| output_array (np.ndarray): Output from the model. | |
| cmap (str): Colormap used for visualization. | |
| contrast (float): Contrast scaling factor. | |
| Returns: | |
| matplotlib.figure.Figure: The heatmap figure. | |
| """ | |
| output_array = output_array - output_array.min() | |
| output_array = output_array / output_array.max() | |
| output_array = output_array ** contrast | |
| fig, ax = plt.subplots(figsize=(6, 5)) | |
| sns.heatmap( | |
| output_array, | |
| ax=ax, | |
| cmap=cmap, | |
| cbar=True, | |
| square=True, | |
| xticklabels=False, | |
| yticklabels=False, | |
| linewidths=0, | |
| ) | |
| ax.set_title("POSEIDON Output") | |
| ax.axis("off") | |
| return fig | |
| def plot_comparison(input_array, output_array, cmap="inferno"): | |
| """ | |
| Plots a side-by-side comparison of the input and the model output. | |
| Args: | |
| input_array (np.ndarray): Ground truth or input data. | |
| output_array (np.ndarray): Output predicted by the model. | |
| cmap (str): Colormap used for both plots. | |
| Returns: | |
| matplotlib.figure.Figure: Figure showing input vs output. | |
| """ | |
| fig, axs = plt.subplots(1, 2, figsize=(10, 4)) | |
| axs[0].imshow(input_array[0], cmap=cmap) | |
| axs[0].set_title("Ground Truth") | |
| axs[0].axis("off") | |
| axs[1].imshow(output_array, cmap=cmap) | |
| axs[1].set_title("POSEIDON Prediction") | |
| axs[1].axis("off") | |
| plt.tight_layout() | |
| return fig | |