前言
这篇博客参考自:GAN学习指南:从原理入门到制作生成Demo
前面曾经写过一篇:GAN入门介绍
这里再提供一个视频(文末):干货 | 直观理解GAN背后的原理:以人脸图像生成为例
GAN的原理很简单,但是它有很多变体,如:DCGAN、CycleGAN、DeblurGAN等,它们也被用在不同地方,本文将用到DCGAN来生成动漫头像,可以做到以假乱真的地步。
补充项目资源github地址
原理
- 整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。
- D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片的是否真实的概率。
- G的目的:上面提到过,D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”。也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。因此我们看到式子的最前面的记号是min_G
那么如何将图像处理与GAN结合呢?我们可以将CNN(卷积神经网络)与GAN结合,这里是论文地址Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
DCGAN中的G网络示意:
代码实现
参考GAN学习指南:从原理入门到制作生成Demo
爬了动漫图库网站:konachan.net - Konachan.com Anime Wallpapers。
- 原始数据集的搜集
爬虫代码如下:
#采用request+beautiful库爬取
import requests
from bs4 import BeautifulSoup
import os
import traceback#python异常模块
def download(url, filename):#判断文件是否存在,存在则退出本次循环
if os.path.exists(filename):
print('file exists!')
return
try:
r = requests.get(url, stream=True, timeout=60)#以流数据形式请求,你可获取来自服务器的原始套接字响应
r.raise_for_status()
with open(filename, 'wb') as f:#将文本流保存到文件
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
f.flush()
return filename
except KeyboardInterrupt:
if os.path.exists(filename):
os.remove(filename)
raise KeyboardInterrupt
except Exception:
traceback.print_exc()#把返回信息输出到控制台
if os.path.exists(filename):
os.remove(filename)
if os.path.exists('imgs') is False:
os.makedirs('imgs')
start = 1
end = 8000
for i in range(start, end + 1):
url = 'http://konachan.net/post?page=%d&tags=' % i#需要爬取的url
html = requests.get(url).text#获取的html页面内容
soup = BeautifulSoup(html, 'html.parser')
for img in soup.find_all('img', class_="preview"):
target_url = 'http:' + img['src']
filename = os.path.join('imgs', target_url.split('/')[-1])
download(target_url, filename)
print('%d / %d' % (i, end))
最后,经过大概半天,我爬取大概500多M图片
由于这些图片比较复杂,对于网络难以训练,我们需要截取出动漫人物的头像,通过opencv工具,github上面已经有这个项目应用,
nagadomi/lbpcascade_animeface
import cv2#需提前在你的python环境下安装opencv包
import sys
import os.path
from glob import glob
def detect(filename, cascade_file="lbpcascade_animeface.xml"):
if not os.path.isfile(cascade_file):#lbpcascade_animeface.xml文件可在github上面找到,就是一个巨长的xml格式代码,表示看不懂。
raise RuntimeError("%s: not found" % cascade_file)
cascade = cv2.CascadeClassifier(cascade_file)
image = cv2.imread(filename)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray = cv2.equalizeHist(gray)
faces = cascade.detectMultiScale(gray,
# detector options
scaleFactor=1.1,
minNeighbors=5,
minSize=(48, 48))
for i, (x, y, w, h) in enumerate(faces):
face = image[y: y + h, x:x + w, :]
face = cv2.resize(face, (96, 96))
save_filename = '%s-%d.jpg' % (os.path.basename(filename).split('.')[0], i)
cv2.imwrite("faces/" + save_filename, face)#写入文件
if __name__ == '__main__':
if os.path.exists('faces') is False:
os.makedirs('faces')
file_list = glob('imgs/*.jpg')
for filename in file_list:
detect(filename)
上面主要是detectMultiScale函数难理解
#C#版本
void detectMultiScale(
const Mat& image, #image--待检测图片,一般为灰度图像加快检测速度
CV_OUT vector<Rect>& objects, #objects--被检测物体的矩形框向量组
double scaleFactor = 1.1, #scaleFactor--表示在前后两次相继的扫描中,搜索窗口的比例系数。默认为1.1即每次搜索窗口依次扩大10%;
int minNeighbors = 3, #minNeighbors--表示构成检测目标的相邻矩形的最小个数(默认为3个)。
如果组成检测目标的小矩形的个数和小于 min_neighbors - 1 都会被排除。
如果min_neighbors 为 0, 则函数不做任何操作就返回所有的被检候选矩形框,
这种设定值一般用在用户自定义对检测结果的组合程序上;
int flags = 0, #flags--要么使用默认值,要么使用CV_HAAR_DO_CANNY_PRUNING,如果设置为
CV_HAAR_DO_CANNY_PRUNING,那么函数将会使用Canny边缘检测来排除边缘过多或过少的区域,
因此这些区域通常不会是人脸所在区域;
Size minSize = Size(), #minSize和maxSize用来限制得到的目标区域的范围。
Size maxSize = Size()
);
截取后的人物数据:
500多m的图片,最后只剩60多m。
- 训练图像
代码只能引用DCGAN的github代码:carpedm20/DCGAN-tensorflow
###mian.py代码
import os
import scipy.misc #
import numpy as np
from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables
import tensorflow as tf
flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")#迭代次数
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")#学习速率,默认是0.002
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")#训练数据大小
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")#每次迭代的图像数量
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")#需要指定输入图像的高
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")#需要指定输入图像的宽
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")#需要指定处理哪个数据集
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")#输入的文件格式
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")#储存训练样例
目录
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
FLAGS = flags.FLAGS
def main(_):
pp.pprint(flags.FLAGS.__flags)
if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
#控制GPU资源使用率
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
y_dim=10,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
else:
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.b