DCGAN

上一篇介绍了GAN 的基本原理以及相关的概念和知识点,同时也反映出GAN 的一些缺点,比如说训练不稳定,生成过程不可控,不具备可解释性等。这一篇就来看看GAN 的改进版之一,DCGAN(Deep Convolutional GAN)。

1. 相比于GAN 的改进

DCGAN 相比于GAN 或者是普通CNN 的改进包含以下几个方面:
(1)使用卷积和去卷积替换池化,即将pooling层convolutions替代
**对于判别模型:**容许网络学习自己的空间下采样
**对于生成模型:**容许它学习自己的空间上采样
(2)在生成器和判别器中都添加了批量归一化操作

  • 解决初始化差的问题
  • 帮助梯度传播到每一层
  • 防止generator把所有的样本都收敛到同一点。
    (3)在CNN中移除全连接层
    (4)在generator的除了输出层外的所有层使用ReLU,输出层采用tanh。
    (5)在discriminator的所有层上使用LeakyReLU。
    在这里插入图片描述
    图: LSUN 场景模型中使用的DCGAN生成网络。一个100维度的均匀分布z映射到一个有很多特征映射的小空间范围卷积。一连串的四个微步幅卷积(在最近的一些论文中它们错误地称为去卷积),将高层表征转换为64*64像素的图像。明显,没有使用全连接层和池化层。

2.DCGAN的实现

2.1 搜集原始数据集

首先是需要获取大量的动漫图像,这个可以利用爬虫爬取一个动漫网站:konachan.net的图片。爬虫的代码如下所示:

import requests
from bs4 import BeautifulSoup
import os
import traceback

def download(url, filename):
    if os.path.exists(filename):
        print('file exists!')
        return
    try:
        r = requests.get(url, stream=True, timeout=60)
        r.raise_for_status()
        with open(filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)
                    f.flush()
        return filename
    except KeyboardInterrupt:
        if os.path.exists(filename):
            os.remove(filename)
        raise KeyboardInterrupt
    except Exception:
        traceback.print_exc()
        if os.path.exists(filename):
            os.remove(filename)


if os.path.exists('imgs') is False:
    os.makedirs('imgs')

start = 1
end = 8000
for i in range(start, end + 1):
    url = 'http://konachan.net/post?page=%d&tags=' % i
    html = requests.get(url).text
    soup = BeautifulSoup(html, 'html.parser')
    for img in soup.find_all('img', class_="preview"):
        target_url = 'http:' + img['src']
        filename = os.path.join('imgs', target_url.split('/')[-1])
        download(target_url, filename)
    print('%d / %d' % (i, end))

截取部分图像如下所示:
在这里插入图片描述
现在已经有了基本的图像了,但我们的目标是生成动漫头像,不需要整张图像,而且其他的信息会干扰到训练,所以需要进行人脸检测截取人脸图像。

2.2 人脸检测截取人脸

截取头像和原文一样,直接使用github上一个基于opencv的工具:nagadomi/lbpcascade_animeface

import cv2
import sys
import os.path
from glob import glob

def detect(filename, cascade_file="lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):
        raise RuntimeError("%s: not found" % cascade_file)

    cascade = cv2.CascadeClassifier(cascade_file)
    image = cv2.imread(filename)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)

    faces = cascade.detectMultiScale(gray,
                                     # detector options
                                     scaleFactor=1.1,
                                     minNeighbors=5,
                                     minSize=(48, 48))
    for i, (x, y, w, h) in enumerate(faces):
        face = image[y: y + h, x:x + w, :]
        face = cv2.resize(face, (96, 96))
        save_filename = '%s-%d.jpg' % (os.path.basename(filename).split('.')[0], i)
        cv2.imwrite("faces/" + save_filename, face)


if __name__ == '__main__':
    if os.path.exists('faces') is False:
        os.makedirs('faces')
    file_list = glob('imgs/*.jpg')
    for filename in file_list:
        detect(filename)

处理后的图像如下所示:
在这里插入图片描述
这样就可以用来训练了!

2.3 源代码解析

参照于DCGAN-tensorflow
一共有4个文件,分别是main.py、model.py、ops.py、utils.py。

2.3.1 main.py
import os
import scipy.misc
import numpy as np
import json

from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables, expand_path, timestamp

import tensorflow as tf

flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("data_dir", "./data", "path to datasets [e.g. $HOME/data]")
flags.DEFINE_string("out_dir", "./out", "Root directory for outputs [e.g. $HOME/out]")
flags.DEFINE_string("out_name", "", "Folder (under out_root_dir) for all outputs. Generated automatically if left blank []")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Folder (under out_root_dir/out_name) to save checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Folder (under out_root_dir/out_name) to save samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
flags.DEFINE_boolean("export", False, "True for exporting with new batch size")
flags.DEFINE_boolean("freeze", False, "True for exporting with new batch size")
flags.DEFINE_integer("max_to_keep", 1, "maximum number of checkpoints to keep")
flags.DEFINE_integer("sample_freq", 200, "sample every this many iterations")
flags.DEFINE_integer("ckpt_freq", 200, "save checkpoint every this many iterations")
flags.DEFINE_integer("z_dim", 100, "dimensions of z")
flags.DEFINE_string("z_dist", "uniform_signed", "'normal01' or 'uniform_unsigned' or uniform_signed")
flags.DEFINE_boolean("G_img_sum", False, "Save generator image summaries in log")
#flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
FLAGS = flags.FLAGS

def main(_):
  pp.pprint(flags.FLAGS.__flags)
  
  # expand user name and environment variables
  FLAGS.data_dir = expand_path(FLAGS.data_dir)
  FLAGS.out_dir = expand_path(FLAGS.out_dir)
  FLAGS.out_name = expand_path(FLAGS.out_name)
  FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir)
  FLAGS.sample_dir = expand_path(FLAGS.sample_dir)

  if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
  if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height

  # output folders
  if FLAGS.out_name == "":
      FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path
      if FLAGS.train:
        FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist, FLAGS.output_width, FLAGS.batch_size)

  FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
  FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
  FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)

  if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)

  with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
    flags_dict = {k:FLAGS[k].value for k in FLAGS}
    json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)
  

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.z_dim,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir,
          out_dir=FLAGS.out_dir,
          max_to_keep=FLAGS.max_to_keep)
    else:
      dcgan = DCGAN(                                    #创建模型对象
          sess,
          input_width=FLAGS.input_width,                #输入108(不固定)
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,              #输出64x64x3
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,                  #batch_size=64
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.z_dim,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,                              #要不要进行crop,以图像中心点为原点截取图像
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir,
          out_dir=FLAGS.out_dir,
          max_to_keep=FLAGS.max_to_keep)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)
    else:
      load_success, load_counter = dcgan.load(FLAGS.checkpoint_dir)
      if not load_success:
        raise Exception("Checkpoint not found in " + FLAGS.checkpoint_dir)


    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])

    # Below is codes for visualization
      if FLAGS.export:
        export_dir = os.path.join(FLAGS.checkpoint_dir, 'export_b'+str(FLAGS.batch_size))
        dcgan.save(export_dir, load_counter, ckpt=True, frozen=False)

      if FLAGS.freeze:
        export_dir = os.path.join(FLAGS.checkpoint_dir, 'frozen_b'+str(FLAGS.batch_size))
        dcgan.save(export_dir, load_counter, ckpt=False, frozen=True)

      if FLAGS.visualize:
        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION, FLAGS.sample_dir)

