PAMI-2021《Deep clustering On the link between discriminative models and K-means》


推荐一个机器学习前沿公众号,第一时间获取最有价值的前沿机器学习文章。

在这里插入图片描述

以下是对论文《Deep clustering: On the link between discriminative models and K-means》的核心思想、目标函数及其优化过程的详细分析,并针对目标函数的局限性提出改进建议。


1. 论文的核心思想

核心思想
论文揭示了深度聚类中看似不同的判别式模型(discriminative models)和生成式模型(如 K-means)之间的理论联系,证明了在特定条件下(例如使用逻辑回归后验和 L 2 L_2 L2 正则化),判别式模型(如基于互信息或 KL 散度的模型)与 K-means 等价。这种联系不仅弥合了两种聚类方法的理论差距,还通过推导提出了一种新的软正则化 K-means 算法(SR-K-means),并在图像聚类任务上取得了与判别式模型相当的性能。

  • 背景
    深度聚类近年来发展迅速,判别式模型(如 DEPICT、MI-based 模型)因其灵活性和较少的分布假设,通常比生成式模型(如 K-means)表现更好。然而,论文指出,尽管表面上判别式模型和 K-means 似乎无关,但它们在数学上可以通过特定的优化方法(如交替方向法 ADM)建立等价关系。

  • 意义

    • 理论上:通过数学推导,证明了判别式模型(如基于互信息的 MI-ADM 和基于 KL 散度的 DEPICT)可以转化为软正则化的 K-means 形式,统一了两种聚类范式。
    • 实践上:基于这一理论联系,提出了 SR-K-means 算法,结合深度神经网络和重建损失,实现了高效的深度聚类。

2. 目标函数

论文中讨论了多个目标函数,主要集中在判别式模型的互信息(MI)和 KL 散度,以及生成式模型的 K-means 损失。以下是主要目标函数及其形式:

