saipuppala commited on
Commit
3ac016b
·
verified ·
1 Parent(s): 06ffd6f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +34 -0
README.md CHANGED
@@ -31,6 +31,40 @@ of the heat equation, given a pair of previous states.
31
 
32
  ---
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ## 📝 Model Description
35
 
36
  PDE-Transformer is a transformer-based foundation model for physics simulations on regular grids.
 
31
 
32
  ---
33
 
34
+ ### How To Load Pretrained Models
35
+
36
+ PDETransformer can be loaded via
37
+
38
+ ```python
39
+ import torch
40
+ from pdetransformer.core.mixed_channels import PDETransformer
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+ # Load fine-tuned model from Hugging Face
45
+ model = PDETransformer.from_pretrained(
46
+ "saipuppala/pde-transformer-mc-s-diffusion-heat"
47
+ ).to(device)
48
+ model.eval()
49
+
50
+ # Example input: batch of size 1, 2 time steps, 64x64 grid
51
+ # x[..., 0, :, :] ~ u(t0), x[..., 1, :, :] ~ u(t1)
52
+ x = torch.randn((1, 2, 64, 64), dtype=torch.float32, device=device)
53
+
54
+ with torch.no_grad():
55
+ out = model(x) # model output (tensor / dict / object)
56
+ pred_all = out if isinstance(out, torch.Tensor) else (
57
+ getattr(out, "prediction", None)
58
+ or getattr(out, "sample", None)
59
+ or next(v for v in out.values() if isinstance(v, torch.Tensor))
60
+ )
61
+
62
+ # Convention: channel 1 corresponds to the next state prediction u(t2)
63
+ u_t2_pred = pred_all[:, 1] # shape: (B, H, W)
64
+ print(u_t2_pred.shape)
65
+
66
+ ```
67
+
68
  ## 📝 Model Description
69
 
70
  PDE-Transformer is a transformer-based foundation model for physics simulations on regular grids.