GAN二次元头像生成Pytorch实现(附完整代码)

介绍

本文是李宏毅GAN课程课后作业HW3_1(二次元头像生成,Keras实现)的Pytorch版本。写这篇的原因是一方面刚开始接触GAN,二是个人比较习惯用Pytorch,所以将keras改成Pytorch版本。

实现所需要的资源:

链接:https://pan.baidu.com/s/1cLmFNQpJe1DOI96IVuvVyQ
提取码:nha2

本文一个改动就是将kernel=4变成了3,因为kernel一般都是奇数。其他和原网络基本相同。

下面是主要部分的代码,包括网络模块和训练/验证/测试两个模块。
完整的代码见 https://github.com/AsajuHuishi/Generate_a_quadratic_image_with_GAN

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn as nn
from torch.autograd import Variable

import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import time
import visdom

1.网络模块

生成器
##定义卷积核
def default_conv(in_channels,out_channels,kernel_size,bias=True):
    return nn.Conv2d(in_channels,out_channels,
                     kernel_size,padding=kernel_size//2,   #保持尺寸
                     bias=bias)
##定义ReLU     
def default_relu():
    return nn.ReLU(inplace=True)
## reshape
def get_feature(x):
    return x.reshape(x.size()[0],128,16,16)

class Generator(nn.Module):
    def __init__(self,input_dim=100,conv=default_conv,relu=default_relu,reshape=get_feature):
        super(Generator,self).__init__()
        head = [nn.Linear(input_dim,128*16*16),
                relu()]
        self.reshape = reshape                               #16x16
        body = [nn.Upsample(scale_factor=2,mode='nearest'),  #32x32
                conv(128,128,3),
                relu(),
                nn.Upsample(scale_factor=2,mode='nearest'),  #64x64
                conv(128,64,3),
                relu(),
                conv(64,3,3),
                nn.Tanh()]
        self.head = nn.Sequential(*head)
        self.body = nn.Sequential(*body)
        
    def forward(self,x):#x:(batchsize,input_dim)
        x = self.head(x)
        x = self.reshape(x)
        x = self.body(x)
        return x        #(batchsize,3,64,64)
    def name(self):
        return 'Generator'
判别器
class Discriminator(nn.Module):
    def __init__(self,conv=default_conv,relu=default_relu):
        super(Discriminator,self).__init__()
        main = [conv(3,32,3),
                relu(),
                conv(32,64,3),
                relu(),
                conv(64,128,3),
                relu(),
                conv(128,256,3),
                relu()]
        self.main = nn.Sequential(*main)
        self.fc = nn.Linear(256*64*64,1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,x):#x:(batchsize,3,64,64)
        x = self.main(x)#(b,256,64,64)
        x = x.view(x.size()[0],-1)#(b,256*64*64)
        x = self.fc(x) #(b,1)
        x = self.sigmoid(x)
        return x
    def name(self):
        return 'Discriminator'

2.训练/验证/测试模块

相关参数、模型初始化
class GAN(nn.Module):
    def __init__(self,args):
        super(GAN,self).__init__()
        self.img_size = 64
        self.channels = 3    
        self.latent_dim = args.latent_dim
        self.num_epoch = args.num_epoch
        self.batch_size = args.batch_size
        self.cuda = args.cuda
        self.interval = 20 #每相邻20个epoch验证一次
        self.continue_training = args.continue_training #是否是继续训练        
        ## 生成器初始化
        self.generator = Generator(self.latent_dim)
        ## 判别器初始化
        self.discriminator = Discriminator()
        self.testmodelpath = args.testmodelpath
        self.datapath = args.datapath
        if self.cuda:
            self.generator.cuda()
            self.discriminator.cuda()
        self.continue_training_isrequired() 
