参考链接:
https://blog.csdn.net/weixin_41923961/article/details/80686419
https://blog.csdn.net/qq_27898631/article/details/87607020
https://blog.csdn.net/Sakura55/article/details/81514828
https://blog.csdn.net/yunyi4367/article/details/80489962
tensorflow实现:https://github.com/hwalsuklee/tensorflow-generative-model-collections
keras实现:https://zhuanlan.zhihu.com/p/34139648
main.py
import os
## GAN Variants
from GAN import GAN
from CGAN import CGAN
from infoGAN import infoGAN
from ACGAN import ACGAN
from EBGAN import EBGAN
from WGAN import WGAN
from WGAN_GP import WGAN_GP
from DRAGAN import DRAGAN
from LSGAN import LSGAN
from BEGAN import BEGAN
## VAE Variants
from VAE import VAE
from CVAE import CVAE
from utils import show_all_variables
from utils import check_folder
import tensorflow as tf
import argparse
"""parsing and configuration"""
def parse_args():
desc = "Tensorflow implementation of GAN collections"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--gan_type', type=str, default='CGAN',
choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN', 'VAE', 'CVAE'],
help='The type of GAN', required=False)
parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'celebA'],
help='The name of dataset')
parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run')
parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector')
parser.add_argument('--checkpoint_dir', type=str, default='data/checkpoint',help='checkpoints')
parser.add_argument('--result_dir', type=str, default='data/results',help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='data/logs',help='Directory name to save training logs')
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
# --checkpoint_dir
check_folder(args.checkpoint_dir)
# --result_dir
check_folder(args.result_dir)
# --result_dir
check_folder(args.log_dir)
# --epoch
assert args.epoch >= 1, 'number of epochs must be larger than or equal to one'
# --batch_size
assert args.batch_size >= 1, 'batch size must be larger than or equal to one'
# --z_dim
assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one'
return args
"""main"""
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
# open session
models = [GAN, CGAN, infoGAN, ACGAN, EBGAN, WGAN, WGAN_GP, DRAGAN,
LSGAN, BEGAN, VAE, CVAE]
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
gan = None
for model in models:
if args.gan_type == model.model_name:
gan = model(sess,
epoch=args.epoch,
batch_size=args.batch_size,
z_dim=args.z_dim,
dataset_name=args.dataset,
checkpoint_dir=args.checkpoint_dir,
result_dir=args.result_dir,
log_dir=args.log_dir)
if gan is None:
raise Exception("[!] There is no option for " + args.gan_type)
# build graph
gan.build_model()
# show network architecture
show_all_variables()
# launch the graph in a session
gan.train()
print(" [*] Training finished!")
# visualize learned generator
gan.visualize_results(args.epoch-1)
print(" [*] Testing finished!")
if __name__ == '__main__':
main()
utils.py
"""
Most codes from https://github.com/carpedm20/DCGAN-tensorflow
"""
from __future__ import division
import math
import random
import pprint
import scipy.misc
import numpy as np
from time import gmtime, strftime
from six.moves import xrange
import matplotlib.pyplot as plt
import os, gzip
import tensorflow as tf
import tensorflow.contrib.slim as slim
def load_mnist(dataset_name):
data_dir = os.path.join("./data", dataset_name)
def extract_data(filename, num_data, head_size, data_size):
with gzip.open(filename) as bytestream:
bytestream.read(head_size)
buf = bytestream.read(data_size * num_data)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
return data
data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
trX = data.reshape((60000, 28, 28, 1))
data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
trY = data.reshape((60000))
data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
teX = data.reshape((10000, 28, 28, 1))
data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
teY = data.reshape((10000))
trY = np.asarray(trY)
teY = np.asarray(teY)
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0).astype(np.int)
seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i, y[i]] = 1.0
return X / 255., y_vec
def check_folder(log_dir):
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir
def show_all_variables():
model_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, crop=True, grayscale=False):
image = imread(image_path, grayscale)
return transform(image, input_height, input_width, resize_height, resize_width, crop)
def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)
def imread(path, grayscale = False):
if (grayscale):
return scipy.misc.imread(path, flatten = True).astype(np.float)
else:
return scipy.misc.imread(path).astype(np.float)
def merge_images(images, size):
return inverse_transform(images)
def merge(images, size):
h, w = images.shape[1], images.shape[2]
if (images.shape[3] in (3,4)):
c = images.shape[3]
img = np.zeros((h * size[0], w * size[1], c))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w, :] = image
return img
elif images.shape[3]==1:
img = np.zeros((h * size[0], w * size[1]))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
return img
else:
raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
def imsave(images, size, path):
image = np.squeeze(merge(images, size))
return scipy.misc.imsave(path, image)
def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64):
if crop_w is None:
crop_w = crop_h
h, w = x.shape[:2]
j = int(round((h - crop_h)/2.))
i = int(round((w - crop_w)/2.))
return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])
def transform(image, input_height, input_width, resize_height=64, resize_width=64, crop=True):
if crop:
cropped_image = center_crop(image, input_height, input_width, resize_height, resize_width)
else:
cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])
return np.array(cropped_image)/127.5 - 1.
def inverse_transform(images):
return (images+1.)/2.
""" Drawing Tools """
# borrowed from https://github.com/ykwon0407/variational_autoencoder/blob/master/variational_bayes.ipynb
def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'):
N = 10
plt.figure(figsize=(8, 6))
plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.colorbar(ticks=range(N))
axes = plt.gca()
axes.set_xlim([-z_range_x, z_range_x])
axes.set_ylim([-z_range_y, z_range_y])
plt.grid(True)
plt.savefig(name)
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
def discrete_cmap(N, base_cmap=None):
"""Create an N-bin discrete colormap from the specified input map"""
# Note that if base_cmap is a string or None, you can simply do
# return plt.cm.get_cmap(base_cmap, N)
# The following works for string, None, or a colormap instance:
base = plt.cm.get_cmap(base_cmap)
color_list = base(np.linspace(0, 1, N))
cmap_name = base.name + str(N)
return base.from_list(cmap_name, color_list, N)
ops.py
"""
Most codes from https://github.com/carpedm20/DCGAN-tensorflow
"""
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
from utils import *
if "concat_v2" in dir(tf):
def concat(tensors, axis, *args, **kwargs):
return tf.concat_v2(tensors, axis, *args, **kwargs)
else:
def concat(tensors, axis, *args, **kwargs):
return tf.concat(tensors, axis, *args, **kwargs)
def bn(x, is_training, scope):
return tf.contrib.layers.batch_norm(x,
decay=0.9,
updates_collections=None,
epsilon=1e-5,
scale=True,
is_training=is_training,
scope=scope)
def conv_out_size_same(size, stride):
return int(math.ceil(float(size) / float(stride)))
def conv_cond_concat(x, y):
"""Concatenate conditioning vector on feature map axis."""
x_shapes = x.get_shape()
y_shapes = y.get_shape()
return concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"):
with tf.variable_scope(name):
w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
return conv
def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", stddev=0.02, with_w=False):
with tf.variable_scope(name):
# filter : [height, width, output_channels, in_channels]
w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
initializer=tf.random_normal_initializer(stddev=stddev))
try:
deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1])
# Support for verisons of TensorFlow before 0.7.0
except AttributeError:
deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1])
biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
if with_w:
return deconv, w, biases
else:
return deconv
def lrelu(x, leak=0.2, name="lrelu"):
return tf.maximum(x, leak*x)
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
shape = input_.get_shape().as_list()
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
tf.random_normal_initializer(stddev=stddev))
bias = tf.get_variable("bias", [output_size],
initializer=tf.constant_initializer(bias_start))
if with_w:
return tf.matmul(input_, matrix) + bias, matrix, bias
else:
return tf.matmul(input_, matrix) + bias
prior_factory.py
"""
Most codes from https://github.com/musyoku/adversarial-autoencoder/blob/master/sampler.py
"""
import numpy as np
from math import sin,cos,sqrt
def onehot_categorical(batch_size, n_labels):
y = np.zeros((batch_size, n_labels), dtype=np.float32)
indices = np.random.randint(0, n_labels, batch_size)
for b in range(batch_size):
y[b, indices[b]] = 1
return y
def uniform(batch_size, n_dim, n_labels=10, minv=-1, maxv=1, label_indices=None):
if label_indices is not None:
if n_dim != 2:
raise Exception("n_dim must be 2.")
def sample(label, n_labels):
num = int(np.ceil(np.sqrt(n_labels)))
size = (maxv-minv)*1.0/num
x, y = np.random.uniform(-size/2, size/2, (2,))
i = label / num
j = label % num
x += j*size+minv+0.5*size
y += i*size+minv+0.5*size
return np.array([x, y]).reshape((2,))
z = np.empty((batch_size, n_dim), dtype=np.float32)
for batch in range(batch_size):
for zi in range((int)(n_dim/2)):
z[batch, zi*2:zi*2+2] = sample(label_indices[batch], n_labels)
else:
z = np.random.uniform(minv, maxv, (batch_size, n_dim)).astype(np.float32)
return z
def gaussian(batch_size, n_dim, mean=0, var=1, n_labels=10, use_label_info=False):
if use_label_info:
if n_dim != 2:
raise Exception("n_dim must be 2.")
def sample(n_labels):
x, y = np.random.normal(mean, var, (2,))
angle = np.angle((x-mean) + 1j*(y-mean), deg=True)
label = ((int)(n_labels*angle))//360
if label<0:
label+=n_labels
return np.array([x, y]).reshape((2,)), label
z = np.empty((batch_size, n_dim), dtype=np.float32)
z_id = np.empty((batch_size, 1), dtype=np.int32)
for batch in range(batch_size):
for zi in range((int)(n_dim/2)):
a_sample, a_label = sample(n_labels)
z[batch, zi*2:zi*2+2] = a_sample
z_id[batch] = a_label
return z, z_id
else:
z = np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32)
return z
def gaussian_mixture(batch_size, n_dim=2, n_labels=10, x_var=0.5, y_var=0.1, label_indices=None):
if n_dim != 2:
raise Exception("n_dim must be 2.")
def sample(x, y, label, n_labels):
shift = 1.4
r = 2.0 * np.pi / float(n_labels) * float(label)
new_x = x * cos(r) - y * sin(r)
new_y = x * sin(r) + y * cos(r)
new_x += shift * cos(r)
new_y += shift * sin(r)
return np.array([new_x, new_y]).reshape((2,))
x = np.random.normal(0, x_var, (batch_size, (int)(n_dim/2)))
y = np.random.normal(0, y_var, (batch_size, (int)(n_dim/2)))
z = np.empty((batch_size, n_dim), dtype=np.float32)
for batch in range(batch_size):
for zi in range((int)(n_dim/2)):
if label_indices is not None:
z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], label_indices[batch], n_labels)
else:
z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], np.random.randint(0, n_labels), n_labels)
return z
def swiss_roll(batch_size, n_dim=2, n_labels=10, label_indices=None):
if n_dim != 2:
raise Exception("n_dim must be 2.")
def sample(label, n_labels):
uni = np.random.uniform(0.0, 1.0) / float(n_labels) + float(label) / float(n_labels)
r = sqrt(uni) * 3.0
rad = np.pi * 4.0 * sqrt(uni)
x = r * cos(rad)
y = r * sin(rad)
return np.array([x, y]).reshape((2,))
z = np.zeros((batch_size, n_dim), dtype=np.float32)
for batch in range(batch_size):
for zi in range((int)(n_dim/2)):
if label_indices is not None:
z[batch, zi*2:zi*2+2] = sample(label_indices[batch], n_labels)
else:
z[batch, zi*2:zi*2+2] = sample(np.random.randint(0, n_labels), n_labels)
return z
GAN.py
#-*- coding: utf-8 -*-
from __future__ import division
import os
import time
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
class GAN(object):
model_name = "GAN" # name for checkpoint
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
self.sess = sess
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
self.result_dir = result_dir
self.log_dir = log_dir
self.epoch = epoch
self.batch_size = batch_size
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
# parameters
self.input_height = 28
self.input_width = 28
self.output_height = 28
self.output_width = 28
self.z_dim = z_dim # dimension of noise-vector
self.c_dim = 1
# train
self.learning_rate = 0.0002
self.beta1 = 0.5
# test
self.sample_num = 64 # number of generated images to be saved
# load mnist
self.data_X, self.data_y = load_mnist(self.dataset_name)
# get number of batches for a single epoch
self.num_batches = len(self.data_X) // self.batch_size
else:
raise NotImplementedError
def discriminator(self, x, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
with tf.variable_scope("discriminator", reuse=reuse):
net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
net = tf.reshape(net, [self.batch_size, -1])
net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
out_logit = linear(net, 1, scope='d_fc4')
out = tf.nn.sigmoid(out_logit)
return out, out_logit, net
def generator(self, z, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
with tf.variable_scope("generator", reuse=reuse):
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
net = tf.nn.relu(
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
scope='g_bn3'))
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
return out
def build_model(self):
# some parameters
image_dims = [self.input_height, self.input_width, self.c_dim]
bs = self.batch_size
""" Graph Input """
# images
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
# noises
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
""" Loss Function """
# output of D for real images
D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False)
# output of D for fake images
G = self.generator(self.z, is_training=True, reuse=False)
D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True)
# get loss for discriminator
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))
self.d_loss = d_loss_real + d_loss_fake
# get loss for generator
self.g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake)))
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
# optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
.minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
.minimize(self.g_loss, var_list=g_vars)
"""" Testing """
# for test
self.fake_images = self.generator(self.z, is_training=False, reuse=True)
""" Summary """
d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
# final summary operations
self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# graph inputs for visualize training results
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
# saver to save model
self.saver = tf.train.Saver()
# summary writer
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
# restore check-point if it exits
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
start_epoch = (int)(checkpoint_counter / self.num_batches)
start_batch_id = checkpoint_counter - start_epoch * self.num_batches
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print(" [!] Load failed...")
# loop for epoch
start_time = time.time()
for epoch in range(start_epoch, self.epoch):
# get batch data
for idx in range(start_batch_id, self.num_batches):
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
# update D network
_, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss],
feed_dict={self.inputs: batch_images, self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# update G network
_, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# display training status
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
# save training results for every 300 steps
if np.mod(counter, 300) == 0:
samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z})
tot_num_samples = min(self.sample_num, self.batch_size)
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
epoch, idx))
# After an epoch, start_batch_id is set to zero
# non-zero value is only for the first epoch after loading pre-trained model
start_batch_id = 0
# save model
self.save(self.checkpoint_dir, counter)
# show temporal results
self.visualize_results(epoch)
# save model for final step
self.save(self.checkpoint_dir, counter)
def visualize_results(self, epoch):
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.model_name, self.dataset_name,
self.batch_size, self.z_dim)
def save(self, checkpoint_dir, step):
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
def load(self, checkpoint_dir):
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0
ACGAN.py
#-*- coding: utf-8 -*-
from __future__ import division
import os
import time
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
class ACGAN(object):
model_name = "ACGAN" # name for checkpoint
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
self.sess = sess
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
self.result_dir = result_dir
self.log_dir = log_dir
self.epoch = epoch
self.batch_size = batch_size
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
# parameters
self.input_height = 28
self.input_width = 28
self.output_height = 28
self.output_width = 28
self.z_dim = z_dim # dimension of noise-vector
self.y_dim = 10 # dimension of code-vector (label)
self.c_dim = 1
# train
self.learning_rate = 0.0002
self.beta1 = 0.5
# test
self.sample_num = 64 # number of generated images to be saved
# code
self.len_discrete_code = 10 # categorical distribution (i.e. label)
self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness)
# load mnist
self.data_X, self.data_y = load_mnist(self.dataset_name)
# get number of batches for a single epoch
self.num_batches = len(self.data_X) // self.batch_size
else:
raise NotImplementedError
def classifier(self, x, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : (64)5c2s-(128)5c2s_BL-FC1024_BL-FC128_BL-FC12S’
# All layers except the last two layers are shared by discriminator
with tf.variable_scope("classifier", reuse=reuse):
net = lrelu(bn(linear(x, 128, scope='c_fc1'), is_training=is_training, scope='c_bn1'))
out_logit = linear(net, self.y_dim, scope='c_fc2')
out = tf.nn.softmax(out_logit)
return out, out_logit
def discriminator(self, x, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
with tf.variable_scope("discriminator", reuse=reuse):
net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
net = tf.reshape(net, [self.batch_size, -1])
net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
out_logit = linear(net, 1, scope='d_fc4')
out = tf.nn.sigmoid(out_logit)
return out, out_logit, net
def generator(self, z, y, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
with tf.variable_scope("generator", reuse=reuse):
# merge noise and code
z = concat([z, y], 1)
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
net = tf.nn.relu(
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
scope='g_bn3'))
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
return out
def build_model(self):
# some parameters
image_dims = [self.input_height, self.input_width, self.c_dim]
bs = self.batch_size
""" Graph Input """
# images
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
# labels
self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y')
# noises
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
""" Loss Function """
## 1. GAN Loss
# output of D for real images
D_real, D_real_logits, input4classifier_real = self.discriminator(self.inputs, is_training=True, reuse=False)
# output of D for fake images
G = self.generator(self.z, self.y, is_training=True, reuse=False)
D_fake, D_fake_logits, input4classifier_fake = self.discriminator(G, is_training=True, reuse=True)
# get loss for discriminator
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))
self.d_loss = d_loss_real + d_loss_fake
# get loss for generator
self.g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake)))
## 2. Information Loss
code_fake, code_logit_fake = self.classifier(input4classifier_fake, is_training=True, reuse=False)
code_real, code_logit_real = self.classifier(input4classifier_real, is_training=True, reuse=True)
# For real samples
q_real_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=code_logit_real, labels=self.y))
# For fake samples
q_fake_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=code_logit_fake, labels=self.y))
# get information loss
self.q_loss = q_fake_loss + q_real_loss
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
q_vars = [var for var in t_vars if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)]
# optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
.minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
.minimize(self.g_loss, var_list=g_vars)
self.q_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \
.minimize(self.q_loss, var_list=q_vars)
"""" Testing """
# for test
self.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True)
""" Summary """
d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
q_loss_sum = tf.summary.scalar("g_loss", self.q_loss)
q_real_sum = tf.summary.scalar("q_real_loss", q_real_loss)
q_fake_sum = tf.summary.scalar("q_fake_loss", q_fake_loss)
# final summary operations
self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
self.q_sum = tf.summary.merge([q_loss_sum, q_real_sum, q_fake_sum])
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# graph inputs for visualize training results
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
self.test_codes = self.data_y[0:self.batch_size]
# saver to save model
self.saver = tf.train.Saver()
# summary writer
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
# restore check-point if it exits
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
start_epoch = (int)(checkpoint_counter / self.num_batches)
start_batch_id = checkpoint_counter - start_epoch * self.num_batches
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print(" [!] Load failed...")
# loop for epoch
start_time = time.time()
for epoch in range(start_epoch, self.epoch):
# get batch data
for idx in range(start_batch_id, self.num_batches):
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_codes = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
# update D network
_, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss],
feed_dict={self.inputs: batch_images, self.y: batch_codes,
self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# update G & Q network
_, summary_str_g, g_loss, _, summary_str_q, q_loss = self.sess.run(
[self.g_optim, self.g_sum, self.g_loss, self.q_optim, self.q_sum, self.q_loss],
feed_dict={self.z: batch_z, self.y: batch_codes, self.inputs: batch_images})
self.writer.add_summary(summary_str_g, counter)
self.writer.add_summary(summary_str_q, counter)
# display training status
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
# save training results for every 300 steps
if np.mod(counter, 300) == 0:
samples = self.sess.run(self.fake_images,
feed_dict={self.z: self.sample_z, self.y: self.test_codes})
tot_num_samples = min(self.sample_num, self.batch_size)
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], './' + check_folder(
self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
epoch, idx))
# After an epoch, start_batch_id is set to zero
# non-zero value is only for the first epoch after loading pre-trained model
start_batch_id = 0
# save model
self.save(self.checkpoint_dir, counter)
# show temporal results
self.visualize_results(epoch)
# save model for final step
self.save(self.checkpoint_dir, counter)
def visualize_results(self, epoch):
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
""" random noise, random discrete code, fixed continuous code """
y = np.random.choice(self.len_discrete_code, self.batch_size)
y_one_hot = np.zeros((self.batch_size, self.y_dim))
y_one_hot[np.arange(self.batch_size), y] = 1
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
""" specified condition, random noise """
n_styles = 10 # must be less than or equal to self.batch_size
np.random.seed()
si = np.random.choice(self.batch_size, n_styles)
for l in range(self.len_discrete_code):
y = np.zeros(self.batch_size, dtype=np.int64) + l
y_one_hot = np.zeros((self.batch_size, self.y_dim))
y_one_hot[np.arange(self.batch_size), y] = 1
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
save_images(samples[:image_frame_dim*image_frame_dim,:,:,:], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
samples = samples[si, :, :, :]
if l == 0:
all_samples = samples
else:
all_samples = np.concatenate((all_samples, samples), axis=0)
""" save merged images to check style-consistency """
canvas = np.zeros_like(all_samples)
for s in range(n_styles):
for c in range(self.len_discrete_code):
canvas[s * self.len_discrete_code + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]
save_images(canvas, [n_styles, self.len_discrete_code],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.model_name, self.dataset_name,
self.batch_size, self.z_dim)
def save(self, checkpoint_dir, step):
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
def load(self, checkpoint_dir):
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0
BEGAN.py
#-*- coding: utf-8 -*-
from __future__ import division
import os
import time
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
class BEGAN(object):
model_name = "BEGAN" # name for checkpoint
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
self.sess = sess
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
self.result_dir = result_dir
self.log_dir = log_dir
self.epoch = epoch
self.batch_size = batch_size
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
# parameters
self.input_height = 28
self.input_width = 28
self.output_height = 28
self.output_width = 28
self.z_dim = z_dim # dimension of noise-vector
self.c_dim = 1
# BEGAN Parameter
self.gamma = 0.75
self.lamda = 0.001
# train
self.learning_rate = 0.0002
self.beta1 = 0.5
# test
self.sample_num = 64 # number of generated images to be saved
# load mnist
self.data_X, self.data_y = load_mnist(self.dataset_name)
# get number of batches for a single epoch
self.num_batches = len(self.data_X) // self.batch_size
else:
raise NotImplementedError
def discriminator(self, x, is_training=True, reuse=False):
# It must be Auto-Encoder style architecture
# Architecture : (64)4c2s-FC32_BR-FC64*14*14_BR-(1)4dc2s_S
with tf.variable_scope("discriminator", reuse=reuse):
net = tf.nn.relu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
net = tf.reshape(net, [self.batch_size, -1])
code = tf.nn.relu(bn(linear(net, 32, scope='d_fc6'), is_training=is_training, scope='d_bn6'))
net = tf.nn.relu(bn(linear(code, 64 * 14 * 14, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
net = tf.reshape(net, [self.batch_size, 14, 14, 64])
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='d_dc5'))
# recon loss
recon_error = tf.sqrt(2 * tf.nn.l2_loss(out - x)) / self.batch_size
return out, recon_error, code
def generator(self, z, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
with tf.variable_scope("generator", reuse=reuse):
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
net = tf.nn.relu(
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
scope='g_bn3'))
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
return out
def build_model(self):
# some parameters
image_dims = [self.input_height, self.input_width, self.c_dim]
bs = self.batch_size
""" BEGAN variable """
self.k = tf.Variable(0., trainable=False)
""" Graph Input """
# images
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
# noises
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
""" Loss Function """
# output of D for real images
D_real_img, D_real_err, D_real_code = self.discriminator(self.inputs, is_training=True, reuse=False)
# output of D for fake images
G = self.generator(self.z, is_training=True, reuse=False)
D_fake_img, D_fake_err, D_fake_code = self.discriminator(G, is_training=True, reuse=True)
# get loss for discriminator
self.d_loss = D_real_err - self.k*D_fake_err
# get loss for generator
self.g_loss = D_fake_err
# convergence metric
self.M = D_real_err + tf.abs(self.gamma*D_real_err - D_fake_err)
# operation for updating k
self.update_k = self.k.assign(
tf.clip_by_value(self.k + self.lamda*(self.gamma*D_real_err - D_fake_err), 0, 1))
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
# optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
.minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
.minimize(self.g_loss, var_list=g_vars)
"""" Testing """
# for test
self.fake_images = self.generator(self.z, is_training=False, reuse=True)
""" Summary """
d_loss_real_sum = tf.summary.scalar("d_error_real", D_real_err)
d_loss_fake_sum = tf.summary.scalar("d_error_fake", D_fake_err)
d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
M_sum = tf.summary.scalar("M", self.M)
k_sum = tf.summary.scalar("k", self.k)
# final summary operations
self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
self.p_sum = tf.summary.merge([M_sum, k_sum])
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# graph inputs for visualize training results
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
# saver to save model
self.saver = tf.train.Saver()
# summary writer
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
# restore check-point if it exits
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
start_epoch = (int)(checkpoint_counter / self.num_batches)
start_batch_id = checkpoint_counter - start_epoch * self.num_batches
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print(" [!] Load failed...")
# loop for epoch
start_time = time.time()
for epoch in range(start_epoch, self.epoch):
# get batch data
for idx in range(start_batch_id, self.num_batches):
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
# update D network
_, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss],
feed_dict={self.inputs: batch_images, self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# update G network
_, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# update k
_, summary_str, M_value, k_value = self.sess.run([self.update_k, self.p_sum, self.M, self.k], feed_dict={self.inputs: batch_images, self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# display training status
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, M: %.8f, k: %.8f" \
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss, M_value, k_value))
# save training results for every 300 steps
if np.mod(counter, 300) == 0:
samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z})
tot_num_samples = min(self.sample_num, self.batch_size)
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
epoch, idx))
# After an epoch, start_batch_id is set to zero
# non-zero value is only for the first epoch after loading pre-trained model
start_batch_id = 0
# save model
self.save(self.checkpoint_dir, counter)
# show temporal results
self.visualize_results(epoch)
# save model for final step
self.save(self.checkpoint_dir, counter)
def visualize_results(self, epoch):
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.model_name, self.dataset_name,
self.batch_size, self.z_dim)
def save(self, checkpoint_dir, step):
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
def load(self, checkpoint_dir):
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0
CGAN.py
#-*- coding: utf-8 -*-
from __future__ import division
import os
import time
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
class CGAN(object):
model_name = "CGAN" # name for checkpoint
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
self.sess = sess
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
self.result_dir = result_dir
self.log_dir = log_dir
self.epoch = epoch
self.batch_size = batch_size
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
# parameters
self.input_height = 28
self.input_width = 28
self.output_height = 28
self.output_width = 28
self.z_dim = z_dim # dimension of noise-vector
self.y_dim = 10 # dimension of condition-vector (label)
self.c_dim = 1
# train
self.learning_rate = 0.0002
self.beta1 = 0.5
# test
self.sample_num = 64 # number of generated images to be saved
# load mnist
self.data_X, self.data_y = load_mnist(self.dataset_name)
# get number of batches for a single epoch
self.num_batches = len(self.data_X) // self.batch_size
else:
raise NotImplementedError
def discriminator(self, x, y, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
with tf.variable_scope("discriminator", reuse=reuse):
# merge image and label
y = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
x = conv_cond_concat(x, y)
net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
net = tf.reshape(net, [self.batch_size, -1])
net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
out_logit = linear(net, 1, scope='d_fc4')
out = tf.nn.sigmoid(out_logit)
return out, out_logit, net
def generator(self, z, y, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
with tf.variable_scope("generator", reuse=reuse):
# merge noise and label
z = concat([z, y], 1)
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
net = tf.nn.relu(
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
scope='g_bn3'))
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
return out
def build_model(self):
# some parameters
image_dims = [self.input_height, self.input_width, self.c_dim]
bs = self.batch_size
""" Graph Input """
# images
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
# labels
self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y')
# noises
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
""" Loss Function """
# output of D for real images
D_real, D_real_logits, _ = self.discriminator(self.inputs, self.y, is_training=True, reuse=False)
# output of D for fake images
G = self.generator(self.z, self.y, is_training=True, reuse=False)
D_fake, D_fake_logits, _ = self.discriminator(G, self.y, is_training=True, reuse=True)
# get loss for discriminator
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))
self.d_loss = d_loss_real + d_loss_fake
# get loss for generator
self.g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake)))
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
# optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
.minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
.minimize(self.g_loss, var_list=g_vars)
"""" Testing """
# for test
self.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True)
""" Summary """
d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
# final summary operations
self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# graph inputs for visualize training results
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
self.test_labels = self.data_y[0:self.batch_size]
# saver to save model
self.saver = tf.train.Saver()
# summary writer
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
# restore check-point if it exits
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
start_epoch = (int)(checkpoint_counter / self.num_batches)
start_batch_id = checkpoint_counter - start_epoch * self.num_batches
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print(" [!] Load failed...")
# loop for epoch
start_time = time.time()
for epoch in range(start_epoch, self.epoch):
# get batch data
for idx in range(start_batch_id, self.num_batches):
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
# update D network
_, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss],
feed_dict={self.inputs: batch_images, self.y: batch_labels,
self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# update G network
_, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss],
feed_dict={self.y: batch_labels, self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# display training status
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
# save training results for every 300 steps
if np.mod(counter, 300) == 0:
samples = self.sess.run(self.fake_images,
feed_dict={self.z: self.sample_z, self.y: self.test_labels})
tot_num_samples = min(self.sample_num, self.batch_size)
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
epoch, idx))
# After an epoch, start_batch_id is set to zero
# non-zero value is only for the first epoch after loading pre-trained model
start_batch_id = 0
# save model
self.save(self.checkpoint_dir, counter)
# show temporal results
self.visualize_results(epoch)
# save model for final step
self.save(self.checkpoint_dir, counter)
def visualize_results(self, epoch):
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
y = np.random.choice(self.y_dim, self.batch_size)
y_one_hot = np.zeros((self.batch_size, self.y_dim))
y_one_hot[np.arange(self.batch_size), y] = 1
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
""" specified condition, random noise """
n_styles = 10 # must be less than or equal to self.batch_size
np.random.seed()
si = np.random.choice(self.batch_size, n_styles)
for l in range(self.y_dim):
y = np.zeros(self.batch_size, dtype=np.int64) + l
y_one_hot = np.zeros((self.batch_size, self.y_dim))
y_one_hot[np.arange(self.batch_size), y] = 1
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample, self.y: y_one_hot})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l)
samples = samples[si, :, :, :]
if l == 0:
all_samples = samples
else:
all_samples = np.concatenate((all_samples, samples), axis=0)
""" save merged images to check style-consistency """
canvas = np.zeros_like(all_samples)
for s in range(n_styles):
for c in range(self.y_dim):
canvas[s * self.y_dim + c, :, :, :] = all_samples[c * n_styles + s, :, :, :]
save_images(canvas, [n_styles, self.y_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png')
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.model_name, self.dataset_name,
self.batch_size, self.z_dim)
def save(self, checkpoint_dir, step):
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
def load(self, checkpoint_dir):
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0
LSGAN.py
#-*- coding: utf-8 -*-
from __future__ import division
import os
import time
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
class LSGAN(object):
model_name = "LSGAN" # name for checkpoint
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
self.sess = sess
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
self.result_dir = result_dir
self.log_dir = log_dir
self.epoch = epoch
self.batch_size = batch_size
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
# parameters
self.input_height = 28
self.input_width = 28
self.output_height = 28
self.output_width = 28
self.z_dim = z_dim # dimension of noise-vector
self.c_dim = 1
# train
self.learning_rate = 0.0002
self.beta1 = 0.5
# test
self.sample_num = 64 # number of generated images to be saved
# load mnist
self.data_X, self.data_y = load_mnist(self.dataset_name)
# get number of batches for a single epoch
self.num_batches = len(self.data_X) // self.batch_size
else:
raise NotImplementedError
def discriminator(self, x, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
with tf.variable_scope("discriminator", reuse=reuse):
net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
net = tf.reshape(net, [self.batch_size, -1])
net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
out_logit = linear(net, 1, scope='d_fc4')
out = tf.nn.sigmoid(out_logit)
return out, out_logit, net
def generator(self, z, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
with tf.variable_scope("generator", reuse=reuse):
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
net = tf.nn.relu(
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
scope='g_bn3'))
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
return out
def mse_loss(self, pred, data):
loss_val = tf.sqrt(2 * tf.nn.l2_loss(pred - data)) / self.batch_size
return loss_val
def build_model(self):
# some parameters
image_dims = [self.input_height, self.input_width, self.c_dim]
bs = self.batch_size
""" Graph Input """
# images
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
# noises
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
""" Loss Function """
# output of D for real images
D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False)
# output of D for fake images
G = self.generator(self.z, is_training=True, reuse=False)
D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True)
# get loss for discriminator
d_loss_real = tf.reduce_mean(self.mse_loss(D_real_logits, tf.ones_like(D_real_logits)))
d_loss_fake = tf.reduce_mean(self.mse_loss(D_fake_logits, tf.zeros_like(D_fake_logits)))
self.d_loss = 0.5*(d_loss_real + d_loss_fake)
# get loss for generator
self.g_loss = tf.reduce_mean(self.mse_loss(D_fake_logits, tf.ones_like(D_fake_logits)))
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
# optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
.minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
.minimize(self.g_loss, var_list=g_vars)
# weight clipping
self.clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars]
"""" Testing """
# for test
self.fake_images = self.generator(self.z, is_training=False, reuse=True)
""" Summary """
d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
# final summary operations
self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# graph inputs for visualize training results
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
# saver to save model
self.saver = tf.train.Saver()
# summary writer
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
# restore check-point if it exits
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
start_epoch = (int)(checkpoint_counter / self.num_batches)
start_batch_id = checkpoint_counter - start_epoch * self.num_batches
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print(" [!] Load failed...")
# loop for epoch
start_time = time.time()
for epoch in range(start_epoch, self.epoch):
# get batch data
for idx in range(start_batch_id, self.num_batches):
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
# update D network
_, _, summary_str, d_loss = self.sess.run([self.d_optim, self.clip_D, self.d_sum, self.d_loss],
feed_dict={self.inputs: batch_images, self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# update G network
_, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# display training status
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
# save training results for every 300 steps
if np.mod(counter, 300) == 0:
samples = self.sess.run(self.fake_images,
feed_dict={self.z: self.sample_z})
tot_num_samples = min(self.sample_num, self.batch_size)
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
epoch, idx))
# After an epoch, start_batch_id is set to zero
# non-zero value is only for the first epoch after loading pre-trained model
start_batch_id = 0
# save model
self.save(self.checkpoint_dir, counter)
# show temporal results
self.visualize_results(epoch)
# save model for final step
self.save(self.checkpoint_dir, counter)
def visualize_results(self, epoch):
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.model_name, self.dataset_name,
self.batch_size, self.z_dim)
def save(self, checkpoint_dir, step):
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
def load(self, checkpoint_dir):
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0
WGAN.py
#-*- coding: utf-8 -*-
from __future__ import division
import os
import time
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
class WGAN(object):
model_name = "WGAN" # name for checkpoint
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
self.sess = sess
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
self.result_dir = result_dir
self.log_dir = log_dir
self.epoch = epoch
self.batch_size = batch_size
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
# parameters
self.input_height = 28
self.input_width = 28
self.output_height = 28
self.output_width = 28
self.z_dim = z_dim # dimension of noise-vector
self.c_dim = 1
# WGAN parameter
self.disc_iters = 1 # The number of critic iterations for one-step of generator
# train
self.learning_rate = 0.0002
self.beta1 = 0.5
# test
self.sample_num = 64 # number of generated images to be saved
# load mnist
self.data_X, self.data_y = load_mnist(self.dataset_name)
# get number of batches for a single epoch
self.num_batches = len(self.data_X) // self.batch_size
else:
raise NotImplementedError
def discriminator(self, x, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
with tf.variable_scope("discriminator", reuse=reuse):
net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
net = tf.reshape(net, [self.batch_size, -1])
net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
out_logit = linear(net, 1, scope='d_fc4')
out = tf.nn.sigmoid(out_logit)
return out, out_logit, net
def generator(self, z, is_training=True, reuse=False):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
with tf.variable_scope("generator", reuse=reuse):
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
net = tf.nn.relu(
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
scope='g_bn3'))
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
return out
def build_model(self):
# some parameters
image_dims = [self.input_height, self.input_width, self.c_dim]
bs = self.batch_size
""" Graph Input """
# images
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
# noises
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
""" Loss Function """
# output of D for real images
D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False)
# output of D for fake images
G = self.generator(self.z, is_training=True, reuse=False)
D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True)
# get loss for discriminator
d_loss_real = - tf.reduce_mean(D_real_logits)
d_loss_fake = tf.reduce_mean(D_fake_logits)
self.d_loss = d_loss_real + d_loss_fake
# get loss for generator
self.g_loss = - d_loss_fake
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'd_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
# optimizers
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
.minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
.minimize(self.g_loss, var_list=g_vars)
# weight clipping
self.clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars]
"""" Testing """
# for test
self.fake_images = self.generator(self.z, is_training=False, reuse=True)
""" Summary """
d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
# final summary operations
self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# graph inputs for visualize training results
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
# saver to save model
self.saver = tf.train.Saver()
# summary writer
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
# restore check-point if it exits
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
if could_load:
start_epoch = (int)(checkpoint_counter / self.num_batches)
start_batch_id = checkpoint_counter - start_epoch * self.num_batches
counter = checkpoint_counter
print(" [*] Load SUCCESS")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print(" [!] Load failed...")
# loop for epoch
start_time = time.time()
for epoch in range(start_epoch, self.epoch):
# get batch data
for idx in range(start_batch_id, self.num_batches):
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
# update D network
_, _, summary_str, d_loss = self.sess.run([self.d_optim, self.clip_D, self.d_sum, self.d_loss],
feed_dict={self.inputs: batch_images, self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# update G network
if (counter - 1) % self.disc_iters == 0:
_, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z})
self.writer.add_summary(summary_str, counter)
# display training status
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
# save training results for every 300 steps
if np.mod(counter, 300) == 0:
samples = self.sess.run(self.fake_images,
feed_dict={self.z: self.sample_z})
tot_num_samples = min(self.sample_num, self.batch_size)
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
epoch, idx))
# After an epoch, start_batch_id is set to zero
# non-zero value is only for the first epoch after loading pre-trained model
start_batch_id = 0
# save model
self.save(self.checkpoint_dir, counter)
# show temporal results
self.visualize_results(epoch)
# save model for final step
self.save(self.checkpoint_dir, counter)
def visualize_results(self, epoch):
tot_num_samples = min(self.sample_num, self.batch_size)
image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
""" random condition, random noise """
z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.model_name, self.dataset_name,
self.batch_size, self.z_dim)
def save(self, checkpoint_dir, step):
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
def load(self, checkpoint_dir):
import re
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0