GAN生成抛物线

本文主要讲解GAN的原理以及一个小实战,利用GAN生成抛物线,首先我们看一下GAN的原理。

GAN是2014年提出来的,他的原理可以这样理解,他有一个生成器和一个判别器,生成器是不断的生成数据,判别器的原理是将真实图片和生成器制作的数据区分开来,目的就是鉴别生成器生成的数据是假的,把原始数据判定为真。为生成器相反,他的目的就是源源不断的生成数据,让判别器无法分辨真假,从而以假乱真。常看到的一个例子就是坏人制作假币,警察差查假币。最终达到判别器无法分辨谁真谁假,也就是对于生成数据,它判断的为假的概率也是0.5,对于一个真实数据,它判断为真的概率也是0.5,这是最理想的状态。

任何一个神经网络模型都有损失函数,对于GAN模型,自然也不例外,他也有损失函数,因为他有丄生成网络和判别网络之分,所以当然是两个损失函数的,我们先看判别网路的损失函数,

min (-log(D(x)) -log(1-D(G(x))))
这个就是判别网络的损失函数,这样理解,我们把真实数据判定为1是对的,把生成网络生成的数据判定为0是对的,也就是我们把真实数据判定为1那么损失函数越小,把生成数据判定为0,判别网络损失函数越小,然后我们再看损失函数,不难发现,要想损失函数越小,D(x)越接近1越好,这里x表示真实数据输入判别网络,可以看出D(G(x))越接近0损失越小,其中G(x)表示生成网络的输出输入判别网络。这也很符合我们平时的理解。再看生成网络的损失函数

min(-log(D(G(x)))
从生成网络我们可以看出,生成网络的目的是让D(G(x))越接近1越好,越就是D(G(x))越接近1,损失函数越小,这就和判别网络矛盾了,那就形成了竞争。

有了这些理论,我们再看实际代码实现生成抛物线的对抗神经网络:

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

np.random.seed(2018)
tf.set_random_seed(2018)

def real_data(num):
    #x = np.random.uniform(-10,10,[num,1])
    x = np.linspace(-3,3, num) + np.random.random(num) * 0.01
    x = x.reshape([-1,1])
    #sample = np.sin(x) + 1
    sample = x**2 +1

    return x,sample


def fake_data(num):
    x = np.linspace(-3,3, num) + np.random.random(num) * 0.01

    return x.reshape([-1,1])


batch_size = 64
iters = 2000
#hidden_units = batch_size//2
alpha = 0.01
lr = 0.0001
#gen_num = 10
# out_dim = None

def generator(inputs,alpha,reuse=False):
    # out_dim 表示输出的大小,最后一层全连接层输出的大小,所以和batch_size大小一样
    # hidden_units隐藏层神经元的个数
    # alpha Leaky ReLU激活函数的参数
    # reuse 是否重用参数变量
    with tf.variable_scope('generator',reuse=reuse) as scope:

        hidden_1 = tf.layers.dense(inputs, 64, activation=None)
        ac1 = tf.maximum(alpha * hidden_1, hidden_1)
        #ac1 = tf.nn.tanh( hidden_1)
        bn1 = tf.layers.batch_normalization(ac1)

        hidden_2 = tf.layers.dense(bn1, 128, activation=None)
        ac2 = tf.maximum(alpha * hidden_2, hidden_2)
        #ac2 = tf.nn.tanh(hidden_2)
        bn2 = tf.layers.batch_normalization(ac2)

        hidden_3 = tf.layers.dense(bn2, 256, activation=None)
        ac3 = tf.maximum(alpha *hidden_3, hidden_3)
        #ac3 = tf.nn.tanh(ac2)
        bn3 = tf.layers.batch_normalization(ac3)


        out = tf.layers.dense(bn3,1,activation=None)
        return out

def discriminator(discr_input,alpha,reuse=False,name='discriminator'):
    # discr_input 判别器的输入
    # alpha Leaky ReLU激活函数的参数
    # hidden_units 隐藏层的神经元个数
    # reuse 是否重用变量

    with tf.variable_scope(name,reuse=reuse) as scope:

        hidden_1 = tf.layers.dense(discr_input,units=64,activation=None)
        #ac1 = tf.maximum(alpha*hidden_1,hidden_1)
        ac1 = tf.nn.tanh(hidden_1)

        hidden_2 = tf.layers.dense(ac1, units=128, activation=None)
        #ac2 = tf.maximum(alpha * hidden_2, hidden_2)
        ac2 = tf.nn.tanh(hidden_2)
        hidden_3 = tf.layers.dense(ac2, units=128, activation=None)
        #ac3 = tf.maximum(alpha * hidden_3, hidden_3)
        ac3 = tf.nn.tanh(hidden_3)
        logits = tf.layers.dense(ac3,1,activation=None)
        out = tf.nn.sigmoid(logits)
        return logits,out
def plot_data(gen_x,gen_y):

    x_r,y_r = real_data(64)

    plt.scatter(x_r, y_r, label='real data')
    plt.scatter(gen_x,gen_y, label='generated data')
    plt.title('GAN')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.legend()
    plt.show()

with tf.name_scope('gen_input') as scope:
    gen_input = tf.placeholder(dtype=tf.float32,shape=[None,1],name='gen_input')
with tf.name_scope('discriminator_input') as scope:
    real_input = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='real_input')


