Upload: load.py
Browse files- single/load.py +232 -237
single/load.py
CHANGED
|
@@ -1,15 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn
|
| 3 |
-
import pathlib
|
| 4 |
-
import pystac
|
| 5 |
-
from typing import Literal, Tuple
|
| 6 |
import numpy as np
|
| 7 |
-
import
|
|
|
|
| 8 |
from tqdm import tqdm
|
| 9 |
-
import math
|
| 10 |
|
| 11 |
-
#
|
|
|
|
|
|
|
|
|
|
| 12 |
class EnsembleModel(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def __init__(self, *models, mode="max"):
|
| 14 |
super(EnsembleModel, self).__init__()
|
| 15 |
self.models = torch.nn.ModuleList(models)
|
|
@@ -19,58 +36,114 @@ class EnsembleModel(torch.nn.Module):
|
|
| 19 |
|
| 20 |
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 21 |
"""
|
| 22 |
-
Forward pass for ensemble.
|
| 23 |
-
|
| 24 |
Returns:
|
| 25 |
-
|
| 26 |
-
-
|
| 27 |
-
- uncertainty: (B, 1, H, W) - normalized std deviation
|
| 28 |
"""
|
| 29 |
-
outputs = []
|
| 30 |
-
for model in self.models:
|
| 31 |
-
output = model(x)
|
| 32 |
-
outputs.append(output)
|
| 33 |
|
| 34 |
if not outputs:
|
| 35 |
return None, None
|
| 36 |
|
| 37 |
-
# Stack
|
| 38 |
-
|
| 39 |
-
stacked_outputs = stacked_outputs.squeeze(2) # (B, N, H, W)
|
| 40 |
|
| 41 |
-
#
|
| 42 |
if self.mode == "max":
|
| 43 |
-
|
| 44 |
elif self.mode == "mean":
|
| 45 |
-
|
| 46 |
elif self.mode == "median":
|
| 47 |
-
|
| 48 |
elif self.mode == "min":
|
| 49 |
-
|
| 50 |
elif self.mode == "none":
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
raise ValueError("Mode must be 'min', 'mean', 'median', or 'max'.")
|
| 55 |
-
|
| 56 |
-
# Calculate uncertainty (normalized standard deviation)
|
| 57 |
N = len(outputs)
|
| 58 |
if N > 1:
|
| 59 |
-
|
| 60 |
-
std_output = torch.std(stacked_outputs, dim=1, keepdim=True)
|
| 61 |
-
|
| 62 |
-
# Normalize the standard deviation [0 - 1]
|
| 63 |
-
# Formula: std_max = sqrt(0.25 * N / (N - 1))
|
| 64 |
std_max = math.sqrt(0.25 * N / (N - 1))
|
| 65 |
-
uncertainty =
|
| 66 |
-
|
| 67 |
-
# Clamp to [0, 1] to avoid numerical issues
|
| 68 |
-
uncertainty = torch.clamp(uncertainty, 0.0, 1.0)
|
| 69 |
else:
|
| 70 |
-
|
| 71 |
-
uncertainty = torch.zeros_like(output_probs)
|
| 72 |
|
| 73 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
def compiled_model(
|
| 76 |
path: pathlib.Path,
|
|
@@ -79,238 +152,160 @@ def compiled_model(
|
|
| 79 |
*args, **kwargs
|
| 80 |
):
|
| 81 |
"""
|
| 82 |
-
Loads model(s)
|
| 83 |
-
|
| 84 |
-
- If single .pt2 → returns single model
|
| 85 |
-
- If multiple .pt2 → returns EnsembleModel
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
mode: Aggregation mode for ensembles (ignored for single models)
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
Single model or EnsembleModel
|
| 92 |
"""
|
| 93 |
-
model_paths = [
|
| 94 |
-
|
| 95 |
-
if asset.href.endswith(".pt2")
|
| 96 |
-
|
| 97 |
|
| 98 |
if not model_paths:
|
| 99 |
raise ValueError("No .pt2 files found in STAC item assets.")
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
if len(model_paths) == 1:
|
| 104 |
-
|
| 105 |
-
return torch.export.load(model_paths[0]).module()
|
| 106 |
else:
|
| 107 |
-
|
| 108 |
-
models = [torch.export.load(p).module() for p in model_paths]
|
| 109 |
return EnsembleModel(*models, mode=mode)
|
| 110 |
|
| 111 |
-
def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
|
| 112 |
-
"""
|
| 113 |
-
Defines iteration strategy to traverse the image with overlap.
|
| 114 |
-
"""
|
| 115 |
-
dimy, dimx = dimension
|
| 116 |
-
if chunk_size > max(dimx, dimy):
|
| 117 |
-
return [(0, 0)]
|
| 118 |
-
y_step = chunk_size - overlap
|
| 119 |
-
x_step = chunk_size - overlap
|
| 120 |
-
iterchunks = list(itertools.product(range(0, dimy, y_step), range(0, dimx, x_step)))
|
| 121 |
-
iterchunks_fixed = fix_lastchunk(
|
| 122 |
-
iterchunks=iterchunks, s2dim=dimension, chunk_size=chunk_size
|
| 123 |
-
)
|
| 124 |
-
return iterchunks_fixed
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def fix_lastchunk(iterchunks, s2dim, chunk_size):
|
| 128 |
-
"""
|
| 129 |
-
Adjusts last chunks to prevent them from exceeding boundaries.
|
| 130 |
-
"""
|
| 131 |
-
itercontainer = []
|
| 132 |
-
for index_i, index_j in iterchunks:
|
| 133 |
-
if index_i + chunk_size > s2dim[0]:
|
| 134 |
-
index_i = max(s2dim[0] - chunk_size, 0)
|
| 135 |
-
if index_j + chunk_size > s2dim[1]:
|
| 136 |
-
index_j = max(s2dim[1] - chunk_size, 0)
|
| 137 |
-
itercontainer.append((index_i, index_j))
|
| 138 |
-
return list(set(itercontainer)) # Returns unique values just in case
|
| 139 |
-
|
| 140 |
|
| 141 |
def predict_large(
|
| 142 |
image: np.ndarray,
|
| 143 |
model: torch.nn.Module,
|
| 144 |
chunk_size: int = 512,
|
| 145 |
-
overlap: int =
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
nodata: float = 0.0
|
| 148 |
) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
|
| 149 |
"""
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
Args:
|
| 153 |
-
image: Input array (C, H, W)
|
| 154 |
-
model: Compiled PyTorch model
|
| 155 |
-
chunk_size: Tile size for inference
|
| 156 |
-
overlap: Overlap between tiles
|
| 157 |
-
device: 'cpu' or 'cuda'
|
| 158 |
-
nodata: No-data value
|
| 159 |
-
|
| 160 |
-
Returns:
|
| 161 |
-
- For ensembles: Tuple of (probabilities, uncertainty), both (1, H, W)
|
| 162 |
-
- For single models: probabilities array (1, H, W)
|
| 163 |
-
|
| 164 |
-
Compatible with:
|
| 165 |
-
- Normal models (with .eval()) - returns probabilities only
|
| 166 |
-
- Exported models (.pt2) - returns probabilities only
|
| 167 |
-
- Ensembles (EnsembleModel) - returns (probabilities, uncertainty)
|
| 168 |
"""
|
| 169 |
|
| 170 |
-
# Validate input array dimensions
|
| 171 |
if image.ndim != 3:
|
| 172 |
-
raise ValueError(f"Input
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
try:
|
| 178 |
model.eval()
|
| 179 |
-
for p in model.parameters():
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
model = model.to(device)
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
with torch.no_grad():
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
|
| 198 |
-
#
|
| 199 |
-
coords = define_iteration(
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
)
|
| 204 |
|
| 205 |
-
#
|
| 206 |
-
for
|
| 207 |
-
|
| 208 |
-
# Read chunk (numpy slicing)
|
| 209 |
-
patch = image[
|
| 210 |
-
:,
|
| 211 |
-
row_off : row_off + chunk_size,
|
| 212 |
-
col_off : col_off + chunk_size
|
| 213 |
-
]
|
| 214 |
-
|
| 215 |
-
# Convert to tensor and handle padding if tile is smaller than chunk_size
|
| 216 |
-
patch_tensor = torch.from_numpy(patch).float().unsqueeze(0).to(device)
|
| 217 |
-
_, _, h_tile, w_tile = patch_tensor.shape
|
| 218 |
-
|
| 219 |
-
# Calculate padding needed
|
| 220 |
-
pad_h = chunk_size - h_tile
|
| 221 |
-
pad_w = chunk_size - w_tile
|
| 222 |
-
|
| 223 |
-
# Apply padding if necessary
|
| 224 |
-
if pad_h > 0 or pad_w > 0:
|
| 225 |
-
patch_tensor = torch.nn.functional.pad(
|
| 226 |
-
patch_tensor, (0, pad_w, 0, pad_h), "constant", nodata
|
| 227 |
-
)
|
| 228 |
|
| 229 |
-
#
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
| 233 |
with torch.no_grad():
|
| 234 |
-
model_output = model(patch_tensor)
|
| 235 |
-
|
| 236 |
if is_ensemble:
|
| 237 |
-
probs,
|
| 238 |
-
probs = probs.masked_fill(mask_all, nodata)
|
| 239 |
-
uncertainty = uncertainty.masked_fill(mask_all, nodata)
|
| 240 |
else:
|
| 241 |
-
probs =
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
# Remove batch dimension and ensure (1, H, W)
|
| 245 |
-
if probs.ndim == 4:
|
| 246 |
-
probs = probs.squeeze(0) # (1, H, W)
|
| 247 |
-
|
| 248 |
-
# Convert to numpy
|
| 249 |
-
result_probs = probs.cpu().numpy() # (1, H, W)
|
| 250 |
-
|
| 251 |
-
if is_ensemble:
|
| 252 |
-
if uncertainty.ndim == 4:
|
| 253 |
-
uncertainty = uncertainty.squeeze(0)
|
| 254 |
-
result_uncertainty = uncertainty.cpu().numpy()
|
| 255 |
|
| 256 |
-
#
|
| 257 |
-
if
|
| 258 |
-
|
| 259 |
-
else:
|
| 260 |
-
offset_x = col_off + overlap // 2
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
length_x = chunk_size - (overlap // 2)
|
| 272 |
-
sub_x_start = overlap // 2 if col_off != 0 else 0
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
else:
|
| 278 |
-
length_y = chunk_size - (overlap // 2)
|
| 279 |
-
sub_y_start = overlap // 2 if row_off != 0 else 0
|
| 280 |
-
|
| 281 |
-
# Ensure we don't exceed array bounds
|
| 282 |
-
if offset_y + length_y > height:
|
| 283 |
-
length_y = height - offset_y
|
| 284 |
-
if offset_x + length_x > width:
|
| 285 |
-
length_x = width - offset_x
|
| 286 |
-
|
| 287 |
-
# Extract the valid region from the result
|
| 288 |
-
to_write_probs = result_probs[
|
| 289 |
-
:,
|
| 290 |
-
sub_y_start : sub_y_start + length_y,
|
| 291 |
-
sub_x_start : sub_x_start + length_x
|
| 292 |
-
]
|
| 293 |
-
|
| 294 |
-
# Write to the output numpy array
|
| 295 |
-
output_probs[
|
| 296 |
-
:,
|
| 297 |
-
offset_y : offset_y + length_y,
|
| 298 |
-
offset_x : offset_x + length_x
|
| 299 |
-
] = to_write_probs
|
| 300 |
-
|
| 301 |
if is_ensemble:
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
]
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
if is_ensemble:
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
return
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
load.py
|
| 3 |
+
|
| 4 |
+
Module for loading ensemble models (STAC compatible) and performing
|
| 5 |
+
optimized inference on large geospatial imagery using dynamic batching
|
| 6 |
+
and Gaussian blending.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import pathlib
|
| 11 |
+
import itertools
|
| 12 |
+
from typing import Literal, Tuple, List
|
| 13 |
+
|
| 14 |
import torch
|
| 15 |
import torch.nn
|
|
|
|
|
|
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
+
import pystac
|
| 18 |
+
from torch.utils.data import Dataset, DataLoader
|
| 19 |
from tqdm import tqdm
|
|
|
|
| 20 |
|
| 21 |
+
# ==============================================================================
|
| 22 |
+
# 1. HELPER CLASSES & FUNCTIONS
|
| 23 |
+
# ==============================================================================
|
| 24 |
+
|
| 25 |
class EnsembleModel(torch.nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Runtime ensemble model for combining multiple model outputs.
|
| 28 |
+
Used when loading multiple separate .pt2 files.
|
| 29 |
+
"""
|
| 30 |
def __init__(self, *models, mode="max"):
|
| 31 |
super(EnsembleModel, self).__init__()
|
| 32 |
self.models = torch.nn.ModuleList(models)
|
|
|
|
| 36 |
|
| 37 |
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 38 |
"""
|
|
|
|
|
|
|
| 39 |
Returns:
|
| 40 |
+
- probabilities: (B, 1, H, W)
|
| 41 |
+
- uncertainty: (B, 1, H, W) (normalized std dev)
|
|
|
|
| 42 |
"""
|
| 43 |
+
outputs = [model(x) for model in self.models]
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
if not outputs:
|
| 46 |
return None, None
|
| 47 |
|
| 48 |
+
# Stack: (B, N, H, W)
|
| 49 |
+
stacked = torch.stack(outputs, dim=1).squeeze(2)
|
|
|
|
| 50 |
|
| 51 |
+
# Aggregation
|
| 52 |
if self.mode == "max":
|
| 53 |
+
probs = torch.max(stacked, dim=1, keepdim=True)[0]
|
| 54 |
elif self.mode == "mean":
|
| 55 |
+
probs = torch.mean(stacked, dim=1, keepdim=True)
|
| 56 |
elif self.mode == "median":
|
| 57 |
+
probs = torch.median(stacked, dim=1, keepdim=True)[0]
|
| 58 |
elif self.mode == "min":
|
| 59 |
+
probs = torch.min(stacked, dim=1, keepdim=True)[0]
|
| 60 |
elif self.mode == "none":
|
| 61 |
+
return stacked, None
|
| 62 |
+
|
| 63 |
+
# Uncertainty
|
|
|
|
|
|
|
|
|
|
| 64 |
N = len(outputs)
|
| 65 |
if N > 1:
|
| 66 |
+
std = torch.std(stacked, dim=1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
std_max = math.sqrt(0.25 * N / (N - 1))
|
| 68 |
+
uncertainty = torch.clamp(std / std_max, 0.0, 1.0)
|
|
|
|
|
|
|
|
|
|
| 69 |
else:
|
| 70 |
+
uncertainty = torch.zeros_like(probs)
|
|
|
|
| 71 |
|
| 72 |
+
return probs, uncertainty
|
| 73 |
+
|
| 74 |
+
def get_spline_window(window_size: int, power: int = 2) -> np.ndarray:
|
| 75 |
+
"""Generates a 2D Hann window for smoothing tile edges."""
|
| 76 |
+
intersection = np.hanning(window_size)
|
| 77 |
+
window_2d = np.outer(intersection, intersection)
|
| 78 |
+
return (window_2d ** power).astype(np.float32)
|
| 79 |
+
|
| 80 |
+
def fix_lastchunk(iterchunks, s2dim, chunk_size):
|
| 81 |
+
"""Adjusts the last chunks to fit within image boundaries."""
|
| 82 |
+
itercontainer = []
|
| 83 |
+
for index_i, index_j in iterchunks:
|
| 84 |
+
if index_i + chunk_size > s2dim[0]:
|
| 85 |
+
index_i = max(s2dim[0] - chunk_size, 0)
|
| 86 |
+
if index_j + chunk_size > s2dim[1]:
|
| 87 |
+
index_j = max(s2dim[1] - chunk_size, 0)
|
| 88 |
+
itercontainer.append((index_i, index_j))
|
| 89 |
+
return list(set(itercontainer))
|
| 90 |
+
|
| 91 |
+
def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
|
| 92 |
+
"""Generates top-left coordinates for sliding window inference."""
|
| 93 |
+
dimy, dimx = dimension
|
| 94 |
+
if chunk_size > max(dimx, dimy):
|
| 95 |
+
return [(0, 0)]
|
| 96 |
+
|
| 97 |
+
y_step = chunk_size - overlap
|
| 98 |
+
x_step = chunk_size - overlap
|
| 99 |
+
|
| 100 |
+
iterchunks = list(itertools.product(
|
| 101 |
+
range(0, dimy, y_step),
|
| 102 |
+
range(0, dimx, x_step)
|
| 103 |
+
))
|
| 104 |
+
|
| 105 |
+
return fix_lastchunk(iterchunks, dimension, chunk_size)
|
| 106 |
+
|
| 107 |
+
# ==============================================================================
|
| 108 |
+
# 2. DATASET FOR PARALLEL LOADING
|
| 109 |
+
# ==============================================================================
|
| 110 |
+
|
| 111 |
+
class PatchDataset(Dataset):
|
| 112 |
+
"""
|
| 113 |
+
Dataset wrapper to handle image slicing and padding on CPU workers.
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, image: np.ndarray, coords: List[Tuple[int, int]], chunk_size: int, nodata: float = 0):
|
| 116 |
+
self.image = image
|
| 117 |
+
self.coords = coords
|
| 118 |
+
self.chunk_size = chunk_size
|
| 119 |
+
self.nodata = nodata
|
| 120 |
+
|
| 121 |
+
def __len__(self):
|
| 122 |
+
return len(self.coords)
|
| 123 |
+
|
| 124 |
+
def __getitem__(self, idx):
|
| 125 |
+
row_off, col_off = self.coords[idx]
|
| 126 |
+
|
| 127 |
+
# Read patch
|
| 128 |
+
patch = self.image[:, row_off : row_off + self.chunk_size, col_off : col_off + self.chunk_size]
|
| 129 |
+
c, h, w = patch.shape
|
| 130 |
+
|
| 131 |
+
patch_tensor = torch.from_numpy(patch).float()
|
| 132 |
+
|
| 133 |
+
# Apply padding if patch is smaller than chunk_size (edges)
|
| 134 |
+
pad_h = self.chunk_size - h
|
| 135 |
+
pad_w = self.chunk_size - w
|
| 136 |
+
if pad_h > 0 or pad_w > 0:
|
| 137 |
+
patch_tensor = torch.nn.functional.pad(patch_tensor, (0, pad_w, 0, pad_h), "constant", self.nodata)
|
| 138 |
+
|
| 139 |
+
# Identify nodata pixels
|
| 140 |
+
mask_nodata = (patch_tensor == self.nodata).all(dim=0)
|
| 141 |
+
|
| 142 |
+
return patch_tensor, row_off, col_off, h, w, mask_nodata
|
| 143 |
+
|
| 144 |
+
# ==============================================================================
|
| 145 |
+
# 3. LOADING & INFERENCE LOGIC
|
| 146 |
+
# ==============================================================================
|
| 147 |
|
| 148 |
def compiled_model(
|
| 149 |
path: pathlib.Path,
|
|
|
|
| 152 |
*args, **kwargs
|
| 153 |
):
|
| 154 |
"""
|
| 155 |
+
Loads .pt2 model(s). Returns a single model or an EnsembleModel.
|
| 156 |
+
Automatically unwraps ExportedProgram if possible.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
"""
|
| 158 |
+
model_paths = sorted([
|
| 159 |
+
asset.href for key, asset in stac_item.assets.items()
|
| 160 |
+
if asset.href.endswith(".pt2")
|
| 161 |
+
])
|
| 162 |
|
| 163 |
if not model_paths:
|
| 164 |
raise ValueError("No .pt2 files found in STAC item assets.")
|
| 165 |
|
| 166 |
+
# Helper to load and unwrap
|
| 167 |
+
def load_pt2(p):
|
| 168 |
+
program = torch.export.load(p)
|
| 169 |
+
return program.module() if hasattr(program, "module") else program
|
| 170 |
+
|
| 171 |
if len(model_paths) == 1:
|
| 172 |
+
return load_pt2(model_paths[0])
|
|
|
|
| 173 |
else:
|
| 174 |
+
models = [load_pt2(p) for p in model_paths]
|
|
|
|
| 175 |
return EnsembleModel(*models, mode=mode)
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
def predict_large(
|
| 179 |
image: np.ndarray,
|
| 180 |
model: torch.nn.Module,
|
| 181 |
chunk_size: int = 512,
|
| 182 |
+
overlap: int = 128,
|
| 183 |
+
batch_size: int = 16,
|
| 184 |
+
num_workers: int = 8, # Recommended: 8-16
|
| 185 |
+
device: str = "cuda",
|
| 186 |
nodata: float = 0.0
|
| 187 |
) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
|
| 188 |
"""
|
| 189 |
+
Optimized inference for large images using Dynamic Batching and Gaussian Blending.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
"""
|
| 191 |
|
|
|
|
| 192 |
if image.ndim != 3:
|
| 193 |
+
raise ValueError(f"Input image must be (C, H, W). Received {image.shape}")
|
| 194 |
|
| 195 |
+
# --- 1. Robust Model Unwrapping ---
|
| 196 |
+
# Fix for torch.export.load() returning an ExportedProgram container
|
| 197 |
+
if hasattr(model, "module") and callable(model.module):
|
| 198 |
+
try:
|
| 199 |
+
unpacked = model.module()
|
| 200 |
+
if isinstance(unpacked, torch.nn.Module):
|
| 201 |
+
model = unpacked
|
| 202 |
+
except Exception:
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
# --- 2. Setup Model ---
|
| 206 |
try:
|
| 207 |
model.eval()
|
| 208 |
+
for p in model.parameters(): p.requires_grad = False
|
| 209 |
+
except: pass
|
| 210 |
+
|
| 211 |
+
# Only move to device if it's a standard Module (ExportedProgram handles device internally or via input)
|
| 212 |
+
if isinstance(model, torch.nn.Module):
|
| 213 |
model = model.to(device)
|
| 214 |
+
|
| 215 |
+
bands, height, width = image.shape
|
| 216 |
+
|
| 217 |
+
# --- 3. Check Signature (Ensemble vs Single) ---
|
| 218 |
+
# Dummy pass (batch=2 to respect dynamic shapes)
|
| 219 |
+
dummy = torch.randn(2, bands, chunk_size, chunk_size).to(device)
|
| 220 |
with torch.no_grad():
|
| 221 |
+
out = model(dummy)
|
| 222 |
+
is_ensemble = isinstance(out, tuple) and len(out) == 2
|
| 223 |
+
|
| 224 |
+
# --- 4. Initialize Buffers (Accumulators) ---
|
| 225 |
+
out_probs = np.zeros((1, height, width), dtype=np.float32)
|
| 226 |
+
count_map = np.zeros((1, height, width), dtype=np.float32)
|
| 227 |
+
out_uncert = np.zeros((1, height, width), dtype=np.float32) if is_ensemble else None
|
| 228 |
+
|
| 229 |
+
# --- 5. Prepare Spline Window ---
|
| 230 |
+
window_spline = get_spline_window(chunk_size, power=2)
|
| 231 |
+
window_tensor = torch.from_numpy(window_spline).to(device)
|
| 232 |
|
| 233 |
+
# --- 6. DataLoader Setup ---
|
| 234 |
+
coords = define_iteration((height, width), chunk_size, overlap)
|
| 235 |
+
dataset = PatchDataset(image, coords, chunk_size, nodata)
|
| 236 |
+
loader = DataLoader(
|
| 237 |
+
dataset,
|
| 238 |
+
batch_size=batch_size,
|
| 239 |
+
shuffle=False,
|
| 240 |
+
num_workers=num_workers,
|
| 241 |
+
prefetch_factor=2,
|
| 242 |
+
pin_memory=True
|
| 243 |
)
|
| 244 |
|
| 245 |
+
# --- 7. Inference Loop ---
|
| 246 |
+
for batch in tqdm(loader, desc=f"Inference (Batch {batch_size})"):
|
| 247 |
+
patches, r_offs, c_offs, h_actuals, w_actuals, nodata_masks = batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
+
# Move inputs to GPU
|
| 250 |
+
patches = patches.to(device, non_blocking=True)
|
| 251 |
+
nodata_masks = nodata_masks.to(device, non_blocking=True) # (B, H, W)
|
| 252 |
+
|
| 253 |
+
# Forward Pass
|
| 254 |
with torch.no_grad():
|
|
|
|
|
|
|
| 255 |
if is_ensemble:
|
| 256 |
+
probs, uncert = model(patches)
|
|
|
|
|
|
|
| 257 |
else:
|
| 258 |
+
probs = model(patches)
|
| 259 |
+
uncert = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
# Ensure correct dimensions (B, C, H, W)
|
| 262 |
+
if probs.ndim == 3: probs = probs.unsqueeze(1)
|
| 263 |
+
if is_ensemble and uncert.ndim == 3: uncert = uncert.unsqueeze(1)
|
|
|
|
|
|
|
| 264 |
|
| 265 |
+
# Prepare weights for batch
|
| 266 |
+
B = patches.size(0)
|
| 267 |
+
batch_weights = window_tensor.unsqueeze(0).unsqueeze(0).repeat(B, 1, 1, 1)
|
| 268 |
+
|
| 269 |
+
# Zero out weights where input was nodata
|
| 270 |
+
batch_weights[nodata_masks.unsqueeze(1)] = 0.0
|
| 271 |
|
| 272 |
+
# Apply weights
|
| 273 |
+
probs_weighted = probs * batch_weights
|
| 274 |
+
if is_ensemble:
|
| 275 |
+
uncert_weighted = uncert * batch_weights
|
|
|
|
|
|
|
| 276 |
|
| 277 |
+
# Move to CPU
|
| 278 |
+
probs_cpu = probs_weighted.cpu().numpy()
|
| 279 |
+
weights_cpu = batch_weights.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
if is_ensemble:
|
| 281 |
+
uncert_cpu = uncert_weighted.cpu().numpy()
|
| 282 |
+
|
| 283 |
+
# Accumulate in global map
|
| 284 |
+
for i in range(B):
|
| 285 |
+
r, c = r_offs[i].item(), c_offs[i].item()
|
| 286 |
+
h, w = h_actuals[i].item(), w_actuals[i].item()
|
| 287 |
+
|
| 288 |
+
# Slice valid regions
|
| 289 |
+
valid_probs = probs_cpu[i, :, :h, :w]
|
| 290 |
+
valid_weights = weights_cpu[i, :, :h, :w]
|
| 291 |
|
| 292 |
+
out_probs[:, r:r+h, c:c+w] += valid_probs
|
| 293 |
+
count_map[:, r:r+h, c:c+w] += valid_weights
|
| 294 |
+
|
| 295 |
+
if is_ensemble:
|
| 296 |
+
valid_uncert = uncert_cpu[i, :, :h, :w]
|
| 297 |
+
out_uncert[:, r:r+h, c:c+w] += valid_uncert
|
| 298 |
+
|
| 299 |
+
# --- 8. Normalization ---
|
| 300 |
+
mask_zero = (count_map == 0)
|
| 301 |
+
count_map[mask_zero] = 1.0 # Prevent div/0
|
| 302 |
+
|
| 303 |
+
out_probs /= count_map
|
| 304 |
+
out_probs[mask_zero] = nodata
|
| 305 |
+
|
| 306 |
if is_ensemble:
|
| 307 |
+
out_uncert /= count_map
|
| 308 |
+
out_uncert[mask_zero] = nodata
|
| 309 |
+
return out_probs, out_uncert
|
| 310 |
+
|
| 311 |
+
return out_probs
|