基于Keras的生成对抗网络(1)——利用Keras搭建简单GAN生成手写体数字

时间:2024-09-29 15:21:33
from keras.datasets import mnist from keras.models import Sequential,Model from keras.layers import Input, Dense, Reshape, Flatten from keras.layers import BatchNormalization, Activation from keras.layers import LeakyReLU from keras.optimizers import Adam from keras.utils.vis_utils import plot_model import matplotlib.pyplot as plt import numpy as np import os class GAN(): def __init__(self): #28,28,1 self.img_shape = (28,28,1) self.latent_dim=100 #输入维度--100 optimizer = Adam(0.0002, 0.5)#定义Adam优化器 #判别器 self.discriminator=self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])#定义loss函数和优化器 #生成器 self.generator=self.build_generator() gan_input = Input(shape=(self.latent_dim,)) img = self.generator(gan_input) #训练生成器 # 在训练generator的时候不训练discriminator self.discriminator.trainable = False#冻结discriminator validity = self.discriminator(img)# 对生成的假图片进行预测 self.combined = Model(gan_input, validity)#在Model中discriminator已经被冻结,仅剩下generator在运行了 self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)#定义loss函数和优化器 #定义生成器模型 #输入:100维向量;输出:28*28图像(像素大小为(-1,1)--tanh) def build_generator(self): model=Sequential() noise=Input(shape=self.latent_dim)#输入 #第一层全连接层:784-->256 model.add(Dense(256, input_dim=self.latent_dim))#全连接 model.add(LeakyReLU(alpha=0.2))#激活函数 model.add(BatchNormalization(momentum=0.8))#标准化 #第二层全连接层:256-->512 model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) #第三层全连接层:512-->1024 model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) #第四层全连接层(输出层):1024-->784 model.add(Dense(np.prod(self.img_shape), activation='tanh'))#product--乘积,即28*28*1=784 model.add(Reshape(self.img_shape)) img = model(noise) plot_model(model, to_file='', show_shapes=True, show_layer_names=False, rankdir='TB')#绘制模型 return Model(noise, img) #定义判别器模型 #输入:28*28图像;输出:0-1数字 def build_discriminator(self): model=Sequential() img=Input(shape=self.img_shape) model.add(Flatten(input_shape=self.img_shape))#铺平 #第一层全连接层:784-->512 model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) #第二层全连接层:512-->256 model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) #第三层全连接层(输出层):256-->1 model.add(Dense(1, activation='sigmoid')) validity = model(img) plot_model(model, to_file='', show_shapes=True, show_layer_names=False, rankdir='TB')#绘制模型 return Model(img,validity) #定义训练函数 def train(self,epochs,batch_size=128,sample_interval=50): #创建训练数据集 (X_train, _), (_, _) = mnist.load_data() # 从mnist数据集中获得数据 X_train =(X_train-0)/(255-0) # 标准化为(0,1) X_train = np.expand_dims(X_train, axis=3) #60000*28*28变为60000*28*28*1;X_train的shape是60000*28*28*1 #创建标签(0或者1)--二分类问题 valid = np.ones((batch_size, 1))#全1阵--真 fake = np.zeros((batch_size, 1))#全0阵--假 for epoch in range(epochs): #训练discriminator idx=np.random.randint(0,X_train.shape[0],batch_size)#随机选择batch_size个数 imgs=X_train[idx]#真图像 noise = np.random.normal(0, 1, (batch_size, self.latent_dim))#生成正态分布噪声 gen_imgs = self.generator.predict(noise)#假图像 #训练discriminator d_loss_real = self.discriminator.train_on_batch(imgs, valid)#真图像训练 d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)#假图像训练 d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)#数据求和 #训练生成器 g_loss=self.combined.train_on_batch(noise,valid) print("steps:%d [D loss: %f, acc.: %.2f%%] [G loss: %f]"% (epoch, d_loss[0], 100*d_loss[1],g_loss)) #迭代50次打印一次 if epoch % sample_interval == 0: self.sample_images(epoch) def sample_images(self,epoch): r,c=5,5 noise=np.random.normal(0,1,(r*c,self.latent_dim))#生成随机噪声 gen_imgs = self.generator.predict(noise)#生成器预测的图像 gen_imgs = 0.5 * gen_imgs + 0.5 #将generator输出的(-1,1)像素值反归一化为(0,1) #绘制5*5图像 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 fig.savefig("images/%" % epoch) plt.close() if __name__=='__main__': if not os.path.exists("./images"): os.makedirs("./images") gan = GAN() gan.train(epochs=30000, batch_size=256, sample_interval=200)