1 GAN
GAN (Generative Adversarial Network) ,即生成对抗网络,曾经是深度学习的主流生成式网络架构,虽然近些年来Diffusion逐渐崛起,但GAN的思想确实有着精妙的独到之处。
对于一个生成式任务而言,其目标无非是利用神经网络的模拟和建模能力,从一个简单的分布拟合成一个复杂的分布,从而满足“创造性”这一需求。
2 GAN原理
GAN模型由两个神经网络组成:生成器G和判别器D。生成器以随机噪声为输入并生成虚假的数据样本,而判别器则接收真实数据和生成的虚假数据作为输入,并尝试将它们区分开来。生成器的目标是生成越来越逼真的假数据,使判别器无法区分真实数据和生成的假数据,而判别器的目标则是尽可能准确地区分真实数据和生成的假数据。这种“对抗”训练使得生成器和判别器逐渐达到平衡,并最终生成高质量的数据样本。
个人认为,GAN的精髓在于巧妙的生成器G与判别器D的对抗设计,从而让生成器G能够逐渐了解与贴近复杂的数据分布。
一方面,利用判别器D从而大大简化了对于生成数据的评价问题,从而在原理上“轻松”地设计损失函数,使得生成器G能端到端地进行无监督训练(指没有标注的非条件GAN);另一方面,如果将G+D视为一个网络的话,那么GAN就是利用判别器D和生成器G两个模块的交错式训练,从而你一拳我一脚,直接自己左脚踩右脚互相提升,直到最后达到纳什均衡。
具体到训练过程上,简单而言就是以下步骤:
1.锁住生成器G的梯度,用真实图片和G产生的图片训练判别器D,使其具备分辨真假图片的能力。
2.锁住判别器D的梯度,利用假数据+真标签的方法训练生成器G,使得生成器G的参数朝着欺骗目前的判别器方向优化。
3.重复上述过程,判别器D和生成器G永远利用对方没有训练变强的空隙提升自己,打败对方,从而交错式成长。
从上述思想看来,事实上GAN是一种思想而非一种固定的网络结构,只要是梯度能传导,就意味着任何两种网络可以利用这种思想进行最终实现生成器G的训练。
当然,理论是美好的,但事实上,GAN的训练极度不稳定!因为涉及到两个网络的平衡问题,一旦判别器D过强或者生成器G找到了判别器的盲点,就无法继续提升了。
而且我始终认为,判别这一任务远比生成简单,事实上在训练过程中,也常常出现判别器提升过快,不得不重置判别器D,使得生成器G有继续进步的空间。
而在理论上,WGAN的大佬也证明了GAN训练之难的理论背景,大致上因为图片的分布及其狭窄,高维空间中绝大部分都是噪声而非图片,导致生成数据与真实分布之间的重叠区域过小或不存在,JS难以进行优化,并且利用推土机距离优化原有了的JS距离判断。
3 DCGAN
怎么能不去亲自玩一下GAN呢,直接利用pytorch官方的DCGAN教程上手一下GAN,我这里采用了Arvin Liu收集的cripko数据集,都是动漫二次元头像。
数据集展示:
接下来就直接上代码,反正基本就是pytorch上copy下来,然后改了改路径。
3.1 数据集分割
import os
import shutil
import random
def run():
original_path="../../../../Dataset/AnimeFaces"
filename=os.listdir(original_path)
filename.remove("train")
filename.remove("test")
test_list=list(random.sample(filename,1500))
train_list=list(filter(lambda x: x not in test_list,filename))
#分割数据集
for i in range(len(train_list)):
src=original_path+"/"+train_list[i]
dst=original_path+"/trainfolder/train/"+train_list[i]
shutil.move(src,dst)
for i in range(len(test_list)):
src=original_path+"/"+test_list[i]
dst=original_path+"/test/"+test_list[i]
shutil.move(src,dst)
#run()
3.2 Dataset
这里训练图片存放在…/trainfolder/train,而不是…/trainfolder下
import os
import torch
import torchvision.datasets as Dataset
import torchvision.transforms as transforms
import numpy as np
dataroot="../../../../Dataset/AnimeFaces/trainfolder"
batch_size=256
dataset=Dataset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
]))
dataloader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=1)
3.3 网络
用现在的眼光看,DCGAN还真是简单粗暴啊。
import torch.nn as nn
class Generator(nn.Module):
def __init__(self,z_dim=100,g_feature=64):
super(Generator, self).__init__()
self.net=nn.Sequential(
nn.ConvTranspose2d(z_dim,g_feature*8,4,1,0,bias=False),
nn.BatchNorm2d(g_feature*8),
nn.ReLU(True),
nn.ConvTranspose2d(g_feature*8, g_feature * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(g_feature * 4),
nn.ReLU(True),
nn.ConvTranspose2d(g_feature * 4, g_feature * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(g_feature * 2),
nn.ReLU(True),
nn.ConvTranspose2d(g_feature * 2, g_feature, 4, 2, 1, bias=False),
nn.BatchNorm2d(g_feature),
nn.ReLU(True),
nn.ConvTranspose2d(g_feature,3,4,2,1,bias=False),
nn.Tanh()
)
def forward(self,input):
return self.net(input)
class Discriminator(nn.Module):
def __init__(self,d_feature=64):
super(Discriminator, self).__init__()
self.net=nn.Sequential(
nn.Conv2d(3,d_feature,4,2,1,bias=False),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(d_feature, d_feature*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_feature*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_feature*2, d_feature * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_feature * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_feature*4, d_feature * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_feature * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_feature*8,1,4,1,0,bias=False),
nn.Sigmoid()
)
def forward(self,input):
return self.net(input)
3.4 训练
输出一些特征,来观察训练过程。
import torch.nn as nn
import torch
from Dataset import dataloader
import torch.optim as optim
from Net import Generator,Discriminator
import numpy as np
import time
if __name__=='__main__':
device=torch.device("cuda")
criterion=nn.BCELoss()
fixed_noise=torch.randn(64,100,1,1,device=device)
real_label=1.
fake_label=0.
netG=Generator().to(device)
netD=Discriminator().to(device)
netG.load_state_dict(torch.load("model_parm/G_epoch600.pt"))
netD.load_state_dict(torch.load("model_parm/D_epoch400.pt"))
D_lr = 2e-6
G_lr = 2e-6
optimizerD=optim.Adam(netD.parameters(),lr=D_lr,betas=(0.5,0.999))
optimizerG=optim.Adam(netG.parameters(),lr=G_lr,betas=(0.5,0.999))
# Training Loop
# Lists to keep track of progress
G_losses = []
D_losses = []
num_epochs=1000
t1=time.time()
print("Starting Training Loop...")
# For each epoch
for epoch in range(601,num_epochs+1):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
G_losses = []
D_losses = []
############################
# (1) 训练判别器D : 最大化 log(D(x)) + log(1 - D(G(z))),即真图->1,假图->0
###########################
## 首先用真图进行训练
netD.zero_grad()
# 制作标签,全为1
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
# 带噪声的softlabel
# label = np.random.rand(b_size)*0.8+0.15
# label = torch.tensor(label,dtype=torch.float,device=device)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# 输出结果
output = netD(real_cpu).view(-1)
# 计算损失
errD_real = criterion(output, label)
# 反向传播
errD_real.backward()
D_x = output.mean().item()
## 使用假图训练
# 生成随机分布
noise = torch.randn(b_size,100 , 1, 1, device=device)
# 制作假图与标签
fake = netG(noise)
# 带噪声的softlabel
# label = np.random.rand(b_size)*0.1+0.05
# label = torch.tensor(label,dtype=torch.float,device=device)
label.fill_(fake_label)
# 输出结果
output = netD(fake.detach()).view(-1)
# C计算损失
errD_fake = criterion(output, label)
# 梯度回传
errD_fake.backward()
# D_G_z1代表着未更新判别器D前,生成器G对目前判别器D的对抗能力
D_G_z1 = output.mean().item()
# 计算总损失
errD = errD_real + errD_fake
# 优化
optimizerD.step()
############################
# (2) 训练生成器G : 最大化 log(D(G(z))),从而骗过D
###########################
netG.zero_grad()
# 假图片配真标签,从而使得更新G参数后,所生成的图片的标签向真标签靠近
# 使用hardlabel
G_label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# 假图片训练
output = netD(fake).view(-1)
# 损失计算
errG = criterion(output, G_label)
# 梯度回传
errG.backward()
# D_G_z2代表着未更新判别器G前,生成器G对目前已经更新后的判别器D的对抗能力
# 显然,一般情况下 D_G_z2 < D_G_z1
D_G_z2 = output.mean().item()
# 优化G
optimizerG.step()
# 记录损失
G_losses.append(errG.item())
D_losses.append(errD.item())
# save model
if epoch % 50 ==0:
torch.save(netG.state_dict(), 'model_parm/G_epoch' + str(epoch) + '.pt')
torch.save(netD.state_dict(), 'model_parm/D_epoch' + str(epoch) + '.pt')
# Output training stats
if epoch % 20 == 0:
print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2) + ' cost time : ' + str(round(time.time()-t1,4))+'s')
t1=time.time()
if epoch<=605:
print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2) + ' cost time : ' + str(round(time.time()-t1,4))+'s')
t1=time.time()
# 初始化判别器D
# if epoch % 15 ==0:
# netD.load_state_dict(torch.load("model_parm/D_epoch0.pt"))
# record train log
Gloss=np.mean(G_losses)
Dloss=np.mean(D_losses)
with open('model_parm/train_log2.txt','a+') as f:
string=str(epoch)+'\t'+str(round(Gloss,5))+'\t'+str(round(Dloss,5))+'\t'+\
str(round(D_x,5))+'\t'+str(round(D_G_z1,5))+'\t'+str(round(D_G_z2,5))+'\n'
f.write(string)
3.5 验证
除了人眼观察,当然就得使用FID指标了
from Net import Generator
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.utils as vutil
import os
import shutil
import re
def generate_img(epoch):
device=torch.device("cuda")
netG=Generator().to(device)
netG.load_state_dict(torch.load("model_parm/G_epoch"+str(epoch)+".pt"))
fixed_noise=torch.randn(64,100,1,1,device=device)
with torch.no_grad():
fake=netG(fixed_noise).detach().cpu()
plt.imshow(np.transpose(vutil.make_grid(fake,padding=2,normalize=True),(1,2,0)))
plt.show()
def cal_FID(epoch):
#生成图片
device=torch.device("cuda")
netG=Generator().to(device)
netG.load_state_dict(torch.load("model_parm/G_epoch"+str(epoch)+".pt"))
fixed_noise=torch.randn(1000,100,1,1,device=device)
fake=netG(fixed_noise).detach().cpu()
for i in range(1000):
vutil.save_image(fake[i],"D:/Pycharm/Dataset/AnimeFaces/fakeimg/"+str(i)+".jpg",normalize=True)
os.system("activate pytorch")
result=os.popen(r"python -m pytorch_fid D:\Pycharm\Dataset\AnimeFaces\test D:\Pycharm\Dataset\AnimeFaces\fakeimg")
content=result.readlines()[0]
fid=re.findall(r"\d+\.?\d*",content)
fid=list(filter(lambda x : x!='0',fid))
if(len(fid)==1):
print("epoch",epoch,":",float(fid[0]))
else:
print("epoch",epoch,":",fid)
shutil.rmtree(r"D:\Pycharm\Dataset\AnimeFaces\fakeimg")
os.mkdir(r"D:\Pycharm\Dataset\AnimeFaces\fakeimg")
generate_img(400)
4 WGAN
使用WGAN原因直接参考论文,我就没打算使用MLP硬整,直接就在DCGAN基础上改网络、损失函数和训练过程了。
4.1 Net
只把判别器D的sigmoid删了罢了。
class Discriminator(nn.Module):
def __init__(self,d_feature=64):
super(Discriminator, self).__init__()
self.net=nn.Sequential(
nn.Conv2d(3,d_feature,4,2,1,bias=False),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(d_feature, d_feature*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_feature*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_feature*2, d_feature * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_feature * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_feature*4, d_feature * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(d_feature * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_feature*8,1,4,1,0,bias=False),
#删除sigmoid防止梯度消失
#nn.Sigmoid()
)
def forward(self,input):
return self.net(input)
4.2 Train
from Dataset import dataloader
import torch.optim as optim
from Net import Generator,Discriminator
import time
import torch
if __name__=='__main__':
device=torch.device("cuda")
#criterion=nn.BCELoss()
fixed_noise=torch.randn(64,100,1,1,device=device)
n_iter=5
clip_value=0.01
netG=Generator().to(device)
netD=Discriminator().to(device)
# netG.load_state_dict(torch.load("model_parm/G_epoch340.pt"))
# netD.load_state_dict(torch.load("model_parm/D_epoch340.pt"))
D_lr = 3e-5
G_lr = 3e-5
optimizerD=optim.RMSprop(netD.parameters(),lr=D_lr)
optimizerG=optim.RMSprop(netG.parameters(),lr=G_lr)
# Training Loop
num_epochs=1500
t1=time.time()
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs+1):
for i, data in enumerate(dataloader,0):
# Configure input
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
# ---------------------
# Train Discriminator
# ---------------------
optimizerD.zero_grad()
# Sample noise as generator input
noise = torch.randn(b_size, 100, 1, 1, device=device)
# Generate a batch of images
fake_imgs = netG(noise).detach()
# Adversarial loss
loss_D = -torch.mean(netD(real_cpu)) + torch.mean(netD(fake_imgs))
loss_D.backward()
optimizerD.step()
# Clip weights of discriminator
for p in netD.parameters():
p.data.clamp_(-clip_value, clip_value)
# Train the generator every n_critic iterations
# -----------------
# Train Generator
# -----------------
optimizerG.zero_grad()
# Generate a batch of images
gen_imgs = netG(noise)
# Adversarial loss
loss_G = -torch.mean(netD(gen_imgs))
loss_G.backward()
optimizerG.step()
# save model
if epoch % 50 ==0:
torch.save(netG.state_dict(), 'model_parm/G2_epoch' + str(epoch) + '.pt')
torch.save(netD.state_dict(), 'model_parm/D2_epoch' + str(epoch) + '.pt')
# Output training stats
if epoch % 20 == 0:
print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t'
% (epoch, num_epochs, loss_D.item(), loss_G.item()) + ' cost time : ' + str(round(time.time()-t1,4))+'s')
t1=time.time()
# 初始化判别器D
# if epoch % 15 ==0:
# netD.load_state_dict(torch.load("model_parm/D_epoch0.pt"))
# record train log
with open('model_parm/train_log2.txt','a+') as f:
string=str(epoch)+'\t'+str(round(loss_G.item(),5))+'\t'+str(round(loss_D.item(),5))+'\n'
f.write(string)
5 结果
没啥条理地训练了4轮吧,最后感觉还是原始的DCGAN效果最好,什么softlabel,WGAN都什么提升,但确实训练很随意,每个训练的epoch不一样,没怎么记录学习率改变啊,还有判别器和生成的回溯标准等等,纯纯地上手感受GAN罢了。
5.1 原始DCGAN
总共训练了1000epoch,中途在600epoch左右判别器宕机了,将判别器回溯到400epoch继续训练到1000epoch,直接上最好的结果(400epoch):
epoch | FID↓ |
---|---|
370 | 84.92 |
400 | 79.94 |
500 | 92.37 |
600 | 93.45 |
… | … |
900 | 88.99 |
1000 | 89.66 |
5.2 DCGAN+SoftLabel+判别器间歇更新
训练350epoch,最好结果(350epoch)如下:
epoch | FID↓ |
---|---|
0 | 345.39 |
150 | 271.35 |
350 | 164.42 |
5.3 WGAN+生成器间歇更新
不得不承认,虽然WGAN最后效果一般,但训练过程基本都是稳步下降,也没怎么发生模式坍塌,训练1100epoch,最好结果(1100epoch)如下:
epoch | FID↓ |
---|---|
0 | 386.23 |
100 | 284.69 |
200 | 218.54 |
300 | 185.20 |
… | … |
900 | 154.90 |
1000 | 154.88 |
1100 | 148.92 |
5.4 WGAN
训练1500epoch,最好结果(1400epoch)如下:
epoch | FID↓ |
---|---|
0 | 306.68 |
100 | 142.88 |
200 | 125.99 |
300 | 124.07 |
… | … |
1300 | 113.03 |
1400 | 111.63 |
1500 | 112.65 |
6 总结
总之GAN还是很好玩的,而且不咋吃显存,我用了256的batchsize,也只吃了不到2G的显存。相比之下仅仅是微调SD的Lora模型,batchsize=1都要吃8G显存,果然只有scale matters。
另外Markdown插入图片的体验太烂了。