MRiabov commited on
Commit
8e73ec9
·
1 Parent(s): ebeb96c

Training loop and around the codebase

Browse files
.gitignore CHANGED
@@ -114,3 +114,9 @@ venv/
114
  ENV/
115
  env.bak/
116
  venv.bak/
 
 
 
 
 
 
 
114
  ENV/
115
  env.bak/
116
  venv.bak/
117
+
118
+ # Secrets
119
+ secrets/
120
+
121
+ # dataset
122
+ dataset/
README.md CHANGED
@@ -28,4 +28,25 @@ python src/wireseghr/infer.py --config configs/default.yaml --image /path/to/ima
28
 
29
  ## Notes
30
  - This is a segmentation-only codebase. Inpainting is out of scope here.
31
- - Defaults locked: MiT-B3 encoder, patch size 768, MinMax 6×6, global+binary mask conditioning with patch-cropped global map.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  ## Notes
30
  - This is a segmentation-only codebase. Inpainting is out of scope here.
31
+ - Defaults locked: SegFormer MiT-B3 encoder, patch size 768, MinMax 6×6, global+binary mask conditioning with patch-cropped global map.
32
+
33
+ ### Backbone Source
34
+ - Preferred: HuggingFace Transformers SegFormer (e.g., `nvidia/mit-b3`). We set `num_channels` to match input channels.
35
+ - Optional: `timm` features_only if a compatible SegFormer is available.
36
+ - Fallback: a small internal CNN that preserves 1/4, 1/8, 1/16, 1/32 strides with channels [64, 128, 320, 512].
37
+
38
+ Install requirements to get Transformers:
39
+ ```
40
+ pip install -r requirements.txt
41
+ ```
42
+
43
+ ## Dataset Convention
44
+ - Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
45
+ - Example (after split 85/5/10):
46
+ - `dataset/train/images/1.jpg, 2.jpg, ...` and `dataset/train/gts/1.png, 2.png, ...`
47
+ - `dataset/val/images/...` and `dataset/val/gts/...`
48
+ - `dataset/test/images/...` and `dataset/test/gts/...`
49
+ - Masks are binary: foreground = white (255), background = black (0).
50
+ - The loader strictly enforces numeric stems and 1:1 pairing and will assert on mismatches.
51
+
52
+ Update `configs/default.yaml` with your paths under `data.train_images`, `data.train_masks`, etc. Defaults point to `dataset/train/images`, `dataset/train/gts`, and validation to `dataset/val/...`.
SEGMENTATION_PLAN.md CHANGED
@@ -9,7 +9,7 @@ This plan distills the model and pipeline described in the paper sources:
9
  Focus: segmentation only (no dataset collection or inpainting).
10
 
11
  ## Decisions and Defaults (locked)
12
- - Backbone: SegFormer MiT-B3 (shared encoder `E`).
13
  - Fine/local patch size p: 768.
14
  - Conditioning: global map + binary location mask by default (Table `tables/logit.tex`).
15
  - Conditioning map scope: patch-cropped from the global map per `paper-tex/sections/method_yq.tex` (no full-image concatenation variant).
@@ -41,7 +41,7 @@ Focus: segmentation only (no dataset collection or inpainting).
41
  - `README.md` (segmentation-only usage)
42
 
43
  ## Model Specification
44
- - Shared encoder `E`: SegFormer MiT-B3.
45
  - Input channels (default): 3 (RGB) + 2 (MinMax) + 1 (global cond) + 1 (binary location) = 7.
46
  - For the coarse pass, the cond and location channels are zeros to keep channel count consistent (`method_yq.tex`).
47
  - Weight init for extra channels: copy mean of RGB conv weights or zero-init.
@@ -65,6 +65,14 @@ Focus: segmentation only (no dataset collection or inpainting).
65
  - Downsample full-res mask to coarse size with max-pooling to prevent wire vanishing (`method_yq.tex`).
66
  - Normalization: standard mean/std per backbone; apply consistently across channels (new channels can be mean=0, std=1 by convention, or min-max scaled).
67
 
 
 
 
 
 
 
 
 
68
  ## Training Pipeline
69
  - Augment the full-res image (scaling, rotation, horizontal flip, photometric distortion) before constructing coarse/fine inputs (`method.tex`).
70
  - Coarse input: downsample augmented full image to 512×512; build channels [RGB+MinMax+zeros(2)] → `E` → `D_C`.
 
9
  Focus: segmentation only (no dataset collection or inpainting).
10
 
11
  ## Decisions and Defaults (locked)
12
+ - Backbone: SegFormer MiT-B3 via HuggingFace Transformers (shared encoder `E`), with `timm` or tiny CNN fallback.
13
  - Fine/local patch size p: 768.
14
  - Conditioning: global map + binary location mask by default (Table `tables/logit.tex`).
15
  - Conditioning map scope: patch-cropped from the global map per `paper-tex/sections/method_yq.tex` (no full-image concatenation variant).
 
41
  - `README.md` (segmentation-only usage)
42
 
43
  ## Model Specification
44
+ - Shared encoder `E`: SegFormer MiT-B3 (HF Transformers preferred).
45
  - Input channels (default): 3 (RGB) + 2 (MinMax) + 1 (global cond) + 1 (binary location) = 7.
46
  - For the coarse pass, the cond and location channels are zeros to keep channel count consistent (`method_yq.tex`).
47
  - Weight init for extra channels: copy mean of RGB conv weights or zero-init.
 
65
  - Downsample full-res mask to coarse size with max-pooling to prevent wire vanishing (`method_yq.tex`).
66
  - Normalization: standard mean/std per backbone; apply consistently across channels (new channels can be mean=0, std=1 by convention, or min-max scaled).
67
 
68
+ ### Dataset Convention (project-specific)
69
+ - Flat directories with numeric filenames; images are `.jpg`/`.jpeg`, masks are `.png`.
70
+ - Example:
71
+ - `dataset/images/1.jpg, 2.jpg, ..., N.jpg` (or `.jpeg`)
72
+ - `dataset/gts/1.png, 2.png, ..., N.png`
73
+ - Masks are binary: foreground = white (255), background = black (0).
74
+ - The loader (`data/dataset.py`) strictly enforces numeric stems and 1:1 pairing and will assert on mismatch.
75
+
76
  ## Training Pipeline
77
  - Augment the full-res image (scaling, rotation, horizontal flip, photometric distortion) before constructing coarse/fine inputs (`method.tex`).
78
  - Coarse input: downsample augmented full image to 512×512; build channels [RGB+MinMax+zeros(2)] → `E` → `D_C`.
configs/default.yaml CHANGED
@@ -1,5 +1,6 @@
1
  # Default configuration for WireSegHR (segmentation-only)
2
  backbone: mit_b3
 
3
 
4
  coarse:
5
  train_size: 512
@@ -34,9 +35,18 @@ optim:
34
  schedule: poly
35
  power: 1.0
36
 
 
 
 
 
 
 
 
37
  # dataset paths (placeholders)
38
  data:
39
- train_images: /path/to/train/images
40
- train_masks: /path/to/train/masks
41
- val_images: /path/to/val/images
42
- val_masks: /path/to/val/masks
 
 
 
1
  # Default configuration for WireSegHR (segmentation-only)
2
  backbone: mit_b3
3
+ pretrained: true # Uses HF SegFormer weights if available; else timm or tiny fallback
4
 
5
  coarse:
6
  train_size: 512
 
35
  schedule: poly
36
  power: 1.0
37
 
38
+ # training housekeeping
39
+ seed: 42
40
+ out_dir: runs/wireseghr
41
+ eval_interval: 500
42
+ ckpt_interval: 1000
43
+ # resume: runs/wireseghr/ckpt_1000.pt # optional
44
+
45
  # dataset paths (placeholders)
46
  data:
47
+ train_images: dataset/train/images
48
+ train_masks: dataset/train/gts
49
+ val_images: dataset/val/images
50
+ val_masks: dataset/val/gts
51
+ test_images: dataset/test/images
52
+ test_masks: dataset/test/gts
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  torch>=2.1.0
2
  torchvision>=0.16.0
3
  timm>=0.9.8
 
4
  numpy>=1.24.0
5
  opencv-python>=4.8.0.76
6
  Pillow>=9.5.0
 
1
  torch>=2.1.0
2
  torchvision>=0.16.0
3
  timm>=0.9.8
4
+ transformers>=4.37.0
5
  numpy>=1.24.0
6
  opencv-python>=4.8.0.76
7
  Pillow>=9.5.0
src/wireseghr/data/dataset.py CHANGED
@@ -36,19 +36,25 @@ class WireSegDataset:
36
  return {"image": img, "mask": mask_bin, "image_path": str(img_path), "mask_path": str(mask_path)}
37
 
38
  def _index_pairs(self) -> List[tuple[Path, Path]]:
39
- exts_img = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
40
- exts_mask = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
41
- imgs: Dict[str, Path] = {}
42
- for p in sorted(self.images_dir.rglob("*")):
43
- if p.is_file() and p.suffix.lower() in exts_img:
44
- imgs[p.stem] = p
45
- masks: Dict[str, Path] = {}
46
- for p in sorted(self.masks_dir.rglob("*")):
47
- if p.is_file() and p.suffix.lower() in exts_mask:
48
- masks[p.stem] = p
49
  pairs: List[tuple[Path, Path]] = []
50
- for stem, ip in imgs.items():
51
- if stem in masks:
52
- pairs.append((ip, masks[stem]))
53
- assert len(pairs) > 0, f"No image-mask pairs found in {self.images_dir} and {self.masks_dir}"
 
 
 
 
 
 
 
 
 
 
 
 
54
  return pairs
 
36
  return {"image": img, "mask": mask_bin, "image_path": str(img_path), "mask_path": str(mask_path)}
37
 
38
  def _index_pairs(self) -> List[tuple[Path, Path]]:
39
+ # Convention: numeric filenames; images are .jpg/.jpeg; masks (gts) are .png
40
+ img_files = sorted([p for p in self.images_dir.glob("*.jpg") if p.is_file()])
41
+ img_files += sorted([p for p in self.images_dir.glob("*.jpeg") if p.is_file()])
42
+ assert len(img_files) > 0, f"No .jpg/.jpeg images in {self.images_dir}"
 
 
 
 
 
 
43
  pairs: List[tuple[Path, Path]] = []
