|
|
import torch |
|
|
|
|
|
from wireseghr.model import WireSegHR |
|
|
|
|
|
|
|
|
def test_wireseghr_forward_shapes(): |
|
|
|
|
|
model = WireSegHR(backbone="mit_b2", in_channels=3, pretrained=False) |
|
|
|
|
|
x = torch.randn(1, 3, 64, 64) |
|
|
logits_coarse, cond = model.forward_coarse(x) |
|
|
assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2 |
|
|
assert cond.shape[0] == 1 and cond.shape[1] == 1 |
|
|
|
|
|
assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16 |
|
|
assert cond.shape[2] == 16 and cond.shape[3] == 16 |
|
|
|
|
|
logits_fine = model.forward_fine(x) |
|
|
assert logits_fine.shape == logits_coarse.shape |
|
|
|
|
|
|
|
|
def test_wireseghr_forward_shapes_resnet50(): |
|
|
|
|
|
model = WireSegHR(backbone="resnet50", in_channels=3, pretrained=False) |
|
|
|
|
|
x = torch.randn(1, 3, 64, 64) |
|
|
logits_coarse, cond = model.forward_coarse(x) |
|
|
assert logits_coarse.shape[0] == 1 and logits_coarse.shape[1] == 2 |
|
|
assert cond.shape[0] == 1 and cond.shape[1] == 1 |
|
|
|
|
|
assert logits_coarse.shape[2] == 16 and logits_coarse.shape[3] == 16 |
|
|
assert cond.shape[2] == 16 and cond.shape[3] == 16 |
|
|
|
|
|
logits_fine = model.forward_fine(x) |
|
|
assert logits_fine.shape == logits_coarse.shape |
|
|
|