gan-cls 具有匹配感知的判别器

原文链接: gan-cls 具有匹配感知的判别器

上一篇: lsgan 生成mnist数据集

下一篇: ESPCN MNIST 数据集超分辨率重建

infogan中,使用了ACGan的方式进行指导模拟数据与生成数据的对应关系,在gan-cls中该效果会以更简单的方式来实现,即增强判别器的功能,令其不仅能判断图片真伪,还能判断匹配真伪

gan-cls 具体做法是,在原有的gan网络上,将判别器的输入变为图片与对应标签的连接数据,这样判别器的输入特征中就会有生成图像的特征与对应标签的特征。然后用这样的判别器分别对真实标签与真实图片,假标签与真实图片,真实标签与假图片进行判断,预期的结果依次为真,假,假,在训练的过程中沿着这个方向收敛即可,而对于生成器,则不需要做任何改动,这样简单的一步就完成了生成根据标签匹配模拟数据的功能。

在lsgan基础上,将判别器的输入改为x和y,新增加的y代表输入的样0本标签,在内部处理中,先通过全连接网络将y变为与图片一样维度的映射,并调整为图片相同的形状,使用concat将二者连接到一起统一处理


def discriminator(x, y):
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
    print(x.shape)  # (?, 784)
    print(y.shape)  # (?, 10)
    with tf.variable_scope('discriminator', reuse=reuse):
        y = slim.fully_connected(y, num_outputs=n_input, activation_fn=tf.nn.leaky_relu)
        print(y.shape)  # (?, 784)
        y = tf.reshape(y, shape=[-1, 28, 28, 1])
        x = tf.reshape(x, shape=[-1, 28, 28, 1])
        x = tf.concat(axis=3, values=[x, y])
        print(x.shape)  # (?, 28, 28, 2)
        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
        print(x.shape)  # (?, 14, 14, 64)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
        print(x.shape)  # (?, 7, 7, 128)
        x = slim.flatten(x)
        print(x.shape)  # (?, 6272)
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=tf.nn.leaky_relu)
        print(shared_tensor.shape)  # (?, 1024)
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        print(disc.shape)  # (?, 1)
        disc = tf.squeeze(disc, -1)
        print(disc.shape)  # (?,)
    return disc

添加错误标签misy,同时在判别器中分别将真实样0本与真实标签,生成的图像gen和真实的标签,真实样0本与错误标签组成的输入传递到判别器中,这里将三种输入的x,y分别按照batch_size维度连接变为判别器的一个输入,生成结果再使用split裁为3个结果disc_real,disc_fake,disc_mis,分别代表真实样0本与真实标签,生成图像gen和真实标签,真实样0本与错误标签所对应的判别值,当然也可以一个一个的输入x,y然后调用三次判别器,效果是一样的

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow.contrib.slim as slim

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/")  # , one_hot=True)

tf.reset_default_graph()


def generator(x):
    reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0
    with tf.variable_scope('generator', reuse=reuse):
        # 两个全连接
        x = slim.fully_connected(x, 1024)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = slim.fully_connected(x, 7 * 7 * 128)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = tf.reshape(x, [-1, 7, 7, 128])
        # 两个转置卷积
        x = slim.conv2d_transpose(x, 64, kernel_size=[4, 4], stride=2, activation_fn=None)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid)
    return z


batch_size = 10  # 最小批次
classes_dim = 10  # 10 个分类

rand_dim = 38
n_input = 784


def discriminator(x, y):
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
    print(x.shape)  # (?, 784)
    print(y.shape)  # (?, 10)
    with tf.variable_scope('discriminator', reuse=reuse):
        y = slim.fully_connected(y, num_outputs=n_input, activation_fn=tf.nn.leaky_relu)
        print(y.shape)  # (?, 784)
        y = tf.reshape(y, shape=[-1, 28, 28, 1])
        x = tf.reshape(x, shape=[-1, 28, 28, 1])
        x = tf.concat(axis=3, values=[x, y])
        print(x.shape)  # (?, 28, 28, 2)
        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
        print(x.shape)  # (?, 14, 14, 64)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
        print(x.shape)  # (?, 7, 7, 128)
        x = slim.flatten(x)
        print(x.shape)  # (?, 6272)
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=tf.nn.leaky_relu)
        print(shared_tensor.shape)  # (?, 1024)
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        print(disc.shape)  # (?, 1)
        disc = tf.squeeze(disc, -1)
        print(disc.shape)  # (?,)
    return disc


x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.int32, [None])
misy = tf.placeholder(tf.int32, [None])

z_rand = tf.random_normal((batch_size, rand_dim))  # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand])  # 50列
gen = generator(z)
genout = tf.squeeze(gen, -1)

# 判别器
xin = tf.concat([x, tf.reshape(gen, shape=[-1, 784]), x], 0)
yin = tf.concat(
    [tf.one_hot(y, depth=classes_dim), tf.one_hot(y, depth=classes_dim), tf.one_hot(misy, depth=classes_dim)], 0)
disc_all = discriminator(xin, yin)
disc_real, disc_fake, disc_mis = tf.split(disc_all, 3)

loss_d = tf.reduce_sum(tf.square(disc_real - 1) + (tf.square(disc_fake) + tf.square(disc_mis)) / 2) / 2
loss_g = tf.reduce_sum(tf.square(disc_fake - 1)) / 2

# 获得各个网络中各自的训练参数
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]

# disc_global_step = tf.Variable(0, trainable=False)
gen_global_step = tf.Variable(0, trainable=False)

global_step = tf.train.get_or_create_global_step()  # 使用MonitoredTrainingSession,必须有

train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d, var_list=d_vars, global_step=global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g, var_list=g_vars, global_step=gen_global_step)

training_epochs = 3
display_step = 1

with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpointsnew', save_checkpoint_secs=60) as sess:
    total_batch = int(mnist.train.num_examples / batch_size)
    print("global_step.eval(session=sess)", global_step.eval(session=sess),
          int(global_step.eval(session=sess) / total_batch))
    for epoch in range(int(global_step.eval(session=sess) / total_batch), training_epochs):
        avg_cost = 0.

        # 遍历全部数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)  # 取数据
            _, mis_batch_ys = mnist.train.next_batch(batch_size)  # 取数据
            feeds = {x: batch_xs, y: batch_ys, misy: mis_batch_ys}

            # Fit training using batch data
            l_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step], feeds)
            l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step], feeds)

        # 显示训练中的详细信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)

    print("完成!")

    # 测试
    _, mis_batch_ys = mnist.train.next_batch(batch_size)
    print("result:",
          loss_d.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size], misy: mis_batch_ys},
                      session=sess)
          , loss_g.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size], misy: mis_batch_ys},
                        session=sess))

    # 根据图片模拟生成图片
    show_num = 10
    gensimple, inputx, inputy = sess.run(
        [genout, x, y], feed_dict={x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]})

    f, a = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(show_num):
        a[0][i].imshow(np.reshape(inputx[i], (28, 28)))
        a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))

    plt.draw()
    plt.show()

效果类似

9a3dad00730f0cb7bd8e7c511fea6321d1d.jpg

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值