Diffusion-Transformer mc-s – 2D Diffusion (Heat Equation)

This repository provides a fine-tuned PDE-Transformer for better accurate 2D heat diffusion predictions on regular grids. The base model is the mixed-channel small variant from thuerey-group/pde-transformer, further trained on synthetic solutions of the 2D diffusion / heat equation with Gaussian bump initial conditions.

The goal of this model is to act as a surrogate solver for short-time predictions of the heat equation, given a pair of previous states.

Input: 2-channel field [u(t0), u(t1)]
Output: next-step prediction u(t2) (via channel index 1 in the model output)


🌐 Project Links


How To Load Model

PDETransformer can be loaded via

import torch
from pdetransformer.core.mixed_channels import PDETransformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load fine-tuned model from Hugging Face
model = PDETransformer.from_pretrained(
    "saipuppala/diffusion_transformer"
).to(device)
model.eval()

# Example input: batch of size 1, 2 time steps, 64x64 grid
# x[..., 0, :, :] ~ u(t0), x[..., 1, :, :] ~ u(t1)
x = torch.randn((1, 2, 64, 64), dtype=torch.float32, device=device)

with torch.no_grad():
    out = model(x)                      # model output (tensor / dict / object)
    pred_all = out if isinstance(out, torch.Tensor) else (
        getattr(out, "prediction", None)
        or getattr(out, "sample", None)
        or next(v for v in out.values() if isinstance(v, torch.Tensor))
    )

# Convention: channel 1 corresponds to the next state prediction u(t2)
u_t2_pred = pred_all[:, 1]             # shape: (B, H, W)
print(u_t2_pred.shape)

How To Download the Model

PDETransformer can be loaded via

from huggingface_hub import hf_hub_download
import torch

ckpt_path = hf_hub_download(
    repo_id="saipuppala/diffusion_transformer ",
    filename="diffusion_finetuned.pth",
)

state_dict = torch.load(ckpt_path, map_location="cpu")
# then load into a PDETransformer instance as shown above

πŸ“ Model Description

PDE-Transformer is a transformer-based foundation model for physics simulations on regular grids. It combines architectural ideas from diffusion transformers with design choices tailored to large-scale physical simulations.

This checkpoint starts from the mixed-channel small (mc-s) variant and is fine-tuned only on 2D diffusion:

  • Equation:
    [ \partial_t u = \nu (u_{xx} + u_{yy}) ]
  • Domain: ([-1, 1]^2) discretized on a regular grid (e.g. (64 \times 64))
  • Boundary conditions: periodic
  • Initial condition: random 2D Gaussian bumps (random center, width, amplitude)
  • Training target: finite-difference solution u(t2) given [u(t0), u(t1)]

What this model is good for

  • Fast surrogate for 2D heat equation rollouts over short time horizons.
  • Experiments in:
    • surrogate modeling,
    • model-based control for diffusion-like processes,
    • benchmarking PDE foundation models on simple physics.

What this model is not guaranteed to handle

  • Arbitrary PDEs outside diffusion (e.g. Navier–Stokes, Burgers, reaction–diffusion)
    β†’ use the original foundation model or fine-tune separately.
  • Very different resolutions or domain geometries than used during training, unless you explicitly adapt / re-fine-tune.

Downloads last month
43
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support