2.1 判别式模型的目标函数
  • 互信息(MI)目标函数(Section 2.1):
    互信息用于衡量输入数据 X X X 和潜在聚类标签 K K K 之间的依赖性,目标是最大化互信息:
    I ( X , K ) = H ( K ) − H ( K ∣ X ) \mathcal{I}(\mathrm{X}, \mathrm{K}) = \mathcal{H}(\mathrm{K}) - \mathcal{H}(\mathrm{K} \mid \mathrm{X}) I(X,K)=H(K)H(KX)
    其中:

    • H ( K ) = − ∑ k = 1 K p ^ k log ⁡ ( p ^ k ) \mathcal{H}(\mathrm{K}) = -\sum_{k=1}^K \hat{p}_k \log (\hat{p}_k) H(K)=k=1Kp^klog(p^k) 是标签的边缘熵, p ^ k = 1 N ∑ i = 1 N p i k \hat{p}_k = \frac{1}{N} \sum_{i=1}^N p_{i k} p^k=N1i=1Npik 是标签的边缘分布。
    • H ( K ∣ X ) = − 1 N ∑ i = 1 N ∑ k = 1 K p i k log ⁡ ( p i k ) \mathcal{H}(\mathrm{K} \mid \mathrm{X}) = -\frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K p_{i k} \log (p_{i k}) H(KX)=N1i=1Nk=1Kpiklog(pik) 是条件熵, p i k p_{i k} pik 是后验概率。
    • p i k p_{i k} pik 通常使用逻辑回归形式建模:
      p i k ∝ exp ⁡ ( θ k T z i + b k ) p_{i k} \propto \exp(\boldsymbol{\theta}_k^T \boldsymbol{z}_i + b_k) pikexp(θkTzi+bk)
      其中 z i = ϕ W ( x i ) \boldsymbol{z}_i = \phi_{\mathcal{W}}(\boldsymbol{x}_i) zi=ϕW(xi) 是深度网络的嵌入, W \mathcal{W} W 是网络参数, θ k \boldsymbol{\theta}_k θk b k b_k bk 是分类器的权重和偏置。
  • KL 散度目标函数(DEPICT 模型)(Section 2.2):
    DEPICT 模型通过引入辅助目标分布 Q \boldsymbol{Q} Q,最小化 KL 散度来优化聚类:
    min ⁡ Φ , Q K L ( Q ∥ P ) + γ ∑ k = 1 K q ^ k log ⁡ ( q ^ k ) s.t. q i T 1 = 1 , q i ≥ 0 ∀ i \min_{\Phi, \boldsymbol{Q}} \mathrm{KL}(\boldsymbol{Q} \| \boldsymbol{P}) + \gamma \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) \quad \text{s.t.} \quad \boldsymbol{q}_i^T \mathbf{1} = 1, \boldsymbol{q}_i \geq 0 \forall i Φ,QminKL(QP)+γk=1Kq^klog(q^k)s.t.qiT1=1,qi0∀i
    其中:

    • K L ( Q ∥ P ) = 1 N ∑ i = 1 N ∑ k = 1 K q i k log ⁡ ( q i k p i k ) \mathrm{KL}(\boldsymbol{Q} \| \boldsymbol{P}) = \frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log \left(\frac{q_{i k}}{p_{i k}}\right) KL(QP)=N1i=1Nk=1Kqiklog(pikqik) 是辅助分布 Q \boldsymbol{Q} Q 和后验 P \boldsymbol{P} P 之间的 KL 散度。
    • γ ∑ k = 1 K q ^ k log ⁡ ( q ^ k ) \gamma \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) γk=1Kq^klog(q^k) 是平衡项,促进聚类分配的均匀性。
    • Φ = { O , W } \Phi = \{\mathcal{O}, \mathcal{W}\} Φ={O,W} 包含分类器参数 O \mathcal{O} O 和网络参数 W \mathcal{W} W
  • 正则化互信息目标函数(Section 3):
    为了建立与 K-means 的联系,论文在互信息目标上加入 L 2 L_2 L2 正则化项:
    I ( X , K ) − λ ∑ k = 1 K θ k T θ k \mathcal{I}(\mathrm{X}, \mathrm{K}) - \lambda \sum_{k=1}^K \boldsymbol{\theta}_k^T \boldsymbol{\theta}_k I(X,K)λk=1KθkTθk
    其中 λ \lambda λ 是正则化参数。

2.2 生成式模型的目标函数(K-means)
  • 标准 K-means 目标函数(Section 3):
    传统的 K-means 目标是最小化数据点到聚类中心的距离:
    ∑ i = 1 N ∑ k = 1 K s i k ∥ z i − μ k ∥ 2 s.t. ∑ k = 1 K s i k = 1 , s i k ∈ { 0 , 1 } ∀ i , k \sum_{i=1}^N \sum_{k=1}^K s_{i k} \|\boldsymbol{z}_i - \boldsymbol{\mu}_k\|^2 \quad \text{s.t.} \quad \sum_{k=1}^K s_{i k} = 1, \quad s_{i k} \in \{0, 1\} \forall i, k i=1Nk=1Ksikziμk2s.t.k=1Ksik=1,sik{0,1}i,k
    其中 s i k s_{i k} sik 是二元分配变量, μ k \boldsymbol{\mu}_k μk 是聚类中心。

  • 软正则化 K-means(SR-K-means)目标函数(Section 3):
    论文证明了正则化互信息的优化等价于以下软正则化 K-means 损失:
    ∑ i = 1 N ∑ k = 1 K q i k ∥ z i − θ k ′ ∥ 2 + λ K ∑ i = 1 N ∑ k = 1 K q i k log ⁡ ( q i k ) − ∑ i = 1 N z i T z i \sum_{i=1}^N \sum_{k=1}^K q_{i k} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2 + \lambda K \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (q_{i k}) - \sum_{i=1}^N \boldsymbol{z}_i^T \boldsymbol{z}_i i=1Nk=1Kqikziθk2+λKi=1Nk=1Kqiklog(qik)i=1NziTzi
    其中:

    • θ k ′ = ∑ i = 1 N q i k z i ∑ i = 1 N q i k \boldsymbol{\theta}_k^{\prime} = \frac{\sum_{i=1}^N q_{i k} \boldsymbol{z}_i}{\sum_{i=1}^N q_{i k}} θk=i=1Nqiki=1Nqikzi 是软聚类中心。
    • q i k ∈ [ 0 , 1 ] q_{i k} \in [0, 1] qik[0,1] 是软分配变量,替代了硬分配 s i k s_{i k} sik
    • 第二项 λ K ∑ i = 1 N ∑ k = 1 K q i k log ⁡ ( q i k ) \lambda K \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (q_{i k}) λKi=1Nk=1Kqiklog(qik) 是负熵项,促进分配的软性。
