一百行代码实现一个GAN网络

GAN:对抗性生成网络,通俗来讲,即有两个网络一个是g(generator )网络,用于生成,一个是d(discriminator)网络,用于判断。
GAN网络的目的就是使其自己生成一副图片,比如说经过对一系列猫的图片的处理,g网络可以自己“绘制”出一张猫的图片,且尽量真实。
d网络则是用来进行判断的,将一张真实的图片和一张由g网络生成的照片同时交给d网络,不断训练d网络,使其可以准确判断,将d网络生成的“假图片”找出来。
再回到两个网络上,g网络不断改进使其可以骗过d网络,而d网络不断改进使其可以更准确找到“假图片”,这种相互促进相互对抗的关系,就叫做对抗网络。

我们可以使用tensorflow中的mnist手写体数据来进行实现。
实现原理如下:
将一张随机像素的图片经过一个全连接层后经过一个Leaky ReLU处理,之后为了避免过拟合dropout后再经过一个全连接层进行tanh激活后,生成一张“假图片”
def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    with tf.variable_scope("generator", reuse=reuse):
        hidden1 = tf.layers.dense(noise_img, n_units)  # 全连接层
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.tanh(logits)
        return logits, outputs

将待判定的图片经过全连接层-->Leaky ReLU-->全连接层-->sigmoid激活函数处理后,得到0或1的结果。
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    with tf.variable_scope("discriminator", reuse=reuse):
        hidden1 = tf.layers.dense(img, n_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        logits = tf.layers.dense(hidden1, 1)
        outputs = tf.sigmoid(logits)
        return logits, outputs

在实现时,我们可以首先把MNIST数据中的标签为0的图像提取出来,存到列表中。
i = j = 0
while i<5000:
    if mnist.train.labels[j] == 0:
        samples.append(mnist.train.images[j])
        i += 1
    j += 1

这样就可以在训练时只训练标签为0的图像。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np

mnist = input_data.read_data_sets("D:/python/MNIST_data/")
img = mnist.train.images[50]


def get_inputs(real_size, noise_size):
    real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
    noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
    return real_img, noise_img


# 生成
def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    with tf.variable_scope("generator", reuse=reuse):
        hidden1 = tf.layers.dense(noise_img, n_units)  # 全连接层
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.tanh(logits)
        return logits, outputs


# 判别
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    with tf.variable_scope("discriminator", reuse=reuse):
        hidden1 = tf.layers.dense(img, n_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        logits = tf.layers.dense(hidden1, 1)
        outputs = tf.sigmoid(logits)
        return logits, outputs

img_size = mnist.train.images[0].shape[0]
noise_size = 100
g_units = 128
d_units = 128
alpha = 0.01
learning_rate = 0.001
smooth = 0.1
tf.reset_default_graph()
real_img, noise_img = get_inputs(img_size, noise_size)
g_logits, g_outputs = get_generator(noise_img, g_units, img_size)

d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_real, labels=tf.ones_like(d_logits_real)
) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
) * (1 - smooth))

train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)


epochs = 5000
samples = []
n_sample = 10
losses = []

i = j = 0
while i<5000:
    if mnist.train.labels[j] == 0:
        samples.append(mnist.train.images[j])
        i += 1
    j += 1

print(len(samples))
size = samples[0].size

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for e in range(epochs):
        batch_images = samples[e] * 2 -1
        batch_noise = np.random.uniform(-1, 1, size=noise_size)

        _ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
        _ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})

    sample_noise = np.random.uniform(-1, 1, size=noise_size)
    g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
                                         reuse=True), feed_dict={
        noise_img:[sample_noise]
    })
    print(g_logit.size)
    g_output = (g_output+1)/2
    plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
    plt.show()

运行结果:

可以看出,在经过了5000次的迭代后,g网络生成的图片已经可以大致呈现出一个0的形状。


  • 2
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值