Diffusion-Transformer mc-s β 2D Diffusion (Heat Equation)
This repository provides a fine-tuned PDE-Transformer for better accurate 2D heat diffusion predictions on regular grids. The base model is the mixed-channel small variant
from thuerey-group/pde-transformer, further trained on synthetic solutions of the
2D diffusion / heat equation with Gaussian bump initial conditions.
The goal of this model is to act as a surrogate solver for short-time predictions of the heat equation, given a pair of previous states.
Input: 2-channel field
[u(t0), u(t1)]
Output: next-step predictionu(t2)(via channel index 1 in the model output)
π Project Links
- Fine-tuning scripts & experiments: https://github.com/psmteja/agentic_ai_PDE_fm
How To Load Model
PDETransformer can be loaded via
import torch
from pdetransformer.core.mixed_channels import PDETransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load fine-tuned model from Hugging Face
model = PDETransformer.from_pretrained(
"saipuppala/diffusion_transformer"
).to(device)
model.eval()
# Example input: batch of size 1, 2 time steps, 64x64 grid
# x[..., 0, :, :] ~ u(t0), x[..., 1, :, :] ~ u(t1)
x = torch.randn((1, 2, 64, 64), dtype=torch.float32, device=device)
with torch.no_grad():
out = model(x) # model output (tensor / dict / object)
pred_all = out if isinstance(out, torch.Tensor) else (
getattr(out, "prediction", None)
or getattr(out, "sample", None)
or next(v for v in out.values() if isinstance(v, torch.Tensor))
)
# Convention: channel 1 corresponds to the next state prediction u(t2)
u_t2_pred = pred_all[:, 1] # shape: (B, H, W)
print(u_t2_pred.shape)
How To Download the Model
PDETransformer can be loaded via
from huggingface_hub import hf_hub_download
import torch
ckpt_path = hf_hub_download(
repo_id="saipuppala/diffusion_transformer ",
filename="diffusion_finetuned.pth",
)
state_dict = torch.load(ckpt_path, map_location="cpu")
# then load into a PDETransformer instance as shown above
π Model Description
PDE-Transformer is a transformer-based foundation model for physics simulations on regular grids. It combines architectural ideas from diffusion transformers with design choices tailored to large-scale physical simulations.
This checkpoint starts from the mixed-channel small (mc-s) variant and is fine-tuned only on 2D diffusion:
- Equation:
[ \partial_t u = \nu (u_{xx} + u_{yy}) ] - Domain: ([-1, 1]^2) discretized on a regular grid (e.g. (64 \times 64))
- Boundary conditions: periodic
- Initial condition: random 2D Gaussian bumps (random center, width, amplitude)
- Training target: finite-difference solution
u(t2)given[u(t0), u(t1)]
What this model is good for
- Fast surrogate for 2D heat equation rollouts over short time horizons.
- Experiments in:
- surrogate modeling,
- model-based control for diffusion-like processes,
- benchmarking PDE foundation models on simple physics.
What this model is not guaranteed to handle
- Arbitrary PDEs outside diffusion (e.g. NavierβStokes, Burgers, reactionβdiffusion)
β use the original foundation model or fine-tune separately. - Very different resolutions or domain geometries than used during training, unless you explicitly adapt / re-fine-tune.
- Downloads last month
- 43