vae cvae cvaegan的区别

VAE

    最近在研究如何生成中间图像变量的问题,看vae,cvae百看不懂,论文读的也是迷迷糊糊,我相信有些人应该和我一样。为了能帮助大家,我将用实际操作给大家讲解一下我的理解。

  首先是vae。其实读起来VAE,我更多的是想起来深度特征插值的一种方法。其实vae的核心在于深度空间的规则化。我们可以想想gan的算法,使用gan的G和D,我们的生成器,也就是G生成方式是随机的,很有可能导致梯度消失或者梯度爆炸。再有可能会有一种投机取巧的方法,生成同一种图片骗过判别器。这种完全交给电脑的方法显然是不合理的,那么有没有一种方法,能很优雅的生成图片,而且不会梯度爆炸,梯度消失,而且很合理呢?

   vae就是这种方法,vae生成的图片虽然变化不大,但是图片可以源源不断的产生,虽然idea很棒,但是很多人读到什么后验分布,什么正态分布,什么匹配,什么变分,整个人都蒙了,还有什么不可求导问题,到底是个啥玩意?那我们直接上个代码。

class Reshape(nn.Module):
    def __init__(self, args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)


class Vae(nn.Module):
    def __init__(self, batch_size):
        super(Vae, self).__init__()
        self.z_dim = 2
        self.encoder = nn.Sequential(
            OrderedDict([
                ('reshape1', Reshape((-1, 1, 28, 28))),
                ('conv1', nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)),
                ('norm1', nn.BatchNorm2d(16)),
                ('relu1', nn.LeakyReLU(0.2, inplace=True)),
                ('conv2', nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)),
                ('norm2', nn.BatchNorm2d(32)),
                ('relu2', nn.LeakyReLU(0.2, inplace=True)),
                ('conv3', nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)),
                ('norm3', nn.BatchNorm2d(32)),
                ('relu3', nn.LeakyReLU(0.2, inplace=True)),
                ('reshape2', Reshape((batch_size, -1)))
            ])
        )

        self.mean_linear = nn.Linear(32*7*7, self.z_dim)
        self.stds_linear = nn.Linear(32*7*7, self.z_dim)
        self.decoder = nn.Sequential(
            OrderedDict([
                ('fc_z', nn.Linear(self.z_dim, 32*7*7)),
                ('view', Reshape((-1, 32, 7, 7))),
                ('deconv1', nn.ConvTranspose2d(32, 16, 4, 2, 1)),
                ('relu1', nn.ReLU(inplace=True)),
                ('deconv2', nn.ConvTranspose2d(16, 1, 4, 2, 1)),
                ('sigmoid', nn.Sigmoid()),
            ])
        )

    def noise_get_z(self, mean, logvar):
        eps = torch.randn(logvar.shape).to('cpu')
        z = mean + eps * torch.exp(logvar)
        return z

    def forward(self, x):
        """

        :param x: 输入的图像
        :return: recon_x, mean, std
        """
        mean, logstd = self.mean_linear(self.encoder(x)), self.stds_linear(self.encoder(x))
        z = self.noise_get_z(mean, logstd)
        out = self.decoder(z)
        return out, mean, logstd

这就是VAE-MNIST的全部代码了,就是这么简单。但是想真的理解,还需要下一定的功夫。

首先,先看到mean_linear, logv_linear这两个全连接层,这两个全连接层是生成mean与std,也就是正态函数中最关键的均值和方差。但是,你光知道mean,std,没有函数,decoder的前向传播传的过去,后向传播没法传呀,因为得求函数的导数传播,mean和std只是一个数,这可咋办?那不行咱们就找个函数替代吧?啥函数一直是可以求导的呢?正态函数可以哈!那么直接从01分布力取一个偏置,mean+std*noise,那么z就是正态函数上的一个点了,那么就可以求导啦!

还有人说了,我不用noise,那也可以反向传播啊!是的,不用noise依然可以反向传播,只要你mean,std的映射函数处处可导。其实前面就是个人理解,但是如果你没有用noise,那么z的分布可能就不能映射在正态分布上,没有KL散度的支持,偏置std可能会逐渐变为0,loss使用MSE的话,图片会尽量与原图靠拢,那么z会变成一个很平庸的中间隐藏层,也没啥生成能力了。

在这个时候,vae就有一定的作用了,随便给一个图片,生成mean,std,加一个noise做偏置,就可以源源不断生成基于该图片的随机图片啦,或者直接使用z = torch.randn([batch_size, self.z_dim]),也可以生成随机图片。请记住一点,添加noise生成的z才是核心,输入的x只是起了一个提供mean和std的作用,生成什么图片,和输入的x没有啥直接关系。简单理解就是输入x是一个圆中点,它是输入网络生成了z,z是圆中间的一点,你不能说z就是x吧,或者咱们直接randn生成都没问题。

 

CVAE 

    上回说到,vae有可以生成源源不断的recon_x了,但是我没法用啊,虽然看着挺好,但是还是个辣鸡,都不受控制。没事,小老弟,我教你一个办法,很快就能控制了

上节中的输入只有x, 那么label空着不用也不是办法,label怎么贴着x一起放进去呢?

答案是one-hot化之后直接在最后一个维度沾一起,例如encoder中:

if self.conditional:
    c = idx2onehot(c, n=10)
    x = torch.cat((x, c), dim=-1)
x = self.MLP(x)

就这么简单?mnist里面就是这么简单,其他的就要靠你的聪明才智了。

那么在decoder中也不能忘了label要一起放进去,如下:


        if self.conditional:
            c = idx2onehot(c, n=10)
            z = torch.cat((z, c), dim=-1)

