基于GAN的动漫头像生成

GAN的原理

GAN是一种典型的生成网络模型,它类似于编解码结构,通过训练,他能够生成不同于训练集的各种图片。
在这里插入图片描述
首先先训练判别器,把真图通过判别器的输出和真标签作损失,把假图通过判别器的输出和假标签作损失,让它具备判别真图和假图的能力。然后再训练生成器,把生成器生成的假图通过判别器的输出和真标签作损失。经过反复的训练,让判别器难以分辨生成图的真假,也就是让它判别为真或为假的概率各为0.5

数据集下载

网上下载的动漫头像数据集有很多不清晰的奇异样本,对此我做了清洗,剩下的都是符合标准的,可直接下载
百度网盘:https://pan.baidu.com/s/1–zFrJdg1gtW2wJ6wtWQsQ
密码:bu55

网络结构

生成网络

相当于一个编码器

class NetD(nn.Module):
    # 构建一个判别器,相当与一个二分类问题, 生成一个值
    def __init__(self):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 输入96*96*3
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            # 输入32*32*ndf
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True),
            # 输入16*16*ndf*2
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True),
            # 输入为8*8*ndf*4
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True),
            # 输入为4*4*ndf*8
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=True),
            nn.Sigmoid()  # 分类问题
        )
    def forward(self, x):
        return self.main(x).view(-1)

生成器

相当于一个解码器

class NetG(nn.Module):
    # 定义一个生成模型,通过输入噪声来产生一张图片
    def __init__(self):
        super(NetG, self).__init__()
        ngf = opt.ngf
        self.main = nn.Sequential(
            # 假定输入为一张1*1*opt.nz维的数据(opt.nz维的向量)
            nn.ConvTranspose2d(opt.nz , ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True),
            # 输入一个4*4*ngf*8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 输入一个8*8*ngf*4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=True),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 输入一个16*16*ngf*2
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(inplace=True),
            # 输入一个32*32*ngf
            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()
            # 输出一张96*96*3
        )
    def forward(self, x):
        return self.main(x)

GAN网络结构设计要点

1、在D网络中用stride卷积(stride>1)代替pooling层,在G网络中用conv2d_transpose代替上采样层
2、在G和D网络中直接将BN应用到所有层会导致样本震荡和模型不稳定,通过在G网络输出层和D网络输入层不采用BN层可以有效防止这种现象
3、不使用全连接层作为输出
4、G网络中除了输出层用tanh激活,其他层都是用ReLu激活
5、D网络中都使用LeakyReLu激活

网络模型训练

训练细节

1、预处理环节,将图像scale到tanh的[-1,1]
2、所有的参数初始化由(0,0.02)的正态分布中随机得到
3、LeakyReLu的斜率是0.2(默认)
4、优化器Adam的learning rate=0.0002,momentum参数betas的beta1从0.9降为0.5,beta2默认,防止震荡和不稳定
5、可以G网络训练1次,然后D网络训练1次,如此反复;也可以G网络先训练几次后,D网络再训练1次,如此反复。前者效果出得较快,后者较慢。
训练代码

# opt参数
ngf=96
ndf=96
nz=256
img_size=96
batch_size=100
num_workers=4
netg_path=r"网络参数/netg_5.pt"
netd_path=r"网络参数/netd_5.pt"
lr1=0.0002
lr2=0.0002
beta1=0.5
epochs=200
d_every=1
g_every=5
save_every=20
from torchvision.utils import save_image
import Nets
import torch
from torch.utils.data import DataLoader
import opt
import torch.nn as nn
import dataset

if __name__=="__main__":
   # 1. 加载数据
    dataset = dataset.Dataset()
    dataloader = DataLoader(dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers,drop_last=True)
    # 2.初始化网络
    netg, netd = Nets.NetG(), Nets.NetD()
    # 3. 设定优化器参数
    optimize_g = torch.optim.Adam(netg.parameters(), lr=opt.lr1, betas=(opt.beta1,0.999))
    optimize_d = torch.optim.Adam(netd.parameters(), lr=opt.lr2, betas=(opt.beta1,0.999))
    loss_func = nn.BCELoss()
    # 4. 定义标签, 并且开始注入生成器的输入noise
    true_labels = torch.ones(opt.batch_size)
    fake_labels = torch.zeros(opt.batch_size)
    noises = torch.randn(opt.batch_size, opt.nz, 1, 1)
    #  6.训练网络
    netg.train()
    netd.train()
    for epoch in range(opt.epochs):
        for i, img in enumerate(dataloader):
            real_img = img
            # 训练判别器
            if i % opt.d_every == 0:
                optimize_d.zero_grad()
                # 真图
                real_out = netd(real_img)
                error_d_real = loss_func(real_out, true_labels)
                error_d_real.backward()
                # 随机生成的假图
                noises = noises.detach()
                fake_image = netg(noises).detach()
                fake_out = netd(fake_image)
                error_d_fake = loss_func(fake_out, fake_labels)
                error_d_fake.backward()
                optimize_d.step()
                # 计算loss
                error_d = error_d_fake + error_d_real
                print("第{0}轮: 判别网络   损失:{1}  对真图评分:{2}  对生成图评分:{3}".format(epoch+1,error_d.item(),real_out.data.mean(),fake_out.data.mean()))
            # 训练生成器
            if i % opt.g_every == 0 and i>0:
                optimize_g.zero_grad()
                noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                output = netd(fake_img)
                error_g = loss_func(output, true_labels)
                error_g.backward()
                optimize_g.step()
                print("       生成网络   损失:{0}".format(error_g.item()))
        #  7.保存模型和图片
            if i % opt.save_every == 0 and i>0:
                fix_noises = torch.randn(opt.batch_size, opt.nz, 1, 1)
                fix_fake_image = netg(fix_noises)
                # save_image(real_img.data*0.5+0.5, "./img/{0}-{1}-real_img.jpg".format(epoch, i), nrow=10)
                save_image(fix_fake_image.data*0.5+0.5, "./image/{0}-{1}-fake_img.jpg".format(epoch, i), nrow=10)
                torch.save(netd.state_dict(), opt.netd_path)
                torch.save(netg.state_dict(), opt.netg_path)

效果展示

生成网络随机生成的头像
在这里插入图片描述

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值