if __name__ == '__main__':
  tf.app.run()

该文件调用了model.py文件和utils.py文件。

step0:执行main函数之前首先进行flags的解析,TensorFlow底层使用了python-gflags项目,然后封装成tf.app.flags接口,也就是说TensorFlow通过设置flags来传递tf.app.run()所需要的参数,我们可以直接在程序运行前初始化flags,也可以在运行程序的时候设置命令行参数来达到传参的目的。

这里主要设置了:

  • epoch:迭代次数
  • learning_rate:学习速率,默认是0.002
  • beta1
  • train_size
  • batch_size:每次迭代的图像数量
  • input_height:需要指定输入图像的高
  • input_width:需要指定输入图像的宽
  • output_height:需要指定输出图像的高
  • output_width:需要指定输出图像的宽
  • dataset:需要指定处理哪个数据集
  • input_fname_pattern
  • checkpoint_dir
  • sample_dir
  • train:True for training, False for testing
  • crop:True for training, False for testing
  • visualize
    step1:首先是打印参数数据,然后判断输入图像的输出图像的宽是否指定,如果没有指定,则等于其图像的高。

step2:然后判断checkpoint和sample的文件是否存在,不存在则创建。

step3:然后是设置session参数。tf.ConfigProto一般用在创建session的时候,用来对session进行参数配置,详细内容可见这篇博客

#tf.ConfigProto()的参数:
log_device_placement=True : 是否打印设备分配日志
allow_soft_placement=True : 如果你指定的设备不存在,允许TF自动分配设备
tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)

控制GPU资源使用率:
#allow growth
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config, ...)
# 使用allow_growth option,刚一开始分配少量的GPU容量,然后按需慢慢的增加,由于不会释放内存,所以会导致碎片

step4:运行session,首先判断处理的是哪个数据集,然后对应使用不同参数的DCGAN类,这个类会在model.py文件中定义。

step5:show所有与训练相关的变量。

step6:判断是训练还是测试,如果是训练,则进行训练;如果不是,判断是否有训练好的model,然后进行测试,如果没有先训练,则会提示“[!] Train a model first, then run test mode”。

step7:最后进行可视化,visualize(sess, dcgan, FLAGS, OPTION)。

main.py主要是调用前面定义好的模型、图像处理方法,来进行训练测试,程序的入口。

2.3.2 utils.py
"""
Some codes from https://github.com/Newmu/dcgan_code
"""
from __future__ import division
import math
import json
import random
import pprint
import scipy.misc
import imageio
import cv2
import numpy as np
import os
import time
import datetime
from time import gmtime, strftime
from six.moves import xrange
from PIL import Image

import tensorflow as tf
import tensorflow.contrib.slim as slim

pp = pprint.PrettyPrinter()

get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])


def expand_path(path):
  return os.path.expanduser(os.path.expandvars(path))

def timestamp(s='%Y%m%d.%H%M%S', ts=None):
  if not ts: ts = time.time()
  st = datetime.datetime.fromtimestamp(ts).strftime(s)
  return st
  
def show_all_variables():
  model_vars = tf.trainable_variables()
  slim.model_analyzer.analyze_vars(model_vars, print_info=True)

def get_image(image_path, input_height, input_width,
              resize_height=64, resize_width=64,
              crop=True, grayscale=False):
  image = imread(image_path, grayscale)                                             #通过图像路径读取图像
  return transform(image, input_height, input_width,                                #图像预处理,中心化,crop成64大小
                   resize_height, resize_width, crop)

def save_images(images, size, image_path):
  return imsave(inverse_transform(images), size, image_path)

def imread(path, grayscale = False):
  if (grayscale):
    return scipy.misc.imread(path, flatten = True).astype(np.float)
  else:
    # Reference: https://github.com/carpedm20/DCGAN-tensorflow/issues/162#issuecomment-315519747
    img_bgr = cv2.imread(path)
    # Reference: https://stackoverflow.com/a/15074748/
    img_rgb = img_bgr[..., ::-1]
    return img_rgb.astype(np.float)

def merge_images(images, size):
  return inverse_transform(images)

def merge(images, size):
  h, w = images.shape[1], images.shape[2]
  if (images.shape[3] in (3,4)):
    c = images.shape[3]
    img = np.zeros((h * size[0], w * size[1], c))
    for idx, image in enumerate(images):
      i = idx % size[1]
      j = idx // size[1]
      img[j * h:j * h + h, i * w:i * w + w, :] = image
    return img
  elif images.shape[3]==1:
    img = np.zeros((h * size[0], w * size[1]))
    for idx, image in enumerate(images):
      i = idx % size[1]
      j = idx // size[1]
      img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
    return img
  else:
    raise ValueError('in merge(images,size) images parameter '
                     'must have dimensions: HxW or HxWx3 or HxWx4')

def imsave(images, size, path):
  image = np.squeeze(merge(images, size))
  return imageio.imwrite(path, image)

def center_crop(x, crop_h, crop_w,
                resize_h=64, resize_w=64):
  if crop_w is None:
    crop_w = crop_h
  h, w = x.shape[:2]
  j = int(round((h - crop_h)/2.))
  i = int(round((w - crop_w)/2.))
  im = Image.fromarray(np.uint8(x[j:j+crop_h, i:i+crop_w]))
  return np.array(im.resize([resize_h, resize_w]), np.uint8(Image.BILINEAR))

