GAN的原理
GAN是一种典型的生成网络模型,它类似于编解码结构,通过训练,他能够生成不同于训练集的各种图片。
首先先训练判别器,把真图通过判别器的输出和真标签作损失,把假图通过判别器的输出和假标签作损失,让它具备判别真图和假图的能力。然后再训练生成器,把生成器生成的假图通过判别器的输出和真标签作损失。经过反复的训练,让判别器难以分辨生成图的真假,也就是让它判别为真或为假的概率各为0.5
数据集下载
网上下载的动漫头像数据集有很多不清晰的奇异样本,对此我做了清洗,剩下的都是符合标准的,可直接下载
百度网盘:https://pan.baidu.com/s/1–zFrJdg1gtW2wJ6wtWQsQ
密码:bu55
网络结构
生成网络
相当于一个编码器
class NetD(nn.Module):
# 构建一个判别器,相当与一个二分类问题, 生成一个值
def __init__(self):
super(NetD, self).__init__()
ndf = opt.ndf
self.main = nn.Sequential(
# 输入96*96*3
nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
# 输入32*32*ndf
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, True),
# 输入16*16*ndf*2
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, True),
# 输入为8*8*ndf*4
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, True),
# 输入为4*4*ndf*8
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=True),
nn.Sigmoid() # 分类问题
)
def forward(self, x):
return self.main(x).view(-1)
生成器
相当于一个解码器
class NetG(nn.Module):
# 定义一个生成模型,通过输入噪声来产生一张图片
def __init__(self):
super(NetG, self).__init__()
ngf = opt.ngf
self.main = nn.Sequential(
# 假定输入为一张1*1*opt.nz维的数据(opt.nz维的向量)
nn.ConvTranspose2d(opt.nz , ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(inplace=True),
# 输入一个4*4*ngf*8
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# 输入一个8*8*ngf*4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=True),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# 输入一个16*16*ngf*2
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(inplace=True),
# 输入一个32*32*ngf
nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
nn.Tanh()
# 输出一张96*96*3
)
def forward(self, x):
return self.main(x)
GAN网络结构设计要点
1、在D网络中用stride卷积(stride>1)代替pooling层,在G网络中用conv2d_transpose代替上采样层
2、在G和D网络中直接将BN应用到所有层会导致样本震荡和模型不稳定,通过在G网络输出层和D网络输入层不采用BN层可以有效防止这种现象
3、不使用全连接层作为输出
4、G网络中除了输出层用tanh激活,其他层都是用ReLu激活
5、D网络中都使用LeakyReLu激活
网络模型训练
训练细节
1、预处理环节,将图像scale到tanh的[-1,1]
2、所有的参数初始化由(0,0.02)的正态分布中随机得到
3、LeakyReLu的斜率是0.2(默认)
4、优化器Adam的learning rate=0.0002,momentum参数betas的beta1从0.9降为0.5,beta2默认,防止震荡和不稳定
5、可以G网络训练1次,然后D网络训练1次,如此反复;也可以G网络先训练几次后,D网络再训练1次,如此反复。前者效果出得较快,后者较慢。
训练代码
# opt参数
ngf=96
ndf=96
nz=256
img_size=96
batch_size=100
num_workers=4
netg_path=r"网络参数/netg_5.pt"
netd_path=r"网络参数/netd_5.pt"
lr1=0.0002
lr2=0.0002
beta1=0.5
epochs=200
d_every=1
g_every=5
save_every=20
from torchvision.utils import save_image
import Nets
import torch
from torch.utils.data import DataLoader
import opt
import torch.nn as nn
import dataset
if __name__=="__main__":
# 1. 加载数据
dataset = dataset.Dataset()
dataloader = DataLoader(dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers,drop_last=True)
# 2.初始化网络
netg, netd = Nets.NetG(), Nets.NetD()
# 3. 设定优化器参数
optimize_g = torch.optim.Adam(netg.parameters(), lr=opt.lr1, betas=(opt.beta1,0.999))
optimize_d = torch.optim.Adam(netd.parameters(), lr=opt.lr2, betas=(opt.beta1,0.999))
loss_func = nn.BCELoss()
# 4. 定义标签, 并且开始注入生成器的输入noise
true_labels = torch.ones(opt.batch_size)
fake_labels = torch.zeros(opt.batch_size)
noises = torch.randn(opt.batch_size, opt.nz, 1, 1)
# 6.训练网络
netg.train()
netd.train()
for epoch in range(opt.epochs):
for i, img in enumerate(dataloader):
real_img = img
# 训练判别器
if i % opt.d_every == 0:
optimize_d.zero_grad()
# 真图
real_out = netd(real_img)
error_d_real = loss_func(real_out, true_labels)
error_d_real.backward()
# 随机生成的假图
noises = noises.detach()
fake_image = netg(noises).detach()
fake_out = netd(fake_image)
error_d_fake = loss_func(fake_out, fake_labels)
error_d_fake.backward()
optimize_d.step()
# 计算loss
error_d = error_d_fake + error_d_real
print("第{0}轮: 判别网络 损失:{1} 对真图评分:{2} 对生成图评分:{3}".format(epoch+1,error_d.item(),real_out.data.mean(),fake_out.data.mean()))
# 训练生成器
if i % opt.g_every == 0 and i>0:
optimize_g.zero_grad()
noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))
fake_img = netg(noises)
output = netd(fake_img)
error_g = loss_func(output, true_labels)
error_g.backward()
optimize_g.step()
print(" 生成网络 损失:{0}".format(error_g.item()))
# 7.保存模型和图片
if i % opt.save_every == 0 and i>0:
fix_noises = torch.randn(opt.batch_size, opt.nz, 1, 1)
fix_fake_image = netg(fix_noises)
# save_image(real_img.data*0.5+0.5, "./img/{0}-{1}-real_img.jpg".format(epoch, i), nrow=10)
save_image(fix_fake_image.data*0.5+0.5, "./image/{0}-{1}-fake_img.jpg".format(epoch, i), nrow=10)
torch.save(netd.state_dict(), opt.netd_path)
torch.save(netg.state_dict(), opt.netg_path)
效果展示
生成网络随机生成的头像