Generative Adversarial Network对抗生成网络,这是当下机器视觉比较热门的一个技术,由两部分组成生成器(Gnet)和判别器(D_{net})组成
GAN区别与传统的生成网络,生成的图片还原度高,主要缘于D网络基于数据相对位置和数据本身对real数据奖励,对fake数据惩罚的缘故
1.GAN思想 & 与单个传统生成器和判别器的对比
1.1GAN的思想类似于"零和博弈",百度百科这样介绍:
零和游戏的原理如下:两人对弈,总会有一个赢,一个输,如果我们把获胜计算为得1分,而输棋为-1分。则若A获胜次数为N,B的失败次数必然也为N。若A失败的次数为M,则B获胜的次数必然为M。这样,A的总分为(N-M),B的总分为(M-N),显然(N-M)+(M-N)=0,这就是零和游戏的数学表达式。
也就是奖励获胜者,惩罚失败者,在GAN中就是奖励真实图片,且惩罚伪造图片,且奖励和惩罚同时发生,当然现在说这个有点早,往后看你会慢慢的发现这就是D网络的一个反馈机制
1.2单个生成器和判别器与GAN的对比
1.2.1 生成器(Generation)
就是利用模型对图片的学习,最终达到可以自己生成图片的目的
![深入浅出 --- GAN网络原理解析 深入浅出 --- GAN网络原理解析](https://image.shishitao.com:8440/aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzc5LzRlYWU0MjNlYWYyM2U2ODIzYzA0NzU4ZWQ1ZTNjZjlmLnBuZw%3D%3D.png?w=700&webp=1)
就像上图表示的就是生成器的一种(还有一种变分自编码器这里不做过多的赘述)
step1:将图片传入解码器 NN-Encoder 转化为机器可以识别的array形式,然后通过 NN-Decoder生成图片 Picfake
step2:已知真实图片 Picreal,通过loss函数MSE,计算真实图片和生成图片的loss,进而反馈网络
这样看起来好像是没有什么问题,但是需要注意一个问题,这里的loss仅仅计算数据之间的差异,图片的像素val不仅仅是数据的堆叠那么简单,同样相对位置(数据之间的相关性)也是很重要的一个部分,由于G网络没有办法学习到位置的相关性
所以Generation不能生成高还原度的图片
1.2.2 判别器(Discriminator)
简单来说就是一个判断 real图片和 fake图片的二分类模型
input:xoutput:yy∈[0,1]
Discriminator是一个卷积的神经网络,所以可以有效的区分图片的相对位置(即注重数据的相关性),但是由于Discriminator只对真实数据奖励(此时的output大),对伪造的数据惩罚(此时的output小)
所以随机数据的选取比较困难
这样对比下 G 和 D的优劣:
![深入浅出 --- GAN网络原理解析 深入浅出 --- GAN网络原理解析](https://image.shishitao.com:8440/aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzQ5MS9hMTdhNGI1NTFkZjE5ZWU4OWNjNjMxZmE2MGI5NmI5My5wbmc%3D.png?w=700&webp=1)
这样来看 G网络和 D网络各有优缺点,但是刚刚好可以互补,所以GAN网络顺势而生
2.GAN原理
2.1 Generation
由于单一G网络不能学习到数据之间的相关性,所以G网络的反向传播依赖于D网络
![深入浅出 --- GAN网络原理解析 深入浅出 --- GAN网络原理解析](https://image.shishitao.com:8440/aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzc2LzY1NjlhMjIwMjllZWVjODcxZmE0M2Q4OGZlYmM4MTFjLnBuZw%3D%3D.png?w=700&webp=1)
对于生成器而言,它的目的是Generation的output要无限接近于真实的数据分布:
![深入浅出 --- GAN网络原理解析 深入浅出 --- GAN网络原理解析](https://image.shishitao.com:8440/aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzQxMy9iYjIxODFiMTc2ZDk1MDViZjU2MDRmYjUxMzVmN2UwNS5wbmc%3D.png?w=700&webp=1)
这里会用到极大似然估计:
step1:给定真实的数据分布:Pdata,G网络output:x=G(z) ,这里的z是G网络的intput
step2: 那么这个问题就变成一个求使G网络output无限接近于Pdata这个真实分布的θ的极大似然估计求解过程
这里我们用P(x;θ)表示G网络output与Pdata相似的概率,所以G网络就是求
∀θ; P(x;θ)最大 的过程;
下面是求解的过程:
θ∗=argmax∏i=1mPG(xi;θ)=argmax∑i=1mlog(PG(xi;θ))
≈argmaxEx∼pdata[log(PG(xi;θ))]
在这里我们要构建一个KL−divergence的形式,我们都知道个KL−divergence是描述两个概率之间差异的形式,上式后面加一个 ∫xPdata(x)log(Pdata(x)dx ,这是一个与θ无关的项所以不会影响后序结果,却可以辅助构建KL−divergence形式,所以上式可以这样变形
上式 =argmax[∫xPdata(x)log(PG(xi;θ))−∫xPdata(x)log(Pdata(x)dx]
=argmaxKL(PG∣∣Pdata)
=argmniKL(Pdata∣∣PG)
这样看来G网络的计算就是求解argmniGKL(Pdata∣∣PG);但是PG的分布和Pdata的差异(也就是Pdata和PG的KL−divergence),G网络是没有完全办法计算的(G网络不具备数据相关性的学习能力),需要用到D网络的卷积来进行有效计算loss,所以接下来我们要引入D网络进行鉴别;
2.2 Discriminator
Discriminator鉴别器的机制是奖励真实样本,惩罚伪造样本,鉴别器需要获取G网络数据分布 PG和真实数据分布Pdata
- step1:samplefromPdatasamplefromPG
-
step2:这样我们就获取到了real和fake的数据分布,用于D网络的loss计算
下面给出D网络的loss函数:
V(G,D)=Ex∼pdata[log(D(x))]+Ex∼pG[log(1−D(x))]
这里简单赘述下,上面成本函数的计算过程,后面会详细提到:
V(G,D)可以看做是一个组合的loss函数
-
若x是生成的数据x∼,Pdata(x∼)=0PG(x∼)=1,那么:
V(G,D)=Ex∼pG[log(1−D(x∼))]
-
若x是真实的数据x,Pdata(x)=1PG(x)=0,那么:
V(G,D)=Ex∼pdata[log(D(x))]
所以实际用到的:
V(G,D)=Ex∼pdata[log(D(x))]+Ex∼∼pG[log(1−D(x∼))]
下面是求解的过程:
正如上面所说Discriminator鉴别器的机制是奖励真实样本,惩罚伪造样本;所以D网络的训练过程就是迭代计算使得其loss函数V(G,D)最大化的过程;也就是argmaxV(D,G)的过程;
为了方便计算出V(G,D)的最大值,我们求解最优的D∗(也就是D(x)),D网络运行阶段G网络可以看做是固定不变的;
V(G,D)=Ex∼pdata[log(D(x))]+Ex∼pG[log(1−D(x))]
=∫xPdata(x)logD(x)dx+∫xPG(x)log(1−D(x))dx
=∫x[Pdata(x)logD(x)+∫xPG(x)log(1−D(x))]dx
令 a=Pdata(x)D=D(x)b=PG(x)
则 V(G,D)=alogD+blog(1−D)
通过偏导来求上述公式的最大值:
∂D∂V(G,D)=Da+1−Db=0
则: D=a/a+b
所以 D∗=Pdata(x)/(Pdata(x)+PG(x)) 此为使V(D,G)最大化的最优解
代入V(G,D)
上式 =V(G,D∗)
=Ex∼pdata[logPdata(x)+PG(x)Pdata(x)]+Ex∼pG[logPdata(x)+PG(x)PG(x)]
=∫xPdata(x)logPdata(x)+PG(x)Pdata(x)dx+∫xPG(x)logPdata(x)+PG(x)PG(x)dx
=∫xPdata(x)log2Pdata(x)+PG(x)Pdata(x)∗21dx+∫xPG(x)log2Pdata(x)+PG(x)PG(x)∗21dx
=−2log2+∫xPdata(x)log2Pdata(x)+PG(x)Pdata(x)dx+∫xPG(x)log2Pdata(x)+PG(x)PG(x)dx
这里需要提到J一个知识点:
-
JSDdivergence 是 KLdivergence 的对称平滑版本,表示了两个分布之间的差异,上式没有办法转化为 KL−divergence.所以这里我们使用JSD
-
JSD公式: JSD(P∣∣Q)=21D(P∣∣M)+21D(Q∣∣M) M=21(P+Q)
上式 =−2log2+KL(Pdata(x)∣∣2Pdata(x)+PG(x))+KL(PG(x)∣∣2Pdata(x)+PG(x))
=−2log2+2JSD(P∣∣Q)
在数学中可以证明(这里不详细赘述),JSDmax=log2
所以V(G,D)最大值是0,最小值是−2log2;也就是说 JSD越大P和Q的差异越大,JSD越小P和Q的差异就越小
所以D网络最优的场景应当是:
maxD(G,D)最小的情况,此时PG=Pdata也就是生成数据完全与真实数据相等
综上来看,GAN就是 θG,θD=argminGmaxDV(G,D)的过程
3.GAN训练过程
![深入浅出 --- GAN网络原理解析 深入浅出 --- GAN网络原理解析](https://image.shishitao.com:8440/aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzg5OS8wOGVjMDY0M2EyNjkyM2RhN2Q3NGI5MTUyODZkODQ2My5wbmc%3D.png?w=700&webp=1)
这就是 GAN的整个训练过程,蓝色框是D网络的训练过程,红色框是G网络的训练过程
这里我们会注意到:
- D的loss迭代过程中要趋向于最大,所以θd=θd+η∇loss;
- G的loss迭代过程中要趋向于最小,所以θd=θd−η∇loss;
- 可以看出来一般情况下D网络每迭代多次,G网络仅迭代一次;主要原因G,D的反馈传播均依赖于D网络,G网络迭代一次,会让D网络的loss较之前下降,所以D网络要调节多次使得D网络的loss尽可能的大;
- D的loss可以看做对D网络而言分辨real数据和fake数据的损失,所以要最大化真实数据的期望logD(x),同时最小化生成数据期望logD(x∼),也就是最大化log(1−D(x∼)),而 lossD=Ex∼pdata[log(D(x))]+Ex∼∼pG[log(1−D(x∼))],所以D的期望是最大化loss
- 而G网络的lossG=Ex∼∼pG[log(1−D(x∼))],G网络的输入是没有Pdata作为input,所以G网络仅保留V的后半部分,也可以看做一个类别的二分类器.是计算生成图片与目标图片的距离;所以越小越好
4.GAN的优化
我们先来看下G网络loss的图像:
![深入浅出 --- GAN网络原理解析 深入浅出 --- GAN网络原理解析](https://image.shishitao.com:8440/aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzE3Ny85NzJmN2JkZDI4MGY0YzgxOTg1YTkzZTJhYzU5MGZmOS5wbmc%3D.png?w=700&webp=1)
可以看到 原始的G网络的loss=log(1−D(x)),首先我们知道我们初始化一般从0开始,而这个loss在0附近梯度较小,从0->1,梯度越来越大;这显然不符合我们的习惯,我们期望的模型迭代应当是初期梯度较大,随着epoch的增加梯度越来越小,这样有利于函数的收敛
所以我们可以把G网络的loss函数转化为−log(D(x))
以上是GAN基础学习中的一些感悟和整理,感谢阅读