简介
在人脸数据上训练DCGAN,并生成一些人脸图片
数据
使用两个数据集
- LFW:http://vis-www.cs.umass.edu/lfw/,Labeled Faces in the Wild,超过1.3W张图片,其中1680人拥有超过两张或以上图片
- CelebA:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html,CelebFaces Attributes Dataset,包括10177人共计超过20W张图片,并且每张图片还包括人脸的5个关键点位置和40个属性的01标注,例如是否有眼镜、帽子、胡子等
实现
和上节课的代码差不多,根据彩色图片进行适当调整即可
加载库
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import urllib
import tarfile
import os
import matplotlib.pyplot as plt
%matplotlib inline
from imageio import imread, imsave, mimsave
from scipy.misc import imresize
import glob
下载LFW数据并解压处理,CelebA数据已经准备好
url = 'http://vis-www.cs.umass.edu/lfw/lfw.tgz'
filename = 'lfw.tgz'
directory = 'lfw_imgs'
new_dir = 'lfw_new_imgs'
if not os.path.isdir(new_dir):
os.mkdir(new_dir)
if not os.path.isdir(directory):
if not os.path.isfile(filename):
urllib.request.urlretrieve(url, filename)
tar = tarfile.open(filename, 'r:gz')
tar.extractall(path=directory)
tar.close()
count = 0
for dir_, _, files in os.walk(directory):
for file_ in files:
img = imread(os.path.join(dir_, file_))
imsave(os.path.join(new_dir, '%d.png' % count), img)
count += 1
设定用于生成人脸的数据集
# dataset = 'lfw_new_imgs' # LFW
dataset = 'celeba' # CelebA
images = glob.glob(os.path.join(dataset, '*.*'))
print(len(images))
定义一些常量、网络输入、辅助函数
batch_size = 100
z_dim = 100
WIDTH = 64
HEIGHT = 64
OUTPUT_DIR = 'samples_' + dataset
if not os.path.exists(OUTPUT_DIR):
os.mkdir(OUTPUT_DIR)
X = tf.placeholder(dtype=tf.float32, shape=[None, HEIGHT, WIDTH, 3], name='X')
noise = tf.placeholder(dtype=tf.float32, shape=[None, z_dim], name='noise')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')
def lrelu(x, leak=0.2):
return tf.maximum(x, leak * x)
def sigmoid_cross_entropy_with_logits(x, y):
return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
判别器部分
def discriminator(image, reuse=None, is_training=is_training