[GAN学习] 生成动漫人物头像

生成式AI以及比较新的概念AIGC最近比较火,之前的Diffusion model也很火,但是一般的电脑不太好训练,这里使用了GAN模型生成动漫人物他头像,简单易懂
img

数据集准备

需要的是动漫人物头像的数据,事实上有现成的,比如参考的repo中提供了网盘链接,也可以自己制作.我这里就自己制作算了,毕竟要走完一个流程.

首先找到一堆带有高清人物头像的数据并不困难,但是需要提出头像,这里就需要图像分割,使用别人利用opencv或者深度学习模型得到的detector切割出头像即可.

爬取图像数据可以使用Bionus/imgbrd-grabber: Very customizable imageboard/booru downloader with powerful filenaming features. (github.com)这个软件,之前使用diffusion model炼丹的时候LoRA Training Guide (rentry.co)用过,可以说是非常舒爽简单,或者使用mikf/gallery-dl: Command-line program to download image galleries and collections from several image hosting sites (github.com),直接pip下载即可

python3 -m pip install -U gallery-dl

然后使用相关命令

gallery-dl --range 1:1000 "https://danbooru.donmai.us/posts?tags=misaka_mikoto" 

这样可能还是会下载到视频,所以可以使用--filter file_ext == 'png' or file_ext == 'jpg'来过滤. 比如

gallery-dl --range 1001:5000 --filter "file_ext == 'png' or file_ext == 'jpg'"  "https://danbooru.donmai.us/
posts?tags=misaka_mikoto"

下载指定网址的一个范围的图片,这里有1000张,但像MNIST这种数据集训练都有60000张,所以最好还是弄多点.另外gallery-dl作为一个二次元图片爬虫还是很不错的,以后应该常用,

image-20230918153207674

然后使用一个分割器,python也有相关的库nya3jp/python-animeface: A library to detect anime faces in images. (github.com),也可以考虑使用别人训练好的detector.

如果使用后者,遇到类似需要rebuild library的报错需要装下面的库.而前者貌似在windows上不太行.

pip install opencv-contrib-python 
import cv2
import sys
import os.path


def detect(filename, cascade_file="./lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):
        raise RuntimeError("%s: not found" % cascade_file)

    cascade = cv2.CascadeClassifier(cascade_file)
    image = cv2.imread(filename, cv2.IMREAD_COLOR)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)

    faces = cascade.detectMultiScale(
        gray,
        # detector options
        scaleFactor=1.1,
        minNeighbors=5,
        minSize=(24, 24),
    )
    for (x, y, w, h) in faces:
        cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), 2)

    cv2.imshow("AnimeFaceDetect", image)
    cv2.waitKey(0)
    cv2.imwrite("out.png", image)


if __name__ == "__main__":
    detect(
        "../gallery-dl/danbooru/misaka_mikoto/danbooru_4831620_03359e23330ae19467b0b772b62cd89b.jpg"
    )

结果如下,上面代码的faces就是四个坐标画出bounding,可以直接利用这个坐标切割.

image-20230918163223707

也就是类似下面这样,注意数据类型是H,W,C.

 for (x, y, w, h) in faces:
        # cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), 2)
        image = image[y : y + h, x : x + w]

此外也有可能识别不出头像(或者是爬取的图本身就有点问题),如果识别区与超出图像就跳过,就改成下面这样

import cv2
import os.path
from tqdm import tqdm


def detect(filename, cascade_file="./lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):
        raise RuntimeError("%s: not found" % cascade_file)
    cascade = cv2.CascadeClassifier(cascade_file)
    image = cv2.imread(filename, cv2.IMREAD_COLOR)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)
    faces = cascade.detectMultiScale(
        gray,
        # detector options
        scaleFactor=1.1,
        minNeighbors=5,
        minSize=(24, 24),
    )
    if len(faces) > 0:
        for (x, y, w, h) in faces:
            # cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), 2)
            image = image[int(y - 0.1 * h) : int(y + 0.9 * h), x : x + w]
        height, width, _ = image.shape
        if height > 0 and width > 0:
            image = cv2.resize(image, (64, 64))
            dir_path = "../assets/misaka_mikoto"
            cv2.imwrite(
                dir_path
                + "/crop_"
                + filename.split("_")[-2]
                + "."
                + filename.split(".")[-1],
                image,
            )
        else:
            return
    else:
        return


