GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

时间:2021-12-10 20:28:25

生成式对抗网络是20年来机器学习领域最酷的想法。   ——Yann LeCun

自从两年前蒙特利尔大学的Ian Goodfellow等人提出生成式对抗网络(Generative Adversarial Networks,GAN)的概念以来,GAN呈现出井喷式发展。

这篇发布在O’Reilly上的文章中,作者向初学者进行了GAN基础知识答疑,并手把手教给大家如何用GAN创建可以生成手写数字的程序。

本教程由两人完成:Jon Bruner是O’Reilly编辑组的一员,负责管理硬件、互联网、制造和电子学等方面的出版物;Adit Deshpande是加州大学洛杉矶分校计算机科学专业的大二学生。Jon Bruner还对教程内容进行了视频讲解:


量子位提示:本篇文章提到的所有代码都可以在GitHub上下载

https://github.com/jonbruner/generative-adversarial-networks

不多说了,开启你的GAN之旅吧——

序言

GAN是一种神经网络,它会学习创建一些类似已知输入数据的合成数据。目前,研究人员已经可以用GAN合成出从卧室到专辑封面等一系列照片,GAN也显示出反映高阶语义逻辑的非凡能力。

这些例子非常复杂,但构建可以生成简单图像的GAN真的不难。在这个教程里,我们将学习构建分析手写数字图像的GAN,并且从零开始学如何让它学会生成新图像。其实说白了,就是教会神经网络如何写字。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

上面这张图就是我们在本教程中构建的用GAN生成的示例图像。

GAN的架构

GAN中包含两个模型:生成模型(Generative Model)和判别模型(Discriminative Model)。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

判别模型是一个分类器,它判断给定的图片到底是来自数据集的真实图像,还是人工创建的假图像。这基本上就是一个表现为卷积神经网络(CNN)形式的二元分类器。

生成模型通过反卷积神经网络将随机输入值转化为图像。

在数次训练迭代的历程中,判别器和生成器的的权重和偏差都是通过反向传播训练的。判别器学习从一堆生成器生成的假数字图像中,找出真正的数字图像。与此同时,生成器通过判别器的反馈学习如何生成具有欺骗性的图片,防止被判别器识别出来。

准备工作

我们将创造一个可以生成手写数字的GAN,希望可以骗过最好的分类器(当然也包括人类)。我们将使用谷歌开源的TensorFlow使在GPU上训练神经网络更容易。

TensorFlow下载地址:
https://www.tensorflow.org/

在学习这个教程之前,希望你可以了解一些TensorFlow的知识。如果之前没有接触过它,建议你先看看相关的文章和教程。

加载MNIST数据

首先,我们需要给判别器输入一系列真实的手写数字图像,算是给判别器的一个参考。这里使用的是深度学习基准数据集MNIST,这是一个手写数字图片数据库,每一张都是0-9的单个数字,且每一张都是抗锯齿(Anti-aliasing)的灰度图。数据库中包含美国国家标准与技术研究所收集的人口调查局员工和高中生写的70000张数字图像。

MNIST数据集链接(英文):
http://yann.lecun.com/exdb/mnist/

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

我们从导入TensorFlow和其他有用的数据库开始讲起。首先我们需要用TensorFlow的便捷函数导入MNIST的图像,不妨把这个函数称为read_data_sets。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

我们创造的MNIST的变量包含图像和标签,并将数据集分为训练集和验证集(不过在本教程里我们并不需要考虑标签这个事情)。我们可以通过调用mnist上的next_batch进行检索,现在我们先加载一张图片看看。

这张图像一开始被格式化为一列784像素,我们可以将它们改造成28×28像素的图像并且用PyPlot查看。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

如果你再次运行上面的cell,你会发现和MNIST训练集不同的图像。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

判别网络

判别器是一个卷积神经网络,接收图片大小为28×28×1的输入图像,之后返还一个单一标量值来描述输入图像的真伪——判断到底是来自MNIST图像集还是生成器。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

判别器的结构与TensorFlow的样例CNN分类模型密切相关。它有两层特征为5×5像素特征的卷积层,还有两个全连接层按图像中每个像素计算增加权重的层。

创建了神经网络后,通常需要将权重和偏差初始化,这项任务可以在tf.get_variable中完成。权重在截断正态分布中被初始化,偏差在0处被初始化。

