介绍
K_Means其实用sklearn即可,TensorFlow1.0早期版本支持K_Means,在2.0之后,由于很多api废弃,导致实现K_Means有很多坑。以下为踩坑记录。
完整代码路径:https://github.com/lilihongjava/leeblog_python/tree/master/tensorflow_kmeans
数据集
采用sklearn iris.csv数据集,位于data目录下
训练方法
入口代码
tf_k_means_model(feature_column="sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)",
center_count=3, input1="./data/iris.csv", output1="./data/")
采用tf.compat.v1.estimator.experimental.KMeans api,此API是从1.X版本迁移来的,目前处于experimental阶段,用于生产环境要小心!
train方法需要接受输入函数(input function),input_fn用于将feature和target data传递给Estimator的train/evaluate/predict方法。这里,将numpy数据转换为Tensors。
def input_fn():
return tf.data.Dataset.from_tensors(tf.convert_to_tensor(points, dtype=tf.float32)).repeat(2)
model.train(input_fn)
模型导出
用的是tf.Estimator.export_saved_model方法,需要指定特征列的类型,这里用的是numeric_column
if output1:
my_feature_columns = []
for key in feature_colu