if __name__ == "__main__":
    # 遍历某个目录
    for root, dirs, files in os.walk("../gallery-dl/danbooru/misaka_mikoto"):
        pbar = tqdm(files)
        for idx, file in enumerate(pbar):
            pbar.set_postfix(index=idx, file=file)
            if not file.startswith("crop") and not file.endswith(".part"):
                detect(os.path.join(root, file))

这样数据就处理好了.可以参考Anime-Face-Dataset/src at master · bchao1/Anime-Face-Dataset (github.com).

当然也存在一些问题,比如这里使用的是根据图库tag搜索图像,但打了这个tag的图像可能还会有其他任务角色,这样就可能把其他人物分割出来.

所以后续可以考虑使用一个专门的高精度识别分割器,专门用于提取某个人物.此外处理时最好写个try-catch丢弃错误,不然一旦哪张图出错又要重新来,或者写个出错时计数的值,接着那个值继续.

DataLoader

可以自定义数据

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
class AnimeDataset(Dataset):
    def __init__(self):
        super().__init__()
        pass
    def __getitem__(self, index):
        # 根据索引返回数据
        # data = self.preprocess(self.data[index]) # 如果需要预处理数据的话
        return self.data[index]

    def __len__(self):
        pass

    def preprocess(self, data):
        # 将data 做一些预处理
        pass

比如

class Pic_Data(Dataset): #继承Dataset
    def __init__(self, root_dir, transform=None): #__init__是初始化该类的一些基础参数
        self.root_dir = root_dir   #文件目录
        self.transform = transform #变换
        self.images = os.listdir(self.root_dir)#目录里的所有文件
    
    def __len__(self):#返回整个数据集的大小
        return len(self.images)
    
    def __getitem__(self,index):#根据索引index返回dataset[index]
        image_index = self.images[index]#根据索引index获取该图片
        img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
        img = Image.open(img_path)# 读取该图片
        label = int(image_index[-5])# 根据该图片的路径名获取该图片的label
        if self.transform:
            sample = self.transform(img)#对样本进行变换
        return sample,label#返回该样本
transform_fn=Compose([ToTensor(), Normalize(mean=(0.1307,),std=(0.3081,))])
train_data = Pic_Data("./data/MNIST/train", transform=transform_fn)
test_data = Pic_Data("./data/MNIST/test")

参考Pytorch 创建Dataset类

但是可以使用ImageFolder直接得到data.

from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir, # target folder of images
                                  transform=data_transform, # transforms to perform on data (images)
                                  target_transform=None)

比如下面这样

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision import datasets

epoch = 500
batch_size = 64
lr = 0.0002
z_dim = 100

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize()])

animeFaceDataset = datasets.ImageFolder(
    root="../assets/misaka_mikoto", transform=transform
)

dataloader = DataLoader(animeFaceDataset, batch_size=64, shuffle=True)

推荐有空做完这个教程04. PyTorch Custom Datasets - Zero to Mastery Learn PyTorch for Deep Learning 由于目前并不需要label,因为没有什么类别划分,所以先就这样,后续可以设计不同人物不同类别,这就是label信息可以用于conditionGAN等.

设计模型

使用DCGAN,WGAN啥的无所谓,github上都有现成模型.

关键是DataLoader肯定要把图片大小弄成一样的.这里使用DCGAN

