目录

神经网络笔记(四)——Spatial Batch Normalization & Spatial Group Normalization

这里我们跟着实验来完成Spatial Batch Normalization和Spatial Group Normalization,用于对CNN进行优化。

Spatial Batch Normalization

回忆之前普通神经网络的BN层,输入为$X_{input}=(N, D)$,输出形状也为$(N, D)$,其作用是将输入进行归一化然后输出。在这里,对于来自卷积层的数据$X_{input}=(N,C,H,W)$,其输出形状也为$(N,C,H,W)$,其中$N$是一个mini-batch的数据数量,$C$是特征映射(feature map)的数量,有几个感受野就会产生几个特征映射,而$(H, W)$则给出特征映射的大小。

如果特征映射是由卷积运算产生的,我们希望对各个特征C映射进行归一化,使得每个特征映射的不同图片(N)和一张图片内的不同位置(H,W)的统计学特征(均值、标准差等)相对一致。也就是说,spatial batch normalization为C个特征通道中的每一个都计算出来对应的均值和方差,而这里的均值和方差则是遍历对应特征通道中N张图片和其空间维度(H,W)计算得出的。可以理解为之前的D是这里的$C$,之前的N在这里则是$N\times H \times W$。

前向传播

对输入$X_{input}=(N, C, H, W)$转置为维度$(N\times H\times W, C)$,转化成普通的BN层输入并传递给普通(vanilla)BN层的前向传播函数,再对输出转化成对应的$(N, C, H, W)$。代码如下:

1
2
3
4
5
6
7
8
9
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    N, C, H, W=x.shape
    x_new=x.transpose((0,2,3,1)).reshape(-1,C)
    out,cache=batchnorm_forward(x_new,gamma,beta,bn_param)
    out=out.reshape(N,H,W,C).transpose((0,3,1,2))
    pass

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

反向传播

1
2
3
4
5
6
7
8
9
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    N, C, H, W=dout.shape
    dout_new=dout.transpose((0,2,3,1)).reshape(-1,C)
    dx, dgamma, dbeta = batchnorm_backward_alt(dout_new,cache)
    dx = dx.reshape((N,H,W,C)).transpose((0,3,1,2))
    pass

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

Spatial Group Normalization

Spatial Group Normalization可看作解决Layer Normalization在CNN上的表现不能够像Batch Normalization一样好的问题的方案。

前向传播

仿照论文中的代码实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    cache = (x, gamma, beta, G, gn_param)

    N, C, H, W=x.shape
    x_new=x.reshape((N,G,C//G,H,W))
    mean=np.mean(x_new,axis=(2,3,4),keepdims=True)
    var=np.var(x_new,axis=(2,3,4),keepdims=True)
    x_new=(x_new-mean)/np.sqrt(var+eps)
    x_new=x_new.reshape((N, C, H, W))
    out=x_new*gamma+beta
    pass

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

反向传播

参考了这篇博客。求导并不复杂,代码实现起来难度较大。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    N, C, H, W=dout.shape
    x, x_new, mean, var, gamma, beta, G, gn_param=cache
    eps=gn_param.get("eps", 1e-5)

    dgamma = np.sum(dout * x_new, axis=(0, 2, 3)).reshape(1, C, 1, 1)
    x = x.reshape(N, G, C // G, H, W)
    # 这里想通过Gradientcheck必须需要将其reshape为(1, C, 1, 1)
    dbeta = np.sum(dout, axis=(0, 2, 3)).reshape(1, C, 1, 1)

    dx_new = (dout * gamma).reshape(N, G, C // G, H, W)
    mean = mean.reshape(N, G, 1, 1, 1)
    var = var.reshape(N, G, 1, 1, 1)
    dL_dvar = -0.5 * np.sum(dx_new * (x - mean), axis=(2, 3, 4)) * np.power(var.squeeze() + eps, -1.5)
    dL_dvar = dL_dvar.reshape(N, G, 1, 1, 1)

    mid = H * W * C // G
    # add L-->y-->x_hat-->x_i
    dx = dx_new / np.sqrt(var + eps)
    # add L-->mean-->x_i
    dx += ((-1 / mid) * np.sum(dx_new / np.sqrt(var + eps), axis=(2, 3, 4))).reshape(N, G, 1, 1, 1) + dL_dvar * (
        np.sum(-2 * (x - mean) / mid, axis=(2, 3, 4))).reshape(N, G, 1, 1, 1)
    # add L-->var-->x_i
    dx += (2 / mid) * (x - mean) * dL_dvar
    dx = dx.reshape((N, C, H, W))

    # dgamma=np.sum(dout*x,axis=(0,2,3)).reshape(1, C, 1, 1)
    # dbeta=dout.sum(axis=(0,2,3)).reshape((1, C, 1, 1))
    pass

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****