def transform(image, input_height, input_width, 
              resize_height=64, resize_width=64, crop=True):
  if crop:
    cropped_image = center_crop(
      image, input_height, input_width, 
      resize_height, resize_width)
    h, w = image.shape[:2]
    j = int(round((h - input_height) / 2.))
    i = int(round((w - input_width) / 2.))
    im = Image.fromarray(np.uint8(image[j:j + input_height, i:i + input_width]))
  else:
    im = Image.fromarray(image[j:j+crop_h, i:i+crop_w])
  #处理后的图像压缩到[-1,1]区间
  return np.array(im.resize([resize_height, resize_width]), np.uint8(Image.BILINEAR))/127.5 - 1.

def inverse_transform(images):
  return (images+1.)/2.

def to_json(output_path, *layers):
  with open(output_path, "w") as layer_f:
    lines = ""
    for w, b, bn in layers:
      layer_idx = w.name.split('/')[0].split('h')[1]

      B = b.eval()

      if "lin/" in w.name:
        W = w.eval()
        depth = W.shape[1]
      else:
        W = np.rollaxis(w.eval(), 2, 0)
        depth = W.shape[0]

      biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]}
      if bn != None:
        gamma = bn.gamma.eval()
        beta = bn.beta.eval()

        gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]}
        beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]}
      else:
        gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []}
        beta = {"sy": 1, "sx": 1, "depth": 0, "w": []}

      if "lin/" in w.name:
        fs = []
        for w in W.T:
          fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]})

        lines += """
          var layer_%s = {
            "layer_type": "fc", 
            "sy": 1, "sx": 1, 
            "out_sx": 1, "out_sy": 1,
            "stride": 1, "pad": 0,
            "out_depth": %s, "in_depth": %s,
            "biases": %s,
            "gamma": %s,
            "beta": %s,
            "filters": %s
          };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs)
      else:
        fs = []
        for w_ in W:
          fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]})

        lines += """
          var layer_%s = {
            "layer_type": "deconv", 
            "sy": 5, "sx": 5,
            "out_sx": %s, "out_sy": %s,
            "stride": 2, "pad": 1,
            "out_depth": %s, "in_depth": %s,
            "biases": %s,
            "gamma": %s,
            "beta": %s,
            "filters": %s
          };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2),
               W.shape[0], W.shape[3], biases, gamma, beta, fs)
    layer_f.write(" ".join(lines.replace("'","").split()))

def make_gif(images, fname, duration=2, true_image=False):
  import moviepy.editor as mpy

  def make_frame(t):
    try:
      x = images[int(len(images)/duration*t)]
    except:
      x = images[-1]

    if true_image:
      return x.astype(np.uint8)
    else:
      return ((x+1)/2*255).astype(np.uint8)

  clip = mpy.VideoClip(make_frame, duration=duration)
  clip.write_gif(fname, fps = len(images) / duration)

def visualize(sess, dcgan, config, option, sample_dir='samples'):
  image_frame_dim = int(math.ceil(config.batch_size**.5))
  if option == 0:
    z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
    samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
    save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() )))
  elif option == 1:
    values = np.arange(0, 1, 1./config.batch_size)
    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim))
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      if config.dataset == "mnist":
        y = np.random.choice(10, config.batch_size)
        y_one_hot = np.zeros((config.batch_size, 10))
        y_one_hot[np.arange(config.batch_size), y] = 1

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
      else:
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})

      save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_arange_%s.png' % (idx)))
  elif option == 2:
    values = np.arange(0, 1, 1./config.batch_size)
    for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:
      print(" [*] %d" % idx)
      z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
      z_sample = np.tile(z, (config.batch_size, 1))
      #z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      if config.dataset == "mnist":
        y = np.random.choice(10, config.batch_size)
        y_one_hot = np.zeros((config.batch_size, 10))
        y_one_hot[np.arange(config.batch_size), y] = 1

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})
      else:
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})

      try:
        make_gif(samples, './samples/test_gif_%s.gif' % (idx))
      except:
        save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() )))
  elif option == 3:
    values = np.arange(0, 1, 1./config.batch_size)
    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample):
        z[idx] = values[kdx]

      samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
      make_gif(samples, os.path.join(sample_dir, 'test_gif_%s.gif' % (idx)))
  elif option == 4:
    image_set = []
    values = np.arange(0, 1, 1./config.batch_size)

    for idx in xrange(dcgan.z_dim):
      print(" [*] %d" % idx)
      z_sample = np.zeros([config.batch_size, dcgan.z_dim])
      for kdx, z in enumerate(z_sample): z[idx] = values[kdx]

      image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
      make_gif(image_set[-1], os.path.join(sample_dir, 'test_gif_%s.gif' % (idx)))

    new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
        for idx in range(64) + range(63, -1, -1)]
    make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)


def image_manifold_size(num_images):
  manifold_h = int(np.floor(np.sqrt(num_images)))
  manifold_w = int(np.ceil(np.sqrt(num_images)))
  assert manifold_h * manifold_w == num_images
  return manifold_h, manifold_w

这份代码主要是定义了各种对图像处理的函数,相当于其他3个文件的头文件。

step0:首先定义了一个pp = pprint.PrettyPrinter(),以方便打印数据结构信息,详细信息可见这篇博客。

step1:定义了get_stddev函数,是三个参数乘积后开平方的倒数,应该是为了随机化用。

step2:定义show_all_variables()函数。首先,tf.trainable_variables返回的是需要训练的变量列表;然后用tensorflow.contrib.slim中的model_analyzer.analyze_vars打印出所有与训练相关的变量信息。用法如下:

-代码:

import tensorflow as tf
import tensorflow.contrib.slim as slim

x1=tf.Variable(tf.constant(1,shape=[1],dtype=tf.float32),name='x11')
x2=tf.Variable(tf.constant(2,shape=[1],dtype=tf.float32),name='x22')
m=tf.train.ExponentialMovingAverage(0.99,5)
v=tf.trainable_variables()
for i in v:
    print 233
    print i

print 23333333   
slim.model_analyzer.analyze_vars(v,print_info=True)
print 23333333

-结果截图如下:
在这里插入图片描述
注:从step3-step11,都是在定义一些图像处理的函数,它们之间相互调用。

step3:定义get_image(image_path,input_height,input_width,resize_height=64, resize_width=64,crop=True, grayscale=False)函数。首先根据图像路径参数读取路径,根据灰度化参数选择是否进行灰度化。然后对图像参照输入的参数进行裁剪。

step4:定义save_images(images,size,image_path)函数。调用imsave(inverse_transform(images), size, image_path)函数并返回新图像。

step5:定义imread(path, grayscale = False)函数。调用cipy.misc.imread()函数,判断grayscale参数是否进行范围灰度化,并进行类型转换为np.float。

step6:定义merge_images(images, size)函数。调用inverse_transform(images)函数,并返回新图像。

step7:定义merge(images, size)函数。首先获取image的高和宽。然后判断image是RGB图还是灰度图,以分别进行不同的处理。如果通道数是3或4,则对每一批次(如,batch_size=64)的所有图像,用0初始化一张原始图像放大8*8的图像,然后循环,依次将所有图像填入大图像,并且返回这张大图像。如果通道数是1,也是一样,只不过填入图像的时候只填一个通道的信息。如果不是上述两种情况,则抛出错误提示。

step8:定义imsave(images, size, path)函数。首先将merge()函数返回的图像,用 np.squeeze()函数移除长度为1的轴。然后利用scipy.misc.imsave()函数将新图像保存到指定路径中。

step9:定义center_crop(x, crop_h, crop_w,resize_h=64, resize_w=64)函数。对图像的H和W与crop的H和W相减,得到取整的值,根据这个值作为下标依据来scipy.misc.resize图像。

step10:定义transform(image, input_height, input_width,resize_height=64, resize_width=64, crop=True)函数。对输入的图像进行裁剪,如果crop为true,则使用center_crop()函数,对图像的H和W与crop的H和W相减,得到取整的值,根据这个值作为下标依据来scipy.misc.resize图像;否则不对图像进行其他操作,直接scipy.misc.resize为64*64大小的图像。最后返回图像。

step11:定义inverse_transform(images)函数。对图像进行翻转后返回新图像。

step12:定义to_json(output_path, *layers)函数。应该是获取每一层的权值、偏置值什么的,但貌似代码中没有用到这个函数,所以先不管,后面用到再说。

step13:定义make_gif(images, fname, duration=2, true_image=False)函数。利用moviepy.editor模块来制作动图,为了可视化用的。函数又定义了一个函数make_frame(t),首先根据图像集的长度和持续的时间做一个除法,然后返回每帧图像。最后视频修剪并制作成GIF动画。

step14:定义visualize(sess, dcgan, config, option)函数。分为0、1、2、3、4种option。如果option=0,则之间显示生产的样本‘如果option=1,根据不同数据集不一样的处理,并利用前面的save_images()函数将sample保存下来;等等。本次在main.py中选用option=1。

step15:定义image_manifold_size(num_images)函数。首先获取图像数量的开平方后向下取整的h和向上取整的w,然后设置一个assert断言,如果h*w与图像数量相等,则返回h和w,否则断言错误提示。

这就是全部utils.py全部内容,主要负责图像的一些基本操作,获取图像、保存图像、图像翻转,和利用moviepy模块可视化训练过程。

2.3.3 ops.py
import math
import numpy as np 
import tensorflow as tf

from tensorflow.python.framework import ops

from utils import *

try:
  image_summary = tf.image_summary
  scalar_summary = tf.scalar_summary
  histogram_summary = tf.histogram_summary
  merge_summary = tf.merge_summary
  SummaryWriter = tf.train.SummaryWriter
except:
  image_summary = tf.summary.image
  scalar_summary = tf.summary.scalar
  histogram_summary = tf.summary.histogram
  merge_summary = tf.summary.merge
  SummaryWriter = tf.summary.FileWriter

if "concat_v2" in dir(tf):
  def concat(tensors, axis, *args, **kwargs):
    return tf.concat_v2(tensors, axis, *args, **kwargs)
else:
  def concat(tensors, axis, *args, **kwargs):
    return tf.concat(tensors, axis, *args, **kwargs)

class batch_norm(object):
  def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
    with tf.variable_scope(name):
      self.epsilon  = epsilon
      self.momentum = momentum
      self.name = name

  def __call__(self, x, train=True):
    return tf.contrib.layers.batch_norm(x,
                      decay=self.momentum, 
                      updates_collections=None,
                      epsilon=self.epsilon,
                      scale=True,
                      is_training=train,
                      scope=self.name)

def conv_cond_concat(x, y):
  """Concatenate conditioning vector on feature map axis."""
  x_shapes = x.get_shape()
  y_shapes = y.get_shape()
  return concat([
    x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)

def conv2d(input_, output_dim, 
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="conv2d"):
  with tf.variable_scope(name):
    w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],        #h0->h1,w[5x5x3x64]输入channel为3,输出为64
              initializer=tf.truncated_normal_initializer(stddev=stddev))
    conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')        #stride=2

    biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))#b.shape=64
    conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())

    return conv

def deconv2d(input_, output_shape,                                                  #input[4x4x512]output[8x8x256]
       k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
       name="deconv2d", with_w=False):
  with tf.variable_scope(name):
    # filter : [height, width, output_channels, in_channels]                        #h0的filter[5x5x256x512]
    w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
              initializer=tf.random_normal_initializer(stddev=stddev))
    
    try:
      deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
                strides=[1, d_h, d_w, 1])                                           #strides=2,d_h=2,d_w=2

    # Support for verisons of TensorFlow before 0.7.0
    except AttributeError:
      deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
                strides=[1, d_h, d_w, 1])

    biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))#b与输出大小一样h0到h1的b为256
    deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())

    if with_w:
      return deconv, w, biases
    else:
      return deconv
     
def lrelu(x, leak=0.2, name="lrelu"):
  return tf.maximum(x, leak*x)

def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):#维度转换
  shape = input_.get_shape().as_list()

  with tf.variable_scope(scope or "Linear"):                                        #对于反卷积          对于卷积
    try:
      matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,       #创建矩阵100x8192    h4[8192x1]
                 tf.random_normal_initializer(stddev=stddev))                       #高斯初始化w
    except ValueError as err:
        msg = "NOTE: Usually, this is due to an issue with the image dimensions.  Did you correctly set '--crop' or '--input_height' or '--output_height'?"
        err.args = err.args + (msg,)
        raise
    bias = tf.get_variable("bias", [output_size],                                   #b.shape=8192        b4.shape=1
      initializer=tf.constant_initializer(bias_start))                              #常量初始化b
    if with_w:
      return tf.matmul(input_, matrix) + bias, matrix, bias
    else:
      return tf.matmul(input_, matrix) + bias

该文件调用了utils.py文件。

step0:首先导入tensorflow.python.framework模块,包含了tensorflow中图、张量等的定义操作。

