推荐一个机器学习前沿公众号,第一时间获取最有价值的前沿机器学习文章。
以下是对论文《Deep clustering: On the link between discriminative models and K-means》的核心思想、目标函数及其优化过程的详细分析,并针对目标函数的局限性提出改进建议。
1. 论文的核心思想
核心思想:
论文揭示了深度聚类中看似不同的判别式模型(discriminative models)和生成式模型(如 K-means)之间的理论联系,证明了在特定条件下(例如使用逻辑回归后验和
L
2
L_2
L2 正则化),判别式模型(如基于互信息或 KL 散度的模型)与 K-means 等价。这种联系不仅弥合了两种聚类方法的理论差距,还通过推导提出了一种新的软正则化 K-means 算法(SR-K-means),并在图像聚类任务上取得了与判别式模型相当的性能。
-
背景:
深度聚类近年来发展迅速,判别式模型(如 DEPICT、MI-based 模型)因其灵活性和较少的分布假设,通常比生成式模型(如 K-means)表现更好。然而,论文指出,尽管表面上判别式模型和 K-means 似乎无关,但它们在数学上可以通过特定的优化方法(如交替方向法 ADM)建立等价关系。 -
意义:
- 理论上:通过数学推导,证明了判别式模型(如基于互信息的 MI-ADM 和基于 KL 散度的 DEPICT)可以转化为软正则化的 K-means 形式,统一了两种聚类范式。
- 实践上:基于这一理论联系,提出了 SR-K-means 算法,结合深度神经网络和重建损失,实现了高效的深度聚类。
2. 目标函数
论文中讨论了多个目标函数,主要集中在判别式模型的互信息(MI)和 KL 散度,以及生成式模型的 K-means 损失。以下是主要目标函数及其形式:
2.1 判别式模型的目标函数
-
互信息(MI)目标函数(Section 2.1):
互信息用于衡量输入数据 X X X 和潜在聚类标签 K K K 之间的依赖性,目标是最大化互信息:
I ( X , K ) = H ( K ) − H ( K ∣ X ) \mathcal{I}(\mathrm{X}, \mathrm{K}) = \mathcal{H}(\mathrm{K}) - \mathcal{H}(\mathrm{K} \mid \mathrm{X}) I(X,K)=H(K)−H(K∣X)
其中:- H ( K ) = − ∑ k = 1 K p ^ k log ( p ^ k ) \mathcal{H}(\mathrm{K}) = -\sum_{k=1}^K \hat{p}_k \log (\hat{p}_k) H(K)=−∑k=1Kp^klog(p^k) 是标签的边缘熵, p ^ k = 1 N ∑ i = 1 N p i k \hat{p}_k = \frac{1}{N} \sum_{i=1}^N p_{i k} p^k=N1∑i=1Npik 是标签的边缘分布。
- H ( K ∣ X ) = − 1 N ∑ i = 1 N ∑ k = 1 K p i k log ( p i k ) \mathcal{H}(\mathrm{K} \mid \mathrm{X}) = -\frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K p_{i k} \log (p_{i k}) H(K∣X)=−N1∑i=1N∑k=1Kpiklog(pik) 是条件熵, p i k p_{i k} pik 是后验概率。
-
p
i
k
p_{i k}
pik 通常使用逻辑回归形式建模:
p i k ∝ exp ( θ k T z i + b k ) p_{i k} \propto \exp(\boldsymbol{\theta}_k^T \boldsymbol{z}_i + b_k) pik∝exp(θkTzi+bk)
其中 z i = ϕ W ( x i ) \boldsymbol{z}_i = \phi_{\mathcal{W}}(\boldsymbol{x}_i) zi=ϕW(xi) 是深度网络的嵌入, W \mathcal{W} W 是网络参数, θ k \boldsymbol{\theta}_k θk 和 b k b_k bk 是分类器的权重和偏置。
-
KL 散度目标函数(DEPICT 模型)(Section 2.2):
DEPICT 模型通过引入辅助目标分布 Q \boldsymbol{Q} Q,最小化 KL 散度来优化聚类:
min Φ , Q K L ( Q ∥ P ) + γ ∑ k = 1 K q ^ k log ( q ^ k ) s.t. q i T 1 = 1 , q i ≥ 0 ∀ i \min_{\Phi, \boldsymbol{Q}} \mathrm{KL}(\boldsymbol{Q} \| \boldsymbol{P}) + \gamma \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) \quad \text{s.t.} \quad \boldsymbol{q}_i^T \mathbf{1} = 1, \boldsymbol{q}_i \geq 0 \forall i Φ,QminKL(Q∥P)+γk=1∑Kq^klog(q^k)s.t.qiT1=1,qi≥0∀i
其中:- K L ( Q ∥ P ) = 1 N ∑ i = 1 N ∑ k = 1 K q i k log ( q i k p i k ) \mathrm{KL}(\boldsymbol{Q} \| \boldsymbol{P}) = \frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log \left(\frac{q_{i k}}{p_{i k}}\right) KL(Q∥P)=N1∑i=1N∑k=1Kqiklog(pikqik) 是辅助分布 Q \boldsymbol{Q} Q 和后验 P \boldsymbol{P} P 之间的 KL 散度。
- γ ∑ k = 1 K q ^ k log ( q ^ k ) \gamma \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) γ∑k=1Kq^klog(q^k) 是平衡项,促进聚类分配的均匀性。
- Φ = { O , W } \Phi = \{\mathcal{O}, \mathcal{W}\} Φ={O,W} 包含分类器参数 O \mathcal{O} O 和网络参数 W \mathcal{W} W。
-
正则化互信息目标函数(Section 3):
为了建立与 K-means 的联系,论文在互信息目标上加入 L 2 L_2 L2 正则化项:
I ( X , K ) − λ ∑ k = 1 K θ k T θ k \mathcal{I}(\mathrm{X}, \mathrm{K}) - \lambda \sum_{k=1}^K \boldsymbol{\theta}_k^T \boldsymbol{\theta}_k I(X,K)−λk=1∑KθkTθk
其中 λ \lambda λ 是正则化参数。
2.2 生成式模型的目标函数(K-means)
-
标准 K-means 目标函数(Section 3):
传统的 K-means 目标是最小化数据点到聚类中心的距离:
∑ i = 1 N ∑ k = 1 K s i k ∥ z i − μ k ∥ 2 s.t. ∑ k = 1 K s i k = 1 , s i k ∈ { 0 , 1 } ∀ i , k \sum_{i=1}^N \sum_{k=1}^K s_{i k} \|\boldsymbol{z}_i - \boldsymbol{\mu}_k\|^2 \quad \text{s.t.} \quad \sum_{k=1}^K s_{i k} = 1, \quad s_{i k} \in \{0, 1\} \forall i, k i=1∑Nk=1∑Ksik∥zi−μk∥2s.t.k=1∑Ksik=1,sik∈{0,1}∀i,k
其中 s i k s_{i k} sik 是二元分配变量, μ k \boldsymbol{\mu}_k μk 是聚类中心。 -
软正则化 K-means(SR-K-means)目标函数(Section 3):
论文证明了正则化互信息的优化等价于以下软正则化 K-means 损失:
∑ i = 1 N ∑ k = 1 K q i k ∥ z i − θ k ′ ∥ 2 + λ K ∑ i = 1 N ∑ k = 1 K q i k log ( q i k ) − ∑ i = 1 N z i T z i \sum_{i=1}^N \sum_{k=1}^K q_{i k} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2 + \lambda K \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (q_{i k}) - \sum_{i=1}^N \boldsymbol{z}_i^T \boldsymbol{z}_i i=1∑Nk=1∑Kqik∥zi−θk′∥2+λKi=1∑Nk=1∑Kqiklog(qik)−i=1∑NziTzi
其中:- θ k ′ = ∑ i = 1 N q i k z i ∑ i = 1 N q i k \boldsymbol{\theta}_k^{\prime} = \frac{\sum_{i=1}^N q_{i k} \boldsymbol{z}_i}{\sum_{i=1}^N q_{i k}} θk′=∑i=1Nqik∑i=1Nqikzi 是软聚类中心。
- q i k ∈ [ 0 , 1 ] q_{i k} \in [0, 1] qik∈[0,1] 是软分配变量,替代了硬分配 s i k s_{i k} sik。
- 第二项 λ K ∑ i = 1 N ∑ k = 1 K q i k log ( q i k ) \lambda K \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (q_{i k}) λK∑i=1N∑k=1Kqiklog(qik) 是负熵项,促进分配的软性。
2.3 附加的重建损失
为了防止嵌入过拟合,所有模型都结合了重建损失(Section 5.1):
R
(
Z
)
=
1
N
∑
i
=
1
N
∑
l
=
0
L
−
1
1
∣
z
i
l
∣
∥
z
i
l
−
z
^
i
l
∥
2
\mathcal{R}(\mathcal{Z}) = \frac{1}{N} \sum_{i=1}^N \sum_{l=0}^{L-1} \frac{1}{|\boldsymbol{z}_i^l|} \|\boldsymbol{z}_i^l - \hat{\boldsymbol{z}}_i^l\|^2
R(Z)=N1i=1∑Nl=0∑L−1∣zil∣1∥zil−z^il∥2
其中
z
i
l
\boldsymbol{z}_i^l
zil 和
z
^
i
l
\hat{\boldsymbol{z}}_i^l
z^il 分别是第
l
l
l 层的干净嵌入和重建嵌入。
综合来看,实际优化目标通常是聚类损失(如互信息或 K-means 损失)与重建损失的组合,例如 SR-K-means 的目标为:
min
W
1
N
λ
K
∑
i
=
1
N
∑
k
=
1
K
q
i
k
∥
z
i
−
θ
k
′
∥
2
−
1
N
λ
K
∑
i
=
1
N
z
i
T
z
i
+
R
(
Z
)
\min_{\mathcal{W}} \frac{1}{N \lambda K} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2 - \frac{1}{N \lambda K} \sum_{i=1}^N \boldsymbol{z}_i^T \boldsymbol{z}_i + \mathcal{R}(\mathcal{Z})
WminNλK1i=1∑Nk=1∑Kqik∥zi−θk′∥2−NλK1i=1∑NziTzi+R(Z)
3. 目标函数的优化过程
3.1 互信息目标的优化(MI-ADM)
-
方法:使用交替方向法(ADM)优化正则化互信息:
max Φ , Q 1 N ∑ i = 1 N ∑ k = 1 K q i k log ( p i k ) − ∑ k = 1 K q ^ k log ( q ^ k ) s.t. Q = P , q i T 1 = 1 , q i ≥ 0 \max_{\Phi, \boldsymbol{Q}} \frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (p_{i k}) - \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) \quad \text{s.t.} \quad \boldsymbol{Q} = \boldsymbol{P}, \quad \boldsymbol{q}_i^T \mathbf{1} = 1, \quad \boldsymbol{q}_i \geq 0 Φ,QmaxN1i=1∑Nk=1∑Kqiklog(pik)−k=1∑Kq^klog(q^k)s.t.Q=P,qiT1=1,qi≥0
通过引入 KL 散度惩罚,将约束问题转化为无约束问题:
max Φ , Q 1 N ∑ i = 1 N ∑ k = 1 K q i k log ( p i k ) − ∑ k = 1 K q ^ k log ( q ^ k ) − K L ( Q ∥ P ) \max_{\Phi, \boldsymbol{Q}} \frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (p_{i k}) - \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) - \mathrm{KL}(\boldsymbol{Q} \| \boldsymbol{P}) Φ,QmaxN1i=1∑Nk=1∑Kqiklog(pik)−k=1∑Kq^klog(q^k)−KL(Q∥P) -
优化步骤:
- 参数学习步(更新
Φ
\Phi
Φ):固定
Q
\boldsymbol{Q}
Q,优化网络参数
Φ
\Phi
Φ,等价于最小化交叉熵损失:
min Φ − 1 N ∑ i = 1 N ∑ k = 1 K q i k log p i k \min_{\Phi} -\frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log p_{i k} Φmin−N1i=1∑Nk=1∑Kqiklogpik
使用 SGD(Adam 优化器)更新 Φ \Phi Φ。 - 目标估计步(更新
Q
\boldsymbol{Q}
Q):固定
Φ
\Phi
Φ,优化
Q
\boldsymbol{Q}
Q,得到闭式解:
q i k ∝ p i k 2 ( ∑ i ′ = 1 N p i ′ k 2 ) 1 / 2 q_{i k} \propto \frac{p_{i k}^2}{\left(\sum_{i^{\prime}=1}^N p_{i^{\prime} k}^2\right)^{1/2}} qik∝(∑i′=1Npi′k2)1/2pik2
- 参数学习步(更新
Φ
\Phi
Φ):固定
Q
\boldsymbol{Q}
Q,优化网络参数
Φ
\Phi
Φ,等价于最小化交叉熵损失:
-
理论保证(Section 4.1):
Proposition 3 证明了 ADM 优化的单调性,即互信息 I ( X , K ) \mathcal{I}(\mathrm{X}, \mathrm{K}) I(X,K) 在每次迭代中不减少。
3.2 KL 散度目标的优化(DEPICT)
-
方法:通过交替优化解决:
min Φ , Q K L ( Q ∥ P ) + γ ∑ k = 1 K q ^ k log ( q ^ k ) \min_{\Phi, \boldsymbol{Q}} \mathrm{KL}(\boldsymbol{Q} \| \boldsymbol{P}) + \gamma \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) Φ,QminKL(Q∥P)+γk=1∑Kq^klog(q^k) -
优化步骤:
- 参数学习步:固定 Q \boldsymbol{Q} Q,优化 Φ \Phi Φ,等价于交叉熵损失(同上)。
- 目标估计步:固定
Φ
\Phi
Φ,优化
Q
\boldsymbol{Q}
Q,得到:
q i k ∝ p i k ( ∑ i ′ = 1 N p i ′ k ) 1 / 2 q_{i k} \propto \frac{p_{i k}}{\left(\sum_{i^{\prime}=1}^N p_{i^{\prime} k}\right)^{1/2}} qik∝(∑i′=1Npi′k)1/2pik
-
与 MI 的联系(Proposition 1):
DEPICT 的优化可以看作是互信息最大化的近似 ADM 解法。
3.3 SR-K-means 目标的优化
-
方法:优化软正则化 K-means 损失:
∑ i = 1 N ∑ k = 1 K q i k ∥ z i − θ k ′ ∥ 2 + λ K ∑ i = 1 N ∑ k = 1 K q i k log ( q i k ) − ∑ i = 1 N z i T z i + R ( Z ) \sum_{i=1}^N \sum_{k=1}^K q_{i k} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2 + \lambda K \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (q_{i k}) - \sum_{i=1}^N \boldsymbol{z}_i^T \boldsymbol{z}_i + \mathcal{R}(\mathcal{Z}) i=1∑Nk=1∑Kqik∥zi−θk′∥2+λKi=1∑Nk=1∑Kqiklog(qik)−i=1∑NziTzi+R(Z) -
优化步骤:
- 更新聚类中心
θ
k
′
\boldsymbol{\theta}_k^{\prime}
θk′:
θ k ′ = ∑ i = 1 N q i k z i ∑ i = 1 N q i k \boldsymbol{\theta}_k^{\prime} = \frac{\sum_{i=1}^N q_{i k} \boldsymbol{z}_i}{\sum_{i=1}^N q_{i k}} θk′=∑i=1Nqik∑i=1Nqikzi - 更新分配变量
q
i
k
q_{i k}
qik:
q i k ∝ exp ( − 1 λ K ∥ z i − θ k ′ ∥ 2 ) q_{i k} \propto \exp \left(-\frac{1}{\lambda K} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2\right) qik∝exp(−λK1∥zi−θk′∥2) - 更新网络参数 W \mathcal{W} W:通过 SGD 优化嵌入 z i \boldsymbol{z}_i zi 和重建损失 R ( Z ) \mathcal{R}(\mathcal{Z}) R(Z)。
- 更新聚类中心
θ
k
′
\boldsymbol{\theta}_k^{\prime}
θk′:
4. 主要贡献点
-
理论联系:
- 证明了判别式模型(基于互信息和 KL 散度)和 K-means 的等价性(Proposition 2)。
- 建立了 DEPICT 和互信息目标之间的联系(Proposition 1),表明 DEPICT 是互信息最大化的近似 ADM 解法。
-
新算法提出:
- 基于理论推导,提出了软正则化 K-means(SR-K-means)算法,结合深度网络和重建损失,提升了生成式模型的性能。
-
实验验证:
- 在多个图像聚类基准数据集(如 MNIST、USPS、YTF)上验证了 SR-K-means 和 MI-ADM 的性能,与 DEPICT 等判别式模型相当。
- SR-K-means 比传统硬 K-means(DCN)提高了 11% 的性能(Table 2)。
-
优化分析:
- 证明了 ADM 优化的单调性(Proposition 3)。
- 分析了 KL 散度和二次惩罚之间的关系,指出 KL 散度在单纯形约束下的计算优势。
5. 针对目标函数的局限性提出改进意见
5.1 目标函数的局限性
-
对平衡聚类的假设:
- SR-K-means 的推导假设了聚类分配是平衡的( q ^ k ≈ 1 K \hat{q}_k \approx \frac{1}{K} q^k≈K1),但实际数据(如 YTF、FRGC)往往是不平衡的(Table 1),这可能导致性能下降。
- 改进建议:引入更灵活的平衡约束,例如使用可调的先验分布 d ^ k \hat{d}_k d^k(如论文中提到的 K L ( ( p ^ k ) ∥ ( d ^ k ) ) \mathrm{KL}((\hat{p}_k) \| (\hat{d}_k)) KL((p^k)∥(d^k))),以适应不平衡数据。
-
逻辑回归后验的限制:
- 目标函数依赖逻辑回归后验 p i k ∝ exp ( θ k T z i + b k ) p_{i k} \propto \exp(\boldsymbol{\theta}_k^T \boldsymbol{z}_i + b_k) pik∝exp(θkTzi+bk),假设嵌入空间中的类别边界是线性的,可能无法捕捉复杂的非线性分布。
- 改进建议:尝试更复杂的后验模型,例如使用核方法(如高斯核)或深度网络直接输出后验分布,以捕捉非线性关系。
-
软分配的熵正则化:
- SR-K-means 中的负熵项 λ K ∑ i = 1 N ∑ k = 1 K q i k log ( q i k ) \lambda K \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (q_{i k}) λK∑i=1N∑k=1Kqiklog(qik) 促进软分配,但可能导致分配过于平滑,削弱聚类的区分性。
- 改进建议:引入可调的温度参数
τ
\tau
τ,修改软分配公式为:
q i k ∝ exp ( − 1 τ ∥ z i − θ k ′ ∥ 2 ) q_{i k} \propto \exp \left(-\frac{1}{\tau} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2\right) qik∝exp(−τ1∥zi−θk′∥2)
通过调整 τ \tau τ 控制分配的软性程度。
-
对高维和流形结构的适应性不足:
- K-means 和 SR-K-means 依赖欧几里得距离,难以处理高维数据中的流形结构(Section 6)。
- 改进建议:如论文结尾建议,探索其他原型方法(如 K-modes)或成对聚类目标(如归一化割 normalized cut),以更好地捕捉数据的流形结构。
-
优化方法的局限性:
- ADM 优化依赖交替更新,可能陷入局部最优,且对初始值敏感(Section 5.2.4)。
- 改进建议:引入全局优化技术(如模拟退火)或多样化的初始化策略,减少对初始值的依赖。此外,可以尝试其他距离度量(如 Bhattacharyya 距离,Section 6)替代 KL 散度,改善收敛性。
总结
- 核心思想:论文通过数学推导揭示了判别式模型和 K-means 的等价性,统一了深度聚类的两种范式,并提出了 SR-K-means 算法。
- 目标函数:包括互信息、KL 散度和软正则化 K-means 损失,通过 ADM 或交替优化实现联合特征学习和聚类。
- 主要贡献:理论联系、新算法、实验验证和优化分析。
- 改进建议:针对不平衡数据、非线性分布、软分配平滑、高维流形结构和优化局限性,提出了引入灵活约束、更复杂后验模型、可调软分配、流形目标和全局优化的改进方向。
1. 理解这句话的含义
原句:
- 论文中判别式深度聚类模型使用逻辑回归后验
P
(
K
∣
X
)
P(K|X)
P(K∣X):
p i k ∝ exp ( θ k T z i + b k ) p_{ik} \propto \exp(\theta_k^T z_i + b_k) pik∝exp(θkTzi+bk)
逐部分拆解:
-
判别式深度聚类模型:
- 判别式模型关注条件概率 P ( K ∣ X ) P(K|X) P(K∣X),即给定输入 X X X,预测其属于某个聚类 K K K 的概率。
- 在深度聚类中,输入 X X X 通常通过深度神经网络(DNN)映射到嵌入空间 Z Z Z,即 z i = ϕ W ( x i ) z_i = \phi_{\mathcal{W}}(x_i) zi=ϕW(xi),其中 W \mathcal{W} W 是网络参数。
- 目标是同时学习嵌入 z i z_i zi 和聚类分配 p i k p_{ik} pik(样本 i i i 属于聚类 k k k 的概率)。
-
逻辑回归后验 P ( K ∣ X ) P(K|X) P(K∣X):
- 判别式模型通过逻辑回归(Logistic Regression)建模条件概率 P ( K ∣ X ) P(K|X) P(K∣X)。
-
p
i
k
p_{ik}
pik 表示样本
i
i
i 属于聚类
k
k
k 的后验概率,形式为:
p i k ∝ exp ( θ k T z i + b k ) p_{ik} \propto \exp(\theta_k^T z_i + b_k) pik∝exp(θkTzi+bk) - 这是一个softmax 形式的概率分布, ∝ \propto ∝ 表示比例关系,最终需要归一化以满足概率分布的要求。
-
公式中的符号:
- p i k p_{ik} pik:样本 i i i 属于聚类 k k k 的概率。
- z i z_i zi:样本 i i i 的嵌入表示,由深度网络 ϕ W ( x i ) \phi_{\mathcal{W}}(x_i) ϕW(xi) 生成。
- θ k \theta_k θk:聚类 k k k 对应的权重向量。
- b k b_k bk:聚类 k k k 对应的偏置(bias)。
- θ k T z i + b k \theta_k^T z_i + b_k θkTzi+bk:线性组合,表示样本 i i i 与聚类 k k k 的“匹配得分”。
-
归一化后的完整形式:
- 为了确保
p
i
k
p_{ik}
pik 是一个有效的概率分布(满足
∑
k
=
1
K
p
i
k
=
1
\sum_{k=1}^K p_{ik} = 1
∑k=1Kpik=1),需要对指数项进行归一化:
p i k = exp ( θ k T z i + b k ) ∑ j = 1 K exp ( θ j T z i + b j ) p_{ik} = \frac{\exp(\theta_k^T z_i + b_k)}{\sum_{j=1}^K \exp(\theta_j^T z_i + b_j)} pik=∑j=1Kexp(θjTzi+bj)exp(θkTzi+bk) - 这是 softmax 函数的经典形式,广泛用于多分类问题中。
- 为了确保
p
i
k
p_{ik}
pik 是一个有效的概率分布(满足
∑
k
=
1
K
p
i
k
=
1
\sum_{k=1}^K p_{ik} = 1
∑k=1Kpik=1),需要对指数项进行归一化:
-
直观理解:
- exp ( θ k T z i + b k ) \exp(\theta_k^T z_i + b_k) exp(θkTzi+bk) 是一个非负的得分,分数越高,样本 i i i 越可能属于聚类 k k k。
- 归一化后, p i k p_{ik} pik 是一个概率值,表示样本 i i i 属于聚类 k k k 的可能性。
- θ k \theta_k θk 和 b k b_k bk 是可学习的参数,通过优化目标函数(如互信息或 KL 散度)来调整,使得 p i k p_{ik} pik 能更好地反映样本的聚类分配。
2. 理论背景:为什么用这种形式?
-
逻辑回归的适用性:
- 逻辑回归是一种经典的判别式模型,适用于多分类任务(如聚类)。
- 在深度聚类中,逻辑回归后验可以无缝嵌入深度网络,通过梯度下降优化参数 θ k \theta_k θk、 b k b_k bk 和 W \mathcal{W} W。
-
深度聚类的目标:
- 深度聚类希望联合学习特征嵌入 z i z_i zi 和聚类分配 p i k p_{ik} pik。
- 逻辑回归后验 p i k p_{ik} pik 提供了概率化的分配方式,便于优化目标函数(如互信息 I ( X , K ) \mathcal{I}(\mathrm{X}, \mathrm{K}) I(X,K) 或 KL 散度)。
-
与论文的联系:
- 论文《Deep clustering: On the link between discriminative models and K-means》中,判别式模型(如 DEPICT、MI-ADM)使用这种逻辑回归后验来建模 P ( K ∣ X ) P(K|X) P(K∣X)。
- 这种形式还被用来证明判别式模型与 K-means 的等价性(Proposition 2),因为它可以被转化为软 K-means 形式。
3. 详细例子:基于 MNIST 数据集的深度聚类
为了更直观地理解这句话,我们通过一个具体的例子来说明:使用深度聚类模型对 MNIST 数据集(手写数字 0-9)进行聚类。
3.1 问题背景
- 数据集:MNIST 包含 70,000 张 28x28 像素的手写数字图像,分为 10 个类别(0-9)。我们假设没有标签(无监督聚类任务),目标是将图像分为 10 个聚类。
- 任务:使用判别式深度聚类模型,学习特征嵌入 z i z_i zi 并预测每个图像的聚类概率 p i k p_{ik} pik。
3.2 模型设置
-
深度网络:
- 输入: x i ∈ R 28 × 28 x_i \in \mathbb{R}^{28 \times 28} xi∈R28×28,是第 i i i 张图像(展平后为 784 维向量)。
- 深度网络
ϕ
W
\phi_{\mathcal{W}}
ϕW:一个卷积神经网络(CNN)或全连接网络,将
x
i
x_i
xi 映射到嵌入空间
z
i
∈
R
d
z_i \in \mathbb{R}^{d}
zi∈Rd(假设
d
=
10
d = 10
d=10)。
- 例如:CNN 包含 2 个卷积层(带 ReLU 激活)和 1 个全连接层,最终输出 10 维嵌入 z i z_i zi。
-
逻辑回归后验:
- 聚类数 K = 10 K = 10 K=10(对应 0-9 数字)。
- 对于每个聚类 k k k( k = 1 , … , 10 k = 1, \ldots, 10 k=1,…,10),有权重 θ k ∈ R 10 \theta_k \in \mathbb{R}^{10} θk∈R10 和偏置 b k ∈ R b_k \in \mathbb{R} bk∈R。
- 后验概率:
p i k = exp ( θ k T z i + b k ) ∑ j = 1 10 exp ( θ j T z i + b j ) p_{ik} = \frac{\exp(\theta_k^T z_i + b_k)}{\sum_{j=1}^{10} \exp(\theta_j^T z_i + b_j)} pik=∑j=110exp(θjTzi+bj)exp(θkTzi+bk)
-
目标函数:
- 假设使用互信息目标(类似论文中的 MI-ADM):
I ( X , K ) = H ( K ) − H ( K ∣ X ) \mathcal{I}(\mathrm{X}, \mathrm{K}) = \mathcal{H}(\mathrm{K}) - \mathcal{H}(\mathrm{K} \mid \mathrm{X}) I(X,K)=H(K)−H(K∣X)
其中:- H ( K ) = − ∑ k = 1 10 p ^ k log ( p ^ k ) \mathcal{H}(\mathrm{K}) = -\sum_{k=1}^{10} \hat{p}_k \log (\hat{p}_k) H(K)=−∑k=110p^klog(p^k), p ^ k = 1 N ∑ i = 1 N p i k \hat{p}_k = \frac{1}{N} \sum_{i=1}^N p_{ik} p^k=N1∑i=1Npik。
- H ( K ∣ X ) = − 1 N ∑ i = 1 N ∑ k = 1 10 p i k log ( p i k ) \mathcal{H}(\mathrm{K} \mid \mathrm{X}) = -\frac{1}{N} \sum_{i=1}^N \sum_{k=1}^{10} p_{ik} \log (p_{ik}) H(K∣X)=−N1∑i=1N∑k=110piklog(pik)。
- 假设使用互信息目标(类似论文中的 MI-ADM):
-
优化目标:
- 互信息目标会利用 p i k p_{ik} pik 计算 H ( K ∣ X ) \mathcal{H}(\mathrm{K} \mid \mathrm{X}) H(K∣X) 和 H ( K ) \mathcal{H}(\mathrm{K}) H(K),并通过梯度下降更新 θ k \theta_k θk、 b k b_k bk 和 W \mathcal{W} W,使 p i 3 p_{i3} pi3 进一步增大。
- 例如, H ( K ∣ X ) \mathcal{H}(\mathrm{K} \mid \mathrm{X}) H(K∣X) 鼓励 p i k p_{ik} pik 更“尖锐”(接近 one-hot 分布),从而减少不确定性。
3. 优化过程
- 交替优化(类似论文中的 ADM):
- 更新网络参数 W \mathcal{W} W:固定 p i k p_{ik} pik,优化嵌入 z i z_i zi,使特征更适合聚类。
- 更新分类参数 θ k , b k \theta_k, b_k θk,bk:固定 z i z_i zi,调整 θ k \theta_k θk 和 b k b_k bk,使 p i k p_{ik} pik 更准确。
- 结果:经过多轮迭代,模型可能学习到:
- 聚类 3 对应数字“3”, p i 3 p_{i3} pi3 接近 1。
- 嵌入空间 z i z_i zi 将不同数字的图像分离开来。
4. 直观理解
-
逻辑回归后验的作用:
- θ k T z i + b k \theta_k^T z_i + b_k θkTzi+bk 可以看作样本 i i i 与聚类 k k k 的“相似度得分”。
- 指数化 exp ( ⋅ ) \exp(\cdot) exp(⋅) 将得分映射为正值,分数越高,概率越大。
- 归一化(softmax)确保 p i k p_{ik} pik 是概率分布,反映样本 i i i 属于各聚类的相对可能性。
-
深度聚类的优势:
- 嵌入 z i z_i zi 由深度网络学习,能捕捉高维数据(如图像)的复杂特征。
- p i k p_{ik} pik 的概率化分配支持软聚类,便于优化(如通过互信息或 KL 散度)。
-
与 K-means 的联系:
- 论文证明,这种逻辑回归后验在特定条件下( L 2 L_2 L2 正则化和平衡假设)等价于软 K-means 损失。
- 直观上, θ k \theta_k θk 可以看作聚类中心, θ k T z i \theta_k^T z_i θkTzi 类似 z i z_i zi 到中心的距离(经过变换)。