2.3 附加的重建损失

为了防止嵌入过拟合,所有模型都结合了重建损失(Section 5.1):
R ( Z ) = 1 N ∑ i = 1 N ∑ l = 0 L − 1 1 ∣ z i l ∣ ∥ z i l − z ^ i l ∥ 2 \mathcal{R}(\mathcal{Z}) = \frac{1}{N} \sum_{i=1}^N \sum_{l=0}^{L-1} \frac{1}{|\boldsymbol{z}_i^l|} \|\boldsymbol{z}_i^l - \hat{\boldsymbol{z}}_i^l\|^2 R(Z)=N1i=1Nl=0L1zil1zilz^il2
其中 z i l \boldsymbol{z}_i^l zil z ^ i l \hat{\boldsymbol{z}}_i^l z^il 分别是第 l l l 层的干净嵌入和重建嵌入。

综合来看,实际优化目标通常是聚类损失(如互信息或 K-means 损失)与重建损失的组合,例如 SR-K-means 的目标为:
min ⁡ W 1 N λ K ∑ i = 1 N ∑ k = 1 K q i k ∥ z i − θ k ′ ∥ 2 − 1 N λ K ∑ i = 1 N z i T z i + R ( Z ) \min_{\mathcal{W}} \frac{1}{N \lambda K} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2 - \frac{1}{N \lambda K} \sum_{i=1}^N \boldsymbol{z}_i^T \boldsymbol{z}_i + \mathcal{R}(\mathcal{Z}) WminNλK1i=1Nk=1Kqikziθk2NλK1i=1NziTzi+R(Z)


