WGAN-GP的Tensorflow实现

相对于有监督学习来说,生成对抗网络是一个很酷的算法。它利用一个生成器和一个判别器来对数据集进行学习。拿生成图片的例子来说,生成器负责生成一个图片,判别器的功能就是来判别这个图片的真假。简而言之,就是判别这张生成的图片是否和数据集中的图片很像。

生成对抗网络算法的思想是这样的,所有原始数据集中的图片判定为真,做法就是标记为1 。而所有生成器生成的图片就判定为假,做法就是标记为0 。上面说的就是判别器的主要功能,它就是不遗余力的把数据集中的原始图片和生成器生成的图片区分开来。

而生成器的作用呢?生成器就是不遗余力的生成一张越来越逼近数据集样式的图片,让判别器来把这张图片判定为真。

可以看出,生成器和判别器是一种相爱相杀的存在。而最终的状态是怎样的呢,最终就是生成器可以生成一张非常逼真的图片,而判别器也无法判别这张图片的真假。也是就判别器的正确率是0.5 。这就是生成对抗网络训练的最佳状态。

说到生成对抗网络,就不得不提DCGAN(深度卷积生成对抗网络),这个算法的优点就是算法好理解,比较容易学习,但缺点也很明显。就是很容易出现模式崩塌现象,最后生成的样本很单一。

所以WGAN-GP就应运而生,它采用一种新的距离度量方法,叫作EM距离。原先DCGAN采用的是JS散度度量方法,JS散度的度量方法在特殊情况下会出现梯度为0的现象,梯度一旦为0,就会导致梯度不下降,因此梯度长时间得不到收敛,于是就出现了梯度弥散的情况。

WGAN-GP采用的EM距离很好的解决了训练过程中梯度为0的现象,另外WGAN-GP还加入了一个梯度惩罚的东西,大大加强了WGAN-GP的训练稳定性。

废话不说了,下面是我的训练过程。我采用的数据集是cifar10

首先加载一些必要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,losses,optimizers,datasets
import matplotlib
from matplotlib import pyplot as plt
import numpy as np

下面是生成器的代码

class Generator(keras.Model):
  def __init__(self):
    super(Generator,self).__init__()

    filter = 32

    #第一个卷积层
    self.conv1 = layers.Conv2DTranspose(filter*8,4,1,'valid',use_bias=False)
    self.bn1 = layers.BatchNormalization()

    #第二个卷积层
    self.conv2 = layers.Conv2DTranspose(filter*4,4,2,'same',use_bias=False)
    self.bn2 = layers.BatchNormalization()

    #第三个卷积层
    self.conv3 = layers.Conv2DTranspose(filter*2,4,2,'same',use_bias=False)
    self.bn3 = layers.BatchNormalization()

    #第四个卷积层
    self.conv4 = layers.Conv2DTranspose(3,4,2,'same',use_bias=False)

  def call(self, inputs, training=None):
    x = inputs
    x = tf.reshape(x,[x.shape[0],1,1,x.shape[1]])

    #卷积-bn-激活:(b,4,4,256)
    x = tf.nn.relu(self.bn1(self.conv1(x), training=training))
    #卷积-bn-激活:(b,8,8,128)
    x = tf.nn.relu(self.bn2(self.conv2(x), training=training))
    #卷积-bn-激活:(b,16,16,64)
    x = tf.nn.relu(self.bn3(self.conv3(x), training=training))
    #卷积:(b,32,32,3)
    x = self.conv4(x)

    #使用tanh激活函数,它的值范围是-1~1,与数据集中的数据保持一致
    x = tf.nn.tanh(x)

    return x

因为cifar10中的图片大小是32x32,所以生成器生成的图片大小也是32x32,3个通道。

下面再看判别器的代码

