Abstract
非监督图像聚类算法通常是提出一个辅助目标函数间接训练模型,并且聚类结果受到错误的预测和过于自信(overconfidence)的结果的影响,作者通过提出RUC (Robust learning for Unsupervised Clustering)模块解决这个问题,该模块将现有聚类算法生成的伪标签(可能会包含错误分类的样本)看作噪声样本,而它的重新训练过程可以纠正错误分类并缓解过度自信的问题。该模块可以作为其他聚类算法的附加模块用来提高精度
RUC主要由两个部分组成:1. extracting clean samples 2. retraining with the refined dataset
作者探索了confidence-based,metric-based,hybrid 三个策略用来过滤掉误分类的伪标签
- confidence-based
将原聚类算法给出高置信度的样本看作干净样本,剔除置信度小的样本
- metric-based
利用无监督embedding模块的相似度度量,使用非参数化分类器通过检查给定的实例与k-nearest样本的label是否相同来检测干净样本
- hybrid
同时根据两个方法筛选干净样本
紧接着作者使用半监督方法MixMatch来retrain模型,该方法主要将干净样本看作有标签数据,不干净样本看作无标签样本,同时还采用了smooth label。最后使用co-training模型减少训练过程中不干净样本的噪声积累,提高性能
Method
RUC模块的过程结构图如下
- Extracting Clean Samples
定义: 为训练数据集( 为图片, 为伪标签),数据集 可以被分为两个部分 ( 为干净数据)
- Confidence-based strategy
给定训练样本 如果 (即属于某个cluster的置信度高于某个阈值),则将其加入集合 否则加入集合 。通常阈值 设得很高,以消除尽可能多的不确定样本。
- Metric-based strategy
上面方法的缺陷在于其全部依赖于无监督的分类器,本方法利用通过无监督方法(如:SimCLR)训练的embedding网络 ,根据伪标签与使用 得到的分类结果的一致程度来衡量伪标签的可信度
对于每个 ,计算其embedding 并且使用基于kNN的无参数分类器得到 ,如果 ,则将其加入 ,否则加入
- Hybrid strategy
如果一个样本同时满足上两个方法,则加入 ,否则加入
2. Retraining via Robust Learning
给定 和 ,下一步是refine分类器 纠正原始无监督聚类算法的错误。
- Vanilla semi-supervised learning
作者使用MixMatch作为baseline,该算法从使用MixUp数据增强方法得到的无标签数据中估计低熵混合标签(low-entropy mixed label),具体来说,给定从有标签或无标签数据集中采样的一对样本 ,数据增强操作如下
MixMatch采用一个代理标签 ,其是锐化后多个增强图片上模型预测的平均
经过MixMatch得到 ,半监督模型存在两个独立的loss:1. 有标签数据集 上的交叉熵 2. 无标签数据集 的一致性正则化,以下是具体过程
其中 表示 和 之间的交叉熵
- Label Smoothing
在半监督学习模型上使用label smoothing改进模型的预测校准,其label smoothing通过混合均匀分布来实现。
其中 是类别数量, 为噪音
计算soft label 和随机增强后的强增强样本 的预测标签之间的交叉熵,我们发现,强增强可以使噪声样本的记忆最小化
则最终的训练优化目标为
- Co-training
单一的网络存在对不正确的伪标签过拟合的缺陷,因此加入co-training模块
模块中两个网络 ,它们平行训练并且通过在MixMatch基础上添加co-refinement来交换它们的guesses以便于相互teaching,其中co-refinement是标签refinement的过程,目标是通过合并两个网络的预测结果产生可靠的label。我们在 和 上都进行co-refinement操作,下面是从 的角度(给定一个样本 ,其原标签为 )展现co-refinement的过程
其中 为对立网络关于 的置信度, 为sharpen temperature
对于无标签数据集 ,使用两个网络预测的结果猜测样本 的伪标签
其中 为 的第 个弱增强样本
通过上面操作,co-refinement构建了refined数据集 ,代替了原始数据集 ,则将两个数据集作为MixMatch的输入
最后网络的优化目标为
即,将上面定义的 换成
- Co-refurbishing
在训练过程中的每个epoch的最后,作者翻新噪声样本来得到额外干净样本,如果给定不干净样本 ,至少一个网络的置信度超过阈值 ,则用网络的预测 更新对应样本的标签,并且该样本被认为是干净样本,加入 中
其中 表示 的one-hot编码(第 个元素值为1, )
整体过程的伪代码如下
编辑于 03-28