基于Keras的生成对抗网络(1)——利用Keras搭建简单GAN生成手写体数字
from keras.datasets import mnist
from keras.models import Sequential,Model
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization, Activation
from keras.layers import LeakyReLU
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt
import numpy as np
import os
class GAN():
def __init__(self):
#28,28,1
self.img_shape = (28,28,1)
self.latent_dim=100 #输入维度--100
optimizer = Adam(0.0002, 0.5)#定义Adam优化器
#判别器
self.discriminator=self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])#定义loss函数和优化器
#生成器
self.generator=self.build_generator()
gan_input = Input(shape=(self.latent_dim,))
img = self.generator(gan_input)
#训练生成器
# 在训练generator的时候不训练discriminator
self.discriminator.trainable = False#冻结discriminator
validity = self.discriminator(img)# 对生成的假图片进行预测
self.combined = Model(gan_input, validity)#在Model中discriminator已经被冻结,仅剩下generator在运行了
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)#定义loss函数和优化器
#定义生成器模型
#输入:100维向量;输出:28*28图像(像素大小为(-1,1)--tanh)
def build_generator(self):
model=Sequential()
noise=Input(shape=self.latent_dim)#输入
#第一层全连接层:784-->256
model.add(Dense(256, input_dim=self.latent_dim))#全连接
model.add(LeakyReLU(alpha=0.2))#激活函数
model.add(BatchNormalization(momentum=0.8))#标准化
#第二层全连接层:256-->512
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
#第三层全连接层:512-->1024
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
#第四层全连接层(输出层):1024-->784
model.add(Dense(np.prod(self.img_shape), activation='tanh'))#product--乘积,即28*28*1=784
model.add(Reshape(self.img_shape))
img = model(noise)
plot_model(model, to_file='', show_shapes=True, show_layer_names=False, rankdir='TB')#绘制模型
return Model(noise, img)
#定义判别器模型
#输入:28*28图像;输出:0-1数字
def build_discriminator(self):
model=Sequential()
img=Input(shape=self.img_shape)
model.add(Flatten(input_shape=self.img_shape))#铺平
#第一层全连接层:784-->512
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
#第二层全连接层:512-->256
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
#第三层全连接层(输出层):256-->1
model.add(Dense(1, activation='sigmoid'))
validity = model(img)
plot_model(model, to_file='', show_shapes=True, show_layer_names=False, rankdir='TB')#绘制模型
return Model(img,validity)
#定义训练函数
def train(self,epochs,batch_size=128,sample_interval=50):
#创建训练数据集
(X_train, _), (_, _) = mnist.load_data() # 从mnist数据集中获得数据
X_train =(X_train-0)/(255-0) # 标准化为(0,1)
X_train = np.expand_dims(X_train, axis=3) #60000*28*28变为60000*28*28*1;X_train的shape是60000*28*28*1
#创建标签(0或者1)--二分类问题
valid = np.ones((batch_size, 1))#全1阵--真
fake = np.zeros((batch_size, 1))#全0阵--假
for epoch in range(epochs):
#训练discriminator
idx=np.random.randint(0,X_train.shape[0],batch_size)#随机选择batch_size个数
imgs=X_train[idx]#真图像
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))#生成正态分布噪声
gen_imgs = self.generator.predict(noise)#假图像
#训练discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)#真图像训练
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)#假图像训练
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)#数据求和
#训练生成器
g_loss=self.combined.train_on_batch(noise,valid)
print("steps:%d [D loss: %f, acc.: %.2f%%] [G loss: %f]"% (epoch, d_loss[0], 100*d_loss[1],g_loss))
#迭代50次打印一次
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self,epoch):
r,c=5,5
noise=np.random.normal(0,1,(r*c,self.latent_dim))#生成随机噪声
gen_imgs = self.generator.predict(noise)#生成器预测的图像
gen_imgs = 0.5 * gen_imgs + 0.5 #将generator输出的(-1,1)像素值反归一化为(0,1)
#绘制5*5图像
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%" % epoch)
plt.close()
if __name__=='__main__':
if not os.path.exists("./images"):
os.makedirs("./images")
gan = GAN()
gan.train(epochs=30000, batch_size=256, sample_interval=200)