ClusterGAN: Latent Space Clustering in Generative Adversarial Networks论文理解

目的:使用GAN在隐空间中进行聚类

一、背景

普通GAN的训练过程: m i n θ G m a x θ D E x ∼ P x r q ( D ( x ) ) + E z ∼ P z q ( 1 − D ( G ( z ) ) ) min_{\theta G}max_{\theta D}\textbf{E}_{x\sim P^r_x}q(D(x))+\textbf{E}_{z\sim P_z}q(1-D(G(z))) minθGmaxθDExPxrq(D(x))+EzPzq(1D(G(z))),它无法在隐空间很好地聚类。
原因:GAN聚类的一个可能的方式是将数据反向传播到隐空间,并在隐空间聚类。但是即使反向传播成功也无法很好地聚类。一个关键问题是反向投影的数据分布和隐空间的分布应该是相似的,通常是高斯分布或均匀分布。因此即使隐空间包含了数据的所有信息,但是隐空间向量之间的几何距离不能反映类别信息,因此无法很好地聚类。

二、本文的方法

网络模型:
clusterGAN网络模型
1.从离散连续混合中抽样:
z = ( z n , z c ) , z n ∼ N ( 0 , σ 2 I d n ) , z c = e k , k ∼ U { 1 , 2 , . . . , K } z=(z_n,z_c),z_n\sim N(0,\sigma ^2I_{d_n}),z_c=e_k,k\sim U\left\{1,2,...,K\right\} z=(zn,zc),znN(0,σ2Idn),zc=ek,kU{1,2,...,K},ek是K维的向量,其中第k维是1,即zn是正态分布,zc是K维的离散型one-hot向量,K是类别数量,二者联合构成离散-连续向量z。
经过试验证明,相比于均匀分布、正态分布、高斯混合分布,从离散连续混合中抽样的聚类效果更好。从不同分布中采样得到的隐空间分布如下图所示:
从不同分布中采样得到的隐空间分布
2. 基于改进的反向传播解码
为了获得更好的隐向量,已有的工作是解决一个优化问题: z ∗ = a r g   m i n z L ( G ( z ) , x ) + λ ∣ ∣ z ∣ ∣ p z^*=arg\space min_zL(G(z),x)+\lambda ||z||_p z=arg minzL(G(z),x)+λzp,其中L是适宜的损失函数,但是这个方法对聚类是不够的。
在本论文中,我们让 L ( G ( z ) , x ) = ∣ ∣ G ( z ) − x ∣ ∣ 1 L(G(z),x)=||G(z)-x||_1 L(G(z),x)=G(z)x1,惩罚项为 ∣ ∣ z n ∣ ∣ 2 2 ||z_n||^2_2 zn22,只惩罚正态部分。再抽样K次,每次用不同的zc进行抽样,在优化时固定zc,用Adam优化正态部分。
3.使用线性分类器可以获得更好的聚类效果

引理:Clustering with only zn cannot recover a mixture of gaussian data in the linearly generated space. Further ∃ a linear G(·) mapping discrete-continuous mixtures to a mixture of Gaussians.

证明:
如果隐空间只包含连续部分,即 z = z n ∼ N ( 0 , σ 2 I d n ) z=z_n\sim N(0,\sigma ^2I_{d_n}) z=znN(0,σ2Idn),则通过线性生成器只能生成高斯分布。(线性生成器的本质是一个线性变换,高斯分布经过线性变换后还是高斯分布,不能变为混合高斯分布)
如果隐空间包含离散和连续维混合,即 z = ( z n , z c ) , z n ∼ N ( 0 , σ 2 I d n ) , z c = e k , k ∼ U { 1 , 2 , . . . , K } z=(z_n,z_c),z_n\sim N(0,\sigma ^2I_{d_n}),z_c=e_k,k\sim U\left\{1,2,...,K\right\} z=(zn,zc),znN(0,σ2Idn),zc=ek,kU{1,2,...,K},要得到生成数据 X ∼ N ( μ ω , σ 2 I d n ) , ω ∼ U { 1 , 2 , . . . , K } X \sim N(\mu_{\omega},\sigma ^2I_{d_n}),\omega \sim U\left\{1,2,...,K\right\} XN(μω,σ2Idn),ωU{1,2,...,K},需要构造一个生成器 G ( ⋅ ) G(\cdot) G(),使得 G : Z → X G:Z\to X GZX,可以得到 x g = G ( z ) = G ( z n , z c ) = z n + A z c , x_g=G(z)=G(z_n,z_c)=z_n+Az_c, xg=G(z)=G(zn,zc)=zn+Azc,其中 A = d i a g [ μ 1 , . . . , μ K ] A=diag[\mu_1,...,\mu_K] A=diag[μ1,...,μK] K × K K\times K K×K的对角矩阵。这里的X符合混合高斯分布。
4.使用插值

插值(Interpolation)是离散函数逼近的重要方法,利用它可通过函数在有限个点处的取值状况,估算出函数在其他点处的近似值。