44
+ ids: List[int] = []
45
+ for p in img_files:
46
+ stem = p.stem
47
+ assert stem.isdigit(), f"Non-numeric filename encountered: {p.name}"
48
+ ids.append(int(stem))
49
+ ids = sorted(ids)
50
+ for i in ids:
51
+ # Prefer .jpg, else .jpeg
52
+ ip_jpg = self.images_dir / f"{i}.jpg"
53
+ ip_jpeg = self.images_dir / f"{i}.jpeg"
54
+ ip = ip_jpg if ip_jpg.exists() else ip_jpeg
55
+ assert ip.exists(), f"Missing image for {i}: {ip_jpg} or {ip_jpeg}"
56
+ mp = self.masks_dir / f"{i}.png"
57
+ assert mp.exists(), f"Missing mask for {i}: {mp}"
58
+ pairs.append((ip, mp))
59
+ assert len(pairs) > 0, f"No numeric pairs found in {self.images_dir} and {self.masks_dir}"
60
  return pairs
src/wireseghr/data/sampler.py CHANGED
@@ -1,14 +1,29 @@
1
  # Balanced patch sampler (>=1% wire pixels)
2
- # TODO: Implement logic over mask to pick patches with wire ratio >= threshold.
 
 
 
 
3
 
4
  from dataclasses import dataclass
 
5
 
6
 
7
  @dataclass
8
  class BalancedPatchSampler:
9
  patch_size: int = 768
10
  min_wire_ratio: float = 0.01
 
11
 
12
- def sample(self, image, mask):
13
- # TODO: sample and return top-left (y, x) of a valid patch
14
- return 0, 0
 
 
 
 
 
 
 
 
 
 
1
  # Balanced patch sampler (>=1% wire pixels)
2
+ """Balanced patch sampling with >= min_wire_ratio positives.
3
+
4
+ Sampling is uniform over valid top-left positions; tries a fixed number of
5
+ attempts and asserts if none meet the threshold.
6
+ """
7
 
8
  from dataclasses import dataclass
9
+ import numpy as np
10
 
11
 
12
  @dataclass
13
  class BalancedPatchSampler:
14
  patch_size: int = 768
15
  min_wire_ratio: float = 0.01
16
+ max_tries: int = 200
17
 
18
+ def sample(self, image: np.ndarray, mask: np.ndarray) -> tuple[int, int]:
19
+ h, w = mask.shape
20
+ p = self.patch_size
21
+ assert h >= p and w >= p, "Image smaller than patch size"
22
+ for _ in range(self.max_tries):
23
+ y = np.random.randint(0, h - p + 1)
24
+ x = np.random.randint(0, w - p + 1)
25
+ m = mask[y : y + p, x : x + p]
26
+ ratio = float(m.sum()) / float(p * p)
27
+ if ratio >= self.min_wire_ratio:
28
+ return int(y), int(x)
29
+ raise AssertionError("Failed to sample a patch meeting min_wire_ratio")
src/wireseghr/metrics.py CHANGED
@@ -1,9 +1,32 @@
1
- # Metrics placeholder: IoU, F1, Precision, Recall
2
- # TODO: Implement proper metrics matching paper tables.
3
-
4
  from typing import Dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
6
 
