VAE
最近在研究如何生成中间图像变量的问题,看vae,cvae百看不懂,论文读的也是迷迷糊糊,我相信有些人应该和我一样。为了能帮助大家,我将用实际操作给大家讲解一下我的理解。
首先是vae。其实读起来VAE,我更多的是想起来深度特征插值的一种方法。其实vae的核心在于深度空间的规则化。我们可以想想gan的算法,使用gan的G和D,我们的生成器,也就是G生成方式是随机的,很有可能导致梯度消失或者梯度爆炸。再有可能会有一种投机取巧的方法,生成同一种图片骗过判别器。这种完全交给电脑的方法显然是不合理的,那么有没有一种方法,能很优雅的生成图片,而且不会梯度爆炸,梯度消失,而且很合理呢?
vae就是这种方法,vae生成的图片虽然变化不大,但是图片可以源源不断的产生,虽然idea很棒,但是很多人读到什么后验分布,什么正态分布,什么匹配,什么变分,整个人都蒙了,还有什么不可求导问题,到底是个啥玩意?那我们直接上个代码。
class Reshape(nn.Module):
def __init__(self, args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape)
class Vae(nn.Module):
def __init__(self, batch_size):
super(Vae, self).__init__()
self.z_dim = 2
self.encoder = nn.Sequential(
OrderedDict([
('reshape1', Reshape((-1, 1, 28, 28))),
('conv1', nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)),
('norm1', nn.BatchNorm2d(16)),
('relu1', nn.LeakyReLU(0.2, inplace=True)),
('conv2', nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)),
('norm2', nn.BatchNorm2d(32)),
('relu2', nn.LeakyReLU(0.2, inplace=True)),
('conv3', nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)),
('norm3', nn.BatchNorm2d(32)),
('relu3', nn.LeakyReLU(0.2, inplace=True)),
('reshape2', Reshape((batch_size, -1)))
])
)
self.mean_linear = nn.Linear(32*7*7, self.z_dim)
self.stds_linear = nn.Linear(32*7*7, self.z_dim)
self.decoder = nn.Sequential(
OrderedDict([
('fc_z', nn.Linear(self.z_dim, 32*7*7)),
('view', Reshape((-1, 32, 7, 7))),
('deconv1', nn.ConvTranspose2d(32, 16, 4, 2, 1)),
('relu1', nn.ReLU(inplace=True)),
('deconv2', nn.ConvTranspose2d(16, 1, 4, 2, 1)),
('sigmoid', nn.Sigmoid()),
])
)
def noise_get_z(self, mean, logvar):
eps = torch.randn(logvar.shape).to('cpu')
z = mean + eps * torch.exp(logvar)
return z
def forward(self, x):
"""
:param x: 输入的图像
:return: recon_x, mean, std
"""
mean, logstd = self.mean_linear(self.encoder(x)), self.stds_linear(self.encoder(x))
z = self.noise_get_z(mean, logstd)
out = self.decoder(z)
return out, mean, logstd
这就是VAE-MNIST的全部代码了,就是这么简单。但是想真的理解,还需要下一定的功夫。
首先,先看到mean_linear, logv_linear这两个全连接层,这两个全连接层是生成mean与std,也就是正态函数中最关键的均值和方差。但是,你光知道mean,std,没有函数,decoder的前向传播传的过去,后向传播没法传呀,因为得求函数的导数传播,mean和std只是一个数,这可咋办?那不行咱们就找个函数替代吧?啥函数一直是可以求导的呢?正态函数可以哈!那么直接从01分布力取一个偏置,mean+std*noise,那么z就是正态函数上的一个点了,那么就可以求导啦!
还有人说了,我不用noise,那也可以反向传播啊!是的,不用noise依然可以反向传播,只要你mean,std的映射函数处处可导。其实前面就是个人理解,但是如果你没有用noise,那么z的分布可能就不能映射在正态分布上,没有KL散度的支持,偏置std可能会逐渐变为0,loss使用MSE的话,图片会尽量与原图靠拢,那么z会变成一个很平庸的中间隐藏层,也没啥生成能力了。
在这个时候,vae就有一定的作用了,随便给一个图片,生成mean,std,加一个noise做偏置,就可以源源不断生成基于该图片的随机图片啦,或者直接使用z = torch.randn([batch_size, self.z_dim]),也可以生成随机图片。请记住一点,添加noise生成的z才是核心,输入的x只是起了一个提供mean和std的作用,生成什么图片,和输入的x没有啥直接关系。简单理解就是输入x是一个圆中点,它是输入网络生成了z,z是圆中间的一点,你不能说z就是x吧,或者咱们直接randn生成都没问题。
CVAE
上回说到,vae有可以生成源源不断的recon_x了,但是我没法用啊,虽然看着挺好,但是还是个辣鸡,都不受控制。没事,小老弟,我教你一个办法,很快就能控制了
上节中的输入只有x, 那么label空着不用也不是办法,label怎么贴着x一起放进去呢?
答案是one-hot化之后直接在最后一个维度沾一起,例如encoder中:
if self.conditional:
c = idx2onehot(c, n=10)
x = torch.cat((x, c), dim=-1)
x = self.MLP(x)
就这么简单?mnist里面就是这么简单,其他的就要靠你的聪明才智了。
那么在decoder中也不能忘了label要一起放进去,如下:
if self.conditional:
c = idx2onehot(c, n=10)
z = torch.cat((z, c), dim=-1)
接下来只需要encoder的输入维度加num_class,decoder的输入维度加num_class就结束了。
结束了?对,其他啥也不要变,loss不变,model不变就完事了。如果你嫌麻烦,encoder中的label都不用加,直接加到decoder中,效果是一样的。个人猜想,这个方法其实是取巧了,你随机变量不是带着z+label一起的吗?z可以随便变,但是label不能变呀,指定的label对应的是指定的图片,label:1只能对应含有数字1的图片,那解码层其实也学到了分类信息了。
还有两个实用的操作,一般的讲解里没有仔细说。
a. 就是之前的输入x获取了z吗?那么我们输入一个1的图片,获取一个z,输入一个10的图片,获取一个z,var = (z10-z1)/ n,那z1+(1..n)var,就获取了层次变换。比如你获取了一个人正面的隐藏层z,获取侧面的隐藏层z,两个z之间的距离,就是从正面到侧面的过度层.
b. 隐藏层z包含了一些隐藏信息,可以做相同类型的检索操作。比较好理解,就是计算隐藏层z之间的距离,分辨是不是同一个人。
CVAE-GAN
要是实现了前两个方法,走到这一步的同学,就会发现一个问题,vae好是好,有两个优点,一个是稳定,例如在人脸,不会输 出一些奇形怪状的东西。另一个是隐藏层规则化,我想往哪里变就往那里变。但是还有个问题,就是图片的生成结果模糊。如果模糊问题解决了,那就起飞了。怎么解决?可以考虑拿判别器下手。之前的问题,生成效果不好,不是因为没能力,而是无论Decoder给你多大的网络,判别器mse+kl随随便便就过了,loss下降太快。那我直接来gan里的判别器。不仅要gan判别器,还要分类器,全部都加一起。说白了,就是缝合怪。
那么,该怎么玩?Encoder是生成正态分布的,它对应的是mse+kl的loss,生成器(也就是VAE里的解码器)对应生成结果迷惑判别器的,它对应的是min(D(G(z)))的loss,判别器中对应的是提升判别能力的,它对应的是max(D(fake_img)) + min(D(real_img)), C是分类器,对应的loss是min(C(fake_img), label),全部加一块就行了。实现代码如下:
class Discriminator(nn.Module):
def __init__(self, outputn=1):
super(Discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=1, padding=1),
nn.LeakyReLU(0.2, True),
nn.MaxPool2d((2, 2)),
nn.Conv2d(32, 64, 3, stride=1, padding=1),
nn.LeakyReLU(0.2, True),
nn.MaxPool2d((2, 2)),
)
self.fc = nn.Sequential(
nn.Linear(7 * 7 * 64, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, outputn),
nn.Sigmoid()
)
def forward(self, input):
x = self.dis(input)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x.squeeze(1)
def loss_function(recon_x, x, mean, logstd):
# BCE = F.binary_cross_entropy(recon_x,x,reduction='sum')
MSE = MSECriterion(recon_x, x)
# 因为var是标准差的自然对数,先求自然对数然后平方转换成方差
var = torch.pow(torch.exp(logstd), 2)
KLD = -0.5 * torch.sum(1 + torch.log(var) - torch.pow(mean, 2) - var)
return MSE + KLD
print("=====> 构建VAE")
vae = VAE().to(device)
print("=====> 构建D")
D = Discriminator(1).to(device)
print("=====> 构建C")
C = Discriminator(10).to(device)
criterion = nn.BCELoss().to(device)
MSECriterion = nn.MSELoss().to(device)
print("=====> Setup optimizer")
optimizerD = optim.Adam(D.parameters(), lr=0.0001)
optimizerC = optim.Adam(C.parameters(), lr=0.0001)
optimizerVAE = optim.Adam(vae.parameters(), lr=0.0001)
for epoch in range(nepoch):
for i, (data, label) in enumerate(dataloader, 0):
# 先处理一下数据
data = data.to(device)
label_onehot = torch.zeros((data.shape[0], 10)).to(device)
label_onehot[torch.arange(data.shape[0]), label] = 1
batch_size = data.shape[0]
# 先训练C
output = C(data)
real_label = label_onehot.to(device) # 定义真实的图片label为1
errC = criterion(output, real_label)
C.zero_grad()
errC.backward()
optimizerC.step()
# 再训练D
output = D(data)
real_label = torch.ones(batch_size).to(device) # 定义真实的图片label为1
fake_label = torch.zeros(batch_size).to(device) # 定义假的图片的label为0
errD_real = criterion(output, real_label)
z = torch.randn(batch_size, nz + 10).to(device)
fake_data = vae.decoder(z)
output = D(fake_data)
errD_fake = criterion(output, fake_label)
errD = errD_real + errD_fake
D.zero_grad()
errD.backward()
optimizerD.step()
# 更新VAE(G)1
z, mean, logstd = vae.encoder(data)
z = torch.cat([z, label_onehot], 1)
recon_data = vae.decoder(z)
vae_loss1 = loss_function(recon_data, data, mean, logstd)
# 更新VAE(G)2
output = D(recon_data)
real_label = torch.ones(batch_size).to(device)
vae_loss2 = criterion(output, real_label)
# 更新VAE(G)3
output = C(recon_data)
real_label = label_onehot
vae_loss3 = criterion(output, real_label)
vae.zero_grad()
vae_loss = vae_loss1 + vae_loss2 + vae_loss3
vae_loss.backward()
optimizerVAE.step()
读完三篇论文,我又怅然若失,看之前觉得可以改变世界,看完感觉能力有限,学不可以已。
看完不懂的朋友或者需要代码的同学和我说一声,我可以把代码传到GitHub