|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
tags: |
|
|
- physics |
|
|
- PDEs |
|
|
- surrogate |
|
|
- heat-equation |
|
|
- diffusion |
|
|
base_model: pde-transformer |
|
|
--- |
|
|
|
|
|
# Diffusion-Transformer mc-s – 2D Diffusion (Heat Equation) |
|
|
|
|
|
This repository provides a **fine-tuned PDE-Transformer** for better accurate 2D heat diffusion predictions on regular grids. The base model is the mixed-channel small variant |
|
|
from `thuerey-group/pde-transformer`, further trained on synthetic solutions of the |
|
|
2D diffusion / heat equation with Gaussian bump initial conditions. |
|
|
|
|
|
The goal of this model is to act as a **surrogate solver** for short-time predictions |
|
|
of the heat equation, given a pair of previous states. |
|
|
|
|
|
> **Input:** 2-channel field `[u(t0), u(t1)]` |
|
|
> **Output:** next-step prediction `u(t2)` (via channel index 1 in the model output) |
|
|
|
|
|
--- |
|
|
|
|
|
## 🌐 Project Links |
|
|
|
|
|
- **Fine-tuning scripts & experiments**: [https://github.com/psmteja/agentic_ai_PDE_fm](https://github.com/psmteja/agentic_ai_PDE_fm) |
|
|
|
|
|
--- |
|
|
|
|
|
### How To Load Model |
|
|
|
|
|
PDETransformer can be loaded via |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from pdetransformer.core.mixed_channels import PDETransformer |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
# Load fine-tuned model from Hugging Face |
|
|
model = PDETransformer.from_pretrained( |
|
|
"saipuppala/diffusion_transformer" |
|
|
).to(device) |
|
|
model.eval() |
|
|
|
|
|
# Example input: batch of size 1, 2 time steps, 64x64 grid |
|
|
# x[..., 0, :, :] ~ u(t0), x[..., 1, :, :] ~ u(t1) |
|
|
x = torch.randn((1, 2, 64, 64), dtype=torch.float32, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model(x) # model output (tensor / dict / object) |
|
|
pred_all = out if isinstance(out, torch.Tensor) else ( |
|
|
getattr(out, "prediction", None) |
|
|
or getattr(out, "sample", None) |
|
|
or next(v for v in out.values() if isinstance(v, torch.Tensor)) |
|
|
) |
|
|
|
|
|
# Convention: channel 1 corresponds to the next state prediction u(t2) |
|
|
u_t2_pred = pred_all[:, 1] # shape: (B, H, W) |
|
|
print(u_t2_pred.shape) |
|
|
|
|
|
``` |
|
|
|
|
|
### How To Download the Model |
|
|
|
|
|
PDETransformer can be loaded via |
|
|
|
|
|
```python |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
|
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id="saipuppala/diffusion_transformer ", |
|
|
filename="diffusion_finetuned.pth", |
|
|
) |
|
|
|
|
|
state_dict = torch.load(ckpt_path, map_location="cpu") |
|
|
# then load into a PDETransformer instance as shown above |
|
|
|
|
|
|
|
|
``` |
|
|
|
|
|
## 📝 Model Description |
|
|
|
|
|
PDE-Transformer is a transformer-based foundation model for physics simulations on regular grids. |
|
|
It combines architectural ideas from diffusion transformers with design choices tailored to |
|
|
large-scale physical simulations. |
|
|
|
|
|
This checkpoint starts from the **mixed-channel small (mc-s)** variant and is **fine-tuned only on 2D diffusion**: |
|
|
|
|
|
- Equation: |
|
|
\[ |
|
|
\partial_t u = \nu (u_{xx} + u_{yy}) |
|
|
\] |
|
|
- Domain: \([-1, 1]^2\) discretized on a regular grid (e.g. \(64 \times 64\)) |
|
|
- Boundary conditions: periodic |
|
|
- Initial condition: random 2D Gaussian bumps (random center, width, amplitude) |
|
|
- Training target: finite-difference solution `u(t2)` given `[u(t0), u(t1)]` |
|
|
|
|
|
### What this model is good for |
|
|
|
|
|
- Fast surrogate for **2D heat equation rollouts** over short time horizons. |
|
|
- Experiments in: |
|
|
- surrogate modeling, |
|
|
- model-based control for diffusion-like processes, |
|
|
- benchmarking PDE foundation models on simple physics. |
|
|
|
|
|
### What this model is *not* guaranteed to handle |
|
|
|
|
|
- Arbitrary PDEs outside diffusion (e.g. Navier–Stokes, Burgers, reaction–diffusion) |
|
|
→ use the original foundation model or fine-tune separately. |
|
|
- Very different resolutions or domain geometries than used during training, |
|
|
unless you explicitly adapt / re-fine-tune. |
|
|
|
|
|
--- |
|
|
|