VaDE模型理解-代码层面

本文介绍了如何在VAE的预训练阶段使用高斯混合模型(GMM)对编码器生成的(latentspace)进行聚类,并在后续的ELBO训练中利用GMM进行预测。作者通过sklearn库实现GMM的参数估计,并详细描述了训练和预测过程。
摘要由CSDN通过智能技术生成

高斯混合模型: 高斯混合模型是一种概率密度模型,它表示所有数据点都是由 K 个高斯分布生成的,每个高斯分布有自己的均值、协方差和混合系数(权重)。这些高斯分布通常代表数据中的不同子群体或聚类。

gmm sklearn api: 传入参数n_components代表有多少个高斯分布,covariance_type代表协方差矩阵的类型。可以用sklearn的fit方法去拟合高斯分布的参数,也就是高斯分布的均值和方差,一般是用EM算法(最大期望)。

code structure:

1、pretrain code:

模型结构类似,encoder和decoder,映射到一个latent space进行重构,这里像vae的训练,但是损失函数用的是mse loss进行重构,而且没有使用正则项,重构完成后,把所有encoder生成的均值向量追加到列表x里面,以及mnist的真实标签加到列表y里面。下面进行高斯混合模型的参数估计,首先利用sklearn这个库,初始化一个gmm模型。再对刚刚x列表,里面是追加的不同样本的latent space,如何进行fit,看看不同样本的数据的latent space属于哪一个高斯分布。接着训练完了,我们在模型初始化的时候已经把高斯混合模型的参数名初始化了,对应论文也就是π 均值 方差,我们把fit好的gmm对应参数赋值给初始化的参数,并保存到本地,这个就是预训练的过程。

2、training code:

预训练完了之后,会自动load训练好的权重,然后进行训练,训练没有什么区别,就是loss改成了elbo的loss,跟vae的loss相比多了一个c变量,也就是类别。在训练结束后,有了一个预测的y,利用了predict方法,让我们看看怎么写的。

首先取出了encoder的均值和向量,利用vade的elbo损失函数训练的,直接采样了一个z,注意用了重参数技巧,防止方差过大。接着把训练好的GMM的参数拿过来进行赋值,gaussian_pdf_log这个函数实际是算每个样本点是属于哪一个分布,通过高斯分布的概率密度函数去算。如何最后通过argmax取出来最大的簇的index。

def predict(self,x):
    z_mu, z_sigma2_log = self.encoder(x)
    z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
    pi = self.pi_
    log_sigma2_c = self.log_sigma2_c
    mu_c = self.mu_c
    yita_c = torch.exp(torch.log(pi.unsqueeze(0))+self.gaussian_pdfs_log(z,mu_c,log_sigma2_c))

    yita=yita_c.detach().cpu().numpy()
    return np.argmax(yita,axis=1)

整个process:

首先预训练是重构图像x,损失用的mse loss,重构好了我们取出encoder的均值向量丢给高斯混合模型,训练出来三个参数,一个是GMM的权重,一个分布的均值,一个是方差。

训练elbo,损失函数就是vade推导出来的elbo函数,主要说一下怎么预测的。还是从训练好的vade拿出均值和方差,利用重参数化技巧采样出一个z,把pretrain的GMM模型的三个参数拿过来赋值,最后y_c是怎么来的呢?就是高斯分布的参数加上每个x属于哪个簇的概率,最后argmax。计算每个簇的概率,利用到了高斯混合模型的概率密度函数直接进行计算就行。至此走完了全部流程。

ps: 论文写的我看不懂,代码我重拳出击。ELBO还没看懂推导,不过直接用就行。

reference:

https://github.com/GuHongyang/VaDE-pytorch

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值