step1:然后是一个try…except部分,定义了一堆变量:image_summary 、scalar_summary、histogram_summary、merge_summary、SummaryWriter,都是从相应的tensorflow中获取的。如果可是直接获取,则获取,否则从tf.summary中获取。

step2:用来连接多个tensor。利用dir(tf)判断”concat_v2”是否在里面,如果在的话,定义一个concat(tensors, axis, *args, **kwargs)函数,并返回tf.concat_v2(tensors, axis, *args, **kwargs);否则也定义concat(tensors, axis, *args, **kwargs)函数,只不过返回的是tf.concat(tensors, axis, *args, **kwargs)。其中,tf.concat使用如下:

t1=tf.constant([[1,2,3],[4,5,6]])
t2=tf.constant([[7,8,9],[10,11,12]])
t3=tf.concat([t1,t2],0)
t4=tf.concat([t1,t2],1)
print t1
print t2
print t3
print t4

在这里插入图片描述
step3:定义一个batch_norm类,包含两个函数init和call函数。首先在init(self, epsilon=1e-5, momentum = 0.9, name=”batch_norm”)函数中,定义一个name参数名字的变量,初始化self变量epsilon、momentum 、name。在call(self, x, train=True)函数中,利用tf.contrib.layers.batch_norm函数批处理规范化。

step4:定义conv_cond_concat(x,y)函数。连接x,y与Int32型的[x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]维度的张量乘积。

step5:定义conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2,d_w=2, stddev=0.02,name=”conv2d”)函数。卷积函数:获取随机正态分布权值、实现卷积、获取初始偏置值,获取添加偏置值后的卷积变量并返回。

step6:定义deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name=”deconv2d”, with_w=False):函数。解卷积函数:获取随机正态分布权值、解卷积,获取初始偏置值,获取添加偏置值后的卷积变量,判断with_w是否为真,真则返回解卷积、权值、偏置值,否则返回解卷积。

step7:定义lrelu(x, leak=0.2, name=”lrelu”)函数。定义一个lrelu激励函数。

step8:定义linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False)函数。进行线性运算,获取一个随机正态分布矩阵,获取初始偏置值,如果with_w为真,则返回xw+b,权值w和偏置值b;否则返回xw+b。

这个文件主要定义了一些变量连接的函数、批处理规范化的函数、卷积函数、解卷积函数、激励函数、线性运算函数。

2.3.4 model.py
from __future__ import division
from __future__ import print_function
import os
import time
import math
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange

from ops import *
from utils import *

def conv_out_size_same(size, stride):
  return int(math.ceil(float(size) / float(stride)))

def gen_random(mode, size):
    if mode=='normal01': return np.random.normal(0,1,size=size)
    if mode=='uniform_signed': return np.random.uniform(-1,1,size=size)
    if mode=='uniform_unsigned': return np.random.uniform(0,1,size=size)


