文章目录
导包
import keras
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
from keras.utils import to_categorical
from keras.models import Model
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
数据集(手写数字识别)
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 28, 28,1)
x_test = x_test.reshape(10000, 28, 28,1)
挑选数据:使用 “1” 这个类别的图片作为基准
indexes = np.where(y_train==1) ## “1” 类别的图片的所有索引
real_img_count = indexes[0].shape[0] ## 查看选出来的图片的数量 --> 6742 张
print(real_img_count)
target_imgs = x_train[indexes] ## 通过索引把 x_train 中所有的 “1” 拿出来
target_labels = y_train[indexes] ## 通过索引把 y_train 中所有的 “1” 拿出来
plt.imshow(target_imgs[2].reshape(28,28)) # 画个图看一下我们是不是真的拿到了想要的图片
- 最后一行的
reshape
只是因为 plt.imshow 要求的输出格式的原因。
核心代码
1.生成器
def generator():
noise = Input((100,))
x = Dense(2048,activation=None)(noise)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dense(1024,activation=None)(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dense(784,activation=None)(x)
x = LeakyReLU()(x)
x = Reshape((28,28,1))(x)
return Model(noise,x)
gen = generator()
- 看网上很多教程,把
generator
写的很复杂,还通过卷积层和upsampling
进行上采样,不利于初学者理解,其实generator
的构造就注意三点即可:- 一定要通过各种网络层产生足够的参数用来训练(既可以用卷积层也可以粗暴的用全连接层,这样对初学者搭建网络很友好),比如本文为了掩饰简单粗暴就使用了
Dense
层的堆积 - 保证最后一层全连接层的参数能够被
reshape
成你想要产生的图片的维度,比如我们要产生的图片维度是28*28*1
,那么我们就最后一个Dense
层设计为784
,如果你要产生的图片维度是32*32*3
,那最后一个Dense
层就设计为32*32*3 = 3072
。再通过最后的Reshape
层进行约束即可 - 最后一点是防止梯度的消失问题,尽量使用
LeakyRelu
而不用sigmoid
或者其他激活函数(这是公认的经验);为了收敛迅速使用BatchNormalization
层锦上添花
- 一定要通过各种网络层产生足够的参数用来训练(既可以用卷积层也可以粗暴的用全连接层,这样对初学者搭建网络很友好),比如本文为了掩饰简单粗暴就使用了
- 因为后面,generator 要包在 GAN 网络里面,通过 GAN 网络的训练来迭代 generator 的参数,所以 generator 不需要单独训练,因此不用在这里
compile
generator
- 我的
generator
的参数量我控制在 300 万左右,如下图所示:
2. 鉴别器
def discriminator():
inputs = Input((28,28,1))
x = Conv2D(filters=64,kernel_size=3,strides=2,padding='valid')(inputs)
x = Conv2D(filters=128,kernel_size=3,strides=2,padding='valid')(x)
x = Conv2D(filters=256,kernel_size=3,strides=2,padding='valid')(x)
x = Flatten()(x)
x = Dense(512,activation='relu')(x)
x = Dense(1,activation='sigmoid')(x)
return Model(inputs,x)
dis = discriminator()
dis.compile(optimizer=Adam(0.0001),loss=binary_crossentropy,metrics=['accuracy'])
-
鉴别器就更简单了,你只需要保证两点即可:
- 输入的维度是你图片的维度
- 产生足够的参数来进行训练
- 最后一个
Dense
层只有1
个神经元;因为discriminator
的作用就是把输入分类成0
或者1
,他认为真的样本他就给他1
,他认为假的样本就给0
;因此最后一层肯定要用sigmoid
函数归一化到0~1
之间
-
因为
discriminator
需要单独训练,因此在这里先compile
一下 -
我把鉴别器的参数量控制在
100 万
左右。
3. GAN 网络的构造
batch_size = 64
epochs = 10000
generated_img = [] # 每训练100个epochs就把一张图片放入这个列表,看整个过程产生的图片的差别
discriminator_loss = [] # 用来存放每个 epoch 的鉴别器的 loss
gan_loss = [] # 用来存放每个 epoch 的 GAN 网络的 loss,其实就代表生成器的 loss,因为训练 GAN 的时候冻结了鉴别器的参数
for epoch in range(epochs):
## 随机产生噪声,每个 epoch 产生一个 batchsize 这么多的 (100,)的噪声
noise = np.random.random((batch_size,100))
## 每个 epoch 都用生成器将噪声生成图片,这个图片是 fake_imgs,对应的标签就是 0,所以每个 epoch 标签的个数就是 batch_size 个
fake_img = gen.predict(noise) # fake_img 的维度 --> (batchsize,28,28,1)
fake_img_labels = np.zeros((batch_size,1)) # ---> 0 ,维度-->(batchsize,1)
## 每个 epoch 都从我们选出的 “1” 的所有图片中挑选随机的 batchsize 张图作为 real_imgs
pick_indexes = np.random.randint(0,real_img_count,(batch_size,)) # np.random.randint(0,6742,(64,))
real_img = target_imgs[pick_indexes] # real_imgs 的维度 --> (batchsize,28,28,1)
real_img_labels = np.ones((batch_size,1)) # ---> 1,维度-->(batchsize,1)
## 结合 fake_imgs 和 real_imgs 以及 他们的 label
data = np.concatenate([fake_img,real_img])
labels = np.concatenate([fake_img_labels,real_img_labels])
## 训练 discriminator
loss = dis.train_on_batch(data,labels)
discriminator_loss.append(loss[-1]) #每一个 batchsize中取最后一个 loss 作为这个 epoch 的loss
print("discriminator_loss:%f"%loss[-1])
## 训练 gan 网络,更新 generator 的参数
loss_gan = ga.train_on_batch(noise,real_img_labels)
print("loss_gan:%f"%loss_gan[-1])
gan_loss.append(loss_gan[-1])
print("epoch:%d" % epoch + "========")
## 每 100 个 epoch 就存一张图到 generated 列表中,10000个 epoch 就相当于存了100张图片进去
if epoch%100 == 0:
im = fake_img[0].reshape((28,28))
generated_img.append(im)
打印最后生成的结果
1. 前10张图的生成效果
for i in range(0,10):
plt.subplot(2,5,i+1) # 只打印前十张图
plt.imshow(generated_img[i])
2. 最后10张图的生成效果
for i in range(90,100):
plt.subplot(2,5,i-90+1) ## 只打印最后十张图
plt.imshow(generated_img[i])
- 可以看到,我们已经几乎成功地完成了 “1” 类图片的生成
打印训练过程的 loss
plt.plot(discriminator_loss)
plt.plot(gan_loss)
- 可以看看到,
GAN
在训练的时候,loss
一直在震荡,没有稳定的收敛,虽然没有收敛,但是依然产生了很好效果的数据,至于如何解决 loss 震荡和收敛的问题,大家可以去查阅资料作进一步的了解。
整体代码
- 为了给大家更加直观的整体效果,在这里附上整体的代码,我使用的是:
- jupyter notebook 来实现的
- keras-gpu 2.3.1;
- GTX1080Ti,
- 10000次 epoch 只需要 5 分钟就可以跑完~~
import keras from keras.layers import * from keras.optimizers import * from keras.losses import * from keras.utils import to_categorical from keras.models import Model from keras.datasets import mnist import numpy as np import matplotlib.pyplot as plt (x_train,y_train),(x_test,y_test) = mnist.load_data() x_train = x_train.reshape(60000, 28, 28,1) x_test = x_test.reshape(10000, 28, 28,1) # y_train = to_categorical(y_train) # y_test = to_categorical(y_test) ### 第一类的索引,选出第一类的数据以及其 labels indexes = np.where(y_train==1) ## “1” 类别的图片的所有索引 real_img_count = indexes[0].shape[0] ## 查看选出来的图片的数量 --> 6742 张 print(real_img_count) target_imgs = x_train[indexes] ## 通过索引把 x_train 中所有的 “1” 拿出来 target_labels = y_train[indexes] ## 通过索引把 y_train 中所有的 “1” 拿出来 plt.imshow(target_imgs[2].reshape(28,28)) # 画个图看一下我们是不是真的拿到了想要的图片 ### 建立 GAN 网络,用来生成第一类的图片 def generator(): noise = Input((100,)) x = Dense(2048,activation=None)(noise) x = LeakyReLU()(x) x = BatchNormalization()(x) x = Dense(1024,activation=None)(x) x = LeakyReLU()(x) x = BatchNormalization()(x) x = Dense(784,activation=None)(x) x = LeakyReLU()(x) x = Reshape((28,28,1))(x) return Model(noise,x) gen = generator() def discriminator(): inputs = Input((28,28,1)) x = Conv2D(filters=64,kernel_size=3,strides=2,padding='valid')(inputs) x = Conv2D(filters=128,kernel_size=3,strides=2,padding='valid')(x) x = Conv2D(filters=256,kernel_size=3,strides=2,padding='valid')(x) x = Flatten()(x) x = Dense(512,activation='relu')(x) x = Dense(1,activation='sigmoid')(x) return Model(inputs,x) dis = discriminator() dis.compile(optimizer=Adam(0.0001),loss=binary_crossentropy,metrics=['accuracy']) def gan(): inputs = Input((100,)) fake_img = gen(inputs) dis.trainable=False score = dis(fake_img) return Model(inputs,score) ## 这个 score 越接近1越好 ga = gan() ga.compile(optimizer=RMSprop(0.0001),loss=binary_crossentropy,metrics=['accuracy']) batch_size = 64 epochs = 10000 generated_img = [] discriminator_loss = [] gan_loss = [] for epoch in range(epochs): ## 随机产生噪声 noise = np.random.random((batch_size,100)) fake_img = gen.predict(noise) fake_img_labels = np.zeros((batch_size,1)) # ---> 0 pick_indexes = np.random.randint(0,real_img_count,(batch_size,)) real_img = target_imgs[pick_indexes] real_img_labels = np.ones((batch_size,1)) # ---> 1 ## 结合 fake_imgs 和 real_imgs 以及 他们的 label data = np.concatenate([fake_img,real_img]) labels = np.concatenate([fake_img_labels,real_img_labels]) ## 训练 discriminator loss = dis.train_on_batch(data,labels) discriminator_loss.append(loss[-1]) print("discriminator_loss:%f"%loss[-1]) ## 训练 gan 网络,更新 generator 的参数 loss_gan = ga.train_on_batch(noise,real_img_labels) print("loss_gan:%f"%loss_gan[-1]) gan_loss.append(loss_gan[-1]) print("epoch:%d" % epoch + "========") if epoch%100 == 0: im = fake_img[0].reshape((28,28)) generated_img.append(im) for i in range(0,10): plt.subplot(2,5,i+1) plt.imshow(generated_img[i]) for i in range(90,100): plt.subplot(2,5,i-90+1) plt.imshow(generated_img[i]) plt.plot(discriminator_loss) plt.plot(gan_loss)
牛刀小试:生成 fashion_mnist数据集中的图片
- 结果还是很好的~
- 一共10000个 epoch
- 每1000 个 epoch 画一张图