代码
这里是对单属性的csv文件进行二分类
package com.fly.cluster;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartFrame;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.DefaultXYDataset;
import sun.plugin2.util.ParameterNames;
import weka.clusterers.Clusterer;
import weka.clusterers.EM;
import weka.clusterers.FarthestFirst;
import weka.clusterers.SimpleKMeans;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.CSVLoader;
import weka.core.pmml.jaxbbindings.Cluster;
import java.awt.*;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class Main {
//kMEANS
public static SimpleKMeans kMeans(Instances dataSet ) throws Exception {
SimpleKMeans simpleKMeans = new SimpleKMeans();
//设置随机数种子
simpleKMeans.setSeed(20);
simpleKMeans.setNumClusters(2);
simpleKMeans.buildClusterer(dataSet);
return simpleKMeans;
}
//FarthestFirst
public static FarthestFirst farthestFirst(Instances dataSet) throws Exception {
FarthestFirst first = new FarthestFirst();
first.setNumClusters(2);
first.buildClusterer(dataSet);
return first;
}
//EM
public static EM em(Instances dataSet) throws Exception {
EM em = new EM();
em.setNumClusters(2);
em.buildClusterer(dataSet);
return em;
}
//输出分类
public static void output(String path, Clusterer cluster,Instances instances) throws Exception {
BufferedWriter out = new BufferedWriter(new FileWriter(path));
out.write("value");
for (int i=0;i
out.newLine();
int i1 = cluster.clusterInstance(instances.instance(i));
out.write(String.valueOf(i1));
}
out.close();
}
//输出混淆矩阵
public static void confusionMatrix(Clusterer clusterer,Instances instances,Instances test)throws Exception{
//这里label 0 等于 value 1
int TP=0,TN=0,FN=0,FP=0;
for(int i=0;i
int value = clusterer.clusterInstance(instances.instance(i));
double label = test.instance(i).value(0);
// if(label==1){
// if(value==0)TP++;
// else FN++;
// }else{
// if(value==1)TN++;
// else FP++;
// }
if(label==1){
if(value==1)TP++;
else FN++;
}else{
if(value==0)TN++;
else FP++;
}
}
System.out.println(TP+" "+FN);
System.out.println(FP+" "+TN);
int P=TP+FN;
int N=FP+TN;
System.out.println("准确率: "+1.0*(TP+TN)/(P+N));
System.out.println("错误率: "+1.0*(FP+FN)/(P+N));
System.out.println("召回率: "+1.0*(TP)/(P));
System.out.println("精度: "+1.0*(TP)/(TP+FP));
}
//散点图
public static void show(Clusterer clusterer,Instances instances) throws Exception {
Integer j=0;
DefaultXYDataset xydataset = new DefaultXYDataset();
//第一个参数j代表第j组数据,data为2*n大小的double数组
//xydataset.addSeries(j,data);
List a1=new ArrayList();
List a0=new ArrayList();
for(int i=0;i
if(clusterer.clusterInstance(instances.instance(i))==0){
a0.add(instances.instance(i).value(0));
}else {
a1.add(instances.instance(i).value(0));
}
}
double[][]d0=new double[2][a0.size()];
double[][]d1=new double[2][a1.size()];
for(int i=0;i
d0[0][i]=a0.get(i);
d0[1][i]=a0.get(i);
}
for(int i=0;i
d1[0][i]=a1.get(i);
d1[1][i]=a1.get(i);
}
xydataset.addSeries(0,d0);
xydataset.addSeries(1,d1);
JFreeChart chart = ChartFactory.createScatterPlot("分布情况","score","score",xydataset, PlotOrientation.VERTICAL, true, false, false);
ChartFrame frame = new ChartFrame("散点图", chart, true);
chart.setBackgroundPaint(Color.white);
chart.setBorderPaint(Color.GREEN);
chart.setBorderStroke(new BasicStroke(1.5f));
XYPlot xyplot = (XYPlot) chart.getPlot();
xyplot.setBackgroundPaint(new Color(255, 253, 246));
ValueAxis vaaxis = xyplot.getDomainAxis();
vaaxis.setAxisLineStroke(new BasicStroke(1.5f));
ValueAxis va = xyplot.getDomainAxis(0);
va.setAxisLineStroke(new BasicStroke(1.5f));
va.setAxisLineStroke(new BasicStroke(1.5f)); // 坐标轴粗细
va.setAxisLinePaint(new Color(215, 215, 215)); // 坐标轴颜色
xyplot.setOutlineStroke(new BasicStroke(1.5f)); // 边框粗细
va.setLabelPaint(new Color(10, 10, 10)); // 坐标轴标题颜色
va.setTickLabelPaint(new Color(102, 102, 102)); // 坐标轴标尺值颜色
ValueAxis axis = xyplot.getRangeAxis();
axis.setAxisLineStroke(new BasicStroke(1.5f));
XYLineAndShapeRenderer xylineandshaperenderer = (XYLineAndShapeRenderer) xyplot
.getRenderer();
xylineandshaperenderer.setSeriesOutlinePaint(0, Color.WHITE);
xylineandshaperenderer.setUseOutlinePaint(true);
NumberAxis numberaxis = (NumberAxis) xyplot.getDomainAxis();
numberaxis.setAutoRangeIncludesZero(false);
numberaxis.setTickMarkInsideLength(2.0F);
numberaxis.setTickMarkOutsideLength(0.0F);
numberaxis.setAxisLineStroke(new BasicStroke(1.5f));
//输出中文为方框的解决办法
Font font=new Font("黑体",Font.BOLD,18);//测试是可以的
chart.getTitle().setFont(font);
axis.setLabelFont(font);
va.setLabelFont(font);
frame.pack();
frame.setVisible(true);
}
public static void main(String[] args) throws Exception {
File file = new File(Main.class.getClassLoader().getResource("test_score.csv").getPath());
System.out.println(file.exists());
CSVLoader csvLoader = new CSVLoader();
csvLoader.setFile(file);
Instances dataSet = csvLoader.getDataSet();
Clusterer clusterer=kMeans(dataSet);
//输出评价结果
String path="D:\\新建文件夹 (2)\\数据挖掘\\论文\\km_20.csv";
output(path,clusterer,dataSet);
File test = new File(Main.class.getClassLoader().getResource("labeled.csv").getPath());
CSVLoader csvLoader1 = new CSVLoader();
csvLoader1.setFile(test);
Instances tesetData = csvLoader1.getDataSet();
//混淆矩阵
confusionMatrix(clusterer,dataSet,tesetData);
//散点图
show(clusterer,dataSet);
System.out.println(clusterer);
}
}