class DCGAN(object):
  def __init__(self, sess, input_height=108, input_width=108, crop=True,
         batch_size=64, sample_num = 64, output_height=64, output_width=64,
         y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
         gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
         max_to_keep=1,
         input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'):
    """

    Args:
      sess: TensorFlow session
      batch_size: The size of batch. Should be specified before training.
      y_dim: (optional) Dimension of dim for y. [None]
      z_dim: (optional) Dimension of dim for Z. [100]
      gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
      df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
      gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]
      dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]
      c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
    """
    self.sess = sess
    self.crop = crop

    self.batch_size = batch_size
    self.sample_num = sample_num        #训练途中测试

    self.input_height = input_height
    self.input_width = input_width
    self.output_height = output_height
    self.output_width = output_width

    self.y_dim = y_dim
    self.z_dim = z_dim                  #生成网络的维度先设置成100 

    self.gf_dim = gf_dim                #filter的基数
    self.df_dim = df_dim

    self.gfc_dim = gfc_dim              #全连接向量
    self.dfc_dim = dfc_dim

    # batch normalization : deals with poor initialization helps gradient flow
    self.d_bn1 = batch_norm(name='d_bn1')#DCGAN在生成器和判别器中都添加了批量归一化操作(BN)
    self.d_bn2 = batch_norm(name='d_bn2')#在每个conv层和ReLU层之间

    if not self.y_dim:
      self.d_bn3 = batch_norm(name='d_bn3')

    self.g_bn0 = batch_norm(name='g_bn0')
    self.g_bn1 = batch_norm(name='g_bn1')
    self.g_bn2 = batch_norm(name='g_bn2')

    if not self.y_dim:
      self.g_bn3 = batch_norm(name='g_bn3')

    self.dataset_name = dataset_name
    self.input_fname_pattern = input_fname_pattern
    self.checkpoint_dir = checkpoint_dir
    self.data_dir = data_dir
    self.out_dir = out_dir
    self.max_to_keep = max_to_keep

    if self.dataset_name == 'mnist':
      self.data_X, self.data_y = self.load_mnist()
      self.c_dim = self.data_X[0].shape[-1]
    else:
      data_path = os.path.join(self.data_dir, self.dataset_name, self.input_fname_pattern)#获取数据集
      self.data = glob(data_path)
      if len(self.data) == 0:
        raise Exception("[!] No data found in '" + data_path + "'")
      np.random.shuffle(self.data)
      imreadImg = imread(self.data[0])
      if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel number
        self.c_dim = imread(self.data[0]).shape[-1]
      else:
        self.c_dim = 1

      if len(self.data) < self.batch_size:
        raise Exception("[!] Entire dataset size is less than the configured batch_size")
    
    self.grayscale = (self.c_dim == 1)

    self.build_model()

  def build_model(self):
    if self.y_dim:
      self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')
    else:
      self.y = None

    if self.crop:
      image_dims = [self.output_height, self.output_width, self.c_dim]
    else:
      image_dims = [self.input_height, self.input_width, self.c_dim]

    self.inputs = tf.placeholder(
      tf.float32, [self.batch_size] + image_dims, name='real_images')

    inputs = self.inputs

    self.z = tf.placeholder(
      tf.float32, [None, self.z_dim], name='z')                                     #生成网络最开始的噪音数据self.z[?,100]
    self.z_sum = histogram_summary("z", self.z)

    self.G                  = self.generator(self.z, self.y)                        #G网络,输入为self.z噪音向量
    self.D, self.D_logits   = self.discriminator(inputs, self.y, reuse=False)       #D网络,真实数据输入
    self.sampler            = self.sampler(self.z, self.y)                          #sampler测试
    self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True)        #D_,生成数据输入,网络结构一样,参数不同
    
    self.d_sum = histogram_summary("d", self.D)
    self.d__sum = histogram_summary("d_", self.D_)
    self.G_sum = image_summary("G", self.G)

    def sigmoid_cross_entropy_with_logits(x, y):                                    #交叉熵函数
      try:
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
      except:
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)

    self.d_loss_real = tf.reduce_mean(                                              #D网络对于真实输入,期望labels=1
      sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))
    self.d_loss_fake = tf.reduce_mean(                                              #D网络对于生成输入,期望labels=0,(targets)
      sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))
    self.g_loss = tf.reduce_mean(                                                   #g_loss,期望labels=1
      sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))

    self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
    self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)
                          
    self.d_loss = self.d_loss_real + self.d_loss_fake                               #d_loss

    self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
    self.d_loss_sum = scalar_summary("d_loss", self.d_loss)

    t_vars = tf.trainable_variables()                                               #变量赋值                                        

    self.d_vars = [var for var in t_vars if 'd_' in var.name]
    self.g_vars = [var for var in t_vars if 'g_' in var.name]

    self.saver = tf.train.Saver(max_to_keep=self.max_to_keep)

  def train(self, config):
    #优化器优化loss,原始论文用的Adam优化器而非梯度下降
    d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
              .minimize(self.d_loss, var_list=self.d_vars)
    g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
              .minimize(self.g_loss, var_list=self.g_vars)
    try:
      tf.global_variables_initializer().run()                                       #全局变量初始化
    except:
      tf.initialize_all_variables().run()

    if config.G_img_sum:
      self.g_sum = merge_summary([self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
    else:
      self.g_sum = merge_summary([self.z_sum, self.d__sum, self.d_loss_fake_sum, self.g_loss_sum])
    self.d_sum = merge_summary(
        [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
    self.writer = SummaryWriter(os.path.join(self.out_dir, "logs"), self.sess.graph)

    sample_z = gen_random(config.z_dist, size=(self.sample_num , self.z_dim))       #[-1,1]区间初始化噪音向量[64x100]一个batch的噪音向量
    
    if config.dataset == 'mnist':
      sample_inputs = self.data_X[0:self.sample_num]
      sample_labels = self.data_y[0:self.sample_num]
    else:
      sample_files = self.data[0:self.sample_num]
      sample = [                                                                    #获取一个batch的数据
          get_image(sample_file,
                    input_height=self.input_height,
                    input_width=self.input_width,
                    resize_height=self.output_height,
                    resize_width=self.output_width,
                    crop=self.crop,
                    grayscale=self.grayscale) for sample_file in sample_files]
      if (self.grayscale):
        sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
      else:
        sample_inputs = np.array(sample).astype(np.float32)                         #处理后的图像压缩到[-1,1]区间,大小[64,64]
  
    counter = 1
    start_time = time.time()
    could_load, checkpoint_counter = self.load(self.checkpoint_dir)
    if could_load:                                                                  #判断之前有没有训练网络模型
      counter = checkpoint_counter                                                  #如果有,接着之前的
      print(" [*] Load SUCCESS")                                            
    else:
      print(" [!] Load failed...")                                                  #如果没有从0开始训练

    for epoch in xrange(config.epoch):
      if config.dataset == 'mnist':
        batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size
      else:      
        self.data = glob(os.path.join(
          config.data_dir, config.dataset, self.input_fname_pattern))
        np.random.shuffle(self.data)
        batch_idxs = min(len(self.data), config.train_size) // config.batch_size

      for idx in xrange(0, int(batch_idxs)):
        if config.dataset == 'mnist':
          batch_images = self.data_X[idx*config.batch_size:(idx+1)*config.batch_size]
          batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size]
        else:
          batch_files = self.data[idx*config.batch_size:(idx+1)*config.batch_size]  #准备batch从0到第64个图像为第一个batch
          batch = [
              get_image(batch_file,
                        input_height=self.input_height,
                        input_width=self.input_width,
                        resize_height=self.output_height,
                        resize_width=self.output_width,
                        crop=self.crop,
                        grayscale=self.grayscale) for batch_file in batch_files]
          if self.grayscale:
            batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
          else:
            batch_images = np.array(batch).astype(np.float32)

        batch_z = gen_random(config.z_dist, size=[config.batch_size, self.z_dim]) \
              .astype(np.float32)

        if config.dataset == 'mnist':
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ 
              self.inputs: batch_images,
              self.z: batch_z,
              self.y:batch_labels,
            })
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={
              self.z: batch_z, 
              self.y:batch_labels,
            })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z, self.y:batch_labels })
          self.writer.add_summary(summary_str, counter)
          
          errD_fake = self.d_loss_fake.eval({
              self.z: batch_z, 
              self.y:batch_labels
          })
          errD_real = self.d_loss_real.eval({
              self.inputs: batch_images,
              self.y:batch_labels
          })
          errG = self.g_loss.eval({
              self.z: batch_z,
              self.y: batch_labels
          })
        else:
          # Update D network                                                        #迭代D网络
          _, summary_str = self.sess.run([d_optim, self.d_sum],                     #通过优化器run,将feed_dict添加进去
            feed_dict={ self.inputs: batch_images, self.z: batch_z })               #之前是placeholder格式
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)
          
          errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
          errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
          errG = self.g_loss.eval({self.z: batch_z})

        print("[%8d Epoch:[%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
          % (counter, epoch, config.epoch, idx, batch_idxs,
            time.time() - start_time, errD_fake+errD_real, errG))

        if np.mod(counter, config.sample_freq) == 0:                                #每迭代200次保存一次模型
          if config.dataset == 'mnist':
            samples, d_loss, g_loss = self.sess.run(
              [self.sampler, self.d_loss, self.g_loss],
              feed_dict={
                  self.z: sample_z,
                  self.inputs: sample_inputs,
                  self.y:sample_labels,
              }
            )
            save_images(samples, image_manifold_size(samples.shape[0]),
                  './{}/train_{:08d}.png'.format(config.sample_dir, counter))
            print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 
          else:
            try:
              samples, d_loss, g_loss = self.sess.run(
                [self.sampler, self.d_loss, self.g_loss],
                feed_dict={
                    self.z: sample_z,
                    self.inputs: sample_inputs,
                },
              )
              save_images(samples, image_manifold_size(samples.shape[0]),
                    './{}/train_{:08d}.png'.format(config.sample_dir, counter))
              print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 
            except:
              print("one pic error!...")

        if np.mod(counter, config.ckpt_freq) == 0:
          self.save(config.checkpoint_dir, counter)
        
        counter += 1
        
  def discriminator(self, image, y=None, reuse=False):                              #判别网络正向卷积
    with tf.variable_scope("discriminator") as scope:
      if reuse:
        scope.reuse_variables()

      if not self.y_dim:                                                            #h0[64x64x3]
        h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
        h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))
        h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))
        h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))
        h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin')

        return tf.nn.sigmoid(h4), h4                                                #全连接层通过sigmoid映射到[0,1]区间上,完成判别网络
      else:
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        x = conv_cond_concat(image, yb)

        h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv'))
        h0 = conv_cond_concat(h0, yb)

        h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))
        h1 = tf.reshape(h1, [self.batch_size, -1])      
        h1 = concat([h1, y], 1)
        
        h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))
        h2 = concat([h2, y], 1)

        h3 = linear(h2, 1, 'd_h3_lin')
        
        return tf.nn.sigmoid(h3), h3

  def generator(self, z, y=None):                                                   #生成网络反卷积
    with tf.variable_scope("generator") as scope:
      if not self.y_dim:                                                            #反卷积计算特征图大小
        s_h, s_w = self.output_height, self.output_width                            #64,64
        s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)         #32,32
        s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)       #16,16
        s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)       #8,8
        s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)     #4,4

        # project `z` and reshape
        #将z先连接一层fc转换为高维向量,linear将input转为output
        self.z_, self.h0_w, self.h0_b = linear(                                     #self.z_[?,8192]
            z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)                  #输入为z,输出64x8x4x4=8192维度

        self.h0 = tf.reshape(
            self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])                           #将self.z_转换为1x4x4x512
        h0 = tf.nn.relu(self.g_bn0(self.h0))                                        #要对h0进行反卷积操作,先bn,再relu  

        self.h1, self.h1_w, self.h1_b = deconv2d(                                   #h0->h1翻卷积4x4x512->8x8x256
            h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)
        h1 = tf.nn.relu(self.g_bn1(self.h1))

        h2, self.h2_w, self.h2_b = deconv2d(                                        #h2[16x16x128]
            h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)
        h2 = tf.nn.relu(self.g_bn2(h2))

        h3, self.h3_w, self.h3_b = deconv2d(                                        #h3[32x32x64]
            h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)
        h3 = tf.nn.relu(self.g_bn3(h3))

        h4, self.h4_w, self.h4_b = deconv2d(                                        #h4[64x64x3],batch为64
            h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True)

        return tf.nn.tanh(h4)
      else:
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_h4 = int(s_h/2), int(s_h/4)
        s_w2, s_w4 = int(s_w/2), int(s_w/4)

        # yb = tf.expand_dims(tf.expand_dims(y, 1),2)
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        z = concat([z, y], 1)

        h0 = tf.nn.relu(
            self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
        h0 = concat([h0, y], 1)

        h1 = tf.nn.relu(self.g_bn1(
            linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin')))
        h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])

        h1 = conv_cond_concat(h1, yb)

        h2 = tf.nn.relu(self.g_bn2(deconv2d(h1,
            [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2')))
        h2 = conv_cond_concat(h2, yb)

        return tf.nn.sigmoid(
            deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))

  def sampler(self, z, y=None):
    with tf.variable_scope("generator") as scope:
      scope.reuse_variables()

      if not self.y_dim:
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
        s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
        s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
        s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

        # project `z` and reshape
        h0 = tf.reshape(
            linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'),
            [-1, s_h16, s_w16, self.gf_dim * 8])
        h0 = tf.nn.relu(self.g_bn0(h0, train=False))

        h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1')
        h1 = tf.nn.relu(self.g_bn1(h1, train=False))

        h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2')
        h2 = tf.nn.relu(self.g_bn2(h2, train=False))

        h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3')
        h3 = tf.nn.relu(self.g_bn3(h3, train=False))

        h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4')

        return tf.nn.tanh(h4)
      else:
        s_h, s_w = self.output_height, self.output_width
        s_h2, s_h4 = int(s_h/2), int(s_h/4)
        s_w2, s_w4 = int(s_w/2), int(s_w/4)

        # yb = tf.reshape(y, [-1, 1, 1, self.y_dim])
        yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
        z = concat([z, y], 1)

        h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'), train=False))
        h0 = concat([h0, y], 1)

        h1 = tf.nn.relu(self.g_bn1(
            linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False))
        h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])
        h1 = conv_cond_concat(h1, yb)

        h2 = tf.nn.relu(self.g_bn2(
            deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False))
        h2 = conv_cond_concat(h2, yb)

        return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))

  def load_mnist(self):
    data_dir = os.path.join(self.data_dir, self.dataset_name)
    
    fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)

    fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    trY = loaded[8:].reshape((60000)).astype(np.float)

    fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)

    fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
    loaded = np.fromfile(file=fd,dtype=np.uint8)
    teY = loaded[8:].reshape((10000)).astype(np.float)

    trY = np.asarray(trY)
    teY = np.asarray(teY)
    
    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)
    
    seed = 547
    np.random.seed(seed)
    np.random.shuffle(X)
    np.random.seed(seed)
    np.random.shuffle(y)
    
    y_vec = np.zeros((len(y), self.y_dim), dtype=np.float)
    for i, label in enumerate(y):
      y_vec[i,y[i]] = 1.0
    
    return X/255.,y_vec

  @property
  def model_dir(self):
    return "{}_{}_{}_{}".format(
        self.dataset_name, self.batch_size,
        self.output_height, self.output_width)

  def save(self, checkpoint_dir, step, filename='model', ckpt=True, frozen=False):
    # model_name = "DCGAN.model"
    # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

    filename += '.b' + str(self.batch_size)
    if not os.path.exists(checkpoint_dir):
      os.makedirs(checkpoint_dir)

    if ckpt:
      self.saver.save(self.sess,
              os.path.join(checkpoint_dir, filename),
              global_step=step)

    if frozen:
      tf.train.write_graph(
              tf.graph_util.convert_variables_to_constants(self.sess, self.sess.graph_def, ["generator_1/Tanh"]),
              checkpoint_dir,
              '{}-{:06d}_frz.pb'.format(filename, step),
              as_text=False)

  def load(self, checkpoint_dir):
    #import re
    print(" [*] Reading checkpoints...", checkpoint_dir)
    # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
    # print("     ->", checkpoint_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
      #counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
      counter = int(ckpt_name.split('-')[-1])
      print(" [*] Success to read {}".format(ckpt_name))
      return True, counter
    else:
      print(" [*] Failed to find a checkpoint")
      return False, 0

