深度学习 Batch Normalization 论文笔记
标题: Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
作者: Sergey Ioffe, Christian Szegedy
下载地址: http://proceedings.mlr.press/v37/ioffe15.html
1 简介
- 在深度学习中,调参一直都是很重要的一环,尤其是学习率和初始值的设置。基于数据在网络中的流通方式,输入的一个小改变可能会影响到后面所有的网络层,层数越深影响越大。简单来说,如果输入的分布发生变化,网络就不得不再去适应新的数据分布,进而影响网络的泛化能力。这个现象称作covariate shift。
- covariate shift这个概念不仅存在于整个网络的输入中,深层网络中每个子网络的输入变化也会出现类似covariate shift的现象。作者将内部网络节点输入的分布发生改变称作Internal Covariate Shift。为了消除这个现象,作者提出了batch normalization算法。batch normalization的主要工作是固定每层输入的均值和方差,它有以下几个强大的优点:
- 使用了batch normalization,可以让网络的梯度不那么依赖于参数和参数初始值的规模。于是我们在训练时可以使用更高的学习速率,而不用担心无法拟合的问题。
- batch normalization可以正则化模型,有效防止过拟合,因此可以替代Dropout。
- batch normalization有效防止了网络出现饱和,因此允许我们使用一些饱和的非线性函数,例如sigmoid。
2 Towards Reducing Internal Covariate Shift
- 作者对Internal Covariate Shift的定义是:“the change in the distribution of network activations due to the change in network parameters during training”。消除Internal Covariate Shift的基本思路,就是在训练时固定每层输入x的分布。传统的方法是对输入进行白化处理,即通过线性变换使其均值为0,方差为1,并且去相关。对网络每层的输入都进行白化处理,可以在一定程度上消除Internal Covariate Shift的病态影响。
- 然而,白化本身也存在诸多问题。在训练时,“the gradient descent step may attempt to update the parameters in a way that requires the normalization to be updated”。文章中举了一个例子,就不展开说了。作者大只想表达的意思是,如果进行白化,某些节点中数值的更新则被白化消除了,于是参数一直增长,但网络的输出和损失几乎没有变化。除此之外,白化要计算整个训练样本的协方差矩阵,计算量过大。
- 总结一下batch normalization提出的动机,就是希望找到一种算法不仅能够进行可微分的归一化,还能不用在整个训练集上进行操作。
3 Normalization via Mini-Batch Statistics
- 基于全局白化的问题,作者提出了两个重要的简化。第一个就是按如下方式对特征(神经元)的每个维度单独做归一化,而非以往的所有输入单元联合白化。第二个就是用每个mini-batches的期望和方差来估计全局的期望和方差:
- 其中期望和方差都是指这一个批次的,上标k代表x的第k个维度。这便是以样本分布来估计全局分布。然而这样做还是有一个问题,就是生成的x绝对值太小,在sigmoid函数中仅在线性范围,降低了模型表达能力。为此,作者还用两个可学习的参数
γ(k)
,
β(k)
,对x做线性处理:
- 这样做,至少可以保证输出
y(k)
与原先的
x(k)
相等。网络可以从中学习
x(k)
的最优重构方案。完整的BN处理过程和反向传播如下所示:
3.1 Training and Inference with Batch-Normalized Networks
- 带BN的网络的训练还是比较好理解的,但inference的时候往往只有一个值,所谓的正则化就没有必要了。一旦网络训练完毕,inference的时候用如下正则化:
- 注意,此时的E[x]和Var[x]是针对整个训练样本的,而不再是某个mini-batch了。因此,在训练过程中还要记录每一个Batch的均值和方差,以便训练完成之后按照下式计算整体的均值和方差。于是,inference时的BN层只是简单的线性变换。完整的inference过程如下
3.2 Batch-Normalized Convolutional Networks
- 在CNN中,如果没有BN,那么一个卷积层或全连接层的表达式为
z=g(Wu+b)
。其中u为输入,b为偏置,W为权重矩阵,g( )为非线性操作。增加了BN之后由于
beta(k)
可以起到偏置的作用,b就可以忽略了。另外,作者认为由于输入u可能是某个非线性操作的输出,直接对u进行BN容易改变下一个非线性操作的输入分布;而Wu+b更可能是对称、非稀疏分布,因此将BN放在Wu+b之后。最后网络变为
z=g(BN(Wu))
。
- 另外,由于CNN参数很多,对每个参数都分别BN显然不现实。因此作者采用了参数共享策略,即一个卷积层只有一组
γ(k)
,
β(k)
。假设某一层的batch大小为m,特征图有c个通道,每个通道尺寸为p*q。那么将m、p、q三个维度整合起来,显得是batch大小有m*p*q。