【论文阅读】Conditional Generative Adversarial Nets

碎碎念

参加jittor比赛,热身赛中使用的GAN模型,想起自己还没真正使用过GAN,希望通过这个机会学习下

引入

在这里插入图片描述
想要真正地理解世界,就应该能够生成世界的种种组成。因此出现了生成模型;

生成模型是说,我们随机地生成一些图片(以图片任务为例),使得这些图片能够尽可能地描绘真实世界;

然而这个标准很难量化去衡量(怎么样才算真实?),因此提出生成对抗网络GAN,同时设计生成器和判别器两个部分,生成器的任务是努力生成能够以假乱真的图片,而判别器的任务是尽可能区分生成的图片和真实图片;这样评判标准就清晰了,即是真实的训练数据还是生成的数据,同时,生成器和判别器两部分的博弈使得整个系统稳步运行;

然而仅仅生成一些随机的真实图片可能没办法满足我们的需求,我们可能还需要生成有条件的真实图片,比如,我想生成猫的真实图像,那么条件就是猫,则提出Conditional GAN;

这些条件可以是类别标签(上面猫的例子),也可以是用于图像修复的部分数据(图像缺失,那么保留下的那部分数据就可以作为条件),甚至是其他模态的数据(给定一个文本查询,生成文本相关的图片,那么文本查询就可以是条件)

方法

整个网络可以通过一个公式来说明,在没有condition的情况下(也就是普通的生成网络),公式是这样的(其中 E E E代表期望):
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] min_Gmax_DV(D,G)=E_{x \sim p_{data}(x)}[logD(x)]+E_{z \sim p_z(z)}[log(1-D(G(z)))] minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))],记为公式 (1)
加上条件之后,公式变为:
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ∣ y ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ∣ y ) ) ) ] min_Gmax_DV(D,G)=E_{x \sim p_{data}(x|y)}[logD(x)]+E_{z \sim p_z(z)}[log(1-D(G(z|y)))] minGmaxDV(D,G)=Expdata(xy)[logD(x)]+Ezpz(z)[log(1D(G(zy)))],记为公式 (2)

对这两个公式的理解(只有条件部分不同,就放在一起说了),可以参考李宏毅的视频对抗生成网络(GAN),这里简单总结一下:
这里的两个目标 m i n G min_G minG m a x D max_D maxD,可以看作是fix D / G D/G D/G的参数,更新 G / D G/D G/D的参数,比如 m a x D max_D maxD就代表固定 G G G,更新 D D D,那么公式对应的详细算法就如下所示:
在这里插入图片描述

  1. 学习判别器 D D D的时候,固定生成器 G G G,先利用生成器 G G G对噪声生成 x ~ \tilde x x~,目标就变为: m a x D V = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( x ~ ) ) ] max_DV=E_{x \sim p_{data}(x)}[logD(x)]+E_{z \sim p_z(z)}[log(1-D(\tilde x))] maxDV=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(x~))],那么需要完成以下任务:
  • 判别器 D D D执行假图判别任务, D ( G ( z ) ) D(G(z)) D(G(z))
向量含义维度
输入gen_imgs假图 b s z , 32 , 32 bsz, 32, 32 bsz,32,32
gen_labels假图的label b s z , 1 bsz, 1 bsz,1
输出validity_fake假图的真实性 b s z , 1 bsz, 1 bsz,1
目标 m a x l o g ( 1 − D ( x ~ ) ) maxlog(1-D(\tilde x)) maxlog(1D(x~))尽可能将假图判定为假等价于 m i n D ( x ~ ) minD(\tilde x) minD(x~)
  • 判别器 D D D执行真图判别任务, D ( x ) D(x) D(x)