这个文件就是DCGAN模型定义的函数。调用了utils.py文件和ops.py文件。

step0:定义conv_out_size_same(size, stride)函数。大小和步幅。

step1:然后是定义了DCGAN类,剩余代码都是在写DCGAN类,所以下面几步都是在这个类里面定义进行的。

step2:定义类的初始化函数 init。主要是对一些默认的参数进行初始化。包括session、crop、批处理大小batch_size、样本数量sample_num、输入与输出的高和宽、各种维度、生成器与判别器的批处理、数据集名字、灰度值、构建模型函数,需要注意的是,要判断数据集的名字是否是mnist,是的话则直接用load_mnist()函数加载数据,否则需要从本地data文件夹中读取数据,并将图像读取为灰度图。

step3:定义构建模型函数build_model(self)。

  1. 首先判断y_dim,然后用tf.placeholder占位符定义并初始化y。
  2. 判断crop是否为真,是的话是进行测试,图像维度是输出图像的维度;否则是输入图像的维度。
  3. 利用tf.placeholder定义inputs,是真实数据的向量。
  4. 定义并初始化生成器用到的噪音z,z_sum。
  5. 再次判断y_dim,如果为真,用噪音z和标签y初始化生成器G、用输入inputs初始化判别器D和D_logits、样本、用G和y初始化D_和D_logits;如果为假,跟上面一样初始化各种变量,只不过都没有标签y。
  6. 将5中的D、D_、G分别放在d_sum、d__sum、G_sum。
  7. 定义sigmoid交叉熵损失函数sigmoid_cross_entropy_with_logits(x, y)。都是调用tf.nn.sigmoid_cross_entropy_with_logits函数,只不过一个是训练,y是标签,一个是测试,y是目标。
  8. 定义各种损失值。真实数据的判别损失值d_loss_real、虚假数据的判别损失值d_loss_fake、生成器损失值g_loss、判别器损失值d_loss。
  9. 定义训练的所有变量t_vars。
  10. 定义生成和判别的参数集。
  11. 最后是保存。

