对抗生成网络原理
Generator:根据输入的随机向量生成Fake image,并使其骗过Discriminator。
Discriminator:正确识别Fake image和Real image。
两者之间是博弈的关系。
Generator网络定义
损失函数是鉴别器对Fake image错误鉴别的损失
class generator(nn.Module):
def __init__(self):
super(generator,self).__init__()
def block(in_feat,out_feat,normalize=True):
layers=[nn.Linear(in_feat,out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat,0.8))
layers.append(nn.LeakyReLU(0.2,inplace=True))
return layers
self.model=nn.Sequential(
*block(opt.latent_dim,128,normalize=False),
*block(128,256),
*block(256,512),
*block(512,1024),
nn.Linear(1024,int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self,z):
img=self.model(z)
img=img.view(img.size(0),*img_shape)
return img
Discriminator网络定义
损失函数是对Real image和Fake image正确鉴别的损失
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.model=nn.Sequential(
nn.Linear(int(np.prod(img_shape)),512),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(512,256),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self,img):
img_flat=img.view(img.size(0),-1)
validity=self.model(img_flat)
return validity
GAN的常见问题和评估方法
- 常见问题:
- 模式坍塌mode collapse:Generator生成的图片来来去去只有那几张。
解决办法:在遇到mode collapse之前就结束Generator的训练。 - 模式崩溃mode dropping:单纯看Generator生成的图片还不错,但其分布只是真实图片分布的一部分,多样性不够。
解决办法:可以把生成的图片丢入分类网络中,计算每个类的分布和均值,若其分布比较均衡,则说明生成的图片多样性是足够的。
3.Gan可能生成的图片和真实图片相同。
- 模式坍塌mode collapse:Generator生成的图片来来去去只有那几张。
- 评估方法:
Frechet Inception Distance score:将生成的图片和真实图片丢入Inception Network中,获得其输入softmax前的隐藏层输出向量,根据真实图片与生成图片的分布做Frechet distance。(此方法需要大量的样本)
完整代码
import argparse
import torch.cuda
import torchvision
import numpy as np
from torch import nn
from torch.autograd import Variable
from torchvision.utils import save_image
parser=argparse.ArgumentParser()
parser.add_argument('--n_epochs',type=int,default=200,help='number of epochs of training')
parser.add_argument('--batch_size',type=int,default=64,help='size of batches')
parser.add_argument('--lr',type=float,default=0.0002,help='adam:learning rate')
parser.add_argument('--b1',type=float,default=0.5,help="adam:decay of first order momentum of gradient")
parser.add_argument('--b2',type=float,default=0.999,help='adam:decay of first order momentum of gradient')
parser.add_argument('--n_cpu',type=int,default=1,help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim',type=int,default=100,help='dimensionality of the latent space')
parser.add_argument('--channels',type=int,default=1,help='number of image channels')
parser.add_argument('--img_size',type=int,default=28,help='size of each image dimension')
parser.add_argument('--sample_interval',type=int,default=400,help='interval between image samples')
opt=parser.parse_args()
print(opt)
device="cuda" if torch.cuda.is_available() else "cpu"
img_shape=(opt.channels,opt.img_size,opt.img_size)
class generator(nn.Module):
def __init__(self):
super(generator,self).__init__()
def block(in_feat,out_feat,normalize=True):
layers=[nn.Linear(in_feat,out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat,0.8))
layers.append(nn.LeakyReLU(0.2,inplace=True))
return layers
self.model=nn.Sequential(
*block(opt.latent_dim,128,normalize=False),
*block(128,256),
*block(256,512),
*block(512,1024),
nn.Linear(1024,int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self,z):
img=self.model(z)
img=img.view(img.size(0),*img_shape)
return img
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.model=nn.Sequential(
nn.Linear(int(np.prod(img_shape)),512),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(512,256),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self,img):
img_flat=img.view(img.size(0),-1)
validity=self.model(img_flat)
return validity
adv_loss=nn.BCELoss().to(device)
# adv_loss=nn.BCEWithLogitsLoss().to(device) #在做BCE损失之前加上了sigmoid变换
generator=generator().to(device)
discriminator=discriminator().to(device)
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5],[0.5])
])
dataset=torchvision.datasets.MNIST("./images",train=True,download=False,transform=transform)
data_iter=torch.utils.data.DataLoader(dataset,batch_size=opt.batch_size,shuffle=True)
optimizer_G=torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
Tensor=torch.cuda.FloatTensor if device=="cuda" else torch.FloatTensor
for epoch in range(opt.n_epochs):
for i,(img,_) in enumerate(data_iter):
valid=Variable(Tensor(img.size(0),1).fill_(1.0),requires_grad=False)
fake=Variable(Tensor(img.size(0),1).fill_(0.0),requires_grad=False)
real_imgs=Variable(img.type(Tensor))
optimizer_G.zero_grad()
z=Variable(Tensor(np.random.normal(0,1,(img.shape[0],opt.latent_dim))))
gen_imgs=generator(z)
g_loss=adv_loss(discriminator(gen_imgs),valid)
g_loss.backward()
optimizer_G.step()
optimizer_D.zero_grad()
real_loss=adv_loss(discriminator(real_imgs),valid)
fake_loss=adv_loss(discriminator(gen_imgs.detach()),fake)
d_loss=(real_loss+fake_loss)/2
d_loss.backward()
optimizer_D.step()
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(data_iter),
d_loss.item(), g_loss.item()))
batches_done = epoch * len(data_iter) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)
训练的部分结果图
训练结果遇到了mode collapse和mode dropping。