pytorch训练GAN生成人脸图像

GAN(Generative Adversarial Networks)生成对抗网络https://arxiv.org/pdf/1406.2661.pdf是在2014提出的一个图像生成模型。基于这种思想后续出现了非常多的GAN架构的生成模型,其中比较典型的是DCGAN 和WGAN, 这篇文章通过讲解GAN思想,实现一下DCGAN代码生成人脸,最后再分析一下WGAN。

一、GAN原理

在图像领域,检测、识别、分割是较为常见的CV任务,这么看好像少了一个图像生成,GAN就是用来生成图像的网络。

如果想凭空生成一张图像,在数学上该怎么设计呢?很显然,我们需要知道已知数据的分布才行,统计出数据分布规律,然后随机采样就可以生成一个数据了。比如高斯分布N~(u, \sigma ^{2}),从数据中拟合出u和sigma两个参数,就可以随机采样生成数据了。

那在图像领域,比如我们有很多美女脸的图像,这些图像的像素分布肯定是符合某种复杂规律的,该怎么统计人脸分布呢?很明显,用神经网络来叠加非线性单元就行了(理论上神经网络可以拟合任意复杂的函数)。

所以我们可以通用设计神经网络来构建一个生成器Generator, 输入为随机噪声,输出为图像。现在的问题就是,该如何训练这个Generator。

GAN就是用来解决这个问题的。

如图所示就是GAN的原理示意图,从图中可以看出GAN是通过一个判别器Discrminator来训练Genrator的。 第一步,先送入一些read images和Generator生成的fake images,训练Discriminator,这就是个二分类问题。此时Discriminator的分辨能力会提升。第二步,训练Generator, 使其生成的图像更加逼真,此时Dsicriminator的分辨能力会降低,就这样此消彼长,交替训练,直到达到纳什均衡,即Dsicriminator再也区分不出real image和fake image, 只能按照50%的概率来猜测是否是real image。

所以我们现在定义个噪声分布p_z(z),定义一个生成器G(z;\theta_g)为从噪声空间到图像分布空间的多层感知机的映射。同时,我们定义一个判别器D(x;\theta_d),输出为属于real image的概率。那么,我们通过交叉熵H(1, D(x))H(0, D(G(z)))来训练D,通过H(1, D(G(z))))来训练生成器就可以了。其中,H表示交叉熵。

上述公示为GAN原论文给出的损失函数。其中第一项就是E_x就是求D(x)的期望,实际上就是取负号后就是交叉熵H(p_{data},D(x) ),这里求E_x最大,就是求H最小。同理,对于第二项E_z的期望就是负的交叉熵H(p_z, 1-D(G(z))), 因为p_z对应的为fake image, 所以D的输出概率应该1-D(G(z)), 最小化G(z)就相当于最大化E_z了。

二. DCGAN代码实现

 DCGAN(Deep Convolution GAN)顾名思义就是深度卷积GAN,因为原生的GAN是采用全连接的,DCGAN就把full connection换成convolution来提取特征。

上图是DCGAN论文中给出的网络结构,上采样使用TransConvolution来实现。本文将在网上找开源数据集来训练一下DCGAN。

2.1 数据 

 我这里是找了一个全是网红脸图像的数据集,来源新数据集

 定义一个读取数据的文件Face1.py

from torch.utils.data import DataLoader,Dataset
from torchvision import transforms as T
import matplotlib.pyplot as plt
import os
from PIL import Image
import numpy as np
import torch
import torchvision.utils as vutils

 
class Face1(Dataset):
    def __init__(self, root, transforms=None):
        imgs = []
        for path in os.listdir(root):
            imgs.append(os.path.join(root, path))
 
        self.imgs = imgs
        if transforms is None:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
            self.transforms = T.Compose([
                    T.Resize(64),
                    T.CenterCrop(64),
                    T.ToTensor(),
                    normalize
            ])
        else:
            self.transforms = transforms
             
    def __getitem__(self, index):
        img_path = self.imgs[index]
 
        data = Image.open(img_path)
        if data.mode != "RGB":
            data = data.convert("RGB")
        data = self.transforms(data)
        return data
 
    def __len__(self):
        return len(self.imgs)
 

if __name__ == "__main__":
    root = "/home/elvis/dataset/face"
    train_dataset = Face1(root)
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    real_batch = next(iter(train_dataloader))
    real_batch = np.transpose(vutils.make_grid(real_batch.to(device)[:64], padding=2, normalize=False).cpu(),(1,2,0)).numpy()
    real_batch = real_batch[:,:,:] * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(real_batch)
    plt.show()

 2.2 网络定义

