GAN 学习 DCGAN1
今天学习了Soumith Chintala 的DCGAN,UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS。
下面学习一下具体的编程过程,继续看代码,依旧是tensorflow下做的代码。这个代码没运行成功。先看和
- 一、代码解读
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #下载minis数据的code,
#具体使用方法 from tensorflow.examples.tutorials.mnistimport input_data
# mnist =input_data.read_data_sets("MNIST_data/", one_hot=True)
import numpy as np # 调用numpy 简单的入门学习网址 https://zhuanlan.zhihu.com/p/24309547
import matplotlib as mpl # 调用 matplotlib, 网址同上
mpl.use('Agg') #调用 mpl的Agg渲染包,该命令必须在 import matplotlib.pyplot 之前
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os,sys #sys模块主要是用于提供对python解释器相关的操作;OS模块是Python标准库中的一个用于访问操作系统功能的模块,使用OS模块中提供的接口,可以实现跨平台访问
sys.path.append('utils') #sys.path 返回的是一个列表!该路径已经添加到系统的环境变量了,当我们要添加自己的搜索目录时,可以通过列表的append()方法;对于模块和自己写的脚本不在同一个目录下,在脚本开头加sys.path.append('xxx'):
from nets import *
from datas import *
def sample_z(m, n):
return np.random.uniform(-1., 1., size=[m, n]) #功能:从一个均匀分布[low,high)中随机采样,注意定义域是左闭右开,即包含low,不包含high.
#看一第行,语法是class 后面紧接着,类的名字,最后别忘记“冒号”,这样来定义一个类。类的名字,首字母,有一个不可文的规定,最好是大写,这样需要在代码中识别区分每个类。和函数非常相似,但是与普通函数不同的是,它的内部有一个“self”,参数,它的作用是对于对象自身的引用。
class DCGAN():
def __init__(self, generator, discriminator, data): #类实例初始化函数
self.generator = generator
self.discriminator = discriminator
self.data = data #self: 类实例本身,self.data: 类实例本身的成员"
# data
self.z_dim = self.data.z_dim
self.size = self.data.size
self.channel = self.data.channel
self.X = tf.placeholder(tf.float32, shape=[None, self.size, self.size, self.channel]) #tf.placeholder(dtype, shape=None, name=None)此函数可以理解为形参,用于定义过程,在执行的时候再赋具体的值。
self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim])
# nets
self.G_sample = self.generator(self.z)
self.D_real, _ = self.discriminator(self.X)
self.D_fake, _ = self.discriminator(self.G_sample, reuse = True)
# loss
self.D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real, labels=tf.ones_like(self.D_real))) + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=tf.zeros_like(self.D_fake))) #tf.nn.sigmoid_cross_entropy_with_logits,一种交叉熵函数见http://blog.csdn.net/QW_sunny/article/details/72885403,tf.ones_like是一种拷贝函数,默认情况下,它会拷贝参数tensor的类型,维度等数据,并将其中的值设置为1.
self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=tf.ones_like(self.D_fake)))
# solver
#优化器的简单学习,http://blog.csdn.net/xierhacker/article/details/53174558;minimize 通过更新var_list来减小loss.
self.D_solver = tf.train.AdamOptimizer(learning_rate=2e-4).minimize(self.D_loss, var_list=self.discriminator.vars)
self.G_solver = tf.train.AdamOptimizer(learning_rate=2e-4).minimize(self.G_loss, var_list=self.generator.vars)
self.saver = tf.train.Saver()#我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。
gpu_options = tf.GPUOptions(allow_growth=True) #设置GPU
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
def train(self, sample_dir, ckpt_dir='ckpt', training_epoches = 1000000, batch_size = 32):
fig_count = 0
self.sess.run(tf.global_variables_initializer())
#tf的表达式中所有的变量或者是常量都应该是 tf 的类型。只要是声明了变量,就得用 sess.run(tf.global_variables_initializer()) 或者 x.initializer.run() 方法来初始化才能用。
#下面的就是具体的训练过程.
for epoch in range(training_epoches):
# update D
X_b = self.data(batch_size)
self.sess.run(
self.D_solver,
feed_dict={self.X: X_b, self.z: sample_z(batch_size, self.z_dim)}
)
# update G
k = 1
for _ in range(k):
self.sess.run(
self.G_solver,
feed_dict={self.z: sample_z(batch_size, self.z_dim)}
)
# save img, model. print loss
if epoch % 100 == 0 or epoch < 100:
D_loss_curr = self.sess.run(
self.D_loss,
feed_dict={self.X: X_b, self.z: sample_z(batch_size, self.z_dim)})
G_loss_curr = self.sess.run(
self.G_loss,
feed_dict={self.z: sample_z(batch_size, self.z_dim)})
print('Iter: {}; D loss: {:.4}; G_loss: {:.4}'.format(epoch, D_loss_curr, G_loss_curr))
if epoch % 1000 == 0:
samples = self.sess.run(self.G_sample, feed_dict={self.z: sample_z(16, self.z_dim)})
fig = self.data.data2fig(samples)
plt.savefig('{}/{}.png'.format(sample_dir, str(fig_count).zfill(3)), bbox_inches='tight')
fig_count += 1
plt.close(fig)
#if epoch % 2000 == 0:
#self.saver.save(self.sess, os.path.join(ckpt_dir, "dcgan.ckpt"))
# 一个python的文件有两种使用的方法,第一是直接作为脚本执行,第二是import到其他的python脚本中被调用(模块重用)执行。因此if __name__ == 'main': 的作用就是控制这两种情况执行代码的过程,在if __name__ == 'main': 下的代码只有在第一种情况下(即文件作为脚本直接执行)才会被执行,而import到其他脚本中是不会被执行的。
if __name__ == '__main__':
# constraint GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# save generated images
sample_dir = 'Samples/celebA_dcgan'
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
# param
generator = G_conv()
discriminator = D_conv()
data = celebA()
# run
dcgan = DCGAN(generator, discriminator, data)
dcgan.train(sample_dir)