全系列GAN网络生成动漫角色

本文介绍了作者在学习GAN网络时,使用动漫头像数据集进行实践,探讨了GAN的基本原理、不同类型的GAN(如GAN,CGAN,WGAN)以及训练过程。作者分享了代码和数据集链接,并指出GAN生成的图像存在噪声和风格差异,期待通过更多训练和优化得到更好的结果。
摘要由CSDN通过智能技术生成

代码地址:Yuki0614/Series-of-GAN-for-Anime (github.com)

数据集地址:Anime Faces | Kaggle

FFT - Fast File Transfer Service (iqiyi.com) 密码: 5Kv2M1

Cartoon Face Recognition: A Benchmark Dataset (arxiv.org)

最近在学习GAN网络,感觉挺有意思的。GAN网络的系列有很多,包括GAN,CGAN,WGAN,CycleGAN等,差别在于训练方式,网络结构,不同任务(风格迁移,多分类图像生成等),在知乎和csdn上找了一些代码,然后找了个动漫图像的数据集复现了一下,想来和大伙交流学习一下。

第一个数据集是无类别的动漫头像,第二是有类别的,可以用于CGAN等网络。数据集比较大,我是选了一部分数据跑的,链接放上面了。

原始GAN网络结构由生成器和判别器组成。生成器输入噪声输出图片,其目的是生成尽可能接近真实图片的图片,来骗过判别器,故其损失函数为判别器对其生成图片真伪的判断和真实图片的标签(全1)构成的损失函数。判别器输入图片,输出真伪(0到1的概率),其目的是尽可能区分是真实图片还是生成器生成的图片,故其损失函数为真实图片标签(1)与输出概率,生成图片标签(0)与输出概率,两部分构成。在GAN中,生成器和判别器均由全连接网络构成。理想状态,经过不断训练,生成器可以生成接近真实的图片,判别器对所有输入都认为概率为0.5(无法分类)。

下面是原始GAN网络的代码,其他网络的放在github了,以后还会继续学习一些新的网络,弄明白以后就会更新。

网络结构:

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 512), nn.ReLU(),
            nn.Linear(512, 1024), nn.ReLU(),
            nn.Linear(1024, 64*64*3), # 数据集图片大小为64*64,3通道
            nn.Tanh()  # 最后必须用tanh,把数据分布到(-1,1)之间
        )
    def forward(self, x):  # x表示长度为100的噪声输入
        img = self.main(x)
        img = img.view(-1,3,64,64)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Linear(64*64*3, 1024),
            nn.LeakyReLU(), # x小于零是是一个很小的值不是0,x大于0是还是x
            nn.Linear(1024,512),
            nn.LeakyReLU(),
            nn.Linear(512,1),
            nn.Sigmoid() # 保证输出范围为(0,1)的概率
        )
    def forward(self, x):
        img = x.view(-1, 64*64*3)
        img = self.main(img)
        return img

训练代码:

import torch
import torch.nn as nn
from torchvision import transforms
from create_dataset import My_dataset, save_img
from torch.utils.data import DataLoader
from model64 import Generator, Discriminator



# 图像变换
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 网络设置图片大小为 64*64,保证图片大小符合网络结构要求
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


dataset = My_dataset(r'E:\work\GAN\1', transform=transform)   # 数据集位置
batch_size, epochs = 32, 500
my_dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

discriminator = Discriminator()
generator = Generator()

if torch.cuda.is_available():
    discriminator = discriminator.cuda()
    generator = generator.cuda()


d_optimizer = torch.optim.Adam(discriminator.parameters(), betas=(0.5, 0.99), lr=1e-4)  # betas为adam算法两个动量参数
g_optimizer = torch.optim.Adam(generator.parameters(), betas=(0.5, 0.99), lr=1e-4)
criterion = nn.BCELoss()

for epoch in range(epochs):

    for i, img in enumerate(my_dataloader):

        noise = torch.randn(batch_size, 100).cuda() # 随机噪声作为输入
        real_img = img.cuda()
        fake_img = generator(noise)


        real_out = discriminator(real_img)
        fake_out = discriminator(fake_img)
        real_label = torch.ones_like(real_out).cuda()  # 真实图片标签为1,生成图片标签为0
        fake_label = torch.zeros_like(fake_out).cuda()
        real_loss = criterion(real_out, real_label)
        fake_loss = criterion(fake_out, fake_label)

        d_loss = real_loss + fake_loss  # 训练判别器
        d_optimizer.zero_grad()

        d_loss.backward()
        d_optimizer.step()

        noise = torch.randn(batch_size, 100).cuda()
        fake_img = generator(noise)
        output = discriminator(fake_img)

        g_loss = criterion(output, real_label)   # 训练生成器
        g_optimizer.zero_grad()

        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 5 == 0:
            print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                  'D_real: {:.6f},D_fake: {:.6f}'.format(
                epoch, epochs, d_loss.data.item(), g_loss.data.item(),
                real_out.data.mean(), fake_out.data.mean()  # 打印的是真实图片的损失均值
            ))
        if epoch == 0 and i == len(my_dataloader) - 1:          # 保存真实图像
            save_img(img[:64, :, :, :], './sample/real_images.png')
        if (epoch+1) % 50 == 0 and i == len(my_dataloader)-1:             # 每50个epoch保存一次预测图像
            save_img(fake_img[:64, :, :, :], './sample/fake_images_{}.png'.format(epoch + 1))

torch.save(generator.state_dict(), './generator.pth')        # 保存权重文件
torch.save(discriminator.state_dict(), './discriminator.pth')

数据集工具:

import os
import numpy as np
import torch
from PIL import Image
from torchvision.utils import make_grid
import torch.utils.data.dataset as Dataset


class My_dataset(Dataset.Dataset):

    def __init__(self, path, transform):

        self.path = path
        loc_list = os.listdir(self.path)
        self.loc_list = loc_list
        self.tranform = transform

    def __getitem__(self, index):

        loc_data = os.path.join(self.path, self.loc_list[index])
        data = Image.open(loc_data)
        data = np.array(data)
        data = Image.fromarray(data)  # 若导入图像不为PIL格式则需要转换
        data = self.tranform(data)
        return data

    def __len__(self):

        return len(self.loc_list)


def save_img(tensor, fp):

    grid = make_grid(tensor)
    ndarr = (grid.mul(0.5).add_(0.5)).mul(255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = Image.fromarray(ndarr, mode='RGB')
    im.save(fp)

训练效果:GAN50epoch

GAN200epoch

对比了其他网络训练的效果:

WGAN

DCGAN

以上是由同一数据集不同GAN网络学习生成结果 

感觉好像是有点差别的,不知道为什么GAN的图像好像有很多噪声?而且画风都有些不一样了,不过DCGAN和WGAN确实要比GAN效果好些,我这里没有用完整数据集,而且训练epoch比较少,增加训练,或者调大网络结构效果可能会更好些。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值