写一个定义DCGAN网络的文件dcgan.py,注释都比较全。

import torch
from torch import nn


class Generator(nn.Module):
    def __init__(self, ngpu=1, nz=100, ngf=64, nc=3):
        super(Generator, self).__init__()
        '''
        self.ngpu : train gpu num
        nz  :  Size of z latent vector (i.e. size of generator input)
        ngf :  Size of feature maps in generator
        nc  :  Number of channels in the generator images. For color images this is 3
        '''
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)




class Discriminator(nn.Module):
    def __init__(self, ngpu, nc=3, ndf=64):
        super(Discriminator, self).__init__()
        '''
        ngpu: train gpu num
        nc  : Number of channels in the real images. For color images this is 3
        ndf : Size of feature maps in discriminator 
        '''
        self.ngpu = ngpu
        # Generator Code
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)



# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

 2.3 训练代码

import torch
import os
from torch.utils.data import DataLoader
import numpy as np
from torch import nn
from torch import optim
import torchvision.utils as vutils
import matplotlib.pyplot as plt

from face1 import Face1
from dcgan import Generator, Discriminator, weights_init





batch_size = 128
# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 20
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 1. load real image
root = "/home/elvis/dataset/seeprettyface_chs_wanghong/xinggan_face"
dataset = Face1(root)
print(len(dataset))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(len(dataloader))


# 2. load generator and discriminator
netG = Generator(ngpu=ngpu, nz=nz, ngf=ngf).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
netG.apply(weights_init)   #  to ``mean=0``, ``stdev=0.02``.
print(netG)


# Create the Discriminator
netD = Discriminator(ngpu).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
netD.apply(weights_init)   # `to mean=0, stdev=0.2``.
print(netD)

# 3. loss function
criterion = nn.BCELoss()

# 4. Create batch of latent vectors that we will use to visualize
fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)

real_label = 1.
fake_label = 0.

# 5. Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))


# 6. train loop
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        real_cpu = data.to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)  # .detach表示不迭代梯度,此时netG不动
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()

        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()

                img_batch = vutils.make_grid(fake[:64], padding=2, normalize=False)
                img_batch = np.transpose(img_batch, (1,2,0)).numpy()
                img_batch = img_batch[:,:,:] * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]

                plt.figure(figsize=(8,8))
                plt.axis("off")
                plt.title("Generate Images")
                plt.imshow(img_batch)
                model_path = "result/"+str(iters)
                if not os.path.exists(model_path):
                    os.makedirs(model_path)
                plt.savefig(model_path+"/fake.png")
                torch.save(netG.state_dict(), model_path+"/model.pth")
            # img_list.append(vutils.make_grid(fake[:64], padding=2, normalize=True))

        iters += 1

print("End Training Loop...")
x = range(iters)
plt.figure(figsize=(10,5))
plt.title("Loss")
plt.xlabel("iters")
plt.ylabel("loss value")
plt.plot(x, D_losses,'-',label="D_loss", color='r')
plt.plot(x, G_losses,'-',label="G_loss", color='b')
plt.legend()
plt.grid(True)
plt.savefig("loss.png")

2.4 生成图片与训练结果

最终的生成效果和训练结果是

 

很遗憾,并不收敛。理想的效果应该是,D_loss上下横跳,但随着G训练的越来越好,D_loss横跳的幅度越来越小才对,G_loss同理。

直观上的理解就是,GAN是学习一个随机高斯分布到特定图像分布的映射,但在实际中,这两种分布很可能是完全不相交的。总体来说,GAN有如下缺点:

  1. 原始的GAN训练困难。需要很小心地平衡生成器和判别器的训练程度,如果判别器过强,会导致生成器梯度消失严重,难以进化,进而大大增加训练所需时间。
  2. 生成器和判别器的loss无法指示进程,也就是说,我们无法通过生成器与判别器的loss来判断我们生成的图像是否到达了我们所满意的情况。只能通过显示训练图像自行感受训练程度。
  3. 生成样本缺乏多样性。容易产生模型崩坏,即生成的图像中有着大量的重复图像。
     

3. WGAN

3.1 WGAN原理

为了解决GAN的问题,提出了许许多多的建立在GAN思想上的算法,其中比较有代表性的是WGAN(Wasserstein GAN)。

