Kmeans++ 对图像聚类

文章介绍了kmeans和kmeans++算法在图像聚类中的应用,特别是在CIFAR10数据集上。通过kmeans++选择初始质心以提高聚类效果,对200张图像进行聚类,形成10个簇,每个簇内的图像具有较高相似度。代码示例展示了如何用Python实现这一过程。
摘要由CSDN通过智能技术生成

        kmeans算法是较为常见的聚类算法,不仅可以对二维的坐标点进行聚类,还可以对高维的图像信息进行聚类。Kmeans算法对初始质心的选择比较敏感,Kmeans++算法针对初始质心的选择做了改进,使得几个初始质心尽可能的远。

        在使用kmeans算法对二维坐标进行聚类时,聚类的依据是坐标点与质心之间的距离;同样,对于高维度的图像信息,可以将像素点之间的差异看作距离,这样得到的每个簇,都是像素点差异较小的图,简单来说,每个簇内是图像相似度较高的图像。

        这里使用kmeans++算法,对CIFAR10数据集进行聚类。CIFAR10是一个用于图像分类的数据集,共有10个类别,每张图像的大小为32*32*3。程序在CIFAR10数据集内挑选出了200张图像,并对这200张图像进行聚类,k值设为10。

        聚类结果:

        

簇1

簇2

簇3

 簇4

 簇5

簇6

 簇7

 簇8

 簇9

 簇10

 其中,avg.jpg是质心,是簇内所有图像的平均值,是对簇内信息的抽象反映。

可以看出,聚类效果还是不错的,把一些较为相似的图像放在了一个簇内。

代码实现:

import torchvision
import torch
import numpy as np
from torch.utils.data import DataLoader
import os
import cv2
import random


def load_data():
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()]
    )

    train = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(train, batch_size=1)

    for i, (img, _) in enumerate(train_loader):
        images.append(np.transpose(img[0].numpy(), (1, 2, 0)))
        if i > 200:
            break


def distance(img, centroids):
    return np.array([np.sum(centroid - img) ** 2 for centroid in centroids])


def kmeans_plus(k):
    centroids = []
    idx = random.randint(0, len(images) - 1)
    centroids.append(images.pop(idx))

    for _ in range(k - 1):
        sum = 0
        dx = np.zeros((len(images),))
        for i, img in enumerate(images):
            dx[i] = np.min([(img - centroid) ** 2 for centroid in centroids])
            sum += dx.sum()
        p = np.array(dx) / sum
        max_idx = np.argmax(p)
        centroids.append(images.pop(max_idx))
    print("finish")

    return centroids


def kmeans(k=10):
    centroids = kmeans_plus(k)

    clu = dict()

    for epoch in range(100):

        for i in range(k):
            clu[i] = []

        for img in images:
            index = distance(img, centroids).argmin()
            clu[index].append(img)

        for i in range(k):
            sum = np.zeros_like(img)
            for img in clu[i]:
                sum += img
            mean = sum / len(clu[i])
            centroids[i] = mean
        print(epoch)
    return clu, centroids


if __name__ == '__main__':
    images = []

    load_data()
    clu, centroids = kmeans()

    for i in range(len(centroids)):
        os.mkdir(f"./{i}")
        ims = clu[i]
        sum = np.zeros_like(images[0].shape)
        for idx, im in enumerate(ims):
            cv2.imwrite(f"./{i}/{idx}.jpg", cv2.resize(im * 255, (320, 320)))
            sum = sum + im
        sum = sum / (idx + 1)
        cv2.imwrite(f"./{i}/avg.jpg", cv2.resize((sum * 255), (320, 320)))

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

G.E.N.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值