JulioContrerasH commited on
Commit
d6a745b
·
verified ·
1 Parent(s): 39c6fd8

Upload: load.py

Browse files
Files changed (1) hide show
  1. single/load.py +232 -237
single/load.py CHANGED
@@ -1,15 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn
3
- import pathlib
4
- import pystac
5
- from typing import Literal, Tuple
6
  import numpy as np
7
- import itertools
 
8
  from tqdm import tqdm
9
- import math
10
 
11
- # Ensemble model for combining multiple models' outputs
 
 
 
12
  class EnsembleModel(torch.nn.Module):
 
 
 
 
13
  def __init__(self, *models, mode="max"):
14
  super(EnsembleModel, self).__init__()
15
  self.models = torch.nn.ModuleList(models)
@@ -19,58 +36,114 @@ class EnsembleModel(torch.nn.Module):
19
 
20
  def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
21
  """
22
- Forward pass for ensemble.
23
-
24
  Returns:
25
- Tuple of (probabilities, uncertainty):
26
- - probabilities: (B, 1, H, W) - aggregated predictions
27
- - uncertainty: (B, 1, H, W) - normalized std deviation
28
  """
29
- outputs = []
30
- for model in self.models:
31
- output = model(x)
32
- outputs.append(output)
33
 
34
  if not outputs:
35
  return None, None
36
 
37
- # Stack all model outputs: (B, N, H, W) where N = number of models
38
- stacked_outputs = torch.stack(outputs, dim=1) # (B, N, 1, H, W)
39
- stacked_outputs = stacked_outputs.squeeze(2) # (B, N, H, W)
40
 
41
- # Calculate aggregated probabilities
42
  if self.mode == "max":
43
- output_probs = torch.max(stacked_outputs, dim=1, keepdim=True)[0]
44
  elif self.mode == "mean":
45
- output_probs = torch.mean(stacked_outputs, dim=1, keepdim=True)
46
  elif self.mode == "median":
47
- output_probs = torch.median(stacked_outputs, dim=1, keepdim=True)[0]
48
  elif self.mode == "min":
49
- output_probs = torch.min(stacked_outputs, dim=1, keepdim=True)[0]
50
  elif self.mode == "none":
51
- # Return all predictions without aggregation
52
- return stacked_outputs, None
53
- else:
54
- raise ValueError("Mode must be 'min', 'mean', 'median', or 'max'.")
55
-
56
- # Calculate uncertainty (normalized standard deviation)
57
  N = len(outputs)
58
  if N > 1:
59
- # Calculate std across models (dim=1)
60
- std_output = torch.std(stacked_outputs, dim=1, keepdim=True)
61
-
62
- # Normalize the standard deviation [0 - 1]
63
- # Formula: std_max = sqrt(0.25 * N / (N - 1))
64
  std_max = math.sqrt(0.25 * N / (N - 1))
65
- uncertainty = std_output / std_max
66
-
67
- # Clamp to [0, 1] to avoid numerical issues
68
- uncertainty = torch.clamp(uncertainty, 0.0, 1.0)
69
  else:
70
- # Single model: no uncertainty
71
- uncertainty = torch.zeros_like(output_probs)
72
 
73
- return output_probs, uncertainty # Both (B, 1, H, W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def compiled_model(
76
  path: pathlib.Path,
@@ -79,238 +152,160 @@ def compiled_model(
79
  *args, **kwargs
80
  ):
81
  """
82
- Loads model(s) dynamically based on STAC metadata.
83
-
84
- - If single .pt2 → returns single model
85
- - If multiple .pt2 → returns EnsembleModel
86
-
87
- Args:
88
- mode: Aggregation mode for ensembles (ignored for single models)
89
-
90
- Returns:
91
- Single model or EnsembleModel
92
  """
93
- model_paths = []
94
- for asset_key, asset in stac_item.assets.items():
95
- if asset.href.endswith(".pt2"):
96
- model_paths.append(asset.href)
97
 
98
  if not model_paths:
99
  raise ValueError("No .pt2 files found in STAC item assets.")
100
 
101
- model_paths.sort()
102
-
 
 
 
103
  if len(model_paths) == 1:
104
- # Single model
105
- return torch.export.load(model_paths[0]).module()
106
  else:
107
- # Ensemble model
108
- models = [torch.export.load(p).module() for p in model_paths]
109
  return EnsembleModel(*models, mode=mode)
110
 
111
- def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
112
- """
113
- Defines iteration strategy to traverse the image with overlap.
114
- """
115
- dimy, dimx = dimension
116
- if chunk_size > max(dimx, dimy):
117
- return [(0, 0)]
118
- y_step = chunk_size - overlap
119
- x_step = chunk_size - overlap
120
- iterchunks = list(itertools.product(range(0, dimy, y_step), range(0, dimx, x_step)))
121
- iterchunks_fixed = fix_lastchunk(
122
- iterchunks=iterchunks, s2dim=dimension, chunk_size=chunk_size
123
- )
124
- return iterchunks_fixed
125
-
126
-
127
- def fix_lastchunk(iterchunks, s2dim, chunk_size):
128
- """
129
- Adjusts last chunks to prevent them from exceeding boundaries.
130
- """
131
- itercontainer = []
132
- for index_i, index_j in iterchunks:
133
- if index_i + chunk_size > s2dim[0]:
134
- index_i = max(s2dim[0] - chunk_size, 0)
135
- if index_j + chunk_size > s2dim[1]:
136
- index_j = max(s2dim[1] - chunk_size, 0)
137
- itercontainer.append((index_i, index_j))
138
- return list(set(itercontainer)) # Returns unique values just in case
139
-
140
 
141
  def predict_large(
142
  image: np.ndarray,
143
  model: torch.nn.Module,
144
  chunk_size: int = 512,
145
- overlap: int = 64,
146
- device: str = "cpu",
 
 
147
  nodata: float = 0.0
148
  ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
149
  """
150
- Predict a full 'image' (C, H, W) using overlapping patches.
151
-
152
- Args:
153
- image: Input array (C, H, W)
154
- model: Compiled PyTorch model
155
- chunk_size: Tile size for inference
156
- overlap: Overlap between tiles
157
- device: 'cpu' or 'cuda'
158
- nodata: No-data value
159
-
160
- Returns:
161
- - For ensembles: Tuple of (probabilities, uncertainty), both (1, H, W)
162
- - For single models: probabilities array (1, H, W)
163
-
164
- Compatible with:
165
- - Normal models (with .eval()) - returns probabilities only
166
- - Exported models (.pt2) - returns probabilities only
167
- - Ensembles (EnsembleModel) - returns (probabilities, uncertainty)
168
  """
169
 
170
- # Validate input array dimensions
171
  if image.ndim != 3:
172
- raise ValueError(f"Input array must be (C, H, W). Received {image.shape}")
173
 
174
- bands, height, width = image.shape
175
-
176
- # Prepare model (compatibility logic for .pt2 models)
 
 
 
 
 
 
 
 
177
  try:
178
  model.eval()
179
- for p in model.parameters():
180
- p.requires_grad = False
181
- model = model.to(device)
182
- except (NotImplementedError, AttributeError):
183
- # Exported model (.pt2) or EnsembleModel
184
  model = model.to(device)
185
-
186
- test_input = torch.zeros(1, bands, chunk_size, chunk_size).to(device)
 
 
 
 
187
  with torch.no_grad():
188
- test_output = model(test_input)
189
-
190
- is_ensemble = isinstance(test_output, tuple) and len(test_output) == 2
191
-
192
- # Initialize output arrays
193
- output_probs = np.full((1, height, width), nodata, dtype=np.float32)
194
-
195
- if is_ensemble:
196
- output_uncertainty = np.full((1, height, width), nodata, dtype=np.float32)
 
 
197
 
198
- # Get the list of tile offsets
199
- coords = define_iteration(
200
- dimension=(height, width),
201
- chunk_size=chunk_size,
202
- overlap=overlap
 
 
 
 
 
203
  )
204
 
205
- # Iterate over tiles
206
- for idx, (row_off, col_off) in enumerate(tqdm(coords, desc="Inference")):
207
-
208
- # Read chunk (numpy slicing)
209
- patch = image[
210
- :,
211
- row_off : row_off + chunk_size,
212
- col_off : col_off + chunk_size
213
- ]
214
-
215
- # Convert to tensor and handle padding if tile is smaller than chunk_size
216
- patch_tensor = torch.from_numpy(patch).float().unsqueeze(0).to(device)
217
- _, _, h_tile, w_tile = patch_tensor.shape
218
-
219
- # Calculate padding needed
220
- pad_h = chunk_size - h_tile
221
- pad_w = chunk_size - w_tile
222
-
223
- # Apply padding if necessary
224
- if pad_h > 0 or pad_w > 0:
225
- patch_tensor = torch.nn.functional.pad(
226
- patch_tensor, (0, pad_w, 0, pad_h), "constant", nodata
227
- )
228
 
229
- # Create mask for nodata areas (all bands are nodata)
230
- mask_all = (patch_tensor == nodata).all(dim=1, keepdim=True)
231
-
232
- # Forward pass
 
233
  with torch.no_grad():
234
- model_output = model(patch_tensor)
235
-
236
  if is_ensemble:
237
- probs, uncertainty = model_output
238
- probs = probs.masked_fill(mask_all, nodata)
239
- uncertainty = uncertainty.masked_fill(mask_all, nodata)
240
  else:
241
- probs = model_output
242
- probs = probs.masked_fill(mask_all, nodata)
243
-
244
- # Remove batch dimension and ensure (1, H, W)
245
- if probs.ndim == 4:
246
- probs = probs.squeeze(0) # (1, H, W)
247
-
248
- # Convert to numpy
249
- result_probs = probs.cpu().numpy() # (1, H, W)
250
-
251
- if is_ensemble:
252
- if uncertainty.ndim == 4:
253
- uncertainty = uncertainty.squeeze(0)
254
- result_uncertainty = uncertainty.cpu().numpy()
255
 
256
- # Logic for partial writing
257
- if col_off == 0:
258
- offset_x = 0
259
- else:
260
- offset_x = col_off + overlap // 2
261
 
262
- if row_off == 0:
263
- offset_y = 0
264
- else:
265
- offset_y = row_off + overlap // 2
 
 
266
 
267
- if (offset_x + chunk_size) == width:
268
- length_x = chunk_size
269
- sub_x_start = 0
270
- else:
271
- length_x = chunk_size - (overlap // 2)
272
- sub_x_start = overlap // 2 if col_off != 0 else 0
273
 
274
- if (offset_y + chunk_size) == height:
275
- length_y = chunk_size
276
- sub_y_start = 0
277
- else:
278
- length_y = chunk_size - (overlap // 2)
279
- sub_y_start = overlap // 2 if row_off != 0 else 0
280
-
281
- # Ensure we don't exceed array bounds
282
- if offset_y + length_y > height:
283
- length_y = height - offset_y
284
- if offset_x + length_x > width:
285
- length_x = width - offset_x
286
-
287
- # Extract the valid region from the result
288
- to_write_probs = result_probs[
289
- :,
290
- sub_y_start : sub_y_start + length_y,
291
- sub_x_start : sub_x_start + length_x
292
- ]
293
-
294
- # Write to the output numpy array
295
- output_probs[
296
- :,
297
- offset_y : offset_y + length_y,
298
- offset_x : offset_x + length_x
299
- ] = to_write_probs
300
-
301
  if is_ensemble:
302
- to_write_uncertainty = result_uncertainty[
303
- :,
304
- sub_y_start : sub_y_start + length_y,
305
- sub_x_start : sub_x_start + length_x
306
- ]
307
- output_uncertainty[
308
- :,
309
- offset_y : offset_y + length_y,
310
- offset_x : offset_x + length_x
311
- ] = to_write_uncertainty
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  if is_ensemble:
314
- return output_probs, output_uncertainty
315
- else:
316
- return output_probs
 
 
 
1
+ """
2
+ load.py
3
+
4
+ Module for loading ensemble models (STAC compatible) and performing
5
+ optimized inference on large geospatial imagery using dynamic batching
6
+ and Gaussian blending.
7
+ """
8
+
9
+ import math
10
+ import pathlib
11
+ import itertools
12
+ from typing import Literal, Tuple, List
13
+
14
  import torch
15
  import torch.nn
 
 
 
16
  import numpy as np
17
+ import pystac
18
+ from torch.utils.data import Dataset, DataLoader
19
  from tqdm import tqdm
 
20
 
21
+ # ==============================================================================
22
+ # 1. HELPER CLASSES & FUNCTIONS
23
+ # ==============================================================================
24
+
25
  class EnsembleModel(torch.nn.Module):
26
+ """
27
+ Runtime ensemble model for combining multiple model outputs.
28
+ Used when loading multiple separate .pt2 files.
29
+ """
30
  def __init__(self, *models, mode="max"):
31
  super(EnsembleModel, self).__init__()
32
  self.models = torch.nn.ModuleList(models)
 
36
 
37
  def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
38
  """
 
 
39
  Returns:
40
+ - probabilities: (B, 1, H, W)
41
+ - uncertainty: (B, 1, H, W) (normalized std dev)
 
42
  """
43
+ outputs = [model(x) for model in self.models]
 
 
 
44
 
45
  if not outputs:
46
  return None, None
47
 
48
+ # Stack: (B, N, H, W)
49
+ stacked = torch.stack(outputs, dim=1).squeeze(2)
 
50
 
51
+ # Aggregation
52
  if self.mode == "max":
53
+ probs = torch.max(stacked, dim=1, keepdim=True)[0]
54
  elif self.mode == "mean":
55
+ probs = torch.mean(stacked, dim=1, keepdim=True)
56
  elif self.mode == "median":
57
+ probs = torch.median(stacked, dim=1, keepdim=True)[0]
58
  elif self.mode == "min":
59
+ probs = torch.min(stacked, dim=1, keepdim=True)[0]
60
  elif self.mode == "none":
61
+ return stacked, None
62
+
63
+ # Uncertainty
 
 
 
64
  N = len(outputs)
65
  if N > 1:
66
+ std = torch.std(stacked, dim=1, keepdim=True)
 
 
 
 
67
  std_max = math.sqrt(0.25 * N / (N - 1))
68
+ uncertainty = torch.clamp(std / std_max, 0.0, 1.0)
 
 
 
69
  else:
70
+ uncertainty = torch.zeros_like(probs)
 
71
 
72
+ return probs, uncertainty
73
+
74
+ def get_spline_window(window_size: int, power: int = 2) -> np.ndarray:
75
+ """Generates a 2D Hann window for smoothing tile edges."""
76
+ intersection = np.hanning(window_size)
77
+ window_2d = np.outer(intersection, intersection)
78
+ return (window_2d ** power).astype(np.float32)
79
+
80
+ def fix_lastchunk(iterchunks, s2dim, chunk_size):
81
+ """Adjusts the last chunks to fit within image boundaries."""
82
+ itercontainer = []
83
+ for index_i, index_j in iterchunks:
84
+ if index_i + chunk_size > s2dim[0]:
85
+ index_i = max(s2dim[0] - chunk_size, 0)
86
+ if index_j + chunk_size > s2dim[1]:
87
+ index_j = max(s2dim[1] - chunk_size, 0)
88
+ itercontainer.append((index_i, index_j))
89
+ return list(set(itercontainer))
90
+
91
+ def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
92
+ """Generates top-left coordinates for sliding window inference."""
93
+ dimy, dimx = dimension
94
+ if chunk_size > max(dimx, dimy):
95
+ return [(0, 0)]
96
+
97
+ y_step = chunk_size - overlap
98
+ x_step = chunk_size - overlap
99
+
100
+ iterchunks = list(itertools.product(
101
+ range(0, dimy, y_step),
102
+ range(0, dimx, x_step)
103
+ ))
104
+
105
+ return fix_lastchunk(iterchunks, dimension, chunk_size)
106
+
107
+ # ==============================================================================
108
+ # 2. DATASET FOR PARALLEL LOADING
109
+ # ==============================================================================
110
+
111
+ class PatchDataset(Dataset):
112
+ """
113
+ Dataset wrapper to handle image slicing and padding on CPU workers.
114
+ """
115
+ def __init__(self, image: np.ndarray, coords: List[Tuple[int, int]], chunk_size: int, nodata: float = 0):
116
+ self.image = image
117
+ self.coords = coords
118
+ self.chunk_size = chunk_size
119
+ self.nodata = nodata
120
+
121
+ def __len__(self):
122
+ return len(self.coords)
123
+
124
+ def __getitem__(self, idx):
125
+ row_off, col_off = self.coords[idx]
126
+
127
+ # Read patch
128
+ patch = self.image[:, row_off : row_off + self.chunk_size, col_off : col_off + self.chunk_size]
129
+ c, h, w = patch.shape
130
+
131
+ patch_tensor = torch.from_numpy(patch).float()
132
+
133
+ # Apply padding if patch is smaller than chunk_size (edges)
134
+ pad_h = self.chunk_size - h
135
+ pad_w = self.chunk_size - w
136
+ if pad_h > 0 or pad_w > 0:
137
+ patch_tensor = torch.nn.functional.pad(patch_tensor, (0, pad_w, 0, pad_h), "constant", self.nodata)
138
+
139
+ # Identify nodata pixels
140
+ mask_nodata = (patch_tensor == self.nodata).all(dim=0)
141
+
142
+ return patch_tensor, row_off, col_off, h, w, mask_nodata
143
+
144
+ # ==============================================================================
145
+ # 3. LOADING & INFERENCE LOGIC
146
+ # ==============================================================================
147
 
148
  def compiled_model(
149
  path: pathlib.Path,
 
152
  *args, **kwargs
153
  ):
154
  """
155
+ Loads .pt2 model(s). Returns a single model or an EnsembleModel.
156
+ Automatically unwraps ExportedProgram if possible.
 
 
 
 
 
 
 
 
157
  """
158
+ model_paths = sorted([
159
+ asset.href for key, asset in stac_item.assets.items()
160
+ if asset.href.endswith(".pt2")
161
+ ])
162
 
163
  if not model_paths:
164
  raise ValueError("No .pt2 files found in STAC item assets.")
165
 
166
+ # Helper to load and unwrap
167
+ def load_pt2(p):
168
+ program = torch.export.load(p)
169
+ return program.module() if hasattr(program, "module") else program
170
+
171
  if len(model_paths) == 1:
172
+ return load_pt2(model_paths[0])
 
173
  else:
174
+ models = [load_pt2(p) for p in model_paths]
 
175
  return EnsembleModel(*models, mode=mode)
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  def predict_large(
179
  image: np.ndarray,
180
  model: torch.nn.Module,
181
  chunk_size: int = 512,
182
+ overlap: int = 128,
183
+ batch_size: int = 16,
184
+ num_workers: int = 8, # Recommended: 8-16
185
+ device: str = "cuda",
186
  nodata: float = 0.0
187
  ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
188
  """
189
+ Optimized inference for large images using Dynamic Batching and Gaussian Blending.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  """
191
 
 
192
  if image.ndim != 3:
193
+ raise ValueError(f"Input image must be (C, H, W). Received {image.shape}")
194
 
195
+ # --- 1. Robust Model Unwrapping ---
196
+ # Fix for torch.export.load() returning an ExportedProgram container
197
+ if hasattr(model, "module") and callable(model.module):
198
+ try:
199
+ unpacked = model.module()
200
+ if isinstance(unpacked, torch.nn.Module):
201
+ model = unpacked
202
+ except Exception:
203
+ pass
204
+
205
+ # --- 2. Setup Model ---
206
  try:
207
  model.eval()
208
+ for p in model.parameters(): p.requires_grad = False
209
+ except: pass
210
+
211
+ # Only move to device if it's a standard Module (ExportedProgram handles device internally or via input)
212
+ if isinstance(model, torch.nn.Module):
213
  model = model.to(device)
214
+
215
+ bands, height, width = image.shape
216
+
217
+ # --- 3. Check Signature (Ensemble vs Single) ---
218
+ # Dummy pass (batch=2 to respect dynamic shapes)
219
+ dummy = torch.randn(2, bands, chunk_size, chunk_size).to(device)
220
  with torch.no_grad():
221
+ out = model(dummy)
222
+ is_ensemble = isinstance(out, tuple) and len(out) == 2
223
+
224
+ # --- 4. Initialize Buffers (Accumulators) ---
225
+ out_probs = np.zeros((1, height, width), dtype=np.float32)
226
+ count_map = np.zeros((1, height, width), dtype=np.float32)
227
+ out_uncert = np.zeros((1, height, width), dtype=np.float32) if is_ensemble else None
228
+
229
+ # --- 5. Prepare Spline Window ---
230
+ window_spline = get_spline_window(chunk_size, power=2)
231
+ window_tensor = torch.from_numpy(window_spline).to(device)
232
 
233
+ # --- 6. DataLoader Setup ---
234
+ coords = define_iteration((height, width), chunk_size, overlap)
235
+ dataset = PatchDataset(image, coords, chunk_size, nodata)
236
+ loader = DataLoader(
237
+ dataset,
238
+ batch_size=batch_size,
239
+ shuffle=False,
240
+ num_workers=num_workers,
241
+ prefetch_factor=2,
242
+ pin_memory=True
243
  )
244
 
245
+ # --- 7. Inference Loop ---
246
+ for batch in tqdm(loader, desc=f"Inference (Batch {batch_size})"):
247
+ patches, r_offs, c_offs, h_actuals, w_actuals, nodata_masks = batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ # Move inputs to GPU
250
+ patches = patches.to(device, non_blocking=True)
251
+ nodata_masks = nodata_masks.to(device, non_blocking=True) # (B, H, W)
252
+
253
+ # Forward Pass
254
  with torch.no_grad():
 
 
255
  if is_ensemble:
256
+ probs, uncert = model(patches)
 
 
257
  else:
258
+ probs = model(patches)
259
+ uncert = None
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
+ # Ensure correct dimensions (B, C, H, W)
262
+ if probs.ndim == 3: probs = probs.unsqueeze(1)
263
+ if is_ensemble and uncert.ndim == 3: uncert = uncert.unsqueeze(1)
 
 
264
 
265
+ # Prepare weights for batch
266
+ B = patches.size(0)
267
+ batch_weights = window_tensor.unsqueeze(0).unsqueeze(0).repeat(B, 1, 1, 1)
268
+
269
+ # Zero out weights where input was nodata
270
+ batch_weights[nodata_masks.unsqueeze(1)] = 0.0
271
 
272
+ # Apply weights
273
+ probs_weighted = probs * batch_weights
274
+ if is_ensemble:
275
+ uncert_weighted = uncert * batch_weights
 
 
276
 
277
+ # Move to CPU
278
+ probs_cpu = probs_weighted.cpu().numpy()
279
+ weights_cpu = batch_weights.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  if is_ensemble:
281
+ uncert_cpu = uncert_weighted.cpu().numpy()
282
+
283
+ # Accumulate in global map
284
+ for i in range(B):
285
+ r, c = r_offs[i].item(), c_offs[i].item()
286
+ h, w = h_actuals[i].item(), w_actuals[i].item()
287
+
288
+ # Slice valid regions
289
+ valid_probs = probs_cpu[i, :, :h, :w]
290
+ valid_weights = weights_cpu[i, :, :h, :w]
291
 
292
+ out_probs[:, r:r+h, c:c+w] += valid_probs
293
+ count_map[:, r:r+h, c:c+w] += valid_weights
294
+
295
+ if is_ensemble:
296
+ valid_uncert = uncert_cpu[i, :, :h, :w]
297
+ out_uncert[:, r:r+h, c:c+w] += valid_uncert
298
+
299
+ # --- 8. Normalization ---
300
+ mask_zero = (count_map == 0)
301
+ count_map[mask_zero] = 1.0 # Prevent div/0
302
+
303
+ out_probs /= count_map
304
+ out_probs[mask_zero] = nodata
305
+
306
  if is_ensemble:
307
+ out_uncert /= count_map
308
+ out_uncert[mask_zero] = nodata
309
+ return out_probs, out_uncert
310
+
311
+ return out_probs