WGAN从数学上证明了,GAN之所以难以训练是主要是损失函数选择的原因。具体推导过程,推荐一篇写的很不错的知乎:令人拍案叫绝的Wasserstein GAN - 知乎

WGAN主要就是从损失函数解决训练困难的问题的,使用Wasserstein距离代替交叉熵,主要改变和算法流程如下:

  • 判别器最后一层去掉sigmoid。sigmoid函数容易出现梯度消失的情况。
  • 生成器和判别器的loss不取log
  • 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
  • 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

 3.2 WGAN训练

根据3.1分析,修改我们的用来训练DCGAN的代码即可,这里注意,要把DCGAN模型定义文件里的Discriminator的sigmoid层注释掉。

然后重写一个训练文件trainwgan.py

import os
import torch
from torch import optim
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np

from face1 import Face1
from dcgan import Generator, Discriminator, weights_init


# 1. load real image
root = "/home/elvis/dataset/seeprettyface_chs_wanghong/xinggan_face"
dataset = Face1(root)
print(len(dataset))
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)


# 2. load generator and discriminator
net_G = Generator(ngpu=1, nz=100, ngf=64).cuda()
# net_G.apply(weights_init)   #  to ``mean=0``, ``stdev=0.02``.

net_D = Discriminator(1, nc=3, ndf=64).cuda()
# net_D.apply(weights_init)   # `to mean=0, stdev=0.2``.

# 3. optim
lr = 5e-5
opt_G = optim.RMSprop(net_G.parameters(), lr=lr)
opt_D = optim.RMSprop(net_D.parameters(), lr=lr)

# 4. super parameter
num_epochs = 20


# train
D_losses, G_losses = [], []
iter = 0
for epoch in range(num_epochs):
    for batch_id, data in enumerate(dataloader):
        net_D.zero_grad()
        real = data.cuda()
        real_loss = net_D(real)
        noise = torch.randn(128, 100, 1, 1).cuda()
        fake = net_G(noise)
        fake_loss = net_D(fake.detach())
        
        D_loss = -torch.mean(real_loss) + torch.mean(fake_loss)
        D_loss.backward()
        opt_D.step()
        
        # clip weights of discriminator
        for p in net_D.parameters(): p.data.clamp_(-0.01, 0.01)

        if batch_id % 5 == 0:
            net_G.zero_grad()

            g_loss = -torch.mean(net_D(fake))
            g_loss.backward()
            opt_G.step()
    
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
                    % (epoch, num_epochs, batch_id, len(dataloader), D_loss.item(), g_loss.item()))
            iter += 1
            D_losses.append(D_loss.item())
            G_losses.append(g_loss.item())

        if iter % 100 == 0:   
            with torch.no_grad():
                noise = torch.randn(128, 100, 1, 1).cuda()
                fake = net_G(noise).detach().cpu()

                img_batch = vutils.make_grid(fake[:64], padding=2, normalize=False)
                img_batch = np.transpose(img_batch, (1,2,0)).numpy()
                img_batch = img_batch[:,:,:] * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]

                plt.figure(figsize=(8,8))
                plt.axis("off")
                plt.title("Generate Images")
                plt.imshow(img_batch)
                model_path = "result_wgan/"+str(iter)
                if not os.path.exists(model_path):
                    os.makedirs(model_path)
                plt.savefig(model_path+"/fake.png")
                torch.save(net_G.state_dict(), model_path+"/model.pth")


print("End Training Loop...")
x = range(iter)
plt.figure(figsize=(10,5))
plt.title("D Loss")
plt.xlabel("iters")
plt.ylabel("loss value")
plt.plot(x, D_losses,'-',label="D_loss", color='r')
plt.plot(x, G_losses,'-',label="G_loss", color='b')
plt.legend()
plt.grid(True)
plt.savefig("loss.png")

     

经过20个epoch之后,效果依旧不理想啊,看看损失函数,还是不收敛

生成的图像更是一言难尽。分析一下这个原因,还是GAN难以训练的问题。实际上,这也是GAN整个算法家族的问题,非常依赖于参数,多尝试几次,找一套完美的参数可以得到较好的效果,不然就效果很差。

损失函数往往不能可靠地收敛到鞍点,导致模型稳定性较差。即使有研究人员提出一些技巧来加强鞍点的稳定性,但还是不足以解决这个问题。

这也是GAN现在(2023年)逐渐被淘汰的原因,代替它的,将是效果更好的stable diffision。

  • 2
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值