step4:定义训练函数train(self, config)。

  1. 定义判别器优化器d_optim和生成器优化器g_optim。
  2. 变量初始化。
  3. 分别将关于生成器和判别器有关的变量各合并到一个变量中,并写入事件文件中。
  4. 噪音z初始化。
  5. 根据数据集是否为mnist的判断,进行输入数据和标签的获取。这里使用到了utils.py文件中的get_image函数。
  6. 定义计数器counter和起始时间start_time。
  7. 加载检查点,并判断加载是否成功。
  8. 开始for epoch in xrange(config.epoch)循环训练。先判断数据集是否是mnist,获取批处理的大小。
  9. 开始for idx in xrange(0, batch_idxs)循环训练,判断数据集是否是mnist,来定义初始化批处理图像和标签。
  10. 定义初始化噪音z。
  11. 判断数据集是否是mnist,来更新判别器网络和生成器网络,这里就不管mnist数据集是怎么处理的,其他数据集是,运行生成器优化器两次,以确保判别器损失值不会变为0,然后是判别器真实数据损失值和虚假数据损失值、生成器损失值。
  12. 输出本次批处理中训练参数的情况,首先是第几个epoch,第几个batch,训练时间,判别器损失值,生成器损失值。
  13. 每100次batch训练后,根据数据集是否是mnist的不同,获取样本、判别器损失值、生成器损失值,调用utils.py文件的save_images函数,保存训练后的样本,并以epoch、batch的次数命名文件。然后打印判别器损失值和生成器损失值。
  14. 每500次batch训练后,保存一次检查点。

step5:定义判别器函数discriminator(self, image, y=None, reuse=False)。

  1. 利用with tf.variable_scope(“discriminator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 对scope利用reuse_variables()进行重利用。
  3. 如果为假,则直接设置5层,前4层为使用lrelu激活函数的卷积层,最后一层是使用线性层,最后返回h4和sigmoid处理后的h4。
  4. 如果为真,则首先将Y_dim变为yb,然后利用ops.py文件中的conv_cond_concat函数,连接image与yb得到x,然后设置4层网络,前3层是使用lrelu激励函数的卷积层,最后一层是线性层,最后返回h3和sigmoid处理后的h3。

step6:定义生成器函数generator(self, z, y=None)。

  1. 利用with tf.variable_scope(“generator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 根据y_dim是否为真,进行判别网络的设置。
  3. 如果为假:首先获取输出的宽和高,然后根据这一值得到更多不同大小的高和宽的对。然后获取h0层的噪音z,权值w,偏置值b,然后利用relu激励函数。h1层,首先对h0层解卷积得到本层的权值和偏置值,然后利用relu激励函数。h2、h3等同于h1。h4层,解卷积h3,然后直接返回使用tanh激励函数后的h4。
  4. 如果为真:首先也是获取输出的高和宽,根据这一值得到更多不同大小的高和宽的对。然后获取yb和噪音z。h0层,使用relu激励函数,并与1连接。h1层,对线性全连接后使用relu激励函数,并与yb连接。h2层,对解卷积后使用relu激励函数,并与yb连接。最后返回解卷积、sigmoid处理后的h2。

step7:定义sampler(self, z, y=None)函数。

  1. 利用tf.variable_scope(“generator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 对scope利用reuse_variables()进行重利用。
  3. 根据y_dim是否为真,进行判别网络的设置。
  4. 然后就跟生成器差不多,不在赘述。

step8:定义load_mnist(self)函数。这个主要是针对mnist数据集设置的,所以暂且不考虑,过。

step9:定义model_dir(self)函数。返回数据集名字,batch大小,输出的高和宽。

step10:定义save(self, checkpoint_dir, step)函数。保存训练好的模型。创建检查点文件夹,如果路径不存在,则创建;然后将其保存在这个文件夹下。

step11:定义load(self, checkpoint_dir)函数。读取检查点,获取路径,重新存储检查点,并且计数。打印成功读取的提示;如果没有路径,则打印失败的提示。

以上,就是model.py所有内容,主要是定义了DCGAN的类,完成了生成判别网络的实现。

3.结果

运行

python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces_ --input_fname_pattern "*.jpg" -crop --train --epoch 100

第1个epoch
在这里插入图片描述
第5个epoch
在这里插入图片描述
第10个epoch
在这里插入图片描述
第20个epoch
在这里插入图片描述
第200epoch
在这里插入图片描述

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值