File size: 1,390 Bytes
8ea2eff d46d294 8ea2eff 6e8ad0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from wireseghr.model import WireSegHR
def test_wireseghr_forward_shapes():
# Use small input to keep test light and avoid downloading weights
model = WireSegHR(backbone="mit_b2", in_channels=3, pretrained=False)
x = torch.randn(1, 3, 64, 64)
logits_coarse, cond = model.forward_coarse(x)
assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
assert cond.shape[0] == 1 and cond.shape[1] == 1
# Expect stage 0 resolution ~ 1/4 of input for MiT
assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
assert cond.shape[2] == 16 and cond.shape[3] == 16
logits_fine = model.forward_fine(x)
assert logits_fine.shape == logits_coarse.shape
def test_wireseghr_forward_shapes_resnet50():
# Ensure ResNet-50 alt backbone works and keeps 1/4 stage0 resolution
model = WireSegHR(backbone="resnet50", in_channels=3, pretrained=False)
x = torch.randn(1, 3, 64, 64)
logits_coarse, cond = model.forward_coarse(x)
assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
assert cond.shape[0] == 1 and cond.shape[1] == 1
# ResNet stage0 is also 1/4 of input
assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
assert cond.shape[2] == 16 and cond.shape[3] == 16
logits_fine = model.forward_fine(x)
assert logits_fine.shape == logits_coarse.shape
|