朴素贝叶斯算法实现 | Java | 机器学习 | 贝叶斯

这是一篇关于使用Java实现朴素贝叶斯算法的结课作业介绍。文章详细介绍了设计思路,包括创建数据类DataFrame、朴素贝叶斯类NaiveBayes以及验证模型类Measure。NaiveBayes类负责计算概率并进行预测,而Measure类则用于评估模型性能。代码清单包含关键的Java文件,如DataFrame、NaiveBayes、DataUtil和Test。
摘要由CSDN通过智能技术生成

做的一个结课作业,用Java实现了朴素贝叶斯算法

关于贝叶斯算法可以参考西瓜书、贝叶斯分类器-华校专

设计思路

  1. 创建数据类、朴素贝叶斯算法类和验证模型的类。
  2. 数据类DataFrame,用来加载和存储数据信息,并提供操作数据的方法,以及返回数据的方法。
  3. 朴素贝叶斯类NaiveBayes,贝叶斯算法算法的核心,提供拟合训练数据的方法,以及预测测试数据输出预测分类结果的方法,里面包含了计算先验概率、条件概率、后验概率等的方法。
  4. 验证模型类Measure,通过提供的真实标记数据和预测标记数据,提供计算预测结果的精度、混淆矩阵、查全率、查准率的方法。
  5. 其他类DataUtil,提供分割数据集的静态方法,和List与Map中Object对象转化为String对象的方法。

其他具体设计请参照代码清单和注释


先看下包结构和运行结果

注:图中蓝框的是需要的,ml.data其他的类构建好,ml.test用来测试

注:以上为数据集和验证结果,采用的验证方式为 用训练集作为测试集

 

代码清单

你可能需要下载 weka.jar 、opencsv-4.6.jar 第三方包

ml.data.DataFrame.java

package ml.data;

import java.util.*;
import com.opencsv.CSVReader;

import java.io.*;

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

public class DataFrame {
	
	/** 保存数据集的值,Map的key是属性名称,value是以List存储的属性值*/
	/** Object对象,可存储String和Double*/
	private Map<String,List<Object>> dataMap;
	
	/** 存储所有属性名称*/
	private List<String> columns;
	
	/** 存储属性的名称和对应的类型*/
	private Map<String,String> types;
	
	/** 存储数据集大小,size[0]为行数,size[1]为列数 */
	private int[] size;
	
	
	/**
	 * 传入一个Map<String,List<Bbject>>数据,key作为属性名
	 * @param data
	 */
	public DataFrame(Map<String,List<Object>> data) {
		dataMap = data;
		columns = new ArrayList<String>(dataMap.keySet());
	}
	
	
	/**
	 * 根据输入的文件路径和列名称,加载数据
	 * @param path 文件路径
	 * @param columns 指定数据集的属性
	 * @throws Exception
	 */
	public DataFrame(String path,List<String> columns) throws Exception {
		this(path);
		this.columns = columns;
	}
	
	
	/**
	 * 加载指定数据集文件 
	 * 仅支持 CSV 和  ARFF文件
	 * @param path 数据集文件路径
	 * @throws Exception 文件类型不支持
	 */
	public DataFrame(String path) throws Exception {
		
		List<Object[]> list = null;
		//根据文件类型,采用不同的加载器
		try {
			if(path.split("[.]")[1].equals("csv")) {
				list = readCSV(path,true);
				System.out.println("csv文件已加载...");
			}else if(path.split("[.]")[1].equals("arff")){
				list = readArff(path);
				System.out.println("arff文件已加载...");
			}else {
				throw new Exception("错误: 数据文件类型或路径错误");
			}
		}catch(Exception e) {
			e.printStackTrace();
		}
		//list返回的第一行是属性集合
		String[] feats = (String[]) list.get(0);
		//实例化类变量
		dataMap = new HashMap<String, List<Object>>();
		types = new HashMap<String, String>();
		columns = new ArrayList<String>();
		
		//获得列数据
		List<Object> col_data ;
		int len = feats.length;		
		for(int i=0;i<len;i++) {
//			System.out.println(feats[i]);
			col_data = new ArrayList<Object>();
			for(Object[] elem : list) {
				col_data.add(elem[i]);
			}
			col_data.remove(0);//第一个是属性值 删去
			add(col_data,feats[i]);
		}
		
	}
	
	/**
	 * 利用wekaAPI,读取 arff文件,返回所有的样本数据
	 * @param path 数据集文件路径
	 * @return List<Object[]>对象,元素类型为Object数组
	 * 			其中List第一行为属性集合
	 * @throws Exception
	 */
	public List<Object[]> readArff(String path) throws Exception{
		// 返回的list 第一行是属性信息
		List<Object[]> list = new ArrayList<Object[]>();
		DataSource soure = new DataSource(path);
		Instances data = soure.getDataSet();
		
		//得到属性集合
		int numAttributes = data.numAttributes();
		String[] attNames = new String[numAttributes];
		for(int i=0;i<numAttributes;i++){
			attNames[i] = data.attribute(i).name();
		}
		list.add(attNames);
		
		//依次遍历得到所有的属性值
		Object[] row ;
		for(Object rowArff: data.toArray()) {
			String rowStr = rowArff.toString();
			//得到一个样本的值
			String[] rowList = rowStr.split(",");
			int len = rowList.length;
			row = new Object[len];
			for(int i=0; i<len; i++) {
				try { //如果是数值类型
					Double elem = Double.valueOf(rowList[i]);
					row[i] = elem;
				}catch(Exception e) {
					//如果是字符串类型
					row[i] = rowList[i];
				}
			}
			list.add(row);
		}
		return list;
	}
	
	
	
	/**
	 * 读取CSV文件,默认csv第一行是属性行
	 * @param path CSV文件路径
	 * @return List 数据集的值
	 * @throws Exception
	 * @see #readCSV(String,boolean)
	 */
	public List<Object[]> readCSV(String path) throws Exception {
		return readCSV(path, true);
	}
	
	
	/**
	 * 读取CSV文件,返回数据值
	 * @param path CSV文件路径
	 * @param columnsRow CSV第一行时候是列信息
	 * @return List<Object[]>对象,元素类型为Object数组
	 * 			<br/>其中List第一行为属性集合
	 * @throws Exception
	 */
	public List<Object[]> readCSV(String path,boolean columnsRow) throws Exception {
		
		List<Object[]> list = new ArrayList<Object[]>();
		InputStreamReader reader = new InputStreamReader(new FileInputStream(path),"UTF-8");
		@SuppressWarnings("resource")
		CSVReader csv = new CSVReader(reader);
		List<String[]> csvList = null;	//保存CSV文件数据
		
		if(columnsRow) { 				// 返回的list第一个存储的是属性信息	
			list.add(csv.readNext()); 		
			csvList = csv.readAll();
		}else {			
			csvList = csv.readAll();
		}
		
		//同readArff(Stirng)方法
		Object[] row;
		for(String[] csvRow : csvList) {
			int len = csvRow.length;
			row = new Object[csvRow.length];
			for(int i=0;i<len;i++) {
				try {
					Double elem =
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值