File size: 3,696 Bytes
60465e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) OpenMMLab. All rights reserved.
def nlc_to_nchw(x, hw_shape):
    """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.



    Args:

        x (Tensor): The input tensor of shape [N, L, C] before conversion.

        hw_shape (Sequence[int]): The height and width of output feature map.



    Returns:

        Tensor: The output tensor of shape [N, C, H, W] after conversion.

    """
    H, W = hw_shape
    assert len(x.shape) == 3
    B, L, C = x.shape
    assert L == H * W, 'The seq_len doesn\'t match H, W'
    return x.transpose(1, 2).reshape(B, C, H, W)


def nchw_to_nlc(x):
    """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.



    Args:

        x (Tensor): The input tensor of shape [N, C, H, W] before conversion.



    Returns:

        Tensor: The output tensor of shape [N, L, C] after conversion.

    """
    assert len(x.shape) == 4
    return x.flatten(2).transpose(1, 2).contiguous()


def nchw2nlc2nchw(module, x, contiguous=False, **kwargs):
    """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the

    reshaped tensor as the input of `module`, and the convert the output of

    `module`, whose shape is.



    [N, L, C], to [N, C, H, W].



    Args:

        module (Callable): A callable object the takes a tensor

            with shape [N, L, C] as input.

        x (Tensor): The input tensor of shape [N, C, H, W].

                contiguous:

        contiguous (Bool): Whether to make the tensor contiguous

            after each shape transform.



    Returns:

        Tensor: The output tensor of shape [N, C, H, W].



    Example:

        >>> import torch

        >>> import torch.nn as nn

        >>> norm = nn.LayerNorm(4)

        >>> feature_map = torch.rand(4, 4, 5, 5)

        >>> output = nchw2nlc2nchw(norm, feature_map)

    """
    B, C, H, W = x.shape
    if not contiguous:
        x = x.flatten(2).transpose(1, 2)
        x = module(x, **kwargs)
        x = x.transpose(1, 2).reshape(B, C, H, W)
    else:
        x = x.flatten(2).transpose(1, 2).contiguous()
        x = module(x, **kwargs)
        x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
    return x


def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):
    """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the

    reshaped tensor as the input of `module`, and convert the output of

    `module`, whose shape is.



    [N, C, H, W], to [N, L, C].



    Args:

        module (Callable): A callable object the takes a tensor

            with shape [N, C, H, W] as input.

        x (Tensor): The input tensor of shape [N, L, C].

        hw_shape: (Sequence[int]): The height and width of the

            feature map with shape [N, C, H, W].

        contiguous (Bool): Whether to make the tensor contiguous

            after each shape transform.



    Returns:

        Tensor: The output tensor of shape [N, L, C].



    Example:

        >>> import torch

        >>> import torch.nn as nn

        >>> conv = nn.Conv2d(16, 16, 3, 1, 1)

        >>> feature_map = torch.rand(4, 25, 16)

        >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))

    """
    H, W = hw_shape
    assert len(x.shape) == 3
    B, L, C = x.shape
    assert L == H * W, 'The seq_len doesn\'t match H, W'
    if not contiguous:
        x = x.transpose(1, 2).reshape(B, C, H, W)
        x = module(x, **kwargs)
        x = x.flatten(2).transpose(1, 2)
    else:
        x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
        x = module(x, **kwargs)
        x = x.flatten(2).transpose(1, 2).contiguous()
    return x