官网示例路径https://github.com/deeplearning4j/deeplearning4j/blob/bca1df607f6e58ae73baa8e684130bfa7ad8c2e3/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java
代码
@Test
public void testKMeans() {
//设置一个底层生成数值,后面的方法产生的数相等
//如Nd4j.randn(5, 5),产生5行5列的矩阵,不管运行多少次该test用例,均生成相等的5行5列矩阵
Nd4j.getRandom().setSeed(7);
//声明一个KMeans聚类对象,参数分别是 最终聚类的类别数量,迭代次数,距离函数 距离函数的取值为(sum,max,min,norm1,norm2,prod,std,var,euclidean,cosine,cosinesimilarity,manhattan,mmul,tensorMmul)具体可参照org.nd4j.linalg.api.ops.factory.DefaultOpFactory
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, "euclidean");
//张量矩阵生成KMeans的点对象
List<Point> points = Point.toPoints(Nd4j.randn(5, 5));
ClusterSet clusterSet = kMeansClustering.applyTo(points);
//将第一个点对象带入进行分类,可得到对象pointClassification,该对象getCenter得到该点所属的类别
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0)); //可以使用classifyPoint(points.get(0),false)使center中心店不进行更新移动
System.out.println(pointClassification);
}