# Generator
class Generator(nn.Module):
    """
    Input shape: (batch, in_dim)
    Output shape: (batch, 3, 64, 64)
    """
    def __init__(self, in_dim, feature_dim=64):
        super().__init__()
        #input: (batch, 100)
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, feature_dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(feature_dim * 8 * 4 * 4),
            nn.ReLU()
        )
        self.l2 = nn.Sequential(
            self.dconv_bn_relu(feature_dim * 8, feature_dim * 4),               #(batch, feature_dim * 16, 8, 8)
            self.dconv_bn_relu(feature_dim * 4, feature_dim * 2),               #(batch, feature_dim * 16, 16, 16)
            self.dconv_bn_relu(feature_dim * 2, feature_dim),                   #(batch, feature_dim * 16, 32, 32)
        )
        self.l3 = nn.Sequential(
            nn.ConvTranspose2d(feature_dim, 3, kernel_size=5, stride=2,
                               padding=2, output_padding=1, bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)
    def dconv_bn_relu(self, in_dim, out_dim):
        return nn.Sequential(
            nn.ConvTranspose2d(in_dim, out_dim, kernel_size=5, stride=2,
                               padding=2, output_padding=1, bias=False),        #double height and width
            nn.BatchNorm2d(out_dim),
            nn.ReLU(True)
        )
    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2(y)
        y = self.l3(y)
        return y
# Discriminator
class Discriminator(nn.Module):
    """
    Input shape: (batch, 3, 64, 64)
    Output shape: (batch)
    """

    def __init__(self, in_dim, feature_dim=64):
        super(Discriminator, self).__init__()

        # input: (batch, 3, 64, 64)
        """
        NOTE FOR SETTING DISCRIMINATOR:

        Remove last sigmoid layer for WGAN
        """
        self.l1 = nn.Sequential(
            nn.Conv2d(
                in_dim, feature_dim, kernel_size=4, stride=2, padding=1
            ),  # (batch, 3, 32, 32)
            nn.LeakyReLU(0.2),
            self.conv_bn_lrelu(feature_dim, feature_dim * 2),  # (batch, 3, 16, 16)
            self.conv_bn_lrelu(feature_dim * 2, feature_dim * 4),  # (batch, 3, 8, 8)
            self.conv_bn_lrelu(feature_dim * 4, feature_dim * 8),  # (batch, 3, 4, 4)
            nn.Conv2d(feature_dim * 8, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid(),
        )
        self.apply(weights_init)

    def conv_bn_lrelu(self, in_dim, out_dim):
        """
        NOTE FOR SETTING DISCRIMINATOR:

        You can't use nn.Batchnorm for WGAN-GP
        Use nn.InstanceNorm2d instead
        """

        return nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 4, 2, 1),
            nn.BatchNorm2d(out_dim),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1)
        return y
def train():
    G = Generator(100).cuda()
    D = Discriminator(3).cuda()
    criterion = nn.BCELoss()
    opt_D = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.999))
    opt_G = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.999))

    for e, epoch in enumerate(range(epochs)):
        progressbar = tqdm(dataloader)
        progressbar.set_description(f"Epoch {e+1}")
        for i, (imgs, _) in enumerate(progressbar):
            imgs = imgs.cuda()
            bs = imgs.size(0)

            z = torch.randn(bs, z_dim).cuda()
            fake_imgs = G(z).cuda()

            real_out = D(imgs)
            fake_out = D(fake_imgs)
            fake_label = torch.zeros(bs, 1).cuda()
            real_label = torch.ones(bs, 1).cuda()

            loss_d = (
                criterion(real_out, real_label) + criterion(fake_out, fake_label)
            ) / 2

            D.zero_grad()
            loss_d.backward()
            opt_D.step()

            # train for G
            if e % 1 == 0:
                z = torch.randn(bs, z_dim).cuda()
                f_imgs = G(z)
                fake_out = D(f_imgs)
                loss_g = criterion(fake_out, real_label)
                G.zero_grad()
                loss_g.backward()
                opt_G.step()
    torch.save(G.state_dict(), "../checkpoints/G.pth")
    torch.save(D.state_dict(), "../checkpoints/D.pth")

在训练过程中,记得常使用logging,opencv,matplotlib,pillow库对图像进行处理以及日志记录,同时每过几个epoch或步数就将模型参数保存下来,并将每个epoch的loss记录下来,使用tensorboard可视化等等.

结果展示

image-20230918223707863

image-20230918222924079

这里只训练了500个epoch,训练集也只有1000多张吧,效果还是将就.

后续会陆续更新到githubdrowning-in-codes/myGAN: learn GAN through self-taught (github.com)

reThink

这里只使用了DCGAN,可以考虑换一下其他GAN. 此外数据集的预处理也可以配合Pytorch做得更好. 另外由于是动漫头像数据集,需要分割爬取到的图片,这里使用训练好的算法爬取,肯定还是有一些错误率的,这个地方可以进行改进.比如下面这张图就分割错了.

image-20230918195800561

可视化loss可以使用tensorboard或者visdom.

常用的GAN数据集还有MNIST,fashion-MNIST,Celeb,SVHN等等还有用于超分的数据集monet2photo,edges2shoes,可以换换数据集.

参考repo

  1. nagadomi/lbpcascade_animeface: A Face detector for anime/manga using OpenCV (github.com)
  2. jayleicn/animeGAN: A simple PyTorch Implementation of Generative Adversarial Networks, focusing on anime face drawing. (github.com)
  3. ML_HW6.ipynb - Colaboratory (google.com)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

procoder338

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值