pytorch之GAN实现生成动漫头像《深度学习框架pytorch入门与实践》

首先有一点点废话,GAN就是生成对抗网络,由生成器和判别器构成。

生成器和判别器可以比喻成一个新手画家和一个新手鉴赏家,生成器这个画家不断的学习作画,给判别器这个鉴赏家检验.

判别器有时候能看到真画和假画,他的职责就是尽可能判断出真画为1,假画为0

生成器的工作就是尽可能欺骗判别器,让他识别不出假画,也就是把假画判断为1

这两部分一直对抗,一直学习,就是我们的GAN了

1.导入需要的模块

import torch
from torch import nn
import torchvision
from torchvision import transforms
import torch.utils.data as Data
from PIL import Image
from torchvision.transforms import ToTensor,ToPILImage

to_tensor = ToTensor() # 将图片转换成Tensor
to_pil = ToPILImage() # 将Tensor转换成Image对象

2.配置信息

'''基本配置'''
class Config(object):
    dataPath = 'data/'
    numWorkers = 4
    imageSize = 96
    batchSize = 128 # batch size=128的梯度下降方法
    maxEpoch = 200
    lr1 = 2e-4 #生成器学习率
    lr2 = 2e-4 #判别器学习率
    beta1 = 0.5 #Adam优化器的beta1参数
    nz = 100 #随机操声维度
    ngf = 64 #generator feature map
    ndf = 64 #discriminator feature

    savePath = './saveimg/'

    useCuda = True
    vis = True #是否使用visdom
    env = 'GAN' #visdom的env
    plotEvery = 20 #每20batch,visdom画图一次

    dEvery = 1 # 判别器训练周期
    gEvery = 5 # 生成器训练周期
    decayEvery = 10 # 模型保存周期
    netDpath = 'checkpoints/netD.pth'
    netGpath = 'checkpoints/netG.pth'

    ganImg = 'result.png'
    ganNum = 64
    ganSearchNum = 512
    ganMean = 0 #噪声均值
    ganStd = 1 #噪声方差
    sol = 0.2 #LeakyReLU的斜率值
    pat = 0.5 #Momentum的patient

config = Config()

3.生成器网络结构

