KMeans 的使用

最近在做目标检测,为了合理的打标签想到了用聚类算法来对自己的数据进行分类,这样可以避免同样的标签打的太多,而有的标签又打的太少,浪费时间和精力。网上查了一下,都是注重讲解算法本身,不才来说一下我的使用流程,见笑。。。

import numpy as np
import tensorflow as tf
from tensorflow.contrib.factorization import KMeans
import os
import cv2
# 导入MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
import time
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
num_steps = 50  # 训练次数
batch_size = 1024  # 每一批的样本数
k = 25  # clusters的数量
num_classes = 10  # 10分类
num_features = 784  # 每张图片是28*28
def mytrain():
    full_data_x = mnist.train.images
    X = tf.placeholder(tf.float32, shape=[None, num_features])
    # K-Means 的参数
    kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine',
                    use_mini_batch=True)
    # 创建 KMeans 模型
    training_graph = kmeans.training_graph()
    if len(training_graph) > 6:
        (all_scores, cluster_idx, scores, cluster_centers_initialized,
         cluster_centers_var, init_op, train_op) = training_graph
    else:
        (all_scores, cluster_idx, scores, cluster_centers_initialized,
         init_op, train_op) = training_graph
    cluster_idx = cluster_idx[0]
    avg_distance = tf.reduce_mean(scores)
    # 初始化变量 (用默认值)
    init_vars = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init_vars, feed_dict={X: full_data_x})
    sess.run(init_op, feed_dict={X: full_data_x})
    saver=tf.train.Saver(tf.global_variables(),max_to_keep=5)
    ckpt_file=r'./model'
    # 训练
    for i in range(1, num_steps + 1):
        _, d, idx = sess.run([train_op, avg_distance, cluster_idx],
                             feed_dict={X: full_data_x})
        if i % 10 == 0 or i == 1:
            print("Step %i, Avg Distance: %f" % (i, d))
            saver.save(sess, ckpt_file)


def mytest():
    # 输入图片
    X = tf.placeholder(tf.float32, shape=[None, num_features])
    kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine',
                    use_mini_batch=True)
    training_graph = kmeans.training_graph()
    if len(training_graph) > 6:
        (all_scores, cluster_idx, scores, cluster_centers_initialized,
         cluster_centers_var, init_op, train_op) = training_graph
    else:
        (all_scores, cluster_idx, scores, cluster_centers_initialized,
         init_op, train_op) = training_graph
    with tf.Session() as sess:
        ckpt_state = tf.train.get_checkpoint_state('./')
        print(ckpt_state)
        saver = tf.train.Saver()
        if ckpt_state:
            saver.restore(sess, tf.train.latest_checkpoint(r'./'))
            test_x, test_y = mnist.test.images, mnist.test.labels
            all_scores_,scores_ = sess.run([all_scores,scores],feed_dict={X: test_x})
            print(all_scores_[0].shape)
            print(scores_[0].shape)
            all_scores_=all_scores_[0]
            scores_=scores_[0]
            for i in range(scores_.shape[0]):
                
                L=abs(all_scores_[i]-scores_[i]) # 计算scores与all_scores中的每个元素的距离
                min_val=min(L) # 取最小的距离
                val=list(L).index(min_val) # 得到scores在all_scores的下标 ,既是数据的分类
                path1 = os.path.join("./images", str(val))
                if os.path.isdir(path1) == False:
                    os.mkdir(path1)
                img=test_x[i].reshape(28,28) * 255
                img=img.astype(np.uint8)
                cv2.imwrite(os.path.join(path1,str(time.time()) + ".png"),img)


mytrain()
mytest()

最后得到25个文件夹,对应我们设置的k值。
在这里插入图片描述
这是其中一个文件夹的图片
在这里插入图片描述这里只是使用了mnist数据集做的测试,各位可以自行改成自己要打标签的数据集。代码只是对tensorflow官网的例子进行的改写。

这里遇到了一个问题,如果有大佬知道,忘不吝赐教。

full_data_x = mnist.train.images
shape_=full_data_x.shape
full_data_x=full_data_x.reshape(shape_[0],28,28)
X = tf.placeholder(tf.float32, shape=[None, 28,28])

把输入改成这样,会报错:

Traceback (most recent call last):
  File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
    return fn(*args)
  File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input points should be a matrix.
	 [[{{node NearestNeighbors}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "I:/myPython/Kmeans/tensorflow_Kmeans.py", line 87, in <module>
    mytrain()
  File "I:/myPython/Kmeans/tensorflow_Kmeans.py", line 44, in mytrain
    feed_dict={X: full_data_x})
  File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
    run_metadata_ptr)
  File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
    run_metadata)
  File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input points should be a matrix.
	 [[node NearestNeighbors (defined at /myPython/Kmeans/tensorflow_Kmeans.py:25) ]]

Original stack trace for 'NearestNeighbors':
  File "/myPython/Kmeans/tensorflow_Kmeans.py", line 87, in <module>
    mytrain()
  File "/myPython/Kmeans/tensorflow_Kmeans.py", line 25, in mytrain
    training_graph = kmeans.training_graph()
  File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\contrib\factorization\python\ops\clustering_ops.py", line 377, in training_graph
    all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
  File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\contrib\factorization\python\ops\clustering_ops.py", line 257, in _infer_graph
    inp, clusters, 1)
  File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\ops\gen_clustering_ops.py", line 258, in nearest_neighbors
    "NearestNeighbors", points=points, centers=centers, k=k, name=name)
  File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in create_op
    op_def=op_def)
  File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()

根据提示,我输入的不是一个矩阵。我。。。。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值