个人GAN训练日志

1 GAN

GAN (Generative Adversarial Network) ,即生成对抗网络,曾经是深度学习的主流生成式网络架构,虽然近些年来Diffusion逐渐崛起,但GAN的思想确实有着精妙的独到之处。

对于一个生成式任务而言,其目标无非是利用神经网络的模拟和建模能力,从一个简单的分布拟合成一个复杂的分布,从而满足“创造性”这一需求。

2 GAN原理

GAN模型由两个神经网络组成:生成器G和判别器D。生成器以随机噪声为输入并生成虚假的数据样本,而判别器则接收真实数据和生成的虚假数据作为输入,并尝试将它们区分开来。生成器的目标是生成越来越逼真的假数据,使判别器无法区分真实数据和生成的假数据,而判别器的目标则是尽可能准确地区分真实数据和生成的假数据。这种“对抗”训练使得生成器和判别器逐渐达到平衡,并最终生成高质量的数据样本。

个人认为,GAN的精髓在于巧妙的生成器G与判别器D的对抗设计,从而让生成器G能够逐渐了解与贴近复杂的数据分布。

一方面,利用判别器D从而大大简化了对于生成数据的评价问题,从而在原理上“轻松”地设计损失函数,使得生成器G能端到端地进行无监督训练(指没有标注的非条件GAN);另一方面,如果将G+D视为一个网络的话,那么GAN就是利用判别器D和生成器G两个模块的交错式训练,从而你一拳我一脚,直接自己左脚踩右脚互相提升,直到最后达到纳什均衡。

具体到训练过程上,简单而言就是以下步骤:

1.锁住生成器G的梯度,用真实图片和G产生的图片训练判别器D,使其具备分辨真假图片的能力。

2.锁住判别器D的梯度,利用假数据+真标签的方法训练生成器G,使得生成器G的参数朝着欺骗目前的判别器方向优化。

3.重复上述过程,判别器D和生成器G永远利用对方没有训练变强的空隙提升自己,打败对方,从而交错式成长。


从上述思想看来,事实上GAN是一种思想而非一种固定的网络结构,只要是梯度能传导,就意味着任何两种网络可以利用这种思想进行最终实现生成器G的训练。

当然,理论是美好的,但事实上,GAN的训练极度不稳定!因为涉及到两个网络的平衡问题,一旦判别器D过强或者生成器G找到了判别器的盲点,就无法继续提升了。
而且我始终认为,判别这一任务远比生成简单,事实上在训练过程中,也常常出现判别器提升过快,不得不重置判别器D,使得生成器G有继续进步的空间。

而在理论上,WGAN的大佬也证明了GAN训练之难的理论背景,大致上因为图片的分布及其狭窄,高维空间中绝大部分都是噪声而非图片,导致生成数据与真实分布之间的重叠区域过小或不存在,JS难以进行优化,并且利用推土机距离优化原有了的JS距离判断。

3 DCGAN

怎么能不去亲自玩一下GAN呢,直接利用pytorch官方的DCGAN教程上手一下GAN,我这里采用了Arvin Liu收集的cripko数据集,都是动漫二次元头像。

数据集展示:

在这里插入图片描述

接下来就直接上代码,反正基本就是pytorch上copy下来,然后改了改路径。

3.1 数据集分割

import os
import shutil
import random

def run():
    original_path="../../../../Dataset/AnimeFaces"
    filename=os.listdir(original_path)
    filename.remove("train")
    filename.remove("test")

    test_list=list(random.sample(filename,1500))
    train_list=list(filter(lambda x: x not in test_list,filename))

    #分割数据集
    for i in range(len(train_list)):
        src=original_path+"/"+train_list[i]
        dst=original_path+"/trainfolder/train/"+train_list[i]
        shutil.move(src,dst)
    for i in range(len(test_list)):
        src=original_path+"/"+test_list[i]
        dst=original_path+"/test/"+test_list[i]
        shutil.move(src,dst)
#run()

3.2 Dataset

这里训练图片存放在…/trainfolder/train,而不是…/trainfolder下

import os
import torch
import torchvision.datasets as Dataset
import torchvision.transforms as transforms
import numpy as np