tf.nn.conv2d()是TensorFlow中的标准卷积函数,它包含四个参数:首个参数就是输入图像(input volume),也就是本示例中的28×28像素的图片;第二个参数是滤波器/权矩阵,最终你也可以改变卷积的“步幅”和“填充”。这两个参数控制着输出图像的尺寸大小。

其实上面这些就是一个普通简单的二进制分类器,如果你不是初次接触CNN,应该对此并不陌生。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

定义了判别器之后,我们需要回头看看生成模型。我们将以Tim O’Shea编写的简单生成器代码为基础构建模型的整体结构。

Code链接:
https://github.com/osh/KerasGAN

其实你可以把生成器想象成反向卷积神经网络的一种。判别器就是一个典型的CNN,它能将二维或三维的像素值矩阵(matrix of pixel values)转化成一个概率。然而作为一个生成器,需要d-维度向量 d-dimensional vector ,并需要将其变为28*28的图像。ReLU和批量标准化(batch normalization)也经常用于稳定每一层的输出。

在这个神经网络中,我们用了三个卷积层和插值,直到形成28*28像素的图像。

我们在输出层添加了一个tf.sigmoid() 激活函数,它将挤压灰色呈现白色或黑色相,从而产生一个更清晰的图像。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

生成样本图像

定义完生成器和判别函数,我们现在看看没有训练过的生成器会生成怎样的样例。

首先打开TensorFlow,为我们的生成器创建一个占位符。占位符的形式是None x z_dimensions,关键字None意味着可以在运行会话时确定它的值。我们通常用None作为我们第一个维度,所以我们的批处理大小是可变的。有了关键词None,所以不需要指定batch_size。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

现在,我们创建能够保存生成器输出的变量(generated_image_output),还要将输入的随机噪声向量初始化。np.random.normal()函数具备了3个参数,前两个定义了正态分布的平均偏差和标准偏差,最后一个定义了向量(1 x 100)的形状。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

接下来需要将所有变量初始化,将z_batch 放到占位符中,并运行这部分代码。

sess.run()函数有两个参数。第一个叫做“获取”参数,定义你在计算中感兴趣的值。在这个案例中,我们想看到生成器会输出什么。如果看看最后的代码片段,你将看到生成函数的输出被存储在generated_image_output里,我们将使用generated_image_output作为第一个参数。

第二个参数相当于一个输入字典,在运行时可以取代计算图,也就是我们要填到占位符里的。在我们的例子中,我们需要将z_batch变量输入到之前定义的z_placeholder中,之后在PyPlot中将图片重新调整为28*28像素。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

它看起来像噪音对吧。现在我们需要训练生成网络中的权重和偏差,将随机数转变为可识别的数字。我们再看看损失函数和优化。

训练GAN

构建和调试GAN就复杂在它有两个损失函数:一个鼓励生成器创造出更好的图像,另一个鼓励判别器区分哪个是真图像,哪个是生成器生成的。

我们同时训练生成器和判别器,当判别器能够很好区分图像来自哪里时,生成器也能更好地调整它的权重和偏差来生成更以假乱真的图像。

这个网络的输入和输出如下:

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

所以,让我们首先考虑一下我们需要在网络中得到什么。判别器的目标是正确地将MNIST图像标记为真,而判别器生成的标记为假。我们将计算判别器的两种损失:Dx和1(代表MNIST中的真实图像)的损失,以及Dg与0(代表生成图像)的损失。我们将这个函数在TensorFlow中的tf.nn.sigmoid_cross_entropy_with_logits()函数上运行,计算Dx和0与Dg和1之间的交叉熵损失。

sigmoid_cross_entropy_with_logits 是在未缩放的值下运行的,而不是在0到1之间的概率值。看一下判别器的最后一行:这里并没有softmax或sigmoid函数层。如果判别器“饱和”了,或者有足够的信心可以在给出生成图像后返回0,那么就会使判别器的梯度下降失去作用。

tf.reduce_mean()函数选取的是交叉熵函数返回的矩阵中所有分量的平均值。这是一种将损失减小到单个标量值的方法,而不是向量或矩阵。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

现在我们来设置生成器的损失函数。我们想让生成网络的图像骗过判别器:当输入生成图像时,判别器可以输出接近1的值,来计算Dg与1之间的损失。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

