"""
GANs
使用MNIST数据集创建生成对抗网络(generative adversarial network)。
"""
import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib as mpl
from tensorflow.examples.tutorials.mnist import input_data
import os
mpl.rcParams['font.sans-serif'] = [u'simHei']
mpl.rcParams['axes.unicode_minus'] = False
mnist = input_data.read_data_sets('../gans_datas/mnist')
def model_inputs(real_dims, z_dim):
"""
模型输入
:param real_dims:
:param z_dim: 随机生成向量的长度
:return:
"""
inputs_real = tf.placeholder(tf.float32, [None, real_dims], name='inputs_real')
inputs_z = tf.placeholder(tf.float32, [None, z_dim], name='inputs_z')
return inputs_real, inputs_z
def generator(inputs_z, output_dims, n_units=128, reuse=False, alpha=0.01):
"""
生成网络
:param inputs_z:
:param output_dims:
:param n_units:
:param reuse:
:param alpha:
:return:
"""
with tf.variable_scope('generator', reuse=reuse):
h1 = tf.layers.dense(inputs_z, units=n_units, activation=None)
h1 = tf.nn.leaky_relu(h1, alpha=alpha)
logits = tf.layers.dense(h1, units=output_dims)
out = tf.nn.tanh(logits)
return out
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
"""
判别网络
:param x:
:param n_units:
:param reuse:
:param alpha:
:return:
"""
with tf.variable_scope('discriminator', reuse=reuse):
h1 = tf.layers.dense(x, n_units, activation=None)
h1 = tf.nn.leaky_relu(h1, alpha=alpha)
logits = tf.layers.dense(h1, units=1, activation=None)
prediction = tf.nn.sigmoid(logits)
return logits, prediction
input_size = 784
z_size = 50
hidden_size = 128
alpha = 0.01
smooth = 0.1
lr = 2e-3
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
inputs_real, inputs_z = model_inputs(input_size, z_size)
fake_images = generator(
inputs_z, output_dims=input_size, n_units=hidden_size, alpha=alpha)
d_logits_real, d_model_real = discriminator(
inputs_real, n_units=hidden_size, reuse=False, alpha=alpha)
d_logits_fake, d_model_fake = discriminator(
fake_images, n_units=hidden_size, reuse=True, alpha=alpha)
with graph.as_default():
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_real, labels=tf.ones_like(d_logits_real) *(1-smooth)
))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
))
d_loss = d_loss_real + d_loss_fake
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
))
with graph.as_default():
vars_list = tf.trainable_variables()
g_vars = [var for var in vars_list if var.name.startswith('generator')]
d_vars = [var for var in vars_list if var.name.startswith('discriminator')]
d_train_opt = tf.train.AdamOptimizer(lr).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(lr).minimize(g_loss, var_list=g_vars)
def train():
batch_size = 128
epochs = 100
samples = []
losses = []
with tf.Session(graph=graph) as sess:
saver = tf.train.Saver(var_list=g_vars)
sess.run(tf.global_variables_initializer())
step = 1
for e in range(1, epochs):
for ii in range(mnist.train.num_examples // batch_size):
images, _ = mnist.train.next_batch(batch_size)
images = images.reshape([batch_size, 784])
images = images * 2.0 - 1.0
batch_z = np.random.uniform(-1, 1, size=[batch_size, z_size])
feed = {inputs_real: images, inputs_z: batch_z}
sess.run(d_train_opt, feed)
sess.run(g_train_opt, {inputs_z: batch_z})
if step % 20 ==0:
g_loss_, d_loss_ = sess.run([g_loss, d_loss], feed)
print('Epochs:{} - Step:{} - G_loss:{} - D_loss:{}'.format(
e, step, g_loss_, d_loss_))
step += 1
sample_z = np.random.uniform(-1, 1, size=(16, z_size))
gen_samples = sess.run(
generator(inputs_z, input_size, n_units=hidden_size, reuse=True, alpha=alpha),
feed_dict={inputs_z: sample_z})
samples.append(gen_samples)
if e % 20 == 0:
saver.save(sess, './checkpoints/generator.ckpt')
with open('train_samples.pkl', 'wb') as f:
pkl.dump(samples, f)
with open('losses.pkl', 'wb') as f1:
pkl.dump(losses, f1)
if __name__ == '__main__':
train()
D:\Anaconda\python.exe D:/AI20/HJZ/04-深度学习/5-GANS生成对抗网络/01_手写数据集vanilla_gans/01_手写数据集vanilla_gans.py
WARNING:tensorflow:From D:/AI20/HJZ/04-深度学习/5-GANS生成对抗网络/01_手写数据集vanilla_gans/01_手写数据集vanilla_gans.py:20: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Please use tf.data to implement this functionality.
Extracting ../gans_datas/mnist\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting ../gans_datas/mnist\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting ../gans_datas/mnist\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting ../gans_datas/mnist\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From D:\Anaconda\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
2020-02-19 12:09:23.224499: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
Epochs:1 - Step:20 - G_loss:0.7817292213439941 - D_loss:0.9641932249069214
Epochs:1 - Step:40 - G_loss:2.90767765045166 - D_loss:0.4107835292816162
Epochs:1 - Step:60 - G_loss:3.216116428375244 - D_loss:0.39682626724243164
Epochs:1 - Step:80 - G_loss:4.175543785095215 - D_loss:0.35861313343048096
Epochs:1 - Step:100 - G_loss:3.4656898975372314 - D_loss:0.3678774833679199
Epochs:1 - Step:120 - G_loss:1.85102117061615 - D_loss:0.5009573101997375
Epochs:1 - Step:140 - G_loss:3.061119318008423 - D_loss:0.38645803928375244
Epochs:1 - Step:160 - G_loss:1.2325481176376343 - D_loss:0.6871609091758728
Epochs:1 - Step:180 - G_loss:3.0256710052490234 - D_loss:0.3843447268009186
Epochs:1 - Step:200 - G_loss:1.8746395111083984 - D_loss:0.5140929818153381
Epochs:1 - Step:220 - G_loss:2.4667201042175293 - D_loss:0.43565645813941956
Epochs:1 - Step:240 - G_loss:3.3306174278259277 - D_loss:0.38084229826927185
Epochs:1 - Step:260 - G_loss:2.340106964111328 - D_loss:0.45705512166023254
Epochs:1 - Step:280 - G_loss:3.2198646068573 - D_loss:0.3850352168083191
Epochs:1 - Step:300 - G_loss:3.836900234222412 - D_loss:0.3552168011665344
Epochs:1 - Step:320 - G_loss:3.1090588569641113 - D_loss:0.38329869508743286
Epochs:1 - Step:340 - G_loss:2.7839038372039795 - D_loss:0.44077199697494507
Epochs:1 - Step:360 - G_loss:3.8576135635375977 - D_loss:0.3560332655906677
Epochs:1 - Step:380 - G_loss:3.5159659385681152 - D_loss:0.3691025972366333
Epochs:1 - Step:400 - G_loss:3.5861971378326416 - D_loss:0.3761620819568634
Epochs:1 - Step:420 - G_loss:3.6675522327423096 - D_loss:0.36183619499206543
Epochs:2 - Step:440 - G_loss:3.7458157539367676 - D_loss:0.3596262037754059
Epochs:2 - Step:460 - G_loss:3.561983108520508 - D_loss:0.36385294795036316
Epochs:2 - Step:480 - G_loss:3.70308780670166 - D_loss:0.3643430769443512
Epochs:2 - Step:500 - G_loss:4.11073112487793 - D_loss:0.35393059253692627
Epochs:2 - Step:520 - G_loss:4.63616943359375 - D_loss:0.3439123034477234
Epochs:2 - Step:540 - G_loss:4.030265808105469 - D_loss:0.3501322865486145
Epochs:2 - Step:560 - G_loss:3.9309778213500977 - D_loss:0.354989618062973
Epochs:2 - Step:580 - G_loss:3.9735004901885986 - D_loss:0.35577061772346497
Epochs:2 - Step:600 - G_loss:4.366448402404785 - D_loss:0.3474786877632141
Epochs:2 - Step:620 - G_loss:4.316489219665527 - D_loss:0.34787890315055847
Epochs:2 - Step:640 - G_loss:3.884830951690674 - D_loss:0.35990938544273376
Epochs:2 - Step:660 - G_loss:3.952484607696533 - D_loss:0.35151758790016174
Epochs:2 - Step:680 - G_loss:4.7438812255859375 - D_loss:0.3428582549095154
Epochs:2 - Step:700 - G_loss:3.887786626815796 - D_loss:0.35907602310180664
Epochs:2 - Step:720 - G_loss:3.8684959411621094 - D_loss:0.3674944043159485
Epochs:2 - Step:740 - G_loss:3.626516819000244 - D_loss:0.3662309944629669
Epochs:2 - Step:760 - G_loss:3.8221027851104736 - D_loss:0.3663366436958313
Epochs:2 - Step:780 - G_loss:4.9205098152160645 - D_loss:0.3417809009552002
Epochs:2 - Step:800 - G_loss:4.62099552154541 - D_loss:0.3406345546245575
Epochs:2 - Step:820 - G_loss:4.176742076873779 - D_loss:0.3505726456642151
Epochs:2 - Step:840 - G_loss:4.228707790374756 - D_loss:0.350876122713089
Epochs:3 - Step:860 - G_loss:1.9485373497009277 - D_loss:0.7147329449653625
Epochs:3 - Step:880 - G_loss:2.9959187507629395 - D_loss:0.3928413689136505
Epochs:3 - Step:900 - G_loss:3.8644604682922363 - D_loss:0.38012346625328064
Epochs:3 - Step:920 - G_loss:2.7844901084899902 - D_loss:0.41952067613601685
Epochs:3 - Step:940 - G_loss:1.9257779121398926 - D_loss:0.7645138502120972
Epochs:3 - Step:960 - G_loss:2.4559528827667236 - D_loss:0.43561944365501404
Epochs:3 - Step:980 - G_loss:1.799408197402954 - D_loss:0.5406032800674438
Epochs:3 - Step:1000 - G_loss:2.633357286453247 - D_loss:0.45080655813217163
Epochs:3 - Step:1020 - G_loss:2.076284885406494 - D_loss:0.5411171913146973
Epochs:3 - Step:1040 - G_loss:1.970678448677063 - D_loss:0.532650887966156
Epochs:3 - Step:1060 - G_loss:2.743875026702881 - D_loss:0.44239479303359985
Epochs:3 - Step:1080 - G_loss:3.954252243041992 - D_loss:0.3728772699832916
Epochs:3 - Step:1100 - G_loss:2.783644437789917 - D_loss:0.4099668562412262
Epochs:3 - Step:1120 - G_loss:3.1912782192230225 - D_loss:0.3929431140422821
Epochs:3 - Step:1140 - G_loss:2.317721366882324 - D_loss:0.4742271900177002
Epochs:3 - Step:1160 - G_loss:2.972458839416504 - D_loss:0.4198395907878876
Epochs:3 - Step:1180 - G_loss:2.892899513244629 - D_loss:0.4024370312690735
Epochs:3 - Step:1200 - G_loss:2.5916624069213867 - D_loss:0.4356590509414673
Epochs:3 - Step:1220 - G_loss:2.032283306121826 - D_loss:0.5036305785179138
Epochs:3 - Step:1240 - G_loss:2.0204825401306152 - D_loss:0.5095717906951904
Epochs:3 - Step:1260 - G_loss:2.79931640625 - D_loss:0.46155110001564026
Epochs:3 - Step:1280 - G_loss:4.067086219787598 - D_loss:0.3892279863357544
Epochs:4 - Step:1300 - G_loss:3.7364625930786133 - D_loss:0.377506285905838
Epochs:4 - Step:1320 - G_loss:3.4091272354125977 - D_loss:0.3864123821258545
Epochs:4 - Step:1340 - G_loss:3.072338104248047 - D_loss:0.4463484585285187
Epochs:4 - Step:1360 - G_loss:2.2842698097229004 - D_loss:0.8370013236999512
Epochs:4 - Step:1380 - G_loss:3.0575428009033203 - D_loss:0.4213850498199463
Epochs:4 - Step:1400 - G_loss:2.8805391788482666 - D_loss:0.44334089756011963
Epochs:4 - Step:1420 - G_loss:2.707998514175415 - D_loss:0.5102008581161499
Epochs:4 - Step:1440 - G_loss:3.76065731048584 - D_loss:0.3874475657939911
Epochs:4 - Step:1460 - G_loss:2.9196529388427734 - D_loss:0.43920594453811646
Epochs:4 - Step:1480 - G_loss:2.459183692932129 - D_loss:0.47488275170326233
Epochs:4 - Step:1500 - G_loss:4.8662943840026855 - D_loss:0.40366634726524353
Epochs:4 - Step:1520 - G_loss:7.328624248504639 - D_loss:0.9148223996162415
Epochs:4 - Step:1540 - G_loss:2.4782156944274902 - D_loss:0.5983544588088989
Epochs:4 - Step:1560 - G_loss:2.1881155967712402 - D_loss:0.5516225695610046
Epochs:4 - Step:1580 - G_loss:2.462240219116211 - D_loss:0.5024042725563049
Epochs:4 - Step:1600 - G_loss:3.1658623218536377 - D_loss:0.4709932208061218
Epochs:4 - Step:1620 - G_loss:4.210256576538086 - D_loss:0.5433186292648315
Epochs:4 - Step:1640 - G_loss:4.455408096313477 - D_loss:0.5004763603210449
Epochs:4 - Step:1660 - G_loss:4.574621677398682 - D_loss:0.8144306540489197
Epochs:4 - Step:1680 - G_loss:1.7993299961090088 - D_loss:0.5978686809539795
Epochs:4 - Step:1700 - G_loss:3.2091588973999023 - D_loss:0.4315302073955536
Epochs:5 - Step:1720 - G_loss:3.8553311824798584 - D_loss:0.40169936418533325
Epochs:5 - Step:1740 - G_loss:3.552783489227295 - D_loss:0.4915289878845215
Process finished with exit code -1