最近在做目标检测,为了合理的打标签想到了用聚类算法来对自己的数据进行分类,这样可以避免同样的标签打的太多,而有的标签又打的太少,浪费时间和精力。网上查了一下,都是注重讲解算法本身,不才来说一下我的使用流程,见笑。。。
import numpy as np
import tensorflow as tf
from tensorflow.contrib.factorization import KMeans
import os
import cv2
# 导入MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
import time
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
num_steps = 50 # 训练次数
batch_size = 1024 # 每一批的样本数
k = 25 # clusters的数量
num_classes = 10 # 10分类
num_features = 784 # 每张图片是28*28
def mytrain():
full_data_x = mnist.train.images
X = tf.placeholder(tf.float32, shape=[None, num_features])
# K-Means 的参数
kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine',
use_mini_batch=True)
# 创建 KMeans 模型
training_graph = kmeans.training_graph()
if len(training_graph) > 6:
(all_scores, cluster_idx, scores, cluster_centers_initialized,
cluster_centers_var, init_op, train_op) = training_graph
else:
(all_scores, cluster_idx, scores, cluster_centers_initialized,
init_op, train_op) = training_graph
cluster_idx = cluster_idx[0]
avg_distance = tf.reduce_mean(scores)
# 初始化变量 (用默认值)
init_vars = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_vars, feed_dict={X: full_data_x})
sess.run(init_op, feed_dict={X: full_data_x})
saver=tf.train.Saver(tf.global_variables(),max_to_keep=5)
ckpt_file=r'./model'
# 训练
for i in range(1, num_steps + 1):
_, d, idx = sess.run([train_op, avg_distance, cluster_idx],
feed_dict={X: full_data_x})
if i % 10 == 0 or i == 1:
print("Step %i, Avg Distance: %f" % (i, d))
saver.save(sess, ckpt_file)
def mytest():
# 输入图片
X = tf.placeholder(tf.float32, shape=[None, num_features])
kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine',
use_mini_batch=True)
training_graph = kmeans.training_graph()
if len(training_graph) > 6:
(all_scores, cluster_idx, scores, cluster_centers_initialized,
cluster_centers_var, init_op, train_op) = training_graph
else:
(all_scores, cluster_idx, scores, cluster_centers_initialized,
init_op, train_op) = training_graph
with tf.Session() as sess:
ckpt_state = tf.train.get_checkpoint_state('./')
print(ckpt_state)
saver = tf.train.Saver()
if ckpt_state:
saver.restore(sess, tf.train.latest_checkpoint(r'./'))
test_x, test_y = mnist.test.images, mnist.test.labels
all_scores_,scores_ = sess.run([all_scores,scores],feed_dict={X: test_x})
print(all_scores_[0].shape)
print(scores_[0].shape)
all_scores_=all_scores_[0]
scores_=scores_[0]
for i in range(scores_.shape[0]):
L=abs(all_scores_[i]-scores_[i]) # 计算scores与all_scores中的每个元素的距离
min_val=min(L) # 取最小的距离
val=list(L).index(min_val) # 得到scores在all_scores的下标 ,既是数据的分类
path1 = os.path.join("./images", str(val))
if os.path.isdir(path1) == False:
os.mkdir(path1)
img=test_x[i].reshape(28,28) * 255
img=img.astype(np.uint8)
cv2.imwrite(os.path.join(path1,str(time.time()) + ".png"),img)
mytrain()
mytest()
最后得到25个文件夹,对应我们设置的k值。
这是其中一个文件夹的图片
这里只是使用了mnist数据集做的测试,各位可以自行改成自己要打标签的数据集。代码只是对tensorflow官网的例子进行的改写。
这里遇到了一个问题,如果有大佬知道,忘不吝赐教。
full_data_x = mnist.train.images
shape_=full_data_x.shape
full_data_x=full_data_x.reshape(shape_[0],28,28)
X = tf.placeholder(tf.float32, shape=[None, 28,28])
把输入改成这样,会报错:
Traceback (most recent call last):
File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
return fn(*args)
File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input points should be a matrix.
[[{{node NearestNeighbors}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "I:/myPython/Kmeans/tensorflow_Kmeans.py", line 87, in <module>
mytrain()
File "I:/myPython/Kmeans/tensorflow_Kmeans.py", line 44, in mytrain
feed_dict={X: full_data_x})
File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
run_metadata_ptr)
File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
run_metadata)
File "I:\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input points should be a matrix.
[[node NearestNeighbors (defined at /myPython/Kmeans/tensorflow_Kmeans.py:25) ]]
Original stack trace for 'NearestNeighbors':
File "/myPython/Kmeans/tensorflow_Kmeans.py", line 87, in <module>
mytrain()
File "/myPython/Kmeans/tensorflow_Kmeans.py", line 25, in mytrain
training_graph = kmeans.training_graph()
File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\contrib\factorization\python\ops\clustering_ops.py", line 377, in training_graph
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\contrib\factorization\python\ops\clustering_ops.py", line 257, in _infer_graph
inp, clusters, 1)
File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\ops\gen_clustering_ops.py", line 258, in nearest_neighbors
"NearestNeighbors", points=points, centers=centers, k=k, name=name)
File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in create_op
op_def=op_def)
File "\myPython\python\python3.7.6\lib\site-packages\tensorflow\python\framework\ops.py", line 2005, in __init__
self._traceback = tf_stack.extract_stack()
根据提示,我输入的不是一个矩阵。我。。。。