Upload 42 files
Browse files- .gitattributes +1 -0
- README.md +63 -12
- __init__.py +0 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/__init__.cpython-312.pyc +0 -0
- __pycache__/app.cpython-310.pyc +0 -0
- __pycache__/app.cpython-312.pyc +0 -0
- __pycache__/poseidon_model.cpython-310.pyc +0 -0
- __pycache__/simulations.cpython-310.pyc +0 -0
- app.py +138 -0
- external/.DS_Store +0 -0
- external/poseidon/.gitignore +160 -0
- external/poseidon/README.md +151 -0
- external/poseidon/assets/fig1.png +3 -0
- external/poseidon/configs/README.md +3 -0
- external/poseidon/configs/run.yaml +26 -0
- external/poseidon/configs/sweep.yaml +57 -0
- external/poseidon/pyproject.toml +20 -0
- external/poseidon/scOT/__init__.py +0 -0
- external/poseidon/scOT/__pycache__/__init__.cpython-310.pyc +0 -0
- external/poseidon/scOT/__pycache__/model.cpython-310.pyc +0 -0
- external/poseidon/scOT/inference.py +950 -0
- external/poseidon/scOT/metrics.py +55 -0
- external/poseidon/scOT/model.py +1485 -0
- external/poseidon/scOT/problems/__init__.py +0 -0
- external/poseidon/scOT/problems/base.py +395 -0
- external/poseidon/scOT/problems/elliptic/__init__.py +0 -0
- external/poseidon/scOT/problems/elliptic/helmholtz.py +49 -0
- external/poseidon/scOT/problems/elliptic/poisson.py +50 -0
- external/poseidon/scOT/problems/fluids/__init__.py +0 -0
- external/poseidon/scOT/problems/fluids/compressible.py +308 -0
- external/poseidon/scOT/problems/fluids/incompressible.py +331 -0
- external/poseidon/scOT/problems/fluids/normalization_constants.py +9 -0
- external/poseidon/scOT/problems/reaction_diffusion/__init__.py +0 -0
- external/poseidon/scOT/problems/reaction_diffusion/allen_cahn.py +53 -0
- external/poseidon/scOT/problems/wave/__init__.py +0 -0
- external/poseidon/scOT/problems/wave/acoustic.py +125 -0
- external/poseidon/scOT/train.py +537 -0
- external/poseidon/scOT/trainer.py +762 -0
- external/poseidon/scOT/utils.py +97 -0
- poseidon_model.py +211 -0
- requirements.txt +10 -0
- simulations.py +84 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
external/poseidon/assets/fig1.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
title: Test Space
|
| 3 |
-
emoji: 😻
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: yellow
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.23.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
short_description: Test
|
| 12 |
-
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🔱 POSEIDON Playground: Across Scientific Domains 🔱
|
| 2 |
+
|
| 3 |
+
**An interactive Gradio demo** exploring how the POSEIDON foundation model for solving Partial Differential Equations (PDEs) could be applied across physics, finance, quantum mechanics, and biology.
|
| 4 |
+
|
| 5 |
+
> Built with love 💖 🔱 for the Hugging Face Community ML Research Engineer take-home assignment.
|
| 6 |
+
> Inspired by [POSEIDON: A Foundation Model for Solving PDEs](https://arxiv.org/abs/2405.19101) by CamLab ETH Zürich.
|
| 7 |
+
> Code from the original repo: [github.com/camlab-ethz/poseidon](https://github.com/camlab-ethz/poseidon)
|
| 8 |
+
|
| 9 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
## Goal
|
| 12 |
+
|
| 13 |
+
This app highlights the **multidisciplinary potential** of pre-trained PDE models like POSEIDON using:
|
| 14 |
+
|
| 15 |
+
- Intuitive **interactive visualizations**
|
| 16 |
+
- Simple simulations from **four real-world domains**
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## 🪩 What You Can Do
|
| 21 |
+
|
| 22 |
+
| Feature | Description |
|
| 23 |
+
|-----------|-------------|
|
| 24 |
+
| ✔️ Pick a scientific domain | Finance, Quantum, Fluids, Biology |
|
| 25 |
+
| ✔️ Run a mini simulation | See how PDEs behave in each field |
|
| 26 |
+
| ✔️ Try POSEIDON inference | Generate predictions from synthetic inputs |
|
| 27 |
+
| ✔️ Use real PDE datasets | Compare POSEIDON output vs. ground truth |
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## 🚀 Running Locally
|
| 33 |
+
|
| 34 |
+
1. **Clone this repo**
|
| 35 |
+
```bash
|
| 36 |
+
git clone https://github.com/YOUR_USERNAME/poseidon_demo.git
|
| 37 |
+
cd poseidon_demo
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
2. **(Option A) Using Conda**
|
| 41 |
+
```bash
|
| 42 |
+
conda env create -f environment.yml
|
| 43 |
+
conda activate poseidon-demo
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
3. **(Option B) Using virtualenv**
|
| 47 |
+
```bash
|
| 48 |
+
python -m venv .venv
|
| 49 |
+
source .venv/bin/activate
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
4. **(Run the demo**
|
| 55 |
+
```bash
|
| 56 |
+
python -m poseidon_demo.app
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
## 🪩 Big Thanks
|
| 61 |
+
|
| 62 |
+
- Hugging Face for the opportunity and open tools
|
| 63 |
+
- ETH Zürich’s CamLab for releasing the POSEIDON repo
|
| 64 |
+
- 💖 You, the curious science hacker, for playing with this demo!
|
| 65 |
+
|
__init__.py
ADDED
|
File without changes
|
__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (180 Bytes). View file
|
|
|
__pycache__/app.cpython-310.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
__pycache__/app.cpython-312.pyc
ADDED
|
Binary file (4.28 kB). View file
|
|
|
__pycache__/poseidon_model.cpython-310.pyc
ADDED
|
Binary file (5.53 kB). View file
|
|
|
__pycache__/simulations.cpython-310.pyc
ADDED
|
Binary file (2.68 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from .simulations import finance_demo, quantum_demo, fluid_demo, bio_demo
|
| 3 |
+
from .poseidon_model import load_model, run_inference_by_domain, run_inference_on_dataset, plot_output, plot_comparison
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def run_poseidon_demo(domain, contrast, cmap):
|
| 7 |
+
"""
|
| 8 |
+
Loads the POSEIDON model and runs it on synthetic input data
|
| 9 |
+
based on the selected scientific domain (e.g., Finance, Quantum).
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
domain (str): Selected scientific field.
|
| 13 |
+
contrast (float): Contrast setting for visualization.
|
| 14 |
+
cmap (str): Colormap choice for heatmap.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Matplotlib figure showing the output.
|
| 18 |
+
"""
|
| 19 |
+
model = load_model()
|
| 20 |
+
output = run_inference_by_domain(model, domain)
|
| 21 |
+
return plot_output(output, contrast=contrast, cmap=cmap)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def render_demo(domain):
|
| 26 |
+
"""
|
| 27 |
+
Returns a mini-simulation plot and a descriptive explanation
|
| 28 |
+
for the selected domain.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
domain (str): One of Finance, Quantum, Fluid Dynamics, Biology.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Tuple of (plot, explanatory markdown string).
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
if domain == "Finance":
|
| 38 |
+
return finance_demo(), (
|
| 39 |
+
"📍 **Finance:** PDEs like Black-Scholes are used to model option pricing. "
|
| 40 |
+
"Imagine fine-tuning Poseidon to forecast derivatives across market regimes!"
|
| 41 |
+
)
|
| 42 |
+
elif domain == "Quantum":
|
| 43 |
+
return quantum_demo(), (
|
| 44 |
+
"📍 **Quantum Mechanics:** Schrödinger's equation is a core PDE in quantum physics. "
|
| 45 |
+
"Could Poseidon learn to generalize across quantum systems?"
|
| 46 |
+
)
|
| 47 |
+
elif domain == "Fluid Dynamics":
|
| 48 |
+
return fluid_demo(), (
|
| 49 |
+
"📍 **Fluid Dynamics:** Poseidon is pretrained here! This sim shows 1D flow, "
|
| 50 |
+
"but Poseidon can do much more."
|
| 51 |
+
)
|
| 52 |
+
elif domain == "Biology / Medicine":
|
| 53 |
+
return bio_demo(), (
|
| 54 |
+
"📍 **Biology:** Reaction-diffusion equations appear in tissue growth and morphogenesis. "
|
| 55 |
+
"Poseidon could help model organ behavior!"
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
return None, "Pick a domain to explore how Poseidon might apply!"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def run_poseidon_real_dataset(dataset_name):
|
| 63 |
+
"""
|
| 64 |
+
Loads Poseidon and runs inference on a real scientific dataset from the Hub.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
dataset_name (str): Dataset ID from dropdown.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Matplotlib figure with side-by-side comparison of input vs output.
|
| 71 |
+
"""
|
| 72 |
+
model = load_model()
|
| 73 |
+
input_array, output_array = run_inference_on_dataset(model, dataset_name)
|
| 74 |
+
return plot_comparison(input_array, output_array)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# this part defines the app
|
| 78 |
+
with gr.Blocks() as demo:
|
| 79 |
+
gr.Markdown("# 🔱 POSEIDON Application Across Scientific Domains 🔱")
|
| 80 |
+
|
| 81 |
+
gr.Markdown("### **Welcome to the POSEIDON Playground!**")
|
| 82 |
+
gr.Markdown("Ever dreamed of solving physics equations with a single click? You’re in the right place.")
|
| 83 |
+
gr.Markdown("POSEIDON is a foundation model that learned to solve partial differential equations (PDEs) — the magical "
|
| 84 |
+
"math behind fluid flows, quantum mechanics, financial markets, and even biology!")
|
| 85 |
+
|
| 86 |
+
gr.Markdown("## ☑️ 1. Pick a scientific domain to see a simple PDE simulation and Explanation.")
|
| 87 |
+
domain_dropdown = gr.Dropdown(
|
| 88 |
+
["Finance", "Quantum", "Fluid Dynamics", "Biology / Medicine"],
|
| 89 |
+
label="Choose The Field",
|
| 90 |
+
value="Finance"
|
| 91 |
+
)
|
| 92 |
+
sim_output = gr.Plot()
|
| 93 |
+
sim_text = gr.Markdown()
|
| 94 |
+
|
| 95 |
+
domain_dropdown.change(fn=render_demo, inputs=domain_dropdown, outputs=[sim_output, sim_text])
|
| 96 |
+
|
| 97 |
+
gr.Markdown("## 🚀 Run a test output from the POSEIDON model based on the chosen domain")
|
| 98 |
+
|
| 99 |
+
with gr.Row():
|
| 100 |
+
gr.Markdown("Play with contrast and choose the colormap you prefer.")
|
| 101 |
+
contrast_slider = gr.Slider(0.5, 5.0, value=2.0, step=0.1, label=" Contrast")
|
| 102 |
+
cmap_dropdown = gr.Dropdown(
|
| 103 |
+
["inferno", "viridis", "plasma"],
|
| 104 |
+
label="Colormap",
|
| 105 |
+
value="inferno"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
with gr.Row():
|
| 109 |
+
poseidon_button = gr.Button("POSEIDON Test Output")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
poseidon_plot = gr.Plot()
|
| 113 |
+
poseidon_button.click(
|
| 114 |
+
fn=run_poseidon_demo,
|
| 115 |
+
inputs=[domain_dropdown, contrast_slider, cmap_dropdown],
|
| 116 |
+
outputs=poseidon_plot
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
gr.Markdown("# ")
|
| 120 |
+
|
| 121 |
+
gr.Markdown("## ☑️ 2. Try POSEIDON on Real Scientific Datasets")
|
| 122 |
+
dataset_dropdown = gr.Dropdown(
|
| 123 |
+
["fluids.incompressible.Sines", "fluids.compressible.Riemann", "reaction_diffusion.AllenCahn"],
|
| 124 |
+
label="Choose a Real Dataset"
|
| 125 |
+
)
|
| 126 |
+
dataset_button = gr.Button("POSEIDON on Dataset")
|
| 127 |
+
dataset_plot = gr.Plot()
|
| 128 |
+
|
| 129 |
+
dataset_button.click(
|
| 130 |
+
fn=run_poseidon_real_dataset,
|
| 131 |
+
inputs=[dataset_dropdown],
|
| 132 |
+
outputs=dataset_plot
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
demo.launch()
|
external/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
external/poseidon/.gitignore
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
external/poseidon/README.md
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Poseidon: Efficient Foundation Models for PDEs
|
| 2 |
+
|
| 3 |
+
This is the source code for the paper [*Poseidon: Efficient Foundation Models for PDEs*](https://arxiv.org/abs/2405.19101). It also acts as a package if you want to use the models in your code.
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
Find pretrained models and pretraining dataset in our collection on the [🤗 Hub – Pretrained Models and Pretraining Datasets](https://huggingface.co/collections/camlab-ethz/poseidon-664fa125729c53d8607e209a). All datasets corresponding to downstream tasks can be downloaded from the respective collection on the [🤗 Hub – Downstream Tasks](https://huggingface.co/collections/camlab-ethz/poseidon-downstream-tasks-664fa237cd6b0c097971ef14) as well. To use them, follow the respective sections below.
|
| 8 |
+
|
| 9 |
+
## Usage
|
| 10 |
+
|
| 11 |
+
### Installation & Requirements
|
| 12 |
+
|
| 13 |
+
To get all requirements and install the package, run (inside this folder), after getting this repository:
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
pip install -e .
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
We recommend running the above command in a [virtual environment](https://docs.python.org/3/library/venv.html).
|
| 20 |
+
|
| 21 |
+
After installation, you can import the models and use the training and inference scripts from everywhere on your system.
|
| 22 |
+
|
| 23 |
+
### Using the models in your own code
|
| 24 |
+
|
| 25 |
+
To use the (pretrained) models in your own code, you can use the following code snippet (after installing):
|
| 26 |
+
|
| 27 |
+
```python
|
| 28 |
+
from scOT.model import ScOT
|
| 29 |
+
|
| 30 |
+
model = ScOT.from_pretrained("camlab-ethz/Poseidon-<MODEL_SIZE>")
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
This will load the pretrained model from the 🤗 Hub. `<MODEL_SIZE>` has to be replaced by `T`, `B`, or `L`, for the respective pretrained model. You can also load a model from a local path by providing the path to the `from_pretrained` method.
|
| 34 |
+
|
| 35 |
+
To finetune and replace embeddings and recovery parameters, load the model as follows:
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
from scOT.model import ScOT
|
| 39 |
+
|
| 40 |
+
model = ScOT.from_pretrained("camlab-ethz/Poseidon-<MODEL_SIZE>", config=model_config, ignore_mismatched_sizes=True)
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
Here, `model_config` is a `ScOTConfig` with the correct input/output dimensions. We also refer to [the training/finetuning script](scOT/train.py), see below on usage, which might be easier.
|
| 44 |
+
|
| 45 |
+
### Training & Finetuning
|
| 46 |
+
|
| 47 |
+
The easiest way to finetune **Poseidon** on your own dataset is by plugging in your own dataset and running the provided training script as follows:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
accelerate launch scOT/train.py \
|
| 51 |
+
--config <WANDB_CONFIG_FILE> \
|
| 52 |
+
--wandb_run_name <WANDB_RUN_NAME> \
|
| 53 |
+
--wandb_project_name <WANDB_PROJECT_NAME> \
|
| 54 |
+
--checkpoint_path <CHECKPOINT_PATH> \
|
| 55 |
+
--data_path <DATA_PATH> \
|
| 56 |
+
--finetune_from <PRETRAINED_MODEL> \
|
| 57 |
+
--replace_embedding_recovery <SET ONLY IF EMBED/RECOVERY NEEDS TO BE REPLACED>
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
For more arguments and options, see the help message of the script:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
accelerate launch scOT/train.py --help
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
Since the code is built on top of [🤗 Accelerate](https://huggingface.co/docs/accelerate/en/index), you should run `accelerate config` first.
|
| 67 |
+
|
| 68 |
+
We also make heavy use of [Weights and Biases](wandb.com) to log and organise all our runs. The code might run without it (by setting `WANDB_MODE=disabled`), but we don't give any guarantees as this probably breaks the folder structure.
|
| 69 |
+
|
| 70 |
+
Most of the actual training configuration is set in a YAML config file (see for all arguments to set for a single W&B [run](configs/run.yaml) or a W&B [sweep](configs/sweep.yaml) (multiple runs, see the [W&B documentation](https://docs.wandb.ai/guides/sweeps) on how to start a sweep)). The config file is passed to the training script via the `--config` argument.
|
| 71 |
+
|
| 72 |
+
We do our pretrainings with the same script.
|
| 73 |
+
|
| 74 |
+
### Inference/Testing
|
| 75 |
+
|
| 76 |
+
To evaluate a model on a dataset, you can use the inference script, for all possible arguments see the help message:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python -m scOT.inference --help
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## Datasets
|
| 83 |
+
|
| 84 |
+
We provide all datasets used in the paper on the 🤗 Hub. You can download them from the respective collections:
|
| 85 |
+
- [🤗 Hub – Pretraining Datasets](https://huggingface.co/collections/camlab-ethz/poseidon-664fa125729c53d8607e209a)
|
| 86 |
+
- [🤗 Hub – Downstream Tasks](https://huggingface.co/collections/camlab-ethz/poseidon-downstream-tasks-664fa237cd6b0c097971ef14)
|
| 87 |
+
|
| 88 |
+
### Naming convention in the code
|
| 89 |
+
|
| 90 |
+
In the code, we refer to the datasets by a different identifier than on the 🤗 Hub, see the following table for a mapping:
|
| 91 |
+
|
| 92 |
+
| Code Identifier | 🤗 Hub/Paper Identifier |
|
| 93 |
+
| ----------------|------------------------- |
|
| 94 |
+
|fluids.incompressible.Sines| NS-Sines|
|
| 95 |
+
|fluids.incompressible.Gaussians| NS-Gauss|
|
| 96 |
+
|fluids.compressible.Riemann|CE-RP|
|
| 97 |
+
|fluids.compressible.RiemannCurved|CE-CRP|
|
| 98 |
+
|fluids.compressible.KelvinHelmholtz|CE-KH|
|
| 99 |
+
|fluids.compressible.Gaussians|CE-Gauss|
|
| 100 |
+
|fluids.incompressible.PiecewiseConstants|NS-PwC|
|
| 101 |
+
|fluids.incompressible.VortexSheet|NS-SVS|
|
| 102 |
+
|fluids.incompressible.BrownianBridge|NS-BB|
|
| 103 |
+
|fluids.incompressible.ShearLayer|NS-SL|
|
| 104 |
+
|fluids.incomressible.PiecewiseConstants.tracer|NS-Tracer-PwC|
|
| 105 |
+
|fluids.incompressible.forcing.KolmogorovFlow|FNS-KF|
|
| 106 |
+
|fluids.compressible.RiemannKelvinHelmholtz|CE-RPUI|
|
| 107 |
+
|fluids.compressible.RichtmyerMeshkov|CE-RM|
|
| 108 |
+
|fluids.compressible.gravity.RayleighTaylor|GCE-RT|
|
| 109 |
+
|wave.Layer|Wave-Layer|
|
| 110 |
+
|wave.Gaussians|Wave-Gauss|
|
| 111 |
+
|reaction_diffusion.AllenCahn|ACE|
|
| 112 |
+
|fluids.compressible.steady.Airfoil(.time)|SE-AF|
|
| 113 |
+
|elliptic.poisson.Gaussians(.time)|Poisson-Gauss|
|
| 114 |
+
|elliptic.Helmholtz(.time)|Helmholtz|
|
| 115 |
+
|
| 116 |
+
Adding the suffix `.time` to the dataset identifier will load the dataset as time-dependent dataset, i.e. as a long-time limit – use that suffix for finetuning on time-independent datasets.
|
| 117 |
+
|
| 118 |
+
### Download & Assembly
|
| 119 |
+
|
| 120 |
+
Download all the datasets used in our paper from the 🤗 Hub. You may want to use the CLI provided by the [Hub Python Library](https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-download):
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
huggingface-cli download camlab-ethz/<DATASET IDENTIFIER FROM PAPER> --repo-type dataset --local-dir <LOCAL DIRECTORY>
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
This will download a specific dataset to the specified `LOCAL DIRECTORY`. After download, you need to assemble the datasets to the format expected by the code; for that, we refer to the README in the respective dataset repository. After assembly, remove the chunked dataset files, as they are not needed for training, and place the assembled dataset at the path you specify as `--data_path` for the training/inference script. You may also specify the 🤗 Hub cache location by specifying the environment variable `HF_HOME` as this is where the download will be performed to.
|
| 127 |
+
|
| 128 |
+
### Adding your own dataset
|
| 129 |
+
|
| 130 |
+
We encourage adding your own datasets. For that, you can subclass from [BaseDataset and BaseTimeDataset](scOT/problems/base.py) and add it to the `get_dataset` selector method. You can then use the dataset in the training script by specifying the dataset identifier in the config file.
|
| 131 |
+
|
| 132 |
+
For subclassing, we refer to the docstrings in the base classes and the existing datasets in the [problems](scOT/problems) folder.
|
| 133 |
+
|
| 134 |
+
## Pretrained models
|
| 135 |
+
|
| 136 |
+
Pretrained models are available on the 🤗 Hub, see the [Poseidon collection](https://huggingface.co/collections/camlab-ethz/poseidon-664fa125729c53d8607e209a) for all models. You can download them via the 🤗 Hub API or by using the `from_pretrained` method, see above.
|
| 137 |
+
|
| 138 |
+
## Citation
|
| 139 |
+
|
| 140 |
+
If you use our models, code, or datasets, please consider citing our paper:
|
| 141 |
+
|
| 142 |
+
```bibtex
|
| 143 |
+
@misc{herde2024poseidon,
|
| 144 |
+
title={Poseidon: Efficient Foundation Models for PDEs},
|
| 145 |
+
author={Maximilian Herde and Bogdan Raonić and Tobias Rohner and Roger Käppeli and Roberto Molinaro and Emmanuel de Bézenac and Siddhartha Mishra},
|
| 146 |
+
year={2024},
|
| 147 |
+
eprint={2405.19101},
|
| 148 |
+
archivePrefix={arXiv},
|
| 149 |
+
primaryClass={cs.LG}
|
| 150 |
+
}
|
| 151 |
+
```
|
external/poseidon/assets/fig1.png
ADDED
|
Git LFS Details
|
external/poseidon/configs/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration Files
|
| 2 |
+
|
| 3 |
+
We give two sample configuration files. One for a single finetuning run and one for a finetuning sweep. Both finetune the Poseidon-B model on Wave-Layer.
|
external/poseidon/configs/run.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
value: "wave.Layer"
|
| 3 |
+
num_trajectories:
|
| 4 |
+
value: 128
|
| 5 |
+
model_name:
|
| 6 |
+
value: "B"
|
| 7 |
+
lr:
|
| 8 |
+
value: 0.00005
|
| 9 |
+
lr_embedding_recovery:
|
| 10 |
+
value: 0.0005
|
| 11 |
+
lr_time_embedding:
|
| 12 |
+
value: 0.0005
|
| 13 |
+
weight_decay:
|
| 14 |
+
value: 0.000001
|
| 15 |
+
lr_scheduler:
|
| 16 |
+
value: "cosine"
|
| 17 |
+
warmup_ratio:
|
| 18 |
+
value: 0.0
|
| 19 |
+
early_stopping_patience:
|
| 20 |
+
value: 200
|
| 21 |
+
num_epochs:
|
| 22 |
+
value: 200
|
| 23 |
+
batch_size:
|
| 24 |
+
value: 40
|
| 25 |
+
max_grad_norm:
|
| 26 |
+
value: 5.0
|
external/poseidon/configs/sweep.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
project: <WANDB_PROJECT>
|
| 2 |
+
entity: <WANDB_ENTITY>
|
| 3 |
+
program: scOT/train.py
|
| 4 |
+
method: grid
|
| 5 |
+
metric:
|
| 6 |
+
name: "eval/loss"
|
| 7 |
+
goal: minimize
|
| 8 |
+
command:
|
| 9 |
+
- "HDF5_USE_FILE_LOCKING=FALSE"
|
| 10 |
+
- "accelerate"
|
| 11 |
+
- "launch"
|
| 12 |
+
- ${program}
|
| 13 |
+
- "--disable_tqdm"
|
| 14 |
+
- "--json-config"
|
| 15 |
+
- "--finetune_from"
|
| 16 |
+
- "camlab-ethz/Poseidon-B"
|
| 17 |
+
- "--replace_embedding_recovery"
|
| 18 |
+
- "--config"
|
| 19 |
+
- ${args_json}
|
| 20 |
+
parameters:
|
| 21 |
+
dataset:
|
| 22 |
+
value: "wave.Layer"
|
| 23 |
+
num_trajectories:
|
| 24 |
+
values:
|
| 25 |
+
- 1
|
| 26 |
+
- 2
|
| 27 |
+
- 4
|
| 28 |
+
- 8
|
| 29 |
+
- 16
|
| 30 |
+
- 32
|
| 31 |
+
- 64
|
| 32 |
+
- 128
|
| 33 |
+
- 256
|
| 34 |
+
- 512
|
| 35 |
+
- 1024
|
| 36 |
+
model_name:
|
| 37 |
+
value: "B"
|
| 38 |
+
lr:
|
| 39 |
+
value: 0.00005
|
| 40 |
+
lr_embedding_recovery:
|
| 41 |
+
value: 0.0005
|
| 42 |
+
lr_time_embedding:
|
| 43 |
+
value: 0.0005
|
| 44 |
+
weight_decay:
|
| 45 |
+
value: 0.000001
|
| 46 |
+
lr_scheduler:
|
| 47 |
+
value: "cosine"
|
| 48 |
+
warmup_ratio:
|
| 49 |
+
value: 0.0
|
| 50 |
+
early_stopping_patience:
|
| 51 |
+
value: 200
|
| 52 |
+
num_epochs:
|
| 53 |
+
value: 200
|
| 54 |
+
batch_size:
|
| 55 |
+
value: 40
|
| 56 |
+
max_grad_norm:
|
| 57 |
+
value: 5.0
|
external/poseidon/pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "scOT"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "Foundation models for PDEs based on a scalable Operator Transformer"
|
| 5 |
+
dependencies = [
|
| 6 |
+
"torch == 2.0.1",
|
| 7 |
+
"torchvision == 0.15.2",
|
| 8 |
+
"numpy",
|
| 9 |
+
"transformers == 4.29.2",
|
| 10 |
+
"matplotlib",
|
| 11 |
+
"accelerate == 0.31.0",
|
| 12 |
+
"wandb == 0.14.2",
|
| 13 |
+
"h5py",
|
| 14 |
+
"pandas",
|
| 15 |
+
"pyyaml",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[build-system]
|
| 19 |
+
build-backend = "flit_core.buildapi"
|
| 20 |
+
requires = ["flit_core >=3.2,<4"]
|
external/poseidon/scOT/__init__.py
ADDED
|
File without changes
|
external/poseidon/scOT/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
external/poseidon/scOT/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (33.9 kB). View file
|
|
|
external/poseidon/scOT/inference.py
ADDED
|
@@ -0,0 +1,950 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Use this script for inference/testing a scOT model.
|
| 3 |
+
The script can be used in different modes:
|
| 4 |
+
- save_samples: Save samples from a model.
|
| 5 |
+
- save_samples_sweep: Save samples from a sweep.
|
| 6 |
+
- eval: Evaluate a model on the test set.
|
| 7 |
+
- eval_sweep: Evaluate a sweep on the test set.
|
| 8 |
+
- eval_accumulation_error: Evaluate the accumulation error of a model.
|
| 9 |
+
- eval_resolutions: Evaluate a model on different resolutions.
|
| 10 |
+
|
| 11 |
+
See the --help page for more information.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import torch
|
| 16 |
+
import numpy as np
|
| 17 |
+
import random
|
| 18 |
+
import psutil
|
| 19 |
+
import os
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import wandb
|
| 22 |
+
from transformers.trainer_utils import EvalPrediction
|
| 23 |
+
from scOT.model import ScOT
|
| 24 |
+
from scOT.trainer import TrainingArguments, Trainer
|
| 25 |
+
from scOT.problems.base import get_dataset, BaseTimeDataset
|
| 26 |
+
from scOT.metrics import relative_lp_error, lp_error
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
SEED = 0
|
| 30 |
+
torch.manual_seed(SEED)
|
| 31 |
+
np.random.seed(SEED)
|
| 32 |
+
random.seed(SEED)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_trainer(
|
| 36 |
+
model_path,
|
| 37 |
+
batch_size,
|
| 38 |
+
dataset,
|
| 39 |
+
full_data=False,
|
| 40 |
+
output_all_steps=False,
|
| 41 |
+
workers=-1,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Get a trainer for the model (actually just using the interface for inference).
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
model_path: str
|
| 48 |
+
Path to the model.
|
| 49 |
+
batch_size: int
|
| 50 |
+
Batch size for evaluation.
|
| 51 |
+
dataset: BaseTimeDataset
|
| 52 |
+
Test set.
|
| 53 |
+
full_data: bool
|
| 54 |
+
Whether to save the full data distribution.
|
| 55 |
+
output_all_steps: bool
|
| 56 |
+
Whether to output all preliminary steps in autoregressive rollout.
|
| 57 |
+
workers: int
|
| 58 |
+
Number of workers for evaluation. If -1 will use all available cores.
|
| 59 |
+
"""
|
| 60 |
+
num_cpu_cores = len(psutil.Process().cpu_affinity())
|
| 61 |
+
if workers == -1:
|
| 62 |
+
workers = num_cpu_cores
|
| 63 |
+
if workers > num_cpu_cores:
|
| 64 |
+
workers = num_cpu_cores
|
| 65 |
+
assert workers > 0
|
| 66 |
+
|
| 67 |
+
model = ScOT.from_pretrained(model_path)
|
| 68 |
+
args = TrainingArguments(
|
| 69 |
+
output_dir=".",
|
| 70 |
+
per_device_eval_batch_size=batch_size,
|
| 71 |
+
eval_accumulation_steps=16,
|
| 72 |
+
dataloader_num_workers=workers,
|
| 73 |
+
)
|
| 74 |
+
time_involved = isinstance(dataset, BaseTimeDataset)
|
| 75 |
+
|
| 76 |
+
def compute_metrics(eval_preds):
|
| 77 |
+
if time_involved and output_all_steps:
|
| 78 |
+
return {}
|
| 79 |
+
channel_list = dataset.channel_slice_list
|
| 80 |
+
|
| 81 |
+
def get_relative_statistics(errors):
|
| 82 |
+
median_error = np.median(errors, axis=0)
|
| 83 |
+
mean_error = np.mean(errors, axis=0)
|
| 84 |
+
std_error = np.std(errors, axis=0)
|
| 85 |
+
min_error = np.min(errors, axis=0)
|
| 86 |
+
max_error = np.max(errors, axis=0)
|
| 87 |
+
return {
|
| 88 |
+
"median_relative_l1_error": median_error,
|
| 89 |
+
"mean_relative_l1_error": mean_error,
|
| 90 |
+
"std_relative_l1_error": std_error,
|
| 91 |
+
"min_relative_l1_error": min_error,
|
| 92 |
+
"max_relative_l1_error": max_error,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def get_statistics(errors):
|
| 96 |
+
median_error = np.median(errors, axis=0)
|
| 97 |
+
mean_error = np.mean(errors, axis=0)
|
| 98 |
+
std_error = np.std(errors, axis=0)
|
| 99 |
+
min_error = np.min(errors, axis=0)
|
| 100 |
+
max_error = np.max(errors, axis=0)
|
| 101 |
+
return {
|
| 102 |
+
"median_l1_error": median_error,
|
| 103 |
+
"mean_l1_error": mean_error,
|
| 104 |
+
"std_l1_error": std_error,
|
| 105 |
+
"min_l1_error": min_error,
|
| 106 |
+
"max_l1_error": max_error,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
relative_errors = [
|
| 110 |
+
relative_lp_error(
|
| 111 |
+
eval_preds.predictions[:, channel_list[i] : channel_list[i + 1]],
|
| 112 |
+
eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]],
|
| 113 |
+
p=1,
|
| 114 |
+
return_percent=True,
|
| 115 |
+
)
|
| 116 |
+
for i in range(len(channel_list) - 1)
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
errors = [
|
| 120 |
+
lp_error(
|
| 121 |
+
eval_preds.predictions[:, channel_list[i] : channel_list[i + 1]],
|
| 122 |
+
eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]],
|
| 123 |
+
p=1,
|
| 124 |
+
)
|
| 125 |
+
for i in range(len(channel_list) - 1)
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
relative_error_statistics = [
|
| 129 |
+
get_relative_statistics(relative_errors[i])
|
| 130 |
+
for i in range(len(channel_list) - 1)
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
error_statistics = [
|
| 134 |
+
get_statistics(errors[i]) for i in range(len(channel_list) - 1)
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
if dataset.output_dim == 1:
|
| 138 |
+
relative_error_statistics = relative_error_statistics[0]
|
| 139 |
+
error_statistics = error_statistics[0]
|
| 140 |
+
if full_data:
|
| 141 |
+
relative_error_statistics["relative_full_data"] = relative_errors[
|
| 142 |
+
0
|
| 143 |
+
].tolist()
|
| 144 |
+
error_statistics["full_data"] = errors[0].tolist()
|
| 145 |
+
return {**relative_error_statistics, **error_statistics}
|
| 146 |
+
else:
|
| 147 |
+
mean_over_relative_means = np.mean(
|
| 148 |
+
np.array(
|
| 149 |
+
[
|
| 150 |
+
stats["mean_relative_l1_error"]
|
| 151 |
+
for stats in relative_error_statistics
|
| 152 |
+
]
|
| 153 |
+
),
|
| 154 |
+
axis=0,
|
| 155 |
+
)
|
| 156 |
+
mean_over_relative_medians = np.mean(
|
| 157 |
+
np.array(
|
| 158 |
+
[
|
| 159 |
+
stats["median_relative_l1_error"]
|
| 160 |
+
for stats in relative_error_statistics
|
| 161 |
+
]
|
| 162 |
+
),
|
| 163 |
+
axis=0,
|
| 164 |
+
)
|
| 165 |
+
mean_over_means = np.mean(
|
| 166 |
+
np.array([stats["mean_l1_error"] for stats in error_statistics]), axis=0
|
| 167 |
+
)
|
| 168 |
+
mean_over_medians = np.mean(
|
| 169 |
+
np.array([stats["median_l1_error"] for stats in error_statistics]),
|
| 170 |
+
axis=0,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
error_statistics_ = {
|
| 174 |
+
"mean_relative_l1_error": mean_over_relative_means,
|
| 175 |
+
"mean_over_median_relative_l1_error": mean_over_relative_medians,
|
| 176 |
+
"mean_l1_error": mean_over_means,
|
| 177 |
+
"mean_over_median_l1_error": mean_over_medians,
|
| 178 |
+
}
|
| 179 |
+
#!! The above is different from train and finetune (here mean_relative_l1_error is mean over medians instead of mean over means)
|
| 180 |
+
for i, stats in enumerate(relative_error_statistics):
|
| 181 |
+
for key, value in stats.items():
|
| 182 |
+
error_statistics_[
|
| 183 |
+
dataset.printable_channel_description[i] + "/" + key
|
| 184 |
+
] = value
|
| 185 |
+
if full_data:
|
| 186 |
+
error_statistics_[
|
| 187 |
+
dataset.printable_channel_description[i]
|
| 188 |
+
+ "/"
|
| 189 |
+
+ "relative_full_data"
|
| 190 |
+
] = relative_errors[i].tolist()
|
| 191 |
+
for i, stats in enumerate(error_statistics):
|
| 192 |
+
for key, value in stats.items():
|
| 193 |
+
error_statistics_[
|
| 194 |
+
dataset.printable_channel_description[i] + "/" + key
|
| 195 |
+
] = value
|
| 196 |
+
if full_data:
|
| 197 |
+
error_statistics_[
|
| 198 |
+
dataset.printable_channel_description[i] + "/" + "full_data"
|
| 199 |
+
] = errors[i].tolist()
|
| 200 |
+
return error_statistics_
|
| 201 |
+
|
| 202 |
+
trainer = Trainer(
|
| 203 |
+
model=model,
|
| 204 |
+
args=args,
|
| 205 |
+
compute_metrics=compute_metrics,
|
| 206 |
+
)
|
| 207 |
+
return trainer
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def rollout(trainer, dataset, ar_steps=1, output_all_steps=False):
|
| 211 |
+
"""
|
| 212 |
+
Do a rollout of the model.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
trainer: Trainer
|
| 216 |
+
Trainer for the model.
|
| 217 |
+
dataset: BaseTimeDataset
|
| 218 |
+
Test set.
|
| 219 |
+
ar_steps: int or list
|
| 220 |
+
Number of autoregressive steps to take. A single int n is interpreted as taking n homogeneous steps, a list of ints [j_0, j_1, ...] is interpreted as taking a step of size j_i.
|
| 221 |
+
output_all_steps: bool
|
| 222 |
+
Whether to output all preliminary steps in autoregressive rollout.
|
| 223 |
+
"""
|
| 224 |
+
time_involved = isinstance(dataset, BaseTimeDataset)
|
| 225 |
+
if time_involved and ar_steps != 1:
|
| 226 |
+
trainer.set_ar_steps(ar_steps, output_all_steps=output_all_steps)
|
| 227 |
+
else:
|
| 228 |
+
trainer.set_ar_steps(ar_steps=1, output_all_steps=False)
|
| 229 |
+
|
| 230 |
+
prediction = trainer.predict(dataset, metric_key_prefix="")
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
return prediction.predictions, prediction.label_ids, prediction.metrics
|
| 234 |
+
except:
|
| 235 |
+
return prediction.predictions
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_test_set(
|
| 239 |
+
dataset, data_path, initial_time=None, final_time=None, dataset_kwargs={}
|
| 240 |
+
):
|
| 241 |
+
"""
|
| 242 |
+
Get a test set (input at initial_time, output at final_time).
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
dataset: str
|
| 246 |
+
Dataset name.
|
| 247 |
+
data_path: str
|
| 248 |
+
Path to data.
|
| 249 |
+
initial_time: int
|
| 250 |
+
Initial time step to start from.
|
| 251 |
+
final_time: int
|
| 252 |
+
Final time step to end at.
|
| 253 |
+
dataset_kwargs: dict
|
| 254 |
+
Additional arguments for dataset as in scOT.problems.base.get_dataset.
|
| 255 |
+
"""
|
| 256 |
+
if initial_time is not None and final_time is not None:
|
| 257 |
+
dataset_kwargs = {
|
| 258 |
+
**dataset_kwargs,
|
| 259 |
+
"fix_input_to_time_step": initial_time,
|
| 260 |
+
"time_step_size": final_time - initial_time,
|
| 261 |
+
"max_num_time_steps": 1,
|
| 262 |
+
}
|
| 263 |
+
dataset = get_dataset(
|
| 264 |
+
dataset=dataset,
|
| 265 |
+
which="test",
|
| 266 |
+
num_trajectories=1,
|
| 267 |
+
data_path=data_path,
|
| 268 |
+
move_to_local_scratch=None,
|
| 269 |
+
**dataset_kwargs,
|
| 270 |
+
)
|
| 271 |
+
return dataset
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_first_n_inputs(dataset, n):
|
| 275 |
+
"""
|
| 276 |
+
Helper to get the first n inputs of a dataset.
|
| 277 |
+
"""
|
| 278 |
+
inputs = []
|
| 279 |
+
for i in range(n):
|
| 280 |
+
inputs.append(dataset[i]["pixel_values"])
|
| 281 |
+
return torch.stack(inputs)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def get_trajectories(
|
| 285 |
+
dataset, data_path, ar_steps, initial_time, final_time, dataset_kwargs
|
| 286 |
+
):
|
| 287 |
+
"""
|
| 288 |
+
Get full trajectories in a dataset. Helper for accumulation error evaluation.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
dataset: str
|
| 292 |
+
Dataset name.
|
| 293 |
+
data_path: str
|
| 294 |
+
Path to data.
|
| 295 |
+
ar_steps: int or list
|
| 296 |
+
Number of autoregressive steps to take. A single int n is interpreted as taking n homogeneous steps, a list of ints [j_0, j_1, ...] is interpreted as taking a step of size j_i.
|
| 297 |
+
initial_time: int
|
| 298 |
+
Initial time step to start from.
|
| 299 |
+
final_time: int
|
| 300 |
+
Final time step to end at.
|
| 301 |
+
dataset_kwargs: dict
|
| 302 |
+
Additional arguments for dataset as in scOT.problems.base.get_dataset.
|
| 303 |
+
"""
|
| 304 |
+
trajectories = []
|
| 305 |
+
if isinstance(ar_steps, int):
|
| 306 |
+
delta = (final_time - initial_time) // ar_steps
|
| 307 |
+
for i in range(ar_steps):
|
| 308 |
+
dataset_ = get_test_set(
|
| 309 |
+
dataset,
|
| 310 |
+
data_path,
|
| 311 |
+
initial_time + i * delta,
|
| 312 |
+
initial_time + (i + 1) * delta,
|
| 313 |
+
dataset_kwargs,
|
| 314 |
+
)
|
| 315 |
+
traj_ = []
|
| 316 |
+
for j in range(len(dataset_)):
|
| 317 |
+
traj_.append(dataset_[j]["labels"])
|
| 318 |
+
trajectories.append(torch.stack(traj_))
|
| 319 |
+
else:
|
| 320 |
+
running_time = initial_time
|
| 321 |
+
for i in ar_steps:
|
| 322 |
+
dataset_ = get_test_set(
|
| 323 |
+
dataset, data_path, running_time, running_time + i, dataset_kwargs
|
| 324 |
+
)
|
| 325 |
+
running_time += i
|
| 326 |
+
traj_ = []
|
| 327 |
+
for j in range(len(dataset_)):
|
| 328 |
+
traj_.append(dataset_[j]["labels"])
|
| 329 |
+
trajectories.append(torch.stack(traj_))
|
| 330 |
+
return torch.stack(trajectories, dim=1)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def remove_underscore_dict(d):
|
| 334 |
+
return {key[1:] if key.startswith("_") else key: value for key, value in d.items()}
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
if __name__ == "__main__":
|
| 338 |
+
parser = argparse.ArgumentParser(
|
| 339 |
+
description="Do different evaluations for a model, see --mode."
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--model_path",
|
| 343 |
+
type=str,
|
| 344 |
+
required=False,
|
| 345 |
+
help="Model path. Not required when mode==eval_sweep or save_samples_sweep.",
|
| 346 |
+
)
|
| 347 |
+
parser.add_argument(
|
| 348 |
+
"--file",
|
| 349 |
+
type=str,
|
| 350 |
+
required=True,
|
| 351 |
+
help="File to load/write to. May also be a directory to save samples.",
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--data_path",
|
| 355 |
+
type=str,
|
| 356 |
+
required=True,
|
| 357 |
+
help="Path to data.",
|
| 358 |
+
)
|
| 359 |
+
parser.add_argument(
|
| 360 |
+
"--dataset",
|
| 361 |
+
type=str,
|
| 362 |
+
help="Which test set to load. Not required if mode==eval_sweep or save_samples_sweep.",
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--batch_size",
|
| 366 |
+
type=int,
|
| 367 |
+
default=64,
|
| 368 |
+
help="Batch size for evaluation.",
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--full_data",
|
| 372 |
+
action="store_true",
|
| 373 |
+
help="Whether to save full data distributions.",
|
| 374 |
+
)
|
| 375 |
+
parser.add_argument(
|
| 376 |
+
"--initial_time",
|
| 377 |
+
type=int,
|
| 378 |
+
default=None,
|
| 379 |
+
help="Initial time step to start from.",
|
| 380 |
+
)
|
| 381 |
+
parser.add_argument(
|
| 382 |
+
"--final_time",
|
| 383 |
+
type=int,
|
| 384 |
+
default=None,
|
| 385 |
+
help="Final time step to end at.",
|
| 386 |
+
)
|
| 387 |
+
parser.add_argument(
|
| 388 |
+
"--ar_steps",
|
| 389 |
+
type=int,
|
| 390 |
+
nargs="+",
|
| 391 |
+
default=[1],
|
| 392 |
+
help="Number of autoregressive steps to take. A single int n is interpreted as taking n homogeneous steps, a list of ints [j_0, j_1, ...] is interpreted as taking a step of size j_i.",
|
| 393 |
+
)
|
| 394 |
+
parser.add_argument(
|
| 395 |
+
"--mode",
|
| 396 |
+
type=str,
|
| 397 |
+
choices=[
|
| 398 |
+
"save_samples",
|
| 399 |
+
"save_samples_sweep",
|
| 400 |
+
"eval",
|
| 401 |
+
"eval_sweep",
|
| 402 |
+
"eval_accumulation_error",
|
| 403 |
+
"eval_resolutions",
|
| 404 |
+
],
|
| 405 |
+
default="eval",
|
| 406 |
+
help="Mode to run. Can be either save_samples to save n samples, save_samples_sweep, eval (to evaluate a single model), eval_sweep (to evaluate all models in a wandb sweep), eval_accumulation_error (to evaluate a model's accumulation error), eval_resolutions (to evaluate a model on different resolutions).",
|
| 407 |
+
)
|
| 408 |
+
parser.add_argument(
|
| 409 |
+
"--save_n_samples",
|
| 410 |
+
type=int,
|
| 411 |
+
default=1,
|
| 412 |
+
help="Number of samples to save. Only required for mode==save_samples or save_samples_sweep.",
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--resolutions",
|
| 416 |
+
type=int,
|
| 417 |
+
nargs="+",
|
| 418 |
+
help="List of resolutions to evaluate. Only required for mode==eval_resolutions.",
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--wandb_project",
|
| 422 |
+
type=str,
|
| 423 |
+
default="scOT",
|
| 424 |
+
help="Wandb project name. Required if mode==eval_sweep or save_samples_sweep.",
|
| 425 |
+
)
|
| 426 |
+
parser.add_argument(
|
| 427 |
+
"--wandb_entity",
|
| 428 |
+
type=str,
|
| 429 |
+
required=False,
|
| 430 |
+
help="Wandb entity name. Required if mode==eval_sweep or save_samples_sweep.",
|
| 431 |
+
)
|
| 432 |
+
parser.add_argument(
|
| 433 |
+
"--wandb_sweep_id",
|
| 434 |
+
type=str,
|
| 435 |
+
default=None,
|
| 436 |
+
help="Wandb sweep id. Required if mode==eval_sweep or save_samples_sweep.",
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument(
|
| 439 |
+
"--ckpt_dir",
|
| 440 |
+
type=str,
|
| 441 |
+
required=True,
|
| 442 |
+
help="Base checkpoint directory. Required if mode==eval_sweep or save_samples_sweep.",
|
| 443 |
+
)
|
| 444 |
+
parser.add_argument(
|
| 445 |
+
"--exclude_dataset",
|
| 446 |
+
type=str,
|
| 447 |
+
nargs="+",
|
| 448 |
+
default=[],
|
| 449 |
+
help="Datasets to exclude from evaluation. Only relevant when mode==eval_sweep or save_samples_sweep.",
|
| 450 |
+
)
|
| 451 |
+
parser.add_argument(
|
| 452 |
+
"--exclusively_evaluate_dataset",
|
| 453 |
+
type=str,
|
| 454 |
+
nargs="+",
|
| 455 |
+
default=[],
|
| 456 |
+
help="Datasets to exclusively evaluate. Only relevant when mode==eval_sweep or save_samples_sweep.",
|
| 457 |
+
)
|
| 458 |
+
parser.add_argument(
|
| 459 |
+
"--just_velocities",
|
| 460 |
+
action="store_true",
|
| 461 |
+
help="Use just velocities in incompressible flow data.",
|
| 462 |
+
)
|
| 463 |
+
parser.add_argument(
|
| 464 |
+
"--allow_failed",
|
| 465 |
+
action="store_true",
|
| 466 |
+
help="Allow failed runs to be taken into account with eval_sweep.",
|
| 467 |
+
)
|
| 468 |
+
parser.add_argument(
|
| 469 |
+
"--append_time",
|
| 470 |
+
action="store_true",
|
| 471 |
+
help="Append .time to dataset name for evaluation.",
|
| 472 |
+
)
|
| 473 |
+
parser.add_argument(
|
| 474 |
+
"--num_trajectories",
|
| 475 |
+
type=int,
|
| 476 |
+
default=128,
|
| 477 |
+
help="Filter runs for number of training trajectories. Only relevant if mode==eval_sweep or save_samples_sweep.",
|
| 478 |
+
)
|
| 479 |
+
params = parser.parse_args()
|
| 480 |
+
if len(params.ar_steps) == 1:
|
| 481 |
+
params.ar_steps = params.ar_steps[0]
|
| 482 |
+
ar_steps = params.ar_steps
|
| 483 |
+
else:
|
| 484 |
+
ar_steps = params.ar_steps
|
| 485 |
+
params.ar_steps = [
|
| 486 |
+
step / (params.final_time - params.initial_time) for step in params.ar_steps
|
| 487 |
+
]
|
| 488 |
+
dataset_kwargs = {}
|
| 489 |
+
if params.just_velocities:
|
| 490 |
+
dataset_kwargs["just_velocities"] = True
|
| 491 |
+
if params.mode == "save_samples":
|
| 492 |
+
dataset = get_test_set(
|
| 493 |
+
params.dataset,
|
| 494 |
+
params.data_path,
|
| 495 |
+
params.initial_time,
|
| 496 |
+
params.final_time,
|
| 497 |
+
dataset_kwargs,
|
| 498 |
+
)
|
| 499 |
+
trainer = get_trainer(params.model_path, params.batch_size, dataset)
|
| 500 |
+
inputs = get_first_n_inputs(dataset, params.save_n_samples)
|
| 501 |
+
outputs, labels, _ = rollout(trainer, dataset, ar_steps=params.ar_steps)
|
| 502 |
+
np.save(
|
| 503 |
+
params.file + "/" + params.dataset.replace(".", "-") + "/" + "inputs.npy",
|
| 504 |
+
inputs.cpu().numpy(),
|
| 505 |
+
)
|
| 506 |
+
np.save(
|
| 507 |
+
params.file + "/" + params.dataset.replace(".", "-") + "/" + "labels.npy",
|
| 508 |
+
labels[: params.save_n_samples],
|
| 509 |
+
)
|
| 510 |
+
np.save(
|
| 511 |
+
params.file + "/" + params.dataset.replace(".", "-") + "/" + "outputs.npy",
|
| 512 |
+
outputs[: params.save_n_samples],
|
| 513 |
+
)
|
| 514 |
+
elif params.mode == "save_samples_sweep":
|
| 515 |
+
api = wandb.Api()
|
| 516 |
+
sweep = api.sweep(
|
| 517 |
+
params.wandb_entity
|
| 518 |
+
+ "/"
|
| 519 |
+
+ params.wandb_project
|
| 520 |
+
+ "/"
|
| 521 |
+
+ params.wandb_sweep_id
|
| 522 |
+
)
|
| 523 |
+
for run in sweep.runs:
|
| 524 |
+
if run.state == "finished" or (
|
| 525 |
+
params.allow_failed and run.state == "failed"
|
| 526 |
+
):
|
| 527 |
+
dset_name = run.config["dataset"]
|
| 528 |
+
if run.config["num_trajectories"] != params.num_trajectories:
|
| 529 |
+
continue
|
| 530 |
+
if dset_name in params.exclude_dataset:
|
| 531 |
+
continue
|
| 532 |
+
if (
|
| 533 |
+
len(params.exclusively_evaluate_dataset) > 0
|
| 534 |
+
and dset_name not in params.exclusively_evaluate_dataset
|
| 535 |
+
):
|
| 536 |
+
continue
|
| 537 |
+
num_trajectories = run.config["num_trajectories"]
|
| 538 |
+
ckpt_dir = (
|
| 539 |
+
params.ckpt_dir
|
| 540 |
+
+ "/"
|
| 541 |
+
+ params.wandb_project
|
| 542 |
+
+ "/"
|
| 543 |
+
+ params.wandb_sweep_id
|
| 544 |
+
+ "/"
|
| 545 |
+
+ run.name
|
| 546 |
+
)
|
| 547 |
+
items = os.listdir(ckpt_dir)
|
| 548 |
+
dirs = [
|
| 549 |
+
item
|
| 550 |
+
for item in items
|
| 551 |
+
if os.path.isdir(os.path.join(ckpt_dir, item))
|
| 552 |
+
]
|
| 553 |
+
if len(dirs) > 1:
|
| 554 |
+
print(
|
| 555 |
+
"WARNING: more than one checkpoint in run directory " + ckpt_dir
|
| 556 |
+
)
|
| 557 |
+
print("choosing " + dirs[0])
|
| 558 |
+
model_path = os.path.join(ckpt_dir, dirs[0])
|
| 559 |
+
dataset = get_test_set(
|
| 560 |
+
dset_name,
|
| 561 |
+
params.data_path,
|
| 562 |
+
params.initial_time,
|
| 563 |
+
params.final_time,
|
| 564 |
+
dataset_kwargs,
|
| 565 |
+
)
|
| 566 |
+
trainer = get_trainer(model_path, params.batch_size, dataset)
|
| 567 |
+
inputs = get_first_n_inputs(dataset, params.save_n_samples)
|
| 568 |
+
outputs, labels, _ = rollout(trainer, dataset, ar_steps=params.ar_steps)
|
| 569 |
+
if not os.path.exists(params.file + "/" + dset_name.replace(".", "-")):
|
| 570 |
+
os.makedirs(params.file + "/" + dset_name.replace(".", "-"))
|
| 571 |
+
if not os.path.exists(
|
| 572 |
+
params.file
|
| 573 |
+
+ "/"
|
| 574 |
+
+ dset_name.replace(".", "-")
|
| 575 |
+
+ "/"
|
| 576 |
+
+ str(num_trajectories)
|
| 577 |
+
):
|
| 578 |
+
os.makedirs(
|
| 579 |
+
params.file
|
| 580 |
+
+ "/"
|
| 581 |
+
+ dset_name.replace(".", "-")
|
| 582 |
+
+ "/"
|
| 583 |
+
+ str(num_trajectories)
|
| 584 |
+
)
|
| 585 |
+
np.save(
|
| 586 |
+
params.file
|
| 587 |
+
+ "/"
|
| 588 |
+
+ dset_name.replace(".", "-")
|
| 589 |
+
+ "/"
|
| 590 |
+
+ str(num_trajectories)
|
| 591 |
+
+ "/inputs.npy",
|
| 592 |
+
inputs.cpu().numpy(),
|
| 593 |
+
)
|
| 594 |
+
np.save(
|
| 595 |
+
params.file
|
| 596 |
+
+ "/"
|
| 597 |
+
+ dset_name.replace(".", "-")
|
| 598 |
+
+ "/"
|
| 599 |
+
+ str(num_trajectories)
|
| 600 |
+
+ "/labels.npy",
|
| 601 |
+
labels[: params.save_n_samples],
|
| 602 |
+
)
|
| 603 |
+
np.save(
|
| 604 |
+
params.file
|
| 605 |
+
+ "/"
|
| 606 |
+
+ dset_name.replace(".", "-")
|
| 607 |
+
+ "/"
|
| 608 |
+
+ str(num_trajectories)
|
| 609 |
+
+ "/"
|
| 610 |
+
+ "outputs.npy",
|
| 611 |
+
outputs[: params.save_n_samples],
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
if params.mode == "eval":
|
| 615 |
+
dataset = get_test_set(
|
| 616 |
+
params.dataset,
|
| 617 |
+
params.data_path,
|
| 618 |
+
params.initial_time,
|
| 619 |
+
params.final_time,
|
| 620 |
+
dataset_kwargs,
|
| 621 |
+
)
|
| 622 |
+
trainer = get_trainer(
|
| 623 |
+
params.model_path,
|
| 624 |
+
params.batch_size,
|
| 625 |
+
dataset,
|
| 626 |
+
full_data=params.full_data,
|
| 627 |
+
)
|
| 628 |
+
_, _, metrics = rollout(
|
| 629 |
+
trainer,
|
| 630 |
+
dataset,
|
| 631 |
+
ar_steps=params.ar_steps,
|
| 632 |
+
output_all_steps=False,
|
| 633 |
+
)
|
| 634 |
+
data = {
|
| 635 |
+
"dataset": params.dataset,
|
| 636 |
+
"initial_time": params.initial_time,
|
| 637 |
+
"final_time": params.final_time,
|
| 638 |
+
"ar_steps": ar_steps,
|
| 639 |
+
**metrics,
|
| 640 |
+
}
|
| 641 |
+
data = [remove_underscore_dict(data)]
|
| 642 |
+
elif params.mode == "eval_sweep":
|
| 643 |
+
api = wandb.Api()
|
| 644 |
+
sweep = api.sweep(
|
| 645 |
+
params.wandb_entity
|
| 646 |
+
+ "/"
|
| 647 |
+
+ params.wandb_project
|
| 648 |
+
+ "/"
|
| 649 |
+
+ params.wandb_sweep_id
|
| 650 |
+
)
|
| 651 |
+
data = []
|
| 652 |
+
for run in sweep.runs:
|
| 653 |
+
if run.state == "finished" or (
|
| 654 |
+
params.allow_failed and run.state == "failed"
|
| 655 |
+
):
|
| 656 |
+
dset_name = (
|
| 657 |
+
run.config["dataset"]
|
| 658 |
+
if not params.append_time
|
| 659 |
+
else run.config["dataset"] + ".time"
|
| 660 |
+
)
|
| 661 |
+
if dset_name in params.exclude_dataset:
|
| 662 |
+
continue
|
| 663 |
+
if (
|
| 664 |
+
len(params.exclusively_evaluate_dataset) > 0
|
| 665 |
+
and dset_name not in params.exclusively_evaluate_dataset
|
| 666 |
+
):
|
| 667 |
+
continue
|
| 668 |
+
num_trajectories = run.config["num_trajectories"]
|
| 669 |
+
ckpt_dir = (
|
| 670 |
+
params.ckpt_dir
|
| 671 |
+
+ "/"
|
| 672 |
+
+ params.wandb_project
|
| 673 |
+
+ "/"
|
| 674 |
+
+ params.wandb_sweep_id
|
| 675 |
+
+ "/"
|
| 676 |
+
+ run.name
|
| 677 |
+
)
|
| 678 |
+
items = os.listdir(ckpt_dir)
|
| 679 |
+
dirs = [
|
| 680 |
+
item
|
| 681 |
+
for item in items
|
| 682 |
+
if os.path.isdir(os.path.join(ckpt_dir, item))
|
| 683 |
+
]
|
| 684 |
+
if len(dirs) > 1:
|
| 685 |
+
print(
|
| 686 |
+
"WARNING: more than one checkpoint in run directory "
|
| 687 |
+
+ ckpt_dir
|
| 688 |
+
)
|
| 689 |
+
print("choosing " + dirs[0])
|
| 690 |
+
continue
|
| 691 |
+
if len(dirs) == 0:
|
| 692 |
+
continue
|
| 693 |
+
model_path = os.path.join(ckpt_dir, dirs[0])
|
| 694 |
+
dataset = get_test_set(
|
| 695 |
+
dset_name,
|
| 696 |
+
params.data_path,
|
| 697 |
+
params.initial_time,
|
| 698 |
+
params.final_time,
|
| 699 |
+
dataset_kwargs,
|
| 700 |
+
)
|
| 701 |
+
trainer = get_trainer(
|
| 702 |
+
model_path,
|
| 703 |
+
params.batch_size,
|
| 704 |
+
dataset,
|
| 705 |
+
full_data=params.full_data,
|
| 706 |
+
)
|
| 707 |
+
_, _, metrics = rollout(
|
| 708 |
+
trainer,
|
| 709 |
+
dataset,
|
| 710 |
+
ar_steps=params.ar_steps,
|
| 711 |
+
output_all_steps=False,
|
| 712 |
+
)
|
| 713 |
+
data.append(
|
| 714 |
+
remove_underscore_dict(
|
| 715 |
+
{
|
| 716 |
+
"dataset": dset_name,
|
| 717 |
+
"num_trajectories": num_trajectories,
|
| 718 |
+
"initial_time": params.initial_time,
|
| 719 |
+
"final_time": params.final_time,
|
| 720 |
+
"ar_steps": ar_steps,
|
| 721 |
+
**metrics,
|
| 722 |
+
}
|
| 723 |
+
)
|
| 724 |
+
)
|
| 725 |
+
elif params.mode == "eval_accumulation_error":
|
| 726 |
+
dataset = get_test_set(
|
| 727 |
+
params.dataset,
|
| 728 |
+
params.data_path,
|
| 729 |
+
params.initial_time,
|
| 730 |
+
params.final_time,
|
| 731 |
+
dataset_kwargs,
|
| 732 |
+
)
|
| 733 |
+
trainer = get_trainer(
|
| 734 |
+
params.model_path,
|
| 735 |
+
params.batch_size,
|
| 736 |
+
dataset,
|
| 737 |
+
output_all_steps=True,
|
| 738 |
+
full_data=params.full_data,
|
| 739 |
+
)
|
| 740 |
+
predictions, _, _ = rollout(
|
| 741 |
+
trainer,
|
| 742 |
+
dataset,
|
| 743 |
+
ar_steps=params.ar_steps,
|
| 744 |
+
output_all_steps=True,
|
| 745 |
+
)
|
| 746 |
+
labels = get_trajectories(
|
| 747 |
+
params.dataset,
|
| 748 |
+
params.data_path,
|
| 749 |
+
params.ar_steps,
|
| 750 |
+
params.initial_time,
|
| 751 |
+
params.final_time,
|
| 752 |
+
dataset_kwargs,
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
def compute_metrics(eval_preds):
|
| 756 |
+
channel_list = dataset.channel_slice_list
|
| 757 |
+
|
| 758 |
+
def get_relative_statistics(errors):
|
| 759 |
+
median_error = np.median(errors, axis=0)
|
| 760 |
+
mean_error = np.mean(errors, axis=0)
|
| 761 |
+
std_error = np.std(errors, axis=0)
|
| 762 |
+
min_error = np.min(errors, axis=0)
|
| 763 |
+
max_error = np.max(errors, axis=0)
|
| 764 |
+
return {
|
| 765 |
+
"median_relative_l1_error": median_error,
|
| 766 |
+
"mean_relative_l1_error": mean_error,
|
| 767 |
+
"std_relative_l1_error": std_error,
|
| 768 |
+
"min_relative_l1_error": min_error,
|
| 769 |
+
"max_relative_l1_error": max_error,
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
def get_statistics(errors):
|
| 773 |
+
median_error = np.median(errors, axis=0)
|
| 774 |
+
mean_error = np.mean(errors, axis=0)
|
| 775 |
+
std_error = np.std(errors, axis=0)
|
| 776 |
+
min_error = np.min(errors, axis=0)
|
| 777 |
+
max_error = np.max(errors, axis=0)
|
| 778 |
+
return {
|
| 779 |
+
"median_l1_error": median_error,
|
| 780 |
+
"mean_l1_error": mean_error,
|
| 781 |
+
"std_l1_error": std_error,
|
| 782 |
+
"min_l1_error": min_error,
|
| 783 |
+
"max_l1_error": max_error,
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
relative_errors = [
|
| 787 |
+
relative_lp_error(
|
| 788 |
+
eval_preds.predictions[
|
| 789 |
+
:, channel_list[i] : channel_list[i + 1]
|
| 790 |
+
],
|
| 791 |
+
eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]],
|
| 792 |
+
p=1,
|
| 793 |
+
return_percent=True,
|
| 794 |
+
)
|
| 795 |
+
for i in range(len(channel_list) - 1)
|
| 796 |
+
]
|
| 797 |
+
|
| 798 |
+
errors = [
|
| 799 |
+
lp_error(
|
| 800 |
+
eval_preds.predictions[
|
| 801 |
+
:, channel_list[i] : channel_list[i + 1]
|
| 802 |
+
],
|
| 803 |
+
eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]],
|
| 804 |
+
p=1,
|
| 805 |
+
)
|
| 806 |
+
for i in range(len(channel_list) - 1)
|
| 807 |
+
]
|
| 808 |
+
|
| 809 |
+
relative_error_statistics = [
|
| 810 |
+
get_relative_statistics(relative_errors[i])
|
| 811 |
+
for i in range(len(channel_list) - 1)
|
| 812 |
+
]
|
| 813 |
+
|
| 814 |
+
error_statistics = [
|
| 815 |
+
get_statistics(errors[i]) for i in range(len(channel_list) - 1)
|
| 816 |
+
]
|
| 817 |
+
|
| 818 |
+
if dataset.output_dim == 1:
|
| 819 |
+
relative_error_statistics = relative_error_statistics[0]
|
| 820 |
+
error_statistics = error_statistics[0]
|
| 821 |
+
if params.full_data:
|
| 822 |
+
relative_error_statistics["relative_full_data"] = (
|
| 823 |
+
relative_errors[0].tolist()
|
| 824 |
+
)
|
| 825 |
+
error_statistics["full_data"] = errors[0].tolist()
|
| 826 |
+
return {**relative_error_statistics, **error_statistics}
|
| 827 |
+
else:
|
| 828 |
+
mean_over_relative_means = np.mean(
|
| 829 |
+
np.array(
|
| 830 |
+
[
|
| 831 |
+
stats["mean_relative_l1_error"]
|
| 832 |
+
for stats in relative_error_statistics
|
| 833 |
+
]
|
| 834 |
+
),
|
| 835 |
+
axis=0,
|
| 836 |
+
)
|
| 837 |
+
mean_over_relative_medians = np.mean(
|
| 838 |
+
np.array(
|
| 839 |
+
[
|
| 840 |
+
stats["median_relative_l1_error"]
|
| 841 |
+
for stats in relative_error_statistics
|
| 842 |
+
]
|
| 843 |
+
),
|
| 844 |
+
axis=0,
|
| 845 |
+
)
|
| 846 |
+
mean_over_means = np.mean(
|
| 847 |
+
np.array(
|
| 848 |
+
[stats["mean_l1_error"] for stats in error_statistics]
|
| 849 |
+
),
|
| 850 |
+
axis=0,
|
| 851 |
+
)
|
| 852 |
+
mean_over_medians = np.mean(
|
| 853 |
+
np.array(
|
| 854 |
+
[stats["median_l1_error"] for stats in error_statistics]
|
| 855 |
+
),
|
| 856 |
+
axis=0,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
error_statistics_ = {
|
| 860 |
+
"mean_relative_l1_error": mean_over_relative_means,
|
| 861 |
+
"mean_over_median_relative_l1_error": mean_over_relative_medians,
|
| 862 |
+
"mean_l1_error": mean_over_means,
|
| 863 |
+
"mean_over_median_l1_error": mean_over_medians,
|
| 864 |
+
}
|
| 865 |
+
#!! The above is different from train and finetune (here mean_relative_l1_error is mean over medians instead of mean over means)
|
| 866 |
+
for i, stats in enumerate(relative_error_statistics):
|
| 867 |
+
for key, value in stats.items():
|
| 868 |
+
error_statistics_[
|
| 869 |
+
dataset.printable_channel_description[i] + "/" + key
|
| 870 |
+
] = value
|
| 871 |
+
if params.full_data:
|
| 872 |
+
error_statistics_[
|
| 873 |
+
dataset.printable_channel_description[i]
|
| 874 |
+
+ "/"
|
| 875 |
+
+ "relative_full_data"
|
| 876 |
+
] = relative_errors[i].tolist()
|
| 877 |
+
for i, stats in enumerate(error_statistics):
|
| 878 |
+
for key, value in stats.items():
|
| 879 |
+
error_statistics_[
|
| 880 |
+
dataset.printable_channel_description[i] + "/" + key
|
| 881 |
+
] = value
|
| 882 |
+
if params.full_data:
|
| 883 |
+
error_statistics_[
|
| 884 |
+
dataset.printable_channel_description[i]
|
| 885 |
+
+ "/"
|
| 886 |
+
+ "full_data"
|
| 887 |
+
] = errors[i].tolist()
|
| 888 |
+
return error_statistics_
|
| 889 |
+
|
| 890 |
+
data = []
|
| 891 |
+
for step in range(predictions.shape[1]):
|
| 892 |
+
metrics = compute_metrics(
|
| 893 |
+
EvalPrediction(predictions[:, step], labels[:, step].cpu().numpy())
|
| 894 |
+
)
|
| 895 |
+
if isinstance(params.ar_steps, int):
|
| 896 |
+
delta = (params.final_time - params.initial_time) // params.ar_steps
|
| 897 |
+
else:
|
| 898 |
+
delta = params.ar_steps[step]
|
| 899 |
+
data.append(
|
| 900 |
+
remove_underscore_dict(
|
| 901 |
+
{
|
| 902 |
+
"dataset": params.dataset,
|
| 903 |
+
"initial_time": params.initial_time + step * delta,
|
| 904 |
+
"final_time": params.initial_time + (step + 1) * delta,
|
| 905 |
+
**metrics,
|
| 906 |
+
}
|
| 907 |
+
)
|
| 908 |
+
)
|
| 909 |
+
elif params.mode == "eval_resolutions":
|
| 910 |
+
data = []
|
| 911 |
+
for resolution in params.resolutions:
|
| 912 |
+
dataset_kwargs = {"resolution": resolution}
|
| 913 |
+
dataset = get_test_set(
|
| 914 |
+
params.dataset,
|
| 915 |
+
params.data_path,
|
| 916 |
+
params.initial_time,
|
| 917 |
+
params.final_time,
|
| 918 |
+
dataset_kwargs,
|
| 919 |
+
)
|
| 920 |
+
trainer = get_trainer(
|
| 921 |
+
params.model_path,
|
| 922 |
+
params.batch_size,
|
| 923 |
+
dataset,
|
| 924 |
+
full_data=params.full_data,
|
| 925 |
+
)
|
| 926 |
+
_, _, metrics = rollout(
|
| 927 |
+
trainer,
|
| 928 |
+
dataset,
|
| 929 |
+
ar_steps=params.ar_steps,
|
| 930 |
+
output_all_steps=False,
|
| 931 |
+
)
|
| 932 |
+
data.append(
|
| 933 |
+
remove_underscore_dict(
|
| 934 |
+
{
|
| 935 |
+
"dataset": params.dataset,
|
| 936 |
+
"initial_time": params.initial_time,
|
| 937 |
+
"final_time": params.final_time,
|
| 938 |
+
"ar_steps": ar_steps,
|
| 939 |
+
"resolution": resolution,
|
| 940 |
+
**metrics,
|
| 941 |
+
}
|
| 942 |
+
)
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
if os.path.exists(params.file):
|
| 946 |
+
df = pd.read_csv(params.file)
|
| 947 |
+
else:
|
| 948 |
+
df = pd.DataFrame()
|
| 949 |
+
df = pd.concat([df, pd.DataFrame(data)], ignore_index=True)
|
| 950 |
+
df.to_csv(params.file, index=False)
|
external/poseidon/scOT/metrics.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def lp_error(preds: np.ndarray, targets: np.ndarray, p=1):
|
| 5 |
+
num_samples, num_channels, _, _ = preds.shape
|
| 6 |
+
preds = preds.reshape(num_samples, num_channels, -1)
|
| 7 |
+
targets = targets.reshape(num_samples, num_channels, -1)
|
| 8 |
+
errors = np.sum(np.abs(preds - targets) ** p, axis=-1)
|
| 9 |
+
return np.sum(errors, axis=-1) ** (1 / p)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def relative_lp_error(
|
| 13 |
+
preds: np.ndarray,
|
| 14 |
+
targets: np.ndarray,
|
| 15 |
+
p=1,
|
| 16 |
+
return_percent=True,
|
| 17 |
+
):
|
| 18 |
+
num_samples, num_channels, _, _ = preds.shape
|
| 19 |
+
preds = preds.reshape(num_samples, num_channels, -1)
|
| 20 |
+
targets = targets.reshape(num_samples, num_channels, -1)
|
| 21 |
+
errors = np.sum(np.abs(preds - targets) ** p, axis=-1)
|
| 22 |
+
normalization_factor = np.sum(np.abs(targets) ** p, axis=-1)
|
| 23 |
+
|
| 24 |
+
# catch 0 division
|
| 25 |
+
normalization_factor = np.sum(normalization_factor, axis=-1)
|
| 26 |
+
normalization_factor = np.where(
|
| 27 |
+
normalization_factor == 0, 1e-10, normalization_factor
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
errors = (np.sum(errors, axis=-1) / normalization_factor) ** (1 / p)
|
| 31 |
+
|
| 32 |
+
if return_percent:
|
| 33 |
+
errors *= 100
|
| 34 |
+
|
| 35 |
+
return errors
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def mean_relative_lp_error(
|
| 39 |
+
preds: np.ndarray,
|
| 40 |
+
targets: np.ndarray,
|
| 41 |
+
p=1,
|
| 42 |
+
return_percent=True,
|
| 43 |
+
):
|
| 44 |
+
errors = relative_lp_error(preds, targets, p, return_percent)
|
| 45 |
+
return np.mean(errors, axis=0)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def median_relative_lp_error(
|
| 49 |
+
preds: np.ndarray,
|
| 50 |
+
targets: np.ndarray,
|
| 51 |
+
p=1,
|
| 52 |
+
return_percent=True,
|
| 53 |
+
):
|
| 54 |
+
errors = relative_lp_error(preds, targets, p, return_percent)
|
| 55 |
+
return np.median(errors, axis=0)
|
external/poseidon/scOT/model.py
ADDED
|
@@ -0,0 +1,1485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains scOT.
|
| 3 |
+
|
| 4 |
+
A lot of this file is taken from the transformers library and changed to our purposes. Huggingface Transformers is licensed under
|
| 5 |
+
Apache 2.0 License, see trainer.py for details.
|
| 6 |
+
|
| 7 |
+
We follow https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/swinv2/configuration_swinv2.py
|
| 8 |
+
and https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/swinv2/modeling_swinv2.py#L1129
|
| 9 |
+
|
| 10 |
+
The class ConvNeXtBlock is taken from the facebookresearch/ConvNeXt repository and is licensed under the MIT License,
|
| 11 |
+
|
| 12 |
+
MIT License
|
| 13 |
+
|
| 14 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 15 |
+
|
| 16 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 17 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 18 |
+
in the Software without restriction, including without limitation the rights
|
| 19 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 20 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 21 |
+
furnished to do so, subject to the following conditions:
|
| 22 |
+
|
| 23 |
+
The above copyright notice and this permission notice shall be included in all
|
| 24 |
+
copies or substantial portions of the Software.
|
| 25 |
+
|
| 26 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 27 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 28 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 29 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 30 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 31 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 32 |
+
SOFTWARE.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
from transformers import (
|
| 36 |
+
Swinv2PreTrainedModel,
|
| 37 |
+
PretrainedConfig,
|
| 38 |
+
)
|
| 39 |
+
from transformers.models.swinv2.modeling_swinv2 import (
|
| 40 |
+
Swinv2EncoderOutput,
|
| 41 |
+
Swinv2Attention,
|
| 42 |
+
Swinv2DropPath,
|
| 43 |
+
Swinv2Intermediate,
|
| 44 |
+
Swinv2Output,
|
| 45 |
+
window_reverse,
|
| 46 |
+
window_partition,
|
| 47 |
+
)
|
| 48 |
+
from transformers.utils import ModelOutput
|
| 49 |
+
from dataclasses import dataclass
|
| 50 |
+
import torch
|
| 51 |
+
from torch import nn
|
| 52 |
+
from typing import Optional, Union, Tuple, List
|
| 53 |
+
import math
|
| 54 |
+
import collections
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class ScOTOutput(ModelOutput):
|
| 59 |
+
loss: Optional[torch.FloatTensor] = None
|
| 60 |
+
output: torch.FloatTensor = None
|
| 61 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 62 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 63 |
+
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ScOTConfig(PretrainedConfig):
|
| 67 |
+
"""https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/swinv2/configuration_swinv2.py"""
|
| 68 |
+
|
| 69 |
+
model_type = "swinv2"
|
| 70 |
+
|
| 71 |
+
attribute_map = {
|
| 72 |
+
"num_attention_heads": "num_heads",
|
| 73 |
+
"num_hidden_layers": "num_layers",
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
image_size=224,
|
| 79 |
+
patch_size=4,
|
| 80 |
+
num_channels=3,
|
| 81 |
+
num_out_channels=1,
|
| 82 |
+
embed_dim=96,
|
| 83 |
+
depths=[2, 2, 6, 2],
|
| 84 |
+
num_heads=[3, 6, 12, 24],
|
| 85 |
+
skip_connections=[True, True, True],
|
| 86 |
+
window_size=7,
|
| 87 |
+
mlp_ratio=4.0,
|
| 88 |
+
qkv_bias=True,
|
| 89 |
+
hidden_dropout_prob=0.0,
|
| 90 |
+
attention_probs_dropout_prob=0.0,
|
| 91 |
+
drop_path_rate=0.1,
|
| 92 |
+
hidden_act="gelu",
|
| 93 |
+
use_absolute_embeddings=False,
|
| 94 |
+
initializer_range=0.02,
|
| 95 |
+
layer_norm_eps=1e-5,
|
| 96 |
+
p=1, # for loss: 1 for l1, 2 for l2
|
| 97 |
+
channel_slice_list_normalized_loss=None, # if None will fall back to absolute loss otherwise normalized loss with split channels
|
| 98 |
+
residual_model="convnext", # "convnext" or "resnet"
|
| 99 |
+
use_conditioning=False,
|
| 100 |
+
learn_residual=False, # learn the residual for time-dependent problems
|
| 101 |
+
**kwargs,
|
| 102 |
+
):
|
| 103 |
+
super().__init__(**kwargs)
|
| 104 |
+
|
| 105 |
+
self.image_size = image_size
|
| 106 |
+
self.patch_size = patch_size
|
| 107 |
+
self.num_channels = num_channels
|
| 108 |
+
self.embed_dim = embed_dim
|
| 109 |
+
self.depths = depths
|
| 110 |
+
self.num_layers = len(depths)
|
| 111 |
+
self.num_heads = num_heads
|
| 112 |
+
self.skip_connections = skip_connections
|
| 113 |
+
self.window_size = window_size
|
| 114 |
+
self.mlp_ratio = mlp_ratio
|
| 115 |
+
self.qkv_bias = qkv_bias
|
| 116 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 117 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 118 |
+
self.drop_path_rate = drop_path_rate
|
| 119 |
+
self.hidden_act = hidden_act
|
| 120 |
+
self.use_absolute_embeddings = use_absolute_embeddings
|
| 121 |
+
self.use_conditioning = use_conditioning
|
| 122 |
+
self.learn_residual = learn_residual if self.use_conditioning else False
|
| 123 |
+
self.layer_norm_eps = layer_norm_eps
|
| 124 |
+
self.initializer_range = initializer_range
|
| 125 |
+
# we set the hidden_size attribute in order to make Swinv2 work with VisionEncoderDecoderModel
|
| 126 |
+
# this indicates the channel dimension after the last stage of the model
|
| 127 |
+
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
| 128 |
+
self.pretrained_window_sizes = (0, 0, 0, 0)
|
| 129 |
+
self.num_out_channels = num_out_channels
|
| 130 |
+
self.p = p
|
| 131 |
+
self.channel_slice_list_normalized_loss = channel_slice_list_normalized_loss
|
| 132 |
+
self.residual_model = residual_model
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class LayerNorm(nn.LayerNorm):
|
| 136 |
+
def __init__(self, *args, **kwargs):
|
| 137 |
+
super().__init__(*args, **kwargs)
|
| 138 |
+
|
| 139 |
+
def forward(self, x, time):
|
| 140 |
+
return super().forward(x)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ConditionalLayerNorm(nn.Module):
|
| 144 |
+
def __init__(self, dim, eps=1e-5):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.eps = eps
|
| 147 |
+
self.weight = nn.Linear(1, dim)
|
| 148 |
+
self.bias = nn.Linear(1, dim)
|
| 149 |
+
|
| 150 |
+
def forward(self, x, time):
|
| 151 |
+
mean = x.mean(dim=-1, keepdim=True)
|
| 152 |
+
var = (x**2).mean(dim=-1, keepdim=True) - mean**2
|
| 153 |
+
x = (x - mean) / (var + self.eps).sqrt()
|
| 154 |
+
time = time.reshape(-1, 1).type_as(x)
|
| 155 |
+
weight = self.weight(time).unsqueeze(1)
|
| 156 |
+
bias = self.bias(time).unsqueeze(1)
|
| 157 |
+
if x.dim() == 4:
|
| 158 |
+
weight = weight.unsqueeze(1)
|
| 159 |
+
bias = bias.unsqueeze(1)
|
| 160 |
+
return weight * x + bias
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ConvNeXtBlock(nn.Module):
|
| 164 |
+
r"""Taken from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
|
| 165 |
+
ConvNeXt Block. There are two equivalent implementations:
|
| 166 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 167 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 168 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
dim (int): Number of input channels.
|
| 172 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 173 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(self, config, dim, drop_path=0.0, layer_scale_init_value=1e-6):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.dwconv = nn.Conv2d(
|
| 179 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
| 180 |
+
) # depthwise conv
|
| 181 |
+
if config.use_conditioning:
|
| 182 |
+
layer_norm = ConditionalLayerNorm
|
| 183 |
+
else:
|
| 184 |
+
layer_norm = LayerNorm
|
| 185 |
+
self.norm = layer_norm(dim, eps=config.layer_norm_eps)
|
| 186 |
+
self.pwconv1 = nn.Linear(
|
| 187 |
+
dim, 4 * dim
|
| 188 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 189 |
+
self.act = nn.GELU()
|
| 190 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 191 |
+
self.weight = (
|
| 192 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 193 |
+
if layer_scale_init_value > 0
|
| 194 |
+
else None
|
| 195 |
+
) # was gamma before
|
| 196 |
+
self.drop_path = Swinv2DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 197 |
+
|
| 198 |
+
def forward(self, x, time):
|
| 199 |
+
batch_size, sequence_length, hidden_size = x.shape
|
| 200 |
+
#! assumes square images
|
| 201 |
+
input_dim = math.floor(sequence_length**0.5)
|
| 202 |
+
|
| 203 |
+
input = x
|
| 204 |
+
x = x.reshape(batch_size, input_dim, input_dim, hidden_size)
|
| 205 |
+
x = x.permute(0, 3, 1, 2)
|
| 206 |
+
x = self.dwconv(x)
|
| 207 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 208 |
+
x = self.norm(x, time)
|
| 209 |
+
x = self.pwconv1(x)
|
| 210 |
+
x = self.act(x)
|
| 211 |
+
x = self.pwconv2(x)
|
| 212 |
+
if self.weight is not None:
|
| 213 |
+
x = self.weight * x
|
| 214 |
+
x = x.reshape(batch_size, sequence_length, hidden_size)
|
| 215 |
+
|
| 216 |
+
x = input + self.drop_path(x)
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ResNetBlock(nn.Module):
|
| 221 |
+
def __init__(self, config, dim):
|
| 222 |
+
super().__init__()
|
| 223 |
+
kernel_size = 3
|
| 224 |
+
pad = (kernel_size - 1) // 2
|
| 225 |
+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=pad)
|
| 226 |
+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=pad)
|
| 227 |
+
self.bn1 = nn.BatchNorm2d(dim)
|
| 228 |
+
self.bn2 = nn.BatchNorm2d(dim)
|
| 229 |
+
|
| 230 |
+
def forward(self, x, time):
|
| 231 |
+
batch_size, sequence_length, hidden_size = x.shape
|
| 232 |
+
#! assumes square images
|
| 233 |
+
input_dim = math.floor(sequence_length**0.5)
|
| 234 |
+
|
| 235 |
+
input = x
|
| 236 |
+
x = x.reshape(batch_size, input_dim, input_dim, hidden_size)
|
| 237 |
+
x = x.permute(0, 3, 1, 2)
|
| 238 |
+
x = self.conv1(x)
|
| 239 |
+
x = self.bn1(x)
|
| 240 |
+
x = nn.functional.leaky_relu(x)
|
| 241 |
+
x = self.conv2(x)
|
| 242 |
+
x = self.bn2(x)
|
| 243 |
+
x = x.permute(0, 2, 3, 1)
|
| 244 |
+
x = x.reshape(batch_size, sequence_length, hidden_size)
|
| 245 |
+
x = x + input
|
| 246 |
+
return x
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ScOTPatchEmbeddings(nn.Module):
|
| 250 |
+
"""
|
| 251 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 252 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 253 |
+
Transformer.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, config):
|
| 257 |
+
super().__init__()
|
| 258 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 259 |
+
num_channels, hidden_size = config.num_channels, config.embed_dim
|
| 260 |
+
image_size = (
|
| 261 |
+
image_size
|
| 262 |
+
if isinstance(image_size, collections.abc.Iterable)
|
| 263 |
+
else (image_size, image_size)
|
| 264 |
+
)
|
| 265 |
+
patch_size = (
|
| 266 |
+
patch_size
|
| 267 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
| 268 |
+
else (patch_size, patch_size)
|
| 269 |
+
)
|
| 270 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
| 271 |
+
image_size[0] // patch_size[0]
|
| 272 |
+
)
|
| 273 |
+
self.image_size = image_size
|
| 274 |
+
self.patch_size = patch_size
|
| 275 |
+
self.num_channels = num_channels
|
| 276 |
+
self.num_patches = num_patches
|
| 277 |
+
self.grid_size = (
|
| 278 |
+
image_size[0] // patch_size[0],
|
| 279 |
+
image_size[1] // patch_size[1],
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
self.projection = nn.Conv2d(
|
| 283 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def maybe_pad(self, pixel_values, height, width):
|
| 287 |
+
if width % self.patch_size[1] != 0:
|
| 288 |
+
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
|
| 289 |
+
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
| 290 |
+
if height % self.patch_size[0] != 0:
|
| 291 |
+
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
|
| 292 |
+
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
| 293 |
+
return pixel_values
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self, pixel_values: Optional[torch.FloatTensor]
|
| 297 |
+
) -> Tuple[torch.Tensor, Tuple[int]]:
|
| 298 |
+
_, num_channels, height, width = pixel_values.shape
|
| 299 |
+
if num_channels != self.num_channels:
|
| 300 |
+
raise ValueError(
|
| 301 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 302 |
+
)
|
| 303 |
+
# pad the input to be divisible by self.patch_size, if needed
|
| 304 |
+
pixel_values = self.maybe_pad(pixel_values, height, width)
|
| 305 |
+
embeddings = self.projection(pixel_values)
|
| 306 |
+
_, _, height, width = embeddings.shape
|
| 307 |
+
output_dimensions = (height, width)
|
| 308 |
+
embeddings = embeddings.flatten(2).transpose(1, 2)
|
| 309 |
+
|
| 310 |
+
return embeddings, output_dimensions
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class ScOTEmbeddings(nn.Module):
|
| 314 |
+
"""
|
| 315 |
+
Construct the patch and position embeddings. Optionally, also the mask token.
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def __init__(self, config, use_mask_token=False):
|
| 319 |
+
super().__init__()
|
| 320 |
+
|
| 321 |
+
self.patch_embeddings = ScOTPatchEmbeddings(config)
|
| 322 |
+
num_patches = self.patch_embeddings.num_patches
|
| 323 |
+
self.patch_grid = self.patch_embeddings.grid_size
|
| 324 |
+
self.mask_token = (
|
| 325 |
+
nn.Parameter(torch.zeros(1, 1, config.embed_dim))
|
| 326 |
+
if use_mask_token
|
| 327 |
+
else None
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if config.use_absolute_embeddings:
|
| 331 |
+
self.position_embeddings = nn.Parameter(
|
| 332 |
+
torch.zeros(1, num_patches, config.embed_dim)
|
| 333 |
+
)
|
| 334 |
+
else:
|
| 335 |
+
self.position_embeddings = None
|
| 336 |
+
|
| 337 |
+
if config.use_conditioning:
|
| 338 |
+
layer_norm = ConditionalLayerNorm
|
| 339 |
+
else:
|
| 340 |
+
layer_norm = LayerNorm
|
| 341 |
+
|
| 342 |
+
self.norm = layer_norm(config.embed_dim)
|
| 343 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 344 |
+
|
| 345 |
+
def forward(
|
| 346 |
+
self,
|
| 347 |
+
pixel_values: Optional[torch.FloatTensor],
|
| 348 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 349 |
+
time: Optional[torch.FloatTensor] = None,
|
| 350 |
+
) -> Tuple[torch.Tensor]:
|
| 351 |
+
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
| 352 |
+
embeddings = self.norm(embeddings, time)
|
| 353 |
+
batch_size, seq_len, _ = embeddings.size()
|
| 354 |
+
|
| 355 |
+
if bool_masked_pos is not None:
|
| 356 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
| 357 |
+
# replace the masked visual tokens by mask_tokens
|
| 358 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
| 359 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
| 360 |
+
|
| 361 |
+
if self.position_embeddings is not None:
|
| 362 |
+
embeddings = embeddings + self.position_embeddings
|
| 363 |
+
|
| 364 |
+
embeddings = self.dropout(embeddings)
|
| 365 |
+
|
| 366 |
+
return embeddings, output_dimensions
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class ScOTLayer(nn.Module):
|
| 370 |
+
def __init__(
|
| 371 |
+
self,
|
| 372 |
+
config,
|
| 373 |
+
dim,
|
| 374 |
+
input_resolution,
|
| 375 |
+
num_heads,
|
| 376 |
+
drop_path=0.0,
|
| 377 |
+
shift_size=0,
|
| 378 |
+
pretrained_window_size=0,
|
| 379 |
+
):
|
| 380 |
+
super().__init__()
|
| 381 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 382 |
+
self.shift_size = shift_size
|
| 383 |
+
self.window_size = config.window_size
|
| 384 |
+
self.input_resolution = input_resolution
|
| 385 |
+
self.set_shift_and_window_size(input_resolution)
|
| 386 |
+
self.attention = Swinv2Attention(
|
| 387 |
+
config=config,
|
| 388 |
+
dim=dim,
|
| 389 |
+
num_heads=num_heads,
|
| 390 |
+
window_size=self.window_size,
|
| 391 |
+
pretrained_window_size=(
|
| 392 |
+
pretrained_window_size
|
| 393 |
+
if isinstance(pretrained_window_size, collections.abc.Iterable)
|
| 394 |
+
else (pretrained_window_size, pretrained_window_size)
|
| 395 |
+
),
|
| 396 |
+
)
|
| 397 |
+
if config.use_conditioning:
|
| 398 |
+
layer_norm = ConditionalLayerNorm
|
| 399 |
+
else:
|
| 400 |
+
layer_norm = LayerNorm
|
| 401 |
+
self.layernorm_before = layer_norm(dim, eps=config.layer_norm_eps)
|
| 402 |
+
self.drop_path = Swinv2DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 403 |
+
self.intermediate = Swinv2Intermediate(config, dim)
|
| 404 |
+
self.output = Swinv2Output(config, dim)
|
| 405 |
+
self.layernorm_after = layer_norm(dim, eps=config.layer_norm_eps)
|
| 406 |
+
|
| 407 |
+
def set_shift_and_window_size(self, input_resolution):
|
| 408 |
+
target_window_size = (
|
| 409 |
+
self.window_size
|
| 410 |
+
if isinstance(self.window_size, collections.abc.Iterable)
|
| 411 |
+
else (self.window_size, self.window_size)
|
| 412 |
+
)
|
| 413 |
+
target_shift_size = (
|
| 414 |
+
self.shift_size
|
| 415 |
+
if isinstance(self.shift_size, collections.abc.Iterable)
|
| 416 |
+
else (self.shift_size, self.shift_size)
|
| 417 |
+
)
|
| 418 |
+
window_dim = (
|
| 419 |
+
input_resolution[0].item()
|
| 420 |
+
if torch.is_tensor(input_resolution[0])
|
| 421 |
+
else input_resolution[0]
|
| 422 |
+
)
|
| 423 |
+
self.window_size = (
|
| 424 |
+
window_dim if window_dim <= target_window_size[0] else target_window_size[0]
|
| 425 |
+
)
|
| 426 |
+
self.shift_size = (
|
| 427 |
+
0
|
| 428 |
+
if input_resolution
|
| 429 |
+
<= (
|
| 430 |
+
self.window_size
|
| 431 |
+
if isinstance(self.window_size, collections.abc.Iterable)
|
| 432 |
+
else (self.window_size, self.window_size)
|
| 433 |
+
)
|
| 434 |
+
else target_shift_size[0]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def get_attn_mask(self, height, width, dtype):
|
| 438 |
+
if self.shift_size > 0:
|
| 439 |
+
# calculate attention mask for shifted window multihead self attention
|
| 440 |
+
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
|
| 441 |
+
height_slices = (
|
| 442 |
+
slice(0, -self.window_size),
|
| 443 |
+
slice(-self.window_size, -self.shift_size),
|
| 444 |
+
slice(-self.shift_size, None),
|
| 445 |
+
)
|
| 446 |
+
width_slices = (
|
| 447 |
+
slice(0, -self.window_size),
|
| 448 |
+
slice(-self.window_size, -self.shift_size),
|
| 449 |
+
slice(-self.shift_size, None),
|
| 450 |
+
)
|
| 451 |
+
count = 0
|
| 452 |
+
for height_slice in height_slices:
|
| 453 |
+
for width_slice in width_slices:
|
| 454 |
+
img_mask[:, height_slice, width_slice, :] = count
|
| 455 |
+
count += 1
|
| 456 |
+
|
| 457 |
+
mask_windows = window_partition(img_mask, self.window_size)
|
| 458 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 459 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 460 |
+
attn_mask = attn_mask.masked_fill(
|
| 461 |
+
attn_mask != 0, float(-100.0)
|
| 462 |
+
).masked_fill(attn_mask == 0, float(0.0))
|
| 463 |
+
else:
|
| 464 |
+
attn_mask = None
|
| 465 |
+
return attn_mask
|
| 466 |
+
|
| 467 |
+
def maybe_pad(self, hidden_states, height, width):
|
| 468 |
+
pad_right = (self.window_size - width % self.window_size) % self.window_size
|
| 469 |
+
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
|
| 470 |
+
pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
|
| 471 |
+
hidden_states = nn.functional.pad(hidden_states, pad_values)
|
| 472 |
+
return hidden_states, pad_values
|
| 473 |
+
|
| 474 |
+
def forward(
|
| 475 |
+
self,
|
| 476 |
+
hidden_states: torch.Tensor,
|
| 477 |
+
input_dimensions: Tuple[int, int],
|
| 478 |
+
time: torch.Tensor,
|
| 479 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 480 |
+
output_attentions: Optional[bool] = False,
|
| 481 |
+
always_partition: Optional[bool] = False,
|
| 482 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 483 |
+
if not always_partition:
|
| 484 |
+
self.set_shift_and_window_size(input_dimensions)
|
| 485 |
+
else:
|
| 486 |
+
pass
|
| 487 |
+
height, width = input_dimensions
|
| 488 |
+
batch_size, _, channels = hidden_states.size()
|
| 489 |
+
shortcut = hidden_states
|
| 490 |
+
|
| 491 |
+
# pad hidden_states to multiples of window size
|
| 492 |
+
hidden_states = hidden_states.view(batch_size, height, width, channels)
|
| 493 |
+
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
| 494 |
+
_, height_pad, width_pad, _ = hidden_states.shape
|
| 495 |
+
# cyclic shift
|
| 496 |
+
if self.shift_size > 0:
|
| 497 |
+
shifted_hidden_states = torch.roll(
|
| 498 |
+
hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
| 499 |
+
)
|
| 500 |
+
else:
|
| 501 |
+
shifted_hidden_states = hidden_states
|
| 502 |
+
|
| 503 |
+
# partition windows
|
| 504 |
+
hidden_states_windows = window_partition(
|
| 505 |
+
shifted_hidden_states, self.window_size
|
| 506 |
+
)
|
| 507 |
+
hidden_states_windows = hidden_states_windows.view(
|
| 508 |
+
-1, self.window_size * self.window_size, channels
|
| 509 |
+
)
|
| 510 |
+
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
|
| 511 |
+
if attn_mask is not None:
|
| 512 |
+
attn_mask = attn_mask.to(hidden_states_windows.device)
|
| 513 |
+
|
| 514 |
+
attention_outputs = self.attention(
|
| 515 |
+
hidden_states_windows,
|
| 516 |
+
attn_mask,
|
| 517 |
+
head_mask,
|
| 518 |
+
output_attentions=output_attentions,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
attention_output = attention_outputs[0]
|
| 522 |
+
|
| 523 |
+
attention_windows = attention_output.view(
|
| 524 |
+
-1, self.window_size, self.window_size, channels
|
| 525 |
+
)
|
| 526 |
+
shifted_windows = window_reverse(
|
| 527 |
+
attention_windows, self.window_size, height_pad, width_pad
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# reverse cyclic shift
|
| 531 |
+
if self.shift_size > 0:
|
| 532 |
+
attention_windows = torch.roll(
|
| 533 |
+
shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
|
| 534 |
+
)
|
| 535 |
+
else:
|
| 536 |
+
attention_windows = shifted_windows
|
| 537 |
+
|
| 538 |
+
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
| 539 |
+
if was_padded:
|
| 540 |
+
attention_windows = attention_windows[:, :height, :width, :].contiguous()
|
| 541 |
+
|
| 542 |
+
attention_windows = attention_windows.view(batch_size, height * width, channels)
|
| 543 |
+
hidden_states = self.layernorm_before(attention_windows, time)
|
| 544 |
+
hidden_states = shortcut + self.drop_path(hidden_states)
|
| 545 |
+
|
| 546 |
+
layer_output = self.intermediate(hidden_states)
|
| 547 |
+
layer_output = self.output(layer_output)
|
| 548 |
+
layer_output = hidden_states + self.drop_path(
|
| 549 |
+
self.layernorm_after(layer_output, time)
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
layer_outputs = (
|
| 553 |
+
(layer_output, attention_outputs[1])
|
| 554 |
+
if output_attentions
|
| 555 |
+
else (layer_output,)
|
| 556 |
+
)
|
| 557 |
+
return layer_outputs
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class ScOTPatchRecovery(nn.Module):
|
| 561 |
+
"""https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py"""
|
| 562 |
+
|
| 563 |
+
def __init__(self, config):
|
| 564 |
+
super().__init__()
|
| 565 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 566 |
+
num_out_channels, hidden_size = (
|
| 567 |
+
config.num_out_channels,
|
| 568 |
+
config.embed_dim, # if not config.skip_connections[0] else 2 * config.embed_dim,
|
| 569 |
+
)
|
| 570 |
+
image_size = (
|
| 571 |
+
image_size
|
| 572 |
+
if isinstance(image_size, collections.abc.Iterable)
|
| 573 |
+
else (image_size, image_size)
|
| 574 |
+
)
|
| 575 |
+
patch_size = (
|
| 576 |
+
patch_size
|
| 577 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
| 578 |
+
else (patch_size, patch_size)
|
| 579 |
+
)
|
| 580 |
+
num_patches = (image_size[0] // patch_size[0]) * (
|
| 581 |
+
image_size[1] // patch_size[1]
|
| 582 |
+
)
|
| 583 |
+
self.num_patches = num_patches
|
| 584 |
+
self.patch_size = patch_size
|
| 585 |
+
self.image_size = image_size
|
| 586 |
+
self.num_out_channels = num_out_channels
|
| 587 |
+
self.grid_size = (
|
| 588 |
+
image_size[0] // patch_size[0],
|
| 589 |
+
image_size[1] // patch_size[1],
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
self.projection = nn.ConvTranspose2d(
|
| 593 |
+
in_channels=hidden_size,
|
| 594 |
+
out_channels=num_out_channels,
|
| 595 |
+
kernel_size=patch_size,
|
| 596 |
+
stride=patch_size,
|
| 597 |
+
)
|
| 598 |
+
# the following is not done in Pangu
|
| 599 |
+
self.mixup = nn.Conv2d(
|
| 600 |
+
num_out_channels,
|
| 601 |
+
num_out_channels,
|
| 602 |
+
kernel_size=5,
|
| 603 |
+
stride=1,
|
| 604 |
+
padding=2,
|
| 605 |
+
bias=False,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
def maybe_crop(self, pixel_values, height, width):
|
| 609 |
+
if pixel_values.shape[2] > height:
|
| 610 |
+
pixel_values = pixel_values[:, :, :height, :]
|
| 611 |
+
if pixel_values.shape[3] > width:
|
| 612 |
+
pixel_values = pixel_values[:, :, :, :width]
|
| 613 |
+
return pixel_values
|
| 614 |
+
|
| 615 |
+
def forward(self, hidden_states):
|
| 616 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 617 |
+
hidden_states = hidden_states.reshape(
|
| 618 |
+
hidden_states.shape[0], hidden_states.shape[1], *self.grid_size
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
output = self.projection(hidden_states)
|
| 622 |
+
output = self.maybe_crop(output, self.image_size[0], self.image_size[1])
|
| 623 |
+
return self.mixup(output)
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class ScOTPatchMerging(nn.Module):
|
| 627 |
+
"""
|
| 628 |
+
Patch Merging Layer.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
input_resolution (`Tuple[int]`):
|
| 632 |
+
Resolution of input feature.
|
| 633 |
+
dim (`int`):
|
| 634 |
+
Number of input channels.
|
| 635 |
+
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
|
| 636 |
+
Normalization layer class.
|
| 637 |
+
"""
|
| 638 |
+
|
| 639 |
+
def __init__(
|
| 640 |
+
self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = LayerNorm
|
| 641 |
+
) -> None:
|
| 642 |
+
super().__init__()
|
| 643 |
+
self.input_resolution = input_resolution
|
| 644 |
+
self.dim = dim
|
| 645 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 646 |
+
self.norm = norm_layer(2 * dim)
|
| 647 |
+
|
| 648 |
+
def maybe_pad(self, input_feature, height, width):
|
| 649 |
+
should_pad = (height % 2 == 1) or (width % 2 == 1)
|
| 650 |
+
if should_pad:
|
| 651 |
+
pad_values = (0, 0, 0, width % 2, 0, height % 2)
|
| 652 |
+
input_feature = nn.functional.pad(input_feature, pad_values)
|
| 653 |
+
|
| 654 |
+
return input_feature
|
| 655 |
+
|
| 656 |
+
def forward(
|
| 657 |
+
self,
|
| 658 |
+
input_feature: torch.Tensor,
|
| 659 |
+
input_dimensions: Tuple[int, int],
|
| 660 |
+
time: torch.Tensor,
|
| 661 |
+
) -> torch.Tensor:
|
| 662 |
+
height, width = input_dimensions
|
| 663 |
+
# `dim` is height * width
|
| 664 |
+
batch_size, dim, num_channels = input_feature.shape
|
| 665 |
+
|
| 666 |
+
input_feature = input_feature.view(batch_size, height, width, num_channels)
|
| 667 |
+
# pad input to be disible by width and height, if needed
|
| 668 |
+
input_feature = self.maybe_pad(input_feature, height, width)
|
| 669 |
+
# [batch_size, height/2, width/2, num_channels]
|
| 670 |
+
input_feature_0 = input_feature[:, 0::2, 0::2, :]
|
| 671 |
+
# [batch_size, height/2, width/2, num_channels]
|
| 672 |
+
input_feature_1 = input_feature[:, 1::2, 0::2, :]
|
| 673 |
+
# [batch_size, height/2, width/2, num_channels]
|
| 674 |
+
input_feature_2 = input_feature[:, 0::2, 1::2, :]
|
| 675 |
+
# [batch_size, height/2, width/2, num_channels]
|
| 676 |
+
input_feature_3 = input_feature[:, 1::2, 1::2, :]
|
| 677 |
+
# [batch_size, height/2 * width/2, 4*num_channels]
|
| 678 |
+
input_feature = torch.cat(
|
| 679 |
+
[input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1
|
| 680 |
+
)
|
| 681 |
+
input_feature = input_feature.view(
|
| 682 |
+
batch_size, -1, 4 * num_channels
|
| 683 |
+
) # [batch_size, height/2 * width/2, 4*C]
|
| 684 |
+
|
| 685 |
+
input_feature = self.reduction(input_feature)
|
| 686 |
+
input_feature = self.norm(input_feature, time)
|
| 687 |
+
|
| 688 |
+
return input_feature
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class ScOTPatchUnmerging(nn.Module):
|
| 692 |
+
def __init__(
|
| 693 |
+
self,
|
| 694 |
+
input_resolution: Tuple[int],
|
| 695 |
+
dim: int,
|
| 696 |
+
norm_layer: nn.Module = LayerNorm,
|
| 697 |
+
) -> None:
|
| 698 |
+
super().__init__()
|
| 699 |
+
self.input_resolution = input_resolution
|
| 700 |
+
self.dim = dim
|
| 701 |
+
self.upsample = nn.Linear(dim, 2 * dim, bias=False)
|
| 702 |
+
self.mixup = nn.Linear(dim // 2, dim // 2, bias=False)
|
| 703 |
+
self.norm = norm_layer(dim // 2)
|
| 704 |
+
|
| 705 |
+
def maybe_crop(self, input_feature, height, width):
|
| 706 |
+
height_in, width_in = input_feature.shape[1], input_feature.shape[2]
|
| 707 |
+
if height_in > height:
|
| 708 |
+
input_feature = input_feature[:, :height, :, :]
|
| 709 |
+
if width_in > width:
|
| 710 |
+
input_feature = input_feature[:, :, :width, :]
|
| 711 |
+
return input_feature
|
| 712 |
+
|
| 713 |
+
def forward(
|
| 714 |
+
self,
|
| 715 |
+
input_feature: torch.Tensor,
|
| 716 |
+
output_dimensions: Tuple[int, int],
|
| 717 |
+
time: torch.Tensor,
|
| 718 |
+
) -> torch.Tensor:
|
| 719 |
+
output_height, output_width = output_dimensions
|
| 720 |
+
batch_size, seq_len, hidden_size = input_feature.shape
|
| 721 |
+
#! assume square image
|
| 722 |
+
input_height = input_width = math.floor(seq_len**0.5)
|
| 723 |
+
input_feature = self.upsample(input_feature)
|
| 724 |
+
input_feature = input_feature.reshape(
|
| 725 |
+
batch_size, input_height, input_width, 2, 2, hidden_size // 2
|
| 726 |
+
)
|
| 727 |
+
input_feature = input_feature.permute(0, 1, 3, 2, 4, 5)
|
| 728 |
+
input_feature = input_feature.reshape(
|
| 729 |
+
batch_size, 2 * input_height, 2 * input_width, hidden_size // 2
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
input_feature = self.maybe_crop(input_feature, output_height, output_width)
|
| 733 |
+
input_feature = input_feature.reshape(batch_size, -1, hidden_size // 2)
|
| 734 |
+
|
| 735 |
+
input_feature = self.norm(input_feature, time)
|
| 736 |
+
return self.mixup(input_feature)
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class ScOTEncodeStage(nn.Module):
|
| 740 |
+
def __init__(
|
| 741 |
+
self,
|
| 742 |
+
config,
|
| 743 |
+
dim,
|
| 744 |
+
input_resolution,
|
| 745 |
+
depth,
|
| 746 |
+
num_heads,
|
| 747 |
+
drop_path,
|
| 748 |
+
downsample,
|
| 749 |
+
pretrained_window_size=0,
|
| 750 |
+
):
|
| 751 |
+
super().__init__()
|
| 752 |
+
self.config = config
|
| 753 |
+
self.dim = dim
|
| 754 |
+
window_size = (
|
| 755 |
+
config.window_size
|
| 756 |
+
if isinstance(config.window_size, collections.abc.Iterable)
|
| 757 |
+
else (config.window_size, config.window_size)
|
| 758 |
+
)
|
| 759 |
+
self.blocks = nn.ModuleList(
|
| 760 |
+
[
|
| 761 |
+
ScOTLayer(
|
| 762 |
+
config=config,
|
| 763 |
+
dim=dim,
|
| 764 |
+
input_resolution=input_resolution,
|
| 765 |
+
num_heads=num_heads,
|
| 766 |
+
shift_size=(
|
| 767 |
+
[0, 0]
|
| 768 |
+
if (i % 2 == 0)
|
| 769 |
+
else [window_size[0] // 2, window_size[1] // 2]
|
| 770 |
+
),
|
| 771 |
+
drop_path=drop_path[i],
|
| 772 |
+
pretrained_window_size=pretrained_window_size,
|
| 773 |
+
)
|
| 774 |
+
for i in range(depth)
|
| 775 |
+
]
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
# patch merging layer
|
| 779 |
+
if downsample is not None:
|
| 780 |
+
if config.use_conditioning:
|
| 781 |
+
layer_norm = ConditionalLayerNorm
|
| 782 |
+
else:
|
| 783 |
+
layer_norm = LayerNorm
|
| 784 |
+
self.downsample = downsample(
|
| 785 |
+
input_resolution, dim=dim, norm_layer=layer_norm
|
| 786 |
+
)
|
| 787 |
+
else:
|
| 788 |
+
self.downsample = None
|
| 789 |
+
|
| 790 |
+
self.pointing = False
|
| 791 |
+
|
| 792 |
+
def forward(
|
| 793 |
+
self,
|
| 794 |
+
hidden_states: torch.Tensor,
|
| 795 |
+
input_dimensions: Tuple[int, int],
|
| 796 |
+
time: torch.Tensor,
|
| 797 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 798 |
+
output_attentions: Optional[bool] = False,
|
| 799 |
+
always_partition: Optional[bool] = False,
|
| 800 |
+
) -> Tuple[torch.Tensor]:
|
| 801 |
+
height, width = input_dimensions
|
| 802 |
+
|
| 803 |
+
inputs = hidden_states
|
| 804 |
+
|
| 805 |
+
for i, layer_module in enumerate(self.blocks):
|
| 806 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 807 |
+
|
| 808 |
+
layer_outputs = layer_module(
|
| 809 |
+
hidden_states,
|
| 810 |
+
input_dimensions,
|
| 811 |
+
time,
|
| 812 |
+
layer_head_mask,
|
| 813 |
+
output_attentions,
|
| 814 |
+
always_partition,
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
hidden_states = layer_outputs[0]
|
| 818 |
+
|
| 819 |
+
hidden_states_before_downsampling = hidden_states
|
| 820 |
+
if self.downsample is not None:
|
| 821 |
+
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
| 822 |
+
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
| 823 |
+
hidden_states = self.downsample(
|
| 824 |
+
hidden_states_before_downsampling + inputs, input_dimensions, time
|
| 825 |
+
)
|
| 826 |
+
else:
|
| 827 |
+
output_dimensions = (height, width, height, width)
|
| 828 |
+
|
| 829 |
+
stage_outputs = (
|
| 830 |
+
hidden_states,
|
| 831 |
+
hidden_states_before_downsampling,
|
| 832 |
+
output_dimensions,
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
if output_attentions:
|
| 836 |
+
stage_outputs += layer_outputs[1:]
|
| 837 |
+
return stage_outputs
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
class ScOTDecodeStage(nn.Module):
|
| 841 |
+
def __init__(
|
| 842 |
+
self,
|
| 843 |
+
config,
|
| 844 |
+
dim,
|
| 845 |
+
input_resolution,
|
| 846 |
+
depth,
|
| 847 |
+
num_heads,
|
| 848 |
+
drop_path,
|
| 849 |
+
upsample,
|
| 850 |
+
upsampled_size,
|
| 851 |
+
pretrained_window_size=0,
|
| 852 |
+
):
|
| 853 |
+
super().__init__()
|
| 854 |
+
self.config = config
|
| 855 |
+
self.dim = dim
|
| 856 |
+
window_size = (
|
| 857 |
+
config.window_size
|
| 858 |
+
if isinstance(config.window_size, collections.abc.Iterable)
|
| 859 |
+
else (config.window_size, config.window_size)
|
| 860 |
+
)
|
| 861 |
+
self.blocks = nn.ModuleList(
|
| 862 |
+
[
|
| 863 |
+
ScOTLayer(
|
| 864 |
+
config=config,
|
| 865 |
+
dim=dim,
|
| 866 |
+
input_resolution=input_resolution,
|
| 867 |
+
num_heads=num_heads,
|
| 868 |
+
shift_size=(
|
| 869 |
+
[0, 0]
|
| 870 |
+
if (i % 2 == 0)
|
| 871 |
+
else [window_size[0] // 2, window_size[1] // 2]
|
| 872 |
+
),
|
| 873 |
+
drop_path=drop_path[depth - 1 - i], # TODO: reverse...
|
| 874 |
+
pretrained_window_size=pretrained_window_size,
|
| 875 |
+
)
|
| 876 |
+
for i in reversed(range(depth)) # TODO: reverse here?
|
| 877 |
+
]
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
if upsample is not None:
|
| 881 |
+
if config.use_conditioning:
|
| 882 |
+
layer_norm = ConditionalLayerNorm
|
| 883 |
+
else:
|
| 884 |
+
layer_norm = LayerNorm
|
| 885 |
+
self.upsample = upsample(input_resolution, dim=dim, norm_layer=layer_norm)
|
| 886 |
+
self.upsampled_size = upsampled_size
|
| 887 |
+
else:
|
| 888 |
+
self.upsample = None
|
| 889 |
+
|
| 890 |
+
self.pointing = False
|
| 891 |
+
|
| 892 |
+
def forward(
|
| 893 |
+
self,
|
| 894 |
+
hidden_states: torch.Tensor,
|
| 895 |
+
input_dimensions: Tuple[int, int],
|
| 896 |
+
time: torch.Tensor,
|
| 897 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 898 |
+
output_attentions: Optional[bool] = False,
|
| 899 |
+
always_partition: Optional[bool] = False,
|
| 900 |
+
) -> Tuple[torch.Tensor]:
|
| 901 |
+
height, width = input_dimensions
|
| 902 |
+
|
| 903 |
+
for i, layer_module in enumerate(self.blocks):
|
| 904 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 905 |
+
|
| 906 |
+
layer_outputs = layer_module(
|
| 907 |
+
hidden_states,
|
| 908 |
+
input_dimensions,
|
| 909 |
+
time,
|
| 910 |
+
layer_head_mask,
|
| 911 |
+
output_attentions,
|
| 912 |
+
always_partition,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
hidden_states = layer_outputs[0]
|
| 916 |
+
|
| 917 |
+
hidden_states_before_upsampling = hidden_states
|
| 918 |
+
if self.upsample is not None:
|
| 919 |
+
height_upsampled, width_upsampled = self.upsampled_size
|
| 920 |
+
output_dimensions = (height, width, height_upsampled, width_upsampled)
|
| 921 |
+
hidden_states = self.upsample(
|
| 922 |
+
hidden_states_before_upsampling,
|
| 923 |
+
(height_upsampled, width_upsampled),
|
| 924 |
+
time,
|
| 925 |
+
)
|
| 926 |
+
else:
|
| 927 |
+
output_dimensions = (height, width, height, width)
|
| 928 |
+
|
| 929 |
+
stage_outputs = (
|
| 930 |
+
hidden_states,
|
| 931 |
+
hidden_states_before_upsampling,
|
| 932 |
+
output_dimensions,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
if output_attentions:
|
| 936 |
+
stage_outputs += layer_outputs[1:]
|
| 937 |
+
return stage_outputs
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
class ScOTEncoder(nn.Module):
|
| 941 |
+
"""
|
| 942 |
+
This is just a Swinv2Encoder with changed dpr.
|
| 943 |
+
We just have to change the drop path rate since we also have a decoder by default.
|
| 944 |
+
"""
|
| 945 |
+
|
| 946 |
+
def __init__(self, config, grid_size, pretrained_window_sizes=(0, 0, 0, 0)):
|
| 947 |
+
super().__init__()
|
| 948 |
+
self.num_layers = len(config.depths)
|
| 949 |
+
self.config = config
|
| 950 |
+
if self.config.pretrained_window_sizes is not None:
|
| 951 |
+
pretrained_window_sizes = config.pretrained_window_sizes
|
| 952 |
+
drop_rates_encode_decode = torch.linspace(
|
| 953 |
+
0, config.drop_path_rate, 2 * sum(config.depths)
|
| 954 |
+
)
|
| 955 |
+
dpr = [
|
| 956 |
+
x.item()
|
| 957 |
+
for x in drop_rates_encode_decode[: drop_rates_encode_decode.shape[0] // 2]
|
| 958 |
+
]
|
| 959 |
+
self.layers = nn.ModuleList(
|
| 960 |
+
[
|
| 961 |
+
ScOTEncodeStage(
|
| 962 |
+
config=config,
|
| 963 |
+
dim=int(config.embed_dim * 2**i_layer),
|
| 964 |
+
input_resolution=(
|
| 965 |
+
grid_size[0] // (2**i_layer),
|
| 966 |
+
grid_size[1] // (2**i_layer),
|
| 967 |
+
),
|
| 968 |
+
depth=config.depths[i_layer],
|
| 969 |
+
num_heads=config.num_heads[i_layer],
|
| 970 |
+
drop_path=dpr[
|
| 971 |
+
sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])
|
| 972 |
+
],
|
| 973 |
+
downsample=(
|
| 974 |
+
ScOTPatchMerging if (i_layer < self.num_layers - 1) else None
|
| 975 |
+
),
|
| 976 |
+
pretrained_window_size=pretrained_window_sizes[i_layer],
|
| 977 |
+
)
|
| 978 |
+
for i_layer in range(self.num_layers)
|
| 979 |
+
]
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
self.gradient_checkpointing = False
|
| 983 |
+
|
| 984 |
+
def forward(
|
| 985 |
+
self,
|
| 986 |
+
hidden_states: torch.Tensor,
|
| 987 |
+
input_dimensions: Tuple[int, int],
|
| 988 |
+
time: torch.Tensor,
|
| 989 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 990 |
+
output_attentions: Optional[bool] = False,
|
| 991 |
+
output_hidden_states: Optional[bool] = False,
|
| 992 |
+
output_hidden_states_before_downsampling: Optional[bool] = False,
|
| 993 |
+
always_partition: Optional[bool] = False,
|
| 994 |
+
return_dict: Optional[bool] = True,
|
| 995 |
+
) -> Union[Tuple, Swinv2EncoderOutput]:
|
| 996 |
+
all_hidden_states = () if output_hidden_states else None
|
| 997 |
+
all_reshaped_hidden_states = () if output_hidden_states else None
|
| 998 |
+
all_self_attentions = () if output_attentions else None
|
| 999 |
+
|
| 1000 |
+
if output_hidden_states:
|
| 1001 |
+
batch_size, _, hidden_size = hidden_states.shape
|
| 1002 |
+
# rearrange b (h w) c -> b c h w
|
| 1003 |
+
reshaped_hidden_state = hidden_states.view(
|
| 1004 |
+
batch_size, *input_dimensions, hidden_size
|
| 1005 |
+
)
|
| 1006 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
| 1007 |
+
all_hidden_states += (hidden_states,)
|
| 1008 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 1009 |
+
|
| 1010 |
+
for i, layer_module in enumerate(self.layers):
|
| 1011 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 1012 |
+
|
| 1013 |
+
if self.gradient_checkpointing and self.training:
|
| 1014 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1015 |
+
layer_module.__call__,
|
| 1016 |
+
hidden_states,
|
| 1017 |
+
input_dimensions,
|
| 1018 |
+
time,
|
| 1019 |
+
layer_head_mask,
|
| 1020 |
+
output_attentions,
|
| 1021 |
+
)
|
| 1022 |
+
else:
|
| 1023 |
+
layer_outputs = layer_module(
|
| 1024 |
+
hidden_states,
|
| 1025 |
+
input_dimensions,
|
| 1026 |
+
time,
|
| 1027 |
+
layer_head_mask,
|
| 1028 |
+
output_attentions,
|
| 1029 |
+
always_partition,
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
hidden_states = layer_outputs[0]
|
| 1033 |
+
hidden_states_before_downsampling = layer_outputs[1]
|
| 1034 |
+
output_dimensions = layer_outputs[2]
|
| 1035 |
+
|
| 1036 |
+
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
| 1037 |
+
|
| 1038 |
+
if output_hidden_states and output_hidden_states_before_downsampling:
|
| 1039 |
+
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
|
| 1040 |
+
# rearrange b (h w) c -> b c h w
|
| 1041 |
+
# here we use the original (not downsampled) height and width
|
| 1042 |
+
reshaped_hidden_state = hidden_states_before_downsampling.view(
|
| 1043 |
+
batch_size,
|
| 1044 |
+
*(output_dimensions[0], output_dimensions[1]),
|
| 1045 |
+
hidden_size,
|
| 1046 |
+
)
|
| 1047 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
| 1048 |
+
all_hidden_states += (hidden_states_before_downsampling,)
|
| 1049 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 1050 |
+
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
| 1051 |
+
batch_size, _, hidden_size = hidden_states.shape
|
| 1052 |
+
# rearrange b (h w) c -> b c h w
|
| 1053 |
+
reshaped_hidden_state = hidden_states.view(
|
| 1054 |
+
batch_size, *input_dimensions, hidden_size
|
| 1055 |
+
)
|
| 1056 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
| 1057 |
+
all_hidden_states += (hidden_states,)
|
| 1058 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 1059 |
+
|
| 1060 |
+
if output_attentions:
|
| 1061 |
+
all_self_attentions += layer_outputs[3:]
|
| 1062 |
+
|
| 1063 |
+
if not return_dict:
|
| 1064 |
+
return tuple(
|
| 1065 |
+
v
|
| 1066 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
| 1067 |
+
if v is not None
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
return Swinv2EncoderOutput(
|
| 1071 |
+
last_hidden_state=hidden_states,
|
| 1072 |
+
hidden_states=all_hidden_states,
|
| 1073 |
+
attentions=all_self_attentions,
|
| 1074 |
+
reshaped_hidden_states=all_reshaped_hidden_states,
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
class ScOTDecoder(nn.Module):
|
| 1079 |
+
"""Here we do reverse encoder."""
|
| 1080 |
+
|
| 1081 |
+
def __init__(self, config, grid_size, pretrained_window_sizes=(0, 0, 0, 0)):
|
| 1082 |
+
super().__init__()
|
| 1083 |
+
self.num_layers = len(config.depths)
|
| 1084 |
+
self.config = config
|
| 1085 |
+
if self.config.pretrained_window_sizes is not None:
|
| 1086 |
+
pretrained_window_sizes = config.pretrained_window_sizes
|
| 1087 |
+
drop_rates_encode_decode = torch.linspace(
|
| 1088 |
+
0, config.drop_path_rate, 2 * sum(config.depths)
|
| 1089 |
+
)
|
| 1090 |
+
dpr = [
|
| 1091 |
+
x.item()
|
| 1092 |
+
for x in drop_rates_encode_decode[drop_rates_encode_decode.shape[0] // 2 :]
|
| 1093 |
+
]
|
| 1094 |
+
self.layers = nn.ModuleList(
|
| 1095 |
+
[
|
| 1096 |
+
ScOTDecodeStage(
|
| 1097 |
+
config=config,
|
| 1098 |
+
dim=int(config.embed_dim * 2**i_layer),
|
| 1099 |
+
input_resolution=(
|
| 1100 |
+
grid_size[0] // (2**i_layer),
|
| 1101 |
+
grid_size[1] // (2**i_layer),
|
| 1102 |
+
),
|
| 1103 |
+
depth=config.depths[i_layer],
|
| 1104 |
+
num_heads=config.num_heads[i_layer],
|
| 1105 |
+
drop_path=dpr[
|
| 1106 |
+
sum(config.depths[i_layer + 1 :]) : sum(config.depths[i_layer:])
|
| 1107 |
+
],
|
| 1108 |
+
upsample=ScOTPatchUnmerging if i_layer > 0 else None,
|
| 1109 |
+
upsampled_size=(
|
| 1110 |
+
grid_size[0] // (2 ** (i_layer - 1)),
|
| 1111 |
+
grid_size[1] // (2 ** (i_layer - 1)),
|
| 1112 |
+
),
|
| 1113 |
+
pretrained_window_size=pretrained_window_sizes[i_layer],
|
| 1114 |
+
)
|
| 1115 |
+
for i_layer in reversed(range(self.num_layers))
|
| 1116 |
+
]
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
self.gradient_checkpointing = False
|
| 1120 |
+
|
| 1121 |
+
def forward(
|
| 1122 |
+
self,
|
| 1123 |
+
hidden_states: torch.Tensor,
|
| 1124 |
+
input_dimensions: Tuple[int, int],
|
| 1125 |
+
skip_states: List[torch.FloatTensor],
|
| 1126 |
+
time: torch.Tensor,
|
| 1127 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1128 |
+
output_attentions: Optional[bool] = False,
|
| 1129 |
+
output_hidden_states: Optional[bool] = False,
|
| 1130 |
+
output_hidden_states_before_upsampling: Optional[bool] = False,
|
| 1131 |
+
always_partition: Optional[bool] = False,
|
| 1132 |
+
return_dict: Optional[bool] = True,
|
| 1133 |
+
) -> Union[Tuple, Swinv2EncoderOutput]:
|
| 1134 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1135 |
+
all_reshaped_hidden_states = () if output_hidden_states else None
|
| 1136 |
+
all_self_attentions = () if output_attentions else None
|
| 1137 |
+
|
| 1138 |
+
if output_hidden_states:
|
| 1139 |
+
batch_size, _, hidden_size = hidden_states.shape
|
| 1140 |
+
# rearrange b (h w) c -> b c h w
|
| 1141 |
+
reshaped_hidden_state = hidden_states.view(
|
| 1142 |
+
batch_size, *input_dimensions, hidden_size
|
| 1143 |
+
)
|
| 1144 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
| 1145 |
+
all_hidden_states += (hidden_states,)
|
| 1146 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 1147 |
+
|
| 1148 |
+
for i, layer_module in enumerate(self.layers):
|
| 1149 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 1150 |
+
|
| 1151 |
+
if i != 0 and skip_states[len(skip_states) - i] is not None:
|
| 1152 |
+
# residual connection
|
| 1153 |
+
hidden_states = hidden_states + skip_states[len(skip_states) - i]
|
| 1154 |
+
if self.gradient_checkpointing and self.training:
|
| 1155 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1156 |
+
layer_module.__call__,
|
| 1157 |
+
hidden_states,
|
| 1158 |
+
input_dimensions,
|
| 1159 |
+
time,
|
| 1160 |
+
layer_head_mask,
|
| 1161 |
+
output_attentions,
|
| 1162 |
+
)
|
| 1163 |
+
else:
|
| 1164 |
+
layer_outputs = layer_module(
|
| 1165 |
+
hidden_states,
|
| 1166 |
+
input_dimensions,
|
| 1167 |
+
time,
|
| 1168 |
+
layer_head_mask,
|
| 1169 |
+
output_attentions,
|
| 1170 |
+
always_partition,
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
hidden_states = layer_outputs[0]
|
| 1174 |
+
hidden_states_before_upsampling = layer_outputs[1]
|
| 1175 |
+
output_dimensions = layer_outputs[2]
|
| 1176 |
+
|
| 1177 |
+
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
| 1178 |
+
|
| 1179 |
+
if output_hidden_states and output_hidden_states_before_upsampling:
|
| 1180 |
+
batch_size, _, hidden_size = hidden_states_before_upsampling.shape
|
| 1181 |
+
# rearrange b (h w) c -> b c h w
|
| 1182 |
+
# here we use the original (not downsampled) height and width
|
| 1183 |
+
reshaped_hidden_state = hidden_states_before_upsampling.view(
|
| 1184 |
+
batch_size,
|
| 1185 |
+
*(output_dimensions[0], output_dimensions[1]),
|
| 1186 |
+
hidden_size,
|
| 1187 |
+
)
|
| 1188 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
| 1189 |
+
all_hidden_states += (hidden_states_before_upsampling,)
|
| 1190 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 1191 |
+
elif output_hidden_states and not output_hidden_states_before_upsampling:
|
| 1192 |
+
batch_size, _, hidden_size = hidden_states.shape
|
| 1193 |
+
# rearrange b (h w) c -> b c h w
|
| 1194 |
+
reshaped_hidden_state = hidden_states.view(
|
| 1195 |
+
batch_size, *input_dimensions, hidden_size
|
| 1196 |
+
)
|
| 1197 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
| 1198 |
+
all_hidden_states += (hidden_states,)
|
| 1199 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
| 1200 |
+
|
| 1201 |
+
if output_attentions:
|
| 1202 |
+
all_self_attentions += layer_outputs[3:]
|
| 1203 |
+
|
| 1204 |
+
if not return_dict:
|
| 1205 |
+
return tuple(
|
| 1206 |
+
v
|
| 1207 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
| 1208 |
+
if v is not None
|
| 1209 |
+
)
|
| 1210 |
+
|
| 1211 |
+
return Swinv2EncoderOutput(
|
| 1212 |
+
last_hidden_state=hidden_states,
|
| 1213 |
+
hidden_states=all_hidden_states,
|
| 1214 |
+
attentions=all_self_attentions,
|
| 1215 |
+
reshaped_hidden_states=all_reshaped_hidden_states,
|
| 1216 |
+
)
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
class ScOT(Swinv2PreTrainedModel):
|
| 1220 |
+
"""Inspired by https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/swinv2/modeling_swinv2.py#L1129"""
|
| 1221 |
+
|
| 1222 |
+
def __init__(self, config, use_mask_token=False):
|
| 1223 |
+
super().__init__(config)
|
| 1224 |
+
|
| 1225 |
+
self.config = config
|
| 1226 |
+
self.num_layers_encoder = len(config.depths)
|
| 1227 |
+
self.num_layers_decoder = len(config.depths)
|
| 1228 |
+
self.num_features = int(config.embed_dim * 2 ** (self.num_layers_encoder - 1))
|
| 1229 |
+
|
| 1230 |
+
self.embeddings = ScOTEmbeddings(config, use_mask_token=use_mask_token)
|
| 1231 |
+
self.encoder = ScOTEncoder(config, self.embeddings.patch_grid)
|
| 1232 |
+
self.decoder = ScOTDecoder(config, self.embeddings.patch_grid)
|
| 1233 |
+
self.patch_recovery = ScOTPatchRecovery(config)
|
| 1234 |
+
|
| 1235 |
+
if config.residual_model == "convnext":
|
| 1236 |
+
res_model = ConvNeXtBlock
|
| 1237 |
+
elif config.residual_model == "resnet":
|
| 1238 |
+
res_model = ResNetBlock
|
| 1239 |
+
else:
|
| 1240 |
+
raise ValueError("residual_model must be 'convnext' or 'resnet'")
|
| 1241 |
+
|
| 1242 |
+
self.residual_blocks = nn.ModuleList(
|
| 1243 |
+
[
|
| 1244 |
+
(
|
| 1245 |
+
nn.ModuleList(
|
| 1246 |
+
[
|
| 1247 |
+
res_model(config, config.embed_dim * 2**i)
|
| 1248 |
+
for _ in range(depth)
|
| 1249 |
+
]
|
| 1250 |
+
)
|
| 1251 |
+
if depth > 0
|
| 1252 |
+
else nn.ModuleList([nn.Identity()])
|
| 1253 |
+
)
|
| 1254 |
+
for i, depth in enumerate(config.skip_connections)
|
| 1255 |
+
]
|
| 1256 |
+
)
|
| 1257 |
+
|
| 1258 |
+
self.post_init()
|
| 1259 |
+
|
| 1260 |
+
def get_input_embeddings(self):
|
| 1261 |
+
return self.embeddings.patch_embeddings
|
| 1262 |
+
|
| 1263 |
+
def _prune_heads(self, heads_to_prune):
|
| 1264 |
+
for layer, heads in heads_to_prune.items():
|
| 1265 |
+
self.encoder.layers[layer].attention.prune_heads(heads)
|
| 1266 |
+
for layer, heads in reversed(heads_to_prune.items()):
|
| 1267 |
+
self.decoder.layers[layer].attention.prune_heads(heads)
|
| 1268 |
+
|
| 1269 |
+
def _downsample(self, image, target_size):
|
| 1270 |
+
image_size = image.shape[-2]
|
| 1271 |
+
freqs = torch.fft.fftfreq(image_size, d=1 / image_size)
|
| 1272 |
+
sel = torch.logical_and(freqs >= -target_size / 2, freqs <= target_size / 2 - 1)
|
| 1273 |
+
image_hat = torch.fft.fft2(image, norm="forward")
|
| 1274 |
+
image_hat = image_hat[:, :, sel, :][:, :, :, sel]
|
| 1275 |
+
image = torch.fft.ifft2(image_hat, norm="forward").real
|
| 1276 |
+
return image
|
| 1277 |
+
|
| 1278 |
+
def _upsample(self, image, target_size):
|
| 1279 |
+
# https://stackoverflow.com/questions/71143279/upsampling-images-in-frequency-domain-using-pytorch
|
| 1280 |
+
image_size = image.shape[-2]
|
| 1281 |
+
image_hat = torch.fft.fft2(image, norm="forward")
|
| 1282 |
+
image_hat = torch.fft.fftshift(image_hat)
|
| 1283 |
+
pad_size = (target_size - image_size) // 2
|
| 1284 |
+
real = nn.functional.pad(
|
| 1285 |
+
image_hat.real, (pad_size, pad_size, pad_size, pad_size), value=0.0
|
| 1286 |
+
)
|
| 1287 |
+
imag = nn.functional.pad(
|
| 1288 |
+
image_hat.imag, (pad_size, pad_size, pad_size, pad_size), value=0.0
|
| 1289 |
+
)
|
| 1290 |
+
image_hat = torch.fft.ifftshift(torch.complex(real, imag))
|
| 1291 |
+
image = torch.fft.ifft2(image_hat, norm="forward").real
|
| 1292 |
+
return image
|
| 1293 |
+
|
| 1294 |
+
def forward(
|
| 1295 |
+
self,
|
| 1296 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1297 |
+
time: Optional[torch.FloatTensor] = None,
|
| 1298 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 1299 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1300 |
+
pixel_mask: Optional[torch.BoolTensor] = None,
|
| 1301 |
+
labels: Optional[torch.FloatTensor] = None,
|
| 1302 |
+
output_attentions: Optional[bool] = None,
|
| 1303 |
+
output_hidden_states: Optional[bool] = None,
|
| 1304 |
+
return_dict: Optional[bool] = None,
|
| 1305 |
+
) -> Union[Tuple, ScOTOutput]:
|
| 1306 |
+
return_dict = (
|
| 1307 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1308 |
+
)
|
| 1309 |
+
|
| 1310 |
+
output_attentions = (
|
| 1311 |
+
output_attentions
|
| 1312 |
+
if output_attentions is not None
|
| 1313 |
+
else self.config.output_attentions
|
| 1314 |
+
)
|
| 1315 |
+
output_hidden_states = (
|
| 1316 |
+
output_hidden_states
|
| 1317 |
+
if output_hidden_states is not None
|
| 1318 |
+
else self.config.output_hidden_states
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
if pixel_values is None:
|
| 1322 |
+
raise ValueError("pixel_values cannot be None")
|
| 1323 |
+
|
| 1324 |
+
head_mask = self.get_head_mask(
|
| 1325 |
+
head_mask, self.num_layers_encoder + self.num_layers_decoder
|
| 1326 |
+
)
|
| 1327 |
+
|
| 1328 |
+
if isinstance(head_mask, list):
|
| 1329 |
+
head_mask_encoder = head_mask[: self.num_layers_encoder]
|
| 1330 |
+
head_mask_decoder = head_mask[self.num_layers_encoder :]
|
| 1331 |
+
else:
|
| 1332 |
+
head_mask_encoder, head_mask_decoder = head_mask.split(
|
| 1333 |
+
[self.num_layers_encoder, self.num_layers_decoder]
|
| 1334 |
+
)
|
| 1335 |
+
|
| 1336 |
+
image_size = pixel_values.shape[2]
|
| 1337 |
+
# image must be square
|
| 1338 |
+
if image_size != self.config.image_size:
|
| 1339 |
+
if image_size < self.config.image_size:
|
| 1340 |
+
pixel_values = self._upsample(pixel_values, self.config.image_size)
|
| 1341 |
+
else:
|
| 1342 |
+
pixel_values = self._downsample(pixel_values, self.config.image_size)
|
| 1343 |
+
|
| 1344 |
+
embedding_output, input_dimensions = self.embeddings(
|
| 1345 |
+
pixel_values, bool_masked_pos=bool_masked_pos, time=time
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
encoder_outputs = self.encoder(
|
| 1349 |
+
embedding_output,
|
| 1350 |
+
input_dimensions,
|
| 1351 |
+
time,
|
| 1352 |
+
head_mask=head_mask_encoder,
|
| 1353 |
+
output_attentions=output_attentions,
|
| 1354 |
+
output_hidden_states=True,
|
| 1355 |
+
output_hidden_states_before_downsampling=True,
|
| 1356 |
+
return_dict=return_dict,
|
| 1357 |
+
)
|
| 1358 |
+
|
| 1359 |
+
if return_dict:
|
| 1360 |
+
skip_states = list(encoder_outputs.hidden_states[1:])
|
| 1361 |
+
else:
|
| 1362 |
+
skip_states = list(encoder_outputs[1][1:])
|
| 1363 |
+
|
| 1364 |
+
for i in range(len(skip_states)):
|
| 1365 |
+
for block in self.residual_blocks[i]:
|
| 1366 |
+
if isinstance(block, nn.Identity):
|
| 1367 |
+
skip_states[i] = block(skip_states[i])
|
| 1368 |
+
else:
|
| 1369 |
+
skip_states[i] = block(skip_states[i], time)
|
| 1370 |
+
|
| 1371 |
+
#! assumes square images
|
| 1372 |
+
input_dim = math.floor(skip_states[-1].shape[1] ** 0.5)
|
| 1373 |
+
decoder_output = self.decoder(
|
| 1374 |
+
skip_states[-1],
|
| 1375 |
+
(input_dim, input_dim),
|
| 1376 |
+
time=time,
|
| 1377 |
+
skip_states=skip_states[:-1],
|
| 1378 |
+
head_mask=head_mask_decoder,
|
| 1379 |
+
output_attentions=output_attentions,
|
| 1380 |
+
output_hidden_states=output_hidden_states,
|
| 1381 |
+
return_dict=return_dict,
|
| 1382 |
+
)
|
| 1383 |
+
|
| 1384 |
+
sequence_output = decoder_output[0]
|
| 1385 |
+
prediction = self.patch_recovery(sequence_output)
|
| 1386 |
+
# The following can be used for learning just the residual for time-dependent problems
|
| 1387 |
+
if self.config.learn_residual:
|
| 1388 |
+
if self.config.num_channels > self.config.num_out_channels:
|
| 1389 |
+
pixel_values = pixel_values[:, 0 : self.config.num_out_channels]
|
| 1390 |
+
prediction += pixel_values
|
| 1391 |
+
|
| 1392 |
+
if image_size != self.config.image_size:
|
| 1393 |
+
if image_size > self.config.image_size:
|
| 1394 |
+
prediction = self._upsample(prediction, image_size)
|
| 1395 |
+
else:
|
| 1396 |
+
prediction = self._downsample(prediction, image_size)
|
| 1397 |
+
|
| 1398 |
+
if pixel_mask is not None:
|
| 1399 |
+
prediction[pixel_mask] = labels[pixel_mask].type_as(prediction)
|
| 1400 |
+
loss = None
|
| 1401 |
+
if labels is not None:
|
| 1402 |
+
if self.config.p == 1:
|
| 1403 |
+
loss_fn = nn.functional.l1_loss
|
| 1404 |
+
elif self.config.p == 2:
|
| 1405 |
+
loss_fn = nn.functional.mse_loss
|
| 1406 |
+
else:
|
| 1407 |
+
raise ValueError("p must be 1 or 2")
|
| 1408 |
+
if self.config.channel_slice_list_normalized_loss is not None:
|
| 1409 |
+
loss = torch.mean(
|
| 1410 |
+
torch.stack(
|
| 1411 |
+
[
|
| 1412 |
+
loss_fn(
|
| 1413 |
+
prediction[
|
| 1414 |
+
:,
|
| 1415 |
+
self.config.channel_slice_list_normalized_loss[
|
| 1416 |
+
i
|
| 1417 |
+
] : self.config.channel_slice_list_normalized_loss[
|
| 1418 |
+
i + 1
|
| 1419 |
+
],
|
| 1420 |
+
],
|
| 1421 |
+
labels[
|
| 1422 |
+
:,
|
| 1423 |
+
self.config.channel_slice_list_normalized_loss[
|
| 1424 |
+
i
|
| 1425 |
+
] : self.config.channel_slice_list_normalized_loss[
|
| 1426 |
+
i + 1
|
| 1427 |
+
],
|
| 1428 |
+
],
|
| 1429 |
+
)
|
| 1430 |
+
/ (
|
| 1431 |
+
loss_fn(
|
| 1432 |
+
labels[
|
| 1433 |
+
:,
|
| 1434 |
+
self.config.channel_slice_list_normalized_loss[
|
| 1435 |
+
i
|
| 1436 |
+
] : self.config.channel_slice_list_normalized_loss[
|
| 1437 |
+
i + 1
|
| 1438 |
+
],
|
| 1439 |
+
],
|
| 1440 |
+
torch.zeros_like(
|
| 1441 |
+
labels[
|
| 1442 |
+
:,
|
| 1443 |
+
self.config.channel_slice_list_normalized_loss[
|
| 1444 |
+
i
|
| 1445 |
+
] : self.config.channel_slice_list_normalized_loss[
|
| 1446 |
+
i + 1
|
| 1447 |
+
],
|
| 1448 |
+
]
|
| 1449 |
+
),
|
| 1450 |
+
)
|
| 1451 |
+
+ 1e-10
|
| 1452 |
+
)
|
| 1453 |
+
for i in range(
|
| 1454 |
+
len(self.config.channel_slice_list_normalized_loss) - 1
|
| 1455 |
+
)
|
| 1456 |
+
]
|
| 1457 |
+
)
|
| 1458 |
+
)
|
| 1459 |
+
else:
|
| 1460 |
+
loss = loss_fn(prediction, labels)
|
| 1461 |
+
|
| 1462 |
+
if not return_dict:
|
| 1463 |
+
output = (prediction,) + decoder_output[1:] + encoder_outputs[1:]
|
| 1464 |
+
return ((loss,) + output) if loss is not None else output
|
| 1465 |
+
|
| 1466 |
+
return ScOTOutput(
|
| 1467 |
+
loss=loss,
|
| 1468 |
+
output=prediction,
|
| 1469 |
+
hidden_states=(
|
| 1470 |
+
decoder_output.hidden_states + encoder_outputs.hidden_states
|
| 1471 |
+
if output_hidden_states is not None and output_hidden_states is True
|
| 1472 |
+
else None
|
| 1473 |
+
),
|
| 1474 |
+
attentions=(
|
| 1475 |
+
decoder_output.attentions + encoder_outputs.attentions
|
| 1476 |
+
if output_attentions is not None and output_attentions is True
|
| 1477 |
+
else None
|
| 1478 |
+
),
|
| 1479 |
+
reshaped_hidden_states=(
|
| 1480 |
+
decoder_output.reshaped_hidden_states
|
| 1481 |
+
+ encoder_outputs.reshaped_hidden_states
|
| 1482 |
+
if output_hidden_states is not None and output_hidden_states is True
|
| 1483 |
+
else None
|
| 1484 |
+
),
|
| 1485 |
+
)
|
external/poseidon/scOT/problems/__init__.py
ADDED
|
File without changes
|
external/poseidon/scOT/problems/base.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains the dataset selector get_dataset, as well as the base
|
| 3 |
+
classes for all datasets.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 7 |
+
from typing import Optional, List, Dict
|
| 8 |
+
from abc import ABC
|
| 9 |
+
import re
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
from accelerate.utils import broadcast_object_list
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_dataset(dataset, **kwargs):
|
| 16 |
+
"""
|
| 17 |
+
Get a dataset by name.
|
| 18 |
+
If you enter a list of str, will return a ConcatDataset of the datasets.
|
| 19 |
+
|
| 20 |
+
Available choices are:
|
| 21 |
+
- fluids.incompressible.BrownianBridge(.tracer)
|
| 22 |
+
- fluids.incompressible.Gaussians(.tracer)
|
| 23 |
+
- fluids.incompressible.ShearLayer
|
| 24 |
+
- fluids.incompressible.Sines(.tracer)
|
| 25 |
+
- fluids.incompressible.PiecewiseConstants(.tracer)
|
| 26 |
+
- fluids.incompressible.VortexSheet(.tracer)
|
| 27 |
+
- fluids.incompressible.forcing.KolmogorovFlow
|
| 28 |
+
- fluids.compressible.gravity.RayleighTaylor(.tracer)
|
| 29 |
+
- fluids.compressible.RiemannKelvinHelmholtz
|
| 30 |
+
- fluids.compressible.RiemannCurved
|
| 31 |
+
- fluids.compressible.Riemann
|
| 32 |
+
- fluids.compressible.KelvinHelmholtz
|
| 33 |
+
- fluids.compressible.Gaussians
|
| 34 |
+
- fluids.compressible.RichtmyerMeshkov(.tracer)
|
| 35 |
+
- fluids.compressible.steady.Airfoil(.time)
|
| 36 |
+
- elliptic.poisson.Gaussians(.time)
|
| 37 |
+
- elliptic.Helmholtz(.time)
|
| 38 |
+
- wave.Layer
|
| 39 |
+
- wave.Gaussians
|
| 40 |
+
- reaction_diffusion.AllenCahn
|
| 41 |
+
|
| 42 |
+
Adding .out at the end of the str, returns a dataset with more time steps.
|
| 43 |
+
**kwargs overwrite the default settings.
|
| 44 |
+
.time is a time-wrapped time-independent dataset.
|
| 45 |
+
"""
|
| 46 |
+
if isinstance(dataset, list):
|
| 47 |
+
return ConcatDataset([get_dataset(d, **kwargs) for d in dataset])
|
| 48 |
+
if "fluids" in dataset:
|
| 49 |
+
if "fluids.incompressible" in dataset:
|
| 50 |
+
if "BrownianBridge" in dataset:
|
| 51 |
+
from .fluids.incompressible import BrownianBridge as dset
|
| 52 |
+
elif "Gaussians" in dataset:
|
| 53 |
+
from .fluids.incompressible import Gaussians as dset
|
| 54 |
+
elif "ShearLayer" in dataset:
|
| 55 |
+
from .fluids.incompressible import ShearLayer as dset
|
| 56 |
+
elif "Sines" in dataset:
|
| 57 |
+
from .fluids.incompressible import Sines as dset
|
| 58 |
+
elif "PiecewiseConstants" in dataset:
|
| 59 |
+
from .fluids.incompressible import PiecewiseConstants as dset
|
| 60 |
+
elif "VortexSheet" in dataset:
|
| 61 |
+
from .fluids.incompressible import VortexSheet as dset
|
| 62 |
+
elif "forcing" in dataset:
|
| 63 |
+
if "KolmogorovFlow" in dataset:
|
| 64 |
+
from .fluids.incompressible import KolmogorovFlow as dset
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 69 |
+
elif "fluids.compressible" in dataset:
|
| 70 |
+
if "gravity" in dataset:
|
| 71 |
+
if "RayleighTaylor" in dataset:
|
| 72 |
+
from .fluids.compressible import RayleighTaylor as dset
|
| 73 |
+
|
| 74 |
+
if "out" in dataset:
|
| 75 |
+
default_time_settings = {
|
| 76 |
+
"max_num_time_steps": 10,
|
| 77 |
+
"time_step_size": 1,
|
| 78 |
+
}
|
| 79 |
+
else:
|
| 80 |
+
default_time_settings = {
|
| 81 |
+
"max_num_time_steps": 7,
|
| 82 |
+
"time_step_size": 1,
|
| 83 |
+
}
|
| 84 |
+
kwargs = {**default_time_settings, **kwargs}
|
| 85 |
+
elif "Blast" in dataset:
|
| 86 |
+
from .fluids.compressible import Blast as dset
|
| 87 |
+
elif "RiemannKelvinHelmholtz" in dataset:
|
| 88 |
+
from .fluids.compressible import RiemannKelvinHelmholtz as dset
|
| 89 |
+
elif "RiemannCurved" in dataset:
|
| 90 |
+
from .fluids.compressible import RiemannCurved as dset
|
| 91 |
+
elif "Riemann" in dataset:
|
| 92 |
+
from .fluids.compressible import Riemann as dset
|
| 93 |
+
elif "KelvinHelmholtz" in dataset:
|
| 94 |
+
from .fluids.compressible import KelvinHelmholtz as dset
|
| 95 |
+
elif "Gaussians" in dataset:
|
| 96 |
+
from .fluids.compressible import Gaussians as dset
|
| 97 |
+
elif "RichtmyerMeshkov" in dataset:
|
| 98 |
+
from .fluids.compressible import RichtmyerMeshkov as dset
|
| 99 |
+
elif "steady" in dataset:
|
| 100 |
+
if "steady.Airfoil" in dataset:
|
| 101 |
+
from .fluids.compressible import Airfoil as dset
|
| 102 |
+
|
| 103 |
+
if "out" in dataset:
|
| 104 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 107 |
+
else:
|
| 108 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 111 |
+
if "out" in dataset:
|
| 112 |
+
default_time_settings = {"max_num_time_steps": 10, "time_step_size": 2}
|
| 113 |
+
else:
|
| 114 |
+
default_time_settings = {"max_num_time_steps": 7, "time_step_size": 2}
|
| 115 |
+
if "tracer" in dataset:
|
| 116 |
+
tracer = True
|
| 117 |
+
else:
|
| 118 |
+
tracer = False
|
| 119 |
+
if not "steady" in dataset:
|
| 120 |
+
kwargs = {"tracer": tracer, **default_time_settings, **kwargs}
|
| 121 |
+
elif "elliptic" in dataset:
|
| 122 |
+
if ".out" in dataset:
|
| 123 |
+
raise NotImplementedError(f"Unknown dataset {dataset}")
|
| 124 |
+
if "elliptic.poisson" in dataset:
|
| 125 |
+
if "Gaussians" in dataset:
|
| 126 |
+
from .elliptic.poisson import Gaussians as dset
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 129 |
+
elif "elliptic.Helmholtz" in dataset:
|
| 130 |
+
from .elliptic.helmholtz import Helmholtz as dset
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 133 |
+
elif "wave" in dataset:
|
| 134 |
+
if "wave.Layer" in dataset:
|
| 135 |
+
if "out" in dataset:
|
| 136 |
+
default_time_settings = {"max_num_time_steps": 10, "time_step_size": 2}
|
| 137 |
+
else:
|
| 138 |
+
default_time_settings = {"max_num_time_steps": 7, "time_step_size": 2}
|
| 139 |
+
kwargs = {**default_time_settings, **kwargs}
|
| 140 |
+
from .wave.acoustic import Layer as dset
|
| 141 |
+
elif "wave.Gaussians" in dataset:
|
| 142 |
+
if "out" in dataset:
|
| 143 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 144 |
+
else:
|
| 145 |
+
default_time_settings = {"max_num_time_steps": 7, "time_step_size": 2}
|
| 146 |
+
kwargs = {**default_time_settings, **kwargs}
|
| 147 |
+
from .wave.acoustic import Gaussians as dset
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 150 |
+
elif "reaction_diffusion" in dataset:
|
| 151 |
+
if "reaction_diffusion.AllenCahn" in dataset:
|
| 152 |
+
if "out" in dataset:
|
| 153 |
+
default_time_settings = {"max_num_time_steps": 9, "time_step_size": 2}
|
| 154 |
+
else:
|
| 155 |
+
default_time_settings = {"max_num_time_steps": 7, "time_step_size": 2}
|
| 156 |
+
kwargs = {**default_time_settings, **kwargs}
|
| 157 |
+
from .reaction_diffusion.allen_cahn import AllenCahn as dset
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"Unknown dataset {dataset}")
|
| 160 |
+
|
| 161 |
+
return dset(**kwargs) if ".time" not in dataset else TimeWrapper(dset(**kwargs))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class BaseDataset(Dataset, ABC):
|
| 165 |
+
"""A base class for all datasets. Can be directly derived from if you have a steady/non-time dependent problem."""
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
which: Optional[str] = None,
|
| 170 |
+
num_trajectories: Optional[int] = None,
|
| 171 |
+
data_path: Optional[str] = "./data",
|
| 172 |
+
move_to_local_scratch: Optional[str] = None,
|
| 173 |
+
) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Args:
|
| 176 |
+
which: Which dataset to use, i.e. train, val, or test.
|
| 177 |
+
resolution: The resolution of the dataset.
|
| 178 |
+
num_trajectories: The number of trajectories to use for training.
|
| 179 |
+
data_path: The path to the data files.
|
| 180 |
+
move_to_local_scratch: If not None, move the data to this directory at dataset initialization and use it from there.
|
| 181 |
+
"""
|
| 182 |
+
assert which in ["train", "val", "test"]
|
| 183 |
+
assert num_trajectories is not None and (
|
| 184 |
+
num_trajectories > 0 or num_trajectories in [-1, -2, -8]
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self.num_trajectories = num_trajectories
|
| 188 |
+
self.data_path = data_path
|
| 189 |
+
self.which = which
|
| 190 |
+
self.move_to_local_scratch = move_to_local_scratch
|
| 191 |
+
|
| 192 |
+
def _move_to_local_scratch(self, file_path):
|
| 193 |
+
if self.move_to_local_scratch is not None:
|
| 194 |
+
data_dir = os.path.join(self.data_path, file_path)
|
| 195 |
+
file = file_path.split("/")[-1]
|
| 196 |
+
scratch_dir = self.move_to_local_scratch
|
| 197 |
+
dest_dir = os.path.join(scratch_dir, file)
|
| 198 |
+
RANK = int(os.environ.get("LOCAL_RANK", -1))
|
| 199 |
+
if not os.path.exists(dest_dir) and (RANK == 0 or RANK == -1):
|
| 200 |
+
print(f"Start copying {file} to {dest_dir}...")
|
| 201 |
+
shutil.copy(data_dir, dest_dir)
|
| 202 |
+
print("Finished data copy.")
|
| 203 |
+
# idk how to do the barrier differently
|
| 204 |
+
ls = broadcast_object_list([dest_dir], from_process=0)
|
| 205 |
+
dest_dir = ls[0]
|
| 206 |
+
return dest_dir
|
| 207 |
+
else:
|
| 208 |
+
return file_path
|
| 209 |
+
|
| 210 |
+
def post_init(self) -> None:
|
| 211 |
+
"""
|
| 212 |
+
Call after self.N_max, self.N_val, self.N_test, as well as the file_paths and normalization constants are set.
|
| 213 |
+
"""
|
| 214 |
+
assert (
|
| 215 |
+
self.N_max is not None
|
| 216 |
+
and self.N_max > 0
|
| 217 |
+
and self.N_max >= self.N_val + self.N_test
|
| 218 |
+
)
|
| 219 |
+
if self.num_trajectories == -1:
|
| 220 |
+
self.num_trajectories = self.N_max - self.N_val - self.N_test
|
| 221 |
+
elif self.num_trajectories == -2:
|
| 222 |
+
self.num_trajectories = (self.N_max - self.N_val - self.N_test) // 2
|
| 223 |
+
elif self.num_trajectories == -8:
|
| 224 |
+
self.num_trajectories = (self.N_max - self.N_val - self.N_test) // 8
|
| 225 |
+
assert self.num_trajectories + self.N_val + self.N_test <= self.N_max
|
| 226 |
+
assert self.N_val is not None and self.N_val > 0
|
| 227 |
+
assert self.N_test is not None and self.N_test > 0
|
| 228 |
+
if self.which == "train":
|
| 229 |
+
self.length = self.num_trajectories
|
| 230 |
+
self.start = 0
|
| 231 |
+
elif self.which == "val":
|
| 232 |
+
self.length = self.N_val
|
| 233 |
+
self.start = self.N_max - self.N_val - self.N_test
|
| 234 |
+
else:
|
| 235 |
+
self.length = self.N_test
|
| 236 |
+
self.start = self.N_max - self.N_test
|
| 237 |
+
|
| 238 |
+
self.output_dim = self.label_description.count(",") + 1
|
| 239 |
+
descriptors, channel_slice_list = self.get_channel_lists(self.label_description)
|
| 240 |
+
self.printable_channel_description = descriptors
|
| 241 |
+
self.channel_slice_list = channel_slice_list
|
| 242 |
+
|
| 243 |
+
def __len__(self) -> int:
|
| 244 |
+
"""
|
| 245 |
+
Returns: overall length of dataset.
|
| 246 |
+
"""
|
| 247 |
+
return self.length
|
| 248 |
+
|
| 249 |
+
def __getitem__(self, idx) -> Dict:
|
| 250 |
+
"""
|
| 251 |
+
Get an item. OVERWRITE!
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
idx: The index of the sample to get.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
A dict of key-value pairs of data.
|
| 258 |
+
"""
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
@staticmethod
|
| 262 |
+
def get_channel_lists(label_description):
|
| 263 |
+
matches = re.findall(r"\[([^\[\]]+)\]", label_description)
|
| 264 |
+
channel_slice_list = [0] # use as channel_slice_list[i]:channel_slice_list[i+1]
|
| 265 |
+
beautiful_descriptors = []
|
| 266 |
+
for match in matches:
|
| 267 |
+
channel_slice_list.append(channel_slice_list[-1] + 1 + match.count(","))
|
| 268 |
+
splt = match.split(",")
|
| 269 |
+
if len(splt) > 1:
|
| 270 |
+
beautiful_descriptors.append("".join(splt))
|
| 271 |
+
else:
|
| 272 |
+
beautiful_descriptors.append(match)
|
| 273 |
+
return beautiful_descriptors, channel_slice_list
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class BaseTimeDataset(BaseDataset, ABC):
|
| 277 |
+
"""A base class for time dependent problems. Inherit time-dependent problems from here."""
|
| 278 |
+
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
*args,
|
| 282 |
+
max_num_time_steps: Optional[int] = None,
|
| 283 |
+
time_step_size: Optional[int] = None,
|
| 284 |
+
fix_input_to_time_step: Optional[int] = None,
|
| 285 |
+
allowed_time_transitions: Optional[List[int]] = None,
|
| 286 |
+
**kwargs,
|
| 287 |
+
) -> None:
|
| 288 |
+
"""
|
| 289 |
+
Args:
|
| 290 |
+
max_num_time_steps: The maximum number of time steps to use.
|
| 291 |
+
time_step_size: The size of the time step.
|
| 292 |
+
fix_input_to_time_step: If not None, fix the input to this time step.
|
| 293 |
+
allowed_time_transitions: If not None, only allow these time transitions (time steps).
|
| 294 |
+
"""
|
| 295 |
+
assert max_num_time_steps is not None and max_num_time_steps > 0
|
| 296 |
+
assert time_step_size is not None and time_step_size > 0
|
| 297 |
+
assert fix_input_to_time_step is None or fix_input_to_time_step >= 0
|
| 298 |
+
|
| 299 |
+
super().__init__(*args, **kwargs)
|
| 300 |
+
self.max_num_time_steps = max_num_time_steps
|
| 301 |
+
self.time_step_size = time_step_size
|
| 302 |
+
self.fix_input_to_time_step = fix_input_to_time_step
|
| 303 |
+
self.allowed_time_transitions = allowed_time_transitions
|
| 304 |
+
|
| 305 |
+
def _idx_map(self, idx):
|
| 306 |
+
i = idx // self.multiplier
|
| 307 |
+
_idx = idx - i * self.multiplier
|
| 308 |
+
|
| 309 |
+
if self.fix_input_to_time_step is None:
|
| 310 |
+
t1, t2 = self.time_indices[_idx]
|
| 311 |
+
assert t2 >= t1
|
| 312 |
+
t = t2 - t1
|
| 313 |
+
else:
|
| 314 |
+
t1 = self.fix_input_to_time_step
|
| 315 |
+
t2 = self.time_step_size * (_idx + 1) + self.fix_input_to_time_step
|
| 316 |
+
t = t2 - t1
|
| 317 |
+
return i, t, t1, t2
|
| 318 |
+
|
| 319 |
+
def post_init(self) -> None:
|
| 320 |
+
"""
|
| 321 |
+
Call after self.N_max, self.N_val, self.N_test, as well as the file_paths and normalization constants are set.
|
| 322 |
+
self.max_time_step must have already been set.
|
| 323 |
+
"""
|
| 324 |
+
assert (
|
| 325 |
+
self.N_max is not None
|
| 326 |
+
and self.N_max > 0
|
| 327 |
+
and self.N_max >= self.N_val + self.N_test
|
| 328 |
+
)
|
| 329 |
+
if self.num_trajectories == -1:
|
| 330 |
+
self.num_trajectories = self.N_max - self.N_val - self.N_test
|
| 331 |
+
elif self.num_trajectories == -2:
|
| 332 |
+
self.num_trajectories = (self.N_max - self.N_val - self.N_test) // 2
|
| 333 |
+
elif self.num_trajectories == -8:
|
| 334 |
+
self.num_trajectories = (self.N_max - self.N_val - self.N_test) // 8
|
| 335 |
+
assert self.num_trajectories + self.N_val + self.N_test <= self.N_max
|
| 336 |
+
assert self.N_val is not None and self.N_val > 0
|
| 337 |
+
assert self.N_test is not None and self.N_test > 0
|
| 338 |
+
assert self.max_num_time_steps is not None and self.max_num_time_steps > 0
|
| 339 |
+
|
| 340 |
+
if self.fix_input_to_time_step is not None:
|
| 341 |
+
self.multiplier = self.max_num_time_steps
|
| 342 |
+
else:
|
| 343 |
+
self.time_indices = []
|
| 344 |
+
for i in range(self.max_num_time_steps + 1):
|
| 345 |
+
for j in range(i, self.max_num_time_steps + 1):
|
| 346 |
+
if (
|
| 347 |
+
self.allowed_time_transitions is not None
|
| 348 |
+
and (j - i) not in self.allowed_time_transitions
|
| 349 |
+
):
|
| 350 |
+
continue
|
| 351 |
+
self.time_indices.append(
|
| 352 |
+
(self.time_step_size * i, self.time_step_size * j)
|
| 353 |
+
)
|
| 354 |
+
self.multiplier = len(self.time_indices)
|
| 355 |
+
|
| 356 |
+
if self.which == "train":
|
| 357 |
+
self.length = self.num_trajectories * self.multiplier
|
| 358 |
+
self.start = 0
|
| 359 |
+
elif self.which == "val":
|
| 360 |
+
self.length = self.N_val * self.multiplier
|
| 361 |
+
self.start = self.N_max - self.N_val - self.N_test
|
| 362 |
+
else:
|
| 363 |
+
self.length = self.N_test * self.multiplier
|
| 364 |
+
self.start = self.N_max - self.N_test
|
| 365 |
+
|
| 366 |
+
self.output_dim = self.label_description.count(",") + 1
|
| 367 |
+
descriptors, channel_slice_list = self.get_channel_lists(self.label_description)
|
| 368 |
+
self.printable_channel_description = descriptors
|
| 369 |
+
self.channel_slice_list = channel_slice_list
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class TimeWrapper(BaseTimeDataset):
|
| 373 |
+
"""For time-independent problems to be plugged into time-dependent models."""
|
| 374 |
+
|
| 375 |
+
def __init__(self, dataset):
|
| 376 |
+
super().__init__(
|
| 377 |
+
dataset.which,
|
| 378 |
+
dataset.num_trajectories,
|
| 379 |
+
dataset.data_path,
|
| 380 |
+
None,
|
| 381 |
+
max_num_time_steps=1,
|
| 382 |
+
time_step_size=1,
|
| 383 |
+
)
|
| 384 |
+
self.dataset = dataset
|
| 385 |
+
self.resolution = dataset.resolution
|
| 386 |
+
self.input_dim = dataset.input_dim
|
| 387 |
+
self.output_dim = dataset.output_dim
|
| 388 |
+
self.channel_slice_list = dataset.channel_slice_list
|
| 389 |
+
self.printable_channel_description = dataset.printable_channel_description
|
| 390 |
+
|
| 391 |
+
def __len__(self):
|
| 392 |
+
return len(self.dataset)
|
| 393 |
+
|
| 394 |
+
def __getitem__(self, idx):
|
| 395 |
+
return {**self.dataset[idx], "time": 1.0}
|
external/poseidon/scOT/problems/elliptic/__init__.py
ADDED
|
File without changes
|
external/poseidon/scOT/problems/elliptic/helmholtz.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import h5py
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scOT.problems.base import BaseDataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Helmholtz(BaseDataset):
|
| 9 |
+
def __init__(self, *args, **kwargs):
|
| 10 |
+
super().__init__(*args, **kwargs)
|
| 11 |
+
|
| 12 |
+
self.N_max = 19675
|
| 13 |
+
self.N_val = 128
|
| 14 |
+
self.N_test = 512
|
| 15 |
+
self.resolution = 128
|
| 16 |
+
|
| 17 |
+
self.file_path = os.path.join(
|
| 18 |
+
self.data_path,
|
| 19 |
+
"Helmholtz.h5",
|
| 20 |
+
)
|
| 21 |
+
self.file_path = self._move_to_local_scratch(self.file_path)
|
| 22 |
+
self.reader = h5py.File(self.file_path, "r")
|
| 23 |
+
self.mean = 0.11523915668552
|
| 24 |
+
self.std = 0.8279975746000605
|
| 25 |
+
|
| 26 |
+
self.input_dim = 2
|
| 27 |
+
self.label_description = "[u]"
|
| 28 |
+
|
| 29 |
+
self.post_init()
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, idx):
|
| 32 |
+
inputs = (
|
| 33 |
+
torch.from_numpy(self.reader["Sample_" + str(idx + self.start)]["a"][:])
|
| 34 |
+
.type(torch.float32)
|
| 35 |
+
.reshape(1, self.resolution, self.resolution)
|
| 36 |
+
)
|
| 37 |
+
inputs = inputs - 1
|
| 38 |
+
b = float(np.array(self.reader["Sample_" + str(idx + self.start)]["bc"]))
|
| 39 |
+
bc = b * torch.ones_like(inputs)
|
| 40 |
+
inputs = torch.cat((inputs, bc), dim=0)
|
| 41 |
+
|
| 42 |
+
labels = (
|
| 43 |
+
torch.from_numpy(self.reader["Sample_" + str(idx + self.start)]["u"][:])
|
| 44 |
+
.type(torch.float32)
|
| 45 |
+
.reshape(1, self.resolution, self.resolution)
|
| 46 |
+
)
|
| 47 |
+
labels = (labels - self.mean) / self.std
|
| 48 |
+
|
| 49 |
+
return {"pixel_values": inputs, "labels": labels}
|
external/poseidon/scOT/problems/elliptic/poisson.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import h5py
|
| 4 |
+
from scOT.problems.base import BaseDataset
|
| 5 |
+
|
| 6 |
+
CONSTANTS = {
|
| 7 |
+
"mean_source": 0.014822142414492256,
|
| 8 |
+
"std_source": 4.755138816607612,
|
| 9 |
+
"mean_solution": 0.0005603458434937093,
|
| 10 |
+
"std_solution": 0.02401226126952699,
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Gaussians(BaseDataset):
|
| 15 |
+
def __init__(self, *args, **kwargs):
|
| 16 |
+
super().__init__(*args, **kwargs)
|
| 17 |
+
self.N_max = 20000
|
| 18 |
+
self.N_val = 120
|
| 19 |
+
self.N_test = 240
|
| 20 |
+
self.resolution = 128
|
| 21 |
+
|
| 22 |
+
self.file_path = os.path.join(self.data_path, "Poisson-Gauss.nc")
|
| 23 |
+
self.file_path = self._move_to_local_scratch(self.file_path)
|
| 24 |
+
self.reader = h5py.File(self.file_path, "r")
|
| 25 |
+
self.constants = CONSTANTS
|
| 26 |
+
|
| 27 |
+
self.input_dim = 1
|
| 28 |
+
self.label_description = "[u]"
|
| 29 |
+
|
| 30 |
+
self.post_init()
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
inputs = (
|
| 34 |
+
torch.from_numpy(self.reader["source"][idx + self.start])
|
| 35 |
+
.type(torch.float32)
|
| 36 |
+
.reshape(1, self.resolution, self.resolution)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
labels = (
|
| 40 |
+
torch.from_numpy(self.reader["solution"][idx + self.start])
|
| 41 |
+
.type(torch.float32)
|
| 42 |
+
.reshape(1, self.resolution, self.resolution)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
inputs = (inputs - self.constants["mean_source"]) / self.constants["std_source"]
|
| 46 |
+
labels = (labels - self.constants["mean_solution"]) / self.constants[
|
| 47 |
+
"std_solution"
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
return {"pixel_values": inputs, "labels": labels}
|
external/poseidon/scOT/problems/fluids/__init__.py
ADDED
|
File without changes
|
external/poseidon/scOT/problems/fluids/compressible.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import h5py
|
| 3 |
+
import copy
|
| 4 |
+
from scOT.problems.base import BaseTimeDataset, BaseDataset
|
| 5 |
+
from scOT.problems.fluids.normalization_constants import CONSTANTS
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Airfoil(BaseDataset):
|
| 9 |
+
def __init__(self, *args, **kwargs):
|
| 10 |
+
super().__init__(*args, **kwargs)
|
| 11 |
+
|
| 12 |
+
self.N_max = 10869
|
| 13 |
+
self.N_val = 120
|
| 14 |
+
self.N_test = 240
|
| 15 |
+
self.resolution = 128
|
| 16 |
+
|
| 17 |
+
data_path = self.data_path + "/SE-AF.nc"
|
| 18 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 19 |
+
self.reader = h5py.File(data_path, "r")
|
| 20 |
+
|
| 21 |
+
self.constants = {
|
| 22 |
+
"mean": 0.92984116,
|
| 23 |
+
"std": 0.10864315,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
self.input_dim = 1
|
| 27 |
+
self.label_description = "[rho]"
|
| 28 |
+
|
| 29 |
+
self.post_init()
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, idx):
|
| 32 |
+
i = idx
|
| 33 |
+
inputs = (
|
| 34 |
+
torch.from_numpy(self.reader["solution"][i + self.start, 0])
|
| 35 |
+
.type(torch.float32)
|
| 36 |
+
.reshape(1, self.resolution, self.resolution)
|
| 37 |
+
)
|
| 38 |
+
labels = (
|
| 39 |
+
torch.from_numpy(self.reader["solution"][i + self.start, 1])
|
| 40 |
+
.type(torch.float32)
|
| 41 |
+
.reshape(1, self.resolution, self.resolution)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
labels = (labels - self.constants["mean"]) / self.constants["std"]
|
| 45 |
+
|
| 46 |
+
pixel_mask = inputs == 1
|
| 47 |
+
labels[pixel_mask] = 1
|
| 48 |
+
|
| 49 |
+
return {
|
| 50 |
+
"pixel_values": inputs,
|
| 51 |
+
"labels": labels,
|
| 52 |
+
"pixel_mask": pixel_mask,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RichtmyerMeshkov(BaseTimeDataset):
|
| 57 |
+
def __init__(self, *args, tracer=False, **kwargs):
|
| 58 |
+
super().__init__(*args, **kwargs)
|
| 59 |
+
assert self.max_num_time_steps * self.time_step_size <= 20
|
| 60 |
+
|
| 61 |
+
self.N_max = 1260
|
| 62 |
+
self.N_val = 100
|
| 63 |
+
self.N_test = 130
|
| 64 |
+
self.resolution = 128
|
| 65 |
+
|
| 66 |
+
data_path = self.data_path + "/CE-RM.nc"
|
| 67 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 68 |
+
self.reader = h5py.File(data_path, "r")
|
| 69 |
+
|
| 70 |
+
self.constants = {
|
| 71 |
+
"mean": torch.tensor([1.1964245, -7.164812e-06, 2.8968952e-06, 1.5648036])
|
| 72 |
+
.unsqueeze(1)
|
| 73 |
+
.unsqueeze(1),
|
| 74 |
+
"std": torch.tensor([0.5543239, 0.24304213, 0.2430597, 0.89639103])
|
| 75 |
+
.unsqueeze(1)
|
| 76 |
+
.unsqueeze(1),
|
| 77 |
+
"time": 20.0,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
self.input_dim = 4
|
| 81 |
+
self.label_description = "[rho],[u,v],[p]"
|
| 82 |
+
|
| 83 |
+
self.pixel_mask = torch.tensor([False, False, False, False])
|
| 84 |
+
|
| 85 |
+
self.post_init()
|
| 86 |
+
|
| 87 |
+
def __getitem__(self, idx):
|
| 88 |
+
i, t, t1, t2 = self._idx_map(idx)
|
| 89 |
+
time = t / self.constants["time"]
|
| 90 |
+
|
| 91 |
+
inputs = (
|
| 92 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t1, 0:4])
|
| 93 |
+
.type(torch.float32)
|
| 94 |
+
.reshape(4, self.resolution, self.resolution)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
label = (
|
| 98 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t2, 0:4])
|
| 99 |
+
.type(torch.float32)
|
| 100 |
+
.reshape(4, self.resolution, self.resolution)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
inputs = (inputs - self.constants["mean"]) / self.constants["std"]
|
| 104 |
+
label = (label - self.constants["mean"]) / self.constants["std"]
|
| 105 |
+
|
| 106 |
+
return {
|
| 107 |
+
"pixel_values": inputs,
|
| 108 |
+
"labels": label,
|
| 109 |
+
"time": time,
|
| 110 |
+
"pixel_mask": self.pixel_mask,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class RayleighTaylor(BaseTimeDataset):
|
| 115 |
+
def __init__(self, *args, tracer=False, **kwargs):
|
| 116 |
+
super().__init__(*args, **kwargs)
|
| 117 |
+
assert self.max_num_time_steps * self.time_step_size <= 10
|
| 118 |
+
|
| 119 |
+
self.N_max = 1260
|
| 120 |
+
self.N_val = 100
|
| 121 |
+
self.N_test = 130
|
| 122 |
+
self.resolution = 128
|
| 123 |
+
|
| 124 |
+
data_path = self.data_path + "/GCE-RT.nc"
|
| 125 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 126 |
+
self.reader = h5py.File(data_path, "r")
|
| 127 |
+
|
| 128 |
+
self.constants = {
|
| 129 |
+
"mean": torch.tensor(
|
| 130 |
+
[0.8970493, 4.0316996e-13, -1.3858967e-13, 0.7133829, -1.7055787]
|
| 131 |
+
)
|
| 132 |
+
.unsqueeze(1)
|
| 133 |
+
.unsqueeze(1),
|
| 134 |
+
"std": torch.tensor(
|
| 135 |
+
[0.12857835, 0.014896976, 0.014896975, 0.21293919, 0.40131348]
|
| 136 |
+
)
|
| 137 |
+
.unsqueeze(1)
|
| 138 |
+
.unsqueeze(1),
|
| 139 |
+
"time": 10.0,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
self.input_dim = 5
|
| 143 |
+
self.label_description = "[rho],[u,v],[p],[g]"
|
| 144 |
+
|
| 145 |
+
self.pixel_mask = torch.tensor([False, False, False, False, False])
|
| 146 |
+
|
| 147 |
+
self.post_init()
|
| 148 |
+
|
| 149 |
+
def __getitem__(self, idx):
|
| 150 |
+
i, t, t1, t2 = self._idx_map(idx)
|
| 151 |
+
time = t / self.constants["time"]
|
| 152 |
+
|
| 153 |
+
inputs = (
|
| 154 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t1, 0:4])
|
| 155 |
+
.type(torch.float32)
|
| 156 |
+
.reshape(4, self.resolution, self.resolution)
|
| 157 |
+
)
|
| 158 |
+
label = (
|
| 159 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t2, 0:4])
|
| 160 |
+
.type(torch.float32)
|
| 161 |
+
.reshape(4, self.resolution, self.resolution)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
g_1 = (
|
| 165 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t1, 5:6])
|
| 166 |
+
.type(torch.float32)
|
| 167 |
+
.reshape(1, self.resolution, self.resolution)
|
| 168 |
+
)
|
| 169 |
+
g_2 = (
|
| 170 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t2, 5:6])
|
| 171 |
+
.type(torch.float32)
|
| 172 |
+
.reshape(1, self.resolution, self.resolution)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
inputs = (inputs - self.constants["mean"][:4]) / self.constants["std"][:4]
|
| 176 |
+
g_1 = (g_1 - self.constants["mean"][4]) / self.constants["std"][4]
|
| 177 |
+
g_2 = (g_2 - self.constants["mean"][4]) / self.constants["std"][4]
|
| 178 |
+
label = (label - self.constants["mean"][:4]) / self.constants["std"][:4]
|
| 179 |
+
|
| 180 |
+
inputs = torch.cat([inputs, g_1], dim=0)
|
| 181 |
+
label = torch.cat([label, g_2], dim=0)
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"pixel_values": inputs,
|
| 185 |
+
"labels": label,
|
| 186 |
+
"time": time,
|
| 187 |
+
"pixel_mask": self.pixel_mask,
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class CompressibleBase(BaseTimeDataset):
|
| 192 |
+
def __init__(self, file_path, *args, tracer=False, **kwargs):
|
| 193 |
+
super().__init__(*args, **kwargs)
|
| 194 |
+
assert self.max_num_time_steps * self.time_step_size <= 20
|
| 195 |
+
|
| 196 |
+
self.N_max = 10000
|
| 197 |
+
self.N_val = 120
|
| 198 |
+
self.N_test = 240
|
| 199 |
+
self.resolution = 128
|
| 200 |
+
self.tracer = tracer
|
| 201 |
+
|
| 202 |
+
data_path = self.data_path + file_path
|
| 203 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 204 |
+
self.reader = h5py.File(data_path, "r")
|
| 205 |
+
|
| 206 |
+
self.constants = copy.deepcopy(CONSTANTS)
|
| 207 |
+
|
| 208 |
+
self.input_dim = 4 if not tracer else 5
|
| 209 |
+
self.label_description = (
|
| 210 |
+
"[rho],[u,v],[p]" if not tracer else "[rho],[u,v],[p],[tracer]"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self.pixel_mask = (
|
| 214 |
+
torch.tensor([False, False, False, False])
|
| 215 |
+
if not tracer
|
| 216 |
+
else torch.tensor([False, False, False, False, False])
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.post_init()
|
| 220 |
+
|
| 221 |
+
def __getitem__(self, idx):
|
| 222 |
+
i, t, t1, t2 = self._idx_map(idx)
|
| 223 |
+
time = t / self.constants["time"]
|
| 224 |
+
|
| 225 |
+
inputs = (
|
| 226 |
+
torch.from_numpy(self.reader["data"][i + self.start, t1, 0:4])
|
| 227 |
+
.type(torch.float32)
|
| 228 |
+
.reshape(4, self.resolution, self.resolution)
|
| 229 |
+
)
|
| 230 |
+
label = (
|
| 231 |
+
torch.from_numpy(self.reader["data"][i + self.start, t2, 0:4])
|
| 232 |
+
.type(torch.float32)
|
| 233 |
+
.reshape(4, self.resolution, self.resolution)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
inputs[3] = inputs[3] - self.mean_pressure
|
| 237 |
+
label[3] = label[3] - self.mean_pressure
|
| 238 |
+
|
| 239 |
+
inputs = (inputs - self.constants["mean"]) / self.constants["std"]
|
| 240 |
+
label = (label - self.constants["mean"]) / self.constants["std"]
|
| 241 |
+
|
| 242 |
+
if self.tracer:
|
| 243 |
+
input_tracer = (
|
| 244 |
+
torch.from_numpy(self.reader["data"][i + self.start, t1, 4:5])
|
| 245 |
+
.type(torch.float32)
|
| 246 |
+
.reshape(1, self.resolution, self.resolution)
|
| 247 |
+
)
|
| 248 |
+
output_tracer = (
|
| 249 |
+
torch.from_numpy(self.reader["data"][i + self.start, t2, 4:5])
|
| 250 |
+
.type(torch.float32)
|
| 251 |
+
.reshape(1, self.resolution, self.resolution)
|
| 252 |
+
)
|
| 253 |
+
inputs = torch.cat([inputs, input_tracer], dim=0)
|
| 254 |
+
label = torch.cat([label, output_tracer], dim=0)
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
"pixel_values": inputs,
|
| 258 |
+
"labels": label,
|
| 259 |
+
"time": time,
|
| 260 |
+
"pixel_mask": self.pixel_mask,
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class Gaussians(CompressibleBase):
|
| 265 |
+
def __init__(self, *args, tracer=False, **kwargs):
|
| 266 |
+
self.mean_pressure = 2.513
|
| 267 |
+
file_path = "/CE-Gauss.nc"
|
| 268 |
+
if tracer:
|
| 269 |
+
raise NotImplementedError("Tracer not implemented for Gaussians")
|
| 270 |
+
super().__init__(file_path, *args, tracer=tracer, **kwargs)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class KelvinHelmholtz(CompressibleBase):
|
| 274 |
+
def __init__(self, *args, tracer=False, **kwargs):
|
| 275 |
+
self.mean_pressure = 1.0
|
| 276 |
+
file_path = "/CE-KH.nc"
|
| 277 |
+
if tracer:
|
| 278 |
+
raise NotImplementedError("Tracer not implemented for KelvinHelmholtz")
|
| 279 |
+
super().__init__(file_path, *args, tracer=tracer, **kwargs)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class Riemann(CompressibleBase):
|
| 283 |
+
def __init__(self, *args, tracer=False, **kwargs):
|
| 284 |
+
self.mean_pressure = 0.215
|
| 285 |
+
file_path = "/CE-RP.nc"
|
| 286 |
+
if tracer:
|
| 287 |
+
raise NotImplementedError("Tracer not implemented for Riemann")
|
| 288 |
+
super().__init__(file_path, *args, tracer=tracer, **kwargs)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class RiemannCurved(CompressibleBase):
|
| 292 |
+
def __init__(self, *args, tracer=False, **kwargs):
|
| 293 |
+
self.mean_pressure = 0.553
|
| 294 |
+
file_path = "/CE-CRP.nc"
|
| 295 |
+
if tracer:
|
| 296 |
+
raise NotImplementedError("Tracer not implemented for RiemannCurved")
|
| 297 |
+
super().__init__(file_path, *args, tracer=tracer, **kwargs)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class RiemannKelvinHelmholtz(CompressibleBase):
|
| 301 |
+
def __init__(self, *args, tracer=False, **kwargs):
|
| 302 |
+
self.mean_pressure = 1.33
|
| 303 |
+
file_path = "/CE-RPUI.nc"
|
| 304 |
+
if tracer:
|
| 305 |
+
raise NotImplementedError(
|
| 306 |
+
"Tracer not implemented for RiemannKelvinHelmholtz"
|
| 307 |
+
)
|
| 308 |
+
super().__init__(file_path, *args, tracer=tracer, **kwargs)
|
external/poseidon/scOT/problems/fluids/incompressible.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import h5py
|
| 3 |
+
import numpy as np
|
| 4 |
+
import copy
|
| 5 |
+
from scOT.problems.base import BaseTimeDataset
|
| 6 |
+
from scOT.problems.fluids.normalization_constants import CONSTANTS
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class IncompressibleBase(BaseTimeDataset):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
N_max,
|
| 13 |
+
file_path,
|
| 14 |
+
*args,
|
| 15 |
+
tracer=False,
|
| 16 |
+
just_velocities=False,
|
| 17 |
+
transpose=False,
|
| 18 |
+
resolution=None,
|
| 19 |
+
**kwargs
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
just_velocities: If True, only the velocities are used as input and output.
|
| 23 |
+
transpose: If True, the input and output are transposed.
|
| 24 |
+
"""
|
| 25 |
+
super().__init__(*args, **kwargs)
|
| 26 |
+
assert self.max_num_time_steps * self.time_step_size <= 20
|
| 27 |
+
|
| 28 |
+
self.N_max = N_max
|
| 29 |
+
self.N_val = 120
|
| 30 |
+
self.N_test = 240
|
| 31 |
+
self.resolution = 128
|
| 32 |
+
self.tracer = tracer
|
| 33 |
+
self.just_velocities = just_velocities
|
| 34 |
+
self.transpose = transpose
|
| 35 |
+
|
| 36 |
+
data_path = self.data_path + file_path
|
| 37 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 38 |
+
self.reader = h5py.File(data_path, "r")
|
| 39 |
+
|
| 40 |
+
self.constants = copy.deepcopy(CONSTANTS)
|
| 41 |
+
if just_velocities:
|
| 42 |
+
self.constants["mean"] = self.constants["mean"][1:3]
|
| 43 |
+
self.constants["std"] = self.constants["std"][1:3]
|
| 44 |
+
|
| 45 |
+
self.density = torch.ones(1, self.resolution, self.resolution)
|
| 46 |
+
self.pressure = torch.zeros(1, self.resolution, self.resolution)
|
| 47 |
+
|
| 48 |
+
self.input_dim = 4 if not tracer else 5
|
| 49 |
+
if just_velocities:
|
| 50 |
+
self.input_dim -= 2
|
| 51 |
+
self.label_description = "[u,v]"
|
| 52 |
+
if not self.just_velocities:
|
| 53 |
+
self.label_description = "[rho],[u,v],[p]"
|
| 54 |
+
if tracer:
|
| 55 |
+
self.label_description += ",[tracer]"
|
| 56 |
+
|
| 57 |
+
self.pixel_mask = torch.tensor([False, False])
|
| 58 |
+
if not self.just_velocities:
|
| 59 |
+
self.pixel_mask = torch.tensor([False, False, False, True])
|
| 60 |
+
if tracer:
|
| 61 |
+
self.pixel_mask = torch.cat(
|
| 62 |
+
[self.pixel_mask, torch.tensor([False])],
|
| 63 |
+
dim=0,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if resolution is None:
|
| 67 |
+
self.res = None
|
| 68 |
+
else:
|
| 69 |
+
if resolution > 128:
|
| 70 |
+
raise ValueError("Resolution must be <= 128")
|
| 71 |
+
self.res = resolution
|
| 72 |
+
|
| 73 |
+
self.post_init()
|
| 74 |
+
|
| 75 |
+
def _downsample(self, image, target_size):
|
| 76 |
+
image = image.unsqueeze(0)
|
| 77 |
+
image_size = image.shape[-2]
|
| 78 |
+
freqs = torch.fft.fftfreq(image_size, d=1 / image_size)
|
| 79 |
+
sel = torch.logical_and(freqs >= -target_size / 2, freqs <= target_size / 2 - 1)
|
| 80 |
+
image_hat = torch.fft.fft2(image, norm="forward")
|
| 81 |
+
image_hat = image_hat[:, :, sel, :][:, :, :, sel]
|
| 82 |
+
image = torch.fft.ifft2(image_hat, norm="forward").real
|
| 83 |
+
return image.squeeze(0)
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, idx):
|
| 86 |
+
i, t, t1, t2 = self._idx_map(idx)
|
| 87 |
+
time = t / self.constants["time"]
|
| 88 |
+
|
| 89 |
+
inputs_v = (
|
| 90 |
+
torch.from_numpy(self.reader["velocity"][i + self.start, t1, 0:2])
|
| 91 |
+
.type(torch.float32)
|
| 92 |
+
.reshape(2, self.resolution, self.resolution)
|
| 93 |
+
)
|
| 94 |
+
label_v = (
|
| 95 |
+
torch.from_numpy(self.reader["velocity"][i + self.start, t2, 0:2])
|
| 96 |
+
.type(torch.float32)
|
| 97 |
+
.reshape(2, self.resolution, self.resolution)
|
| 98 |
+
)
|
| 99 |
+
if self.transpose:
|
| 100 |
+
inputs_v = inputs_v.transpose(-2, -1)
|
| 101 |
+
label_v = label_v.transpose(-2, -1)
|
| 102 |
+
|
| 103 |
+
if not self.just_velocities:
|
| 104 |
+
inputs = torch.cat([self.density, inputs_v, self.pressure], dim=0)
|
| 105 |
+
label = torch.cat([self.density, label_v, self.pressure], dim=0)
|
| 106 |
+
else:
|
| 107 |
+
inputs = inputs_v
|
| 108 |
+
label = label_v
|
| 109 |
+
|
| 110 |
+
inputs = (inputs - self.constants["mean"]) / self.constants["std"]
|
| 111 |
+
label = (label - self.constants["mean"]) / self.constants["std"]
|
| 112 |
+
|
| 113 |
+
if self.tracer:
|
| 114 |
+
input_tracer = (
|
| 115 |
+
torch.from_numpy(self.reader["velocity"][i + self.start, t1, 2:3])
|
| 116 |
+
.type(torch.float32)
|
| 117 |
+
.reshape(1, self.resolution, self.resolution)
|
| 118 |
+
)
|
| 119 |
+
output_tracer = (
|
| 120 |
+
torch.from_numpy(self.reader["velocity"][i + self.start, t2, 2:3])
|
| 121 |
+
.type(torch.float32)
|
| 122 |
+
.reshape(1, self.resolution, self.resolution)
|
| 123 |
+
)
|
| 124 |
+
if self.transpose:
|
| 125 |
+
input_tracer = input_tracer.transpose(-2, -1)
|
| 126 |
+
output_tracer = output_tracer.transpose(-2, -1)
|
| 127 |
+
input_tracer = (
|
| 128 |
+
input_tracer - self.constants["tracer_mean"]
|
| 129 |
+
) / self.constants["tracer_std"]
|
| 130 |
+
output_tracer = (
|
| 131 |
+
output_tracer - self.constants["tracer_mean"]
|
| 132 |
+
) / self.constants["tracer_std"]
|
| 133 |
+
|
| 134 |
+
inputs = torch.cat([inputs, input_tracer], dim=0)
|
| 135 |
+
label = torch.cat([label, output_tracer], dim=0)
|
| 136 |
+
|
| 137 |
+
if self.res is not None:
|
| 138 |
+
inputs = self._downsample(inputs, self.res)
|
| 139 |
+
label = self._downsample(label, self.res)
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"pixel_values": inputs,
|
| 143 |
+
"labels": label,
|
| 144 |
+
"time": time,
|
| 145 |
+
"pixel_mask": self.pixel_mask,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class KolmogorovFlow(BaseTimeDataset):
|
| 150 |
+
def __init__(self, *args, tracer=False, just_velocities=False, **kwargs):
|
| 151 |
+
super().__init__(*args, **kwargs)
|
| 152 |
+
assert self.max_num_time_steps * self.time_step_size <= 20
|
| 153 |
+
|
| 154 |
+
assert tracer == False
|
| 155 |
+
|
| 156 |
+
self.N_max = 20000
|
| 157 |
+
self.N_val = 120
|
| 158 |
+
self.N_test = 240
|
| 159 |
+
self.resolution = 128
|
| 160 |
+
self.just_velocities = just_velocities
|
| 161 |
+
|
| 162 |
+
data_path = self.data_path + "/FNS-KF.nc"
|
| 163 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 164 |
+
self.reader = h5py.File(data_path, "r")
|
| 165 |
+
|
| 166 |
+
self.constants = copy.deepcopy(CONSTANTS)
|
| 167 |
+
self.constants["mean"][1] = -2.2424793e-13
|
| 168 |
+
self.constants["mean"][2] = 4.1510376e-12
|
| 169 |
+
self.constants["std"][1] = 0.22017328
|
| 170 |
+
self.constants["std"][2] = 0.22078253
|
| 171 |
+
if just_velocities:
|
| 172 |
+
self.constants["mean"] = self.constants["mean"][1:3]
|
| 173 |
+
self.constants["std"] = self.constants["std"][1:3]
|
| 174 |
+
|
| 175 |
+
self.density = torch.ones(1, self.resolution, self.resolution)
|
| 176 |
+
self.pressure = torch.zeros(1, self.resolution, self.resolution)
|
| 177 |
+
X, Y = torch.meshgrid(
|
| 178 |
+
torch.linspace(0, 1, self.resolution),
|
| 179 |
+
torch.linspace(0, 1, self.resolution),
|
| 180 |
+
indexing="ij",
|
| 181 |
+
)
|
| 182 |
+
f = lambda x, y: 0.1 * torch.sin(2.0 * np.pi * (x + y))
|
| 183 |
+
self.forcing = f(X, Y).unsqueeze(0)
|
| 184 |
+
self.constants["mean_forcing"] = -1.2996679288335145e-09
|
| 185 |
+
self.constants["std_forcing"] = 0.0707106739282608
|
| 186 |
+
self.forcing = (self.forcing - self.constants["mean_forcing"]) / self.constants[
|
| 187 |
+
"std_forcing"
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
self.input_dim = 5 if not tracer else 6
|
| 191 |
+
if just_velocities:
|
| 192 |
+
self.input_dim -= 2
|
| 193 |
+
self.label_description = "[u,v],[g]"
|
| 194 |
+
if not self.just_velocities:
|
| 195 |
+
self.label_description = "[rho],[u,v],[p],[g]"
|
| 196 |
+
if tracer:
|
| 197 |
+
self.label_description += ",[tracer]"
|
| 198 |
+
|
| 199 |
+
self.pixel_mask = torch.tensor([False, False, False])
|
| 200 |
+
if not self.just_velocities:
|
| 201 |
+
self.pixel_mask = torch.tensor([False, False, False, True, False])
|
| 202 |
+
if tracer:
|
| 203 |
+
self.pixel_mask = torch.cat(
|
| 204 |
+
[self.pixel_mask, torch.tensor([False])],
|
| 205 |
+
dim=0,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
self.post_init()
|
| 209 |
+
|
| 210 |
+
def __getitem__(self, idx):
|
| 211 |
+
i, t, t1, t2 = self._idx_map(idx)
|
| 212 |
+
time = t / self.constants["time"]
|
| 213 |
+
|
| 214 |
+
inputs_v = (
|
| 215 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t1, 0:2])
|
| 216 |
+
.type(torch.float32)
|
| 217 |
+
.reshape(2, self.resolution, self.resolution)
|
| 218 |
+
)
|
| 219 |
+
label_v = (
|
| 220 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t2, 0:2])
|
| 221 |
+
.type(torch.float32)
|
| 222 |
+
.reshape(2, self.resolution, self.resolution)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if not self.just_velocities:
|
| 226 |
+
inputs = torch.cat([self.density, inputs_v, self.pressure], dim=0)
|
| 227 |
+
label = torch.cat([self.density, label_v, self.pressure], dim=0)
|
| 228 |
+
else:
|
| 229 |
+
inputs = inputs_v
|
| 230 |
+
label = label_v
|
| 231 |
+
|
| 232 |
+
inputs = (inputs - self.constants["mean"]) / self.constants["std"]
|
| 233 |
+
label = (label - self.constants["mean"]) / self.constants["std"]
|
| 234 |
+
|
| 235 |
+
inputs = torch.cat([inputs, self.forcing], dim=0)
|
| 236 |
+
label = torch.cat([label, self.forcing], dim=0)
|
| 237 |
+
|
| 238 |
+
return {
|
| 239 |
+
"pixel_values": inputs,
|
| 240 |
+
"labels": label,
|
| 241 |
+
"time": time,
|
| 242 |
+
"pixel_mask": self.pixel_mask,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class BrownianBridge(IncompressibleBase):
|
| 247 |
+
def __init__(self, *args, tracer=False, just_velocities=False, **kwargs):
|
| 248 |
+
if tracer:
|
| 249 |
+
raise ValueError("BrownianBridge does not have a tracer")
|
| 250 |
+
file_path = "/NS-BB.nc"
|
| 251 |
+
super().__init__(
|
| 252 |
+
20000,
|
| 253 |
+
file_path,
|
| 254 |
+
*args,
|
| 255 |
+
tracer=False,
|
| 256 |
+
just_velocities=just_velocities,
|
| 257 |
+
**kwargs
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class PiecewiseConstants(IncompressibleBase):
|
| 262 |
+
def __init__(self, *args, tracer=False, just_velocities=False, **kwargs):
|
| 263 |
+
file_path = "/NS-PwC.nc"
|
| 264 |
+
super().__init__(
|
| 265 |
+
20000,
|
| 266 |
+
file_path,
|
| 267 |
+
*args,
|
| 268 |
+
tracer=tracer,
|
| 269 |
+
just_velocities=just_velocities,
|
| 270 |
+
**kwargs
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class Gaussians(IncompressibleBase):
|
| 275 |
+
def __init__(self, *args, tracer=False, just_velocities=False, **kwargs):
|
| 276 |
+
if tracer:
|
| 277 |
+
raise ValueError("Gaussians does not have a tracer")
|
| 278 |
+
file_path = "/NS-Gauss.nc"
|
| 279 |
+
super().__init__(
|
| 280 |
+
20000,
|
| 281 |
+
file_path,
|
| 282 |
+
*args,
|
| 283 |
+
tracer=False,
|
| 284 |
+
just_velocities=just_velocities,
|
| 285 |
+
**kwargs
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class ShearLayer(IncompressibleBase):
|
| 290 |
+
def __init__(self, *args, tracer=False, just_velocities=False, **kwargs):
|
| 291 |
+
if tracer:
|
| 292 |
+
raise ValueError("Shear layer does not have a tracer")
|
| 293 |
+
super().__init__(
|
| 294 |
+
40000,
|
| 295 |
+
"/NS-SL.nc",
|
| 296 |
+
*args,
|
| 297 |
+
transpose=True,
|
| 298 |
+
tracer=False,
|
| 299 |
+
just_velocities=just_velocities,
|
| 300 |
+
**kwargs
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class VortexSheet(IncompressibleBase):
|
| 305 |
+
def __init__(self, *args, tracer=False, just_velocities=False, **kwargs):
|
| 306 |
+
if tracer:
|
| 307 |
+
raise ValueError("VortexSheet does not have a tracer")
|
| 308 |
+
file_path = "/NS-SVS.nc"
|
| 309 |
+
super().__init__(
|
| 310 |
+
20000,
|
| 311 |
+
file_path,
|
| 312 |
+
*args,
|
| 313 |
+
tracer=False,
|
| 314 |
+
just_velocities=just_velocities,
|
| 315 |
+
**kwargs
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class Sines(IncompressibleBase):
|
| 320 |
+
def __init__(self, *args, tracer=False, just_velocities=False, **kwargs):
|
| 321 |
+
if tracer:
|
| 322 |
+
raise ValueError("Sines does not have a tracer")
|
| 323 |
+
file_path = "/NS-Sines.nc"
|
| 324 |
+
super().__init__(
|
| 325 |
+
20000,
|
| 326 |
+
file_path,
|
| 327 |
+
*args,
|
| 328 |
+
tracer=False,
|
| 329 |
+
just_velocities=just_velocities,
|
| 330 |
+
**kwargs
|
| 331 |
+
)
|
external/poseidon/scOT/problems/fluids/normalization_constants.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
CONSTANTS = {
|
| 4 |
+
"mean": torch.tensor([0.80, 0.0, 0.0, 0.0]).unsqueeze(1).unsqueeze(1),
|
| 5 |
+
"std": torch.tensor([0.31, 0.391, 0.356, 0.185]).unsqueeze(1).unsqueeze(1),
|
| 6 |
+
"time": 20.0,
|
| 7 |
+
"tracer_mean": 0.19586183,
|
| 8 |
+
"tracer_std": 0.37,
|
| 9 |
+
}
|
external/poseidon/scOT/problems/reaction_diffusion/__init__.py
ADDED
|
File without changes
|
external/poseidon/scOT/problems/reaction_diffusion/allen_cahn.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import h5py
|
| 3 |
+
from scOT.problems.base import BaseTimeDataset
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AllenCahn(BaseTimeDataset):
|
| 7 |
+
def __init__(self, *args, **kwargs):
|
| 8 |
+
super().__init__(*args, **kwargs)
|
| 9 |
+
assert self.max_num_time_steps * self.time_step_size <= 19
|
| 10 |
+
|
| 11 |
+
self.N_max = 15000
|
| 12 |
+
self.N_val = 60
|
| 13 |
+
self.N_test = 240
|
| 14 |
+
self.resolution = 128
|
| 15 |
+
|
| 16 |
+
data_path = self.data_path + "/ACE.nc"
|
| 17 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 18 |
+
self.reader = h5py.File(data_path, "r")
|
| 19 |
+
|
| 20 |
+
self.constants = {
|
| 21 |
+
"mean": 0.002484262,
|
| 22 |
+
"std": 0.65351176,
|
| 23 |
+
"time": 19.0,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
self.input_dim = 1
|
| 27 |
+
self.label_description = "[u]"
|
| 28 |
+
|
| 29 |
+
self.post_init()
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, idx):
|
| 32 |
+
i, t, t1, t2 = self._idx_map(idx)
|
| 33 |
+
time = t / self.constants["time"]
|
| 34 |
+
|
| 35 |
+
inputs = (
|
| 36 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t1])
|
| 37 |
+
.type(torch.float32)
|
| 38 |
+
.reshape(1, self.resolution, self.resolution)
|
| 39 |
+
)
|
| 40 |
+
labels = (
|
| 41 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t2])
|
| 42 |
+
.type(torch.float32)
|
| 43 |
+
.reshape(1, self.resolution, self.resolution)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
inputs = (inputs - self.constants["mean"]) / self.constants["std"]
|
| 47 |
+
labels = (labels - self.constants["mean"]) / self.constants["std"]
|
| 48 |
+
|
| 49 |
+
return {
|
| 50 |
+
"pixel_values": inputs,
|
| 51 |
+
"labels": labels,
|
| 52 |
+
"time": time,
|
| 53 |
+
}
|
external/poseidon/scOT/problems/wave/__init__.py
ADDED
|
File without changes
|
external/poseidon/scOT/problems/wave/acoustic.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import h5py
|
| 3 |
+
from scOT.problems.base import BaseTimeDataset
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Layer(BaseTimeDataset):
|
| 7 |
+
def __init__(self, *args, **kwargs):
|
| 8 |
+
super().__init__(*args, **kwargs)
|
| 9 |
+
assert self.max_num_time_steps * self.time_step_size <= 20
|
| 10 |
+
|
| 11 |
+
self.N_max = 10512
|
| 12 |
+
self.N_val = 60
|
| 13 |
+
self.N_test = 240
|
| 14 |
+
self.resolution = 128
|
| 15 |
+
|
| 16 |
+
data_path = self.data_path + "/Wave-Layer.nc"
|
| 17 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 18 |
+
self.reader = h5py.File(data_path, "r")
|
| 19 |
+
|
| 20 |
+
self.constants = {
|
| 21 |
+
"mean": 0.03467443221585092,
|
| 22 |
+
"std": 0.10442421752963911,
|
| 23 |
+
"mean_c": 3498.5644380917424,
|
| 24 |
+
"std_c": 647.843958567462,
|
| 25 |
+
"time": 20.0,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
self.input_dim = 2
|
| 29 |
+
self.label_description = "[u],[c]"
|
| 30 |
+
|
| 31 |
+
self.post_init()
|
| 32 |
+
|
| 33 |
+
def __getitem__(self, idx):
|
| 34 |
+
i, t, t1, t2 = self._idx_map(idx)
|
| 35 |
+
time = t / self.constants["time"]
|
| 36 |
+
|
| 37 |
+
inputs = (
|
| 38 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t1])
|
| 39 |
+
.type(torch.float32)
|
| 40 |
+
.reshape(1, self.resolution, self.resolution)
|
| 41 |
+
)
|
| 42 |
+
inputs_c = (
|
| 43 |
+
torch.from_numpy(self.reader["c"][i + self.start])
|
| 44 |
+
.type(torch.float32)
|
| 45 |
+
.reshape(1, self.resolution, self.resolution)
|
| 46 |
+
)
|
| 47 |
+
labels = (
|
| 48 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t2])
|
| 49 |
+
.type(torch.float32)
|
| 50 |
+
.reshape(1, self.resolution, self.resolution)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
inputs = (inputs - self.constants["mean"]) / self.constants["std"]
|
| 54 |
+
inputs_c = (inputs_c - self.constants["mean_c"]) / self.constants["std_c"]
|
| 55 |
+
labels = (labels - self.constants["mean"]) / self.constants["std"]
|
| 56 |
+
|
| 57 |
+
inputs = torch.cat([inputs, inputs_c], dim=0)
|
| 58 |
+
labels = torch.cat([labels, inputs_c], dim=0)
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
"pixel_values": inputs,
|
| 62 |
+
"labels": labels,
|
| 63 |
+
"time": time,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Gaussians(BaseTimeDataset):
|
| 68 |
+
def __init__(self, *args, **kwargs):
|
| 69 |
+
super().__init__(*args, **kwargs)
|
| 70 |
+
assert self.max_num_time_steps * self.time_step_size <= 15
|
| 71 |
+
|
| 72 |
+
self.N_max = 10512
|
| 73 |
+
self.N_val = 60
|
| 74 |
+
self.N_test = 240
|
| 75 |
+
self.resolution = 128
|
| 76 |
+
|
| 77 |
+
data_path = self.data_path + "/Wave-Gauss.nc"
|
| 78 |
+
data_path = self._move_to_local_scratch(data_path)
|
| 79 |
+
self.reader = h5py.File(data_path, "r")
|
| 80 |
+
|
| 81 |
+
self.constants = {
|
| 82 |
+
"mean": 0.0334376316,
|
| 83 |
+
"std": 0.1171879068,
|
| 84 |
+
"mean_c": 2618.4593933,
|
| 85 |
+
"std_c": 601.51658913,
|
| 86 |
+
"time": 15.0,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
self.input_dim = 2
|
| 90 |
+
self.label_description = "[u],[c]"
|
| 91 |
+
|
| 92 |
+
self.post_init()
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, idx):
|
| 95 |
+
i, t, t1, t2 = self._idx_map(idx)
|
| 96 |
+
time = t / self.constants["time"]
|
| 97 |
+
|
| 98 |
+
inputs = (
|
| 99 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t1])
|
| 100 |
+
.type(torch.float32)
|
| 101 |
+
.reshape(1, self.resolution, self.resolution)
|
| 102 |
+
)
|
| 103 |
+
inputs_c = (
|
| 104 |
+
torch.from_numpy(self.reader["c"][i + self.start])
|
| 105 |
+
.type(torch.float32)
|
| 106 |
+
.reshape(1, self.resolution, self.resolution)
|
| 107 |
+
)
|
| 108 |
+
labels = (
|
| 109 |
+
torch.from_numpy(self.reader["solution"][i + self.start, t2])
|
| 110 |
+
.type(torch.float32)
|
| 111 |
+
.reshape(1, self.resolution, self.resolution)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
inputs = (inputs - self.constants["mean"]) / self.constants["std"]
|
| 115 |
+
inputs_c = (inputs_c - self.constants["mean_c"]) / self.constants["std_c"]
|
| 116 |
+
labels = (labels - self.constants["mean"]) / self.constants["std"]
|
| 117 |
+
|
| 118 |
+
inputs = torch.cat([inputs, inputs_c], dim=0)
|
| 119 |
+
labels = torch.cat([labels, inputs_c], dim=0)
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"pixel_values": inputs,
|
| 123 |
+
"labels": labels,
|
| 124 |
+
"time": time,
|
| 125 |
+
}
|
external/poseidon/scOT/train.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script trains a scOT or pretrains Poseidon on a PDE dataset.
|
| 3 |
+
Can be also used for finetuning Poseidon.
|
| 4 |
+
Can be used in a single config or sweep setup.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import torch
|
| 9 |
+
import wandb
|
| 10 |
+
import numpy as np
|
| 11 |
+
import random
|
| 12 |
+
import json
|
| 13 |
+
import psutil
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
| 17 |
+
import yaml
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import transformers
|
| 20 |
+
from accelerate.utils import broadcast_object_list
|
| 21 |
+
from scOT.trainer import TrainingArguments, Trainer
|
| 22 |
+
from transformers import EarlyStoppingCallback
|
| 23 |
+
from scOT.model import ScOT, ScOTConfig
|
| 24 |
+
from mpl_toolkits.axes_grid1 import ImageGrid
|
| 25 |
+
from scOT.problems.base import get_dataset, BaseTimeDataset
|
| 26 |
+
from scOT.utils import get_num_parameters, read_cli, get_num_parameters_no_embed
|
| 27 |
+
from scOT.metrics import relative_lp_error
|
| 28 |
+
|
| 29 |
+
SEED = 0
|
| 30 |
+
torch.manual_seed(SEED)
|
| 31 |
+
np.random.seed(SEED)
|
| 32 |
+
random.seed(SEED)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
MODEL_MAP = {
|
| 36 |
+
"T": {
|
| 37 |
+
"num_heads": [3, 6, 12, 24],
|
| 38 |
+
"skip_connections": [2, 2, 2, 0],
|
| 39 |
+
"window_size": 16,
|
| 40 |
+
"patch_size": 4,
|
| 41 |
+
"mlp_ratio": 4.0,
|
| 42 |
+
"depths": [4, 4, 4, 4],
|
| 43 |
+
"embed_dim": 48,
|
| 44 |
+
},
|
| 45 |
+
"S": {
|
| 46 |
+
"num_heads": [3, 6, 12, 24],
|
| 47 |
+
"skip_connections": [2, 2, 2, 0],
|
| 48 |
+
"window_size": 16,
|
| 49 |
+
"patch_size": 4,
|
| 50 |
+
"mlp_ratio": 4.0,
|
| 51 |
+
"depths": [8, 8, 8, 8],
|
| 52 |
+
"embed_dim": 48,
|
| 53 |
+
},
|
| 54 |
+
"B": {
|
| 55 |
+
"num_heads": [3, 6, 12, 24],
|
| 56 |
+
"skip_connections": [2, 2, 2, 0],
|
| 57 |
+
"window_size": 16,
|
| 58 |
+
"patch_size": 4,
|
| 59 |
+
"mlp_ratio": 4.0,
|
| 60 |
+
"depths": [8, 8, 8, 8],
|
| 61 |
+
"embed_dim": 96,
|
| 62 |
+
},
|
| 63 |
+
"L": {
|
| 64 |
+
"num_heads": [3, 6, 12, 24],
|
| 65 |
+
"skip_connections": [2, 2, 2, 0],
|
| 66 |
+
"window_size": 16,
|
| 67 |
+
"patch_size": 4,
|
| 68 |
+
"mlp_ratio": 4.0,
|
| 69 |
+
"depths": [8, 8, 8, 8],
|
| 70 |
+
"embed_dim": 192,
|
| 71 |
+
},
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def create_predictions_plot(predictions, labels, wandb_prefix):
|
| 76 |
+
assert predictions.shape[0] >= 4
|
| 77 |
+
|
| 78 |
+
indices = random.sample(range(predictions.shape[0]), 4)
|
| 79 |
+
|
| 80 |
+
predictions = predictions[indices]
|
| 81 |
+
labels = labels[indices]
|
| 82 |
+
|
| 83 |
+
fig = plt.figure()
|
| 84 |
+
grid = ImageGrid(
|
| 85 |
+
fig, 111, nrows_ncols=(predictions.shape[1] + labels.shape[1], 4), axes_pad=0.1
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
vmax, vmin = max(predictions.max(), labels.max()), min(
|
| 89 |
+
predictions.min(), labels.min()
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
for _i, ax in enumerate(grid):
|
| 93 |
+
i = _i // 4
|
| 94 |
+
j = _i % 4
|
| 95 |
+
|
| 96 |
+
if i % 2 == 0:
|
| 97 |
+
ax.imshow(
|
| 98 |
+
predictions[j, i // 2, :, :],
|
| 99 |
+
cmap="gist_ncar",
|
| 100 |
+
origin="lower",
|
| 101 |
+
vmin=vmin,
|
| 102 |
+
vmax=vmax,
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
ax.imshow(
|
| 106 |
+
labels[j, i // 2, :, :],
|
| 107 |
+
cmap="gist_ncar",
|
| 108 |
+
origin="lower",
|
| 109 |
+
vmin=vmin,
|
| 110 |
+
vmax=vmax,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
ax.set_xticks([])
|
| 114 |
+
ax.set_yticks([])
|
| 115 |
+
|
| 116 |
+
wandb.log({wandb_prefix + "/predictions": wandb.Image(fig)})
|
| 117 |
+
plt.close()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def setup(params, model_map=True):
|
| 121 |
+
config = None
|
| 122 |
+
RANK = int(os.environ.get("LOCAL_RANK", -1))
|
| 123 |
+
CPU_CORES = len(psutil.Process().cpu_affinity())
|
| 124 |
+
CPU_CORES = min(CPU_CORES, 16)
|
| 125 |
+
print(f"Detected {CPU_CORES} CPU cores, will use {CPU_CORES} workers.")
|
| 126 |
+
if params.disable_tqdm:
|
| 127 |
+
transformers.utils.logging.disable_progress_bar()
|
| 128 |
+
if params.json_config:
|
| 129 |
+
config = json.loads(params.config)
|
| 130 |
+
else:
|
| 131 |
+
config = params.config
|
| 132 |
+
|
| 133 |
+
if RANK == 0 or RANK == -1:
|
| 134 |
+
run = wandb.init(
|
| 135 |
+
project=params.wandb_project_name, name=params.wandb_run_name, config=config
|
| 136 |
+
)
|
| 137 |
+
config = wandb.config
|
| 138 |
+
else:
|
| 139 |
+
|
| 140 |
+
def clean_yaml(config):
|
| 141 |
+
d = {}
|
| 142 |
+
for key, inner_dict in config.items():
|
| 143 |
+
d[key] = inner_dict["value"]
|
| 144 |
+
return d
|
| 145 |
+
|
| 146 |
+
if not params.json_config:
|
| 147 |
+
with open(params.config, "r") as s:
|
| 148 |
+
config = yaml.safe_load(s)
|
| 149 |
+
config = clean_yaml(config)
|
| 150 |
+
run = None
|
| 151 |
+
|
| 152 |
+
ckpt_dir = "./"
|
| 153 |
+
if RANK == 0 or RANK == -1:
|
| 154 |
+
if run.sweep_id is not None:
|
| 155 |
+
ckpt_dir = (
|
| 156 |
+
params.checkpoint_path
|
| 157 |
+
+ "/"
|
| 158 |
+
+ run.project
|
| 159 |
+
+ "/"
|
| 160 |
+
+ run.sweep_id
|
| 161 |
+
+ "/"
|
| 162 |
+
+ run.name
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
ckpt_dir = params.checkpoint_path + "/" + run.project + "/" + run.name
|
| 166 |
+
if (RANK == 0 or RANK == -1) and not os.path.exists(ckpt_dir):
|
| 167 |
+
os.makedirs(ckpt_dir)
|
| 168 |
+
ls = broadcast_object_list([ckpt_dir], from_process=0)
|
| 169 |
+
ckpt_dir = ls[0]
|
| 170 |
+
|
| 171 |
+
if model_map and (
|
| 172 |
+
type(config["model_name"]) == str and config["model_name"] in MODEL_MAP.keys()
|
| 173 |
+
):
|
| 174 |
+
config = {**config, **MODEL_MAP[config["model_name"]]}
|
| 175 |
+
if RANK == 0 or RANK == -1:
|
| 176 |
+
wandb.config.update(MODEL_MAP[config["model_name"]], allow_val_change=True)
|
| 177 |
+
|
| 178 |
+
return run, config, ckpt_dir, RANK, CPU_CORES
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
parser = argparse.ArgumentParser(description="Train scOT or pretrain Poseidon.")
|
| 183 |
+
parser.add_argument("--resume_training", action="store_true")
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--finetune_from",
|
| 186 |
+
type=str,
|
| 187 |
+
default=None,
|
| 188 |
+
help="Set this to a str pointing to a HF Hub model checkpoint or a directory with a scOT checkpoint if you want to finetune.",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--replace_embedding_recovery",
|
| 192 |
+
action="store_true",
|
| 193 |
+
help="Set this if you have to replace the embeddings and recovery layers because you are not just using the density, velocity and pressure channels. Only relevant for finetuning.",
|
| 194 |
+
)
|
| 195 |
+
params = read_cli(parser).parse_args()
|
| 196 |
+
run, config, ckpt_dir, RANK, CPU_CORES = setup(params)
|
| 197 |
+
|
| 198 |
+
train_eval_set_kwargs = (
|
| 199 |
+
{"just_velocities": True}
|
| 200 |
+
if ("incompressible" in config["dataset"]) and params.just_velocities
|
| 201 |
+
else {}
|
| 202 |
+
)
|
| 203 |
+
if params.move_data is not None:
|
| 204 |
+
train_eval_set_kwargs["move_to_local_scratch"] = params.move_data
|
| 205 |
+
if params.max_num_train_time_steps is not None:
|
| 206 |
+
train_eval_set_kwargs["max_num_time_steps"] = params.max_num_train_time_steps
|
| 207 |
+
if params.train_time_step_size is not None:
|
| 208 |
+
train_eval_set_kwargs["time_step_size"] = params.train_time_step_size
|
| 209 |
+
if params.train_small_time_transition:
|
| 210 |
+
train_eval_set_kwargs["allowed_time_transitions"] = [1]
|
| 211 |
+
train_dataset = get_dataset(
|
| 212 |
+
dataset=config["dataset"],
|
| 213 |
+
which="train",
|
| 214 |
+
num_trajectories=config["num_trajectories"],
|
| 215 |
+
data_path=params.data_path,
|
| 216 |
+
**train_eval_set_kwargs,
|
| 217 |
+
)
|
| 218 |
+
eval_dataset = get_dataset(
|
| 219 |
+
dataset=config["dataset"],
|
| 220 |
+
which="val",
|
| 221 |
+
num_trajectories=config["num_trajectories"],
|
| 222 |
+
data_path=params.data_path,
|
| 223 |
+
**train_eval_set_kwargs,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
config["effective_train_set_size"] = len(train_dataset)
|
| 227 |
+
time_involved = isinstance(train_dataset, BaseTimeDataset) or (
|
| 228 |
+
isinstance(train_dataset, torch.utils.data.ConcatDataset)
|
| 229 |
+
and isinstance(train_dataset.datasets[0], BaseTimeDataset)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
if not isinstance(train_dataset, torch.utils.data.ConcatDataset):
|
| 233 |
+
resolution = train_dataset.resolution
|
| 234 |
+
input_dim = train_dataset.input_dim
|
| 235 |
+
output_dim = train_dataset.output_dim
|
| 236 |
+
channel_slice_list = train_dataset.channel_slice_list
|
| 237 |
+
printable_channel_description = train_dataset.printable_channel_description
|
| 238 |
+
else:
|
| 239 |
+
resolution = train_dataset.datasets[0].resolution
|
| 240 |
+
input_dim = train_dataset.datasets[0].input_dim
|
| 241 |
+
output_dim = train_dataset.datasets[0].output_dim
|
| 242 |
+
channel_slice_list = train_dataset.datasets[0].channel_slice_list
|
| 243 |
+
printable_channel_description = train_dataset.datasets[
|
| 244 |
+
0
|
| 245 |
+
].printable_channel_description
|
| 246 |
+
|
| 247 |
+
model_config = (
|
| 248 |
+
ScOTConfig(
|
| 249 |
+
image_size=resolution,
|
| 250 |
+
patch_size=config["patch_size"],
|
| 251 |
+
num_channels=input_dim,
|
| 252 |
+
num_out_channels=output_dim,
|
| 253 |
+
embed_dim=config["embed_dim"],
|
| 254 |
+
depths=config["depths"],
|
| 255 |
+
num_heads=config["num_heads"],
|
| 256 |
+
skip_connections=config["skip_connections"],
|
| 257 |
+
window_size=config["window_size"],
|
| 258 |
+
mlp_ratio=config["mlp_ratio"],
|
| 259 |
+
qkv_bias=True,
|
| 260 |
+
hidden_dropout_prob=0.0, # default
|
| 261 |
+
attention_probs_dropout_prob=0.0, # default
|
| 262 |
+
drop_path_rate=0.0,
|
| 263 |
+
hidden_act="gelu",
|
| 264 |
+
use_absolute_embeddings=False,
|
| 265 |
+
initializer_range=0.02,
|
| 266 |
+
layer_norm_eps=1e-5,
|
| 267 |
+
p=1,
|
| 268 |
+
channel_slice_list_normalized_loss=channel_slice_list,
|
| 269 |
+
residual_model="convnext",
|
| 270 |
+
use_conditioning=time_involved,
|
| 271 |
+
learn_residual=False,
|
| 272 |
+
)
|
| 273 |
+
if params.finetune_from is None or params.replace_embedding_recovery
|
| 274 |
+
else None
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
train_config = TrainingArguments(
|
| 278 |
+
output_dir=ckpt_dir,
|
| 279 |
+
overwrite_output_dir=True, #! OVERWRITE THIS DIRECTORY IN CASE, also for resuming training
|
| 280 |
+
evaluation_strategy="epoch",
|
| 281 |
+
per_device_train_batch_size=config["batch_size"],
|
| 282 |
+
per_device_eval_batch_size=config["batch_size"],
|
| 283 |
+
eval_accumulation_steps=16,
|
| 284 |
+
max_grad_norm=config["max_grad_norm"],
|
| 285 |
+
num_train_epochs=config["num_epochs"],
|
| 286 |
+
optim="adamw_torch",
|
| 287 |
+
learning_rate=config["lr"],
|
| 288 |
+
learning_rate_embedding_recovery=(
|
| 289 |
+
None
|
| 290 |
+
if (params.finetune_from is None or "lr_embedding_recovery" not in config)
|
| 291 |
+
else config["lr_embedding_recovery"]
|
| 292 |
+
),
|
| 293 |
+
learning_rate_time_embedding=(
|
| 294 |
+
None
|
| 295 |
+
if (params.finetune_from is None or "lr_time_embedding" not in config)
|
| 296 |
+
else config["lr_time_embedding"]
|
| 297 |
+
),
|
| 298 |
+
weight_decay=config["weight_decay"],
|
| 299 |
+
adam_beta1=0.9, # default
|
| 300 |
+
adam_beta2=0.999, # default
|
| 301 |
+
adam_epsilon=1e-8, # default
|
| 302 |
+
lr_scheduler_type=config["lr_scheduler"],
|
| 303 |
+
warmup_ratio=config["warmup_ratio"],
|
| 304 |
+
log_level="passive",
|
| 305 |
+
logging_strategy="steps",
|
| 306 |
+
logging_steps=5,
|
| 307 |
+
logging_nan_inf_filter=False,
|
| 308 |
+
save_strategy="epoch",
|
| 309 |
+
save_total_limit=1,
|
| 310 |
+
seed=SEED,
|
| 311 |
+
fp16=False,
|
| 312 |
+
dataloader_num_workers=CPU_CORES,
|
| 313 |
+
load_best_model_at_end=True,
|
| 314 |
+
metric_for_best_model="loss",
|
| 315 |
+
greater_is_better=False,
|
| 316 |
+
dataloader_pin_memory=True,
|
| 317 |
+
gradient_checkpointing=False,
|
| 318 |
+
auto_find_batch_size=False,
|
| 319 |
+
full_determinism=False,
|
| 320 |
+
torch_compile=False,
|
| 321 |
+
report_to="wandb",
|
| 322 |
+
run_name=params.wandb_run_name,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
early_stopping = EarlyStoppingCallback(
|
| 326 |
+
early_stopping_patience=config["early_stopping_patience"],
|
| 327 |
+
early_stopping_threshold=0.0, # set no threshold for now
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if params.finetune_from is not None:
|
| 331 |
+
model = ScOT.from_pretrained(
|
| 332 |
+
params.finetune_from, config=model_config, ignore_mismatched_sizes=True
|
| 333 |
+
)
|
| 334 |
+
else:
|
| 335 |
+
model = ScOT(model_config)
|
| 336 |
+
num_params = get_num_parameters(model)
|
| 337 |
+
config["num_params"] = num_params
|
| 338 |
+
num_params_no_embed = get_num_parameters_no_embed(model)
|
| 339 |
+
config["num_params_wout_embed"] = num_params_no_embed
|
| 340 |
+
if RANK == 0 or RANK == -1:
|
| 341 |
+
print(f"Model size: {num_params}")
|
| 342 |
+
print(f"Model size without embeddings: {num_params_no_embed}")
|
| 343 |
+
|
| 344 |
+
def compute_metrics(eval_preds):
|
| 345 |
+
channel_list = channel_slice_list
|
| 346 |
+
|
| 347 |
+
def get_statistics(errors):
|
| 348 |
+
median_error = np.median(errors, axis=0)
|
| 349 |
+
mean_error = np.mean(errors, axis=0)
|
| 350 |
+
std_error = np.std(errors, axis=0)
|
| 351 |
+
min_error = np.min(errors, axis=0)
|
| 352 |
+
max_error = np.max(errors, axis=0)
|
| 353 |
+
return {
|
| 354 |
+
"median_relative_l1_error": median_error,
|
| 355 |
+
"mean_relative_l1_error": mean_error,
|
| 356 |
+
"std_relative_l1_error": std_error,
|
| 357 |
+
"min_relative_l1_error": min_error,
|
| 358 |
+
"max_relative_l1_error": max_error,
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
error_statistics = [
|
| 362 |
+
get_statistics(
|
| 363 |
+
relative_lp_error(
|
| 364 |
+
eval_preds.predictions[:, channel_list[i] : channel_list[i + 1]],
|
| 365 |
+
eval_preds.label_ids[:, channel_list[i] : channel_list[i + 1]],
|
| 366 |
+
p=1,
|
| 367 |
+
return_percent=True,
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
for i in range(len(channel_list) - 1)
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
if output_dim == 1:
|
| 374 |
+
error_statistics = error_statistics[0]
|
| 375 |
+
return error_statistics
|
| 376 |
+
else:
|
| 377 |
+
mean_over_means = np.mean(
|
| 378 |
+
np.array(
|
| 379 |
+
[stats["mean_relative_l1_error"] for stats in error_statistics]
|
| 380 |
+
),
|
| 381 |
+
axis=0,
|
| 382 |
+
)
|
| 383 |
+
mean_over_medians = np.mean(
|
| 384 |
+
np.array(
|
| 385 |
+
[stats["median_relative_l1_error"] for stats in error_statistics]
|
| 386 |
+
),
|
| 387 |
+
axis=0,
|
| 388 |
+
)
|
| 389 |
+
error_statistics_ = {
|
| 390 |
+
"mean_relative_l1_error": mean_over_means,
|
| 391 |
+
"mean_over_median_relative_l1_error": mean_over_medians,
|
| 392 |
+
}
|
| 393 |
+
for i, stats in enumerate(error_statistics):
|
| 394 |
+
for key, value in stats.items():
|
| 395 |
+
error_statistics_[printable_channel_description[i] + "/" + key] = (
|
| 396 |
+
value
|
| 397 |
+
)
|
| 398 |
+
return error_statistics_
|
| 399 |
+
|
| 400 |
+
trainer = Trainer(
|
| 401 |
+
model=model,
|
| 402 |
+
args=train_config,
|
| 403 |
+
train_dataset=train_dataset,
|
| 404 |
+
eval_dataset=eval_dataset,
|
| 405 |
+
compute_metrics=compute_metrics,
|
| 406 |
+
callbacks=[early_stopping],
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
trainer.train(resume_from_checkpoint=params.resume_training)
|
| 410 |
+
trainer.save_model(train_config.output_dir)
|
| 411 |
+
|
| 412 |
+
if (RANK == 0 or RANK == -1) and params.push_to_hf_hub is not None:
|
| 413 |
+
model.push_to_hub(params.push_to_hf_hub)
|
| 414 |
+
|
| 415 |
+
do_test = (
|
| 416 |
+
True
|
| 417 |
+
if params.max_num_train_time_steps is None
|
| 418 |
+
and params.train_time_step_size is None
|
| 419 |
+
and not params.train_small_time_transition
|
| 420 |
+
and not ".time" in config["dataset"]
|
| 421 |
+
else False
|
| 422 |
+
)
|
| 423 |
+
if do_test:
|
| 424 |
+
print("Testing...")
|
| 425 |
+
test_set_kwargs = (
|
| 426 |
+
{"just_velocities": True}
|
| 427 |
+
if ("incompressible" in config["dataset"]) and params.just_velocities
|
| 428 |
+
else {}
|
| 429 |
+
)
|
| 430 |
+
out_test_set_kwargs = (
|
| 431 |
+
{"just_velocities": True}
|
| 432 |
+
if ("incompressible" in config["dataset"]) and params.just_velocities
|
| 433 |
+
else {}
|
| 434 |
+
)
|
| 435 |
+
if params.move_data is not None:
|
| 436 |
+
test_set_kwargs["move_to_local_scratch"] = params.move_data
|
| 437 |
+
out_test_set_kwargs["move_to_local_scratch"] = params.move_data
|
| 438 |
+
if time_involved:
|
| 439 |
+
test_set_kwargs = {
|
| 440 |
+
**test_set_kwargs,
|
| 441 |
+
"max_num_time_steps": 1,
|
| 442 |
+
"time_step_size": 14,
|
| 443 |
+
"allowed_time_transitions": [1],
|
| 444 |
+
}
|
| 445 |
+
out_test_set_kwargs = {
|
| 446 |
+
**out_test_set_kwargs,
|
| 447 |
+
"max_num_time_steps": 1,
|
| 448 |
+
"time_step_size": 20,
|
| 449 |
+
"allowed_time_transitions": [1],
|
| 450 |
+
}
|
| 451 |
+
if "RayleighTaylor" in config["dataset"]:
|
| 452 |
+
test_set_kwargs = {
|
| 453 |
+
**test_set_kwargs,
|
| 454 |
+
"max_num_time_steps": 1,
|
| 455 |
+
"time_step_size": 7,
|
| 456 |
+
"allowed_time_transitions": [1],
|
| 457 |
+
}
|
| 458 |
+
out_test_set_kwargs = {
|
| 459 |
+
**out_test_set_kwargs,
|
| 460 |
+
"max_num_time_steps": 1,
|
| 461 |
+
"time_step_size": 10,
|
| 462 |
+
"allowed_time_transitions": [1],
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
test_dataset = get_dataset(
|
| 466 |
+
dataset=config["dataset"],
|
| 467 |
+
which="test",
|
| 468 |
+
num_trajectories=config["num_trajectories"],
|
| 469 |
+
data_path=params.data_path,
|
| 470 |
+
**test_set_kwargs,
|
| 471 |
+
)
|
| 472 |
+
try:
|
| 473 |
+
out_dist_test_dataset = get_dataset(
|
| 474 |
+
dataset=config["dataset"] + ".out",
|
| 475 |
+
which="test",
|
| 476 |
+
num_trajectories=config["num_trajectories"],
|
| 477 |
+
data_path=params.data_path,
|
| 478 |
+
**out_test_set_kwargs,
|
| 479 |
+
)
|
| 480 |
+
except:
|
| 481 |
+
out_dist_test_dataset = None
|
| 482 |
+
predictions = trainer.predict(test_dataset, metric_key_prefix="")
|
| 483 |
+
if RANK == 0 or RANK == -1:
|
| 484 |
+
metrics = {}
|
| 485 |
+
for key, value in predictions.metrics.items():
|
| 486 |
+
metrics["test/" + key[1:]] = value
|
| 487 |
+
wandb.log(metrics)
|
| 488 |
+
create_predictions_plot(
|
| 489 |
+
predictions.predictions,
|
| 490 |
+
predictions.label_ids,
|
| 491 |
+
wandb_prefix="test",
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# evaluate on out-of-distribution test set
|
| 495 |
+
if out_dist_test_dataset is not None:
|
| 496 |
+
predictions = trainer.predict(out_dist_test_dataset, metric_key_prefix="")
|
| 497 |
+
if RANK == 0 or RANK == -1:
|
| 498 |
+
metrics = {}
|
| 499 |
+
for key, value in predictions.metrics.items():
|
| 500 |
+
metrics["test_out_dist/" + key[1:]] = value
|
| 501 |
+
wandb.log(metrics)
|
| 502 |
+
create_predictions_plot(
|
| 503 |
+
predictions.predictions,
|
| 504 |
+
predictions.label_ids,
|
| 505 |
+
wandb_prefix="test_out_dist",
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
if time_involved and (test_set_kwargs["time_step_size"] // 2 > 0):
|
| 509 |
+
trainer.set_ar_steps(test_set_kwargs["time_step_size"] // 2)
|
| 510 |
+
predictions = trainer.predict(test_dataset, metric_key_prefix="")
|
| 511 |
+
if RANK == 0 or RANK == -1:
|
| 512 |
+
metrics = {}
|
| 513 |
+
for key, value in predictions.metrics.items():
|
| 514 |
+
metrics["test/ar/" + key[1:]] = value
|
| 515 |
+
wandb.log(metrics)
|
| 516 |
+
create_predictions_plot(
|
| 517 |
+
predictions.predictions,
|
| 518 |
+
predictions.label_ids,
|
| 519 |
+
wandb_prefix="test/ar",
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# evaluate on out-of-distribution test set
|
| 523 |
+
if out_dist_test_dataset is not None:
|
| 524 |
+
trainer.set_ar_steps(out_test_set_kwargs["time_step_size"] // 2)
|
| 525 |
+
predictions = trainer.predict(
|
| 526 |
+
out_dist_test_dataset, metric_key_prefix=""
|
| 527 |
+
)
|
| 528 |
+
if RANK == 0 or RANK == -1:
|
| 529 |
+
metrics = {}
|
| 530 |
+
for key, value in predictions.metrics.items():
|
| 531 |
+
metrics["test_out_dist/ar/" + key[1:]] = value
|
| 532 |
+
wandb.log(metrics)
|
| 533 |
+
create_predictions_plot(
|
| 534 |
+
predictions.predictions,
|
| 535 |
+
predictions.label_ids,
|
| 536 |
+
wandb_prefix="test_out_dist/ar",
|
| 537 |
+
)
|
external/poseidon/scOT/trainer.py
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Our version of the Huggingface Trainer class.
|
| 3 |
+
It adds learning_rate_time_embedding, learning_rate_embedding_recovery as
|
| 4 |
+
additional learning rates and groups parameters for the optimizer.
|
| 5 |
+
It also allows for autoregressive rollouts by using
|
| 6 |
+
trainer.set_ar_steps(AR_STEPS) where AR_STEPS is either a an integer for a
|
| 7 |
+
homogeneous rollout of AR_STEPS steps or a list of integers for a heterogeneous
|
| 8 |
+
rollout where each element is the timestep.
|
| 9 |
+
If, additionally, output_all_steps is also set, the predict function will
|
| 10 |
+
output all intermediate steps as well.
|
| 11 |
+
|
| 12 |
+
We sublass a Huggingface Trainer to allow for autoregressive rollouts and multiple parameter groups in the optimizer.
|
| 13 |
+
It is specifically subclassed for our purpose.
|
| 14 |
+
|
| 15 |
+
A lot of code is copied over because only slight changes have been made.
|
| 16 |
+
|
| 17 |
+
The original code of Huggingface Transformers is distributed under the Apache 2.0 license. See below:
|
| 18 |
+
|
| 19 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
| 20 |
+
|
| 21 |
+
Apache License
|
| 22 |
+
Version 2.0, January 2004
|
| 23 |
+
http://www.apache.org/licenses/
|
| 24 |
+
|
| 25 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 26 |
+
|
| 27 |
+
1. Definitions.
|
| 28 |
+
|
| 29 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 30 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 31 |
+
|
| 32 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 33 |
+
the copyright owner that is granting the License.
|
| 34 |
+
|
| 35 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 36 |
+
other entities that control, are controlled by, or are under common
|
| 37 |
+
control with that entity. For the purposes of this definition,
|
| 38 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 39 |
+
direction or management of such entity, whether by contract or
|
| 40 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 41 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 42 |
+
|
| 43 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 44 |
+
exercising permissions granted by this License.
|
| 45 |
+
|
| 46 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 47 |
+
including but not limited to software source code, documentation
|
| 48 |
+
source, and configuration files.
|
| 49 |
+
|
| 50 |
+
"Object" form shall mean any form resulting from mechanical
|
| 51 |
+
transformation or translation of a Source form, including but
|
| 52 |
+
not limited to compiled object code, generated documentation,
|
| 53 |
+
and conversions to other media types.
|
| 54 |
+
|
| 55 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 56 |
+
Object form, made available under the License, as indicated by a
|
| 57 |
+
copyright notice that is included in or attached to the work
|
| 58 |
+
(an example is provided in the Appendix below).
|
| 59 |
+
|
| 60 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 61 |
+
form, that is based on (or derived from) the Work and for which the
|
| 62 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 63 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 64 |
+
of this License, Derivative Works shall not include works that remain
|
| 65 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 66 |
+
the Work and Derivative Works thereof.
|
| 67 |
+
|
| 68 |
+
"Contribution" shall mean any work of authorship, including
|
| 69 |
+
the original version of the Work and any modifications or additions
|
| 70 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 71 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 72 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 73 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 74 |
+
means any form of electronic, verbal, or written communication sent
|
| 75 |
+
to the Licensor or its representatives, including but not limited to
|
| 76 |
+
communication on electronic mailing lists, source code control systems,
|
| 77 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 78 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 79 |
+
excluding communication that is conspicuously marked or otherwise
|
| 80 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 81 |
+
|
| 82 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 83 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 84 |
+
subsequently incorporated within the Work.
|
| 85 |
+
|
| 86 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 87 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 88 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 89 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 90 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 91 |
+
Work and such Derivative Works in Source or Object form.
|
| 92 |
+
|
| 93 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 94 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 95 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 96 |
+
(except as stated in this section) patent license to make, have made,
|
| 97 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 98 |
+
where such license applies only to those patent claims licensable
|
| 99 |
+
by such Contributor that are necessarily infringed by their
|
| 100 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 101 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 102 |
+
institute patent litigation against any entity (including a
|
| 103 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 104 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 105 |
+
or contributory patent infringement, then any patent licenses
|
| 106 |
+
granted to You under this License for that Work shall terminate
|
| 107 |
+
as of the date such litigation is filed.
|
| 108 |
+
|
| 109 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 110 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 111 |
+
modifications, and in Source or Object form, provided that You
|
| 112 |
+
meet the following conditions:
|
| 113 |
+
|
| 114 |
+
(a) You must give any other recipients of the Work or
|
| 115 |
+
Derivative Works a copy of this License; and
|
| 116 |
+
|
| 117 |
+
(b) You must cause any modified files to carry prominent notices
|
| 118 |
+
stating that You changed the files; and
|
| 119 |
+
|
| 120 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 121 |
+
that You distribute, all copyright, patent, trademark, and
|
| 122 |
+
attribution notices from the Source form of the Work,
|
| 123 |
+
excluding those notices that do not pertain to any part of
|
| 124 |
+
the Derivative Works; and
|
| 125 |
+
|
| 126 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 127 |
+
distribution, then any Derivative Works that You distribute must
|
| 128 |
+
include a readable copy of the attribution notices contained
|
| 129 |
+
within such NOTICE file, excluding those notices that do not
|
| 130 |
+
pertain to any part of the Derivative Works, in at least one
|
| 131 |
+
of the following places: within a NOTICE text file distributed
|
| 132 |
+
as part of the Derivative Works; within the Source form or
|
| 133 |
+
documentation, if provided along with the Derivative Works; or,
|
| 134 |
+
within a display generated by the Derivative Works, if and
|
| 135 |
+
wherever such third-party notices normally appear. The contents
|
| 136 |
+
of the NOTICE file are for informational purposes only and
|
| 137 |
+
do not modify the License. You may add Your own attribution
|
| 138 |
+
notices within Derivative Works that You distribute, alongside
|
| 139 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 140 |
+
that such additional attribution notices cannot be construed
|
| 141 |
+
as modifying the License.
|
| 142 |
+
|
| 143 |
+
You may add Your own copyright statement to Your modifications and
|
| 144 |
+
may provide additional or different license terms and conditions
|
| 145 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 146 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 147 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 148 |
+
the conditions stated in this License.
|
| 149 |
+
|
| 150 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 151 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 152 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 153 |
+
this License, without any additional terms or conditions.
|
| 154 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 155 |
+
the terms of any separate license agreement you may have executed
|
| 156 |
+
with Licensor regarding such Contributions.
|
| 157 |
+
|
| 158 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 159 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 160 |
+
except as required for reasonable and customary use in describing the
|
| 161 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 162 |
+
|
| 163 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 164 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 165 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 166 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 167 |
+
implied, including, without limitation, any warranties or conditions
|
| 168 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 169 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 170 |
+
appropriateness of using or redistributing the Work and assume any
|
| 171 |
+
risks associated with Your exercise of permissions under this License.
|
| 172 |
+
|
| 173 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 174 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 175 |
+
unless required by applicable law (such as deliberate and grossly
|
| 176 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 177 |
+
liable to You for damages, including any direct, indirect, special,
|
| 178 |
+
incidental, or consequential damages of any character arising as a
|
| 179 |
+
result of this License or out of the use or inability to use the
|
| 180 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 181 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 182 |
+
other commercial damages or losses), even if such Contributor
|
| 183 |
+
has been advised of the possibility of such damages.
|
| 184 |
+
|
| 185 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 186 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 187 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 188 |
+
or other liability obligations and/or rights consistent with this
|
| 189 |
+
License. However, in accepting such obligations, You may act only
|
| 190 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 191 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 192 |
+
defend, and hold each Contributor harmless for any liability
|
| 193 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 194 |
+
of your accepting any such warranty or additional liability.
|
| 195 |
+
|
| 196 |
+
END OF TERMS AND CONDITIONS
|
| 197 |
+
|
| 198 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 199 |
+
|
| 200 |
+
To apply the Apache License to your work, attach the following
|
| 201 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 202 |
+
replaced with your own identifying information. (Don't include
|
| 203 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 204 |
+
comment syntax for the file format. We also recommend that a
|
| 205 |
+
file or class name and description of purpose be included on the
|
| 206 |
+
same "printed page" as the copyright notice for easier
|
| 207 |
+
identification within third-party archives.
|
| 208 |
+
|
| 209 |
+
Copyright [yyyy] [name of copyright owner]
|
| 210 |
+
|
| 211 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 212 |
+
you may not use this file except in compliance with the License.
|
| 213 |
+
You may obtain a copy of the License at
|
| 214 |
+
|
| 215 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 216 |
+
|
| 217 |
+
Unless required by applicable law or agreed to in writing, software
|
| 218 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 219 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 220 |
+
See the License for the specific language governing permissions and
|
| 221 |
+
limitations under the License.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
import torch
|
| 225 |
+
from torch import nn
|
| 226 |
+
from typing import List, Optional, Dict, Tuple, Union, Any
|
| 227 |
+
from transformers.trainer import *
|
| 228 |
+
from transformers import Trainer as Trainer_
|
| 229 |
+
from transformers import TrainingArguments as TrainingArguments_
|
| 230 |
+
from scOT.model import LayerNorm, ConditionalLayerNorm
|
| 231 |
+
from dataclasses import dataclass, field
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@dataclass
|
| 235 |
+
class TrainingArguments(TrainingArguments_):
|
| 236 |
+
learning_rate_embedding_recovery: Optional[float] = field(
|
| 237 |
+
default=None,
|
| 238 |
+
metadata={
|
| 239 |
+
"help": "The initial learning rate for the embedding/recovery. When not provided, falls back to `learning_rate`."
|
| 240 |
+
},
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
learning_rate_time_embedding: Optional[float] = field(
|
| 244 |
+
default=None,
|
| 245 |
+
metadata={
|
| 246 |
+
"help": "The initial learning rate for the time embedding. When not provided, falls back to `learning_rate`. Only used when embedding and recovery are also fine-tuned with different lr."
|
| 247 |
+
},
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
def set_training(
|
| 251 |
+
self,
|
| 252 |
+
*args,
|
| 253 |
+
learning_rate_embedding_recovery: Optional[float] = None,
|
| 254 |
+
learning_rate_time_embedding: Optional[float] = None,
|
| 255 |
+
**kwargs,
|
| 256 |
+
):
|
| 257 |
+
self = super().set_training(*args, **kwargs)
|
| 258 |
+
self.learning_rate_embedding_recovery = learning_rate_embedding_recovery
|
| 259 |
+
self.learning_rate_time_embedding = learning_rate_time_embedding
|
| 260 |
+
return self
|
| 261 |
+
|
| 262 |
+
def set_optimizer(
|
| 263 |
+
self,
|
| 264 |
+
*args,
|
| 265 |
+
learning_rate_embedding_recovery: Optional[float] = None,
|
| 266 |
+
learning_rate_time_embedding: Optional[float] = None,
|
| 267 |
+
**kwargs,
|
| 268 |
+
):
|
| 269 |
+
self = super().set_optimizer(*args, **kwargs)
|
| 270 |
+
self.learning_rate_embedding_recovery = learning_rate_embedding_recovery
|
| 271 |
+
self.learning_rate_time_embedding = learning_rate_time_embedding
|
| 272 |
+
return self
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class Trainer(Trainer_):
|
| 276 |
+
def __init__(self, *args, **kwargs):
|
| 277 |
+
super().__init__(*args, **kwargs)
|
| 278 |
+
self.ar_steps = None
|
| 279 |
+
self.output_all_steps = False
|
| 280 |
+
|
| 281 |
+
def get_decay_parameter_names(self, model) -> List[str]:
|
| 282 |
+
ALL_LAYERNORM_LAYERS = [torch.nn.LayerNorm, LayerNorm, ConditionalLayerNorm]
|
| 283 |
+
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
|
| 284 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
| 285 |
+
return decay_parameters
|
| 286 |
+
|
| 287 |
+
def get_conditional_norm_params(self, model):
|
| 288 |
+
params = []
|
| 289 |
+
for name, module in model.named_modules():
|
| 290 |
+
if isinstance(module, ConditionalLayerNorm):
|
| 291 |
+
for param_name, _ in module.named_parameters():
|
| 292 |
+
params.append(f"{name}.{param_name}")
|
| 293 |
+
return params
|
| 294 |
+
|
| 295 |
+
def create_optimizer(self):
|
| 296 |
+
"""This is the same as in the standard trainer, except param groups"""
|
| 297 |
+
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
| 298 |
+
if self.optimizer is None:
|
| 299 |
+
decay_parameters = self.get_decay_parameter_names(self.model)
|
| 300 |
+
if self.args.learning_rate_embedding_recovery is not None:
|
| 301 |
+
if self.args.learning_rate_time_embedding is not None:
|
| 302 |
+
time_embedding_params = self.get_conditional_norm_params(self.model)
|
| 303 |
+
params = {
|
| 304 |
+
"standard": [],
|
| 305 |
+
"no_weight_decay": [],
|
| 306 |
+
"embeddings": [],
|
| 307 |
+
"time_embedding": [],
|
| 308 |
+
}
|
| 309 |
+
for n, p in opt_model.named_parameters():
|
| 310 |
+
if (
|
| 311 |
+
"embeddings" in n or "patch_recovery" in n
|
| 312 |
+
) and p.requires_grad:
|
| 313 |
+
params["embeddings"].append(p)
|
| 314 |
+
elif n in decay_parameters and p.requires_grad:
|
| 315 |
+
params["standard"].append(p)
|
| 316 |
+
elif p.requires_grad:
|
| 317 |
+
if n in time_embedding_params:
|
| 318 |
+
params["time_embedding"].append(p)
|
| 319 |
+
else:
|
| 320 |
+
params["no_weight_decay"].append(p)
|
| 321 |
+
optimizer_grouped_parameters = [
|
| 322 |
+
{
|
| 323 |
+
"params": params["standard"],
|
| 324 |
+
"weight_decay": self.args.weight_decay,
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"params": params["no_weight_decay"],
|
| 328 |
+
"weight_decay": 0.0,
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"params": params["embeddings"],
|
| 332 |
+
"lr": self.args.learning_rate_embedding_recovery,
|
| 333 |
+
"weight_decay": self.args.weight_decay,
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"params": params["time_embedding"],
|
| 337 |
+
"lr": self.args.learning_rate_time_embedding,
|
| 338 |
+
"weight_decay": 0.0,
|
| 339 |
+
},
|
| 340 |
+
]
|
| 341 |
+
else:
|
| 342 |
+
params = {"standard": [], "no_weight_decay": [], "embeddings": []}
|
| 343 |
+
for n, p in opt_model.named_parameters():
|
| 344 |
+
if (
|
| 345 |
+
"embeddings" in n or "patch_recovery" in n
|
| 346 |
+
) and p.requires_grad:
|
| 347 |
+
params["embeddings"].append(p)
|
| 348 |
+
elif n in decay_parameters and p.requires_grad:
|
| 349 |
+
params["standard"].append(p)
|
| 350 |
+
elif p.requires_grad:
|
| 351 |
+
params["no_weight_decay"].append(p)
|
| 352 |
+
optimizer_grouped_parameters = [
|
| 353 |
+
{
|
| 354 |
+
"params": params["standard"],
|
| 355 |
+
"weight_decay": self.args.weight_decay,
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"params": params["no_weight_decay"],
|
| 359 |
+
"weight_decay": 0.0,
|
| 360 |
+
},
|
| 361 |
+
{
|
| 362 |
+
"params": params["embeddings"],
|
| 363 |
+
"lr": self.args.learning_rate_embedding_recovery,
|
| 364 |
+
"weight_decay": self.args.weight_decay,
|
| 365 |
+
},
|
| 366 |
+
]
|
| 367 |
+
elif self.args.learning_rate_time_embedding is not None:
|
| 368 |
+
time_embedding_params = self.get_conditional_norm_params(self.model)
|
| 369 |
+
params = {"standard": [], "no_weight_decay": [], "time_embedding": []}
|
| 370 |
+
for n, p in opt_model.named_parameters():
|
| 371 |
+
if n in decay_parameters and p.requires_grad:
|
| 372 |
+
params["standard"].append(p)
|
| 373 |
+
elif p.requires_grad:
|
| 374 |
+
if n in time_embedding_params:
|
| 375 |
+
params["time_embedding"].append(p)
|
| 376 |
+
else:
|
| 377 |
+
params["no_weight_decay"].append(p)
|
| 378 |
+
optimizer_grouped_parameters = [
|
| 379 |
+
{
|
| 380 |
+
"params": params["standard"],
|
| 381 |
+
"weight_decay": self.args.weight_decay,
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"params": params["no_weight_decay"],
|
| 385 |
+
"weight_decay": 0.0,
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"params": params["time_embedding"],
|
| 389 |
+
"lr": self.args.learning_rate_time_embedding,
|
| 390 |
+
"weight_decay": 0.0,
|
| 391 |
+
},
|
| 392 |
+
]
|
| 393 |
+
else:
|
| 394 |
+
optimizer_grouped_parameters = [
|
| 395 |
+
{
|
| 396 |
+
"params": [
|
| 397 |
+
p
|
| 398 |
+
for n, p in opt_model.named_parameters()
|
| 399 |
+
if (n in decay_parameters and p.requires_grad)
|
| 400 |
+
],
|
| 401 |
+
"weight_decay": self.args.weight_decay,
|
| 402 |
+
},
|
| 403 |
+
{
|
| 404 |
+
"params": [
|
| 405 |
+
p
|
| 406 |
+
for n, p in opt_model.named_parameters()
|
| 407 |
+
if (n not in decay_parameters and p.requires_grad)
|
| 408 |
+
],
|
| 409 |
+
"weight_decay": 0.0,
|
| 410 |
+
},
|
| 411 |
+
]
|
| 412 |
+
|
| 413 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
| 414 |
+
self.args
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
self.optimizer = optimizer_cls(
|
| 418 |
+
optimizer_grouped_parameters, **optimizer_kwargs
|
| 419 |
+
)
|
| 420 |
+
if optimizer_cls.__name__ == "Adam8bit":
|
| 421 |
+
import bitsandbytes
|
| 422 |
+
|
| 423 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
| 424 |
+
|
| 425 |
+
skipped = 0
|
| 426 |
+
for module in opt_model.modules():
|
| 427 |
+
if isinstance(module, nn.Embedding):
|
| 428 |
+
skipped += sum(
|
| 429 |
+
{
|
| 430 |
+
p.data_ptr(): p.numel() for p in module.parameters()
|
| 431 |
+
}.values()
|
| 432 |
+
)
|
| 433 |
+
print(f"skipped {module}: {skipped/2**20}M params")
|
| 434 |
+
manager.register_module_override(
|
| 435 |
+
module, "weight", {"optim_bits": 32}
|
| 436 |
+
)
|
| 437 |
+
logger.debug(
|
| 438 |
+
f"bitsandbytes: will optimize {module} in fp32"
|
| 439 |
+
)
|
| 440 |
+
print(f"skipped: {skipped/2**20}M params")
|
| 441 |
+
|
| 442 |
+
if is_sagemaker_mp_enabled():
|
| 443 |
+
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
| 444 |
+
|
| 445 |
+
return self.optimizer
|
| 446 |
+
|
| 447 |
+
def set_ar_steps(self, ar_steps=None, output_all_steps=False):
|
| 448 |
+
self.ar_steps = ar_steps
|
| 449 |
+
if self.ar_steps is not None and output_all_steps:
|
| 450 |
+
self.output_all_steps = True
|
| 451 |
+
|
| 452 |
+
def _model_forward(self, model, inputs):
|
| 453 |
+
if self.ar_steps is not None and model.config.use_conditioning:
|
| 454 |
+
channel_difference = (
|
| 455 |
+
model.config.num_channels > model.config.num_out_channels
|
| 456 |
+
)
|
| 457 |
+
# TODO: if outputs is not a dataclass this will break
|
| 458 |
+
if isinstance(self.ar_steps, int):
|
| 459 |
+
inputs = {**inputs, **{"time": inputs["time"] / self.ar_steps}}
|
| 460 |
+
if self.output_all_steps:
|
| 461 |
+
loss_ = []
|
| 462 |
+
outputs_ = []
|
| 463 |
+
hidden_states_ = []
|
| 464 |
+
attentions_ = []
|
| 465 |
+
reshaped_hidden_states_ = []
|
| 466 |
+
else:
|
| 467 |
+
loss = 0
|
| 468 |
+
for i in range(self.ar_steps):
|
| 469 |
+
outputs = model(**inputs)
|
| 470 |
+
if self.output_all_steps:
|
| 471 |
+
outputs_.append(outputs.output.detach())
|
| 472 |
+
if outputs.hidden_states is not None:
|
| 473 |
+
hidden_states_.append(outputs.hidden_states)
|
| 474 |
+
if outputs.attentions is not None:
|
| 475 |
+
attentions_.append(outputs.attentions)
|
| 476 |
+
if outputs.reshaped_hidden_states is not None:
|
| 477 |
+
reshaped_hidden_states_.append(
|
| 478 |
+
outputs.reshaped_hidden_states
|
| 479 |
+
)
|
| 480 |
+
if outputs.loss is not None:
|
| 481 |
+
loss_.append(outputs.loss)
|
| 482 |
+
else:
|
| 483 |
+
if outputs.loss is not None:
|
| 484 |
+
loss += outputs.loss
|
| 485 |
+
inputs = {
|
| 486 |
+
**inputs,
|
| 487 |
+
**{
|
| 488 |
+
"pixel_values": (
|
| 489 |
+
outputs.output.detach()
|
| 490 |
+
if not channel_difference
|
| 491 |
+
else torch.cat(
|
| 492 |
+
[
|
| 493 |
+
outputs.output.detach(),
|
| 494 |
+
inputs["pixel_values"][
|
| 495 |
+
:,
|
| 496 |
+
model.config.num_out_channels :,
|
| 497 |
+
],
|
| 498 |
+
],
|
| 499 |
+
dim=1,
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
},
|
| 503 |
+
}
|
| 504 |
+
if self.output_all_steps:
|
| 505 |
+
outputs.output = torch.stack(outputs_, dim=1)
|
| 506 |
+
if len(loss_) > 0:
|
| 507 |
+
outputs.loss = torch.stack(loss_, dim=0)
|
| 508 |
+
if len(hidden_states_) > 0:
|
| 509 |
+
outputs.hidden_states = [
|
| 510 |
+
torch.stack(hs, dim=1) for hs in zip(*hidden_states_)
|
| 511 |
+
]
|
| 512 |
+
if len(attentions_) > 0:
|
| 513 |
+
outputs.attentions = [
|
| 514 |
+
torch.stack(att, dim=1) for att in zip(*attentions_)
|
| 515 |
+
]
|
| 516 |
+
if len(reshaped_hidden_states_) > 0:
|
| 517 |
+
outputs.reshaped_hidden_states = [
|
| 518 |
+
torch.stack(rhs, dim=1)
|
| 519 |
+
for rhs in zip(*reshaped_hidden_states_)
|
| 520 |
+
]
|
| 521 |
+
else:
|
| 522 |
+
loss /= self.ar_steps
|
| 523 |
+
outputs.loss = loss
|
| 524 |
+
elif isinstance(self.ar_steps, list):
|
| 525 |
+
if self.output_all_steps:
|
| 526 |
+
loss_ = []
|
| 527 |
+
outputs_ = []
|
| 528 |
+
hidden_states_ = []
|
| 529 |
+
attentions_ = []
|
| 530 |
+
reshaped_hidden_states_ = []
|
| 531 |
+
else:
|
| 532 |
+
loss = 0
|
| 533 |
+
lead_time = inputs["time"]
|
| 534 |
+
for i in self.ar_steps:
|
| 535 |
+
inputs = {
|
| 536 |
+
**inputs,
|
| 537 |
+
**{"time": lead_time * i},
|
| 538 |
+
}
|
| 539 |
+
outputs = model(**inputs)
|
| 540 |
+
if self.output_all_steps:
|
| 541 |
+
outputs_.append(outputs.output.detach())
|
| 542 |
+
if self.output_all_steps:
|
| 543 |
+
outputs_.append(outputs.output.detach())
|
| 544 |
+
if outputs.hidden_states is not None:
|
| 545 |
+
hidden_states_.append(outputs.hidden_states)
|
| 546 |
+
if outputs.attentions is not None:
|
| 547 |
+
attentions_.append(outputs.attentions)
|
| 548 |
+
if outputs.reshaped_hidden_states is not None:
|
| 549 |
+
reshaped_hidden_states_.append(
|
| 550 |
+
outputs.reshaped_hidden_states
|
| 551 |
+
)
|
| 552 |
+
if outputs.loss is not None:
|
| 553 |
+
loss_.append(outputs.loss)
|
| 554 |
+
else:
|
| 555 |
+
if outputs.loss is not None:
|
| 556 |
+
loss += outputs.loss
|
| 557 |
+
inputs = {
|
| 558 |
+
**inputs,
|
| 559 |
+
**{
|
| 560 |
+
"pixel_values": (
|
| 561 |
+
outputs.output.detach()
|
| 562 |
+
if not channel_difference
|
| 563 |
+
else torch.cat(
|
| 564 |
+
[
|
| 565 |
+
outputs.output.detach(),
|
| 566 |
+
inputs["pixel_values"][
|
| 567 |
+
:,
|
| 568 |
+
model.config.num_out_channels :,
|
| 569 |
+
],
|
| 570 |
+
],
|
| 571 |
+
dim=1,
|
| 572 |
+
)
|
| 573 |
+
)
|
| 574 |
+
},
|
| 575 |
+
}
|
| 576 |
+
if self.output_all_steps:
|
| 577 |
+
outputs.output = torch.stack(outputs_, dim=1)
|
| 578 |
+
if len(loss_) > 0:
|
| 579 |
+
outputs.loss = torch.stack(loss_, dim=1)
|
| 580 |
+
if len(hidden_states_) > 0:
|
| 581 |
+
outputs.hidden_states = [
|
| 582 |
+
torch.stack(hs, dim=1) for hs in zip(*hidden_states_)
|
| 583 |
+
]
|
| 584 |
+
if len(attentions_) > 0:
|
| 585 |
+
outputs.attentions = [
|
| 586 |
+
torch.stack(att, dim=1) for att in zip(*attentions_)
|
| 587 |
+
]
|
| 588 |
+
if len(reshaped_hidden_states_) > 0:
|
| 589 |
+
outputs.reshaped_hidden_states = [
|
| 590 |
+
torch.stack(rhs, dim=1)
|
| 591 |
+
for rhs in zip(*reshaped_hidden_states_)
|
| 592 |
+
]
|
| 593 |
+
else:
|
| 594 |
+
loss /= len(self.ar_steps)
|
| 595 |
+
outputs.loss = loss
|
| 596 |
+
else:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
"num_ar_steps must be an integer or a list of integers."
|
| 599 |
+
)
|
| 600 |
+
else:
|
| 601 |
+
outputs = model(**inputs)
|
| 602 |
+
|
| 603 |
+
return outputs
|
| 604 |
+
|
| 605 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
| 606 |
+
if self.label_smoother is not None and "labels" in inputs:
|
| 607 |
+
labels = inputs.pop("labels")
|
| 608 |
+
else:
|
| 609 |
+
labels = None
|
| 610 |
+
outputs = self._model_forward(model, inputs)
|
| 611 |
+
# Save past state if it exists
|
| 612 |
+
# TODO: this needs to be fixed and made cleaner later.
|
| 613 |
+
if self.args.past_index >= 0:
|
| 614 |
+
self._past = outputs[self.args.past_index]
|
| 615 |
+
|
| 616 |
+
if labels is not None:
|
| 617 |
+
unwrapped_model = unwrap_model(model)
|
| 618 |
+
if _is_peft_model(unwrapped_model):
|
| 619 |
+
model_name = unwrapped_model.base_model.model._get_name()
|
| 620 |
+
else:
|
| 621 |
+
model_name = unwrapped_model._get_name()
|
| 622 |
+
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
| 623 |
+
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
| 624 |
+
else:
|
| 625 |
+
loss = self.label_smoother(outputs, labels)
|
| 626 |
+
else:
|
| 627 |
+
if isinstance(outputs, dict) and "loss" not in outputs:
|
| 628 |
+
raise ValueError(
|
| 629 |
+
"The model did not return a loss from the inputs, only the following keys: "
|
| 630 |
+
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
| 631 |
+
)
|
| 632 |
+
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
| 633 |
+
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
| 634 |
+
|
| 635 |
+
return (loss, outputs) if return_outputs else loss
|
| 636 |
+
|
| 637 |
+
def prediction_step(
|
| 638 |
+
self,
|
| 639 |
+
model: nn.Module,
|
| 640 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
| 641 |
+
prediction_loss_only: bool,
|
| 642 |
+
ignore_keys: Optional[List[str]] = None,
|
| 643 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 644 |
+
"""
|
| 645 |
+
Perform an evaluation step on `model` using `inputs`.
|
| 646 |
+
|
| 647 |
+
Subclass and override to inject custom behavior.
|
| 648 |
+
|
| 649 |
+
Args:
|
| 650 |
+
model (`nn.Module`):
|
| 651 |
+
The model to evaluate.
|
| 652 |
+
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
| 653 |
+
The inputs and targets of the model.
|
| 654 |
+
|
| 655 |
+
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
| 656 |
+
argument `labels`. Check your model's documentation for all accepted arguments.
|
| 657 |
+
prediction_loss_only (`bool`):
|
| 658 |
+
Whether or not to return the loss only.
|
| 659 |
+
ignore_keys (`List[str]`, *optional*):
|
| 660 |
+
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
| 661 |
+
gathering predictions.
|
| 662 |
+
|
| 663 |
+
Return:
|
| 664 |
+
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
|
| 665 |
+
logits and labels (each being optional).
|
| 666 |
+
"""
|
| 667 |
+
has_labels = (
|
| 668 |
+
False
|
| 669 |
+
if len(self.label_names) == 0
|
| 670 |
+
else all(inputs.get(k) is not None for k in self.label_names)
|
| 671 |
+
)
|
| 672 |
+
# For CLIP-like models capable of returning loss values.
|
| 673 |
+
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
|
| 674 |
+
# is `True` in `model.forward`.
|
| 675 |
+
return_loss = inputs.get("return_loss", None)
|
| 676 |
+
if return_loss is None:
|
| 677 |
+
return_loss = self.can_return_loss
|
| 678 |
+
loss_without_labels = (
|
| 679 |
+
True if len(self.label_names) == 0 and return_loss else False
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
inputs = self._prepare_inputs(inputs)
|
| 683 |
+
if ignore_keys is None:
|
| 684 |
+
if hasattr(self.model, "config"):
|
| 685 |
+
ignore_keys = getattr(
|
| 686 |
+
self.model.config, "keys_to_ignore_at_inference", []
|
| 687 |
+
)
|
| 688 |
+
else:
|
| 689 |
+
ignore_keys = []
|
| 690 |
+
|
| 691 |
+
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
| 692 |
+
if has_labels or loss_without_labels:
|
| 693 |
+
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
| 694 |
+
if len(labels) == 1:
|
| 695 |
+
labels = labels[0]
|
| 696 |
+
else:
|
| 697 |
+
labels = None
|
| 698 |
+
|
| 699 |
+
with torch.no_grad():
|
| 700 |
+
if is_sagemaker_mp_enabled():
|
| 701 |
+
raw_outputs = smp_forward_only(model, inputs)
|
| 702 |
+
if has_labels or loss_without_labels:
|
| 703 |
+
if isinstance(raw_outputs, dict):
|
| 704 |
+
loss_mb = raw_outputs["loss"]
|
| 705 |
+
logits_mb = tuple(
|
| 706 |
+
v
|
| 707 |
+
for k, v in raw_outputs.items()
|
| 708 |
+
if k not in ignore_keys + ["loss"]
|
| 709 |
+
)
|
| 710 |
+
else:
|
| 711 |
+
loss_mb = raw_outputs[0]
|
| 712 |
+
logits_mb = raw_outputs[1:]
|
| 713 |
+
|
| 714 |
+
loss = loss_mb.reduce_mean().detach().cpu()
|
| 715 |
+
logits = smp_nested_concat(logits_mb)
|
| 716 |
+
else:
|
| 717 |
+
loss = None
|
| 718 |
+
if isinstance(raw_outputs, dict):
|
| 719 |
+
logits_mb = tuple(
|
| 720 |
+
v for k, v in raw_outputs.items() if k not in ignore_keys
|
| 721 |
+
)
|
| 722 |
+
else:
|
| 723 |
+
logits_mb = raw_outputs
|
| 724 |
+
logits = smp_nested_concat(logits_mb)
|
| 725 |
+
else:
|
| 726 |
+
if has_labels or loss_without_labels:
|
| 727 |
+
with self.compute_loss_context_manager():
|
| 728 |
+
loss, outputs = self.compute_loss(
|
| 729 |
+
model, inputs, return_outputs=True
|
| 730 |
+
)
|
| 731 |
+
loss = loss.mean().detach()
|
| 732 |
+
|
| 733 |
+
if isinstance(outputs, dict):
|
| 734 |
+
logits = tuple(
|
| 735 |
+
v
|
| 736 |
+
for k, v in outputs.items()
|
| 737 |
+
if k not in ignore_keys + ["loss"]
|
| 738 |
+
)
|
| 739 |
+
else:
|
| 740 |
+
logits = outputs[1:]
|
| 741 |
+
else:
|
| 742 |
+
loss = None
|
| 743 |
+
with self.compute_loss_context_manager():
|
| 744 |
+
outputs = self._model_forward(model, inputs)
|
| 745 |
+
if isinstance(outputs, dict):
|
| 746 |
+
logits = tuple(
|
| 747 |
+
v for k, v in outputs.items() if k not in ignore_keys
|
| 748 |
+
)
|
| 749 |
+
else:
|
| 750 |
+
logits = outputs
|
| 751 |
+
# TODO: this needs to be fixed and made cleaner later.
|
| 752 |
+
if self.args.past_index >= 0:
|
| 753 |
+
self._past = outputs[self.args.past_index - 1]
|
| 754 |
+
|
| 755 |
+
if prediction_loss_only:
|
| 756 |
+
return (loss, None, None)
|
| 757 |
+
|
| 758 |
+
logits = nested_detach(logits)
|
| 759 |
+
if len(logits) == 1:
|
| 760 |
+
logits = logits[0]
|
| 761 |
+
|
| 762 |
+
return (loss, logits, labels)
|
external/poseidon/scOT/utils.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def read_cli(parser):
|
| 5 |
+
"""Reads command line arguments."""
|
| 6 |
+
|
| 7 |
+
parser.add_argument(
|
| 8 |
+
"--config",
|
| 9 |
+
type=str,
|
| 10 |
+
required=True,
|
| 11 |
+
help="Path to config file or JSON string",
|
| 12 |
+
)
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--json_config",
|
| 15 |
+
action="store_true",
|
| 16 |
+
help="Whether the config is a JSON string",
|
| 17 |
+
)
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--wandb_run_name",
|
| 20 |
+
type=str,
|
| 21 |
+
required=False,
|
| 22 |
+
default=None,
|
| 23 |
+
help="Name of the run in wandb",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--wandb_project_name",
|
| 27 |
+
type=str,
|
| 28 |
+
default="scOT",
|
| 29 |
+
help="Name of the wandb project",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--max_num_train_time_steps",
|
| 33 |
+
type=int,
|
| 34 |
+
default=None,
|
| 35 |
+
help="Maximum number of time steps to use for training and validation.",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--train_time_step_size",
|
| 39 |
+
type=int,
|
| 40 |
+
default=None,
|
| 41 |
+
help="Time step size to use for training and validation.",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--train_small_time_transition",
|
| 45 |
+
action="store_true",
|
| 46 |
+
help="Whether to train only for next step prediction.",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--data_path",
|
| 50 |
+
type=str,
|
| 51 |
+
required=True,
|
| 52 |
+
help="Base path to data.",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--checkpoint_path",
|
| 56 |
+
type=str,
|
| 57 |
+
required=True,
|
| 58 |
+
help="Path to checkpoint directory. Will be prepended by wandb project and run name.",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--disable_tqdm",
|
| 62 |
+
action="store_true",
|
| 63 |
+
help="Whether to disable tqdm progress bar",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--push_to_hf_hub",
|
| 67 |
+
type=str,
|
| 68 |
+
default=None,
|
| 69 |
+
help="Whether to push the model to Huggingface Hub. Specify the model repository name.",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--just_velocities",
|
| 73 |
+
action="store_true",
|
| 74 |
+
help="Whether to only use velocities as input. Only relevant for incompressible flow datasets.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--move_data",
|
| 78 |
+
type=str,
|
| 79 |
+
default=None,
|
| 80 |
+
help="If set, moves the data to this directory and trains from there.",
|
| 81 |
+
)
|
| 82 |
+
return parser
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_num_parameters(model):
|
| 86 |
+
"""Returns the number of trainable parameters in a model."""
|
| 87 |
+
|
| 88 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_num_parameters_no_embed(model):
|
| 92 |
+
"""Returns the number of trainable parameters in a scOT model without embedding and recovery."""
|
| 93 |
+
out = 0
|
| 94 |
+
for name, p in model.named_parameters():
|
| 95 |
+
if not ("embeddings" in name or "patch_recovery" in name) and p.requires_grad:
|
| 96 |
+
out += p.numel()
|
| 97 |
+
return out
|
poseidon_model.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
import xarray as xr
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
from torchvision.transforms.functional import resize
|
| 10 |
+
|
| 11 |
+
sys.path.append(os.path.abspath("poseidon_demo/external/poseidon"))
|
| 12 |
+
from scOT.model import ScOT, ScOTConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_model():
|
| 16 |
+
"""
|
| 17 |
+
Initializes and loads a POSEIDON model with fixed configuration.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
model (ScOT): An instance of the POSEIDON model in evaluation mode.
|
| 21 |
+
"""
|
| 22 |
+
config = ScOTConfig(
|
| 23 |
+
num_channels=4,
|
| 24 |
+
skip_connections=[True, True, True, True]
|
| 25 |
+
)
|
| 26 |
+
model = ScOT(config)
|
| 27 |
+
model.eval()
|
| 28 |
+
return model
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def run_inference_by_domain(model, domain):
|
| 32 |
+
"""
|
| 33 |
+
Runs the model on a synthetic input based on the chosen domain.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model (ScOT): The POSEIDON model.
|
| 37 |
+
domain (str): Domain to simulate input for. One of: 'Fluid Dynamics', 'Finance', 'Quantum', 'Biology / Medicine'.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
np.ndarray: The predicted model output.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
if domain == "Fluid Dynamics":
|
| 44 |
+
x = torch.linspace(-1, 1, 224)
|
| 45 |
+
y = torch.linspace(-1, 1, 224)
|
| 46 |
+
X, Y = torch.meshgrid(x, y, indexing="ij")
|
| 47 |
+
blob = torch.exp(-(X**2 + Y**2) * 10)
|
| 48 |
+
input_tensor = blob.expand(4, 224, 224).unsqueeze(0)
|
| 49 |
+
|
| 50 |
+
elif domain == "Finance":
|
| 51 |
+
base = torch.linspace(0, 1, 224).reshape(1, -1).repeat(224, 1)
|
| 52 |
+
noise = torch.randn(4, 224, 224) * 0.05
|
| 53 |
+
input_tensor = (base + noise).unsqueeze(0)
|
| 54 |
+
|
| 55 |
+
elif domain == "Quantum":
|
| 56 |
+
x = torch.linspace(0, 4 * torch.pi, 224)
|
| 57 |
+
y = torch.linspace(0, 4 * torch.pi, 224)
|
| 58 |
+
X, Y = torch.meshgrid(x, y, indexing="ij")
|
| 59 |
+
sin_grid = torch.sin(X) * torch.sin(Y)
|
| 60 |
+
input_tensor = sin_grid.expand(4, 224, 224).unsqueeze(0)
|
| 61 |
+
|
| 62 |
+
elif domain == "Biology / Medicine":
|
| 63 |
+
x = torch.linspace(-1, 1, 224)
|
| 64 |
+
y = torch.linspace(-1, 1, 224)
|
| 65 |
+
X, Y = torch.meshgrid(x, y, indexing="ij")
|
| 66 |
+
base_blob = torch.exp(-(X**2 + Y**2) * 5)
|
| 67 |
+
blob = torch.randn(4, 224, 224) * 0.2 + base_blob
|
| 68 |
+
input_tensor = blob.unsqueeze(0)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
else:
|
| 72 |
+
input_tensor = torch.randn(1, 4, 224, 224)
|
| 73 |
+
|
| 74 |
+
time_tensor = torch.tensor([0.0])
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
output = model(pixel_values=input_tensor, time=time_tensor).output
|
| 78 |
+
return output.squeeze().numpy()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def run_inference_on_dataset(model, dataset_name):
|
| 82 |
+
"""
|
| 83 |
+
Downloads and runs inference on a real scientific dataset using POSEIDON.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
model (ScOT): The POSEIDON model.
|
| 87 |
+
dataset_name (str): Identifier for the dataset.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
tuple: (input_array, output_array) as numpy arrays.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
dataset_mapping = {
|
| 94 |
+
"fluids.incompressible.Sines": {
|
| 95 |
+
"repo_id": "camlab-ethz/NS-Sines",
|
| 96 |
+
"filename": "velocity_0.nc",
|
| 97 |
+
"variable": "velocity"
|
| 98 |
+
},
|
| 99 |
+
"fluids.compressible.Riemann": {
|
| 100 |
+
"repo_id": "camlab-ethz/CE-RP",
|
| 101 |
+
"filename": "data_0.nc",
|
| 102 |
+
"variable": "data"
|
| 103 |
+
},
|
| 104 |
+
"reaction_diffusion.AllenCahn": {
|
| 105 |
+
"repo_id": "camlab-ethz/ACE",
|
| 106 |
+
"filename": "solution_0.nc",
|
| 107 |
+
"variable": "solution"
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
entry = dataset_mapping.get(dataset_name)
|
| 112 |
+
if entry is None:
|
| 113 |
+
raise ValueError(f"Unknown dataset name: {dataset_name}")
|
| 114 |
+
|
| 115 |
+
file_path = hf_hub_download(
|
| 116 |
+
repo_id=entry["repo_id"],
|
| 117 |
+
filename=entry["filename"],
|
| 118 |
+
repo_type="dataset"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
ds = xr.open_dataset(file_path)
|
| 122 |
+
var = ds[entry["variable"]]
|
| 123 |
+
print(f"Loaded shape: {var.shape}, dims: {var.dims}")
|
| 124 |
+
|
| 125 |
+
if "sample" in var.dims:
|
| 126 |
+
sample = var.isel(sample=0, time=0).values.astype(np.float32)
|
| 127 |
+
else:
|
| 128 |
+
sample = var.isel(time=0).values.astype(np.float32)
|
| 129 |
+
|
| 130 |
+
if sample.ndim > 3:
|
| 131 |
+
sample = np.squeeze(sample)
|
| 132 |
+
while sample.ndim < 3:
|
| 133 |
+
sample = np.expand_dims(sample, 0)
|
| 134 |
+
|
| 135 |
+
tensor = torch.tensor(sample)
|
| 136 |
+
if tensor.shape[-1] != 224 or tensor.shape[-2] != 224:
|
| 137 |
+
tensor = resize(tensor, size=[224, 224])
|
| 138 |
+
|
| 139 |
+
if tensor.shape[0] < 4:
|
| 140 |
+
pad = 4 - tensor.shape[0]
|
| 141 |
+
extra = torch.zeros((pad, 224, 224))
|
| 142 |
+
tensor = torch.cat([tensor, extra], dim=0)
|
| 143 |
+
elif tensor.shape[0] > 4:
|
| 144 |
+
tensor = tensor[:4]
|
| 145 |
+
|
| 146 |
+
input_tensor = tensor.unsqueeze(0)
|
| 147 |
+
time_tensor = torch.tensor([0.0])
|
| 148 |
+
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
output = model(pixel_values=input_tensor, time=time_tensor).output
|
| 151 |
+
|
| 152 |
+
return tensor.squeeze().numpy(), output.squeeze().numpy()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def plot_output(output_array, cmap="inferno", contrast=2.0):
|
| 156 |
+
"""
|
| 157 |
+
Plots the output array from the model using a heatmap.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
output_array (np.ndarray): Output from the model.
|
| 161 |
+
cmap (str): Colormap used for visualization.
|
| 162 |
+
contrast (float): Contrast scaling factor.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
matplotlib.figure.Figure: The heatmap figure.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
output_array = output_array - output_array.min()
|
| 169 |
+
output_array = output_array / output_array.max()
|
| 170 |
+
output_array = output_array ** contrast
|
| 171 |
+
|
| 172 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 173 |
+
sns.heatmap(
|
| 174 |
+
output_array,
|
| 175 |
+
ax=ax,
|
| 176 |
+
cmap=cmap,
|
| 177 |
+
cbar=True,
|
| 178 |
+
square=True,
|
| 179 |
+
xticklabels=False,
|
| 180 |
+
yticklabels=False,
|
| 181 |
+
linewidths=0,
|
| 182 |
+
)
|
| 183 |
+
ax.set_title("POSEIDON Output")
|
| 184 |
+
ax.axis("off")
|
| 185 |
+
return fig
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def plot_comparison(input_array, output_array, cmap="inferno"):
|
| 189 |
+
"""
|
| 190 |
+
Plots a side-by-side comparison of the input and the model output.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
input_array (np.ndarray): Ground truth or input data.
|
| 194 |
+
output_array (np.ndarray): Output predicted by the model.
|
| 195 |
+
cmap (str): Colormap used for both plots.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
matplotlib.figure.Figure: Figure showing input vs output.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
|
| 202 |
+
axs[0].imshow(input_array[0], cmap=cmap)
|
| 203 |
+
axs[0].set_title("Ground Truth")
|
| 204 |
+
axs[0].axis("off")
|
| 205 |
+
|
| 206 |
+
axs[1].imshow(output_array, cmap=cmap)
|
| 207 |
+
axs[1].set_title("POSEIDON Prediction")
|
| 208 |
+
axs[1].axis("off")
|
| 209 |
+
|
| 210 |
+
plt.tight_layout()
|
| 211 |
+
return fig
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
matplotlib
|
| 3 |
+
numpy
|
| 4 |
+
torch
|
| 5 |
+
torchvision
|
| 6 |
+
scipy
|
| 7 |
+
plotly
|
| 8 |
+
seaborn
|
| 9 |
+
huggingface_hub
|
| 10 |
+
xarray
|
simulations.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.stats import norm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def finance_demo():
|
| 7 |
+
"""
|
| 8 |
+
Simulates a Black-Scholes pricing scenario for European call options.
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
matplotlib.figure.Figure: A plot of option price vs stock price using the Black-Scholes formula.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
fig, ax = plt.subplots()
|
| 15 |
+
S = np.linspace(1, 100, 100)
|
| 16 |
+
K = 50 # strike price
|
| 17 |
+
T = 1 # time to maturity
|
| 18 |
+
r = 0.05 # risk-free rate
|
| 19 |
+
sigma = 0.2 # volatility
|
| 20 |
+
d1 = (np.log(S / K) + (r + sigma**2 / 2) * T) / (sigma * np.sqrt(T))
|
| 21 |
+
d2 = d1 - sigma * np.sqrt(T)
|
| 22 |
+
call_price = S * norm.cdf(d1) - K * np.exp(-r * T) * norm.cdf(d2)
|
| 23 |
+
ax.plot(S, call_price)
|
| 24 |
+
ax.set_title("Black-Scholes Call Option Price")
|
| 25 |
+
ax.set_xlabel("Stock Price")
|
| 26 |
+
ax.set_ylabel("Option Price")
|
| 27 |
+
return fig
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def quantum_demo():
|
| 31 |
+
"""
|
| 32 |
+
Simulates a 1D quantum wavefunction as a product of a Gaussian envelope and a cosine wave.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
matplotlib.figure.Figure: A plot representing a wavefunction in space.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
x = np.linspace(-5, 5, 500)
|
| 39 |
+
t = 0.1
|
| 40 |
+
psi = np.exp(-x**2) * np.cos(5 * x - t)
|
| 41 |
+
fig, ax = plt.subplots()
|
| 42 |
+
ax.plot(x, psi)
|
| 43 |
+
ax.set_title("Wavefunction: Particle in a Potential")
|
| 44 |
+
ax.set_xlabel("Position")
|
| 45 |
+
ax.set_ylabel("Amplitude")
|
| 46 |
+
return fig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def fluid_demo():
|
| 50 |
+
"""
|
| 51 |
+
Simulates a 1D velocity field representing wave-like fluid behavior.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
matplotlib.figure.Figure: A sine wave representing fluid velocity over space.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
x = np.linspace(0, 2 * np.pi, 100)
|
| 58 |
+
t = 1.0
|
| 59 |
+
u = np.sin(x - t)
|
| 60 |
+
fig, ax = plt.subplots()
|
| 61 |
+
ax.plot(x, u)
|
| 62 |
+
ax.set_title("1D Fluid Velocity Field")
|
| 63 |
+
ax.set_xlabel("x")
|
| 64 |
+
ax.set_ylabel("u(x, t)")
|
| 65 |
+
return fig
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def bio_demo():
|
| 69 |
+
"""
|
| 70 |
+
Simulates a reaction-diffusion pattern, commonly seen in developmental biology.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
matplotlib.figure.Figure: A morphogen concentration gradient over space.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
x = np.linspace(0, 1, 100)
|
| 77 |
+
t = 0.1
|
| 78 |
+
u = np.exp(-10 * (x - 0.5) ** 2) * np.exp(-t)
|
| 79 |
+
fig, ax = plt.subplots()
|
| 80 |
+
ax.plot(x, u)
|
| 81 |
+
ax.set_title("Reaction-Diffusion: Morphogen Gradient")
|
| 82 |
+
ax.set_xlabel("Position")
|
| 83 |
+
ax.set_ylabel("Concentration")
|
| 84 |
+
return fig
|