kmeans算法
定义:
k均值聚类算法是一种迭代求解的聚类分析算法,其步骤是随机选取K个对象作为初始的聚类中心,然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。
方法:
1.随机取1000个点,其中随机5个为中心点
2.求其余995个点与每个中心点的距离
3.距离最小的,将此点和对应的中心点归为一类
4.类中求平均值作为新的中心点
5.如果新中心点和旧中心点的误差小于某个阈值,停止,否则重复以上操作
代码:
import random #因为用到了随机数,所以引用此模块
import matplotlib.pyplot as plt #因为需要成像,所以引入此模块
import csv #因为需要将需要的数据存为csv文件,所以引入此模块
class KMeans():
def __init__(self, k=1): #k代表分类数,默认为1,可以改变
self.__k = k
self.__data = [] # 存放原始数据
self.__pointCenter = [] # 存放中心点,第一次获得的中心点通过随机方式在__data里随机出来
self.__result = [] #类里的数据
for i in range(k):
self.__result.append([]) # [[],[],[],[],[]] #有几个类添加几个列表
pass
pass
#进行模型训练
def fit(self, data, threshold, times=500): #data: 训练数据 threshold: 阈值,退出条件
self.__data = data
self.randomCenter() #调用此方法随机出最开始的中心点
print(self.__pointCenter) #输出中心点列表,无实意
centerDistance = self.calPointCenterDistance(self.__pointCenter, self.__data) #计算每个点和每个中心点的距离
# 对原始数据进行分类,将每个点分到离它最近的中心点
i = 0 #从第一个点开始
for temp in centerDistance:
index = temp.index(min(temp))
self.__result[index].append(self.__data[i])
i += 1
pass
print(self.__result) # 打印分类结果
oldCenterPoint = self.__pointCenter #将中心点赋值给它
newCenterPoint = self.calNewPointCenter(self.__result) #调用此方发,计算新的中心点
#判断:如果前后两次中心点之间的距离是否小于某个阈值
while self.calCenterToCenterDistance(oldCenterPoint, newCenterPoint) > threshold: #如果大于,重复while语句的操作,次数在times次内
times -= 1
result = []
for i in range(self.__k):
result.append([])
pass
oldCenterPoint = newCenterPoint # 保存上次的中心点
centerDistance = self.calPointCenterDistance(newCenterPoint, self.__data) #计算新中心点和个点之间的距离
# 对原始数据进行分类,将每个点分到离它最近的中心点
i = 0
for temp in centerDistance:
index = temp.index(min(temp))
result[index].append(self.__data[i]) # result = [[[10,20]]]
i += 1
pass
newCenterPoint = self.calNewPointCenter(result)
print(self.calCenterToCenterDistance(oldCenterPoint, newCenterPoint))
self.__result = result
pass
self.__pointCenter = newCenterPoint
return newCenterPoint, self.__result #将新中心点和所有类中的点返回
pass
#计算两次中心点之间的距离,求和求均值
def calCenterToCenterDistance(self, old, new): #old:上次的中心点 new:新计算的中心点
total = 0
for point1, point2 in zip (old, new):
total += self.distance(point1, point2) #有几个中心点求几次
pass
return total / len(old) #返回前后两次中心点距离之和的平均值
pass
#计算每个点和每个中心点之间的距离
def calPointCenterDistance(self, center, data):
centerDistance = []
for temp in data:
centerDistance.append([self.distance(temp, point) for point in center]) #调用distance方法求各点到各中心点的距离
pass
print(centerDistance)
return centerDistance #返回距离
pass
#计算新的中心点
def calNewPointCenter(self, result):
newCenterPoint = []
for temp in result:
# 转置:将每个点与所有中心点的距离列表 转置成 每个中心点与同一类的所有点的列表
temps = [[temp[x][i] for x in range(len(temp)) ] for i in range(len(temp[0]))]
point = []
for t in temps:
# 对每个维度(类)求和,取平均即为新的中心点
point.append(sum(t)/len(t)) # mean
pass
newCenterPoint.append(point)
pass
print(newCenterPoint)
return newCenterPoint #返回新中心点的列表
pass
#计算两个点之间的距离,支持任意维度,欧式距离
def distance(self, pointer1, pointer2):
distance = (sum([(x1 - x2)**2 for x1, x2 in zip(pointer1, pointer2)]))**0.5 #公式,几个未知数和集合都可以,支持任意维度,这里是二维的
return distance #返回距离
pass
#从原始的__data里随机出最开始进行计算的k个中心点
def randomCenter(self):
while len(self.__pointCenter) < self.__k:
# 随机一个索引 所有的点中随机取几个作为中心点
index = random.randint(0, len(self.__data) - 1)
# 判断中心点是否重复,如果不重复,加入中心点列表 避免因重复导致的中心点个数不够
if self.__data[index] not in self.__pointCenter:
self.__pointCenter.append(self.__data[index])
pass
pass
pass
# 将中心点和类中的所有数据转成csv文件进行存储
def writetocsv(self):
with open('3.csv','w',encoding='gb18030',newline='')as fp:
writer = csv.writer(fp)
writer.writerow(['x','y'])
writer.writerows(centerPoint) #最后的中心点
writer.writerow('每个类的所有元素::')
writer.writerows(result) #最后的类
pass
pass
if __name__ == "__main__":
data = [[random.randint(1, 100), random.randint(1, 100)] for i in range(1000)] #随机1000个点
kmeans = KMeans(k=5) #实例化对象,参数为几个中心点(几个类)
centerPoint, result = kmeans.fit(data, 0.0001) #调用fit方法,参数中的0.0001为阈值
print(centerPoint)
kmeans.writetocsv() #调用此方法,将需要的数据存在csv文件中
#界面可视化,可以清晰的看到散点图,需引入 matplotlib模块
plt.plot()
plt.title("KMeans Classification")
i = 0
tempx = []
tempy = []
color = []
for temp in result: #遍历类中的点,用同一种颜色,需转置
temps = [[temp[x][i] for x in range(len(temp))] for i in range(len(temp[0]))]
color += [i] * len(temps[0])
tempx += temps[0]
tempy += temps[1]
i += 2
pass
plt.scatter(tempx, tempy, c=color, s=30) #x轴,y轴,点的颜色,点的大小
plt.show() #显示出来
pass
pass
因为用的类,所以里面的方法不是从上到下去看的,调用的时候才需要,注释比较详细,代码内容不算原创,只算是提供一个想法和方法,可以借鉴以了解何为kmeans算法。
ps:需要matplotlib模块才能看见图象,csv文件命名随意,需要用excel表打开