out_gen = generator(gen_input,alpha,reuse=False)

real_logits,label_real = discriminator(real_input,alpha,reuse=False)
logits_gen,label_fake = discriminator(out_gen,alpha,reuse=True)

with tf.name_scope('discr_train') as scope:
    train_input = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='train_input')
train_disc = discriminator(train_input,alpha,reuse=False,name='train_dis')
para = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='train_dis')
train_loss = tf.reduce_mean(tf.square(train_disc-train_input))

#with tf.Session() as sess:
#    sess.run(tf.global_variables_initializer())

with tf.name_scope('metrics') as name:

    loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_gen)*0.99,logits=logits_gen))
    loss_d_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_gen),logits=logits_gen))
    loss_d_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logits)*0.99, logits=real_logits))
    loss_d = loss_d_fake+loss_d_real


    var_list_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='generator')
    var_list_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='discriminator')

    d_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_d,var_list=var_list_d)
    g_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_g,var_list=var_list_g)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    writer = tf.summary.FileWriter('./graph/gan',sess.graph)
    # for i in range(1000):
    #     _, real = real_data(batch_size)
    #     _ = sess.run(train_loss,feed_dict={train_input:real})
    # train_weig = sess.run(para)
    # for i in range(len((var_list_d))):
    #     sess.run(var_list_d[i].assign(train_weig[i]))

    for iter in range(iters):
        _,real = real_data(batch_size)

        fake = fake_data(batch_size)
        _,train_loss_d = sess.run([d_optimizer,loss_d],feed_dict={real_input:real,gen_input: fake})
        _, train_loss_g = sess.run([g_optimizer, loss_g], feed_dict={gen_input: fake})
        fake = fake_data(batch_size)
        _, train_loss_g = sess.run([g_optimizer, loss_g], feed_dict={gen_input: fake})
        fake = fake_data(batch_size)
        _, train_loss_g = sess.run([g_optimizer, loss_g], feed_dict={gen_input: fake})

        if iter % 200 == 0:
            print(train_loss_d)
            print(train_loss_g)
            gen_x = np.linspace(-3,3,500).reshape([-1,1])
            gen_y = sess.run(out_gen,feed_dict={gen_input:gen_x})
            plot_data(gen_x, gen_y)

    saver.save(sess, "./checkpoints/gen")
    writer.close()
下面展示一个训练过程的图像:

在做这个的时候,我试图生成正弦曲线,但是效果比较差,我猜测是和正弦函数泰勒展开有关系,又知道的希望提点一下。






  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值