clusterGAN中构造插值点是通过 z = ( z n , μ z c ( 1 ) + ( 1 − μ ) z c ( 2 ) ) , μ ∈ [ 0 , 1 ] z=(z_n,\mu z^{(1)}_c+(1-\mu )z_c^{(2)}),\mu \in [0,1] z=(zn,μzc(1)+(1μ)zc(2)),μ[0,1],可以达到渐变效果(不同类间的过渡)。通过构造不同类间的插值点,使得生成的不同类的 可以明显地区分开,是一种提升训练精度的手段
5.几种精度
(1)模型精度
从Z选中第k个簇生成的样本xg,然后用分类器判断生成样本的类别 y ^ \hat{y} y^,计算正确率, ( k , y ^ ) (k,\hat{y}) (k,y^)即为模型精度
(2)重构精度
X中的属于类 y y y的x解码得到z,z再生成xg,xg经过分类得到的类标签为 y ^ , ( y , y ^ ) \hat{y},(y,\hat{y}) y^,(y,y^)的精度为重构精度。
(3)聚类精度
X空间中同一类中所有点的映射生成具有相同的one-hot编码,这些点占总点数的比率为聚类精度
6.在原来GAN结构基础上加一个编码器
在目标函数中,编码器的损失函数作为正则化项。因此,加入编码器的其中一个目的是防止GAN的过拟合现象。另一个目的是聚类
7.目标函数
m i n Θ G , Θ E   m a x Θ D   E x ∼ P x r   q ( D ( x ) ) + E z ∼ P z   q ( 1 − D ( G ( z ) ) ) + β n E z ∼ P z ∣ ∣ z n − E ( G ( z n ) ) ∣ ∣ 2 2 + β c E z ∼ P z H ( z c , E ( G ( z c ) ) ) min_{\Theta _G,\Theta _E}\space max_{\Theta_D}\space \textbf{E}_{x\sim \textbf{P}^r_x}\space q(D(x))+\textbf{E}_{z\sim \textbf{P}_z}\space q(1-D(G(z))) +\beta_n\textbf{E}_{z\sim \textbf{P}_z}||z_n-E(G(z_n))||^2_2+\beta_c\textbf{E}_{z\sim \textbf{P}_z}H(z_c,E(G(z_c))) minΘG,ΘE maxΘD ExPxr q(D(x))+EzPz q(1D(G(z)))+βnEzPzznE(G(zn))22+βcEzPzH(zc,E(G(zc)))
其中H是交叉熵损失。可将其分为两部分来看,前两项为第一部分,后两项为第二部分。
前两项的含义是使得xg尽可能地与xr相似,可重新表示为 E x r ∼ P x r   q ( D ( x r ) ) + E x g ∼ P x g   q ( 1 − D ( x g ) ) \textbf{E}_{x_r\sim \textbf{P}^r_x}\space q(D(x_r))+\textbf{E}_{x_g\sim \textbf{P}^g_x}\space q(1-D(x_g)) ExrPxr q(D(xr))+ExgPxg q(1D(xg)),其中 x ∼ P x r x\sim \textbf{P}^r_x xPxr表示为真实数据 x ∈ X x\in X xX服从分布 P x r \textbf{P}^r_x Pxr x ∼ P x g x\sim \textbf{P}^g_x xPxg表示为真实数据 x ∈ X x\in X xX服从分布 P x g \textbf{P}^g_x Pxg
后两项的含义是使得xg尽可能地保留z的信息,即信息损失尽可能的小。可重新表示为 β n E z ∼ P z , x g ∼ P x g ∣ ∣ z n − E ( x g , n ) ∣ ∣ 2 2 + β c E z ∼ P z , x g ∼ P x g H ( z c , E ( x g , c ) ) = β n E z ∼ P z ∣ ∣ z n − z n ^ ∣ ∣ 2 2 + β c E z ∼ P z H ( z c , z c ^ ) \beta _n \textbf{E}_{z \sim \textbf{P}_z,x_g\sim \textbf{P}^g_x}||z_n-E(x_{g,n})||^2_2+\beta _c\textbf{E}_{z\sim \textbf{P}_z,x_g\sim\textbf{P}^g_x}H(z_c,E(x_{g,c})) =\beta _n\textbf{E}_{z\sim \textbf{P}_z}||z_n-\hat{z_n}||^2_2+\beta _c\textbf{E}_{z\sim \textbf{P}_z}H(z_c,\hat{z_c}) βnEzPz,xgPxgznE(xg,n)22+βcEzPz,xgPxgH(zc,E(xg,c))=βnEzPzznzn^22+βcEzPzH(zc,zc^),其中, z = ( z n , z c ) , z ^ = ( z n ^ , z c ^ ) , x g = ( x g , n , x g , c ) z=(z_n,z_c),\hat{z}=(\hat{z_n},\hat{z_c}),x_g=(x_{g,n},x_{g,c}) z=(zn,zc),z^=(zn^,zc^),xg=(xg,n,xg,c)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值