CCandPCM
This is the conde for the paper “Clustering Algorithm Based on Contrastive Learning and Partition Confidence Maximization”
【Abstract】 Existing deep clustering methods based on contrastive learning mainly face the following two problems: First, most methods require a large number of negative samples, meaning that the clustering effect can only be improved by increasing the number of negative samples. Second, these methods usually simply output the results after cluster-level contrastive learning without aiming to achieve higher confidence in clustering assignments. To address these challenges, this paper proposes a novel clustering algorithm that integrates contrastive learning with partition confidence maximization. This algorithm is built upon Contrastive Clustering (CC) and SimSiam (Exploring Simple Siamese Representation Learning). CC extends traditional instance-level contrastive learning to cluster-level contrastive learning, which enhances model performance and enables direct output of clustering results. SimSiam eliminates the dependency on negative samples by implementing a stop-gradient operation at one end of the predictor. Specifically, this algorithm extends the dual-view structure to a tri-view structure (comprising two prediction networks and one target network) and subsequently performs both instance-level contrastive learning and cluster-level contrastive learning. Next, it maximizes partition confidence on the feature matrix generated by cluster-level contrastive learning in order to achieve high-confidence clustering assignments. Experimental results on six public image datasets, including CIFAR10 and CIFAR100, show that this algorithm achieves average improvements of 5.21%, 4.72%, and 7.34% in normalized mutual information (NMI), accuracy (ACC), and adjusted Rand index (ARI) respectively compared to the CC algorithm, validating the effectiveness of the proposed algorithm.
Dependency
The environment.yml
is provided in our code, and you can use conda to create an environment based on this
conda env create -f environment.yml
conda activate myenv
conda list
Performance
The clustering results for the six challenging image datasets (NMI, ACC and ARI) are shown in the table below
Dataset | NMI | ACC | ARI |
---|---|---|---|
CIFAR-10 | 79.6 | 86.8 | 75.5 |
CIFAR-100 | 46.3 | 45.9 | 29.1 |
STL-10 | 75.0 | 83.7 | 70.7 |
ImageNet-10 | 89.2 | 91.6 | 86.0 |
ImageNet-dogs | 52.1 | 50.8 | 32.1 |
Tiny-ImageNet | 34.4 | 14.3 | 6.8 |
Convergence process for the CIFAR-10 dataset
Confusion matrix for CIFAR-10 and STL-10 datasets
Usage
Configuration
All configuration operations can be modified in config/config.yaml
DataSets
Six image datasets were used for training and testing, and their download links are as follows
- download link:Password: pfhx
Training
After you have completed the required configuration operations, you can run the following command to start training. In particular, note that the STL dataset is trained differently than the other datasets because it requires an additional 10 unlabelled data
# other dataset
python train.py
# STL-10 dataset
python train_STL.py
Testing
Once you have completed the training you can test the clustering effect with the following command
- Note:The ImageNet-10 dataset is tested using the full dataset
python cluster.py
Alternatively, you can use our trained model for testing, download link below. Just change the filenames in the configuration file and in cluster.py
.
- download link:Password: t8wm
Acknowledgments
We would like to acknowledge the following repository for providing valuable resources and inspiration for our project: