import sys import os import torch import numpy as np import matplotlib.pyplot as plt import seaborn as sns import xarray as xr from huggingface_hub import hf_hub_download from torchvision.transforms.functional import resize sys.path.append(os.path.abspath("poseidon_demo/external/poseidon")) from external.poseidon.scOT.model import ScOT, ScOTConfig def load_model(): """ Initializes and loads a POSEIDON model with fixed configuration. Returns: model (ScOT): An instance of the POSEIDON model in evaluation mode. """ config = ScOTConfig( num_channels=4, skip_connections=[True, True, True, True] ) model = ScOT(config) model.eval() return model def run_inference_by_domain(model, domain): """ Runs the model on a synthetic input based on the chosen domain. Args: model (ScOT): The POSEIDON model. domain (str): Domain to simulate input for. One of: 'Fluid Dynamics', 'Finance', 'Quantum', 'Biology / Medicine'. Returns: np.ndarray: The predicted model output. """ if domain == "Fluid Dynamics": x = torch.linspace(-1, 1, 224) y = torch.linspace(-1, 1, 224) X, Y = torch.meshgrid(x, y, indexing="ij") blob = torch.exp(-(X**2 + Y**2) * 10) input_tensor = blob.expand(4, 224, 224).unsqueeze(0) elif domain == "Finance": base = torch.linspace(0, 1, 224).reshape(1, -1).repeat(224, 1) noise = torch.randn(4, 224, 224) * 0.05 input_tensor = (base + noise).unsqueeze(0) elif domain == "Quantum": x = torch.linspace(0, 4 * torch.pi, 224) y = torch.linspace(0, 4 * torch.pi, 224) X, Y = torch.meshgrid(x, y, indexing="ij") sin_grid = torch.sin(X) * torch.sin(Y) input_tensor = sin_grid.expand(4, 224, 224).unsqueeze(0) elif domain == "Biology / Medicine": x = torch.linspace(-1, 1, 224) y = torch.linspace(-1, 1, 224) X, Y = torch.meshgrid(x, y, indexing="ij") base_blob = torch.exp(-(X**2 + Y**2) * 5) blob = torch.randn(4, 224, 224) * 0.2 + base_blob input_tensor = blob.unsqueeze(0) else: input_tensor = torch.randn(1, 4, 224, 224) time_tensor = torch.tensor([0.0]) with torch.no_grad(): output = model(pixel_values=input_tensor, time=time_tensor).output return output.squeeze().numpy() def run_inference_on_dataset(model, dataset_name): """ Downloads and runs inference on a real scientific dataset using POSEIDON. Args: model (ScOT): The POSEIDON model. dataset_name (str): Identifier for the dataset. Returns: tuple: (input_array, output_array) as numpy arrays. """ dataset_mapping = { "fluids.incompressible.Sines": { "repo_id": "camlab-ethz/NS-Sines", "filename": "velocity_0.nc", "variable": "velocity" }, "fluids.compressible.Riemann": { "repo_id": "camlab-ethz/CE-RP", "filename": "data_0.nc", "variable": "data" }, "reaction_diffusion.AllenCahn": { "repo_id": "camlab-ethz/ACE", "filename": "solution_0.nc", "variable": "solution" } } entry = dataset_mapping.get(dataset_name) if entry is None: raise ValueError(f"Unknown dataset name: {dataset_name}") file_path = hf_hub_download( repo_id=entry["repo_id"], filename=entry["filename"], repo_type="dataset" ) ds = xr.open_dataset(file_path, engine="netcdf4") var = ds[entry["variable"]] print(f"Loaded shape: {var.shape}, dims: {var.dims}") if "sample" in var.dims: sample = var.isel(sample=0, time=0).values.astype(np.float32) else: sample = var.isel(time=0).values.astype(np.float32) if sample.ndim > 3: sample = np.squeeze(sample) while sample.ndim < 3: sample = np.expand_dims(sample, 0) tensor = torch.tensor(sample) if tensor.shape[-1] != 224 or tensor.shape[-2] != 224: tensor = resize(tensor, size=[224, 224]) if tensor.shape[0] < 4: pad = 4 - tensor.shape[0] extra = torch.zeros((pad, 224, 224)) tensor = torch.cat([tensor, extra], dim=0) elif tensor.shape[0] > 4: tensor = tensor[:4] input_tensor = tensor.unsqueeze(0) time_tensor = torch.tensor([0.0]) with torch.no_grad(): output = model(pixel_values=input_tensor, time=time_tensor).output return tensor.squeeze().numpy(), output.squeeze().numpy() def plot_output(output_array, cmap="inferno", contrast=2.0): """ Plots the output array from the model using a heatmap. Args: output_array (np.ndarray): Output from the model. cmap (str): Colormap used for visualization. contrast (float): Contrast scaling factor. Returns: matplotlib.figure.Figure: The heatmap figure. """ output_array = output_array - output_array.min() output_array = output_array / output_array.max() output_array = output_array ** contrast fig, ax = plt.subplots(figsize=(6, 5)) sns.heatmap( output_array, ax=ax, cmap=cmap, cbar=True, square=True, xticklabels=False, yticklabels=False, linewidths=0, ) ax.set_title("POSEIDON Output") ax.axis("off") return fig def plot_comparison(input_array, output_array, cmap="inferno"): """ Plots a side-by-side comparison of the input and the model output. Args: input_array (np.ndarray): Ground truth or input data. output_array (np.ndarray): Output predicted by the model. cmap (str): Colormap used for both plots. Returns: matplotlib.figure.Figure: Figure showing input vs output. """ fig, axs = plt.subplots(1, 2, figsize=(10, 4)) axs[0].imshow(input_array[0], cmap=cmap) axs[0].set_title("Ground Truth") axs[0].axis("off") axs[1].imshow(output_array, cmap=cmap) axs[1].set_title("POSEIDON Prediction") axs[1].axis("off") plt.tight_layout() return fig