通过自己的实践 加上 参考别人的代码,实现了GAN,现贴出代码:
Vanilla_GAN:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
##discrimitive
def discrimitive(x, scope=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
fc1 = tf.layers.dense(inputs=x, units=256, activation=tf.nn.leaky_relu)
fc2 = tf.layers.dense(inputs=fc1, units=256, activation=tf.nn.leaky_relu)
logits = tf.layers.dense(inputs=fc2, units=1)
return logits
##generator
def generator(z, scope=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
fc2 = tf.layers.dense(inputs=fc1, units=1024, activation=tf.nn.relu)
img = tf.layers.dense(inputs=fc2, units=784, activation=tf.nn.tanh)
return img
dim = 784
batch_size = 100
iterations = 3000
learning_rate = 1e-3
beta1 = 0.5
##
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
x = tf.placeholder(dtype = tf.float32, shape=[batch_size, dim])
z = tf.placeholder(dtype = tf.float32, shape=[batch_size, 1024])
#####
labels_G = tf.placeholder(dtype=tf.float32, shape=[batch_size, 1])
labels_D = tf.placeholder(dtype=tf.float32, shape=[batch_size, 1])
logits_D = discrimitive(x, 'd')
logits_G1 = generator(z, 'g')
logits_G = discrimitive(logits_G1, 'd')
D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
D_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_D, logits=logits_D)
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_G, logits=logits_G)
D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'd')
G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'g')
D_train_step = D_solver.minimize(D_loss, var_list=D_vars)
G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for i in range(iterations):
zz = np.random.uniform(-1, 1, [batch_size, 1024])
feed_dict0 = {z: zz}
zzz = sess.run(logits_G1, feed_dict=feed_dict0)
xx, _ = mnist.train.next_batch(batch_size)
#discrimitive( 0 for real , 1 for fake)
labels1 = np.zeros([batch_size, 1])
labels2 = np.ones([batch_size, 1])
feed_dict1 = {x: xx, labels_D: labels1}
feed_dict2 = {x: zzz, labels_D: labels2}
sess.run(D_train_step, feed_dict=feed_dict1)
sess.run(D_train_step, feed_dict=feed_dict2)
#generator
feed_dict3 = {labels_G: labels1, z: zz}
sess.run(G_train_step, feed_dict=feed_dict3)
if i%100 == 0:
zzz = zzz.reshape(100, 28, 28)
zzz0 = zzz[0]
zzz1 = zzz[11]
zzz2 = zzz[22]
zzz3 = zzz[33]
plt.subplot(221)
plt.imshow(zzz0, cmap=plt.cm.gray)
plt.subplot(222)
plt.imshow(zzz1, cmap=plt.cm.gray)
plt.subplot(223)
plt.imshow(zzz2, cmap=plt.cm.gray)
plt.subplot(224)
plt.imshow(zzz3, cmap=plt.cm.gray)
plt.show()
效果如下:
DC_GAN
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
##discrimitive
def discrimitive(x, scope=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
unflatten = tf.reshape(x, shape=[-1, 28, 28, 1])
conv1 = tf.layers.conv2d(inputs=unflatten, kernel_size=5, strides=1, filters=32 ,activation=tf.nn.leaky_relu)
maxpool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2)
conv2 = tf.layers.conv2d(inputs=maxpool1, kernel_size=5, strides=1, filters=64,activation=tf.nn.leaky_relu)
maxpool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=2, strides=2)
flatten = tf.reshape(maxpool2, shape=[-1, 1024])
fc1 = tf.layers.dense(inputs=flatten, units=1024, activation=tf.nn.leaky_relu)
logits = tf.layers.dense(inputs=fc1, units=1)
return logits
##generator
def generator(z, scope=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
bn1 = tf.layers.batch_normalization(inputs=fc1, training=True)
fc2 = tf.layers.dense(inputs=bn1, units=7 * 7 * 128, activation=tf.nn.relu)
bn2 = tf.layers.batch_normalization(inputs=fc2, training=True)
reshaped = tf.reshape(bn2, shape=[-1, 7, 7, 128])
conv_transpose1 = tf.layers.conv2d_transpose(inputs=reshaped, filters=64, kernel_size=4, strides=2,
activation=tf.nn.relu,
padding='same')
bn3 = tf.layers.batch_normalization(inputs=conv_transpose1, training=True)
conv_transpose2 = tf.layers.conv2d_transpose(inputs=bn3, filters=1, kernel_size=4, strides=2,
activation=tf.nn.tanh,
padding='same')
img = tf.reshape(conv_transpose2, shape=[-1, 784])
return img
dim = 784
batch_size = 100
iterations = 3000
learning_rate = 1e-3
beta1 = 0.5
##
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
x = tf.placeholder(dtype = tf.float32, shape=[batch_size, dim])
z = tf.placeholder(dtype = tf.float32, shape=[batch_size, 1024])
#####
labels_G = tf.placeholder(dtype=tf.float32, shape=[batch_size, 1])
labels_D = tf.placeholder(dtype=tf.float32, shape=[batch_size, 1])
logits_D = discrimitive(x, 'd')
logits_G1 = generator(z, 'g')
logits_G = discrimitive(logits_G1, 'd')
D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
D_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_D, logits=logits_D)
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_G, logits=logits_G)
D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'd')
G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'g')
D_train_step = D_solver.minimize(D_loss, var_list=D_vars)
G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for i in range(iterations):
zz = np.random.uniform(-1, 1, [batch_size, 1024])
feed_dict0 = {z: zz}
zzz = sess.run(logits_G1, feed_dict=feed_dict0)
xx, _ = mnist.train.next_batch(batch_size)
#discrimitive( 0 for real , 1 for fake)
labels1 = np.zeros([batch_size, 1])
labels2 = np.ones([batch_size, 1])
feed_dict1 = {x: xx, labels_D: labels1}
feed_dict2 = {x: zzz, labels_D: labels2}
sess.run(D_train_step, feed_dict=feed_dict1)
sess.run(D_train_step, feed_dict=feed_dict2)
#generator
feed_dict3 = {labels_G: labels1, z: zz}
sess.run(G_train_step, feed_dict=feed_dict3)
if i%10 == 0:
zzz = zzz.reshape(100, 28, 28)
zzz0 = zzz[0]
zzz1 = zzz[11]
zzz2 = zzz[22]
zzz3 = zzz[33]
plt.subplot(221)
plt.imshow(zzz0, cmap=plt.cm.gray)
plt.subplot(222)
plt.imshow(zzz1, cmap=plt.cm.gray)
plt.subplot(223)
plt.imshow(zzz2, cmap=plt.cm.gray)
plt.subplot(224)
plt.imshow(zzz3, cmap=plt.cm.gray)
plt.show()