File size: 4,814 Bytes
6021dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import mxnet as mx
import numpy as np
from nowcasting.operators.common import constant


def CDNA(data, kernels, mask, batch_size, num_filter, kernel_size):
    """We assume that the kernels and masks are the output of an identity activation

    Parameters
    ----------
    data : mx.sym.symbol
        Shape: (batch_size, C, H, W)
    kernels : mx.sym.symbol
        Shape: (batch_size, M, K, K)
    mask : mx.sym.symbol
        Shape: (batch_size, M, H, W)
    batch_size : int
    num_filter : int
        M
    kernel_size : int
        K
    Returns
    -------
    ret : mx.sym.symbol
        Shape: (batch_size, C, H, W)
    """
    assert kernel_size % 2 == 1, "Only support odd kernel size"
    # Use softmax activation for the kernel and the mask
    kernels = mx.sym.SoftmaxActivation(mx.sym.Reshape(kernels,
                                                      shape=(-1, kernel_size * kernel_size)))
    kernels = mx.sym.Reshape(kernels, shape=(-1, num_filter, kernel_size, kernel_size))
    mask = mx.sym.SoftmaxActivation(mask, mode="channel")

    data_sliced = mx.sym.SliceChannel(mx.sym.expand_dims(data, axis=2), axis=0,
                                      num_outputs=batch_size, squeeze_axis=True) # Each Shape: (C, 1, H, W)
    kernels_sliced = mx.sym.SliceChannel(mx.sym.expand_dims(kernels, axis=2),
                                         axis=0, num_outputs=batch_size,
                                         squeeze_axis=True) # Each Shape: (M, 1, K, K)
    out = []
    for i in range(batch_size):
        ele = mx.sym.Convolution(data=data_sliced[i],
                                 num_filter=num_filter,
                                 kernel=(kernel_size, kernel_size),
                                 pad=(kernel_size/2, kernel_size/2),
                                 weight=kernels_sliced[i], no_bias=True) # Shape: (C, M, H, W)
        out.append(mx.sym.expand_dims(ele, axis=0))
    out = mx.sym.Concat(*out, num_args=batch_size, dim=0) # Shape: (batch_size, C, M, H, W)
    mask = mx.sym.Reshape(mask, reverse=True, shape=(batch_size, 1, num_filter, 0, 0))
    out = mx.sym.broadcast_mul(out, mask)
    out = mx.sym.sum(out, axis=2)
    return out

def STP(data, affine_transform_matrices, mask, num_filter, kernel_size):
    """Spatial Transformer Predictor

    Parameters
    ----------
    data : mx.sym.symbol
    affine_transform_matrices
    mask

    Returns
    -------

    """
    raise NotImplementedError()


def DFN(data, local_kernels, K, batch_size):
    """[NIPS2016] Dynamic Filter Network

    Parameters
    ----------
    data : mx.sym.symbol
        Shape: (batch_size, C, H, W)
    local_kernels : mx.sym.symbol
        Shape: (batch_size, K*K, H, W)
    K : int
        size of the local convolutional kernel
    batch_size : int
        size of the minibatch
    Returns
    -------

    """
    local_kernels = mx.sym.SoftmaxActivation(local_kernels, mode="channel")
    #filter_localexpand_npy = np.eye(K*K, K*K).reshape((K*K, 1, K, K)).astype(np.float32)
    #filter_localexpand = constant(filter_localexpand_npy, name="CDNA_kernels")
    filter_localexpand = mx.sym.one_hot(indices=mx.sym.arange(K * K), depth=K*K)
    filter_localexpand = mx.sym.reshape(mx.sym.transpose(filter_localexpand, axes=(1, 0)),
                                        shape=(K * K, 1, K, K))
    data_sliced = mx.sym.SliceChannel(data, num_outputs=batch_size, axis=0, squeeze_axis=True)
    vec = []
    for i in range(batch_size):
        ele = mx.sym.Convolution(data=mx.sym.expand_dims(data_sliced[i], axis=1),
                                 weight=filter_localexpand,
                                 num_filter=K*K,
                                 kernel=(K, K),
                                 pad=(K // 2, K // 2), no_bias=True)  # Shape (C, K*K, H, W)
        vec.append(mx.sym.expand_dims(ele, axis=0))
    input_localexpanded = mx.sym.Concat(*vec, num_args=len(vec), dim=0)   # Shape (batch_size, C, K*K, H, W)
    output = mx.sym.broadcast_mul(input_localexpanded, mx.sym.expand_dims(local_kernels, axis=1))
    output = mx.sym.sum(output, axis=2)
    return output



if __name__ == '__main__':
    data = mx.sym.Variable('data')
    local_kernels = mx.sym.Variable('local_kernels')
    K = 11
    C = 3
    H = 60
    W = 60
    batch_size = 32
    local_kernels_npy = np.random.normal(size=(batch_size, K*K, H, W))
    data_npy = np.random.normal(size=(batch_size, C, H, W))
    out = data
    for i in range(10):
        out = DFN(data=out, local_kernels=local_kernels, K=K, batch_size=batch_size)
    exe = out.simple_bind(ctx=mx.gpu(), data=(batch_size, C, H, W),
                          local_kernels=(batch_size, K*K, H, W))
    exe.forward(data=data_npy, local_kernels=local_kernels_npy)
    print(exe.outputs[0].asnumpy().shape)