使用生成式对抗网络(GAN)生成动漫人物图像

【图书推荐】《PyTorch深度学习与企业级项目实战》-CSDN博客

《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)

如今AI艺术创作能力越来越强大,Google发布的ImageGen项目基于文本提示作画的结果和真实艺术家的成品难辨真假。本项目将使用PyTorch实现生成式对抗网络生成式对抗网络来完成AI生成动漫人物图像。

本项目中使用的数据集是一个由63 632个高质量动画人脸组成的数据集,从www.getchu.com中抓取,然后使用https://github.com/nagadomi/lbpcascade_animeface中的动画人脸检测算法进行裁剪。图像大小从90×90到120×120不等。该数据集包含高质量的动漫角色图像,具有干净的背景和丰富的颜色。数据集下载链接:https://github.com/bchao1/Anime-Face-Dataset

我们知道在生成式对抗网络中有两个模型——生成模型(Generative Model,G)和判别模型(Discriminative Model,D)。G就是一个生成图片的网络,它接收一个随机的噪声z,然后通过这个噪声生成图片,生成的数据记作G(z)。D是一个判别网络,判别一幅图片是不是“真实的”(是不是捏造的)。它的输入参数是x,x代表一幅图片,输出D(x)代表x为真实图片的概率,如果为1,就代表是真实的图片,而输出为0,就代表不可能是真实的图片。

  1. 定义生成器Generator:生成器的输入为100维的高斯噪声,生成器会利用这个噪声生成指定大小的图片,关于最初的噪声,可以看成10011的特征图,然后利用转置卷积来进行尺寸还原操作,标准的卷积操作是不断缩小尺寸,转置卷积就可以理解为它的逆操作,这样就可以不断放大图像。
  2. 定义判别器Discriminator:判别器就是一个典型的二分类网络,首先它的输入是我们输入的图片,我们会利用一系列卷积操作来形成一维特征图进行分类操作,这里可以发现判别器的网络和生成器的相关操作是可逆的,唯独不一样的是激活函数。

模型训练的步骤如下:

   步骤1:首先固定生成器,训练判别器,提高真实样本被判别为真的概率,同时降低生成器生成的假图像被判别为真的概率,目标是判别器能准确进行分类。

   步骤2:固定判别器,训练生成器,生成器生成图像,尽可能提高该图像被判别器判别为真的概率,目标是生成器的结果能够骗过判别器。

   步骤3:重复,循环交替训练,最终生成器生成的样本足够逼真,使得鉴别器只有大约50%的判断正确率(相当于乱猜)。

完整代码如下:

#####################GANDEMO.py####################
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
import torchvision
from torchvision import transforms, datasets
from tqdm import tqdm

class Config(object):
    data_path = './gandata/data/'
    image_size = 96
    batch_size = 32
    epochs = 200
    lr1 = 2e-3
    lr2 = 2e-4
    beta1 = 0.5
    gpu = False
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    nz = 100
    ngf = 64
    ndf = 64
    save_path = './gandata/images'
    generator_path = './gandata/generator.pkl' 			#模型保存路径
    discriminator_path = './gandata/discriminator.pkl' 	#模型保存路径
    gen_img = './gandata/result.png'
    gen_num = 64
    gen_search_num = 5000
    gen_mean = 0
    gen_std = 1

config = Config()

# 1.数据转换
data_transform = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.CenterCrop(config.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 2.形成训练集
train_dataset = datasets.ImageFolder(root=os.path.join(config.data_path),
                                     transform=data_transform)

# 3.形成迭代器
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           config.batch_size,
                                           True,
                                           drop_last=True)
print('using {} images for training.'.format(len(train_dataset)))

class Generator(nn.Module):
    def __init__(self, config):
        super().__init__()

        ngf = config.ngf

        self.model = nn.Sequential(
            nn.ConvTranspose2d(config.nz, ngf * 8, 4, 1, 0),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1),
            nn.Tanh()
        )

    def forward(self, x):
        output = self.model(x)
        return output


class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()

        ndf = config.ndf

        self.model = nn.Sequential(
            nn.Conv2d(3, ndf, 5, 3, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0)
        )

    def forward(self, x):
        output = self.model(x)
        return output.view(-1)

generator = Generator(config)
discriminator = Discriminator(config)

optimizer_generator = torch.optim.Adam(generator.parameters(),
                                       config.lr1,
                                       betas=(config.beta1, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                           config.lr2,
                                           betas=(config.beta1, 0.999))

true_labels = torch.ones(config.batch_size)
fake_labels = torch.zeros(config.batch_size)
fix_noises = torch.randn(config.batch_size, config.nz, 1, 1)
noises = torch.randn(config.batch_size, config.nz, 1, 1)

for epoch in range(config.epochs):
    for ii, (img, _) in tqdm(enumerate(train_loader)):
        real_img = img.to(config.device)

        if ii % 2 == 0:
            optimizer_discriminator.zero_grad()

            r_preds = discriminator(real_img)
            noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
            fake_img = generator(noises).detach()
            f_preds = discriminator(fake_img)

            r_f_diff = (r_preds - f_preds.mean()).clamp(max=1)
            f_r_diff = (f_preds - r_preds.mean()).clamp(min=-1)
            loss_d_real = (1 - r_f_diff).mean()
            loss_d_fake = (1 + f_r_diff).mean()
            loss_d = loss_d_real + loss_d_fake

            loss_d.backward()
            optimizer_discriminator.step()

        else:
            optimizer_generator.zero_grad()
            noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
            fake_img = generator(noises)
            f_preds = discriminator(fake_img)
            r_preds = discriminator(real_img)
            r_f_diff = r_preds - torch.mean(f_preds)
            f_r_diff = f_preds - torch.mean(r_preds)
            loss_g = torch.mean(F.relu(1 + r_f_diff)) \
                     + torch.mean(F.relu(1 - f_r_diff))
            loss_g.backward()
            optimizer_generator.step()

    if epoch == config.epochs - 1:
        # 保存模型
        torch.save(discriminator.state_dict(), config.discriminator_path)
        torch.save(generator.state_dict(), config.generator_path)

print('Finished Training')

generator = Generator(config)
discriminator = Discriminator(config)

noises = torch.randn(config.gen_search_num,
                     config.nz, 1, 1).normal_(config.gen_mean,
                                                                     config.gen_std)
noises = noises.to(config.device)

generator.load_state_dict(torch.load(config.generator_path,
                                     map_location='cpu'))
discriminator.load_state_dict(torch.load(config.discriminator_path,
                                         map_location='cpu'))
generator.to(config.device)
discriminator.to(config.device)

fake_img = generator(noises)
scores = discriminator(fake_img).detach()

indexs = scores.topk(config.gen_num)[1]
result = []
for ii in indexs:
    result.append(fake_img.data[ii])

torchvision.utils.save_image(torch.stack(result), config.gen_img,
                             normalize=True, value_range=(-1, 1))

代码运行结果如下:

using 900 images for training.
28it [00:20,  1.40it/s]
28it [00:20,  1.33it/s]
28it [00:21,  1.29it/s]
…
28it [00:26,  1.06it/s]
Finished Training

效果图如图13-9所示,由于只训练了100个Epoch,因此图像生成的纹理还不算太清楚,大家计算资源允许的话,可以多训练一些Epoch来生成更多的图像细节。

图13-9

  • 34
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值