ICLR-2025《Towards Calibrated Deep Clustering Network》


推荐一个机器学习前沿公众号,第一时间获取最有价值的前沿机器学习文章。

在这里插入图片描述


核心思想

该论文提出了一种校准深度聚类网络(Calibrated Deep Clustering, CDC),旨在解决传统深度聚类方法中的过自信问题,即模型预测的置信度远超其实际预测准确率的问题。核心思想是通过引入一个双头网络结构(聚类头和校准头),实现置信度的校准和聚类性能的提升。校准头负责调整聚类头的过高置信度,使预测置信度与模型的实际学习状态相匹配,而聚类头利用校准头提供的可靠高置信度样本进行伪标签自训练。此外,论文提出了一种基于特征原型的网络初始化策略,增强训练速度和网络鲁棒性。实验表明,该方法在校准误差(ECE)和聚类准确率上均显著优于现有方法。

目标函数

CDC的目标函数由两部分组成,分别对应校准头和聚类头的优化:

  1. 校准头的目标函数
    校准头的目标是使模型的预测置信度与实际准确率对齐,同时避免所有样本被分配到同一簇。校准头的损失函数由两部分组成:

    • 校准损失( L cal \mathcal{L}_{\text{cal}} Lcal:通过K均值算法将特征嵌入分为 K K K个迷你簇,计算每个迷你簇在聚类头的平均预测 q ^ k = ∑ x i ∈ Q k p i clu ∣ Q k ∣ \hat{\boldsymbol{q}}_k = \frac{\sum_{\boldsymbol{x}_i \in Q_k} \boldsymbol{p}_i^{\text{clu}}}{|Q_k|} q^k=QkxiQkpiclu,并将其作为校准头中该迷你簇样本的目标分布。校准损失定义为:
      L cal = − 1 B ∑ k ∑ x i ∈ Q k q ^ k log ⁡ ( p i cal ) , \mathcal{L}_{\text{cal}} = -\frac{1}{B} \sum_k \sum_{\boldsymbol{x}_i \in Q_k} \hat{\boldsymbol{q}}_k \log \left( \boldsymbol{p}_i^{\text{cal}} \right), Lcal=B1kxiQkq^klog(pical),
      其中 p i cal = σ ( g ( θ cal ; f ( Θ ; x i ) ) ) \boldsymbol{p}_i^{\text{cal}} = \sigma\left(g\left(\boldsymbol{\theta}_{\text{cal}}; f\left(\boldsymbol{\Theta}; \boldsymbol{x}_i\right)\right)\right) pical=σ(g(θcal;f(Θ;xi)))是校准头的预测, B B B是批次大小。
    • 负熵损失( L en \mathcal{L}_{\text{en}} Len:用于使预测类分布更均匀,防止所有样本被分配到同一簇:
      L en = 1 C ∑ j = 1 C p c , j cal log ⁡ p c , j cal , \mathcal{L}_{\text{en}} = \frac{1}{C} \sum_{j=1}^C \boldsymbol{p}_{c,j}^{\text{cal}} \log \boldsymbol{p}_{c,j}^{\text{cal}}, Len=C1j=1Cpc,jcallogpc,jcal,
      其中 p c , j cal \boldsymbol{p}_{c,j}^{\text{cal}} pc,jcal是校准头对批次样本的第 j j j类预测。
    • 总校准损失
      L = L cal + w en L en , \mathcal{L} = \mathcal{L}_{\text{cal}} + w_{\text{en}} \mathcal{L}_{\text{en}}, L=Lcal+wenLen,
      其中 w en = 1 w_{\text{en}} = 1 wen=1为超参数。
  2. 聚类头的目标函数
    聚类头通过伪标签自训练优化,使用校准头提供的置信度动态选择高置信度样本。聚类头的损失函数为交叉熵损失:
    L clu = − 1 ∣ S ∣ ∑ x i ∈ S y i log ⁡ p i s,clu , \mathcal{L}_{\text{clu}} = -\frac{1}{|S|} \sum_{\boldsymbol{x}_i \in S} y_i \log \boldsymbol{p}_i^{\text{s,clu}}, Lclu=S1xiSyilogpis,clu,
    其中 S S S是选出的伪标签样本集合, y i = arg ⁡ max ⁡ p i w,cal y_i = \arg\max \boldsymbol{p}_i^{\text{w,cal}} yi=argmaxpiw,cal是伪标签, p i s,clu \boldsymbol{p}_i^{\text{s,clu}} pis,clu是聚类头对强增强样本的预测, ∣ S ∣ |S| S是伪标签样本数量。

目标函数的优化过程

CDC的优化过程分为以下步骤:

  1. 预训练

    • 使用MoCo-v2(一种自监督学习方法)预训练特征提取器 f ( Θ ; ⋅ ) f(\boldsymbol{\Theta}; \cdot) f(Θ;),以获得初始的判别性特征表示。
  2. 初始化

    • 对聚类头和校准头(均为三层MLP)进行基于特征原型的初始化:
      • 对特征 z \boldsymbol{z} z进行K均值聚类,得到 H H H个原型,初始化第一层权重 W ( 1 ) = Kmeans H ( z ) \boldsymbol{W}^{(1)} = \text{Kmeans}_H(\boldsymbol{z}) W(1)=KmeansH(z)
      • 对隐藏层输出 h \boldsymbol{h} h进行K均值聚类,初始化第二层权重 W ( 2 ) = Kmeans C ( h ) \boldsymbol{W}^{(2)} = \text{Kmeans}_C(\boldsymbol{h}) W(2)=KmeansC(h)
      • 对权重进行正交化以增强判别能力。
    • 这一初始化策略通过Proposition 1证明能有效传递预训练特征的判别性。
  3. 联合优化

    • 校准头优化
      • 在每个批次中,使用K均值将特征分为 K K K个迷你簇,计算 q ^ k \hat{\boldsymbol{q}}_k q^k
      • 根据 L cal \mathcal{L}_{\text{cal}} Lcal L en \mathcal{L}_{\text{en}} Len计算总损失 L \mathcal{L} L,仅优化校准头参数 θ cal \boldsymbol{\theta}_{\text{cal}} θcal(使用停止梯度策略,避免影响整个网络)。
    • 聚类头优化
      • 使用校准头的置信度 p i w,cal \boldsymbol{p}_i^{\text{w,cal}} piw,cal动态选择伪标签样本:
        • 对每个类 c c c,按置信度降序排序样本,选择前 ⌊ B / C ⌋ \lfloor B/C \rfloor B/C个样本构成 T O P ( c ) TOP(c) TOP(c)
        • 计算该类的伪标签样本数量 M ( c ) = ⌊ ∑ x i ∈ T O P ( c ) p i w,cal ⌋ M(c) = \lfloor \sum_{\boldsymbol{x}_i \in TOP(c)} \boldsymbol{p}_i^{\text{w,cal}} \rfloor M(c)=xiTOP(c)piw,cal
        • 选择前 M ( c ) M(c) M(c)个样本作为伪标签样本 S S S
      • 使用 L clu \mathcal{L}_{\text{clu}} Lclu优化特征提取器 Θ \boldsymbol{\Theta} Θ和聚类头参数 θ clu \boldsymbol{\theta}_{\text{clu}} θclu
    • 校准头和聚类头同时优化,相互协作:校准头提供可靠置信度,聚类头利用这些置信度选择高质量伪标签。
  4. 最终预测

    • 使用校准头的输出 p i cal = σ ( g ( θ cal ; f ( Θ ; x i ) ) ) \boldsymbol{p}_i^{\text{cal}} = \sigma\left(g\left(\boldsymbol{\theta}_{\text{cal}}; f\left(\boldsymbol{\Theta}; \boldsymbol{x}_i\right)\right)\right) pical=σ(g(θcal;f(Θ;xi)))作为最终聚类预测,因其校准误差(ECE)低于聚类头。

优化使用Adam优化器,训练100个epoch,学习率分别为编码器 5 × 1 0 − 5 5 \times 10^{-5} 5×105(CIFAR-20和Tiny-ImageNet调整为 1 0 − 5 10^{-5} 105)和MLP 1 0 − 4 10^{-4} 104

主要贡献点

  1. 提出校准深度聚类框架

    • 首次针对深度聚类的过自信问题,提出双头网络(聚类头和校准头),通过校准头调整置信度,使其与实际准确率对齐,显著降低校准误差(ECE降低约5倍)。
  2. 动态伪标签选择策略

    • 利用校准头的置信度动态选择伪标签样本,克服固定阈值带来的问题(如早期样本不足、后期噪声伪标签增加),并通过类特定阈值缓解类不平衡问题。
  3. 特征原型初始化策略

    • 提出基于K均值原型的初始化方法,将预训练特征的判别性传递到聚类头和校准头,显著提升初始聚类准确率(例如,CIFAR-20从10.4%提升到更高水平)。
  4. 理论保证

    • 通过Theorem 1证明校准方法仅对不可靠区域的置信度进行惩罚,保留高置信度样本的可靠性。
    • 通过Theorem 2证明校准头的ECE低于聚类头。
    • 通过Proposition 1证明初始化策略能有效传递特征判别性。
  5. 广泛实验验证

    • 在六个基准数据集(CIFAR-10、CIFAR-20、STL-10、ImageNet-10、ImageNet-Dogs、Tiny-ImageNet)上验证,CDC在聚类准确率(ACC)、归一化互信息(NMI)、调整兰德指数(ARI)和ECE上全面优于现有方法。
    • 展示了CDC作为自标签阶段的通用性,应用于SCAN和SeCu后进一步提升性能。
  6. 实际应用潜力

    • 校准后的置信度增强了模型在可信决策系统(如医疗诊断、自动驾驶)的适用性。
    • 改进的OOD检测能力(AUROC提升2.9%,FPR95降低19.5%)显示其在处理未知样本时的鲁棒性。

总结

该论文通过创新的双头网络结构、动态伪标签选择和特征原型初始化,成功解决了深度聚类的过自信问题,同时提升了聚类性能。其目标函数设计合理,优化过程高效,理论分析严谨,实验结果令人信服,为深度聚类领域提供了重要的理论和实践贡献。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Christo3

你的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值