PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)-1. GAN

时间:2024-01-26 15:48:31

生成对抗网络 (Generative Adversarial Networks, GAN) 包含两个网络:生成网络( Generator,也称生成器)和判别网络( discriminator,也称判别器)。在 GAN 网络训练过程中,需要有一个合理的图像样本数据集,生成网络从图像样本中学习图像表示,然后生成与图像样本相似的图像。判别网络接收(由生成网络)生成的图像和原始图像样本作为输入,并将图像分类为原始(真实)图像或生成(伪造)图像。
生成网络的目标是生成逼真的伪造图像骗过判别网络,判别网络的目标是将生成的图像分类为伪造图像,将原始图像样本分类为真实图像。本质上,GAN 中的对抗表示两个网络的相反性质,生成网络生成图像来欺骗判别网络,判别网络通过判别图像是生成图像还是原始图像来对输入图像进行分类:

GAN原理

在上图中,生成网络根据输入随机噪声生成图像,判别网络接收生成网络生成的图像,并将它们与真实图像样本进行比较,以判断生成的图像是真实的还是伪造的。生成网络尝试生成尽可能逼真的图像,而判别网络尝试判定生成网络生成图像的真实性,从而学习生成尽可能逼真的图像。
GAN 的关键思想是生成网络和判别网络之间的竞争和动态平衡,通过不断的训练和迭代,生成网络和判别网络会逐渐提高性能,生成网络能够生成更加逼真的样本,而判别网络则能够更准确地区分真实和伪造的样本。
通常,生成网络和判别网络交替训练,将生成网络和判别网络视为博弈双方,并通过两者之间的对抗来推动模型性能的提升,直到生成网络生成的样本能够以假乱真,判别网络无法分辨真实样本和生成样本之间的差异:

  • 生成网络的训练过程:冻结判别网络权重,生成网络以噪声 z 作为输入,通过最小化生成网络与真实数据之间的差异来学习如何生成更好的样本,以便判别网络将图像分类为真实图像
  • 判别网络的训练过程:冻结生成网络权重,判别网络通过最小化真实样本和假样本之间的分类误差来更新判别网络,区分真实样本和生成样本,将生成网络生成的图像分类为伪造图像

重复训练生成网络与判别网络,直到达到平衡,当判别网络能够很好地检测到生成的图像时,生成网络对应的损失比判别网络对应的损失要高得多。通过不断训练生成网络和判别网络,直到生成网络可以生成逼真图像,而判别网络无法区分真实图像和生成图像。