dataroot="../../../../Dataset/AnimeFaces/trainfolder"
batch_size=256
dataset=Dataset.ImageFolder(root=dataroot,
                         transform=transforms.Compose([
                             transforms.Resize((64,64)),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
                         ]))

dataloader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=1)

3.3 网络

用现在的眼光看,DCGAN还真是简单粗暴啊。

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self,z_dim=100,g_feature=64):
        super(Generator, self).__init__()
        self.net=nn.Sequential(

            nn.ConvTranspose2d(z_dim,g_feature*8,4,1,0,bias=False),
            nn.BatchNorm2d(g_feature*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(g_feature*8, g_feature * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_feature * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(g_feature * 4, g_feature * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_feature * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(g_feature * 2, g_feature, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_feature),
            nn.ReLU(True),

            nn.ConvTranspose2d(g_feature,3,4,2,1,bias=False),
            nn.Tanh()
        )

    def forward(self,input):
        return self.net(input)

class Discriminator(nn.Module):
    def __init__(self,d_feature=64):
        super(Discriminator, self).__init__()

        self.net=nn.Sequential(

            nn.Conv2d(3,d_feature,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),

            nn.Conv2d(d_feature, d_feature*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(d_feature*2, d_feature * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(d_feature*4, d_feature * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(d_feature*8,1,4,1,0,bias=False),
            nn.Sigmoid()
        )

    def forward(self,input):
        return self.net(input)

3.4 训练

输出一些特征,来观察训练过程。

import torch.nn as nn
import torch
from Dataset import dataloader
import torch.optim as optim
from Net import Generator,Discriminator
import numpy as np
import time

if __name__=='__main__':
    device=torch.device("cuda")
    criterion=nn.BCELoss()

    fixed_noise=torch.randn(64,100,1,1,device=device)

    real_label=1.
    fake_label=0.

    netG=Generator().to(device)
    netD=Discriminator().to(device)

    netG.load_state_dict(torch.load("model_parm/G_epoch600.pt"))
    netD.load_state_dict(torch.load("model_parm/D_epoch400.pt"))

    D_lr = 2e-6
    G_lr = 2e-6

    optimizerD=optim.Adam(netD.parameters(),lr=D_lr,betas=(0.5,0.999))
    optimizerG=optim.Adam(netG.parameters(),lr=G_lr,betas=(0.5,0.999))

    # Training Loop

    # Lists to keep track of progress
    G_losses = []
    D_losses = []

    num_epochs=1000
    t1=time.time()
    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(601,num_epochs+1):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):
            G_losses = []
            D_losses = []
            ############################
            # (1) 训练判别器D : 最大化 log(D(x)) + log(1 - D(G(z))),即真图->1,假图->0
            ###########################
            ## 首先用真图进行训练
            netD.zero_grad()
            # 制作标签,全为1
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            # 带噪声的softlabel
            # label = np.random.rand(b_size)*0.8+0.15
            # label = torch.tensor(label,dtype=torch.float,device=device)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # 输出结果
            output = netD(real_cpu).view(-1)
            # 计算损失
            errD_real = criterion(output, label)
            # 反向传播
            errD_real.backward()
            D_x = output.mean().item()

            ## 使用假图训练
            # 生成随机分布
            noise = torch.randn(b_size,100 , 1, 1, device=device)
            # 制作假图与标签
            fake = netG(noise)
            # 带噪声的softlabel
            # label = np.random.rand(b_size)*0.1+0.05
            # label = torch.tensor(label,dtype=torch.float,device=device)
            label.fill_(fake_label)
            # 输出结果
            output = netD(fake.detach()).view(-1)
            # C计算损失
            errD_fake = criterion(output, label)
            # 梯度回传
            errD_fake.backward()
            # D_G_z1代表着未更新判别器D前,生成器G对目前判别器D的对抗能力
            D_G_z1 = output.mean().item()
            # 计算总损失
            errD = errD_real + errD_fake
            # 优化
            optimizerD.step()

            ############################
            # (2) 训练生成器G : 最大化 log(D(G(z))),从而骗过D
            ###########################
            netG.zero_grad()
            # 假图片配真标签,从而使得更新G参数后,所生成的图片的标签向真标签靠近
            # 使用hardlabel
            G_label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # 假图片训练
            output = netD(fake).view(-1)
            # 损失计算
            errG = criterion(output, G_label)
            # 梯度回传
            errG.backward()
            # D_G_z2代表着未更新判别器G前,生成器G对目前已经更新后的判别器D的对抗能力
            # 显然,一般情况下 D_G_z2 < D_G_z1
            D_G_z2 = output.mean().item()
            # 优化G
            optimizerG.step()

            # 记录损失
            G_losses.append(errG.item())
            D_losses.append(errD.item())

        # save model
        if epoch % 50 ==0:
            torch.save(netG.state_dict(), 'model_parm/G_epoch' + str(epoch) + '.pt')
            torch.save(netD.state_dict(), 'model_parm/D_epoch' + str(epoch) + '.pt')

        # Output training stats
        if epoch % 20 == 0:
            print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2) + ' cost time : ' + str(round(time.time()-t1,4))+'s')
            t1=time.time()
        if epoch<=605:
            print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2) + ' cost time : ' + str(round(time.time()-t1,4))+'s')
            t1=time.time()
        # 初始化判别器D
        # if epoch % 15 ==0:
        #     netD.load_state_dict(torch.load("model_parm/D_epoch0.pt"))

        # record train log
        Gloss=np.mean(G_losses)
        Dloss=np.mean(D_losses)
        with open('model_parm/train_log2.txt','a+') as f:
            string=str(epoch)+'\t'+str(round(Gloss,5))+'\t'+str(round(Dloss,5))+'\t'+\
                   str(round(D_x,5))+'\t'+str(round(D_G_z1,5))+'\t'+str(round(D_G_z2,5))+'\n'
            f.write(string)

3.5 验证

除了人眼观察,当然就得使用FID指标了

from Net import Generator
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.utils as vutil
import os
import shutil
import re

def generate_img(epoch):
    device=torch.device("cuda")
    netG=Generator().to(device)
    netG.load_state_dict(torch.load("model_parm/G_epoch"+str(epoch)+".pt"))

    fixed_noise=torch.randn(64,100,1,1,device=device)
    with torch.no_grad():
        fake=netG(fixed_noise).detach().cpu()
        plt.imshow(np.transpose(vutil.make_grid(fake,padding=2,normalize=True),(1,2,0)))
        plt.show()

def cal_FID(epoch):
    #生成图片
    device=torch.device("cuda")
    netG=Generator().to(device)
    netG.load_state_dict(torch.load("model_parm/G_epoch"+str(epoch)+".pt"))

    fixed_noise=torch.randn(1000,100,1,1,device=device)
    fake=netG(fixed_noise).detach().cpu()
    for i in range(1000):
        vutil.save_image(fake[i],"D:/Pycharm/Dataset/AnimeFaces/fakeimg/"+str(i)+".jpg",normalize=True)
    os.system("activate pytorch")
    result=os.popen(r"python -m pytorch_fid D:\Pycharm\Dataset\AnimeFaces\test D:\Pycharm\Dataset\AnimeFaces\fakeimg")
    content=result.readlines()[0]
    fid=re.findall(r"\d+\.?\d*",content)
    fid=list(filter(lambda x : x!='0',fid))
    if(len(fid)==1):
        print("epoch",epoch,":",float(fid[0]))
    else:
        print("epoch",epoch,":",fid)
    shutil.rmtree(r"D:\Pycharm\Dataset\AnimeFaces\fakeimg")
    os.mkdir(r"D:\Pycharm\Dataset\AnimeFaces\fakeimg")

generate_img(400)

4 WGAN

使用WGAN原因直接参考论文,我就没打算使用MLP硬整,直接就在DCGAN基础上改网络、损失函数和训练过程了。

4.1 Net

只把判别器D的sigmoid删了罢了。

class Discriminator(nn.Module):
    def __init__(self,d_feature=64):
        super(Discriminator, self).__init__()

        self.net=nn.Sequential(

            nn.Conv2d(3,d_feature,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),

            nn.Conv2d(d_feature, d_feature*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(d_feature*2, d_feature * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(d_feature*4, d_feature * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(d_feature*8,1,4,1,0,bias=False),
            #删除sigmoid防止梯度消失
            #nn.Sigmoid()
        )

    def forward(self,input):
        return self.net(input)

4.2 Train

from Dataset import dataloader
import torch.optim as optim
from Net import Generator,Discriminator
import time
import torch

if __name__=='__main__':
    device=torch.device("cuda")
    #criterion=nn.BCELoss()

    fixed_noise=torch.randn(64,100,1,1,device=device)

    n_iter=5
    clip_value=0.01

    netG=Generator().to(device)
    netD=Discriminator().to(device)
    # netG.load_state_dict(torch.load("model_parm/G_epoch340.pt"))
    # netD.load_state_dict(torch.load("model_parm/D_epoch340.pt"))

    D_lr = 3e-5
    G_lr = 3e-5

    optimizerD=optim.RMSprop(netD.parameters(),lr=D_lr)
    optimizerG=optim.RMSprop(netG.parameters(),lr=G_lr)

    # Training Loop
    num_epochs=1500
    t1=time.time()
    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs+1):
        for i, data in enumerate(dataloader,0):
            # Configure input
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizerD.zero_grad()

            # Sample noise as generator input
            noise = torch.randn(b_size, 100, 1, 1, device=device)
            # Generate a batch of images
            fake_imgs = netG(noise).detach()
            # Adversarial loss
            loss_D = -torch.mean(netD(real_cpu)) + torch.mean(netD(fake_imgs))

            loss_D.backward()
            optimizerD.step()

            # Clip weights of discriminator
            for p in netD.parameters():
                p.data.clamp_(-clip_value, clip_value)

            # Train the generator every n_critic iterations
            # -----------------
            #  Train Generator
            # -----------------

            optimizerG.zero_grad()

            # Generate a batch of images
            gen_imgs = netG(noise)
            # Adversarial loss
            loss_G = -torch.mean(netD(gen_imgs))

            loss_G.backward()
            optimizerG.step()

        # save model
        if epoch % 50 ==0:
            torch.save(netG.state_dict(), 'model_parm/G2_epoch' + str(epoch) + '.pt')
            torch.save(netD.state_dict(), 'model_parm/D2_epoch' + str(epoch) + '.pt')

        # Output training stats
        if epoch % 20 == 0:
            print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t'
              % (epoch, num_epochs, loss_D.item(), loss_G.item()) + ' cost time : ' + str(round(time.time()-t1,4))+'s')
            t1=time.time()
        # 初始化判别器D
        # if epoch % 15 ==0:
        #     netD.load_state_dict(torch.load("model_parm/D_epoch0.pt"))

        # record train log
        with open('model_parm/train_log2.txt','a+') as f:
            string=str(epoch)+'\t'+str(round(loss_G.item(),5))+'\t'+str(round(loss_D.item(),5))+'\n'
            f.write(string)

5 结果

没啥条理地训练了4轮吧,最后感觉还是原始的DCGAN效果最好,什么softlabel,WGAN都什么提升,但确实训练很随意,每个训练的epoch不一样,没怎么记录学习率改变啊,还有判别器和生成的回溯标准等等,纯纯地上手感受GAN罢了。

5.1 原始DCGAN

总共训练了1000epoch,中途在600epoch左右判别器宕机了,将判别器回溯到400epoch继续训练到1000epoch,直接上最好的结果(400epoch):

在这里插入图片描述

epochFID↓
37084.92
40079.94
50092.37
60093.45
90088.99
100089.66

5.2 DCGAN+SoftLabel+判别器间歇更新

训练350epoch,最好结果(350epoch)如下:

在这里插入图片描述

epochFID↓
0345.39
150271.35
350164.42

5.3 WGAN+生成器间歇更新

不得不承认,虽然WGAN最后效果一般,但训练过程基本都是稳步下降,也没怎么发生模式坍塌,训练1100epoch,最好结果(1100epoch)如下:

在这里插入图片描述

epochFID↓
0386.23
100284.69
200218.54
300185.20
900154.90
1000154.88
1100148.92

5.4 WGAN

训练1500epoch,最好结果(1400epoch)如下:

在这里插入图片描述

epochFID↓
0306.68
100142.88
200125.99
300124.07
1300113.03
1400111.63
1500112.65

6 总结

总之GAN还是很好玩的,而且不咋吃显存,我用了256的batchsize,也只吃了不到2G的显存。相比之下仅仅是微调SD的Lora模型,batchsize=1都要吃8G显存,果然只有scale matters。
另外Markdown插入图片的体验太烂了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值