本篇博客主要介绍经典的三层BP神经网络的基本结构及反向传播算法的公式推导。
我们首先假设有四类样本,每个样本有三类特征,并且我们在输出层与隐藏层加上一个偏置单元。这样的话,我们可以得到以下经典的三层BP网络结构:
当我们构建BP神经网络的时候,一般是有两个步骤,第一是正向传播(也叫做前向传播),第二是反向传播(也就是误差的反向传播)。
Step1 正向传播
在正向传播之前,可以先给W,b赋初始值,最好不要全设置为0,不然后面会出现问题。赋完初值后,下面开始正向传播:
neth1=W11∗i1+W12∗i2+W13∗i3+b1
Outh1=11+e−neth1 ——>**函数为sigmoid函数:y=11+e−x
隐含层到输出层:
netO1=W′11∗h1+W′12∗h2+W′13∗h3+W′14∗h4+b′1=∑4j=1W′1j∗Outhj+b′1
OutO1=enetO1enetO1+enetO2+enetO3+enetO4 ——>**函数为Softmax型、常用于多分类问题
到这里我们已经完成了正向传播,这里我之所以没给出向量的形式,是因为我觉得标量的形式容易理解。下面我们开始反向传播。
Step2 反向传播
1.计算总误差
这里我们采用交叉熵损失函数,我看到网上大部分都是采用均方误差形式的损失函数,使用交叉熵的相对较少,且使用交叉熵损失函数具有较多优点。
Etotal=−∑4i=1targetOi∗lnOutOi
这里其实只有一项,因为targetOi无论在何时都只有一项为1,也就是说,要么第一类,此时Etotal=−targetO1∗lnOutO1. 要么第二类,此时Etotal=−targetO2∗lnOutO2,要么是第三类,第四类等情况。
2.隐含层到输出层的权值更新
以W′11为例,我们想知道W′11对整体误差产生了多少影响,可用整体误差对W′11求偏导得出。
∂Etotal∂W′11=∂Etotal∂OutO1∗∂OutO1∂netO1∗∂netO1∂W′11
下面我们依次来计算每个式子(不要着急,一步一步算):
∂Etotal∂OutO1=−targetO1∗1OutO1
∂OutO1∂netO1=OutO1∗(1−OutO1)
∂netO1∂netW′11=Outh1
然后将三者相乘,就可以了嘛??
答案是否定的,之前我也是这么推导的,结果在迭代时发现,权值一直在增大,后来经过很长时间的分析才发现,原来这里的∂Etotal∂W′11我求错了。
问题出在哪里呢?
是因为采用了交叉熵的损失函数,在更新W′11时,误差不仅仅来自于O1,还与其他所有的输出层的节点有关系。咋一看非常不可思议,但是仔细一想,你会发现因为在计算 OutOi是,分母中e的指数涉及到了其他所有的神经元的输出,即netO2、netO3等。
所以,我们对W′11的偏导就应该是:
∂Etotal∂W′11=∂Etotal∂OutO1∗∂OutO1∂netO1∗∂netO1∂W′11+∂Etotal∂OutO2∗∂OutO2∂netO1∗∂netO1∂W′11+∂Etotal∂OutO3∗∂OutO3∂netO1∗∂netO1∂W′11+∂Etotal∂OutO4∗∂OutO4∂netO1∗∂netO1∂W′11
因为我这里有四类,所以显得式子很长,但是经过化简,可以得到以下式子:
∂Etotal∂W′11=(OutO1−targetO1)∗Outh1
我们可以令OutO1−targetO1为δO1,意思是在该神经元输出点的误差值,那么我们就可以很容易的得到权值的更新公式:
W′11=W′11−ηδO1Outh1
同理,我们可以得到偏置的更新公式:
b′1=b′1−ηδO1
其中,η表示学习率,这是一个可以自己调节的变量,看自己的数据分布情况,可设为0.1、0.05等等。
从输入层到隐含层的权值更新:
∂Etotal∂W11=∂Etotal∂Outh1∗∂Outh1∂neth1∗∂neth1∂W11
其中,∂Etotal∂Outh1=∂EO1∂Outh1+∂EO2∂Outh1+∂EO3∂Outh1+∂EO4∂Outh1 是因为输出层的每一个神经元都对隐藏层的第一个神经元有误差的传递。
由此我们可以得到:
∂Etotal∂W11=(δO1∗W′11+δO2∗W′12+δO3∗W′13+δO4∗W′14)∗Outh1∗(1−Outh1)∗i1
我们将
(δO1∗W′11+δO2∗W′12+δO3∗W′13+δO4∗W′14)∗Outh1∗(1−Outh1)
记作δh1
那么,我们的输入层到隐含层的权值及偏置的更新策略为:
W11=W11−ηδh1i1
b1=b1−ηδh1
至此,我们已经将两层之间的权值和偏置都更新完了,可以根据以上写成向量或者矩阵的形式,方便后面的运算。
最后,我想说,虽然现在很多机器学习的框架用的如火如荼,但是我觉得对于刚入门的同学来说,一些基本的公式推导和证明还是要掌握的。如有错误,欢迎交流和指正。