一、原理
- 请见这篇博客。
- K-Means分类的主要步骤为:
① 给定聚类中心个数K;
② 随机选取K个聚类中心点μ1, μ2, … ,μK;每个中心点对应一个类别;
③ 对于每一个特征点xi,根据到聚类中心的距离计算其所属类别 ck (k=1,2,…,K),其中ck=argminj||xi-μj||2;
④对每一类k,根据其包含的所有特征点更新均值μk;
⑤重复步骤3-4,直至μ1, μ2, … ,μK不再改变。
二、代码
(一)调用opencv函数
参考了这位博主的代码:
import cv2
import numpy as np
########## Begin ##########
# 对图像用kmeans聚类
# 显示图片的函数
def show(winname,src):
cv2.namedWindow(winname,cv2.WINDOW_GUI_NORMAL)
cv2.imshow(winname,src)
cv2.waitKey()
img = cv2.imread('D:/cherry.png')
o = img.copy()
print(img.shape)
# 将一个像素点的rgb值作为一个单元处理,这一点很重要
data = img.reshape((-1,3))
print(data.shape)
# 转换数据类型
data = np.float32(data)
# 设置Kmeans参数
critera = (cv2.TermCriteria_EPS+cv2.TermCriteria_MAX_ITER,10,0.1)
flags = cv2.KMEANS_RANDOM_CENTERS
# 对图片进行四分类
r,best,center = cv2.kmeans(data,3,None,criteria=critera,attempts=10,flags=flags)
print(r)
print(best.shape)
print(center)
center = np.uint8(center)
# 将不同分类的数据重新赋予另外一种颜色,实现分割图片
data[best.ravel()==1] = (0,0,0)
data[best.ravel()==0] = (255,0,0)
data[best.ravel()==2] = (0,0,255)
data[best.ravel()==3] = (0,255,0)
# 将结果转换为图片需要的格式
data = np.uint8(data)
oi = data.reshape((img.shape))
# 显示图片
show('img',img)
show('res',oi)
检测效果如下:
(二)不调用函数
1. 处理图片
要将图片中一个像素点的rgb值作为一个单元处理,首先要搞清楚图片在python中是如何存储的。
假如一张彩色图片的大小为 384 * 512,即有 384 行,512 列,那么我们看python中显示的 img 数组如下:
这是 img[0],即图像第一行每个像素点 rgb 的值。同理, img[383] 就是图像最后一行每个像素点 rgb 的值:
当我们输入 img[384] 时,显示没有这一层,也就是说彩色图像是按每一行的像素点来保存的:
彩色图像显然是用三维数组来表示的,我们要将每个像素点的rgb值作为一个单元来处理,那么我们首先要将它转变为一个二维数组,这里用到了 reshape() 函数,关于这个函数的讲解请看这篇博客。总之就是将图像这个三维数组转变成二维数组(也就是将二维图像转变成一维,分成3列), reshape((-1, 3)) 即转变为有三列的二维数组,故每一行变成了
(384 * 512 * 3)/ 3 = 196608 个元素:
也就是每个数组由一个像素点的rgb三个值组成,也就是将一个像素点的rgb当作了一个单元。
(三)改进距离
(四)错误
1. 将数组进行拼接
2. 布尔操作错误
原本代码如图所示:
def newCenter2(point, new_kclass): # 求每组的平均值以确定中心
newCenter = np.zeros((3, 3)) # 初始化一个(3,3)的矩阵,用来存放分类类别
point2 = np.c_[point, new_kclass] # 将像素点与类别拼接成一个数组
point_class = point2.astype(int)
for i in range(3): # 分别计算类别为0,1的平均值,保存为一个(K,2)的二维数组
point3 = point_class[point_class[3] == i]
newCenter[i] = point3[["x", "y"]].mean(axis=0)
return newCenter
错误如下图所示:
在CSDN上看了一下别人的回答,大概是这个样子。
但我没找到那个get_support()函数,错误中也没体现,所以我将代码改为如下形式:
def newCenter2(point, new_kclass): # 求每组的平均值以确定中心
newCenter = np.zeros((3, 3)) # 初始化一个(3,3)的矩阵,用来存放分类类别
point2 = np.c_[point, new_kclass] # 将像素点与类别拼接成一个数组
point_class = point2.astype(int)
for i in range(3): # 分别计算类别为0,1的平均值,保存为一个(K,2)的二维数组
if point_class[3] == i:
point3 = point_class
newCenter[i] = point3[["x", "y"]].mean(axis=0)
return newCenter
然后报错如下:
这时