class Discriminator(keras.Model):
  def __init__(self):
    super(Discriminator,self).__init__()

    filter = 32

    #第一个卷积层
    self.conv1 = layers.Conv2D(filter*2,4,2,'same',use_bias=False)
    self.bn1 = layers.BatchNormalization()

    #第二个卷积层
    self.conv2 = layers.Conv2D(filter*4,4,2,'same',use_bias=False)
    self.bn2 = layers.BatchNormalization()

    #第三个卷积层
    self.conv3 = layers.Conv2D(filter*8,4,2,'same',use_bias=False)
    self.bn3 = layers.BatchNormalization()

    #第四个卷积层
    self.conv4 = layers.Conv2D(filter*16,4,2,'same',use_bias=False)
    self.bn4 = layers.BatchNormalization()

    self.conv5 = layers.Conv2D(1,2,1,'valid',use_bias=False)

    #池化层
    # self.pool = layers.GlobalAveragePooling2D()
    # #打平层
    # self.flatten = layers.Flatten()

    # #输出层
    # self.fc = layers.Dense(1)

  def call(self, inputs, training=None):
    #卷积-bn-激活:(b,16,16,64)
    x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training))
    #卷积-bn-激活:(b,8,8,128)
    x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
    #卷积-bn-激活:(b,4,4,256)
    x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
    #卷积-bn-激活:(b,2,2,512)
    x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training))


    # x = self.pool(x)
    # x = self.flatten(x)

    # logits = self.fc(x)

    x = self.conv5(x)
    logits = tf.reshape(x,[x.shape[0],-1])

    return logits

判别器的输出层最开始用的是全连接层,运行后发现效果不好,极容易出现过拟合现象,后来把输出层换成卷积层,效果就好很多了。

生成器和判别器都搞好后,下面测试一下看是否有错误:

g = Generator()
d = Discriminator()

z = tf.random.normal([64,100])
x_hat = g(z)

print(x_hat.shape)

out = d(x_hat)
print(out.shape)


#输出
(64, 32, 32, 3)
(64, 1)

结果显示,生成器和判别器的输出没有问题。

下面是WGAN-GP的梯度惩罚函数:

def gradient_penalty(discriminator, batch_x, fake_image):
  #梯度惩罚项计算函数
  batchsz = batch_x.shape[0]

  #每个样本均随机采样t,用于插值
  t = tf.random.uniform([batchsz,1,1,1])
  #自动扩展为x的形状:[b,1,1,1] => [b,h,w,c]
  t = tf.broadcast_to(t, batch_x.shape)
  #在真假图片之间作线性插值
  interplate = t * batch_x + (1-t) * fake_image
  #在梯度环境中计算 D 对插值样本的梯度
  with tf.GradientTape() as tape:
    tape.watch([interplate])
    d_interplate_logits = discriminator(interplate)

  grads = tape.gradient(d_interplate_logits, interplate)

  #计算每个样本的梯度的范数:[b,h,w,c] => [b,-1]
  grads = tf.reshape(grads, [grads.shape[0],-1])
  gp = tf.norm(grads,axis=1)
  #计算梯度惩罚项
  gp = tf.reduce_mean((gp-1.)**2)

  return gp

梯度惩罚项最终是为判别器的损失函数服务的。判别器的损失函数代码如下:

def d_loss_fn(generator,discriminator,batch_z,batch_x,is_training):
  #计算判别器的损失函数

  #生成样本
  fake_image = generator(batch_z,is_training)
  #判别生成样本
  d_fake_logits = discriminator(fake_image,is_training)
  #判别真实样本
  d_real_logits = discriminator(batch_x,is_training)
  #计算梯度惩罚项
  gp = gradient_penalty(discriminator,batch_x,fake_image)

  loss = tf.reduce_mean(d_fake_logits) - tf.reduce_mean(d_real_logits) + 10.*gp

  return loss,gp

既然有判别器的损失函数,那么必然会有生成器的损失函数,生成器的损失函数比较好理解,代码如下:

def g_loss_fn(generator,discriminator,batch_z,is_training):
  #计算生成器的损失函数
  fake_image = generator(batch_z,is_training)
  d_fake_logits = discriminator(fake_image,is_training)
  #最大化假样本的输出值
  loss = -tf.reduce_mean(d_fake_logits)

  return loss

下面,我们准备一下训练过程中要用到的超参数:

epochs = 3000000#训练回合数
batch_sz = 64#批处理大小
learn_rate = 0.0001#学习率
is_training = True#是否训练的标志
z_dim = 100#采样向量的长度

再准备一下要训练的数据集,前面说过,我用的数据集是cifar10.
首先加载数据集:

(x_train,y_train),(x_test,y_test) = datasets.cifar10.load_data()
print(x_train.shape)

#输出:
(50000, 32, 32, 3)

然后再做进一步处理:

def preprocess(x):
  #预处理函数
  x = tf.cast(x,dtype=tf.float32)/127.5 - 1 
  return x

#开始处理数据
train_db = tf.data.Dataset.from_tensor_slices((x_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_sz)

sample = next(iter(train_db))
print(sample.shape)

#输出:
(64, 32, 32, 3)

数据集准备好后,我们来准备生成器和判别器的实例:

#生成器
generator = Generator()
#判别器
discriminator = Discriminator()

generator.build(input_shape=(4,z_dim))
generator.summary()

discriminator.build(input_shape=(4,32,32,3))
discriminator.summary()

#g_optimizer = optimizers.Adam(learn_rate,beta_1=0.5)
#d_optimizer = optimizers.Adam(learn_rate,beta_1=0.5)

#生成器的优化器
g_optimizer = optimizers.RMSprop(learn_rate)
#判别器的优化器
d_optimizer = optimizers.RMSprop(learn_rate)

下面,就开始正式训练了:

for epoch in range(epochs):
  for _ in range(1):
    batch_z = tf.random.normal([batch_sz,z_dim])
    batch_x = next(iter(train_db))

    with tf.GradientTape() as tape:
      d_loss = d_loss_fn(generator,discriminator,batch_z,batch_x,is_training)[0]
    grads = tape.gradient(d_loss,discriminator.trainable_variables)
    d_optimizer.apply_gradients(zip(grads,discriminator.trainable_variables))

  batch_z = tf.random.normal([batch_sz,z_dim])
  with tf.GradientTape() as tape:
    g_loss = g_loss_fn(generator,discriminator,batch_z,is_training)
  grads = tape.gradient(g_loss,generator.trainable_variables)
  g_optimizer.apply_gradients(zip(grads,generator.trainable_variables))

  if epoch!=0 and epoch %100 ==0:
    print(epoch, ' d-loss:',float(d_loss), ' g_loss:',float(g_loss))

    batch_z = tf.random.normal([100,z_dim])
    fake_images = generator(batch_z,False)
    save_image(fake_images.numpy(), 'sample_data/ganimgs3/wgan-gp%d.png'%epoch)

代码思路是每训练一次判别器就训练一次生成器,同时每100个回合生成一次图片,并保存图片,保存图片代码如下:

from PIL import Image

def save_image(imgs,name):
  new_imgs = Image.new('RGB',(320,320))

  index = 0
  for i in range(0,320,32):
    for j in range(0,320,32):

      img = imgs[index]
      img = ((img+1)*127.5).astype(np.uint8)
      img = Image.fromarray(img,mode='RGB')

      new_imgs.paste(img,(i,j))

      index += 1

  new_imgs.save(name)

最终生成的图片效果如下所示:

wgan-gp17900.png

而数据集的图片如下所示:

数据集真实图片

对比可以发现,生成的图片清晰度还是不够高,说明模型还是有很多可以改进的地方的,希望看到的大神朋友可以教教我改进的方法。



作者:climb66的夏天
链接:https://www.jianshu.com/p/e8e1bb5b3af8
来源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值