3. 目标函数的优化过程
3.1 互信息目标的优化(MI-ADM)
  • 方法:使用交替方向法(ADM)优化正则化互信息:
    max ⁡ Φ , Q 1 N ∑ i = 1 N ∑ k = 1 K q i k log ⁡ ( p i k ) − ∑ k = 1 K q ^ k log ⁡ ( q ^ k ) s.t. Q = P , q i T 1 = 1 , q i ≥ 0 \max_{\Phi, \boldsymbol{Q}} \frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (p_{i k}) - \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) \quad \text{s.t.} \quad \boldsymbol{Q} = \boldsymbol{P}, \quad \boldsymbol{q}_i^T \mathbf{1} = 1, \quad \boldsymbol{q}_i \geq 0 Φ,QmaxN1i=1Nk=1Kqiklog(pik)k=1Kq^klog(q^k)s.t.Q=P,qiT1=1,qi0
    通过引入 KL 散度惩罚,将约束问题转化为无约束问题:
    max ⁡ Φ , Q 1 N ∑ i = 1 N ∑ k = 1 K q i k log ⁡ ( p i k ) − ∑ k = 1 K q ^ k log ⁡ ( q ^ k ) − K L ( Q ∥ P ) \max_{\Phi, \boldsymbol{Q}} \frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (p_{i k}) - \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) - \mathrm{KL}(\boldsymbol{Q} \| \boldsymbol{P}) Φ,QmaxN1i=1Nk=1Kqiklog(pik)k=1Kq^klog(q^k)KL(QP)

  • 优化步骤

    1. 参数学习步(更新 Φ \Phi Φ:固定 Q \boldsymbol{Q} Q,优化网络参数 Φ \Phi Φ,等价于最小化交叉熵损失:
      min ⁡ Φ − 1 N ∑ i = 1 N ∑ k = 1 K q i k log ⁡ p i k \min_{\Phi} -\frac{1}{N} \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log p_{i k} ΦminN1i=1Nk=1Kqiklogpik
      使用 SGD(Adam 优化器)更新 Φ \Phi Φ
    2. 目标估计步(更新 Q \boldsymbol{Q} Q:固定 Φ \Phi Φ,优化 Q \boldsymbol{Q} Q,得到闭式解:
      q i k ∝ p i k 2 ( ∑ i ′ = 1 N p i ′ k 2 ) 1 / 2 q_{i k} \propto \frac{p_{i k}^2}{\left(\sum_{i^{\prime}=1}^N p_{i^{\prime} k}^2\right)^{1/2}} qik(i=1Npik2)1/2pik2
  • 理论保证(Section 4.1):
    Proposition 3 证明了 ADM 优化的单调性,即互信息 I ( X , K ) \mathcal{I}(\mathrm{X}, \mathrm{K}) I(X,K) 在每次迭代中不减少。

3.2 KL 散度目标的优化(DEPICT)
  • 方法:通过交替优化解决:
    min ⁡ Φ , Q K L ( Q ∥ P ) + γ ∑ k = 1 K q ^ k log ⁡ ( q ^ k ) \min_{\Phi, \boldsymbol{Q}} \mathrm{KL}(\boldsymbol{Q} \| \boldsymbol{P}) + \gamma \sum_{k=1}^K \hat{q}_k \log (\hat{q}_k) Φ,QminKL(QP)+γk=1Kq^klog(q^k)

  • 优化步骤

    1. 参数学习步:固定 Q \boldsymbol{Q} Q,优化 Φ \Phi Φ,等价于交叉熵损失(同上)。
    2. 目标估计步:固定 Φ \Phi Φ,优化 Q \boldsymbol{Q} Q,得到:
      q i k ∝ p i k ( ∑ i ′ = 1 N p i ′ k ) 1 / 2 q_{i k} \propto \frac{p_{i k}}{\left(\sum_{i^{\prime}=1}^N p_{i^{\prime} k}\right)^{1/2}} qik(i=1Npik)1/2pik
  • 与 MI 的联系(Proposition 1):
    DEPICT 的优化可以看作是互信息最大化的近似 ADM 解法。

3.3 SR-K-means 目标的优化
  • 方法:优化软正则化 K-means 损失:
    ∑ i = 1 N ∑ k = 1 K q i k ∥ z i − θ k ′ ∥ 2 + λ K ∑ i = 1 N ∑ k = 1 K q i k log ⁡ ( q i k ) − ∑ i = 1 N z i T z i + R ( Z ) \sum_{i=1}^N \sum_{k=1}^K q_{i k} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2 + \lambda K \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (q_{i k}) - \sum_{i=1}^N \boldsymbol{z}_i^T \boldsymbol{z}_i + \mathcal{R}(\mathcal{Z}) i=1Nk=1Kqikziθk2+λKi=1Nk=1Kqiklog(qik)i=1NziTzi+R(Z)

  • 优化步骤

    1. 更新聚类中心 θ k ′ \boldsymbol{\theta}_k^{\prime} θk
      θ k ′ = ∑ i = 1 N q i k z i ∑ i = 1 N q i k \boldsymbol{\theta}_k^{\prime} = \frac{\sum_{i=1}^N q_{i k} \boldsymbol{z}_i}{\sum_{i=1}^N q_{i k}} θk=i=1Nqiki=1Nqikzi
    2. 更新分配变量 q i k q_{i k} qik
      q i k ∝ exp ⁡ ( − 1 λ K ∥ z i − θ k ′ ∥ 2 ) q_{i k} \propto \exp \left(-\frac{1}{\lambda K} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2\right) qikexp(λK1ziθk2)
    3. 更新网络参数 W \mathcal{W} W:通过 SGD 优化嵌入 z i \boldsymbol{z}_i zi 和重建损失 R ( Z ) \mathcal{R}(\mathcal{Z}) R(Z)

4. 主要贡献点
  1. 理论联系

    • 证明了判别式模型(基于互信息和 KL 散度)和 K-means 的等价性(Proposition 2)。
    • 建立了 DEPICT 和互信息目标之间的联系(Proposition 1),表明 DEPICT 是互信息最大化的近似 ADM 解法。
  2. 新算法提出

    • 基于理论推导,提出了软正则化 K-means(SR-K-means)算法,结合深度网络和重建损失,提升了生成式模型的性能。
  3. 实验验证

    • 在多个图像聚类基准数据集(如 MNIST、USPS、YTF)上验证了 SR-K-means 和 MI-ADM 的性能,与 DEPICT 等判别式模型相当。
    • SR-K-means 比传统硬 K-means(DCN)提高了 11% 的性能(Table 2)。
  4. 优化分析

    • 证明了 ADM 优化的单调性(Proposition 3)。
    • 分析了 KL 散度和二次惩罚之间的关系,指出 KL 散度在单纯形约束下的计算优势。

5. 针对目标函数的局限性提出改进意见
5.1 目标函数的局限性
  1. 对平衡聚类的假设

    • SR-K-means 的推导假设了聚类分配是平衡的( q ^ k ≈ 1 K \hat{q}_k \approx \frac{1}{K} q^kK1),但实际数据(如 YTF、FRGC)往往是不平衡的(Table 1),这可能导致性能下降。
    • 改进建议:引入更灵活的平衡约束,例如使用可调的先验分布 d ^ k \hat{d}_k d^k(如论文中提到的 K L ( ( p ^ k ) ∥ ( d ^ k ) ) \mathrm{KL}((\hat{p}_k) \| (\hat{d}_k)) KL((p^k)(d^k))),以适应不平衡数据。
  2. 逻辑回归后验的限制

    • 目标函数依赖逻辑回归后验 p i k ∝ exp ⁡ ( θ k T z i + b k ) p_{i k} \propto \exp(\boldsymbol{\theta}_k^T \boldsymbol{z}_i + b_k) pikexp(θkTzi+bk),假设嵌入空间中的类别边界是线性的,可能无法捕捉复杂的非线性分布。
    • 改进建议:尝试更复杂的后验模型,例如使用核方法(如高斯核)或深度网络直接输出后验分布,以捕捉非线性关系。
  3. 软分配的熵正则化

    • SR-K-means 中的负熵项 λ K ∑ i = 1 N ∑ k = 1 K q i k log ⁡ ( q i k ) \lambda K \sum_{i=1}^N \sum_{k=1}^K q_{i k} \log (q_{i k}) λKi=1Nk=1Kqiklog(qik) 促进软分配,但可能导致分配过于平滑,削弱聚类的区分性。
    • 改进建议:引入可调的温度参数 τ \tau τ,修改软分配公式为:
      q i k ∝ exp ⁡ ( − 1 τ ∥ z i − θ k ′ ∥ 2 ) q_{i k} \propto \exp \left(-\frac{1}{\tau} \|\boldsymbol{z}_i - \boldsymbol{\theta}_k^{\prime}\|^2\right) qikexp(τ1ziθk2)
      通过调整 τ \tau τ 控制分配的软性程度。
  4. 对高维和流形结构的适应性不足

    • K-means 和 SR-K-means 依赖欧几里得距离,难以处理高维数据中的流形结构(Section 6)。
    • 改进建议:如论文结尾建议,探索其他原型方法(如 K-modes)或成对聚类目标(如归一化割 normalized cut),以更好地捕捉数据的流形结构。
  5. 优化方法的局限性

    • ADM 优化依赖交替更新,可能陷入局部最优,且对初始值敏感(Section 5.2.4)。
    • 改进建议:引入全局优化技术(如模拟退火)或多样化的初始化策略,减少对初始值的依赖。此外,可以尝试其他距离度量(如 Bhattacharyya 距离,Section 6)替代 KL 散度,改善收敛性。

总结
  • 核心思想:论文通过数学推导揭示了判别式模型和 K-means 的等价性,统一了深度聚类的两种范式,并提出了 SR-K-means 算法。
  • 目标函数:包括互信息、KL 散度和软正则化 K-means 损失,通过 ADM 或交替优化实现联合特征学习和聚类。
  • 主要贡献:理论联系、新算法、实验验证和优化分析。
  • 改进建议:针对不平衡数据、非线性分布、软分配平滑、高维流形结构和优化局限性,提出了引入灵活约束、更复杂后验模型、可调软分配、流形目标和全局优化的改进方向。
1. 理解这句话的含义

原句

  • 论文中判别式深度聚类模型使用逻辑回归后验 P ( K ∣ X ) P(K|X) P(KX)
    p i k ∝ exp ⁡ ( θ k T z i + b k ) p_{ik} \propto \exp(\theta_k^T z_i + b_k) pikexp(θkTzi+bk)

逐部分拆解

  1. 判别式深度聚类模型

    • 判别式模型关注条件概率 P ( K ∣ X ) P(K|X) P(KX),即给定输入 X X X,预测其属于某个聚类 K K K 的概率。
    • 在深度聚类中,输入 X X X 通常通过深度神经网络(DNN)映射到嵌入空间 Z Z Z,即 z i = ϕ W ( x i ) z_i = \phi_{\mathcal{W}}(x_i) zi=ϕW(xi),其中 W \mathcal{W} W 是网络参数。
    • 目标是同时学习嵌入 z i z_i zi 和聚类分配 p i k p_{ik} pik(样本 i i i 属于聚类 k k k 的概率)。
  2. 逻辑回归后验 P ( K ∣ X ) P(K|X) P(KX)

    • 判别式模型通过逻辑回归(Logistic Regression)建模条件概率 P ( K ∣ X ) P(K|X) P(KX)
    • p i k p_{ik} pik 表示样本 i i i 属于聚类 k k k 的后验概率,形式为:
      p i k ∝ exp ⁡ ( θ k T z i + b k ) p_{ik} \propto \exp(\theta_k^T z_i + b_k) pikexp(θkTzi+bk)
    • 这是一个softmax 形式的概率分布, ∝ \propto 表示比例关系,最终需要归一化以满足概率分布的要求。
  3. 公式中的符号

    • p i k p_{ik} pik:样本 i i i 属于聚类 k k k 的概率。
    • z i z_i zi:样本 i i i 的嵌入表示,由深度网络 ϕ W ( x i ) \phi_{\mathcal{W}}(x_i) ϕW(xi) 生成。
    • θ k \theta_k θk:聚类 k k k 对应的权重向量。
    • b k b_k bk:聚类 k k k 对应的偏置(bias)。
    • θ k T z i + b k \theta_k^T z_i + b_k θkTzi+bk:线性组合,表示样本 i i i 与聚类 k k k 的“匹配得分”。
  4. 归一化后的完整形式

    • 为了确保 p i k p_{ik} pik 是一个有效的概率分布(满足 ∑ k = 1 K p i k = 1 \sum_{k=1}^K p_{ik} = 1 k=1Kpik=1),需要对指数项进行归一化:
      p i k = exp ⁡ ( θ k T z i + b k ) ∑ j = 1 K exp ⁡ ( θ j T z i + b j ) p_{ik} = \frac{\exp(\theta_k^T z_i + b_k)}{\sum_{j=1}^K \exp(\theta_j^T z_i + b_j)} pik=j=1Kexp(θjTzi+bj)exp(θkTzi+bk)
    • 这是 softmax 函数的经典形式,广泛用于多分类问题中。
  5. 直观理解

    • exp ⁡ ( θ k T z i + b k ) \exp(\theta_k^T z_i + b_k) exp(θkTzi+bk) 是一个非负的得分,分数越高,样本 i i i 越可能属于聚类 k k k
    • 归一化后, p i k p_{ik} pik 是一个概率值,表示样本 i i i 属于聚类 k k k 的可能性。
    • θ k \theta_k θk b k b_k bk 是可学习的参数,通过优化目标函数(如互信息或 KL 散度)来调整,使得 p i k p_{ik} pik 能更好地反映样本的聚类分配。

2. 理论背景:为什么用这种形式?
  • 逻辑回归的适用性

    • 逻辑回归是一种经典的判别式模型,适用于多分类任务(如聚类)。
    • 在深度聚类中,逻辑回归后验可以无缝嵌入深度网络,通过梯度下降优化参数 θ k \theta_k θk b k b_k bk W \mathcal{W} W
  • 深度聚类的目标

    • 深度聚类希望联合学习特征嵌入 z i z_i zi 和聚类分配 p i k p_{ik} pik
    • 逻辑回归后验 p i k p_{ik} pik 提供了概率化的分配方式,便于优化目标函数(如互信息 I ( X , K ) \mathcal{I}(\mathrm{X}, \mathrm{K}) I(X,K) 或 KL 散度)。
  • 与论文的联系

    • 论文《Deep clustering: On the link between discriminative models and K-means》中,判别式模型(如 DEPICT、MI-ADM)使用这种逻辑回归后验来建模 P ( K ∣ X ) P(K|X) P(KX)
    • 这种形式还被用来证明判别式模型与 K-means 的等价性(Proposition 2),因为它可以被转化为软 K-means 形式。

3. 详细例子:基于 MNIST 数据集的深度聚类

为了更直观地理解这句话,我们通过一个具体的例子来说明:使用深度聚类模型对 MNIST 数据集(手写数字 0-9)进行聚类。

3.1 问题背景
  • 数据集:MNIST 包含 70,000 张 28x28 像素的手写数字图像,分为 10 个类别(0-9)。我们假设没有标签(无监督聚类任务),目标是将图像分为 10 个聚类。
  • 任务:使用判别式深度聚类模型,学习特征嵌入 z i z_i zi 并预测每个图像的聚类概率 p i k p_{ik} pik
3.2 模型设置
  1. 深度网络

    • 输入: x i ∈ R 28 × 28 x_i \in \mathbb{R}^{28 \times 28} xiR28×28,是第 i i i 张图像(展平后为 784 维向量)。
    • 深度网络 ϕ W \phi_{\mathcal{W}} ϕW:一个卷积神经网络(CNN)或全连接网络,将 x i x_i xi 映射到嵌入空间 z i ∈ R d z_i \in \mathbb{R}^{d} ziRd(假设 d = 10 d = 10 d=10)。
      • 例如:CNN 包含 2 个卷积层(带 ReLU 激活)和 1 个全连接层,最终输出 10 维嵌入 z i z_i zi
  2. 逻辑回归后验

    • 聚类数 K = 10 K = 10 K=10(对应 0-9 数字)。
    • 对于每个聚类 k k k k = 1 , … , 10 k = 1, \ldots, 10 k=1,,10),有权重 θ k ∈ R 10 \theta_k \in \mathbb{R}^{10} θkR10 和偏置 b k ∈ R b_k \in \mathbb{R} bkR
    • 后验概率:
      p i k = exp ⁡ ( θ k T z i + b k ) ∑ j = 1 10 exp ⁡ ( θ j T z i + b j ) p_{ik} = \frac{\exp(\theta_k^T z_i + b_k)}{\sum_{j=1}^{10} \exp(\theta_j^T z_i + b_j)} pik=j=110exp(θjTzi+bj)exp(θkTzi+bk)
  3. 目标函数

    • 假设使用互信息目标(类似论文中的 MI-ADM):
      I ( X , K ) = H ( K ) − H ( K ∣ X ) \mathcal{I}(\mathrm{X}, \mathrm{K}) = \mathcal{H}(\mathrm{K}) - \mathcal{H}(\mathrm{K} \mid \mathrm{X}) I(X,K)=H(K)H(KX)
      其中:
      • H ( K ) = − ∑ k = 1 10 p ^ k log ⁡ ( p ^ k ) \mathcal{H}(\mathrm{K}) = -\sum_{k=1}^{10} \hat{p}_k \log (\hat{p}_k) H(K)=k=110p^klog(p^k) p ^ k = 1 N ∑ i = 1 N p i k \hat{p}_k = \frac{1}{N} \sum_{i=1}^N p_{ik} p^k=N1i=1Npik
      • H ( K ∣ X ) = − 1 N ∑ i = 1 N ∑ k = 1 10 p i k log ⁡ ( p i k ) \mathcal{H}(\mathrm{K} \mid \mathrm{X}) = -\frac{1}{N} \sum_{i=1}^N \sum_{k=1}^{10} p_{ik} \log (p_{ik}) H(KX)=N1i=1Nk=110piklog(pik)
  4. 优化目标

    • 互信息目标会利用 p i k p_{ik} pik 计算 H ( K ∣ X ) \mathcal{H}(\mathrm{K} \mid \mathrm{X}) H(KX) H ( K ) \mathcal{H}(\mathrm{K}) H(K),并通过梯度下降更新 θ k \theta_k θk b k b_k bk W \mathcal{W} W,使 p i 3 p_{i3} pi3 进一步增大。
    • 例如, H ( K ∣ X ) \mathcal{H}(\mathrm{K} \mid \mathrm{X}) H(KX) 鼓励 p i k p_{ik} pik 更“尖锐”(接近 one-hot 分布),从而减少不确定性。
3. 优化过程
  • 交替优化(类似论文中的 ADM):
    1. 更新网络参数 W \mathcal{W} W:固定 p i k p_{ik} pik,优化嵌入 z i z_i zi,使特征更适合聚类。
    2. 更新分类参数 θ k , b k \theta_k, b_k θk,bk:固定 z i z_i zi,调整 θ k \theta_k θk b k b_k bk,使 p i k p_{ik} pik 更准确。
  • 结果:经过多轮迭代,模型可能学习到:
    • 聚类 3 对应数字“3”, p i 3 p_{i3} pi3 接近 1。
    • 嵌入空间 z i z_i zi 将不同数字的图像分离开来。

4. 直观理解
  • 逻辑回归后验的作用

    • θ k T z i + b k \theta_k^T z_i + b_k θkTzi+bk 可以看作样本 i i i 与聚类 k k k 的“相似度得分”。
    • 指数化 exp ⁡ ( ⋅ ) \exp(\cdot) exp() 将得分映射为正值,分数越高,概率越大。
    • 归一化(softmax)确保 p i k p_{ik} pik 是概率分布,反映样本 i i i 属于各聚类的相对可能性。
  • 深度聚类的优势

    • 嵌入 z i z_i zi 由深度网络学习,能捕捉高维数据(如图像)的复杂特征。
    • p i k p_{ik} pik 的概率化分配支持软聚类,便于优化(如通过互信息或 KL 散度)。
  • 与 K-means 的联系

    • 论文证明,这种逻辑回归后验在特定条件下( L 2 L_2 L2 正则化和平衡假设)等价于软 K-means 损失。
    • 直观上, θ k \theta_k θk 可以看作聚类中心, θ k T z i \theta_k^T z_i θkTzi 类似 z i z_i zi 到中心的距离(经过变换)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Christo3

你的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值