训练+dataloader数据集
    def trainer(self):
        ## 读入图片数据,分batch
        print('===> Data preparing...')
        import torchvision.transforms as transforms
        from torch.utils.data import DataLoader
        from torchvision.datasets import ImageFolder
        transform = transforms.ToTensor()  ##dataloader输出是tensor,不加这个会报错
        dataset = ImageFolder(self.datapath,transform=transform)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True)       
        ##drop_last: dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
        num_batch = len(dataloader) #batch的数量为len(dataloader)=总图片数/batchsize
        print('num_batch:',num_batch)
        #dataloader: (batchsize,3,64,64) 分布0-1
        ## 判别值
        target_real = Variable(torch.ones(self.batch_size,1))
        target_false = Variable(torch.zeros(self.batch_size,1))
        one_const = Variable(torch.ones(self.batch_size,1))
        if self.cuda:
            target_real = target_real.cuda()
            target_false = target_false.cuda()
            one_const = one_const.cuda()
        ## 优化器
        optim_generator = optim.Adam(self.generator.parameters(),lr=0.0002,betas=(0.5,0.999))
        optim_discriminator = optim.Adam(self.discriminator.parameters(),lr=0.0002,betas=(0.5,0.999))
        ## 误差函数
#        content_criterion = nn.MSELoss()
        adversarial_criterion = nn.BCELoss()
        ## 训 练 开 始
        for epoch in range(self.start_epoch,self.num_epoch): ##epoch
            ##用于观察一个epoch不同batch的平均loss
            mean_dis_loss = 0.0
            mean_gen_con_loss = 0.0
            mean_gen_adv_loss = 0.0
            mean_gen_total_loss = 0.0
            for i,data in enumerate(dataloader):  ##循环次数:batch的数量为len(dataloader)=总图片数//batchsize
                if epoch<3 and i%10==0:
                    print('epoch%d: %d/%d'%(epoch,i,len(dataloader)))
                ##1.1生成noise
                gen_input = np.random.normal(0,1,(self.batch_size,self.latent_dim)).astype(np.float32)
                gen_input = torch.from_numpy(gen_input)
                gen_input = torch.autograd.Variable(gen_input,requires_grad=True)
                if self.cuda:
                    gen_input = gen_input.cuda()                    
                fake = self.generator(gen_input) ##生成器生成的(batchsize,3,64,64)
                real, _ = data  #data:list[tensor,tensor]取第零个 real:(batchsize,3,64,64)
                if self.cuda:
                    real = real.cuda()
                    fake = fake.cuda()
                ## 1.固定G,训练判别器D                
                self.discriminator.zero_grad()
                dis_loss1 = adversarial_criterion(self.discriminator(real),target_real)
                dis_loss2 = adversarial_criterion(self.discriminator(fake.detach()),target_false)##注意经过G的网络再进入D网络之前要detach()之后再进入
                dis_loss = 0.5*(dis_loss1+dis_loss2)
#                print('epoch:%d--%d,判别器loss:%.6f'%(epoch,i,dis_loss))
                dis_loss.backward()
                optim_discriminator.step()
                
                mean_dis_loss+=dis_loss
                ## 2.固定D,训练生成器G
                self.generator.zero_grad()
                ##生成noise
                gen_input = np.random.normal(0,1,(self.batch_size,self.latent_dim)).astype(np.float32)
                gen_input = torch.from_numpy(gen_input)
                gen_input = torch.autograd.Variable(gen_input,requires_grad=True)
                if self.cuda:
                    gen_input = gen_input.cuda()    
                fake = self.generator(gen_input) ##生成器生成的(batchsize,3,64,64)  
                gen_con_loss = 0
                gen_adv_loss = adversarial_criterion(self.discriminator(fake),one_const)##固定D更新G
                gen_total_loss = gen_con_loss + gen_adv_loss