'''生成器'''
class Generator(nn.Module):
    def __init__(self,config):
        config = config
        super(Generator,self).__init__()
        self.out = nn.Sequential(
            # 100*1*1 --> (64*8)*4*4
            # ConvTranspose2d 是二维转置卷积
            nn.ConvTranspose2d(config.nz, config.ngf * 8, kernel_size=4, bias=False), 
            nn.BatchNorm2d(config.ngf * 8), # 批规范化  #如果不好加上0.5试试
            nn.ReLU(True), # True为直接修改覆盖 ,节省内存

            # (64*8)*4*4 --> (64*4)*8*8
            nn.ConvTranspose2d(config.ngf * 8, config.ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),  
            nn.BatchNorm2d(config.ngf * 4),
            nn.ReLU(True),

            # (64*4)*8*8 --> (64*2)*16*16
            nn.ConvTranspose2d(config.ngf * 4, config.ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(config.ngf * 2),
            nn.ReLU(True),

            # (64*2)*16*16 --> 64*32*32
            nn.ConvTranspose2d(config.ngf * 2, config.ngf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(config.ngf),
            nn.ReLU(True),

            # 64*32*32 --> 3*96*96
            nn.ConvTranspose2d(config.ngf, 3, kernel_size=5, stride=3, padding=1, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        return self.out(x)

4.判别器网络结构

'''判别器'''
class Discriminator(nn.Module):
    def __init__(self,config):
        config = config
        super(Discriminator,self).__init__()
        self.out = nn.Sequential(
            # 3*96*96 --> 64*32*32
            nn.Conv2d(3, config.ndf, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(config.sol,True),

            # 64*32*32 --> (64*2)*16*16
            nn.Conv2d(config.ndf, config.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False), 
            nn.BatchNorm2d(config.ndf * 2),
            nn.LeakyReLU(config.sol,True),

            # (64*2)*16*16 --> (64*4)*8*8
            nn.Conv2d(config.ndf * 2, config.ndf * 4, kernel_size=4, stride=2,padding=1, bias=False), 
            nn.BatchNorm2d(config.ndf * 4),
            nn.LeakyReLU(config.sol,True),
            # (64*4)*8*8 --> (64*8)*4*4
            nn.Conv2d(config.ndf * 4, config.ndf * 8, kernel_size=4, stride=2,padding=1, bias=False),
            nn.BatchNorm2d(config.ndf * 8),
            nn.LeakyReLU(config.sol,True),

            # (64*8)*4*4 --> 1 * 1 * 1
            nn.Conv2d(config.ndf * 8, 1, kernel_size=4, bias=False), 
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.out(x).view(-1)

从生成器和判别器的结构就可以很容易的看出,两者是对称相反的结构 。

5.准备数据

'''准备数据'''
tfs = transforms.Compose([
        transforms.Resize(config.imageSize),# 改成(size * size)
        transforms.CenterCrop(config.imageSize), #中心切割
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #标准化
    ])

'''这里的数据放在./data/faces/下,注意dataPath = "./data" 
    这样ImageFolder判断faces下所有图片为一类'''
trainset = torchvision.datasets.ImageFolder(config.dataPath ,transform=tfs)

trainloader = Data.DataLoader(
    trainset,
    batch_size=config.batchSize,
    shuffle=True,
    num_workers=config.numWorkers,
    drop_last=True
)

6.训练模型

if __name__ == '__main__':
    map_location = lambda storage,loc:storage
    netG = Generator(config)  # 生成器
    netD = Discriminator(config)  # 判别器
    if config.vis:
        vis = visdom.Visdom(env=config.env)
    if config.netDpath:
         netD.load_state_dict(torch.load(config.netDpath,map_location=map_location))
    if config.netGpath:
         netG.load_state_dict(torch.load(config.netGpath,map_location=map_location))
    optG = torch.optim.Adam(netG.parameters(), config.lr1, betas=(config.beta1, 0.999))  # 生成器优化器
    optD = torch.optim.Adam(netD.parameters(), config.lr2, betas=(config.beta1, 0.999))  # 判别器优化器
    loss_func = torch.nn.BCELoss()

    true_labels = torch.ones(config.batchSize)  # 真图片为1
    false_labels = torch.zeros(config.batchSize)  # 假图片为0
    fix_noises = torch.randn(config.batchSize, config.nz, 1, 1)  # batch组 nz*1*1的数据
    noises = torch.randn(config.batchSize, config.nz, 1, 1)  # 随机生成噪声

    '''判断是否使用GPU'''
    if config.useCuda:
        netD.cuda()
        netG.cuda()
        loss_func.cuda()
        true_labels,false_labels = true_labels.cuda(),false_labels.cuda()
        fix_noises,noises = fix_noises.cuda(),noises.cuda()

    '''开始训练'''
    for epoch in range(config.maxEpoch):
        for i, (img, _) in enumerate(trainloader):
            real_img = img
            if config.useCuda:
                real_img = real_img.cuda()

            # 训练判别器
            if (i + 1) % config.dEvery == 0:
                optD.zero_grad()
                out = netD(real_img)  # 尽可能把真的图片判别为1
                loss_real = loss_func(out, true_labels)
                loss_real.backward()

                noises.data.copy_(torch.randn(config.batchSize, config.nz, 1, 1))
                fake_img = netG(noises).detach()  # 生成假图片 detach是切断求导关联
                fake_out = netD(fake_img)  # 尽可能把假的图片判别为0
                loss_fake = loss_func(fake_out, false_labels)
                loss_fake.backward()
                optD.step()


            # 训练生成器
            if i % config.gEvery == 0:
                optG.zero_grad()
                noises.data.copy_(torch.randn(config.batchSize, config.nz, 1, 1))
                fake_img = netG(noises)  # 尽可能让噪声为真,让判别器把假的图片判为1
                fake_out = netD(fake_img)
                loss_fake = loss_func(fake_out, true_labels)
                loss_fake.backward()
                optG.step()
            '''
            这段代码不够成熟请忽略
            if i %config.plotEvery == config.plotEvery - 1:
                #可视化
                fix_fake_imgs = netG(fix_noises)
                fix_fake_imgs = fix_fake_imgs.data.cpu()[:1] * 0.5 + 0.5
                check_real_img = real_img.data.cpu()[:1] * 0.5 + 0.5
                
                to_pil(fix_fake_imgs.squeeze())
                to_pil(check_real_img.squeeze())
            '''

        if epoch % config.decayEvery == 0:
            # 保存模型、图片
            fix_fake_imgs = fix_fake_imgs.data.cpu()[:1] * 0.5 + 0.5
            to_pil(fix_fake_imgs.squeeze()).save('%s/%s.png' % (config.savePath, epoch))

            torch.save(netD.state_dict(), './netd_%s.pth' % epoch)
            torch.save(netG.state_dict(), './netg_%s.pth' % epoch)

            optG = torch.optim.Adam(netG.parameters(), config.lr1, betas=(config.beta1, 0.999))
            optD = torch.optim.Adam(netD.parameters(), config.lr2, betas=(config.beta1, 0.999))
            #
            # fix_imgs = netG(fix_noises)
            # vis.images(fix_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')

7.加载模型并随机生成一张图片

    netG = Generator(config)  # 生成器
    netG.load_state_dict(torch.load('./netG.pth'))
    rand_img = netG(torch.randn(1, config.nz, 1, 1))
    rand_img = rand_img.data.cpu()[:1] * 0.5 + 0.5
    to_pil(rand_img.squeeze()) 
    '''本段代码是在jupyter notebook上跑,因此会自动输出Image对象,
    如果读者不是和我一样,则应该另寻可视化方法,这些都不是重点
    '''

强调:本文章是作者学习陈云所著《深度学习框架pytorch入门与实践》中的实践代码,有一些代码的精简,可能不太成熟,如有出错请指正

  • 1
    点赞
  • 50
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值