向量含义维度
输入real_imgs真图 b s z , 32 , 32 bsz, 32, 32 bsz,32,32
labels真图的label b s z , 1 bsz, 1 bsz,1
输出validity_real真图的真实性 b s z , 1 bsz, 1 bsz,1
目标 m a x l o g D ( x ) maxlogD(x) maxlogD(x)尽可能将真图判定为真等价于 m a x D ( x ) maxD(x) maxD(x)
  1. 学习生成器 G G G的时候,固定判别器 D D D,目标就变为: m i n G V = E z ∼ p z ( z ) [ l o g ( 1 − D ( x ~ ) ) ] min_GV=E_{z \sim p_z(z)}[log(1-D(\tilde x))] minGV=Ezpz(z)[log(1D(x~))],也就是等价于 m a x G V = E z ∼ p z ( z ) [ l o g D ( x ~ ) ] max_GV=E_{z \sim p_z(z)}[logD(\tilde x)] maxGV=Ezpz(z)[logD(x~)](后者的max是上面那张算法图中的写法,我们还是跟随公式使用前者min,因为当时因为写成max导致好久没看懂,无语…),那么需要完成以下任务:
  • 生成器 G G G执行假图生成任务, G ( z ) G(z) G(z),然后判别器 D D D执行假图判断任务, D ( G ( z ) ) D(G(z)) D(G(z))
生成器向量含义维度
输入 z z z随机向量 b s z , 100 bsz, 100 bsz,100
gen_labels要生成假图的label,相当于condition b s z , 1 bsz, 1 bsz,1
输出gen_imgs生成的假图 b s z , 32 , 32 bsz, 32, 32 bsz,32,32
判别器向量含义维度
输入gen_imgs生成的假图 b s z , 32 , 32 bsz, 32, 32 bsz,32,32
gen_labels假图的label b s z , 1 bsz, 1 bsz,1
输出validity假图的真实性 b s z , 1 bsz, 1 bsz,1
目标 m i n l o g ( 1 − D ( G ( z ) ) ) minlog(1-D(G(z))) minlog(1D(G(z)))尽可能令生成器生成能够以假乱真的图片,从而让判别器将生成的图片判别为真实等价于 m a x D ( G ( z ) ) maxD(G(z)) maxD(G(z))

代码学习(jittor版)

import jittor as jt
from jittor import init
import argparse
import os
import numpy as np
import math
from jittor import nn

if jt.has_cuda:
    jt.flags.use_cuda = 1

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
# sample_interval 暂不知用途
parser.add_argument('--sample_interval', type=int, default=1000, help='interval between image sampling')
opt = parser.parse_args()
print(opt)

# img_shape: 1, 32, 32
img_shape = (opt.channels, opt.img_size, opt.img_size)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # nn.Embedding创建了一个可查询的embedding字典
        # 参数1是num,指embedding字典的大小,参数2是dim,embedding变量的大小
        # 也就是说会创建一个10*10的查询表
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
        # nn.Linear(in_dim, out_dim)表示全连接层
        # in_dim:输入向量维度
        # out_dim:输出向量维度
        # block定义了一个层,fc+bn+LeakyReLU,以list的形式存在
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            # nn.LeakyReLU代表泄露的ReLU(负值部分泄露,0.2表示泄露的斜率)
            layers.append(nn.LeakyReLU(0.2))
            return layers
        # *用在实参前,相当于对tuple的解构(对应的**是对dict的解构),长度需要跟函数需要的参数对应
        # *用在形参前,可以表示一个可变长度的tuple(对应的**表示可变长度的dict,同时有*和**时,*在前,**在后)
        # 这里是指把block作为一个tuple传递给nn.Sequential函数
        # nn.Sequential需要的参数为*args,也就是说可以传递任意长度的参数
        # block(100+10, 128, normalize=False)
        # model的维度不断增加,看来应该是个生成器(reshape的操作在后面)
        # nn.Linear(1024, 1*32*32=1024)
        self.model = nn.Sequential(*block((opt.latent_dim + opt.n_classes), 128, normalize=False), 
                                   *block(128, 256), 
                                   *block(256, 512), 
                                   *block(512, 1024), 
                                   nn.Linear(1024, int(np.prod(img_shape))), 
                                   nn.Tanh())

    # 相当于pytorch中的forward函数,表示网络向量的传递过程
    def execute(self, noise, labels):
        # label是条件,noise是生成的随机变量
        # label: , noise: 
        gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
        img = self.model(gen_input)
        # 将img从1024维向量变为32*32矩阵
        img = img.view((img.shape[0], *img_shape))
        return img

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
        # nn.Linear(10+1*32*32=1034, 512)
        self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512), 
                                   nn.LeakyReLU(0.2), 
                                   nn.Linear(512, 512), 
                                   nn.Dropout(0.4), 
                                   nn.LeakyReLU(0.2), 
                                   nn.Linear(512, 512), 
                                   nn.Dropout(0.4), 
                                   nn.LeakyReLU(0.2), 
                                   # TODO: 添加最后一个线性层,最终输出为一个实数
                                   nn.Linear(512, 1)
                                   )

    def execute(self, img, labels):
        # img: bsz, 1*32*32=1024; label: bsz, 10
        d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
        # TODO: 将d_in输入到模型中并返回计算结果
        validity = self.model(d_in)
        return validity

