GAN 生成动漫头像 手动实现Pytorch全代码

本文分享手动实现DCGAN生成动漫头像的Pytorch代码。

简单来说,DCGAN(Deep Convolutional GAN)就是用全卷积代替了原始GAN的全连接结构,提升了GAN的训练稳定性和生成结果质量。

我使用的数据集,5W张96×96的动漫头像。
在这里插入图片描述

import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
import os


class D_Net(nn.Module):
    def __init__(self):
        super(D_Net,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, True)
        )   # 64, 32, 32
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True)
        )   # 128, 16, 16
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True)
        )   # 256, 8, 8
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True)
        )   # 512, 4, 4
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )   # 1, 1, 1

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

    # 判别器参数初始化
    def d_weight_init(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, mean=0, std=0.02)


class G_Net(nn.Module):
    def __init__(self):
        super(G_Net,self).__init__()
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(128, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )   # 512, 4, 4
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )   # 256, 8, 8
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )   # 128, 16, 16
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )   # 128, 32, 32
        self.conv5 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
            nn.Tanh()
        )   # 3, 96, 96

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

    # 生成器参数初始化
    def g_weight_init(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, mean=0, std=0.02)


if __name__ == '__main__':
    batch_size = 225
    if not os.path.exists("./dcgan_img"):
        os.mkdir("./dcgan_img")
    if not os.path.exists("./dcgan_params"):
        os.mkdir("./dcgan_params")
    img_transf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5, ], [0.5, ])
    ])
    img_dir = r"C:\Cartoon_faces0.1"
    # ImageFolder 不用自己写Dataset
    dataset = datasets.ImageFolder(img_dir, transform=img_transf)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    d_net = D_Net().to(device)
    g_net = G_Net().to(device)

    d_weight_file = r"dcgan_params/d_net.pth"
    g_weight_file = r"dcgan_params/g_net.pth"
    if os.path.exists(d_weight_file) and os.path.getsize(d_weight_file) != 0:
        d_net.load_state_dict(torch.load(d_weight_file))
        print("加载判别器保存参数成功")
    else:
        d_net.apply(d_net.d_weight_init)
        print("加载判别器随机参数成功")

    if os.path.exists(g_weight_file) and os.path.getsize(g_weight_file) != 0:
        g_net.load_state_dict(torch.load(g_weight_file))
        print("加载生成器保存参数成功")
    else:
        g_net.apply(g_net.g_weight_init)
        print("加载生成器随机参数成功")

    loss_fn = nn.BCELoss()
    d_opt = torch.optim.Adam(d_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
    g_opt = torch.optim.Adam(g_net.parameters(), lr=0.0002, betas=(0.5, 0.999))

    epoch = 1
    while True:
        print("epoch--{}".format(epoch))
        for i, (x, y) in enumerate(loader):
            # 判别器
            real_img = x.to(device)
            real_label = torch.ones(x.size(0), 1, 1, 1).to(device)
            fake_label = torch.zeros(x.size(0), 1, 1, 1).to(device)
            real_out = d_net(real_img)
            d_real_loss = loss_fn(real_out, real_label)

            z = torch.randn(x.size(0), 128, 1, 1).to(device)
            fake_img = g_net(z).detach()
            fake_out = d_net(fake_img)
            d_fake_loss = loss_fn(fake_out, fake_label)

            d_loss = d_real_loss + d_fake_loss
            d_opt.zero_grad()
            d_real_loss.backward()
            d_fake_loss.backward()
            d_opt.step()

            # 生成器
            fake_img = g_net(z)
            fake_out = d_net(fake_img)
            g_loss = loss_fn(fake_out, real_label)
            g_opt.zero_grad()
            g_loss.backward()
            g_opt.step()

            if i == 100:
                print("d_loss:{:.3f}\tg_loss:{:.3f}\td_real:{:.3f}\td_fake:{:.3f}".
                      format(d_loss.item(), g_loss.item(), real_out.data.mean(), fake_out.data.mean()))

                fake_image = fake_img.cpu().data
                save_image(fake_image, "./dcgan_img/{}_{}-fake_img.jpg".
                           format(epoch, i), nrow=15, normalize=True, scale_each=True)

        torch.save(d_net.state_dict(), "dcgan_params/d_net.pth")
        torch.save(g_net.state_dict(), "dcgan_params/g_net.pth")
        epoch += 1


  1. 生成网络G和判别网络D结构几乎完全对称,G网络用转置卷积实现上采样,参数设置在我的另一篇文章中已解释。偶数卷积核在正常网络模型中很少见,但在生成模型中效果比较好,避免生成图像不均匀的现象。
  2. 两个网络的激活函数和输出函数、网络参数初始化、优化器参数的选择大都是DCGAN论文的默认值,是实验结果。
  3. real_label和fake_label就是全1和全0的值,判别器训练时,真图标签为1,假图标签为0。产生的真图假图两个loss,其实可以合成一个,进行一次backward()就行,但是实验发现分开效果会比较好。
  4. D网络训练时只需要正常判别输入图片是真图还是G网络生成的假图,而G网络则需要混淆视听,尽量提高生成假图在D网络的输出评分,互相对抗学习。关键代码为:g_loss = loss_fn(fake_out, real_label)
  5. 判别器和生成器交替训练,训练一个时,另一个的参数应该固定。这在Pytorch中不需要我们做什么处理,因为在优化器中已经给定了需要优化的参数,虽然backward()计算了两个网络的全部梯度,但step()只更新了对应参数。训练D网络时,生成的假图用detach()操作截断了计算图,即不计算G网络的梯度,没什么大用,只是略微节省了一点时间。
  6. 交替训练时可以指定两个网络的训练频率,比如D网络每个batch训练1次,G网络每个batch训练2次。但是这个比例怎么取比较好需要实验,我这老年机就放弃了……同学们有兴趣可以调一下试试。

一开始是酱的:
在这里插入图片描述

200 epochs later……
在这里插入图片描述
远观尚可,别细看!

本文实现的DCGAN是最基础的,需要自己小心调参。为了解决GAN训练不稳定以及生成器和判别器的训练平衡问题,同学们可以参考一下WGAN,改动很小,效果很好。

  • 6
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值