GAN 学习 (二)

GAN 学习 DCGAN1

今天学习了Soumith Chintala 的DCGAN,UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS
下面学习一下具体的编程过程,继续看代码,依旧是tensorflow下做的代码。这个代码没运行成功。先看和
- 一、代码解读

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #下载minis数据的code,
#具体使用方法 from tensorflow.examples.tutorials.mnistimport input_data   
#            mnist =input_data.read_data_sets("MNIST_data/", one_hot=True) 
import numpy as np  # 调用numpy  简单的入门学习网址 https://zhuanlan.zhihu.com/p/24309547
import matplotlib as mpl # 调用 matplotlib, 网址同上
mpl.use('Agg')  #调用 mpl的Agg渲染包,该命令必须在 import matplotlib.pyplot 之前
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os,sys  #sys模块主要是用于提供对python解释器相关的操作;OS模块是Python标准库中的一个用于访问操作系统功能的模块,使用OS模块中提供的接口,可以实现跨平台访问

sys.path.append('utils') #sys.path 返回的是一个列表!该路径已经添加到系统的环境变量了,当我们要添加自己的搜索目录时,可以通过列表的append()方法;对于模块和自己写的脚本不在同一个目录下,在脚本开头加sys.path.append('xxx'):
from nets import *
from datas import *

def sample_z(m, n):
    return np.random.uniform(-1., 1., size=[m, n]) #功能:从一个均匀分布[low,high)中随机采样,注意定义域是左闭右开,即包含low,不包含high.


 #看一第行,语法是class 后面紧接着,类的名字,最后别忘记“冒号”,这样来定义一个类。类的名字,首字母,有一个不可文的规定,最好是大写,这样需要在代码中识别区分每个类。和函数非常相似,但是与普通函数不同的是,它的内部有一个“self”,参数,它的作用是对于对象自身的引用。
class DCGAN():
    def __init__(self, generator, discriminator, data):  #类实例初始化函数
        self.generator = generator
        self.discriminator = discriminator
        self.data = data    #self: 类实例本身,self.data: 类实例本身的成员"

        # data
        self.z_dim = self.data.z_dim
        self.size = self.data.size
        self.channel = self.data.channel

        self.X = tf.placeholder(tf.float32, shape=[None, self.size, self.size, self.channel])  #tf.placeholder(dtype, shape=None, name=None)此函数可以理解为形参,用于定义过程,在执行的时候再赋具体的值。
        self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim])

        # nets
        self.G_sample = self.generator(self.z)

        self.D_real, _ = self.discriminator(self.X)
        self.D_fake, _ = self.discriminator(self.G_sample, reuse = True)

        # loss
        self.D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real, labels=tf.ones_like(self.D_real))) + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=tf.zeros_like(self.D_fake))) #tf.nn.sigmoid_cross_entropy_with_logits,一种交叉熵函数见http://blog.csdn.net/QW_sunny/article/details/72885403,tf.ones_like是一种拷贝函数,默认情况下,它会拷贝参数tensor的类型,维度等数据,并将其中的值设置为1.
        self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=tf.ones_like(self.D_fake)))

        # solver
        #优化器的简单学习,http://blog.csdn.net/xierhacker/article/details/53174558;minimize 通过更新var_list来减小loss.
        self.D_solver = tf.train.AdamOptimizer(learning_rate=2e-4).minimize(self.D_loss, var_list=self.discriminator.vars)
        self.G_solver = tf.train.AdamOptimizer(learning_rate=2e-4).minimize(self.G_loss, var_list=self.generator.vars)

        self.saver = tf.train.Saver()#我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。
        gpu_options = tf.GPUOptions(allow_growth=True) #设置GPU
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    def train(self, sample_dir, ckpt_dir='ckpt', training_epoches = 1000000, batch_size = 32):
        fig_count = 0
        self.sess.run(tf.global_variables_initializer())  
        #tf的表达式中所有的变量或者是常量都应该是 tf 的类型。只要是声明了变量,就得用 sess.run(tf.global_variables_initializer()) 或者 x.initializer.run() 方法来初始化才能用。
        #下面的就是具体的训练过程.
        for epoch in range(training_epoches):
            # update D
            X_b = self.data(batch_size)
            self.sess.run(
                self.D_solver,
                feed_dict={self.X: X_b, self.z: sample_z(batch_size, self.z_dim)}
                )
            # update G
            k = 1
            for _ in range(k):
                self.sess.run(
                    self.G_solver,
                    feed_dict={self.z: sample_z(batch_size, self.z_dim)}
                )

            # save img, model. print loss
            if epoch % 100 == 0 or epoch < 100:
                D_loss_curr = self.sess.run(
                        self.D_loss,
                        feed_dict={self.X: X_b, self.z: sample_z(batch_size, self.z_dim)})
                G_loss_curr = self.sess.run(
                        self.G_loss,
                        feed_dict={self.z: sample_z(batch_size, self.z_dim)})
                print('Iter: {}; D loss: {:.4}; G_loss: {:.4}'.format(epoch, D_loss_curr, G_loss_curr))

                if epoch % 1000 == 0:
                    samples = self.sess.run(self.G_sample, feed_dict={self.z: sample_z(16, self.z_dim)})

                    fig = self.data.data2fig(samples)
                    plt.savefig('{}/{}.png'.format(sample_dir, str(fig_count).zfill(3)), bbox_inches='tight')
                    fig_count += 1
                    plt.close(fig)

                #if epoch % 2000 == 0:
                    #self.saver.save(self.sess, os.path.join(ckpt_dir, "dcgan.ckpt"))

#   一个python的文件有两种使用的方法,第一是直接作为脚本执行,第二是import到其他的python脚本中被调用(模块重用)执行。因此if __name__ == 'main': 的作用就是控制这两种情况执行代码的过程,在if __name__ == 'main': 下的代码只有在第一种情况下(即文件作为脚本直接执行)才会被执行,而import到其他脚本中是不会被执行的。
if __name__ == '__main__':  

    # constraint GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    # save generated images
    sample_dir = 'Samples/celebA_dcgan'
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)

    # param
    generator = G_conv()
    discriminator = D_conv()

    data = celebA()

    # run
    dcgan = DCGAN(generator, discriminator, data)
    dcgan.train(sample_dir)

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值