# 损失函数:平方误差
# 调用方法:adversarial_loss(网络输出A, 分类标签B)
# 计算结果:(A-B)^2
adversarial_loss = nn.MSELoss()

generator = Generator()
discriminator = Discriminator()

# 导入MNIST数据集
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
# 以链式的方式组合多个transform
transform = transform.Compose([
    transform.Resize(opt.img_size),
    # transform.Gray()只在jittor中有的一个函数
    # 将任意形式的PIL图像(RGB, HSV, LAB等)转化为灰度图
    # 应该有参数啊,为什么没给参数?
    transform.Gray(),
    # transform.ImageNormalize只在jittor中有的一个函数
    transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
# set_attrs只在jittor中有的一个函数
# 顾名思义,为数据集设置一些属性
# 这个函数的确很方便,方便取用一些重要的参数
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)

optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

from PIL import Image
def save_image(img, path, nrow=10, padding=5):
    N,C,W,H = img.shape
    if (N%nrow!=0):
        print("N%nrow!=0")
        return
    # 共N个图片,把他们分成nrow行ncol列
    ncol=int(N/nrow)
    img_all = []
    for i in range(ncol):
        # 每一列都建立一个新的img_ list
        img_ = []
        for j in range(nrow):
            # 将第i*nrow+j个图片存入,其维度为C, W, H
            img_.append(img[i*nrow+j])
            # 生成一个C, W, padding的0矩阵加入(干嘛用的,猜想是用来放label的,所以padding应该为10)
            img_.append(np.zeros((C,W,padding)))
        # np.concatenate可以一次完成多个数组的拼接
        # 对img_ 这个list在第2维度上进行拼接
        # 则生成C, W, (H+padding)*nrow+的矩阵,加入img_all list
        img_all.append(np.concatenate(img_, 2))
        # 生成C, padding, (H+padding)*nrow的矩阵,也加入img_all list(干嘛用的)
        img_all.append(np.zeros((C,padding,img_all[0].shape[2])))
    # 在img_all的第1维度上拼接
    # img: C, (W+padding)*ncol, (H+padding)*nrow
    img = np.concatenate(img_all, 1)
    # 生成C, padding, (H+padding)*nrow的0矩阵
    # 再与img在第1维度上拼接,得到C, (W+padding)*ncol+padding, (H+padding)*nrow
    img = np.concatenate([np.zeros((C,padding,img.shape[2])), img], 1)
    # 生成C, (W+padding)*ncol+padding, padding的0矩阵
    # 再与img在第2维度上拼接,得到C, (W+padding)*ncol+padding, (H+padding)*nrow+padding
    img = np.concatenate([np.zeros((C,img.shape[1],padding)), img], 2)
    min_=img.min()
    max_=img.max()
    # 标准化到0~255
    img=(img-min_)/(max_-min_)*255
    # (W+padding)*ncol+padding, (H+padding)*nrow+padding, C
    img=img.transpose((1,2,0))
    if C==3:
        # 把C的维度倒序输出?why
        img = img[:,:,::-1]
    elif C==1:
        img = img[:,:,0]
    Image.fromarray(np.uint8(img)).save(path)

def sample_image(n_row, batches_done):
    # 随机采样输入并保存生成的图片
    # 从正态分布中生成随机变量,维度n_row^2, latent_dim=100
    # 推测含义是生成nrow^2个向量,每个向量维度是100
    # 所以为什么要停止梯度的计算?(又不是叶子节点)
    z = jt.array(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))).float32().stop_grad()
    # 生成一个array,nrow*nrow的长度
    labels = jt.array(np.array([num for _ in range(n_row) for num in range(n_row)])).float32().stop_grad()
    # 生成器生成图片32*32
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.numpy(), "%d.png" % batches_done, nrow=n_row)

