【pytorch】Kmeans_pytorch用于一般聚类任务的代码模板

概述

喜大普奔!(bushi

依旧是我们的版本答案pytorch,只不过这次想要说明的是,它依旧可以主要用于聚类任务。

还是老规矩,原理先不多讲(站中有很多讲到KMeans聚类的文章,如果大家对KMeans聚类原理有问题可以先移步其他文章或留言,笔者后面会考虑更新原理篇)

作为一种常见的机器学习范畴内的经典无监督学习方法,KMeans聚类常用于未确定具体标签的数据集分类任务,准确来说,KMeans通过计算衡量数据点之间的距离等多种方式确定点的大致类别。

数据集

本次数据采用xclara数据集,里面记录了3000余组点的坐标,如下图所示。

附数据集下载地址:
https://download.csdn.net/download/weixin_52456426/86725232
在这里插入图片描述

代码

本次将要通过本模板介绍的是pytorch中KMeans包的使用,但要注意的是,这个模块的维护是独立的,并非包含在pytorch源码下,所以需要我们单独下载。

首先打开终端或Anaconda prompt,输入

pip install kmeans-pytorch

等待安装完成。

注意:由于本包基于torch模块编写,需要环境中拥有numpy、torch等前置包。

接下来是我们的代码部分,具体注释大家可以参考代码

import numpy as np
import pandas as pd
import torch
from kmeans_pytorch import kmeans, kmeans_predict
import matplotlib.pyplot as plt

# 导入数据
data = pd.read_csv('./xclara.csv')
data = np.array(data.iloc[:, :])

# 设定:数据集数量,数据集维数,聚类的类别数
data_size, dims, num_clusters = len(data), 2, 3
data = torch.from_numpy(data)

# 训练阶段
# X:待聚类数据集(需要是torch.Tensor类型),维数,距离计算法则,训练设备
cluster_ids_x, cluster_centers = kmeans(
    X=data, num_clusters=num_clusters, distance='euclidean', device=torch.device("cuda:0")
)

# 数据集中数据类别所属
print(cluster_ids_x)
# 数据集各类别聚类中心
print(cluster_centers)

# ======================================================================================================================
# 测试阶段
# how to predict
test_data = np.array(pd.read_csv('./xclara.csv').iloc[:, :])
test = []
for item in test_data:
    point = np.random.uniform(0, 1)
    if point > 0.6:
        test.append(item)
# 测试集为在训练集随机取数据
test = torch.from_numpy(np.array(test))
# 预测阶段,需要额外提供测试集(Tensor)和训练阶段得到的聚类中心
cluster_ids_y = kmeans_predict(
    X=test, cluster_centers=cluster_centers, distance='euclidean', device=torch.device("cuda:0")
)
# 输出预测结果
# print(cluster_ids_y)

# ======================================================================================================================
# plot:绘图阶段————训练集上的聚类图
plt.figure()
# 训练集聚类点的分布
plt.scatter(data[:, 0], data[:, 1], c=cluster_ids_x, cmap='cool')
# 聚类中心点的分布
plt.scatter(
    cluster_centers[:, 0], cluster_centers[:, 1],
    c='white',
    alpha=0.6,
    edgecolors='black',
    linewidths=2
)

plt.tight_layout()
plt.show()

# ======================================================================================================================
# plot:绘图阶段————测试集上的聚类图
plt.figure()
# 测试集聚类点的分布
plt.scatter(test[:, 0], test[:, 1], c=cluster_ids_y, cmap='cool', marker='X')
# 聚类中心点的分布
plt.scatter(
    cluster_centers[:, 0], cluster_centers[:, 1],
    c='white',
    alpha=0.6,
    edgecolors='black',
    linewidths=2
)
plt.tight_layout()
plt.show()

参考资料:
https://www.cnpython.com/pypi/kmeans-pytorch

  • 6
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值