kmeans算法是较为常见的聚类算法,不仅可以对二维的坐标点进行聚类,还可以对高维的图像信息进行聚类。Kmeans算法对初始质心的选择比较敏感,Kmeans++算法针对初始质心的选择做了改进,使得几个初始质心尽可能的远。
在使用kmeans算法对二维坐标进行聚类时,聚类的依据是坐标点与质心之间的距离;同样,对于高维度的图像信息,可以将像素点之间的差异看作距离,这样得到的每个簇,都是像素点差异较小的图,简单来说,每个簇内是图像相似度较高的图像。
这里使用kmeans++算法,对CIFAR10数据集进行聚类。CIFAR10是一个用于图像分类的数据集,共有10个类别,每张图像的大小为32*32*3。程序在CIFAR10数据集内挑选出了200张图像,并对这200张图像进行聚类,k值设为10。
聚类结果:
簇1
簇2
簇3
簇4
簇5
簇6
簇7
簇8
簇9
簇10
其中,avg.jpg是质心,是簇内所有图像的平均值,是对簇内信息的抽象反映。
可以看出,聚类效果还是不错的,把一些较为相似的图像放在了一个簇内。
代码实现:
import torchvision
import torch
import numpy as np
from torch.utils.data import DataLoader
import os
import cv2
import random
def load_data():
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
)
train = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train, batch_size=1)
for i, (img, _) in enumerate(train_loader):
images.append(np.transpose(img[0].numpy(), (1, 2, 0)))
if i > 200:
break
def distance(img, centroids):
return np.array([np.sum(centroid - img) ** 2 for centroid in centroids])
def kmeans_plus(k):
centroids = []
idx = random.randint(0, len(images) - 1)
centroids.append(images.pop(idx))
for _ in range(k - 1):
sum = 0
dx = np.zeros((len(images),))
for i, img in enumerate(images):
dx[i] = np.min([(img - centroid) ** 2 for centroid in centroids])
sum += dx.sum()
p = np.array(dx) / sum
max_idx = np.argmax(p)
centroids.append(images.pop(max_idx))
print("finish")
return centroids
def kmeans(k=10):
centroids = kmeans_plus(k)
clu = dict()
for epoch in range(100):
for i in range(k):
clu[i] = []
for img in images:
index = distance(img, centroids).argmin()
clu[index].append(img)
for i in range(k):
sum = np.zeros_like(img)
for img in clu[i]:
sum += img
mean = sum / len(clu[i])
centroids[i] = mean
print(epoch)
return clu, centroids
if __name__ == '__main__':
images = []
load_data()
clu, centroids = kmeans()
for i in range(len(centroids)):
os.mkdir(f"./{i}")
ims = clu[i]
sum = np.zeros_like(images[0].shape)
for idx, im in enumerate(ims):
cv2.imwrite(f"./{i}/{idx}.jpg", cv2.resize(im * 255, (320, 320)))
sum = sum + im
sum = sum / (idx + 1)
cv2.imwrite(f"./{i}/avg.jpg", cv2.resize((sum * 255), (320, 320)))