Training loop and around the codebase
Browse files- .gitignore +6 -0
- README.md +22 -1
- SEGMENTATION_PLAN.md +10 -2
- configs/default.yaml +14 -4
- requirements.txt +1 -0
- src/wireseghr/data/dataset.py +20 -14
- src/wireseghr/data/sampler.py +19 -4
- src/wireseghr/metrics.py +29 -6
- src/wireseghr/model/encoder.py +120 -12
- src/wireseghr/model/model.py +2 -2
- src/wireseghr/train.py +386 -1
.gitignore
CHANGED
|
@@ -114,3 +114,9 @@ venv/
|
|
| 114 |
ENV/
|
| 115 |
env.bak/
|
| 116 |
venv.bak/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
ENV/
|
| 115 |
env.bak/
|
| 116 |
venv.bak/
|
| 117 |
+
|
| 118 |
+
# Secrets
|
| 119 |
+
secrets/
|
| 120 |
+
|
| 121 |
+
# dataset
|
| 122 |
+
dataset/
|
README.md
CHANGED
|
@@ -28,4 +28,25 @@ python src/wireseghr/infer.py --config configs/default.yaml --image /path/to/ima
|
|
| 28 |
|
| 29 |
## Notes
|
| 30 |
- This is a segmentation-only codebase. Inpainting is out of scope here.
|
| 31 |
-
- Defaults locked: MiT-B3 encoder, patch size 768, MinMax 6×6, global+binary mask conditioning with patch-cropped global map.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
## Notes
|
| 30 |
- This is a segmentation-only codebase. Inpainting is out of scope here.
|
| 31 |
+
- Defaults locked: SegFormer MiT-B3 encoder, patch size 768, MinMax 6×6, global+binary mask conditioning with patch-cropped global map.
|
| 32 |
+
|
| 33 |
+
### Backbone Source
|
| 34 |
+
- Preferred: HuggingFace Transformers SegFormer (e.g., `nvidia/mit-b3`). We set `num_channels` to match input channels.
|
| 35 |
+
- Optional: `timm` features_only if a compatible SegFormer is available.
|
| 36 |
+
- Fallback: a small internal CNN that preserves 1/4, 1/8, 1/16, 1/32 strides with channels [64, 128, 320, 512].
|
| 37 |
+
|
| 38 |
+
Install requirements to get Transformers:
|
| 39 |
+
```
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Dataset Convention
|
| 44 |
+
- Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
|
| 45 |
+
- Example (after split 85/5/10):
|
| 46 |
+
- `dataset/train/images/1.jpg, 2.jpg, ...` and `dataset/train/gts/1.png, 2.png, ...`
|
| 47 |
+
- `dataset/val/images/...` and `dataset/val/gts/...`
|
| 48 |
+
- `dataset/test/images/...` and `dataset/test/gts/...`
|
| 49 |
+
- Masks are binary: foreground = white (255), background = black (0).
|
| 50 |
+
- The loader strictly enforces numeric stems and 1:1 pairing and will assert on mismatches.
|
| 51 |
+
|
| 52 |
+
Update `configs/default.yaml` with your paths under `data.train_images`, `data.train_masks`, etc. Defaults point to `dataset/train/images`, `dataset/train/gts`, and validation to `dataset/val/...`.
|
SEGMENTATION_PLAN.md
CHANGED
|
@@ -9,7 +9,7 @@ This plan distills the model and pipeline described in the paper sources:
|
|
| 9 |
Focus: segmentation only (no dataset collection or inpainting).
|
| 10 |
|
| 11 |
## Decisions and Defaults (locked)
|
| 12 |
-
- Backbone: SegFormer MiT-B3 (shared encoder `E`).
|
| 13 |
- Fine/local patch size p: 768.
|
| 14 |
- Conditioning: global map + binary location mask by default (Table `tables/logit.tex`).
|
| 15 |
- Conditioning map scope: patch-cropped from the global map per `paper-tex/sections/method_yq.tex` (no full-image concatenation variant).
|
|
@@ -41,7 +41,7 @@ Focus: segmentation only (no dataset collection or inpainting).
|
|
| 41 |
- `README.md` (segmentation-only usage)
|
| 42 |
|
| 43 |
## Model Specification
|
| 44 |
-
- Shared encoder `E`: SegFormer MiT-B3.
|
| 45 |
- Input channels (default): 3 (RGB) + 2 (MinMax) + 1 (global cond) + 1 (binary location) = 7.
|
| 46 |
- For the coarse pass, the cond and location channels are zeros to keep channel count consistent (`method_yq.tex`).
|
| 47 |
- Weight init for extra channels: copy mean of RGB conv weights or zero-init.
|
|
@@ -65,6 +65,14 @@ Focus: segmentation only (no dataset collection or inpainting).
|
|
| 65 |
- Downsample full-res mask to coarse size with max-pooling to prevent wire vanishing (`method_yq.tex`).
|
| 66 |
- Normalization: standard mean/std per backbone; apply consistently across channels (new channels can be mean=0, std=1 by convention, or min-max scaled).
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
## Training Pipeline
|
| 69 |
- Augment the full-res image (scaling, rotation, horizontal flip, photometric distortion) before constructing coarse/fine inputs (`method.tex`).
|
| 70 |
- Coarse input: downsample augmented full image to 512×512; build channels [RGB+MinMax+zeros(2)] → `E` → `D_C`.
|
|
|
|
| 9 |
Focus: segmentation only (no dataset collection or inpainting).
|
| 10 |
|
| 11 |
## Decisions and Defaults (locked)
|
| 12 |
+
- Backbone: SegFormer MiT-B3 via HuggingFace Transformers (shared encoder `E`), with `timm` or tiny CNN fallback.
|
| 13 |
- Fine/local patch size p: 768.
|
| 14 |
- Conditioning: global map + binary location mask by default (Table `tables/logit.tex`).
|
| 15 |
- Conditioning map scope: patch-cropped from the global map per `paper-tex/sections/method_yq.tex` (no full-image concatenation variant).
|
|
|
|
| 41 |
- `README.md` (segmentation-only usage)
|
| 42 |
|
| 43 |
## Model Specification
|
| 44 |
+
- Shared encoder `E`: SegFormer MiT-B3 (HF Transformers preferred).
|
| 45 |
- Input channels (default): 3 (RGB) + 2 (MinMax) + 1 (global cond) + 1 (binary location) = 7.
|
| 46 |
- For the coarse pass, the cond and location channels are zeros to keep channel count consistent (`method_yq.tex`).
|
| 47 |
- Weight init for extra channels: copy mean of RGB conv weights or zero-init.
|
|
|
|
| 65 |
- Downsample full-res mask to coarse size with max-pooling to prevent wire vanishing (`method_yq.tex`).
|
| 66 |
- Normalization: standard mean/std per backbone; apply consistently across channels (new channels can be mean=0, std=1 by convention, or min-max scaled).
|
| 67 |
|
| 68 |
+
### Dataset Convention (project-specific)
|
| 69 |
+
- Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
|
| 70 |
+
- Example:
|
| 71 |
+
- `dataset/images/1.jpg, 2.jpg, ..., N.jpg` (or `.jpeg`)
|
| 72 |
+
- `dataset/gts/1.png, 2.png, ..., N.png`
|
| 73 |
+
- Masks are binary: foreground = white (255), background = black (0).
|
| 74 |
+
- The loader (`data/dataset.py`) strictly enforces numeric stems and 1:1 pairing and will assert on mismatch.
|
| 75 |
+
|
| 76 |
## Training Pipeline
|
| 77 |
- Augment the full-res image (scaling, rotation, horizontal flip, photometric distortion) before constructing coarse/fine inputs (`method.tex`).
|
| 78 |
- Coarse input: downsample augmented full image to 512×512; build channels [RGB+MinMax+zeros(2)] → `E` → `D_C`.
|
configs/default.yaml
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
# Default configuration for WireSegHR (segmentation-only)
|
| 2 |
backbone: mit_b3
|
|
|
|
| 3 |
|
| 4 |
coarse:
|
| 5 |
train_size: 512
|
|
@@ -34,9 +35,18 @@ optim:
|
|
| 34 |
schedule: poly
|
| 35 |
power: 1.0
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# dataset paths (placeholders)
|
| 38 |
data:
|
| 39 |
-
train_images: /
|
| 40 |
-
train_masks: /
|
| 41 |
-
val_images: /
|
| 42 |
-
val_masks: /
|
|
|
|
|
|
|
|
|
| 1 |
# Default configuration for WireSegHR (segmentation-only)
|
| 2 |
backbone: mit_b3
|
| 3 |
+
pretrained: true # Uses HF SegFormer weights if available; else timm or tiny fallback
|
| 4 |
|
| 5 |
coarse:
|
| 6 |
train_size: 512
|
|
|
|
| 35 |
schedule: poly
|
| 36 |
power: 1.0
|
| 37 |
|
| 38 |
+
# training housekeeping
|
| 39 |
+
seed: 42
|
| 40 |
+
out_dir: runs/wireseghr
|
| 41 |
+
eval_interval: 500
|
| 42 |
+
ckpt_interval: 1000
|
| 43 |
+
# resume: runs/wireseghr/ckpt_1000.pt # optional
|
| 44 |
+
|
| 45 |
# dataset paths (placeholders)
|
| 46 |
data:
|
| 47 |
+
train_images: dataset/train/images
|
| 48 |
+
train_masks: dataset/train/gts
|
| 49 |
+
val_images: dataset/val/images
|
| 50 |
+
val_masks: dataset/val/gts
|
| 51 |
+
test_images: dataset/test/images
|
| 52 |
+
test_masks: dataset/test/gts
|
requirements.txt
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
torch>=2.1.0
|
| 2 |
torchvision>=0.16.0
|
| 3 |
timm>=0.9.8
|
|
|
|
| 4 |
numpy>=1.24.0
|
| 5 |
opencv-python>=4.8.0.76
|
| 6 |
Pillow>=9.5.0
|
|
|
|
| 1 |
torch>=2.1.0
|
| 2 |
torchvision>=0.16.0
|
| 3 |
timm>=0.9.8
|
| 4 |
+
transformers>=4.37.0
|
| 5 |
numpy>=1.24.0
|
| 6 |
opencv-python>=4.8.0.76
|
| 7 |
Pillow>=9.5.0
|
src/wireseghr/data/dataset.py
CHANGED
|
@@ -36,19 +36,25 @@ class WireSegDataset:
|
|
| 36 |
return {"image": img, "mask": mask_bin, "image_path": str(img_path), "mask_path": str(mask_path)}
|
| 37 |
|
| 38 |
def _index_pairs(self) -> List[tuple[Path, Path]]:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
if p.is_file() and p.suffix.lower() in exts_img:
|
| 44 |
-
imgs[p.stem] = p
|
| 45 |
-
masks: Dict[str, Path] = {}
|
| 46 |
-
for p in sorted(self.masks_dir.rglob("*")):
|
| 47 |
-
if p.is_file() and p.suffix.lower() in exts_mask:
|
| 48 |
-
masks[p.stem] = p
|
| 49 |
pairs: List[tuple[Path, Path]] = []
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
return pairs
|
|
|
|
| 36 |
return {"image": img, "mask": mask_bin, "image_path": str(img_path), "mask_path": str(mask_path)}
|
| 37 |
|
| 38 |
def _index_pairs(self) -> List[tuple[Path, Path]]:
|
| 39 |
+
# Convention: numeric filenames; images are .jpg/.jpeg; masks (gts) are .png
|
| 40 |
+
img_files = sorted([p for p in self.images_dir.glob("*.jpg") if p.is_file()])
|
| 41 |
+
img_files += sorted([p for p in self.images_dir.glob("*.jpeg") if p.is_file()])
|
| 42 |
+
assert len(img_files) > 0, f"No .jpg/.jpeg images in {self.images_dir}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
pairs: List[tuple[Path, Path]] = []
|
| 44 |
+
ids: List[int] = []
|
| 45 |
+
for p in img_files:
|
| 46 |
+
stem = p.stem
|
| 47 |
+
assert stem.isdigit(), f"Non-numeric filename encountered: {p.name}"
|
| 48 |
+
ids.append(int(stem))
|
| 49 |
+
ids = sorted(ids)
|
| 50 |
+
for i in ids:
|
| 51 |
+
# Prefer .jpg, else .jpeg
|
| 52 |
+
ip_jpg = self.images_dir / f"{i}.jpg"
|
| 53 |
+
ip_jpeg = self.images_dir / f"{i}.jpeg"
|
| 54 |
+
ip = ip_jpg if ip_jpg.exists() else ip_jpeg
|
| 55 |
+
assert ip.exists(), f"Missing image for {i}: {ip_jpg} or {ip_jpeg}"
|
| 56 |
+
mp = self.masks_dir / f"{i}.png"
|
| 57 |
+
assert mp.exists(), f"Missing mask for {i}: {mp}"
|
| 58 |
+
pairs.append((ip, mp))
|
| 59 |
+
assert len(pairs) > 0, f"No numeric pairs found in {self.images_dir} and {self.masks_dir}"
|
| 60 |
return pairs
|
src/wireseghr/data/sampler.py
CHANGED
|
@@ -1,14 +1,29 @@
|
|
| 1 |
# Balanced patch sampler (>=1% wire pixels)
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class BalancedPatchSampler:
|
| 9 |
patch_size: int = 768
|
| 10 |
min_wire_ratio: float = 0.01
|
|
|
|
| 11 |
|
| 12 |
-
def sample(self, image, mask):
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Balanced patch sampler (>=1% wire pixels)
|
| 2 |
+
"""Balanced patch sampling with >= min_wire_ratio positives.
|
| 3 |
+
|
| 4 |
+
Sampling is uniform over valid top-left positions; tries a fixed number of
|
| 5 |
+
attempts and asserts if none meet the threshold.
|
| 6 |
+
"""
|
| 7 |
|
| 8 |
from dataclasses import dataclass
|
| 9 |
+
import numpy as np
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class BalancedPatchSampler:
|
| 14 |
patch_size: int = 768
|
| 15 |
min_wire_ratio: float = 0.01
|
| 16 |
+
max_tries: int = 200
|
| 17 |
|
| 18 |
+
def sample(self, image: np.ndarray, mask: np.ndarray) -> tuple[int, int]:
|
| 19 |
+
h, w = mask.shape
|
| 20 |
+
p = self.patch_size
|
| 21 |
+
assert h >= p and w >= p, "Image smaller than patch size"
|
| 22 |
+
for _ in range(self.max_tries):
|
| 23 |
+
y = np.random.randint(0, h - p + 1)
|
| 24 |
+
x = np.random.randint(0, w - p + 1)
|
| 25 |
+
m = mask[y : y + p, x : x + p]
|
| 26 |
+
ratio = float(m.sum()) / float(p * p)
|
| 27 |
+
if ratio >= self.min_wire_ratio:
|
| 28 |
+
return int(y), int(x)
|
| 29 |
+
raise AssertionError("Failed to sample a patch meeting min_wire_ratio")
|
src/wireseghr/metrics.py
CHANGED
|
@@ -1,9 +1,32 @@
|
|
| 1 |
-
# Metrics placeholder: IoU, F1, Precision, Recall
|
| 2 |
-
# TODO: Implement proper metrics matching paper tables.
|
| 3 |
-
|
| 4 |
from typing import Dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
# TODO: implement
|
| 9 |
-
return {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Dict
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compute_metrics(pred_mask: np.ndarray, gt_mask: np.ndarray) -> Dict[str, float]:
|
| 6 |
+
"""Compute binary segmentation metrics on 0/1 numpy masks.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
pred_mask: HxW uint8 or bool in {0,1}
|
| 10 |
+
gt_mask: HxW uint8 or bool in {0,1}
|
| 11 |
+
Returns:
|
| 12 |
+
dict with iou, f1, precision, recall
|
| 13 |
+
"""
|
| 14 |
+
p = (pred_mask > 0).astype(np.uint8)
|
| 15 |
+
g = (gt_mask > 0).astype(np.uint8)
|
| 16 |
+
|
| 17 |
+
tp = int(np.sum((p == 1) & (g == 1)))
|
| 18 |
+
fp = int(np.sum((p == 1) & (g == 0)))
|
| 19 |
+
fn = int(np.sum((p == 0) & (g == 1)))
|
| 20 |
+
|
| 21 |
+
denom_iou = tp + fp + fn
|
| 22 |
+
iou = (tp / denom_iou) if denom_iou > 0 else 0.0
|
| 23 |
+
|
| 24 |
+
prec_den = tp + fp
|
| 25 |
+
rec_den = tp + fn
|
| 26 |
+
precision = (tp / prec_den) if prec_den > 0 else 0.0
|
| 27 |
+
recall = (tp / rec_den) if rec_den > 0 else 0.0
|
| 28 |
|
| 29 |
+
denom_f1 = precision + recall
|
| 30 |
+
f1 = (2 * precision * recall / denom_f1) if denom_f1 > 0 else 0.0
|
| 31 |
|
| 32 |
+
return {"iou": float(iou), "f1": float(f1), "precision": float(precision), "recall": float(recall)}
|
|
|
|
|
|
src/wireseghr/model/encoder.py
CHANGED
|
@@ -25,18 +25,126 @@ class SegFormerEncoder(nn.Module):
|
|
| 25 |
self.pretrained = pretrained
|
| 26 |
self.out_indices = out_indices
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
#
|
| 30 |
-
self.encoder =
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
)
|
| 37 |
|
| 38 |
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
self.pretrained = pretrained
|
| 26 |
self.out_indices = out_indices
|
| 27 |
|
| 28 |
+
# Prefer HuggingFace SegFormer for 'mit_*' backbones.
|
| 29 |
+
# Otherwise try timm features_only. Always have Tiny CNN fallback.
|
| 30 |
+
self.encoder = None
|
| 31 |
+
self.hf = None
|
| 32 |
+
prefer_hf = backbone.startswith("mit_") or backbone.startswith("segformer")
|
| 33 |
+
if prefer_hf:
|
| 34 |
+
# HF -> timm -> tiny
|
| 35 |
+
try:
|
| 36 |
+
self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
|
| 37 |
+
self.feature_dims = self.hf.feature_dims
|
| 38 |
+
except Exception:
|
| 39 |
+
try:
|
| 40 |
+
self.encoder = timm.create_model(
|
| 41 |
+
backbone,
|
| 42 |
+
pretrained=pretrained,
|
| 43 |
+
features_only=True,
|
| 44 |
+
out_indices=out_indices,
|
| 45 |
+
in_chans=in_channels,
|
| 46 |
+
)
|
| 47 |
+
self.feature_dims = list(self.encoder.feature_info.channels())
|
| 48 |
+
except Exception:
|
| 49 |
+
self.encoder = None
|
| 50 |
+
self.fallback = _TinyEncoder(in_channels)
|
| 51 |
+
self.feature_dims = [64, 128, 320, 512]
|
| 52 |
+
else:
|
| 53 |
+
# timm -> HF -> tiny
|
| 54 |
+
try:
|
| 55 |
+
self.encoder = timm.create_model(
|
| 56 |
+
backbone,
|
| 57 |
+
pretrained=pretrained,
|
| 58 |
+
features_only=True,
|
| 59 |
+
out_indices=out_indices,
|
| 60 |
+
in_chans=in_channels,
|
| 61 |
+
)
|
| 62 |
+
self.feature_dims = list(self.encoder.feature_info.channels())
|
| 63 |
+
except Exception:
|
| 64 |
+
try:
|
| 65 |
+
self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
|
| 66 |
+
self.feature_dims = self.hf.feature_dims
|
| 67 |
+
except Exception:
|
| 68 |
+
self.encoder = None
|
| 69 |
+
self.fallback = _TinyEncoder(in_channels)
|
| 70 |
+
self.feature_dims = [64, 128, 320, 512]
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 73 |
+
if self.encoder is not None:
|
| 74 |
+
feats = self.encoder(x)
|
| 75 |
+
assert isinstance(feats, (list, tuple)) and len(feats) == len(self.out_indices)
|
| 76 |
+
return list(feats)
|
| 77 |
+
elif self.hf is not None:
|
| 78 |
+
return self.hf(x)
|
| 79 |
+
else:
|
| 80 |
+
return self.fallback(x)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class _TinyEncoder(nn.Module):
|
| 84 |
+
def __init__(self, in_chans: int):
|
| 85 |
+
super().__init__()
|
| 86 |
+
# Output strides: 4, 8, 16, 32 with channels 64,128,320,512
|
| 87 |
+
self.stem = nn.Sequential(
|
| 88 |
+
nn.Conv2d(in_chans, 64, kernel_size=7, stride=4, padding=3, bias=False),
|
| 89 |
+
nn.BatchNorm2d(64),
|
| 90 |
+
nn.ReLU(inplace=True),
|
| 91 |
+
)
|
| 92 |
+
self.stage1 = nn.Sequential(
|
| 93 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
|
| 94 |
+
nn.BatchNorm2d(128),
|
| 95 |
+
nn.ReLU(inplace=True),
|
| 96 |
+
)
|
| 97 |
+
self.stage2 = nn.Sequential(
|
| 98 |
+
nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1, bias=False),
|
| 99 |
+
nn.BatchNorm2d(320),
|
| 100 |
+
nn.ReLU(inplace=True),
|
| 101 |
+
)
|
| 102 |
+
self.stage3 = nn.Sequential(
|
| 103 |
+
nn.Conv2d(320, 512, kernel_size=3, stride=2, padding=1, bias=False),
|
| 104 |
+
nn.BatchNorm2d(512),
|
| 105 |
+
nn.ReLU(inplace=True),
|
| 106 |
)
|
| 107 |
|
| 108 |
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 109 |
+
c0 = self.stem(x) # 1/4
|
| 110 |
+
c1 = self.stage1(c0) # 1/8
|
| 111 |
+
c2 = self.stage2(c1) # 1/16
|
| 112 |
+
c3 = self.stage3(c2) # 1/32
|
| 113 |
+
return [c0, c1, c2, c3]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class _HFEncoderWrapper(nn.Module):
|
| 117 |
+
def __init__(self, in_chans: int, backbone: str, pretrained: bool):
|
| 118 |
+
super().__init__()
|
| 119 |
+
# Lazy import to avoid hard dependency during tests if not used
|
| 120 |
+
from transformers import SegformerModel, SegformerConfig
|
| 121 |
+
|
| 122 |
+
name_map = {
|
| 123 |
+
"mit_b0": "nvidia/mit-b0",
|
| 124 |
+
"mit_b1": "nvidia/mit-b1",
|
| 125 |
+
"mit_b2": "nvidia/mit-b2",
|
| 126 |
+
"mit_b3": "nvidia/mit-b3",
|
| 127 |
+
"mit_b4": "nvidia/mit-b4",
|
| 128 |
+
"mit_b5": "nvidia/mit-b5",
|
| 129 |
+
}
|
| 130 |
+
model_id = name_map.get(backbone, "nvidia/mit-b0")
|
| 131 |
+
|
| 132 |
+
if pretrained:
|
| 133 |
+
base_cfg = SegformerConfig.from_pretrained(model_id)
|
| 134 |
+
base_cfg.num_channels = in_chans
|
| 135 |
+
self.model = SegformerModel.from_pretrained(
|
| 136 |
+
model_id, config=base_cfg, ignore_mismatched_sizes=True
|
| 137 |
+
)
|
| 138 |
+
else:
|
| 139 |
+
cfg = SegformerConfig() # default config (B0-like)
|
| 140 |
+
cfg.num_channels = in_chans
|
| 141 |
+
self.model = SegformerModel(cfg)
|
| 142 |
+
|
| 143 |
+
# Expose channel dims per stage
|
| 144 |
+
self.feature_dims = list(self.model.config.hidden_sizes)
|
| 145 |
+
|
| 146 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 147 |
+
outputs = self.model(pixel_values=x, output_hidden_states=True, return_dict=True)
|
| 148 |
+
feats = list(outputs.hidden_states)
|
| 149 |
+
assert len(feats) == 4
|
| 150 |
+
return feats
|
src/wireseghr/model/model.py
CHANGED
|
@@ -22,8 +22,8 @@ class WireSegHR(nn.Module):
|
|
| 22 |
def __init__(self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True):
|
| 23 |
super().__init__()
|
| 24 |
self.encoder = SegFormerEncoder(backbone=backbone, in_channels=in_channels, pretrained=pretrained)
|
| 25 |
-
#
|
| 26 |
-
in_chs = (
|
| 27 |
self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
|
| 28 |
self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
|
| 29 |
self.cond1x1 = Conditioning1x1()
|
|
|
|
| 22 |
def __init__(self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True):
|
| 23 |
super().__init__()
|
| 24 |
self.encoder = SegFormerEncoder(backbone=backbone, in_channels=in_channels, pretrained=pretrained)
|
| 25 |
+
# Use encoder-exposed feature dims for decoder projections
|
| 26 |
+
in_chs = tuple(self.encoder.feature_dims)
|
| 27 |
self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
|
| 28 |
self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
|
| 29 |
self.cond1x1 = Conditioning1x1()
|
src/wireseghr/train.py
CHANGED
|
@@ -2,6 +2,24 @@ import argparse
|
|
| 2 |
import os
|
| 3 |
import pprint
|
| 4 |
import yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def main():
|
|
@@ -18,8 +36,375 @@ def main():
|
|
| 18 |
|
| 19 |
print("[WireSegHR][train] Loaded config from:", cfg_path)
|
| 20 |
pprint.pprint(cfg)
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
if __name__ == "__main__":
|
| 25 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
import pprint
|
| 4 |
import yaml
|
| 5 |
+
from typing import Tuple, List, Optional, Dict
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import random
|
| 14 |
+
import torch.backends.cudnn as cudnn
|
| 15 |
+
import cv2
|
| 16 |
+
|
| 17 |
+
from wireseghr.model import WireSegHR
|
| 18 |
+
from wireseghr.model.minmax import MinMaxLuminance
|
| 19 |
+
from wireseghr.model.label_downsample import downsample_label_maxpool
|
| 20 |
+
from wireseghr.data.dataset import WireSegDataset
|
| 21 |
+
from wireseghr.data.sampler import BalancedPatchSampler
|
| 22 |
+
from wireseghr.metrics import compute_metrics
|
| 23 |
|
| 24 |
|
| 25 |
def main():
|
|
|
|
| 36 |
|
| 37 |
print("[WireSegHR][train] Loaded config from:", cfg_path)
|
| 38 |
pprint.pprint(cfg)
|
| 39 |
+
|
| 40 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 41 |
+
print(f"[WireSegHR][train] Device: {device}")
|
| 42 |
+
|
| 43 |
+
# Config
|
| 44 |
+
coarse_train = int(cfg["coarse"]["train_size"]) # 512
|
| 45 |
+
patch_size = int(cfg["fine"]["patch_size"]) # 768
|
| 46 |
+
iters = int(cfg["optim"]["iters"]) # 40000
|
| 47 |
+
batch_size = int(cfg["optim"]["batch_size"]) # 8
|
| 48 |
+
base_lr = float(cfg["optim"]["lr"]) # 6e-5
|
| 49 |
+
weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
|
| 50 |
+
power = float(cfg["optim"]["power"]) # 1.0
|
| 51 |
+
amp_flag = bool(cfg["optim"].get("amp", True))
|
| 52 |
+
|
| 53 |
+
# Housekeeping
|
| 54 |
+
seed = int(cfg.get("seed", 42))
|
| 55 |
+
out_dir = cfg.get("out_dir", "runs/wireseghr")
|
| 56 |
+
eval_interval = int(cfg.get("eval_interval", 500))
|
| 57 |
+
ckpt_interval = int(cfg.get("ckpt_interval", 1000))
|
| 58 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 59 |
+
set_seed(seed)
|
| 60 |
+
|
| 61 |
+
# Dataset
|
| 62 |
+
train_images = cfg["data"]["train_images"]
|
| 63 |
+
train_masks = cfg["data"]["train_masks"]
|
| 64 |
+
dset = WireSegDataset(train_images, train_masks, split="train")
|
| 65 |
+
# Validation and test
|
| 66 |
+
val_images = cfg["data"].get("val_images", None)
|
| 67 |
+
val_masks = cfg["data"].get("val_masks", None)
|
| 68 |
+
test_images = cfg["data"].get("test_images", None)
|
| 69 |
+
test_masks = cfg["data"].get("test_masks", None)
|
| 70 |
+
dset_val = WireSegDataset(val_images, val_masks, split="val") if val_images and val_masks else None
|
| 71 |
+
dset_test = WireSegDataset(test_images, test_masks, split="test") if test_images and test_masks else None
|
| 72 |
+
sampler = BalancedPatchSampler(patch_size=patch_size, min_wire_ratio=0.01)
|
| 73 |
+
minmax = MinMaxLuminance(kernel=cfg["minmax"]["kernel"]) if cfg["minmax"]["enable"] else None
|
| 74 |
+
|
| 75 |
+
# Model
|
| 76 |
+
# Channel definition: RGB(3) + MinMax(2) + cond(1) + loc(1) = 7
|
| 77 |
+
pretrained_flag = bool(cfg.get("pretrained", False))
|
| 78 |
+
model = WireSegHR(backbone=cfg["backbone"], in_channels=7, pretrained=pretrained_flag)
|
| 79 |
+
model = model.to(device)
|
| 80 |
+
|
| 81 |
+
# Optimizer and loss
|
| 82 |
+
optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
|
| 83 |
+
scaler = GradScaler(enabled=(device.type == "cuda" and amp_flag))
|
| 84 |
+
ce = nn.CrossEntropyLoss()
|
| 85 |
+
|
| 86 |
+
# Resume
|
| 87 |
+
start_step = 0
|
| 88 |
+
best_f1 = -1.0
|
| 89 |
+
resume_path = cfg.get("resume", None)
|
| 90 |
+
if resume_path and os.path.isfile(resume_path):
|
| 91 |
+
print(f"[WireSegHR][train] Resuming from {resume_path}")
|
| 92 |
+
start_step, best_f1 = _load_checkpoint(resume_path, model, optim, scaler, device)
|
| 93 |
+
|
| 94 |
+
# Training loop
|
| 95 |
+
model.train()
|
| 96 |
+
step = start_step
|
| 97 |
+
pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100)
|
| 98 |
+
while step < iters:
|
| 99 |
+
optim.zero_grad(set_to_none=True)
|
| 100 |
+
imgs, masks = _sample_batch_same_size(dset, batch_size)
|
| 101 |
+
batch = _prepare_batch(imgs, masks, coarse_train, patch_size, sampler, minmax, device)
|
| 102 |
+
|
| 103 |
+
logits_coarse, cond_map = model.forward_coarse(batch["x_coarse"]) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
|
| 104 |
+
|
| 105 |
+
# Upsample cond to full-res to crop the fine patch-aligned conditioning
|
| 106 |
+
B, _, hc4, wc4 = cond_map.shape
|
| 107 |
+
cond_up = F.interpolate(
|
| 108 |
+
cond_map.detach(), size=(batch["full_h"], batch["full_w"]), mode="bilinear", align_corners=False
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Build fine inputs: crop cond to patch, concat with patch RGB+MinMax and loc mask
|
| 112 |
+
x_fine = _build_fine_inputs(batch, cond_up, device)
|
| 113 |
+
logits_fine = model.forward_fine(x_fine)
|
| 114 |
+
|
| 115 |
+
# Targets
|
| 116 |
+
y_coarse = _build_coarse_targets(batch["mask_full"], hc4, wc4, device)
|
| 117 |
+
y_fine = _build_fine_targets(batch["mask_patches"], logits_fine.shape[2], logits_fine.shape[3], device)
|
| 118 |
+
|
| 119 |
+
with autocast(enabled=(device.type == "cuda" and amp_flag)):
|
| 120 |
+
loss_coarse = ce(logits_coarse, y_coarse)
|
| 121 |
+
loss_fine = ce(logits_fine, y_fine)
|
| 122 |
+
loss = loss_coarse + loss_fine
|
| 123 |
+
|
| 124 |
+
scaler.scale(loss).backward()
|
| 125 |
+
scaler.step(optim)
|
| 126 |
+
scaler.update()
|
| 127 |
+
|
| 128 |
+
# Poly LR schedule (per optimizer step)
|
| 129 |
+
lr = base_lr * ((1.0 - float(step) / float(iters)) ** power)
|
| 130 |
+
for pg in optim.param_groups:
|
| 131 |
+
pg["lr"] = lr
|
| 132 |
+
|
| 133 |
+
if step % 50 == 0:
|
| 134 |
+
print(
|
| 135 |
+
f"[Iter {step}/{iters}] lr={lr:.6e}"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Eval & Checkpoint
|
| 139 |
+
if (step % eval_interval == 0) and (dset_val is not None):
|
| 140 |
+
model.eval()
|
| 141 |
+
val_stats = validate(model, dset_val, coarse_train, device, amp_flag)
|
| 142 |
+
print(f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}")
|
| 143 |
+
# Save best
|
| 144 |
+
if val_stats["f1"] > best_f1:
|
| 145 |
+
best_f1 = val_stats["f1"]
|
| 146 |
+
_save_checkpoint(os.path.join(out_dir, "best.pt"), step, model, optim, scaler, best_f1)
|
| 147 |
+
# Save periodic ckpt
|
| 148 |
+
if ckpt_interval > 0 and (step % ckpt_interval == 0):
|
| 149 |
+
_save_checkpoint(os.path.join(out_dir, f"ckpt_{step}.pt"), step, model, optim, scaler, best_f1)
|
| 150 |
+
# Save test visualizations
|
| 151 |
+
if dset_test is not None:
|
| 152 |
+
save_test_visuals(model, dset_test, coarse_train, device, os.path.join(out_dir, f"test_vis_{step}"), amp_flag, max_samples=8)
|
| 153 |
+
model.train()
|
| 154 |
+
|
| 155 |
+
step += 1
|
| 156 |
+
pbar.update(1)
|
| 157 |
+
|
| 158 |
+
print("[WireSegHR][train] Done.")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _sample_batch_same_size(dset: WireSegDataset, batch_size: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
| 162 |
+
# Select a seed sample, then fill the batch with samples of the same (H,W)
|
| 163 |
+
assert len(dset) > 0
|
| 164 |
+
seed_idx = int(np.random.randint(0, len(dset)))
|
| 165 |
+
seed_item = dset[seed_idx]
|
| 166 |
+
H, W = seed_item["image"].shape[:2]
|
| 167 |
+
imgs: List[np.ndarray] = [seed_item["image"]]
|
| 168 |
+
masks: List[np.ndarray] = [seed_item["mask"]]
|
| 169 |
+
tries = 0
|
| 170 |
+
while len(imgs) < batch_size and tries < 5000:
|
| 171 |
+
idx = int(np.random.randint(0, len(dset)))
|
| 172 |
+
item = dset[idx]
|
| 173 |
+
im = item["image"]
|
| 174 |
+
if im.shape[0] == H and im.shape[1] == W:
|
| 175 |
+
imgs.append(im)
|
| 176 |
+
masks.append(item["mask"])
|
| 177 |
+
tries += 1
|
| 178 |
+
assert len(imgs) == batch_size, "Failed to assemble same-size batch"
|
| 179 |
+
return imgs, masks
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _prepare_batch(
|
| 183 |
+
imgs: List[np.ndarray],
|
| 184 |
+
masks: List[np.ndarray],
|
| 185 |
+
coarse_train: int,
|
| 186 |
+
patch_size: int,
|
| 187 |
+
sampler: BalancedPatchSampler,
|
| 188 |
+
minmax: Optional[MinMaxLuminance],
|
| 189 |
+
device: torch.device,
|
| 190 |
+
):
|
| 191 |
+
B = len(imgs)
|
| 192 |
+
assert B == len(masks)
|
| 193 |
+
# Keep numpy versions for geometry and torch versions for model inputs
|
| 194 |
+
import cv2
|
| 195 |
+
|
| 196 |
+
full_h = imgs[0].shape[0]
|
| 197 |
+
full_w = imgs[0].shape[1]
|
| 198 |
+
for im, m in zip(imgs, masks):
|
| 199 |
+
assert im.shape[0] == full_h and im.shape[1] == full_w
|
| 200 |
+
assert m.shape[0] == full_h and m.shape[1] == full_w
|
| 201 |
+
|
| 202 |
+
xs_coarse = []
|
| 203 |
+
patches_rgb = []
|
| 204 |
+
patches_mask = []
|
| 205 |
+
patches_min = []
|
| 206 |
+
patches_max = []
|
| 207 |
+
loc_masks = []
|
| 208 |
+
yx_list: List[tuple[int, int]] = []
|
| 209 |
+
|
| 210 |
+
for img, mask in zip(imgs, masks):
|
| 211 |
+
# Float32 [0,1]
|
| 212 |
+
imgf = img.astype(np.float32) / 255.0
|
| 213 |
+
if minmax is not None:
|
| 214 |
+
y_min, y_max = minmax(imgf)
|
| 215 |
+
else:
|
| 216 |
+
y = (0.299 * imgf[..., 0] + 0.587 * imgf[..., 1] + 0.114 * imgf[..., 2]).astype(np.float32)
|
| 217 |
+
y_min, y_max = y, y
|
| 218 |
+
|
| 219 |
+
# Coarse input: resize RGB + MinMax to coarse_train, pad cond+loc zeros to reach 7 channels
|
| 220 |
+
rgb_coarse = cv2.resize(imgf, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
|
| 221 |
+
y_min_c = cv2.resize(y_min, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
|
| 222 |
+
y_max_c = cv2.resize(y_max, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
|
| 223 |
+
c = np.concatenate([
|
| 224 |
+
np.transpose(rgb_coarse, (2, 0, 1)), # 3xHxW
|
| 225 |
+
y_min_c[None, ...], # 1xHxW
|
| 226 |
+
y_max_c[None, ...], # 1xHxW
|
| 227 |
+
np.zeros((1, coarse_train, coarse_train), np.float32), # cond placeholder
|
| 228 |
+
np.zeros((1, coarse_train, coarse_train), np.float32), # loc placeholder
|
| 229 |
+
], axis=0)
|
| 230 |
+
xs_coarse.append(torch.from_numpy(c))
|
| 231 |
+
|
| 232 |
+
# Sample fine patch
|
| 233 |
+
y0, x0 = sampler.sample(imgf, mask)
|
| 234 |
+
patch_rgb = imgf[y0 : y0 + patch_size, x0 : x0 + patch_size, :]
|
| 235 |
+
patch_mask = mask[y0 : y0 + patch_size, x0 : x0 + patch_size]
|
| 236 |
+
patches_rgb.append(patch_rgb)
|
| 237 |
+
patches_mask.append(patch_mask)
|
| 238 |
+
patches_min.append(y_min[y0 : y0 + patch_size, x0 : x0 + patch_size])
|
| 239 |
+
patches_max.append(y_max[y0 : y0 + patch_size, x0 : x0 + patch_size])
|
| 240 |
+
# Binary location mask (ones inside the patch)
|
| 241 |
+
loc_masks.append(np.ones((patch_size, patch_size), dtype=np.float32))
|
| 242 |
+
yx_list.append((y0, x0))
|
| 243 |
+
|
| 244 |
+
x_coarse = torch.stack(xs_coarse, dim=0).to(device) # Bx7xHc x Wc
|
| 245 |
+
|
| 246 |
+
# Store numpy arrays for fine build
|
| 247 |
+
return {
|
| 248 |
+
"x_coarse": x_coarse,
|
| 249 |
+
"full_h": full_h,
|
| 250 |
+
"full_w": full_w,
|
| 251 |
+
"rgb_patches": patches_rgb,
|
| 252 |
+
"mask_patches": patches_mask,
|
| 253 |
+
"ymin_patches": patches_min,
|
| 254 |
+
"ymax_patches": patches_max,
|
| 255 |
+
"loc_patches": loc_masks,
|
| 256 |
+
"patch_yx": yx_list,
|
| 257 |
+
"mask_full": masks,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _build_fine_inputs(batch, cond_up: torch.Tensor, device: torch.device) -> torch.Tensor:
|
| 262 |
+
# Build fine input tensor Bx7xP x P from per-sample numpy buffers and upsampled cond maps
|
| 263 |
+
B = cond_up.shape[0]
|
| 264 |
+
P = batch["loc_patches"][0].shape[0]
|
| 265 |
+
xs: List[torch.Tensor] = []
|
| 266 |
+
for i in range(B):
|
| 267 |
+
rgb = batch["rgb_patches"][i]
|
| 268 |
+
ymin = batch["ymin_patches"][i]
|
| 269 |
+
ymax = batch["ymax_patches"][i]
|
| 270 |
+
loc = batch["loc_patches"][i]
|
| 271 |
+
y0, x0 = batch["patch_yx"][i]
|
| 272 |
+
|
| 273 |
+
cond_patch = cond_up[i : i + 1, :, y0 : y0 + P, x0 : x0 + P] # 1x1xPxP
|
| 274 |
+
cond_patch = cond_patch.squeeze(1) # 1xPxP
|
| 275 |
+
|
| 276 |
+
# Convert numpy channels to torch and concat
|
| 277 |
+
rgb_t = torch.from_numpy(np.transpose(rgb, (2, 0, 1))) # 3xPxP
|
| 278 |
+
ymin_t = torch.from_numpy(ymin)[None, ...] # 1xPxP
|
| 279 |
+
ymax_t = torch.from_numpy(ymax)[None, ...] # 1xPxP
|
| 280 |
+
loc_t = torch.from_numpy(loc)[None, ...] # 1xPxP
|
| 281 |
+
x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch.cpu(), loc_t], dim=0).float() # 7xPxP
|
| 282 |
+
xs.append(x)
|
| 283 |
+
x_fine = torch.stack(xs, dim=0).to(device)
|
| 284 |
+
return x_fine
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _build_coarse_targets(masks: List[np.ndarray], out_h: int, out_w: int, device: torch.device) -> torch.Tensor:
|
| 288 |
+
ys: List[torch.Tensor] = []
|
| 289 |
+
for m in masks:
|
| 290 |
+
dm = downsample_label_maxpool(m, out_h, out_w)
|
| 291 |
+
ys.append(torch.from_numpy(dm.astype(np.int64)))
|
| 292 |
+
y = torch.stack(ys, dim=0).to(device) # BxHc4xWc4 with values {0,1}
|
| 293 |
+
return y
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _build_fine_targets(mask_patches: List[np.ndarray], out_h: int, out_w: int, device: torch.device) -> torch.Tensor:
|
| 297 |
+
ys: List[torch.Tensor] = []
|
| 298 |
+
for m in mask_patches:
|
| 299 |
+
dm = downsample_label_maxpool(m, out_h, out_w)
|
| 300 |
+
ys.append(torch.from_numpy(dm.astype(np.int64)))
|
| 301 |
+
y = torch.stack(ys, dim=0).to(device) # BxHf4xWf4 with values {0,1}
|
| 302 |
+
return y
|
| 303 |
|
| 304 |
|
| 305 |
if __name__ == "__main__":
|
| 306 |
main()
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def set_seed(seed: int):
|
| 310 |
+
random.seed(seed)
|
| 311 |
+
np.random.seed(seed)
|
| 312 |
+
torch.manual_seed(seed)
|
| 313 |
+
if torch.cuda.is_available():
|
| 314 |
+
torch.cuda.manual_seed_all(seed)
|
| 315 |
+
cudnn.benchmark = False
|
| 316 |
+
cudnn.deterministic = True
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _save_checkpoint(path: str, step: int, model: nn.Module, optim: torch.optim.Optimizer, scaler: GradScaler, best_f1: float):
|
| 320 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 321 |
+
state = {
|
| 322 |
+
"step": step,
|
| 323 |
+
"model": model.state_dict(),
|
| 324 |
+
"optim": optim.state_dict(),
|
| 325 |
+
"scaler": scaler.state_dict(),
|
| 326 |
+
"best_f1": best_f1,
|
| 327 |
+
}
|
| 328 |
+
torch.save(state, path)
|
| 329 |
+
print(f"[WireSegHR][train] Saved checkpoint: {path}")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _load_checkpoint(path: str, model: nn.Module, optim: torch.optim.Optimizer, scaler: GradScaler, device: torch.device) -> Tuple[int, float]:
|
| 333 |
+
ckpt = torch.load(path, map_location=device)
|
| 334 |
+
model.load_state_dict(ckpt["model"])
|
| 335 |
+
optim.load_state_dict(ckpt["optim"])
|
| 336 |
+
try:
|
| 337 |
+
scaler.load_state_dict(ckpt["scaler"]) # may not exist
|
| 338 |
+
except Exception:
|
| 339 |
+
pass
|
| 340 |
+
step = int(ckpt.get("step", 0))
|
| 341 |
+
best_f1 = float(ckpt.get("best_f1", -1.0))
|
| 342 |
+
return step, best_f1
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
@torch.no_grad()
|
| 346 |
+
def validate(model: WireSegHR, dset_val: WireSegDataset, coarse_size: int, device: torch.device, amp_flag: bool) -> Dict[str, float]:
|
| 347 |
+
# Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
|
| 348 |
+
model = model.to(device)
|
| 349 |
+
metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
|
| 350 |
+
n = 0
|
| 351 |
+
for i in range(len(dset_val)):
|
| 352 |
+
item = dset_val[i]
|
| 353 |
+
img = item["image"].astype(np.float32) / 255.0 # HxWx3
|
| 354 |
+
mask = item["mask"].astype(np.uint8)
|
| 355 |
+
H, W = mask.shape
|
| 356 |
+
# Build coarse input (zeros for cond+loc)
|
| 357 |
+
rgb_c = cv2.resize(img, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
|
| 358 |
+
y = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.float32)
|
| 359 |
+
y_min = cv2.resize(y, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
|
| 360 |
+
y_max = y_min
|
| 361 |
+
x = np.concatenate([
|
| 362 |
+
np.transpose(rgb_c, (2, 0, 1)),
|
| 363 |
+
y_min[None, ...],
|
| 364 |
+
y_max[None, ...],
|
| 365 |
+
np.zeros((1, coarse_size, coarse_size), np.float32),
|
| 366 |
+
np.zeros((1, coarse_size, coarse_size), np.float32),
|
| 367 |
+
], axis=0)
|
| 368 |
+
x_t = torch.from_numpy(x)[None, ...].to(device)
|
| 369 |
+
with autocast(enabled=(device.type == "cuda" and amp_flag)):
|
| 370 |
+
logits_c, _ = model.forward_coarse(x_t)
|
| 371 |
+
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 372 |
+
prob_up = F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0].detach().cpu().numpy()
|
| 373 |
+
pred = (prob_up > 0.5).astype(np.uint8)
|
| 374 |
+
m = compute_metrics(pred, mask)
|
| 375 |
+
for k in metrics_sum:
|
| 376 |
+
metrics_sum[k] += m[k]
|
| 377 |
+
n += 1
|
| 378 |
+
if n == 0:
|
| 379 |
+
return {k: 0.0 for k in metrics_sum}
|
| 380 |
+
return {k: v / float(n) for k, v in metrics_sum.items()}
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
@torch.no_grad()
|
| 384 |
+
def save_test_visuals(model: WireSegHR, dset_test: WireSegDataset, coarse_size: int, device: torch.device, out_dir: str, amp_flag: bool, max_samples: int = 8):
|
| 385 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 386 |
+
for i in range(min(max_samples, len(dset_test))):
|
| 387 |
+
item = dset_test[i]
|
| 388 |
+
img = item["image"].astype(np.float32) / 255.0
|
| 389 |
+
H, W = img.shape[:2]
|
| 390 |
+
rgb_c = cv2.resize(img, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
|
| 391 |
+
y = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.float32)
|
| 392 |
+
y_min = cv2.resize(y, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
|
| 393 |
+
y_max = y_min
|
| 394 |
+
x = np.concatenate([
|
| 395 |
+
np.transpose(rgb_c, (2, 0, 1)),
|
| 396 |
+
y_min[None, ...],
|
| 397 |
+
y_max[None, ...],
|
| 398 |
+
np.zeros((1, coarse_size, coarse_size), np.float32),
|
| 399 |
+
np.zeros((1, coarse_size, coarse_size), np.float32),
|
| 400 |
+
], axis=0)
|
| 401 |
+
x_t = torch.from_numpy(x)[None, ...].to(device)
|
| 402 |
+
with autocast(enabled=(device.type == "cuda" and amp_flag)):
|
| 403 |
+
logits_c, _ = model.forward_coarse(x_t)
|
| 404 |
+
prob = torch.softmax(logits_c, dim=1)[:, 1:2]
|
| 405 |
+
prob_up = F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0].detach().cpu().numpy()
|
| 406 |
+
pred = (prob_up > 0.5).astype(np.uint8) * 255
|
| 407 |
+
# Save input and prediction
|
| 408 |
+
img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
|
| 409 |
+
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
|
| 410 |
+
cv2.imwrite(os.path.join(out_dir, f"{i:03d}_pred.png"), pred)
|