使用fashion-mnist训练数据集,测试样本的标签使用的是前batch_size个图像所对应的标签
# -*- coding: utf-8 -*-
# @Time : 2019/3/10 20:52
# @Author : YYLin
# @Email : 854280599@qq.com
import os
import time
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
class CGAN(object):
model_name = "CGAN" # name for checkpoint
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, result_dir):
self.sess = sess
self.dataset_name = dataset_name
self.result_dir = result_dir
self.epoch = epoch
self.batch_size = batch_size
if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
# parameters
self.input_height = 28
self.input_width = 28
self.output_height = 28
self.output_width = 28
self.z_dim = z_dim # dimension of noise-vector
self.y_dim = 10 # dimension of condition-vector (label)
self.c_dim = 1
# train
self.learning_rate = 0.0002
self.beta1 = 0.5
self.sample_num = 64
self.data_X, self.data_y = load_mnist()
# get number of batches for a single epoch
self.num_batches = len(self.data_X) // self.batch_size
else:
print("********there is no other dataset to do *********")
raise NotImplementedError
def discriminator(self, x, y, is_training=True, reuse=False):
with tf.variable_scope("discriminator", reuse=reuse):
y = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
x = conv_cond_concat(x, y)
net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
net = tf.reshape(net, [self.batch_size, -1])
net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
out_logit = linear(net, 1, scope='d_fc4')
out = tf.nn.sigmoid(out_logit)
return out, out_logit
def generator(self, z, y, is_training=True, reuse=False):
with tf.variable_scope("generator", reuse=reuse):
z = concat([z, y], 1)
net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
net = tf.reshape(net, [self.batch_size, 7, 7, 128])
net = tf.nn.relu(
bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
scope='g_bn3'))
out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))
return out
def build_model(self):
image_dims = [self.input_height, self.input_width, self.c_dim]
bs = self.batch_size
# 定义函数中使用的数据的占位符
self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y')
self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
# 定义调用识别器和生成器函数
D_real, D_real_logits = self.discriminator(self.inputs, self.y, is_training=True, reuse=False)
G = self.generator(self.z, self.y, is_training=True, reuse=False)
D_fake, D_fake_logits = self.discriminator(G, self.y, is_training=True, reuse=True)
# 定义识别器的loss function
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))
self.d_loss = d_loss_real + d_loss_fake
# 定义生成器G的loss function
self.g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake)))
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
# 定义AdamOptimizer优化器
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
.minimize(self.d_loss, var_list=d_vars)
self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
.minimize(self.g_loss, var_list=g_vars)
"""" Testing """
# 将is_training设置为False用于后来生成图像
self.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True)
def train(self):
# initialize all variables
tf.global_variables_initializer().run()
# 为了方便比较 图像使用的标签是加载图像是对应前batch_size的标签
self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
self.test_labels = self.data_y[0:self.batch_size]
self.saver = tf.train.Saver()
start_epoch = 0
start_batch_id = 0
counter = 1
# loop for epoch
start_time = time.time()
for epoch in range(start_epoch, self.epoch):
# get batch data
for idx in range(start_batch_id, self.num_batches):
batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
# 识别器训练一次 识别器训练三次
for i_d in range(1):
_, d_loss = self.sess.run([self.d_optim, self.d_loss],
feed_dict={self.inputs: batch_images, self.y: batch_labels,
self.z: batch_z})
for i_g in range(3):
_, g_loss = self.sess.run([self.g_optim, self.g_loss],
feed_dict={self.y: batch_labels, self.z: batch_z})
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
# 保存图像
if np.mod(counter, 300) == 0:
samples = self.sess.run(self.fake_images,
feed_dict={self.z: self.sample_z, self.y: self.test_labels})
tot_num_samples = min(self.sample_num, self.batch_size)
manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
'./' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
epoch, idx))
start_batch_id = 0
@property
def model_dir(self):
return "{}_{}_{}_{}".format(
self.model_name, self.dataset_name,
self.batch_size, self.z_dim)
图像生成的结果: