利用GAN生成动漫头像

该博客介绍了如何利用DCGAN结合卷积神经网络生成逼真的动漫头像。作者参考了多个资源,包括论文和教程,并提供了代码实现,包括数据爬取、头像截取以及训练过程。通过opencv爬取和预处理动漫图片,使用Tensorflow实现DCGAN模型,并展示了随着训练迭代,生成的头像质量逐步提升的过程。
摘要由CSDN通过智能技术生成

前言

这篇博客参考自:GAN学习指南:从原理入门到制作生成Demo
前面曾经写过一篇:GAN入门介绍
这里再提供一个视频(文末):干货 | 直观理解GAN背后的原理:以人脸图像生成为例
GAN的原理很简单,但是它有很多变体,如:DCGAN、CycleGAN、DeblurGAN等,它们也被用在不同地方,本文将用到DCGAN来生成动漫头像,可以做到以假乱真的地步。
补充项目资源github地址

原理

这里写图片描述

  • 整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。
  • D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片的是否真实的概率。
  • G的目的:上面提到过,D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”。也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。因此我们看到式子的最前面的记号是min_G

那么如何将图像处理与GAN结合呢?我们可以将CNN(卷积神经网络)与GAN结合,这里是论文地址Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
DCGAN中的G网络示意:

代码实现

参考GAN学习指南:从原理入门到制作生成Demo
爬了动漫图库网站:konachan.net - Konachan.com Anime Wallpapers

  • 原始数据集的搜集

爬虫代码如下:

#采用request+beautiful库爬取
import requests
from bs4 import BeautifulSoup
import os
import traceback#python异常模块

def download(url, filename):#判断文件是否存在,存在则退出本次循环
    if os.path.exists(filename):
        print('file exists!')
        return
    try:
        r = requests.get(url, stream=True, timeout=60)#以流数据形式请求,你可获取来自服务器的原始套接字响应
        r.raise_for_status()
        with open(filename, 'wb') as f:#将文本流保存到文件
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)
                    f.flush()
        return filename
    except KeyboardInterrupt:
        if os.path.exists(filename):
            os.remove(filename)
        raise KeyboardInterrupt
    except Exception:
        traceback.print_exc()#把返回信息输出到控制台
        if os.path.exists(filename):
            os.remove(filename)


if os.path.exists('imgs') is False:
    os.makedirs('imgs')

start = 1
end = 8000
for i in range(start, end + 1):
    url = 'http://konachan.net/post?page=%d&tags=' % i#需要爬取的url
    html = requests.get(url).text#获取的html页面内容
    soup = BeautifulSoup(html, 'html.parser')
    for img in soup.find_all('img', class_="preview"):
        target_url = 'http:' + img['src']
        filename = os.path.join('imgs', target_url.split('/')[-1])
        download(target_url, filename)
    print('%d / %d' % (i, end))

最后,经过大概半天,我爬取大概500多M图片
这里写图片描述

由于这些图片比较复杂,对于网络难以训练,我们需要截取出动漫人物的头像,通过opencv工具,github上面已经有这个项目应用,
nagadomi/lbpcascade_animeface

import cv2#需提前在你的python环境下安装opencv包
import sys
import os.path
from glob import glob

def detect(filename, cascade_file="lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):#lbpcascade_animeface.xml文件可在github上面找到,就是一个巨长的xml格式代码,表示看不懂。
        raise RuntimeError("%s: not found" % cascade_file)

    cascade = cv2.CascadeClassifier(cascade_file)
    image = cv2.imread(filename)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)

    faces = cascade.detectMultiScale(gray,
                                     # detector options
                                     scaleFactor=1.1,
                                     minNeighbors=5,
                                     minSize=(48, 48))
    for i, (x, y, w, h) in enumerate(faces):
        face = image[y: y + h, x:x + w, :]
        face = cv2.resize(face, (96, 96))
        save_filename = '%s-%d.jpg' % (os.path.basename(filename).split('.')[0], i)
        cv2.imwrite("faces/" + save_filename, face)#写入文件


if __name__ == '__main__':
    if os.path.exists('faces') is False:
        os.makedirs('faces')
    file_list = glob('imgs/*.jpg')
    for filename in file_list:
        detect(filename)

上面主要是detectMultiScale函数难理解

#C#版本
    void detectMultiScale(  
        const Mat& image,  #image--待检测图片,一般为灰度图像加快检测速度
        CV_OUT vector<Rect>& objects,  #objects--被检测物体的矩形框向量组
        double scaleFactor = 1.1,  #scaleFactor--表示在前后两次相继的扫描中,搜索窗口的比例系数。默认为1.1即每次搜索窗口依次扩大10%;
        int minNeighbors = 3,   #minNeighbors--表示构成检测目标的相邻矩形的最小个数(默认为3个)。
        如果组成检测目标的小矩形的个数和小于 min_neighbors - 1 都会被排除。
        如果min_neighbors 为 0, 则函数不做任何操作就返回所有的被检候选矩形框,
        这种设定值一般用在用户自定义对检测结果的组合程序上;
        int flags = 0,  #flags--要么使用默认值,要么使用CV_HAAR_DO_CANNY_PRUNING,如果设置为

        CV_HAAR_DO_CANNY_PRUNING,那么函数将会使用Canny边缘检测来排除边缘过多或过少的区域,

        因此这些区域通常不会是人脸所在区域;
        Size minSize = Size(),  #minSize和maxSize用来限制得到的目标区域的范围。
        Size maxSize = Size()  
    );  

截取后的人物数据:
这里写图片描述
500多m的图片,最后只剩60多m。

###mian.py代码

import os
import scipy.misc # 
import numpy as np

from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables

import tensorflow as tf

flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")#迭代次数
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")#学习速率,默认是0.002
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")#训练数据大小
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")#每次迭代的图像数量
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")#需要指定输入图像的高
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")#需要指定输入图像的宽
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")#需要指定处理哪个数据集
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")#输入的文件格式
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")#储存训练样例
目录
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
FLAGS = flags.FLAGS

def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
#控制GPU资源使用率
  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.b
  • 22
    点赞
  • 214
    收藏
    觉得还不错? 一键收藏
  • 45
    评论
评论 45
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值