File size: 1,390 Bytes
8ea2eff
 
 
 
 
 
 
d46d294
8ea2eff
 
 
 
 
 
 
 
 
 
 
6e8ad0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch

from wireseghr.model import WireSegHR


def test_wireseghr_forward_shapes():
    # Use small input to keep test light and avoid downloading weights
    model = WireSegHR(backbone="mit_b2", in_channels=3, pretrained=False)

    x = torch.randn(1, 3, 64, 64)
    logits_coarse, cond = model.forward_coarse(x)
    assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
    assert cond.shape[0] == 1 and cond.shape[1] == 1
    # Expect stage 0 resolution ~ 1/4 of input for MiT
    assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
    assert cond.shape[2] == 16 and cond.shape[3] == 16

    logits_fine = model.forward_fine(x)
    assert logits_fine.shape == logits_coarse.shape


def test_wireseghr_forward_shapes_resnet50():
    # Ensure ResNet-50 alt backbone works and keeps 1/4 stage0 resolution
    model = WireSegHR(backbone="resnet50", in_channels=3, pretrained=False)

    x = torch.randn(1, 3, 64, 64)
    logits_coarse, cond = model.forward_coarse(x)
    assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2
    assert cond.shape[0] == 1 and cond.shape[1] == 1
    # ResNet stage0 is also 1/4 of input
    assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16
    assert cond.shape[2] == 16 and cond.shape[3] == 16

    logits_fine = model.forward_fine(x)
    assert logits_fine.shape == logits_coarse.shape