k_means聚类算法学习经历
图像分析与理解课程的一个编程小作业,python纯小白,踩坑无数,为了让之间的过程以后也有印象,所以记录下来
1.首先是对元组,序列,numpy数组,矩阵的理解,到现在还没完全明白,希望后续实践可以摸清
python中列表,矩阵,数组之间的转换
Python中的数组、列表、元组、Numpy数组、Numpy矩阵
2.while和for后面要加:(这个基础语法报错竟然怎么也没发现)
3.for…else的神奇用法:python有时候就感觉像是一个个句子,for…else是指如果for中所有的都遍历了就会执行else的语句,如果for+break然后else会很好用
4.for…in…的变式两个for,for…in…for…in…(震惊)相当于在一个for的循环内再加一个for的循环,两个for的叠加:
[j for i in a for j in i]
相当于
for i in a:
for j in i:
5.如果对一个列表,既要遍历索引又要遍历元素时,for index, item in enumerate(list1)
6.查资料读取图片数据的方法有很多,Python 读取、显示、保存图像的各种方法(全!)因为后续可能涉及计算机视觉相关领域,下载了opencv库,下载这个库的过程也一言难尽,找了很多方法,实测有效的在anaconda+pycharm环境下装cv2的方法,记录一个资源丰富的库https://www.lfd.uci.edu/~gohlke/pythonlibs/#numpy)
import cv2
import numpy as np
# 读取图片数据到数组
def loadDataSet(arrpict):
"""
cv2默认为 BGR顺序
读取图片数据(三维向量b,g,r) 到 character(一维列表) 数据集
return: 特征向量组character
"""
row = arrpict.shape[0]
col = arrpict.shape[1]
# print("数组行数:",row):500
# print("数组列数:",col):500
# 特征向量集合
characters = []
print("正在读取图片信息,请稍等......")
# (行,列):读取像素点bgr数据
for i in range(row):
for j in range(col):
b = arrpict[i, j, 0]
g = arrpict[i, j, 1]
r = arrpict[i, j, 2]
characters.append([b, g, r])
character = np.array(characters, 'f') # 转换成numpy矩阵计算形式,浮点型f
print(character,"以上为像素点集合")
print("像素点集合的shape:", character.shape)
return character
# 随机选取k个初始值
def sel_init_cen(n,characters):
rand = []
center = []
# 选取随机数,选取初始中心
for i in range(int(n)):
rand.append((int)(np.random.random() * (characters.shape[0])))
center.append(characters[rand[i]])
print("初始聚类中心", i, center[i], ",")
return center
# 判断两个点之间的‘距离’
def get_distance(n1, n2):
a= []
for i in range(len(n1)):
a.append(np.power(n1[i] - n2[i], 2))
return np.sqrt(sum(a))
# 判断两个中心点是否相同
def compare_center(center, new_center):
for n, m in zip(center, new_center):
if n == m:
return True
return False
# 计算一个类别的许多点的中心值
def calculate_center_pot(clus):
n = len(clus)
x = np.matrix(clus).transpose().tolist()
'''
matrix():转换成矩阵
tolist():转换成列表
transpose(): 对二维数组是转置,对多维设置参数,某些轴交换
'''
center_pot = [sum(j)/n for j in x]
return center_pot
# 显示聚类图像
def imageshow(center):
row = pict.shape[0]
col = pict.shape[1]
distances = []
'''
# 可自定义显示颜色[蓝色,白色,红色,绿色,黑色] 注意bgr和rgb
colors = [[255, 0, 0], [255, 255, 255], [0, 0, 255], [0, 255, 0], [0, 0, 0]]
'''
# (行,列):设置像素点bgr数据
for i in range(row):
for j in range(col):
for center_pot in center:
distances.append(get_distance(pict[i][j], center_pot))
index = distances.index(min(distances))
distances.clear()
# print(index)
pict[i][j] = center[index]
'''
win = cv2.namedWindow('pict win', flags=0) # 窗口可调整图像大小
'''
cv2.imshow('pict win', pict)
cv2.waitKey(0)
"""
主程序
"""
# K-means算法的实现
def K_means(pots, center_pots):
'''
N = len(pots) # 样本个数
n = len(pots[0]) # 单个样本的维度
'''
k = len(center_pots) # k值大小
while True: # 迭代
new_center_pots = [] # 记录中心点
clusters = [] # 记录聚类的结果
for c in range(k):
clusters.append([]) # 初始化
# 针对每个点,寻找距离其最近的中心点
for i, data in enumerate(pots):
distances = []
for center_pot in center_pots:
distances.append(get_distance(data, center_pot))
index = distances.index(min(distances)) # 找到最小的距离的那个中心点的索引
clusters[index].append(data) # 那么这个中心点代表的类中增加一个样本
# 重新计算中心点
for cluster in clusters:
new_center_pots.append(calculate_center_pot(cluster))
# 如果初始的中心点相同,那么有些cluster就会是空集,后续的计算就会有问题
for j in range(k):
if len(clusters[j]) == 0:
new_center_pots[j] = center_pots[j]
# print(new_center_pots)
# 判断中心点是否发生变化:即,判断聚类前后样本的类别是否发生变化
for n, m in zip(center_pots, new_center_pots):
if not compare_center(n, m):
center_pots = new_center_pots[:] # 复制一份
break
else:
# 如果没有变化,那么退出迭代,聚类结束 python中的for...else用法
# 用于输出各个分类的像素点的多少
for i in range(k):
print(len(clusters[i]))
break
return new_center_pots # 返回聚类的结果
if __name__ == '__main__':
pictpath = 'E:/wo/cherry.png'
pict = cv2.imread(pictpath)
'''
print(type(pict))
print(pict.shape) # numpy.ndarray (500, 500, 3)
'''
k = input('请输入k值:') # 自定义将分成几类
# 获取像素点的集合
character = loadDataSet(pict)
'''
print(type(character))
print(character.shape) #numpy.ndarray (250000, 3)
'''
center_pots = sel_init_cen(k, character) # 获取初始随机的k个中心点
character = character.tolist() # 将numpy里的数组转换成列表形式
# 调用k_means算法得到最终的中心点
cen = K_means(character, center_pots)
print('----------最终结果----------')
# 将新的中心点的值打印
for i, center in enumerate(cen):
print('center', i, ' ', center)
# 图片分类及展示图片
imageshow(cen)
代码很多借鉴了以下两个(非常感谢):https://blog.csdn.net/ten_sory/article/details/81016748
https://blog.csdn.net/wsh596823919/article/details/79981703
老师的实践课讲了关于内存与运行时间的关系,在关于这个代码中,有个遍历的部分:关于求两个点的距离部分。相比较于遍历,用矩阵运算时间会大大减小