【GAN系列一】普通GAN--随机向量生成真实图片

对抗生成网络原理

在这里插入图片描述
Generator:根据输入的随机向量生成Fake image,并使其骗过Discriminator。
Discriminator:正确识别Fake image和Real image。
两者之间是博弈的关系。

Generator网络定义

损失函数是鉴别器对Fake image错误鉴别的损失

class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        def block(in_feat,out_feat,normalize=True):
            layers=[nn.Linear(in_feat,out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        self.model=nn.Sequential(
            *block(opt.latent_dim,128,normalize=False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,int(np.prod(img_shape))),
            nn.Tanh()
        )
    def forward(self,z):
        img=self.model(z)
        img=img.view(img.size(0),*img_shape)
        return img

Discriminator网络定义

损失函数是对Real image和Fake image正确鉴别的损失

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.model=nn.Sequential(
            nn.Linear(int(np.prod(img_shape)),512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self,img):
        img_flat=img.view(img.size(0),-1)
        validity=self.model(img_flat)
        return validity

GAN的常见问题和评估方法

  • 常见问题:
    1. 模式坍塌mode collapse:Generator生成的图片来来去去只有那几张。
      解决办法:在遇到mode collapse之前就结束Generator的训练。
    2. 模式崩溃mode dropping:单纯看Generator生成的图片还不错,但其分布只是真实图片分布的一部分,多样性不够。
      解决办法:可以把生成的图片丢入分类网络中,计算每个类的分布和均值,若其分布比较均衡,则说明生成的图片多样性是足够的。
      3.Gan可能生成的图片和真实图片相同。
  • 评估方法:
    Frechet Inception Distance score:将生成的图片和真实图片丢入Inception Network中,获得其输入softmax前的隐藏层输出向量,根据真实图片与生成图片的分布做Frechet distance。(此方法需要大量的样本)

完整代码

import argparse
import torch.cuda
import torchvision
import numpy as np
from torch import nn

from torch.autograd import Variable
from torchvision.utils import save_image
parser=argparse.ArgumentParser()
parser.add_argument('--n_epochs',type=int,default=200,help='number of epochs of training')
parser.add_argument('--batch_size',type=int,default=64,help='size of batches')
parser.add_argument('--lr',type=float,default=0.0002,help='adam:learning rate')
parser.add_argument('--b1',type=float,default=0.5,help="adam:decay of first order momentum of gradient")
parser.add_argument('--b2',type=float,default=0.999,help='adam:decay of first order momentum of gradient')
parser.add_argument('--n_cpu',type=int,default=1,help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim',type=int,default=100,help='dimensionality of the latent space')
parser.add_argument('--channels',type=int,default=1,help='number of image channels')
parser.add_argument('--img_size',type=int,default=28,help='size of each image dimension')
parser.add_argument('--sample_interval',type=int,default=400,help='interval between image samples')
opt=parser.parse_args()
print(opt)
device="cuda" if torch.cuda.is_available() else "cpu"
img_shape=(opt.channels,opt.img_size,opt.img_size)
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        def block(in_feat,out_feat,normalize=True):
            layers=[nn.Linear(in_feat,out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        self.model=nn.Sequential(
            *block(opt.latent_dim,128,normalize=False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,int(np.prod(img_shape))),
            nn.Tanh()
        )
    def forward(self,z):
        img=self.model(z)
        img=img.view(img.size(0),*img_shape)
        return img
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.model=nn.Sequential(
            nn.Linear(int(np.prod(img_shape)),512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self,img):
        img_flat=img.view(img.size(0),-1)
        validity=self.model(img_flat)
        return validity
adv_loss=nn.BCELoss().to(device)
# adv_loss=nn.BCEWithLogitsLoss().to(device)   #在做BCE损失之前加上了sigmoid变换
generator=generator().to(device)
discriminator=discriminator().to(device)
transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5],[0.5])
])
dataset=torchvision.datasets.MNIST("./images",train=True,download=False,transform=transform)

data_iter=torch.utils.data.DataLoader(dataset,batch_size=opt.batch_size,shuffle=True)
optimizer_G=torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
Tensor=torch.cuda.FloatTensor if device=="cuda" else torch.FloatTensor

for epoch in range(opt.n_epochs):
    for i,(img,_) in enumerate(data_iter):
        valid=Variable(Tensor(img.size(0),1).fill_(1.0),requires_grad=False)
        fake=Variable(Tensor(img.size(0),1).fill_(0.0),requires_grad=False)
        real_imgs=Variable(img.type(Tensor))
        optimizer_G.zero_grad()
        z=Variable(Tensor(np.random.normal(0,1,(img.shape[0],opt.latent_dim))))
        gen_imgs=generator(z)
        g_loss=adv_loss(discriminator(gen_imgs),valid)
        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        real_loss=adv_loss(discriminator(real_imgs),valid)
        fake_loss=adv_loss(discriminator(gen_imgs.detach()),fake)
        d_loss=(real_loss+fake_loss)/2
        d_loss.backward()
        optimizer_D.step()
        print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(data_iter),
                                                            d_loss.item(), g_loss.item()))
        batches_done = epoch * len(data_iter) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)

训练的部分结果图

训练结果遇到了mode collapse和mode dropping。
在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值