6-Gans-01_手写数据集vanilla_gans




"""
GANs
使用MNIST数据集创建生成对抗网络(generative adversarial network)。
"""
import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib as mpl
from tensorflow.examples.tutorials.mnist import input_data
import os

# 设置字符集,防止中文乱码
mpl.rcParams['font.sans-serif'] = [u'simHei']
mpl.rcParams['axes.unicode_minus'] = False

mnist = input_data.read_data_sets('../gans_datas/mnist')

def model_inputs(real_dims, z_dim):
    """
    模型输入
    :param real_dims:
    :param z_dim:  随机生成向量的长度
    :return:
    """
    inputs_real = tf.placeholder(tf.float32, [None, real_dims], name='inputs_real')
    inputs_z = tf.placeholder(tf.float32, [None, z_dim], name='inputs_z')
    return inputs_real, inputs_z

def generator(inputs_z, output_dims, n_units=128, reuse=False, alpha=0.01):
    """
    生成网络
    :param inputs_z:
    :param output_dims:
    :param n_units:
    :param reuse:
    :param alpha:
    :return:
    """
    with tf.variable_scope('generator', reuse=reuse):
        # 第一层隐藏层
        h1 = tf.layers.dense(inputs_z, units=n_units, activation=None)
        h1 = tf.nn.leaky_relu(h1, alpha=alpha)

        # 输出层
        logits = tf.layers.dense(h1, units=output_dims)
        out = tf.nn.tanh(logits)  # [-1, 1]
        return out


def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    """
    判别网络
    :param x:
    :param n_units:
    :param reuse:
    :param alpha:
    :return:
    """
    with tf.variable_scope('discriminator', reuse=reuse):
        # 第一层隐藏层
        h1 = tf.layers.dense(x, n_units, activation=None)
        h1 = tf.nn.leaky_relu(h1, alpha=alpha)

        # 输出层
        logits = tf.layers.dense(h1, units=1, activation=None)
        prediction = tf.nn.sigmoid(logits)
        return logits, prediction

# 超参数设置
input_size = 784
z_size = 50
hidden_size = 128  # 网络中隐藏层节点数量
alpha = 0.01  # LeakyRELU的系数
smooth = 0.1  # 标签平滑的系数
lr = 2e-3


tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
    # 1、创建输入占位符
    inputs_real, inputs_z = model_inputs(input_size, z_size)
    # 2、调用生成网络 生成fake images
    fake_images = generator(
        inputs_z, output_dims=input_size, n_units=hidden_size, alpha=alpha)

    # 3、用判别网络进行鉴别
    d_logits_real, d_model_real = discriminator(
        inputs_real, n_units=hidden_size, reuse=False, alpha=alpha)
    d_logits_fake, d_model_fake = discriminator(
        fake_images, n_units=hidden_size, reuse=True, alpha=alpha)


with graph.as_default():
    # 计算模型损失(D_loss)
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_logits_real, labels=tf.ones_like(d_logits_real) *(1-smooth)
    ))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
    ))
    d_loss = d_loss_real + d_loss_fake

    # 生成网络的损失(G_loss)
    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
    ))

with graph.as_default():
    # 构建模型优化器
    # 根据变量的名字,分别获取 G 和 D的变量
    vars_list = tf.trainable_variables()
    g_vars = [var for var in vars_list if var.name.startswith('generator')]
    d_vars = [var for var in vars_list if var.name.startswith('discriminator')]
    # print(g_vars, '\n', d_vars)

    d_train_opt = tf.train.AdamOptimizer(lr).minimize(d_loss, var_list=d_vars)
    g_train_opt = tf.train.AdamOptimizer(lr).minimize(g_loss, var_list=g_vars)


