Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
摘要
当前神经网络层之前的神经网络层的参数变化,引起神经网络每一层输入数据的分布产生了变化,这使得训练一个深度神经网络变得复杂。这样就要求使用更小的学习率,参数初始化也需要更为谨慎的设置。并且由于非线性饱和(注:如sigmoid激活函数的非线性饱和问题),训练一个深度神经网络会非常困难。我们称这个现象为:internal covariate shift。同时利用归一化层输入解决这个问题。我们将归一化层输入作为神经网络的结构,并且对每一个小批量训练数据执行这一操作。Batch Normalization(BN) 能使用更高的学习率,并且不需要过多地注重参数初始化问题。BN 的过程与正则化相似,在某些情况下可以去除Dropout
引言
随即梯度下降法(SGD)通过最小化 \(\theta\) 来最小化损失函数
其中X1…N为训练数据集。在使用SGD时,每次迭代我们使用一个大小为m 的小批量数据X1…m 。通过计算
来逼近损失函数关于权值的梯度。在迭代过程中使用小批量数据相比使用一个样本有几个好处。首先,由小批量数据计算而来的损失函数梯度是由整个训练数据集的损失函数梯度的估计。并且随着小批量数据大小的增加,其性能会越好。其次,由于现代计算平台的并行性,小批量训练会比单个样例训练更高效
尽管随机梯度下降法简单有效,但却需要谨慎的调整模型的参数,特别是在优化过程中加入学习率和参数初始化方式的选择。每一层的输入都会受之前所有层的参数影响,并且随着网络越深,即使参数的变化很小也为对每一层的输入产生很大的影响。这使得训练一个网络变得十分复杂。神经网络层输入分布的改变,使得神经网络层必须不停的适应新的数据分布。当一个学习系统的输入数据分布产生变化,称这种现象为:Experience Covariate Shift。解决这种现象的典型方法是领域适应。输入数据分布相同这一特性,使得子网络更容易训练。因此保持输入的分布不变是有利的。保持一个子网络的输入数据分布不变,对该子网络以外的隐藏层也有积极的作用
称在训练深度神经网络的过程中,网络内部节点的分布发生变换这一现象为 Internal Covariate Shift。而消除这个现象能够加速网络的训练。因此提出了Batch Normalization ,通过减少依赖参数缩放和初始化,进而缓解Internal Covariate Shift,并动态的加速深度神经网络的训练速度。BN 允许使用更高的学习率,而不会有发散的风险。进一步的,BN能够正则化模型,并且不需要Dropout。最后,BN能够使用s型激活函数,同时不会陷入饱和端
降低内协变量漂移(Internal Covariate Shift)
将 Internal Covariate Shift 定义为:在神经网络的训练过程中,由于参数改变,而引起的神经网络激活值分布的改变。通过缓解 Internal Covariate Shift 来提高训练。在训练的过程中保持神经网络层输入的分布不变,来提高训练速度。已知,如果对网络的输入进行白化(输入线性变换为具有零均值和单位方差,并去相关),网络训练将会收敛的更快。通过白化每一层的输入,采取措施实现输入的固定分布,消除内部协变量转移的不良影响
考虑在每个训练步骤或在某些间隔来白化激活值,通过直接修改网络或根据网络激活值来更改优化方法的参数。然而,如果这些修改分散在优化步骤中,那么梯度下降步骤可能会试图以要求标准化进行更新的方式来更新参数,这会降低梯度下降步骤的影响
我们希望确保对于任何参数值,网络总是产生具有所需分布的激活值。这样做将允许关于模型参数损失的梯度来解释标准化,以及它对模型参数\(\Theta\)的依赖。希望通过对相对于整个训练数据统计信息的单个训练样本的激活值进行归一化来保留网络中的信息
通过Mini-Batch统计进行标准化
由于每一层输入的整个白化是代价昂贵的并且不是到处可微分的,因此做了两个必要的简化。首先是单独标准化每个标量特征,从而代替在层输入输出对特征进行共同白化,使其具有零均值和单位方差。对于多为输入的层,将标准化每一维。简单标准化层的每一个输入可能会改变层可以表示什么。例如,标准化sigmoid的输入会将它们约束到非线性的线性状态。为了解决这个问题,要确保插入到网络中的变换可以表示恒等变换。为了实现这个,对于每一个激活值x(k),引入成对的参数\(\gamma^{(k)}, \beta^{(k)}\),它们会归一化和移动标准值
\[
y^{k} = \gamma^{k} \hat a^{k} + \beta(k)
\]
这些参数与原始的模型参数一起学习,并恢复网络的表示能力
每个训练步骤的批处理设置是基于整个训练集的,将使用整个训练集来标准化激活值。然而,当使用随机优化时,这是不切实际的。因此,做了第二个简化:由于在随机梯度训练中使用小批量,每个小批量产生每次激活平均值和方差的估计。这样,用于标准化的统计信息可以完全参与梯度反向传播。通过计算每一维的方差而不是联合协方差,可以实现小批量的使用;在联合情况下,将需要正则化,因为小批量大小可能小于白化的激活值的数量,从而导致单个协方差矩阵
批标准化步骤
BN变换可以添加到网络上来操纵任何激活。在公式\(y=BN_{\gamma,\beta}(x)\)中,参数\(\gamma\)和\(\beta\)需要进行学习,但应该注意到在每一个训练样本中BN变换不单独处理激活。相反,\(BN_{\gamma,\beta}(x)\)取决于训练样本和小批量数据中的其它样本。所有的这些子网络输入都有固定的均值和方差,尽管这些标准化的\(\hat {x} ^{(k)}\)的联合分布可能在训练过程中改变,但预计标准化输入的引入会加速子网络的训练,从而加速整个网络的训练
BN变换是将标准化激活引入到网络中的可微变换。这确保了在模型训练时,层可以继续学习输入分布,表现出更少的内部协变量转移,从而加快训练。此外,应用于这些标准化的激活上的学习到的仿射变换允许BN变换表示恒等变换并保留网络的能力
批量标准化网络的训练与推理
为了批标准化一个网络,根据上面的算法,指定一个激活的子集,然后在每一个激活中插入BN变换。任何以前接收x作为输入的层现在接收BN(x)作为输入。采用批标准化的模型可以使用批梯度下降,或者用小批量数据大小为m>1的随机梯度下降,或使用它的任何变种例如Adagrad进行训练。依赖小批量数据的激活值的标准化可以有效地训练,但在推断过程中是不必要的也是不需要的;希望输出只确定性地取决于输入。为此,一旦网络训练完成,将使用总体统计来进行标准化,而不是小批量数据统计
\[
\hat x = \frac {x - E[x]}{\sqrt {Var[x] + \epsilon}}
\]
此时,如果忽略\(\epsilon\),这些标准化的激活具有相同的均值0和方差1,使用无偏方差估计\(Var[x] = \frac{m}{m-1} E_{\beta} [\sigma^2_{\beta}]\),其中期望是在大小为m的小批量训练数据上得到的,\(\sigma^2_{\beta}\)是其样本方差。使用这些值移动平均,在训练过程中可以跟踪模型的准确性。由于均值和方差在推断时是固定的,因此标准化是应用到每一个激活上的简单线性变换。它可以进一步由缩放\(\gamma\)和转移\(\beta\)组成,以产生代替BN(x)的单线性变换
训练批标准化网络的过程
批标准化卷积网络
批标准化可以应用于网络的任何激活集合。这里专注于仿射变换和元素级非线性组成的变换
\[
z = g(Wu + b)
\]
其中W和b是模型学习的参数,g(⋅)是非线性例如sigmoid或ReLU。这个公式涵盖了全连接层和卷积层。在非线性之前通过标准化x=Wu+b加入BN变换。也可以标准化层输入u,但由于u可能是另一个非线性的输出,它的分布形状可能在训练过程中改变,并且限制其第一矩或第二矩不能去除协变量转移。相比之下,Wu+b更可能具有对称,非稀疏分布,即“更高斯”,对其标准化可能产生具有稳定分布的激活
由于对Wu+b进行标准化,偏置b可以忽略,因为它的效应将会被后面的中心化取消(偏置的作用会归入到算法1的β)。因此,z=g(Wu+b)被\(z=g(BN(Wu))\)替代,其中BN变换独立地应用到x=Wu的每一维,每一维具有单独的成对学习参数\(\gamma^{(k)},\beta^{(k)}\)
另外,对于卷积层,希望标准化遵循卷积特性——为的是同一特征映射的不同元素,在不同的位置,以相同的方式进行标准化。为了实现这个,在所有位置联合标准化了小批量数据中的所有激活。在第一个算法中,让\(B\)是跨越小批量数据的所有元素和空间位置的特征图中所有值的集合——因此对于大小为m的小批量数据和大小为p×q的特征映射,使用有效的大小为m' = \(|B|\) = m ⋅ pq 的小批量数据。每个特征映射学习一对参数\(\gamma^{(k)}\)和\(\beta^{(k)}\),而不是每个激活。第二个算法进行类似的修改,以便推断期间BN变换对在给定的特征映射上的每一个激活应用同样的线性变换
批标准化可以提高学习率
通过标准化整个网络的激活值,在数据通过深度网络传播时,它可以防止层参数的微小变化被放大。批标准化也使训练对参数的缩放更有弹性。通常,大的学习率可能会增加层参数的缩放,这会在反向传播中放大梯度并导致模型爆炸。然而,通过批标准化,通过层的反向传播不受其参数缩放的影响。对于标量\(\alpha\)
\[
BN(Wu)=BN((aW)u)
\]
因此,\(\frac {∂BN((aW)u)}{∂u}= \frac {∂BN(Wu)}{∂u}\)。因此,标量不影响层的雅可比行列式,从而不影响梯度传播。此外,\(\frac {∂BN((aW)u)}{∂(aW)}=\frac {1}{a}⋅\frac {∂BN(Wu)}{∂W}\),因此更大的权重会导致更小的梯度,并且批标准化会稳定参数的增长。研究者进一步推测,批标准化可能会导致雅可比行列式的奇异值接近于1,这被认为对训练是有利的
实验
实验表明,批标准化有助于网络训练的更快,取得更高的准确率,原因是随着训练的进行,批标准化网络中的分布更加稳定,这有助于训练
加速BN网络
提高学习率。在批标准化模型中,能够从高学习率中实现训练加速,没有不良的副作用
删除丢弃。发现从BN-Inception中删除丢弃可以使网络实现更高的验证准确率
推测批标准化提供了类似丢弃的正则化收益,因为对于训练样本观察到的激活受到了同一小批量数据中样本随机选择的影响
更彻底地搅乱训练样本。这导致验证准确率提高了约1%
减少L2全中正则化。虽然在Inception中模型参数的L2损失会控制过拟合,但在修改的BN-Inception中,损失的权重减少了5倍。研究者发现这提高了在提供的验证数据上的准确性
通过仅使用批标准化(BN-Baseline),可以在不到Inception一半的训练步骤数量内将准确度与其相匹配,这显著提高了网络的训练速度
结论
提出的方法(批标准化)大大加快了深度网络的训练。该方法从标准化激活以及将这种标准化结合到网络体系结构本身中汲取了它的力量。批处理规范化每次激活只添加两个额外的参数,并且这样做保留了网络的表示能力。本文提出了一种用批量标准化网络构建、训练和执行推理的算法。所得到的网络可以用饱和非线性进行训练,更容忍增加的训练速率,并且通常不需要dropout来进行正规化
总结
为什么需要它
在BN出现之前,我们的归一化操作一般都在数据输入层,对输入的数据进行求均值以及求方差做归一化,但是BN的出现打破了这一个规定,我们可以在网络中任意一层进行归一化处理。因为我们现在所用的优化方法大多都是min-batch SGD,所以我们的归一化操作就成为Batch Normalization
神经网络层之前的神经网络层的参数变化,将会引起神经网络每一层输入数据的分布产生了变化,这使得训练一个深度神经网络会非常困难。作者称这个现象为:internal covariate shift
BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同的分布的方法,是用来解决“Internal Covariate Shift”问题的
Mini-Batch SGD相对于One Example SGD的两个优势:梯度更新方向更准确;并行计算速度快
它的思想
BN的基本思想是:因为深层神经网络在做非线性变换前的激活输入值(就是那个x=WU+B,U是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近,这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,即通过这种方法让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度
BN其实就是把每个隐层神经元的激活输入分布从偏离均值为0方差为1的正态分布通过平移均值压缩或者扩大曲线尖锐程度,调整为均值为0方差为1的正态分布
当输入均值为0,方差为1时,当使用sigmoid激活函数时,绝大多数的输入都落到了[-2,2]的区间,而这一段是sigmoid函数接近于线性变换的区域,意味着x的小变化会导致非线性函数值较大的变化,也即是梯度变化较大,对应导数函数图中明显大于0的区域,就是梯度非饱和区。即经过BN后,目前大部分Activation的值落入非线性函数的线性区内,其对应的导数远离导数饱和区,这样来加速训练收敛过程
然后为了不让网络的表达能力下降,保证非线性的获得,它又对变换后的满足均值为0方差为1的x又进行了scale加上shift操作(y=scale*x+shift),这样就等价于将非线性函数的值从正中心周围的线性区往非线性区移动了一点
推理的过程中,我们得不到实例集合的均值和方差,因此,这里用从所有训练实例中获得的统计量来代替Mini-Batch里面m个训练实例获得的均值和方差统计量
BN的优点
- 提升了训练速度,收敛过程大大加快
- 增加分类效果,一种解释是这是类似于Dropout的,一种是自带了防止过拟合的正则化表达方式,所以不用Dropout也能达到相当的效果
- 调参过程也简单多了,对于初始化要求没那么高,而且可以使用大的学习率等
参考链接
深入理解Batch Normalization批标准化
批归一化(Batch Normalization)
视频链接
GoogLeNetv1 论文研读笔记
GoogLeNetv3 论文研读笔记
GoogLeNetv4 论文研读笔记
GoogLeNetv2 论文研读笔记的更多相关文章
-
GoogLeNetv4 论文研读笔记
Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning 原文链接 摘要 向传统体系结构中引入 ...
-
GoogLeNetv3 论文研读笔记
Rethinking the Inception Architecture for Computer Vision 原文链接 摘要 卷积网络是目前最新的计算机视觉解决方案的核心,对于大多数任务而言,虽 ...
-
GoogLeNetv1 论文研读笔记
Going deeper with convolutions 原文链接 摘要 研究提出了一个名为"Inception"的深度卷积神经网结构,其目标是将分类.识别ILSVRC14数据 ...
-
ResNet 论文研读笔记
Deep Residual Learning for Image Recognition 原文链接 摘要 深度神经网络很难去训练,本文提出了一个残差学习框架来简化那些非常深的网络的训练,该框架使得层能 ...
-
<; AlexNet - 论文研读个人笔记 >;
Alexnet - 论文研读个人笔记 一.论文架构 摘要: 简要说明了获得成绩.网络架构.技巧特点 1.introduction 领域方向概述 前人模型成绩 本文具体贡献 2.The Dataset ...
-
《DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks》研读笔记
<DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks>研读笔记 论文标题:DSLR-Quality ...
-
论文阅读笔记 - YARN : Architecture of Next Generation Apache Hadoop MapReduceFramework
作者:刘旭晖 Raymond 转载请注明出处 Email:colorant at 163.com BLOG:http://blog.csdn.net/colorant/ 更多论文阅读笔记 http:/ ...
-
论文阅读笔记 - Mesos: A Platform for Fine-Grained ResourceSharing in the Data Center
作者:刘旭晖 Raymond 转载请注明出处 Email:colorant at 163.com BLOG:http://blog.csdn.net/colorant/ 更多论文阅读笔记 http:/ ...
-
论文阅读笔记 Word Embeddings A Survey
论文阅读笔记 Word Embeddings A Survey 收获 Word Embedding 的定义 dense, distributed, fixed-length word vectors, ...
随机推荐
-
hdu 5976 Detachment
Detachment Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)Total ...
-
Hbase中的BloomFilter(布隆过滤器)
(1) Bloomfilter在hbase中的作用 Hbase利用bloomfilter来提高随机读(get)的性能,对于顺序读(scan)而言,设置Bloomfilter是没有作用的(0.9 ...
-
php7安装
# 配置参数 ./configure --prefix=/usr/local/php7 \ --with-config-file-path=/usr/local/php7/etc \ --with-m ...
-
zabbix3.0 安装方法,一键实现短信、电话、微信、APP 告警
引言 免费开源监控工具 Zabbix 因其强大的监控功能得到各大互联网公司的广泛认可,具体功能不再详细介绍,在之前发布的 Zabbix 2.4.1 安装及微信短信提醒已经做了详细介绍,本篇主要对 Za ...
-
linux下安装vld
将vld-0.10.1下载并传到/home/wangxiaolan/tar 1.进行解压 tar zxvf vld-0.10.tgz 2.进入 cd vld-0.10.1 3.usr/local/ph ...
-
服务器、应用框架、MVC、MTV
web服务器:负责处理http请求,响应静态文件,常见的有Apache,Nginx以及微软的IIS. 应用服务器:负责处理逻辑的服务器.比如php.python的代码,是不能直接通过nginx这种we ...
-
Windows设置.txt文件默认打开程序
一.配置某个程序默认打开哪些类型的文件(以firefox为例) 依次打开”控制面板\程序\默认程序“,点击”设置默认程序“ 在右侧列表找到firefox,选中 以firefox为例,”将此程序设置为默 ...
-
Qt简介 及与MFC、GDK+的比较
Qt C++图形用户界面应用程序开发框架. Qt的由来和发展 1.QT由来 Haavard Nord 和Eirik Chambe-Eng于1991年开始开发"Qt",1994年3月 ...
-
Population Size CodeForces - 416D (贪心,模拟)
大意: 给定$n$元素序列$a$, 求将$a$划分为连续的等差数列, 且划分数尽量小. $a$中的$-1$表示可以替换为任意正整数, 等差数列中必须也都是正整数. 贪心策略就是从前到后尽量添进一个等差 ...
-
(转)Python进阶:函数式编程(高阶函数,map,reduce,filter,sorted,返回函数,匿名函数,偏函数)
原文:https://www.cnblogs.com/chenwolong/p/reduce.html 函数式编程 函数是Python内建支持的一种封装,我们通过把大段代码拆成函数,通过一层一层的函数 ...