WireSegHR / tests /test_model_forward.py
MRiabov's picture
(test) test ResNet backbone
6e8ad0e
import torch
from wireseghr.model import WireSegHR
def test_wireseghr_forward_shapes():
# Use small input to keep test light and avoid downloading weights
model = WireSegHR(backbone="mit_b2", in_channels=3, pretrained=False)
x = torch.randn(1, 3, 64, 64)
logits_coarse, cond = model.forward_coarse(x)
assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
assert cond.shape[0] == 1 and cond.shape[1] == 1
# Expect stage 0 resolution ~ 1/4 of input for MiT
assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
assert cond.shape[2] == 16 and cond.shape[3] == 16
logits_fine = model.forward_fine(x)
assert logits_fine.shape == logits_coarse.shape
def test_wireseghr_forward_shapes_resnet50():
# Ensure ResNet-50 alt backbone works and keeps 1/4 stage0 resolution
model = WireSegHR(backbone="resnet50", in_channels=3, pretrained=False)
x = torch.randn(1, 3, 64, 64)
logits_coarse, cond = model.forward_coarse(x)
assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
assert cond.shape[0] == 1 and cond.shape[1] == 1
# ResNet stage0 is also 1/4 of input
assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
assert cond.shape[2] == 16 and cond.shape[3] == 16
logits_fine = model.forward_fine(x)
assert logits_fine.shape == logits_coarse.shape