做的一个结课作业,用Java实现了朴素贝叶斯算法
关于贝叶斯算法可以参考西瓜书、贝叶斯分类器-华校专
设计思路
- 创建数据类、朴素贝叶斯算法类和验证模型的类。
- 数据类DataFrame,用来加载和存储数据信息,并提供操作数据的方法,以及返回数据的方法。
- 朴素贝叶斯类NaiveBayes,贝叶斯算法算法的核心,提供拟合训练数据的方法,以及预测测试数据输出预测分类结果的方法,里面包含了计算先验概率、条件概率、后验概率等的方法。
- 验证模型类Measure,通过提供的真实标记数据和预测标记数据,提供计算预测结果的精度、混淆矩阵、查全率、查准率的方法。
- 其他类DataUtil,提供分割数据集的静态方法,和List与Map中Object对象转化为String对象的方法。
其他具体设计请参照代码清单和注释
先看下包结构和运行结果
注:图中蓝框的是需要的,ml.data其他的类构建好,ml.test用来测试
注:以上为数据集和验证结果,采用的验证方式为 用训练集作为测试集
代码清单
你可能需要下载 weka.jar 、opencsv-4.6.jar 第三方包
ml.data.DataFrame.java
package ml.data;
import java.util.*;
import com.opencsv.CSVReader;
import java.io.*;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
public class DataFrame {
/** 保存数据集的值,Map的key是属性名称,value是以List存储的属性值*/
/** Object对象,可存储String和Double*/
private Map<String,List<Object>> dataMap;
/** 存储所有属性名称*/
private List<String> columns;
/** 存储属性的名称和对应的类型*/
private Map<String,String> types;
/** 存储数据集大小,size[0]为行数,size[1]为列数 */
private int[] size;
/**
* 传入一个Map<String,List<Bbject>>数据,key作为属性名
* @param data
*/
public DataFrame(Map<String,List<Object>> data) {
dataMap = data;
columns = new ArrayList<String>(dataMap.keySet());
}
/**
* 根据输入的文件路径和列名称,加载数据
* @param path 文件路径
* @param columns 指定数据集的属性
* @throws Exception
*/
public DataFrame(String path,List<String> columns) throws Exception {
this(path);
this.columns = columns;
}
/**
* 加载指定数据集文件
* 仅支持 CSV 和 ARFF文件
* @param path 数据集文件路径
* @throws Exception 文件类型不支持
*/
public DataFrame(String path) throws Exception {
List<Object[]> list = null;
//根据文件类型,采用不同的加载器
try {
if(path.split("[.]")[1].equals("csv")) {
list = readCSV(path,true);
System.out.println("csv文件已加载...");
}else if(path.split("[.]")[1].equals("arff")){
list = readArff(path);
System.out.println("arff文件已加载...");
}else {
throw new Exception("错误: 数据文件类型或路径错误");
}
}catch(Exception e) {
e.printStackTrace();
}
//list返回的第一行是属性集合
String[] feats = (String[]) list.get(0);
//实例化类变量
dataMap = new HashMap<String, List<Object>>();
types = new HashMap<String, String>();
columns = new ArrayList<String>();
//获得列数据
List<Object> col_data ;
int len = feats.length;
for(int i=0;i<len;i++) {
// System.out.println(feats[i]);
col_data = new ArrayList<Object>();
for(Object[] elem : list) {
col_data.add(elem[i]);
}
col_data.remove(0);//第一个是属性值 删去
add(col_data,feats[i]);
}
}
/**
* 利用wekaAPI,读取 arff文件,返回所有的样本数据
* @param path 数据集文件路径
* @return List<Object[]>对象,元素类型为Object数组
* 其中List第一行为属性集合
* @throws Exception
*/
public List<Object[]> readArff(String path) throws Exception{
// 返回的list 第一行是属性信息
List<Object[]> list = new ArrayList<Object[]>();
DataSource soure = new DataSource(path);
Instances data = soure.getDataSet();
//得到属性集合
int numAttributes = data.numAttributes();
String[] attNames = new String[numAttributes];
for(int i=0;i<numAttributes;i++){
attNames[i] = data.attribute(i).name();
}
list.add(attNames);
//依次遍历得到所有的属性值
Object[] row ;
for(Object rowArff: data.toArray()) {
String rowStr = rowArff.toString();
//得到一个样本的值
String[] rowList = rowStr.split(",");
int len = rowList.length;
row = new Object[len];
for(int i=0; i<len; i++) {
try { //如果是数值类型
Double elem = Double.valueOf(rowList[i]);
row[i] = elem;
}catch(Exception e) {
//如果是字符串类型
row[i] = rowList[i];
}
}
list.add(row);
}
return list;
}
/**
* 读取CSV文件,默认csv第一行是属性行
* @param path CSV文件路径
* @return List 数据集的值
* @throws Exception
* @see #readCSV(String,boolean)
*/
public List<Object[]> readCSV(String path) throws Exception {
return readCSV(path, true);
}
/**
* 读取CSV文件,返回数据值
* @param path CSV文件路径
* @param columnsRow CSV第一行时候是列信息
* @return List<Object[]>对象,元素类型为Object数组
* <br/>其中List第一行为属性集合
* @throws Exception
*/
public List<Object[]> readCSV(String path,boolean columnsRow) throws Exception {
List<Object[]> list = new ArrayList<Object[]>();
InputStreamReader reader = new InputStreamReader(new FileInputStream(path),"UTF-8");
@SuppressWarnings("resource")
CSVReader csv = new CSVReader(reader);
List<String[]> csvList = null; //保存CSV文件数据
if(columnsRow) { // 返回的list第一个存储的是属性信息
list.add(csv.readNext());
csvList = csv.readAll();
}else {
csvList = csv.readAll();
}
//同readArff(Stirng)方法
Object[] row;
for(String[] csvRow : csvList) {
int len = csvRow.length;
row = new Object[csvRow.length];
for(int i=0;i<len;i++) {
try {
Double elem =