7
- def compute_metrics(pred_mask, gt_mask) -> Dict[str, float]:
8
- # TODO: implement
9
- return {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
 
 
 
 
1
  from typing import Dict
2
+ import numpy as np
3
+
4
+
5
+ def compute_metrics(pred_mask: np.ndarray, gt_mask: np.ndarray) -> Dict[str, float]:
6
+ """Compute binary segmentation metrics on 0/1 numpy masks.
7
+
8
+ Args:
9
+ pred_mask: HxW uint8 or bool in {0,1}
10
+ gt_mask: HxW uint8 or bool in {0,1}
11
+ Returns:
12
+ dict with iou, f1, precision, recall
13
+ """
14
+ p = (pred_mask > 0).astype(np.uint8)
15
+ g = (gt_mask > 0).astype(np.uint8)
16
+
17
+ tp = int(np.sum((p == 1) & (g == 1)))
18
+ fp = int(np.sum((p == 1) & (g == 0)))
19
+ fn = int(np.sum((p == 0) & (g == 1)))
20
+
21
+ denom_iou = tp + fp + fn
22
+ iou = (tp / denom_iou) if denom_iou > 0 else 0.0
23
+
24
+ prec_den = tp + fp
25
+ rec_den = tp + fn
26
+ precision = (tp / prec_den) if prec_den > 0 else 0.0
27
+ recall = (tp / rec_den) if rec_den > 0 else 0.0
28
 
29
+ denom_f1 = precision + recall
30
+ f1 = (2 * precision * recall / denom_f1) if denom_f1 > 0 else 0.0
31
 
32
+ return {"iou": float(iou), "f1": float(f1), "precision": float(precision), "recall": float(recall)}
 
 
src/wireseghr/model/encoder.py CHANGED
@@ -25,18 +25,126 @@ class SegFormerEncoder(nn.Module):
25
  self.pretrained = pretrained
26
  self.out_indices = out_indices
27
 
28
- # Create MiT with features_only to obtain multi-scale feature maps.
29
- # in_chans allows expanded inputs (RGB + minmax + cond + loc)
30
- self.encoder = timm.create_model(
31
- backbone,
32
- pretrained=pretrained,
33
- features_only=True,
34
- out_indices=out_indices,
35
- in_chans=in_channels,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
 
38
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
39
- feats = self.encoder(x)
40
- # Ensure list of tensors is returned
41
- assert isinstance(feats, (list, tuple)) and len(feats) == len(self.out_indices)
42
- return list(feats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  self.pretrained = pretrained
26
  self.out_indices = out_indices
27
 
28
+ # Prefer HuggingFace SegFormer for 'mit_*' backbones.
29
+ # Otherwise try timm features_only. Always have Tiny CNN fallback.
30
+ self.encoder = None
31
+ self.hf = None
32
+ prefer_hf = backbone.startswith("mit_") or backbone.startswith("segformer")
33
+ if prefer_hf:
34
+ # HF -> timm -> tiny
35
+ try:
36
+ self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
37
+ self.feature_dims = self.hf.feature_dims
38
+ except Exception:
39
+ try:
40
+ self.encoder = timm.create_model(
41
+ backbone,
42
+ pretrained=pretrained,
43
+ features_only=True,
44
+ out_indices=out_indices,
45
+ in_chans=in_channels,
46
+ )
47
+ self.feature_dims = list(self.encoder.feature_info.channels())
48
+ except Exception:
49
+ self.encoder = None
50
+ self.fallback = _TinyEncoder(in_channels)
51
+ self.feature_dims = [64, 128, 320, 512]
52
+ else:
53
+ # timm -> HF -> tiny
54
+ try:
55
+ self.encoder = timm.create_model(
56
+ backbone,
57
+ pretrained=pretrained,
58
+ features_only=True,
59
+ out_indices=out_indices,
60
+ in_chans=in_channels,
61
+ )
62
+ self.feature_dims = list(self.encoder.feature_info.channels())
63
+ except Exception:
64
+ try:
65
+ self.hf = _HFEncoderWrapper(in_channels, backbone, pretrained)
66
+ self.feature_dims = self.hf.feature_dims
67
+ except Exception:
68
+ self.encoder = None
69
+ self.fallback = _TinyEncoder(in_channels)
70
+ self.feature_dims = [64, 128, 320, 512]
71
+
72
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
73
+ if self.encoder is not None:
74
+ feats = self.encoder(x)
75
+ assert isinstance(feats, (list, tuple)) and len(feats) == len(self.out_indices)
76
+ return list(feats)
77
+ elif self.hf is not None:
78
+ return self.hf(x)
79
+ else:
80
+ return self.fallback(x)
81
+
82
+
83
+ class _TinyEncoder(nn.Module):
84
+ def __init__(self, in_chans: int):
85
+ super().__init__()
86
+ # Output strides: 4, 8, 16, 32 with channels 64,128,320,512
87
+ self.stem = nn.Sequential(
88
+ nn.Conv2d(in_chans, 64, kernel_size=7, stride=4, padding=3, bias=False),
89
+ nn.BatchNorm2d(64),
90
+ nn.ReLU(inplace=True),
91
+ )
92
+ self.stage1 = nn.Sequential(
93
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
94
+ nn.BatchNorm2d(128),
95
+ nn.ReLU(inplace=True),
96
+ )
97
+ self.stage2 = nn.Sequential(
98
+ nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1, bias=False),
99
+ nn.BatchNorm2d(320),
100
+ nn.ReLU(inplace=True),
101
+ )
102
+ self.stage3 = nn.Sequential(
103
+ nn.Conv2d(320, 512, kernel_size=3, stride=2, padding=1, bias=False),
104
+ nn.BatchNorm2d(512),
105
+ nn.ReLU(inplace=True),
106
  )
107
 
108
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
109
+ c0 = self.stem(x) # 1/4
110
+ c1 = self.stage1(c0) # 1/8
111
+ c2 = self.stage2(c1) # 1/16
112
+ c3 = self.stage3(c2) # 1/32
113
+ return [c0, c1, c2, c3]
114
+
115
+
116
+ class _HFEncoderWrapper(nn.Module):
117
+ def __init__(self, in_chans: int, backbone: str, pretrained: bool):
118
+ super().__init__()
119
+ # Lazy import to avoid hard dependency during tests if not used
120
+ from transformers import SegformerModel, SegformerConfig
121
+
122
+ name_map = {
123
+ "mit_b0": "nvidia/mit-b0",
124
+ "mit_b1": "nvidia/mit-b1",
125
+ "mit_b2": "nvidia/mit-b2",
126
+ "mit_b3": "nvidia/mit-b3",
127
+ "mit_b4": "nvidia/mit-b4",
128
+ "mit_b5": "nvidia/mit-b5",
129
+ }
130
+ model_id = name_map.get(backbone, "nvidia/mit-b0")
131
+
132
+ if pretrained:
133
+ base_cfg = SegformerConfig.from_pretrained(model_id)
134
+ base_cfg.num_channels = in_chans
135
+ self.model = SegformerModel.from_pretrained(
136
+ model_id, config=base_cfg, ignore_mismatched_sizes=True
137
+ )
138
+ else:
139
+ cfg = SegformerConfig() # default config (B0-like)
140
+ cfg.num_channels = in_chans
141
+ self.model = SegformerModel(cfg)
142
+
143
+ # Expose channel dims per stage
144
+ self.feature_dims = list(self.model.config.hidden_sizes)
145
+
146
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
147
+ outputs = self.model(pixel_values=x, output_hidden_states=True, return_dict=True)
148
+ feats = list(outputs.hidden_states)
149
+ assert len(feats) == 4
150
+ return feats
src/wireseghr/model/model.py CHANGED
@@ -22,8 +22,8 @@ class WireSegHR(nn.Module):
22
  def __init__(self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True):
23
  super().__init__()
24
  self.encoder = SegFormerEncoder(backbone=backbone, in_channels=in_channels, pretrained=pretrained)
25
- # Default MiT-B3 channel dims for stages
26
- in_chs = (64, 128, 320, 512)
27
  self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
28
  self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
29
  self.cond1x1 = Conditioning1x1()
 
22
  def __init__(self, backbone: str = "mit_b3", in_channels: int = 7, pretrained: bool = True):
23
  super().__init__()
24
  self.encoder = SegFormerEncoder(backbone=backbone, in_channels=in_channels, pretrained=pretrained)
25
+ # Use encoder-exposed feature dims for decoder projections
26
+ in_chs = tuple(self.encoder.feature_dims)
27
  self.coarse_head = CoarseDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
28
  self.fine_head = FineDecoder(in_chs=in_chs, embed_dim=128, num_classes=2)
29
  self.cond1x1 = Conditioning1x1()
src/wireseghr/train.py CHANGED
@@ -2,6 +2,24 @@ import argparse
2
  import os
3
  import pprint
4
  import yaml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def main():
@@ -18,8 +36,375 @@ def main():
18
 
19
  print("[WireSegHR][train] Loaded config from:", cfg_path)
20
  pprint.pprint(cfg)
21
- print("[WireSegHR][train] Skeleton OK. Implement training per SEGMENTATION_PLAN.md.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  if __name__ == "__main__":
25
  main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import pprint
4
  import yaml
5
+ from typing import Tuple, List, Optional, Dict
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.cuda.amp import autocast, GradScaler
12
+ from tqdm import tqdm
13
+ import random
14
+ import torch.backends.cudnn as cudnn
15
+ import cv2
16
+
17
+ from wireseghr.model import WireSegHR
18
+ from wireseghr.model.minmax import MinMaxLuminance
19
+ from wireseghr.model.label_downsample import downsample_label_maxpool
20
+ from wireseghr.data.dataset import WireSegDataset
21
+ from wireseghr.data.sampler import BalancedPatchSampler
22
+ from wireseghr.metrics import compute_metrics
23
 
24
 
25
  def main():
 
36
 
37
  print("[WireSegHR][train] Loaded config from:", cfg_path)
38
  pprint.pprint(cfg)
39
+
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ print(f"[WireSegHR][train] Device: {device}")
42
+
43
+ # Config
44
+ coarse_train = int(cfg["coarse"]["train_size"]) # 512
45
+ patch_size = int(cfg["fine"]["patch_size"]) # 768
46
+ iters = int(cfg["optim"]["iters"]) # 40000
47
+ batch_size = int(cfg["optim"]["batch_size"]) # 8
48
+ base_lr = float(cfg["optim"]["lr"]) # 6e-5
49
+ weight_decay = float(cfg["optim"]["weight_decay"]) # 0.01
50
+ power = float(cfg["optim"]["power"]) # 1.0
51
+ amp_flag = bool(cfg["optim"].get("amp", True))
52
+
53
+ # Housekeeping
54
+ seed = int(cfg.get("seed", 42))
55
+ out_dir = cfg.get("out_dir", "runs/wireseghr")
56
+ eval_interval = int(cfg.get("eval_interval", 500))
57
+ ckpt_interval = int(cfg.get("ckpt_interval", 1000))
58
+ os.makedirs(out_dir, exist_ok=True)
59
+ set_seed(seed)
60
+
61
+ # Dataset
62
+ train_images = cfg["data"]["train_images"]
63
+ train_masks = cfg["data"]["train_masks"]
64
+ dset = WireSegDataset(train_images, train_masks, split="train")
65
+ # Validation and test
66
+ val_images = cfg["data"].get("val_images", None)
67
+ val_masks = cfg["data"].get("val_masks", None)
68
+ test_images = cfg["data"].get("test_images", None)
69
+ test_masks = cfg["data"].get("test_masks", None)
70
+ dset_val = WireSegDataset(val_images, val_masks, split="val") if val_images and val_masks else None
71
+ dset_test = WireSegDataset(test_images, test_masks, split="test") if test_images and test_masks else None
72
+ sampler = BalancedPatchSampler(patch_size=patch_size, min_wire_ratio=0.01)
73
+ minmax = MinMaxLuminance(kernel=cfg["minmax"]["kernel"]) if cfg["minmax"]["enable"] else None
74
+
75
+ # Model
76
+ # Channel definition: RGB(3) + MinMax(2) + cond(1) + loc(1) = 7
77
+ pretrained_flag = bool(cfg.get("pretrained", False))
78
+ model = WireSegHR(backbone=cfg["backbone"], in_channels=7, pretrained=pretrained_flag)
79
+ model = model.to(device)
80
+
81
+ # Optimizer and loss
82
+ optim = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
83
+ scaler = GradScaler(enabled=(device.type == "cuda" and amp_flag))
84
+ ce = nn.CrossEntropyLoss()
85
+
86
+ # Resume
87
+ start_step = 0
88
+ best_f1 = -1.0
89
+ resume_path = cfg.get("resume", None)
90
+ if resume_path and os.path.isfile(resume_path):
91
+ print(f"[WireSegHR][train] Resuming from {resume_path}")
92
+ start_step, best_f1 = _load_checkpoint(resume_path, model, optim, scaler, device)
93
+
94
+ # Training loop
95
+ model.train()
96
+ step = start_step
97
+ pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100)
98
+ while step < iters:
99
+ optim.zero_grad(set_to_none=True)
100
+ imgs, masks = _sample_batch_same_size(dset, batch_size)
101
+ batch = _prepare_batch(imgs, masks, coarse_train, patch_size, sampler, minmax, device)
102
+
103
+ logits_coarse, cond_map = model.forward_coarse(batch["x_coarse"]) # (B,2,Hc/4,Wc/4) and (B,1,Hc/4,Wc/4)
104
+
105
+ # Upsample cond to full-res to crop the fine patch-aligned conditioning
106
+ B, _, hc4, wc4 = cond_map.shape
107
+ cond_up = F.interpolate(
108
+ cond_map.detach(), size=(batch["full_h"], batch["full_w"]), mode="bilinear", align_corners=False
109
+ )
110
+
111
+ # Build fine inputs: crop cond to patch, concat with patch RGB+MinMax and loc mask
112
+ x_fine = _build_fine_inputs(batch, cond_up, device)
113
+ logits_fine = model.forward_fine(x_fine)
114
+
115
+ # Targets
116
+ y_coarse = _build_coarse_targets(batch["mask_full"], hc4, wc4, device)
117
+ y_fine = _build_fine_targets(batch["mask_patches"], logits_fine.shape[2], logits_fine.shape[3], device)
118
+
119
+ with autocast(enabled=(device.type == "cuda" and amp_flag)):
120
+ loss_coarse = ce(logits_coarse, y_coarse)
121
+ loss_fine = ce(logits_fine, y_fine)
122
+ loss = loss_coarse + loss_fine
123
+
124
+ scaler.scale(loss).backward()
125
+ scaler.step(optim)
126
+ scaler.update()
127
+
128
+ # Poly LR schedule (per optimizer step)
129
+ lr = base_lr * ((1.0 - float(step) / float(iters)) ** power)
130
+ for pg in optim.param_groups:
131
+ pg["lr"] = lr
132
+
133
+ if step % 50 == 0:
134
+ print(
135
+ f"[Iter {step}/{iters}] lr={lr:.6e}"
136
+ )
137
+
138
+ # Eval & Checkpoint
139
+ if (step % eval_interval == 0) and (dset_val is not None):
140
+ model.eval()
141
+ val_stats = validate(model, dset_val, coarse_train, device, amp_flag)
142
+ print(f"[Val @ {step}] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}")
143
+ # Save best
144
+ if val_stats["f1"] > best_f1:
145
+ best_f1 = val_stats["f1"]
146
+ _save_checkpoint(os.path.join(out_dir, "best.pt"), step, model, optim, scaler, best_f1)
147
+ # Save periodic ckpt
148
+ if ckpt_interval > 0 and (step % ckpt_interval == 0):
149
+ _save_checkpoint(os.path.join(out_dir, f"ckpt_{step}.pt"), step, model, optim, scaler, best_f1)
150
+ # Save test visualizations
151
+ if dset_test is not None:
152
+ save_test_visuals(model, dset_test, coarse_train, device, os.path.join(out_dir, f"test_vis_{step}"), amp_flag, max_samples=8)
153
+ model.train()
154
+
155
+ step += 1
156
+ pbar.update(1)
157
+
158
+ print("[WireSegHR][train] Done.")
159
+
160
+
161
+ def _sample_batch_same_size(dset: WireSegDataset, batch_size: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
162
+ # Select a seed sample, then fill the batch with samples of the same (H,W)
163
+ assert len(dset) > 0
164
+ seed_idx = int(np.random.randint(0, len(dset)))
165
+ seed_item = dset[seed_idx]
166
+ H, W = seed_item["image"].shape[:2]
167
+ imgs: List[np.ndarray] = [seed_item["image"]]
168
+ masks: List[np.ndarray] = [seed_item["mask"]]
169
+ tries = 0
170
+ while len(imgs) < batch_size and tries < 5000:
171
+ idx = int(np.random.randint(0, len(dset)))
172
+ item = dset[idx]
173
+ im = item["image"]
174
+ if im.shape[0] == H and im.shape[1] == W:
175
+ imgs.append(im)
176
+ masks.append(item["mask"])
177
+ tries += 1
178
+ assert len(imgs) == batch_size, "Failed to assemble same-size batch"
179
+ return imgs, masks
180
+
181
+
182
+ def _prepare_batch(
183
+ imgs: List[np.ndarray],
184
+ masks: List[np.ndarray],
185
+ coarse_train: int,
186
+ patch_size: int,
187
+ sampler: BalancedPatchSampler,
188
+ minmax: Optional[MinMaxLuminance],
189
+ device: torch.device,
190
+ ):
191
+ B = len(imgs)
192
+ assert B == len(masks)
193
+ # Keep numpy versions for geometry and torch versions for model inputs
194
+ import cv2
195
+
196
+ full_h = imgs[0].shape[0]
197
+ full_w = imgs[0].shape[1]
198
+ for im, m in zip(imgs, masks):
199
+ assert im.shape[0] == full_h and im.shape[1] == full_w
200
+ assert m.shape[0] == full_h and m.shape[1] == full_w
201
+
202
+ xs_coarse = []
203
+ patches_rgb = []
204
+ patches_mask = []
205
+ patches_min = []
206
+ patches_max = []
207
+ loc_masks = []
208
+ yx_list: List[tuple[int, int]] = []
209
+
210
+ for img, mask in zip(imgs, masks):
211
+ # Float32 [0,1]
212
+ imgf = img.astype(np.float32) / 255.0
213
+ if minmax is not None:
214
+ y_min, y_max = minmax(imgf)
215
+ else:
216
+ y = (0.299 * imgf[..., 0] + 0.587 * imgf[..., 1] + 0.114 * imgf[..., 2]).astype(np.float32)
217
+ y_min, y_max = y, y
218
+
219
+ # Coarse input: resize RGB + MinMax to coarse_train, pad cond+loc zeros to reach 7 channels
220
+ rgb_coarse = cv2.resize(imgf, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
221
+ y_min_c = cv2.resize(y_min, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
222
+ y_max_c = cv2.resize(y_max, (coarse_train, coarse_train), interpolation=cv2.INTER_LINEAR)
223
+ c = np.concatenate([
224
+ np.transpose(rgb_coarse, (2, 0, 1)), # 3xHxW
225
+ y_min_c[None, ...], # 1xHxW
226
+ y_max_c[None, ...], # 1xHxW
227
+ np.zeros((1, coarse_train, coarse_train), np.float32), # cond placeholder
228
+ np.zeros((1, coarse_train, coarse_train), np.float32), # loc placeholder
229
+ ], axis=0)
230
+ xs_coarse.append(torch.from_numpy(c))
231
+
232
+ # Sample fine patch
233
+ y0, x0 = sampler.sample(imgf, mask)
234
+ patch_rgb = imgf[y0 : y0 + patch_size, x0 : x0 + patch_size, :]
235
+ patch_mask = mask[y0 : y0 + patch_size, x0 : x0 + patch_size]
236
+ patches_rgb.append(patch_rgb)
237
+ patches_mask.append(patch_mask)
238
+ patches_min.append(y_min[y0 : y0 + patch_size, x0 : x0 + patch_size])
239
+ patches_max.append(y_max[y0 : y0 + patch_size, x0 : x0 + patch_size])
240
+ # Binary location mask (ones inside the patch)
241
+ loc_masks.append(np.ones((patch_size, patch_size), dtype=np.float32))
242
+ yx_list.append((y0, x0))
243
+
244
+ x_coarse = torch.stack(xs_coarse, dim=0).to(device) # Bx7xHc x Wc
245
+
246
+ # Store numpy arrays for fine build
247
+ return {
248
+ "x_coarse": x_coarse,
249
+ "full_h": full_h,
250
+ "full_w": full_w,
251
+ "rgb_patches": patches_rgb,
252
+ "mask_patches": patches_mask,
253
+ "ymin_patches": patches_min,
254
+ "ymax_patches": patches_max,
255
+ "loc_patches": loc_masks,
256
+ "patch_yx": yx_list,
257
+ "mask_full": masks,
258
+ }
259
+
260
+
261
+ def _build_fine_inputs(batch, cond_up: torch.Tensor, device: torch.device) -> torch.Tensor:
262
+ # Build fine input tensor Bx7xP x P from per-sample numpy buffers and upsampled cond maps
263
+ B = cond_up.shape[0]
264
+ P = batch["loc_patches"][0].shape[0]
265
+ xs: List[torch.Tensor] = []
266
+ for i in range(B):
267
+ rgb = batch["rgb_patches"][i]
268
+ ymin = batch["ymin_patches"][i]
269
+ ymax = batch["ymax_patches"][i]
270
+ loc = batch["loc_patches"][i]
271
+ y0, x0 = batch["patch_yx"][i]
272
+
273
+ cond_patch = cond_up[i : i + 1, :, y0 : y0 + P, x0 : x0 + P] # 1x1xPxP
274
+ cond_patch = cond_patch.squeeze(1) # 1xPxP
275
+
276
+ # Convert numpy channels to torch and concat
277
+ rgb_t = torch.from_numpy(np.transpose(rgb, (2, 0, 1))) # 3xPxP
278
+ ymin_t = torch.from_numpy(ymin)[None, ...] # 1xPxP
279
+ ymax_t = torch.from_numpy(ymax)[None, ...] # 1xPxP
280
+ loc_t = torch.from_numpy(loc)[None, ...] # 1xPxP
281
+ x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch.cpu(), loc_t], dim=0).float() # 7xPxP
282
+ xs.append(x)
283
+ x_fine = torch.stack(xs, dim=0).to(device)
284
+ return x_fine
285
+
286
+
287
+ def _build_coarse_targets(masks: List[np.ndarray], out_h: int, out_w: int, device: torch.device) -> torch.Tensor:
288
+ ys: List[torch.Tensor] = []
289
+ for m in masks:
290
+ dm = downsample_label_maxpool(m, out_h, out_w)
291
+ ys.append(torch.from_numpy(dm.astype(np.int64)))
292
+ y = torch.stack(ys, dim=0).to(device) # BxHc4xWc4 with values {0,1}
293
+ return y
294
+
295
+
296
+ def _build_fine_targets(mask_patches: List[np.ndarray], out_h: int, out_w: int, device: torch.device) -> torch.Tensor:
297
+ ys: List[torch.Tensor] = []
298
+ for m in mask_patches:
299
+ dm = downsample_label_maxpool(m, out_h, out_w)
300
+ ys.append(torch.from_numpy(dm.astype(np.int64)))
301
+ y = torch.stack(ys, dim=0).to(device) # BxHf4xWf4 with values {0,1}
302
+ return y
303
 
304
 
305
  if __name__ == "__main__":
306
  main()
307
+
308
+
309
+ def set_seed(seed: int):
310
+ random.seed(seed)
311
+ np.random.seed(seed)
312
+ torch.manual_seed(seed)
313
+ if torch.cuda.is_available():
314
+ torch.cuda.manual_seed_all(seed)
315
+ cudnn.benchmark = False
316
+ cudnn.deterministic = True
317
+
318
+
319
+ def _save_checkpoint(path: str, step: int, model: nn.Module, optim: torch.optim.Optimizer, scaler: GradScaler, best_f1: float):
320
+ os.makedirs(os.path.dirname(path), exist_ok=True)
321
+ state = {
322
+ "step": step,
323
+ "model": model.state_dict(),
324
+ "optim": optim.state_dict(),
325
+ "scaler": scaler.state_dict(),
326
+ "best_f1": best_f1,
327
+ }
328
+ torch.save(state, path)
329
+ print(f"[WireSegHR][train] Saved checkpoint: {path}")
330
+
331
+
332
+ def _load_checkpoint(path: str, model: nn.Module, optim: torch.optim.Optimizer, scaler: GradScaler, device: torch.device) -> Tuple[int, float]:
333
+ ckpt = torch.load(path, map_location=device)
334
+ model.load_state_dict(ckpt["model"])
335
+ optim.load_state_dict(ckpt["optim"])
336
+ try:
337
+ scaler.load_state_dict(ckpt["scaler"]) # may not exist
338
+ except Exception:
339
+ pass
340
+ step = int(ckpt.get("step", 0))
341
+ best_f1 = float(ckpt.get("best_f1", -1.0))
342
+ return step, best_f1
343
+
344
+
345
+ @torch.no_grad()
346
+ def validate(model: WireSegHR, dset_val: WireSegDataset, coarse_size: int, device: torch.device, amp_flag: bool) -> Dict[str, float]:
347
+ # Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
348
+ model = model.to(device)
349
+ metrics_sum = {"iou": 0.0, "f1": 0.0, "precision": 0.0, "recall": 0.0}
350
+ n = 0
351
+ for i in range(len(dset_val)):
352
+ item = dset_val[i]
353
+ img = item["image"].astype(np.float32) / 255.0 # HxWx3
354
+ mask = item["mask"].astype(np.uint8)
355
+ H, W = mask.shape
356
+ # Build coarse input (zeros for cond+loc)
357
+ rgb_c = cv2.resize(img, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
358
+ y = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.float32)
359
+ y_min = cv2.resize(y, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
360
+ y_max = y_min
361
+ x = np.concatenate([
362
+ np.transpose(rgb_c, (2, 0, 1)),
363
+ y_min[None, ...],
364
+ y_max[None, ...],
365
+ np.zeros((1, coarse_size, coarse_size), np.float32),
366
+ np.zeros((1, coarse_size, coarse_size), np.float32),
367
+ ], axis=0)
368
+ x_t = torch.from_numpy(x)[None, ...].to(device)
369
+ with autocast(enabled=(device.type == "cuda" and amp_flag)):
370
+ logits_c, _ = model.forward_coarse(x_t)
371
+ prob = torch.softmax(logits_c, dim=1)[:, 1:2]
372
+ prob_up = F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0].detach().cpu().numpy()
373
+ pred = (prob_up > 0.5).astype(np.uint8)
374
+ m = compute_metrics(pred, mask)
375
+ for k in metrics_sum:
376
+ metrics_sum[k] += m[k]
377
+ n += 1
378
+ if n == 0:
379
+ return {k: 0.0 for k in metrics_sum}
380
+ return {k: v / float(n) for k, v in metrics_sum.items()}
381
+
382
+
383
+ @torch.no_grad()
384
+ def save_test_visuals(model: WireSegHR, dset_test: WireSegDataset, coarse_size: int, device: torch.device, out_dir: str, amp_flag: bool, max_samples: int = 8):
385
+ os.makedirs(out_dir, exist_ok=True)
386
+ for i in range(min(max_samples, len(dset_test))):
387
+ item = dset_test[i]
388
+ img = item["image"].astype(np.float32) / 255.0
389
+ H, W = img.shape[:2]
390
+ rgb_c = cv2.resize(img, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
391
+ y = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.float32)
392
+ y_min = cv2.resize(y, (coarse_size, coarse_size), interpolation=cv2.INTER_LINEAR)
393
+ y_max = y_min
394
+ x = np.concatenate([
395
+ np.transpose(rgb_c, (2, 0, 1)),
396
+ y_min[None, ...],
397
+ y_max[None, ...],
398
+ np.zeros((1, coarse_size, coarse_size), np.float32),
399
+ np.zeros((1, coarse_size, coarse_size), np.float32),
400
+ ], axis=0)
401
+ x_t = torch.from_numpy(x)[None, ...].to(device)
402
+ with autocast(enabled=(device.type == "cuda" and amp_flag)):
403
+ logits_c, _ = model.forward_coarse(x_t)
404
+ prob = torch.softmax(logits_c, dim=1)[:, 1:2]
405
+ prob_up = F.interpolate(prob, size=(H, W), mode="bilinear", align_corners=False)[0, 0].detach().cpu().numpy()
406
+ pred = (prob_up > 0.5).astype(np.uint8) * 255
407
+ # Save input and prediction
408
+ img_bgr = (img[..., ::-1] * 255.0).astype(np.uint8)
409
+ cv2.imwrite(os.path.join(out_dir, f"{i:03d}_input.jpg"), img_bgr)
410
+ cv2.imwrite(os.path.join(out_dir, f"{i:03d}_pred.png"), pred)