[deeplearning-022] tf kmeans

这篇博客详细介绍了如何使用TensorFlow实现K-means聚类算法,包括参考文献和完整的源代码,适合想要在深度学习框架中进行聚类分析的读者。
摘要由CSDN通过智能技术生成

1.参考文献

https://blog.csdn.net/freedom098/article/details/56021013

https://blog.csdn.net/freedom098/article/details/56021013

2.源码


# 源码来自 https://blog.csdn.net/freedom098/article/details/56021013

import tensorflow as tf
import matplotlib.pyplot as plt

from sklearn.datasets.samples_generator import make_blobs

#下面三种配置任选一种
K = 4 #类别数
N = 8 #样本数
centers_processing = [[-2, -2], [-2, 1.5], [1.5, -2], [2, 1.5]]

K = 2 #类别数
N = 4 #样本数
centers_processing = [[-2, -2], [1.5, -2]]

K = 3 #类别数
N = 4 #样本数
centers_processing = [[-2, -2], [-2, 1.5], [1.5, -2]]


MAX_ITERS = 1000 #最大迭代次数

data, labels = make_blobs(n_samples=N, centers=centers_processing, n_features=2, cluster_std=0.8, shuffle=False, random_state=42)

# #绘制data图
# legend_type = ['ro','bo','go','r+']
# for i in range(4):
#     print('i=',i)
#     ni = [ii for ii in range(len(labels)) if labels[ii] == i]
#     print(ni)
#     plt.plot(data[ni,0], data[ni,1], legend_type[i])
# plt.show()

# 计算类内平均值函数
def clusterMean(points, bestCenter, K):
    #注意,unsorted_segment_sum是没有axis参数的,所以它始终是按照第一维度进行计算的,这是关键
    #如果要求另一个维度,要做tf.transpose(data)
    total = tf.unsorted_segment_sum(points, bestCenter, K) # 第一个参数是tensor,第二个参数是簇标签,第三个是簇数目
    count = tf.unsorted_segment_sum(tf.ones_like(points), bestCenter, K)
    return total/count

points = tf.Variable(data)
label_processing = tf.Variable(tf.zeros([N], dtype=tf.int64))
centers_processing = tf.Variable(tf.slice(points.initialized_value(), [0, 0], [K, 2]))

#把数据复制成tensor,以便计算每个样本和每个聚类中心的距离
#对数据也做tile然后reshape,生成一个NxKx2的tensor
repPoints = tf.reshape(tf.tile(points,[1,K]),[N,K,2])
#centers_processing是一个4x2的矩阵,对它进行复制,在第一个维度上复制N次,在第二个维度上复制1次,结果是一个4Nx2的矩阵,N=8的时候,是32x2矩阵
cetnersTile = tf.tile(centers_processing, [N,1])
#如果要计算每个样本和聚类中心的距离,做reshape:第一个维度表示样本数量,第二个维度表示 聚类中心数量,第三个样本表示数据的维度
#reshape(t, shape) => reshape(t, [-1]) => reshape(t, shape) 首先将矩阵t变为一维矩阵,然后再对矩阵的形式更改就可以了。
#从一维变成多维,首先,将一维切成N个,其次,将结果的每个切成K个,最后2维,不用切,自然就是了。
repCenters = tf.reshape(cetnersTile, [N, K, 2])
#计算距离,reduction_indices可以理解成,要消掉哪一维,比如,这条语句,就是把结果变成[N,K],N是样本,K是聚类中心数量,这样结果就是每个样本和每个聚类中心的距离
sumSquare = tf.reduce_sum(tf.square(repCenters - repPoints), reduction_indices=2)
#axis表示,对每个样本,在聚类中心上取最小值的序号
bestCenter = tf.argmin(sumSquare, axis=1)
#判断哪些样本的的类标记是否发生变化,如果不再变化,就提前终止计算--因为已经稳定了。reduce_any是 逻辑或 操作
change = tf.reduce_any(tf.not_equal(bestCenter, label_processing))

#测试中间变量
total = tf.unsorted_segment_sum(points, bestCenter, K) # 第一个参数是tensor,第二个参数是簇标签,第三个是簇数目
count = tf.unsorted_segment_sum(tf.ones_like(points), bestCenter, K)

#计算新的聚类中心坐标
means = clusterMean(points, bestCenter, K)
#计算之前,先计算change
with tf.control_dependencies([change]):
    update = tf.group(centers_processing.assign(means), label_processing.assign(bestCenter))

# #展示tf的数据
# with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
#     ret = sess.run(points)
#     print('\npoints = \n', ret)
#
#     ret = sess.run(label_processing)
#     print('\nlabel_processing=',ret)
#
#     ret = sess.run(centers_processing)
#     print('\ncenters_processing = \n', ret)
#
#     ret = sess.run(cetnersTile)
#     print('\ncetnersTile = \n', ret)
#
#     ret = sess.run(repCenters)
#     print('\nrepCenters = \n', ret)
#
#     ret = sess.run(sumSquare)
#     print('\nsumSquare = \n', ret)
#
#     ret = sess.run(bestCenter)
#     print('\nbestCenter = \n', ret)
#
#     ret = sess.run(change)
#     print('\nchange = \n', ret)
#
#     ret = sess.run(total)
#     print('\ntotal = \n', ret)
#
#     ret = sess.run(count)
#     print('\ncount = \n', ret)
#
#     ret = sess.run(means)
#     print('\nmeans = \n', ret)
#
#     #update没有返回结果
#     ret = sess.run(update)
#     print('\nupdate = \n', ret)
# exit(1)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    changed = True
    iterNum = 0
    while changed and iterNum < MAX_ITERS:
        print(iterNum)
        iterNum += 1
        [changed, _] = sess.run([change, update])
        [centersArr, clusterArr] = sess.run([centers_processing, label_processing])
        print(clusterArr)
        print(centersArr)

        # # 显示图像
        # fig, ax = plt.subplots()
        # ax.scatter(data.transpose()[0], data.transpose()[1], marker='o', s=100, c=clusterArr)
        # plt.plot()
        # plt.show()

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值