现在我们已经得到损失函数,需要定义优化程序了。生成网络的优化程序只需要升级生成器的权重,而不是判别器的。同样的,当训练判别器的时候,我们需要固定生成器的权重。

为了使这些看起来不同,我们需要创建两个变量列表,一个是判别器的权重和偏差,另一个是生成器的权重和偏差。这就是当给TensorFlow变量取名字需要深思熟虑的原因。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

下一步,你需要制定两个优化器,我们一般选择Adam优化算法,它利用了自适应学习速率和动量。我们调用Adam最小函数并且指定我们想更新的变量——也就是我们训练生成器时的生成器权重和偏差,和我们训练判别器时的判别器权重和偏差。

我们为判别器设置了两套不同的训练方案:一种是用真实图像训练判别器,另一种是用生成的“假图像”训练它。有时使用不同的学习速率很有必要,或者单独使用它们来规范学习的其他方面。

其他方面是指的什么?Code下载链接:
https://github.com/jonbruner/ezgan

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

收敛GAN是一件棘手的事情,经常需要训练很长时间。可以用TensorBoard追踪训练过程:它可以用图表描绘标量属性(如损失),展示训练中的样本图像,并展示神经网络中的拓扑结构。

想了解更多TensorBoard信息?链接(英文):
https://www.tensorflow.org/get_started/summaries_and_tensorboard

如果在自己的机器上运行此脚本,要记得包含下面的cell。之后在终端窗口中运行 tensorboard —logdir=tensorboard/ ,再在浏览器中输入http://localhost:6006,打开TensorBoard。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

现在先给判别器几个简单的原始训练进行迭代,这种方法有助于形成对生成器有用的梯度。

之后我们继续进行主要的训练循环。当训练生成器的时候,我们需要将随机的z向量输入到生成器中,并将其输出传递给判别器(这就是我们早先定义的Dg变量)。生成器的权重和偏差将被改变,主要是为了生成能骗过判别器的图像。

为了训练判别器,我们将给它提供一组来自MNIST数据集中的正面例子,并且再次用生成的图像训练判别器,用它们作为反面例子。

因为训练GAN通常需要很长时间,所以我们建议如果您是第一次使用这个教程,建议先不要运行这个代码块。但你可以先执行下面的代码块,让它生成出一个预先训练模型。

如果你想自己运行这个代码块,请做好长时间等待的准备:在速度相对较快的GPU上运行大概需要3小时,在台式机的CPU上可能耗费10倍时间。

所以,建议你跳过上面直接执行下面的cell。它加载了一个我们在高速GPU机器上训练了10小时的模型,你可以试验下训练过的GAN。

GAN入门教程 | 从0开始,手把手教你学会最火的神经网络

训练不易

众所周知训练GAN很艰难。在没有正确的超参数、网络体系结构和培训流程的情况下,判别器会压制生成器。

一个常见的故障模式是,判别器压制了生成器,肯定地把生成图像定义为假的。当判别器以绝对肯定时,会使生成器无梯度可降。这就是为什么我们建立判别器来产生未缩放的输出,而不是通过一个sigmoid函数将其输出推到0或1。

在另一种常见的故障模式(模式崩溃)中,生成器发现并利用了判别器中的一些弱点,如果它不顾生成器输入z.变量,生成了很多相似图像,你是可以识别出这种模式崩溃的。模式崩溃有时可以通过“强化”鉴别器来修正,例如通过调整其训练速率或重新配置它的层。

研究人员已经确定了一些帮助建立稳定的GAN的小方法。

你也想让GANs稳定一下?Code链接:
https://github.com/soumith/ganhacks

结论

GANs有巨大的潜力重塑我们每天与之互动的数字世界。这个领域还很年轻,下一个GAN新发现可能就来自你。

其他资料(英文)

1.2014年Ian Goodfellow和他的伙伴合作发表的GAN论文

Paper链接:
https://arxiv.org/abs/1406.2661

2.Goodfellow最近的一篇教程,通俗易懂地解释了GAN

Tutorial链接:
https://arxiv.org/abs/1701.00160

3.Alec Radford、Luke Metz和Soumith Chintala等人的论文,介绍了本教程中我们在生成器上使用的复杂GANs的基本结构

Paper链接:
https://arxiv.org/abs/1511.06434

GitHub上的DCGAN代码:
https://github.com/Newmu/dcgan_code