这里是对鸢尾花进行分类,如果要修改,只需要换路径df还有种类k就可以了。具体逻辑在很多博客讲解已经很清楚了。
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
df = pd.read_csv("C:/Users/wxc/Desktop/xuexi/pythonProject/机器学习/分类与聚类/kmeans/ris.csv")
X_train = df.iloc[:, 0:4].values# 将数据集转换为numpy数组
k = 3# 设置聚类中心的个数
random_point = X_train[random.sample(range(len(X_train)), k)]# 这里是用的numpy数组,从 X_train 中随机选择 k 个样本作为初始的聚类中心
# 计算欧式距离,并将每个样本分配到最近的聚类中心
def assign_cluster(X_train, random_point):
cluster_assignment = []
for x in X_train:
distances = [np.linalg.norm(x - point) for point in random_point] #这里在进行欧式距离的计算
closest_cluster = np.argmin(distances) # 获取最小距离对应的聚类中心的索引值
cluster_assignment.append(closest_cluster) #添加进去
return cluster_assignment
# 更新聚类中心
def New_centers(X_train, cluster_assignment, k):
new_centers = []
for i in range(k):
cluster_mean = X_train[cluster_assignment == i].mean(axis=0) # 求取均值,这样子就能更新聚类中心了
new_centers.append(cluster_mean)
return new_centers
# kmeans算法
def kmeans(X_train, random_point, k):
for i in range(100): # 这个迭代次数可以修改
cluster_assignment = assign_cluster(X_train, random_point)
new_centers = New_centers(X_train, np.array(cluster_assignment), k)
if np.allclose(random_point, new_centers, rtol=1e-05): # 由于是nd.array类型,无法直接使用==,所以这里做差表示
break
random_point = new_centers
return cluster_assignment, new_centers
cluster_assignment, final_centers = kmeans(X_train, random_point, k)
final_centers = np.array(final_centers)
for i in range(len(X_train)):
print(f"样本{i+1}的具体分类: {cluster_assignment[i]}")
print("最终的中心点位置:", final_centers)# 输出每个样本所属的具体分类
# 绘制分类的图像
fig, ax = plt.subplots()
colors = ['red', 'green', 'blue'] # 设置每个聚类的颜色
for i in range(k):
cluster_samples = [X_train[j] for j in range(len(X_train)) if cluster_assignment[j] == i] # 遍历一遍所属类别的元素
cluster_samples = np.array(cluster_samples) #先变成np.array类型
ax.scatter(cluster_samples[:, 0], cluster_samples[:, 1], c=colors[i], label=f'Cluster {i+1}') #以第一特征为x轴,第二特征为y轴,开始绘制
ax.scatter(final_centers[:, 0], final_centers[:, 1], c='black', marker='x', label='Centroids')
ax.legend()
plt.show()