Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的。
目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接层,卷积层,池化层等等。对于需要对网络本身做创新的实验,keas可能不是很方便,还是得用tensorflow来搭建。
这篇博客,我想用Keras写一个简单的生成对抗网络。
生成对抗网络的目标是生成手写体数字。
先看看实验的效果:
epoch=1000的时候:
epoch=10000的时候:数字1已经有点像了
epoch=60000,数字1就很清晰了 ,而且其他数字也越来越清晰了
epoch=80000: 生成了5,7 啥的了。
随着训练的加深,生成的数字会越来越真实了。
代码已经开源,项目地址:
https://github.com/jmhIcoding/GAN_MNIST.git
模型原理
模型原理就不说了,就是使用最基础GAN结构。
模型由一个生成器和一个鉴别器组成。
生成器用于输入噪声,然后生成一个手写体数字图片。
鉴别器用于判断某个输入给它的图片是不是生成器合成的。
生成器的目标是生成让鉴别器判断为非合成的图片。
鉴别器的目标则是以尽量高的正确率分类某种图片是否为合成的。
总的原理就是这些了。
模型的损失函数就是围绕着这两个目标来展开的。
模型编写
生成器
__author__ = 'dk'
#生成器
import sys
import numpy as np
import keras
from keras import layers
from keras import models
from keras import optimizers
from keras import losses
class Generator:
def __init__(self,height=28,width=28,channel=1,latent_space_dimension=100):
'''
:param height: 生成图片的高,minist为28
:param width: 生成图片的宽,minist为28
:param channel: 生成器所生成的图片的通道数目,对于mnist灰度图来说,channel为1
:param latent_space_dimension: 噪声的维度
:return:
'''
self.latent_space_dimension = latent_space_dimension
self.height = height
self.width = width
self.channel = channel
self.generator = self.build_model()
self.generator.summary()
def build_model(self,block_starting_size=128,num_blocks=4):
model = models.Sequential(name='generator')
for i in range(num_blocks):
if i ==0 :
model.add(layers.Dense(block_starting_size,input_shape=(self.latent_space_dimension,)))
else:
block_size = block_starting_size * (2**i)
model.add(layers.Dense(block_size))
model.add(layers.LeakyReLU())
model.add(layers.BatchNormalization(momentum=0.75))
model.add(layers.Dense(self.height*self.channel*self.width,activation='tanh'))
model.add(layers.Reshape((self.width,self.height,self.channel)))
return model
def summary(self):
self.model.summary()
def save_model(self):
self.generator.save("generator.h5")
注意,generator是和整个模型一起训练的,它可以不需要compile模型。
鉴别器
__author__ = 'dk'
#判别器
import sys
import os
import keras
from keras import layers
from keras import optimizers
from keras import models
from keras import losses
class Discriminator:
def __init__(self,height=28,width=28,channel=1):
'''
:param height: 输入图片的高
:param width: 输入图片的宽
:param channel: 输入图片的通道数
:return:
'''
self.height = height
self.width = width
self.channel = channel
self.discriminator = self.build_model()
OPTIMIZER = optimizers.Adam()
self.discriminator = self.build_model()
self.discriminator.compile(optimizer=OPTIMIZER,loss=losses.binary_crossentropy,metrics =['accuracy'])
self.discriminator.summary()
def build_model(self):
model = models.Sequential(name='discriminator')
model.add(layers.Flatten(input_shape=(self.width,self.height,self.channel)))
model.add(layers.Dense(self.height*self.width*self.channel,input_shape=(self.width,self.height,self.channel)))
model.add(layers.LeakyReLU(0.2))
model.add(layers.Dense(self.height*self.width*self.channel//2))
model.add(layers.LeakyReLU(0.2))
model.add(layers.Dense(1,activation='sigmoid'))
return model
def summary(self):
return self.discriminator.summary()
def save_model(self):
self.discriminator.save("discriminator.h5")
gan网络
把生成器和鉴别器合并起来
__author__ = 'dk'
#生成对抗网络
import keras
from keras import layers
from keras import optimizers
from keras import losses
from keras import models
import sys
import os
from Discriminator import Discriminator
from Generator import Generator
class GAN:
def __init__(self,latent_space_dimension,height,width,channel):
self.generator = Generator(height,width,channel,latent_space_dimension)
self.discriminator = Discriminator(height,width,channel)
self.discriminator.discriminator.trainable = False
#gan部分,只训练生成器,鉴别器通过显式discriminator.train_on_batch调用来训练
self.gan = self.build_model()
OPTIMIZER = optimizers.Adamax()
self.gan.compile(optimizer = OPTIMIZER,loss = losses.binary_crossentropy)
self.gan.summary()
def build_model(self):
model = models.Sequential(name='gan')
model.add(self.generator.generator)
model.add(self.discriminator.discriminator)
return model
def summary(self):
self.gan.summary()
def save_model(self):
self.gan.save("gan.h5")
数据准备模块
__author__ = 'dk'
#数据集采集器,主要是对mnist进行简单的封装
from keras.datasets import mnist
import numpy as np
def sample_latent_space(instances_number,latent_space_dimension):
return np.random.normal(0,1,(instances_number,latent_space_dimension))
class Dator:
def __init__(self,batch_size=None,model_type=1):
'''
:param batch_size:
:param model_type: 当model_type为-1的时候,表示0-9个数字都选;当model_type=2,说明只选择数字2
:return:
'''
self.batch_size = batch_size
self.model_type = model_type
with np.load("mnist.npz", allow_pickle=True) as f:
X_train, y_train = f['x_train'], f['y_train']
#X_test, y_test = f['x_test'], f['y_test']
if model_type != -1:
X_train = X_train[np.where(y_train==model_type)[0]]
if batch_size == None:
self.batch_size = X_train.shape[0]
else:
self.batch_size = batch_size
self.X_train = (np.float32(X_train)-128)/128.0
self.X_train = np.expand_dims(self.X_train,3)
self.watch_index = 0
self.train_size = self.X_train.shape[0]
def next_batch(self,batch_size = None):
if batch_size == None:
batch_size =self.batch_size
X=np.concatenate([self.X_train[self.watch_index:(self.watch_index+batch_size)], self.X_train[:batch_size]])[:batch_size]
self.watch_index = (self.watch_index + batch_size) % self.train_size
return X
if __name__ == '__main__':
print(sample_latent_space(5,4))
训练main脚本:train.py
__author__ = 'dk'
#模型训练代码
from GAN import GAN
from data_utils import Dator,sample_latent_space
import numpy as np
from matplotlib import pyplot as plt
import time
epochs = 50000
height = 28
width = 28
channel =1
latent_space_dimension = 100
batch = 128
dator = Dator(batch_size=batch,model_type=-1)
gan = GAN(latent_space_dimension,height,width,channel)
image_index = 0
for i in range(epochs):
real_img = dator.next_batch(batch_size=batch*2)
real_label = np.ones(shape=(real_img.shape[0],1)) #真实的样本设置为1的标签
noise = sample_latent_space(real_img.shape[0],latent_space_dimension)
fake_img = gan.generator.generator.predict(noise)
fake_label = np.zeros(shape=(fake_img.shape[0],1)) #生成器生成的假图片标注为0
###合成给gan的鉴别器的数据
x_batch = np.concatenate([real_img,fake_img])
y_batch = np.concatenate([real_label,fake_label])
#训练一次
discriminator_loss = gan.discriminator.discriminator.train_on_batch(x_batch,y_batch)[0]
###注意,此时训练的是鉴别器,生成器部分不动。
###合成训练生成器的数据
noise = sample_latent_space(batch*2,latent_space_dimension)
noise_labels = np.ones((batch*2,1))
#生成器的目标是把图片的label越来越像1
generator_loss = gan.gan.train_on_batch(noise,noise_labels)
print('Epoch : {0}, [Discriminator Loss:{1} ], [Generator Loss:{2}]'.format(i,discriminator_loss,generator_loss))
if i!=0 and (i%50)==0:
print('show time')
#每50次输入16张图片看看效果
noise = sample_latent_space(16,latent_space_dimension)
images = gan.generator.generator.predict(noise)
plt.figure(figsize=(10,10))
plt.suptitle('epoch={0}'.format(i),fontsize=16)
for index in range(images.shape[0]):
plt.subplot(4,4,index+1)
image =images[index,:,:,:]
image = image.reshape(height,width)
plt.imshow(image,cmap='gray')
#plt.tight_layout()
plt.savefig("./show_time/{0}.png".format(time.time()))
image_index += 1
plt.close()
运行脚本
python3 train.py
即可。
输出:
Model: "generator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 128) 12928
_________________________________________________________________
dense_2 (Dense) (None, 256) 33024
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 256) 0
_________________________________________________________________
batch_normalization_1 (Batch (None, 256) 1024
_________________________________________________________________
dense_3 (Dense) (None, 512) 131584
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 512) 0
_________________________________________________________________
batch_normalization_2 (Batch (None, 512) 2048
_________________________________________________________________
dense_4 (Dense) (None, 1024) 525312
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 1024) 0
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024) 4096
_________________________________________________________________
dense_5 (Dense) (None, 784) 803600
_________________________________________________________________
reshape_1 (Reshape) (None, 28, 28, 1) 0
=================================================================
Total params: 1,513,616
Trainable params: 1,510,032
Non-trainable params: 3,584
_________________________________________________________________
Model: "discriminator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_2 (Flatten) (None, 784) 0
_________________________________________________________________
dense_9 (Dense) (None, 784) 615440
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 784) 0
_________________________________________________________________
dense_10 (Dense) (None, 392) 307720
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 392) 0
_________________________________________________________________
dense_11 (Dense) (None, 1) 393
=================================================================
Total params: 923,553
Trainable params: 923,553
Non-trainable params: 0
_________________________________________________________________
Model: "gan"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
generator (Sequential) (None, 28, 28, 1) 1513616
_________________________________________________________________
discriminator (Sequential) (None, 1) 923553
=================================================================
Total params: 2,437,169
Trainable params: 1,510,032
Non-trainable params: 927,137
_________________________________________________________________
····
···
··
Epoch : 117754, [Discriminator Loss:0.22975191473960876 ], [Generator Loss:2.57688570022583]
Epoch : 117755, [Discriminator Loss:0.26782122254371643 ], [Generator Loss:3.1791584491729736]
Epoch : 117756, [Discriminator Loss:0.2609345614910126 ], [Generator Loss:2.960988998413086]
Epoch : 117757, [Discriminator Loss:0.2673880159854889 ], [Generator Loss:2.317220687866211]
Epoch : 117758, [Discriminator Loss:0.24904575943946838 ], [Generator Loss:1.929720401763916]
Epoch : 117759, [Discriminator Loss:0.25158950686454773 ], [Generator Loss:2.954155683517456]
Epoch : 117760, [Discriminator Loss:0.20324105024337769 ], [Generator Loss:3.5244760513305664]
Epoch : 117761, [Discriminator Loss:0.2849388122558594 ], [Generator Loss:3.195873498916626]
Epoch : 117762, [Discriminator Loss:0.19631560146808624 ], [Generator Loss:2.328411340713501]
Epoch : 117763, [Discriminator Loss:0.20523831248283386 ], [Generator Loss:2.402683973312378]
Epoch : 117764, [Discriminator Loss:0.2625979781150818 ], [Generator Loss:3.2176101207733154]
Epoch : 117765, [Discriminator Loss:0.29969191551208496 ], [Generator Loss:2.9656052589416504]
Epoch : 117766, [Discriminator Loss:0.270328551530838 ], [Generator Loss:2.3880398273468018]
Epoch : 117767, [Discriminator Loss:0.26741161942481995 ], [Generator Loss:2.7729406356811523]
Epoch : 117768, [Discriminator Loss:0.28797847032546997 ], [Generator Loss:2.8959264755249023]
Epoch : 117769, [Discriminator Loss:0.30181047320365906 ], [Generator Loss:2.791097402572632]
Epoch : 117770, [Discriminator Loss:0.26939862966537476 ], [Generator Loss:2.3666043281555176]
Epoch : 117771, [Discriminator Loss:0.26297527551651 ], [Generator Loss:2.895970582962036]
Epoch : 117772, [Discriminator Loss:0.21928083896636963 ], [Generator Loss:3.4627976417541504]
Epoch : 117773, [Discriminator Loss:0.3553962707519531 ], [Generator Loss:3.2194197177886963]
Epoch : 117774, [Discriminator Loss:0.32673510909080505 ], [Generator Loss:2.473867893218994]
Epoch : 117775, [Discriminator Loss:0.31245478987693787 ], [Generator Loss:2.999265193939209]
Epoch : 117776, [Discriminator Loss:0.29536381363868713 ], [Generator Loss:3.733344554901123]
Epoch : 117777, [Discriminator Loss:0.2955515682697296 ], [Generator Loss:3.2467658519744873]
Epoch : 117778, [Discriminator Loss:0.3677394986152649 ], [Generator Loss:1.8517814874649048]
Epoch : 117779, [Discriminator Loss:0.31648850440979004 ], [Generator Loss:2.6385254859924316]
Epoch : 117780, [Discriminator Loss:0.31941041350364685 ], [Generator Loss:3.350475311279297]
Epoch : 117781, [Discriminator Loss:0.47521263360977173 ], [Generator Loss:1.9556307792663574]
Epoch : 117782, [Discriminator Loss:0.44070643186569214 ], [Generator Loss:1.9684114456176758]