def train():
    batch_size = 128
    epochs = 100
    samples = []
    losses = []
    with tf.Session(graph=graph) as sess:
        saver = tf.train.Saver(var_list=g_vars)
        sess.run(tf.global_variables_initializer())
        step = 1
        for e in range(1, epochs):
            for ii in range(mnist.train.num_examples // batch_size):
                images, _ = mnist.train.next_batch(batch_size)
                # 将images重新缩放到 [-1, 1]
                images = images.reshape([batch_size, 784])
                images = images * 2.0 - 1.0

                # 构造随机噪音向量
                batch_z = np.random.uniform(-1, 1, size=[batch_size, z_size])

                feed = {inputs_real: images, inputs_z: batch_z}

                # 执行模型训练
                sess.run(d_train_opt, feed)
                sess.run(g_train_opt, {inputs_z: batch_z})

                if step % 20 ==0:
                    g_loss_, d_loss_ = sess.run([g_loss, d_loss], feed)
                    print('Epochs:{} - Step:{} - G_loss:{} - D_loss:{}'.format(
                        e, step, g_loss_, d_loss_))
                step += 1

                # Sample from generator as we're training for viewing afterwards
                sample_z = np.random.uniform(-1, 1, size=(16, z_size))
                gen_samples = sess.run(
                    generator(inputs_z, input_size, n_units=hidden_size, reuse=True, alpha=alpha),
                    feed_dict={inputs_z: sample_z})
                samples.append(gen_samples)
                if e % 20 == 0:
                    saver.save(sess, './checkpoints/generator.ckpt')

        # Save training generator samples
        with open('train_samples.pkl', 'wb') as f:
            pkl.dump(samples, f)
        with open('losses.pkl', 'wb') as f1:
            pkl.dump(losses, f1)


if __name__ == '__main__':
    train()
D:\Anaconda\python.exe D:/AI20/HJZ/04-深度学习/5-GANS生成对抗网络/01_手写数据集vanilla_gans/01_手写数据集vanilla_gans.py
WARNING:tensorflow:From D:/AI20/HJZ/04-深度学习/5-GANS生成对抗网络/01_手写数据集vanilla_gans/01_手写数据集vanilla_gans.py:20: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Please use tf.data to implement this functionality.
Extracting ../gans_datas/mnist\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting ../gans_datas/mnist\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting ../gans_datas/mnist\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting ../gans_datas/mnist\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
2020-02-19 12:09:23.224499: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
Epochs:1 - Step:20 - G_loss:0.7817292213439941 - D_loss:0.9641932249069214
Epochs:1 - Step:40 - G_loss:2.90767765045166 - D_loss:0.4107835292816162
Epochs:1 - Step:60 - G_loss:3.216116428375244 - D_loss:0.39682626724243164
Epochs:1 - Step:80 - G_loss:4.175543785095215 - D_loss:0.35861313343048096
Epochs:1 - Step:100 - G_loss:3.4656898975372314 - D_loss:0.3678774833679199
Epochs:1 - Step:120 - G_loss:1.85102117061615 - D_loss:0.5009573101997375
Epochs:1 - Step:140 - G_loss:3.061119318008423 - D_loss:0.38645803928375244
Epochs:1 - Step:160 - G_loss:1.2325481176376343 - D_loss:0.6871609091758728
Epochs:1 - Step:180 - G_loss:3.0256710052490234 - D_loss:0.3843447268009186
Epochs:1 - Step:200 - G_loss:1.8746395111083984 - D_loss:0.5140929818153381
Epochs:1 - Step:220 - G_loss:2.4667201042175293 - D_loss:0.43565645813941956
Epochs:1 - Step:240 - G_loss:3.3306174278259277 - D_loss:0.38084229826927185
Epochs:1 - Step:260 - G_loss:2.340106964111328 - D_loss:0.45705512166023254
Epochs:1 - Step:280 - G_loss:3.2198646068573 - D_loss:0.3850352168083191
Epochs:1 - Step:300 - G_loss:3.836900234222412 - D_loss:0.3552168011665344
Epochs:1 - Step:320 - G_loss:3.1090588569641113 - D_loss:0.38329869508743286
Epochs:1 - Step:340 - G_loss:2.7839038372039795 - D_loss:0.44077199697494507
Epochs:1 - Step:360 - G_loss:3.8576135635375977 - D_loss:0.3560332655906677
Epochs:1 - Step:380 - G_loss:3.5159659385681152 - D_loss:0.3691025972366333
Epochs:1 - Step:400 - G_loss:3.5861971378326416 - D_loss:0.3761620819568634
Epochs:1 - Step:420 - G_loss:3.6675522327423096 - D_loss:0.36183619499206543
Epochs:2 - Step:440 - G_loss:3.7458157539367676 - D_loss:0.3596262037754059
Epochs:2 - Step:460 - G_loss:3.561983108520508 - D_loss:0.36385294795036316
Epochs:2 - Step:480 - G_loss:3.70308780670166 - D_loss:0.3643430769443512
Epochs:2 - Step:500 - G_loss:4.11073112487793 - D_loss:0.35393059253692627
Epochs:2 - Step:520 - G_loss:4.63616943359375 - D_loss:0.3439123034477234
Epochs:2 - Step:540 - G_loss:4.030265808105469 - D_loss:0.3501322865486145
Epochs:2 - Step:560 - G_loss:3.9309778213500977 - D_loss:0.354989618062973
Epochs:2 - Step:580 - G_loss:3.9735004901885986 - D_loss:0.35577061772346497
Epochs:2 - Step:600 - G_loss:4.366448402404785 - D_loss:0.3474786877632141
Epochs:2 - Step:620 - G_loss:4.316489219665527 - D_loss:0.34787890315055847
Epochs:2 - Step:640 - G_loss:3.884830951690674 - D_loss:0.35990938544273376
Epochs:2 - Step:660 - G_loss:3.952484607696533 - D_loss:0.35151758790016174
Epochs:2 - Step:680 - G_loss:4.7438812255859375 - D_loss:0.3428582549095154
Epochs:2 - Step:700 - G_loss:3.887786626815796 - D_loss:0.35907602310180664
Epochs:2 - Step:720 - G_loss:3.8684959411621094 - D_loss:0.3674944043159485
Epochs:2 - Step:740 - G_loss:3.626516819000244 - D_loss:0.3662309944629669
Epochs:2 - Step:760 - G_loss:3.8221027851104736 - D_loss:0.3663366436958313
Epochs:2 - Step:780 - G_loss:4.9205098152160645 - D_loss:0.3417809009552002
Epochs:2 - Step:800 - G_loss:4.62099552154541 - D_loss:0.3406345546245575
Epochs:2 - Step:820 - G_loss:4.176742076873779 - D_loss:0.3505726456642151
Epochs:2 - Step:840 - G_loss:4.228707790374756 - D_loss:0.350876122713089
Epochs:3 - Step:860 - G_loss:1.9485373497009277 - D_loss:0.7147329449653625
Epochs:3 - Step:880 - G_loss:2.9959187507629395 - D_loss:0.3928413689136505
Epochs:3 - Step:900 - G_loss:3.8644604682922363 - D_loss:0.38012346625328064
Epochs:3 - Step:920 - G_loss:2.7844901084899902 - D_loss:0.41952067613601685
Epochs:3 - Step:940 - G_loss:1.9257779121398926 - D_loss:0.7645138502120972
Epochs:3 - Step:960 - G_loss:2.4559528827667236 - D_loss:0.43561944365501404
Epochs:3 - Step:980 - G_loss:1.799408197402954 - D_loss:0.5406032800674438
Epochs:3 - Step:1000 - G_loss:2.633357286453247 - D_loss:0.45080655813217163
Epochs:3 - Step:1020 - G_loss:2.076284885406494 - D_loss:0.5411171913146973
Epochs:3 - Step:1040 - G_loss:1.970678448677063 - D_loss:0.532650887966156
Epochs:3 - Step:1060 - G_loss:2.743875026702881 - D_loss:0.44239479303359985
Epochs:3 - Step:1080 - G_loss:3.954252243041992 - D_loss:0.3728772699832916
Epochs:3 - Step:1100 - G_loss:2.783644437789917 - D_loss:0.4099668562412262
Epochs:3 - Step:1120 - G_loss:3.1912782192230225 - D_loss:0.3929431140422821
Epochs:3 - Step:1140 - G_loss:2.317721366882324 - D_loss:0.4742271900177002
Epochs:3 - Step:1160 - G_loss:2.972458839416504 - D_loss:0.4198395907878876
Epochs:3 - Step:1180 - G_loss:2.892899513244629 - D_loss:0.4024370312690735
Epochs:3 - Step:1200 - G_loss:2.5916624069213867 - D_loss:0.4356590509414673
Epochs:3 - Step:1220 - G_loss:2.032283306121826 - D_loss:0.5036305785179138
Epochs:3 - Step:1240 - G_loss:2.0204825401306152 - D_loss:0.5095717906951904
Epochs:3 - Step:1260 - G_loss:2.79931640625 - D_loss:0.46155110001564026
Epochs:3 - Step:1280 - G_loss:4.067086219787598 - D_loss:0.3892279863357544
Epochs:4 - Step:1300 - G_loss:3.7364625930786133 - D_loss:0.377506285905838
Epochs:4 - Step:1320 - G_loss:3.4091272354125977 - D_loss:0.3864123821258545
Epochs:4 - Step:1340 - G_loss:3.072338104248047 - D_loss:0.4463484585285187
Epochs:4 - Step:1360 - G_loss:2.2842698097229004 - D_loss:0.8370013236999512
Epochs:4 - Step:1380 - G_loss:3.0575428009033203 - D_loss:0.4213850498199463
Epochs:4 - Step:1400 - G_loss:2.8805391788482666 - D_loss:0.44334089756011963
Epochs:4 - Step:1420 - G_loss:2.707998514175415 - D_loss:0.5102008581161499
Epochs:4 - Step:1440 - G_loss:3.76065731048584 - D_loss:0.3874475657939911
Epochs:4 - Step:1460 - G_loss:2.9196529388427734 - D_loss:0.43920594453811646
Epochs:4 - Step:1480 - G_loss:2.459183692932129 - D_loss:0.47488275170326233
Epochs:4 - Step:1500 - G_loss:4.8662943840026855 - D_loss:0.40366634726524353
Epochs:4 - Step:1520 - G_loss:7.328624248504639 - D_loss:0.9148223996162415
Epochs:4 - Step:1540 - G_loss:2.4782156944274902 - D_loss:0.5983544588088989
Epochs:4 - Step:1560 - G_loss:2.1881155967712402 - D_loss:0.5516225695610046
Epochs:4 - Step:1580 - G_loss:2.462240219116211 - D_loss:0.5024042725563049
Epochs:4 - Step:1600 - G_loss:3.1658623218536377 - D_loss:0.4709932208061218
Epochs:4 - Step:1620 - G_loss:4.210256576538086 - D_loss:0.5433186292648315
Epochs:4 - Step:1640 - G_loss:4.455408096313477 - D_loss:0.5004763603210449
Epochs:4 - Step:1660 - G_loss:4.574621677398682 - D_loss:0.8144306540489197
Epochs:4 - Step:1680 - G_loss:1.7993299961090088 - D_loss:0.5978686809539795
Epochs:4 - Step:1700 - G_loss:3.2091588973999023 - D_loss:0.4315302073955536
Epochs:5 - Step:1720 - G_loss:3.8553311824798584 - D_loss:0.40169936418533325
Epochs:5 - Step:1740 - G_loss:3.552783489227295 - D_loss:0.4915289878845215

Process finished with exit code -1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值