#                print('epoch:%d--%d,生成器loss:%.6f'%(epoch,i,gen_total_loss))
                gen_total_loss.backward()
                optim_generator.step()
                mean_gen_con_loss+=gen_con_loss
                mean_gen_adv_loss+=gen_adv_loss
                mean_gen_total_loss+=gen_total_loss
                
            ## 一个epoch输出一次
            print('epoch:%d/%d'%(epoch, self.num_epoch))
            print('Discriminator_Loss: %.4f'%(mean_dis_loss/num_batch))
            print('Generator_total_Loss:%.4f'%(mean_gen_total_loss/num_batch))
            
            ## 保存模型
            state_dis = {'dis_model': self.discriminator.state_dict(), 'epoch': epoch}
            state_gen = {'gen_model': self.generator.state_dict(), 'epoch': epoch}
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint') 
            torch.save(state_dis, 'checkpoint/'+self.discriminator.name()+'__'+str(epoch+1)) #each epoch
            torch.save(state_gen, 'checkpoint/'+self.generator.name()+'__'+str(epoch+1))     #each epoch
            torch.save(state_dis, 'checkpoint/'+self.discriminator.name())    #final  
            torch.save(state_gen, 'checkpoint/'+self.generator.name())        #final  
            ## 验证模型
            if epoch<45 or epoch%self.interval==0:
                 self.validater(epoch)
            print('--'.center(12,'-'))
验证
    def validater(self,epoch):
        vis = visdom.Visdom(env='generate_girl_epoch%d'%(epoch))
        r,c = 3,3
        gen_input_val = np.random.normal(0,1,(r*c,self.latent_dim)).astype(np.float32)
        gen_input_val = torch.from_numpy(gen_input_val)
        gen_input_val = torch.autograd.Variable(gen_input_val)
        if self.cuda:
            gen_input_val = gen_input_val.cuda()   
        output_val = self.generator(gen_input_val)     #(r*c,3,64,64)
        output_val = output_val.cpu()
        output_val = output_val.data.numpy()      #(r*c,3,64,64)        
        img = np.transpose(output_val,(0,2,3,1))  #(r*c,64,64,3) 
        fig, axs = plt.subplots(r,c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                vis.image(output_val[cnt],opts={'title':'epoch%d_cnt%d'%(epoch,cnt)}) 
                axs[i, j].imshow(img[cnt, :, :, :])
                axs[i, j].axis('off')
                cnt += 1   
        if not os.path.isdir('images'):
            os.mkdir('images') 
        fig.savefig('images/val_%d.png'%(epoch+1)) ##保存验证结果
        plt.close()
测试
    def tester(self,gen_input_test): #输入(N,latent_dim)
        assert gen_input_test.shape[1]==self.latent_dim, \
        'dimension 1''s size expect %d,but input %d'%(self.latent_dim,gen_input_test.shape[1])
        gen_input_test = gen_input_test.astype(np.float32)
        gen_input_test = torch.from_numpy(gen_input_test)
        gen_input_test = torch.autograd.Variable(gen_input_test)
        if self.cuda:
            gen_input_test = gen_input_test.cuda()   
        ## 下载验证结果
        if os.path.isdir('checkpoint'):
            try:
                checkpoint_gen = torch.load(self.testmodelpath)
                self.generator.load_state_dict(checkpoint_gen['gen_model'])
            except FileNotFoundError:
                print('Can\'t found dict')
        output_test = self.generator(gen_input_test)          
        output_test = output_test.cpu()
        output_test = output_test.data.numpy()      #(N,3,64,64)
        img = np.transpose(output_test,(0,2,3,1))  #(N,64,64,3) 
        if not os.path.isdir('images'):
            os.mkdir('images')         
        N = img.shape[0] #图像个数
        for i in range(N):
            plt.imshow(img[i, :, :, :])
            plt.axis('off')
            plt.savefig('images/test_%d.png'%(i+1)) ##保存结果
            plt.close()

结果和原keras相比没什么区别,毕竟网络都差不多,也不需要过高期望,而且网络本身比较小,生成一个好看的人脸,对是五官是否协调有很大的要求,是很有挑战的事情。

输入:

np.random.normal(0,1,(1,self.latent_dim)).astype(np.float32)

部分结果:

  • 5
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值