接下来只需要encoder的输入维度加num_class,decoder的输入维度加num_class就结束了。

结束了?对,其他啥也不要变,loss不变,model不变就完事了。如果你嫌麻烦,encoder中的label都不用加,直接加到decoder中,效果是一样的。个人猜想,这个方法其实是取巧了,你随机变量不是带着z+label一起的吗?z可以随便变,但是label不能变呀,指定的label对应的是指定的图片,label:1只能对应含有数字1的图片,那解码层其实也学到了分类信息了。

 

还有两个实用的操作,一般的讲解里没有仔细说。

a. 就是之前的输入x获取了z吗?那么我们输入一个1的图片,获取一个z,输入一个10的图片,获取一个z,var = (z10-z1)/ n,那z1+(1..n)var,就获取了层次变换。比如你获取了一个人正面的隐藏层z,获取侧面的隐藏层z,两个z之间的距离,就是从正面到侧面的过度层.

b. 隐藏层z包含了一些隐藏信息,可以做相同类型的检索操作。比较好理解,就是计算隐藏层z之间的距离,分辨是不是同一个人。

 

CVAE-GAN

       要是实现了前两个方法,走到这一步的同学,就会发现一个问题,vae好是好,有两个优点,一个是稳定,例如在人脸,不会输 出一些奇形怪状的东西。另一个是隐藏层规则化,我想往哪里变就往那里变。但是还有个问题,就是图片的生成结果模糊。如果模糊问题解决了,那就起飞了。怎么解决?可以考虑拿判别器下手。之前的问题,生成效果不好,不是因为没能力,而是无论Decoder给你多大的网络,判别器mse+kl随随便便就过了,loss下降太快。那我直接来gan里的判别器。不仅要gan判别器,还要分类器,全部都加一起。说白了,就是缝合怪。

     那么,该怎么玩?Encoder是生成正态分布的,它对应的是mse+kl的loss,生成器(也就是VAE里的解码器)对应生成结果迷惑判别器的,它对应的是min(D(G(z)))的loss,判别器中对应的是提升判别能力的,它对应的是max(D(fake_img)) + min(D(real_img)), C是分类器,对应的loss是min(C(fake_img), label),全部加一块就行了。实现代码如下:

class Discriminator(nn.Module):
    def __init__(self, outputn=1):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),
        )
        self.fc = nn.Sequential(
            nn.Linear(7 * 7 * 64, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, outputn),
            nn.Sigmoid()
        )

    def forward(self, input):
        x = self.dis(input)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x.squeeze(1)


def loss_function(recon_x, x, mean, logstd):
    # BCE = F.binary_cross_entropy(recon_x,x,reduction='sum')
    MSE = MSECriterion(recon_x, x)
    # 因为var是标准差的自然对数,先求自然对数然后平方转换成方差
    var = torch.pow(torch.exp(logstd), 2)
    KLD = -0.5 * torch.sum(1 + torch.log(var) - torch.pow(mean, 2) - var)
    return MSE + KLD

   print("=====> 构建VAE")
    vae = VAE().to(device)
    print("=====> 构建D")
    D = Discriminator(1).to(device)
    print("=====> 构建C")
    C = Discriminator(10).to(device)
    criterion = nn.BCELoss().to(device)
    MSECriterion = nn.MSELoss().to(device)

    print("=====> Setup optimizer")
    optimizerD = optim.Adam(D.parameters(), lr=0.0001)
    optimizerC = optim.Adam(C.parameters(), lr=0.0001)
    optimizerVAE = optim.Adam(vae.parameters(), lr=0.0001)

    for epoch in range(nepoch):
        for i, (data, label) in enumerate(dataloader, 0):
            # 先处理一下数据
            data = data.to(device)
            label_onehot = torch.zeros((data.shape[0], 10)).to(device)
            label_onehot[torch.arange(data.shape[0]), label] = 1
            batch_size = data.shape[0]
            # 先训练C
            output = C(data)
            real_label = label_onehot.to(device)  # 定义真实的图片label为1
            errC = criterion(output, real_label)
            C.zero_grad()
            errC.backward()
            optimizerC.step()
            # 再训练D
            output = D(data)
            real_label = torch.ones(batch_size).to(device)  # 定义真实的图片label为1
            fake_label = torch.zeros(batch_size).to(device)  # 定义假的图片的label为0
            errD_real = criterion(output, real_label)

            z = torch.randn(batch_size, nz + 10).to(device)
            fake_data = vae.decoder(z)
            output = D(fake_data)
            errD_fake = criterion(output, fake_label)

            errD = errD_real + errD_fake
            D.zero_grad()
            errD.backward()
            optimizerD.step()
            # 更新VAE(G)1
            z, mean, logstd = vae.encoder(data)
            z = torch.cat([z, label_onehot], 1)
            recon_data = vae.decoder(z)
            vae_loss1 = loss_function(recon_data, data, mean, logstd)
            # 更新VAE(G)2
            output = D(recon_data)
            real_label = torch.ones(batch_size).to(device)
            vae_loss2 = criterion(output, real_label)
            # 更新VAE(G)3
            output = C(recon_data)
            real_label = label_onehot
            vae_loss3 = criterion(output, real_label)

            vae.zero_grad()
            vae_loss = vae_loss1 + vae_loss2 + vae_loss3
            vae_loss.backward()
            optimizerVAE.step()

读完三篇论文,我又怅然若失,看之前觉得可以改变世界,看完感觉能力有限,学不可以已。

看完不懂的朋友或者需要代码的同学和我说一声,我可以把代码传到GitHub

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值