# ----------
#  模型训练
# ----------

for epoch in range(opt.n_epochs):
    # imgs: 64, 1, 32, 32
    # labels: 64
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # 数据标签,valid=1表示真实的图片,fake=0表示生成的图片
        # 表示有batch_size个真实图片和batch_size个生成图片
        valid = jt.ones([batch_size, 1]).float32().stop_grad()
        fake = jt.zeros([batch_size, 1]).float32().stop_grad()

        # 真实图片及其类别
        real_imgs = jt.array(imgs)
        labels = jt.array(labels)

        # -----------------
        #  训练生成器
        # -----------------

        # 采样随机噪声和数字类别作为生成器输入
        # z: batch_size, 100
        z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
        # gen_labels: batch_size(从0~10生成随机整数)
        gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()

        # 生成一组图片
        # 放到生成器中的是生成的label,不是真正的label
        # 生成器会返回64, 32, 32的图片
        # 此处完成了公式(2)中的G(z|y),在生成标签的条件下,为噪声z生成图片
        gen_imgs = generator(z, gen_labels)
        # 损失函数衡量生成器欺骗判别器的能力,即希望判别器将生成图片分类为valid
        # 判别器对生成的图片和生成的label进行判断,输出一个实数
        # validity: 64, 1
        # 此处是完成了公式(2)中的D(G(z|y)),获得了生成标签下生成图片的真实性validity
        validity = discriminator(gen_imgs, gen_labels)
        # 令生成的图片与真实的图片标签进行均方误差
        # 此处是完成了公式(2)中的目标:maxD,也就是令D(G(z|y))逼近于1
        g_loss = adversarial_loss(validity, valid)
        g_loss.sync()
        optimizer_G.step(g_loss)

        # ---------------------
        #  训练判别器
        # ---------------------

        # 判别器输出真实图片的实数
        # validity_real: 64, 1
        # 此处是完成公式(2)中的D(x|y),判断真实标签下真实图片的真实性
        validity_real = discriminator(real_imgs, labels)
        # d_real_loss = adversarial_loss("""TODO: 计算真实类别的损失函数""")
        # 此处是完成公式(2)中的
        d_real_loss = adversarial_loss(validity_real, valid)

        validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
        # d_fake_loss = adversarial_loss("""TODO: 计算虚假类别的损失函数""")
        # 这里的目标是令对生成图片的预测努力接近生成标签
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # 总的判别器损失
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.sync()
        optimizer_D.step(d_loss)
        if i  % 50 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

    if epoch % 10 == 0:
        generator.save("generator_last.pkl")
        discriminator.save("discriminator_last.pkl")

generator.eval()
discriminator.eval()
generator.load('generator_last.pkl')
discriminator.load('discriminator_last.pkl')

number = 12312345678#TODO: 写入你注册时绑定的手机号(字符串类型)
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)

img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))
min_=img_array.min()
max_=img_array.max()
img_array=(img_array-min_)/(max_-min_)*255
Image.fromarray(np.uint8(img_array)).save("result.png")

结束语

感觉GAN应该是深度学习中最难理解的网络了(我比较孤陋寡闻),不过最后最算是理解了,开心。姑且就把这篇文章定义为中级文章吧hh(本人的第一篇中级)

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值