朴素贝叶斯分类器【java实现 + 从mysql数据库读数据】

目标:

1、用java写一个贝叶斯分类器,通过一个人的几项特征(性别、是否吸烟、是否纹身、是否戴眼镜、是否骑自行      车)来判断其是否为一个学生。

2、其中训练数据从mysql数据库中读取,测试数据从标准输入输出读取


关于贝叶斯分类器的算法原理很好理解,在此不再赘述。下面是实现:

建立项目,分为3个.java类:

1、       FetchData:读数据,包含两个方法

    a)    从数据库中读入训练数据。

    b)    从标准输入输出中读入测试数据。

2、       Bayes:就是实现贝叶斯分类器的核心算法,先将读入的训练数据按类别分类构造一个map,再对每个类别分别计算先验概率,进而算出贝叶斯公式中的分子(分母都一样故只用比较分子大小即可),取算出概率最大的类别返回即可。(这里了解原理就不难理解)

3、       Main:就是主程序,实现调用方法和一些给用户的提示和相应。

然后我自己设计了少量的数据测试了一下:



执行代码,验证结果如图:




代码如下:

1、FetchData.java

package IsStudent_bys;

import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class FetchData {

	//连接数据库,读取训练数据
	//输入:数据库
	//输出:可变长数组
	public ArrayList<ArrayList<String>> fetch_traindata(){
		ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>();  //待返回
		
		Connection conn;    //与特定数据库的连接(会话)的变量
		String driver = "com.mysql.jdbc.Driver";  //驱动程序名?(待深入了解)
		String url = "jdbc:mysql://localhost:3306/Bayes";  //指向要访问的数据库!注意后面跟的是数据库名称
		String user = "root";   //navicat for sql配置的用户名
		String password = "123456";  //navicat for sql配置的密码
		try{
			Class.forName(driver);  //用class加载动态链接库——驱动程序
			conn = DriverManager.getConnection(url,user,password);  //利用信息链接数据库
			if(!conn.isClosed())
				System.out.println("Succeeded connecting to the Database!");
			
			Statement statement = conn.createStatement();  //用statement 来执行sql语句
			String sql = "select * from TrainData";   //这是sql语句中的查询某个表,注意后面的emp是表名!!!
			ResultSet rs = statement.executeQuery(sql);  //用于返回结果
			
			String str = null;
			while(rs.next()){  //一直读到最后一条表
				ArrayList<String> s= new ArrayList<String>();
				str = rs.getString("Sex");  //分别读取相应栏位的信息加入到可变长数组中
				s.add(str);
				str = rs.getString("tatto");
				s.add(str);
				str = rs.getString("smoking");
				s.add(str);
				str = rs.getString("wearglasses");
				s.add(str);
				str = rs.getString("ridebike");
				s.add(str);
				str = rs.getString("isStudent");
				s.add(str);
				dataSet.add(s);  //加入dataSet
				//System.out.println(s);  输出中间结果调试
			}
			rs.close();
			conn.close();
		}catch(ClassNotFoundException e){  //catch不同的错误信息,并报错
			System.out.println("Sorry,can`t find the Driver!");
			e.printStackTrace();
		}catch(SQLException e){
			e.printStackTrace();
		}catch (Exception e) {
			e.printStackTrace();
		}finally{
			System.out.println("数据库训练数据读取成功!");
		}
		return dataSet;
	}
	
	
	public ArrayList<String> read_testdata(String str) throws IOException  //将用户输入的一整行字符串分割解析成可变长数组
	{
		ArrayList<String> testdata=new ArrayList<String>();  //待返回
		StringTokenizer tokenizer = new StringTokenizer(str);  //这是借鉴学习了StringTokenizer类型(待深入了解)
		while (tokenizer.hasMoreTokens()) { 
			testdata.add(tokenizer.nextToken());
		}
		return testdata;
	}
}

2、Bayes.java

package IsStudent_bys;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public class Bayes {

	//按类别分类
	//输入:训练数据(dataSet)
	//输出:类别到训练数据的一个Map
	public Map<String,ArrayList<ArrayList<String>>> classify(ArrayList<ArrayList<String>> dataSet){
		Map<String,ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>(); //待返回的Map
		int num=dataSet.size();
		for(int i=0;i<num;i++)  //遍历所有数据项
		{
			ArrayList<String> Y = dataSet.get(i);  //将第i个训练样本的信息取出
			String Class = Y.get(Y.size()-1).toString();  //约定将类别信息放在最后一个字符串
			
			if(map.containsKey(Class)){  //判断map中是否已经有这个类了
				map.get(Class).add(Y);
			}else{  //若没有这个类就新建一个可变长数组记录并加入map
				ArrayList<ArrayList<String>> nlist = new ArrayList<ArrayList<String>>();
				nlist.add(Y);
				map.put(Class,nlist);
			}
		}
		return map;
	}
	
	//计算分类后每个类对应的样本中某个特征出现的概率
	//输入:某一类别对应的数据(classdata) 目标值(value) 相应的列值(index)
	//输出:该类数据中相应列上的值等于目标值得频率
	public double CalPro_yj_c(ArrayList<ArrayList<String>> classdata, String value, int index){
		int sum = 0;  //sum用于记录相同特征出现的频数
		int num = classdata.size();
		for(int i=0;i<num;i++)
		{
			ArrayList<String> Y = classdata.get(i);
			if(Y.get(index).equals(value)) sum++;  //相同则计数
		}
		return (double)sum/num;  //返回频率,以频率带概率
		
	}
	
	//贝叶斯分类器主函数
	//输入:训练集(可变长数组);待分类集
	//输出:概率最大的类别
	public String bys_Main(ArrayList<ArrayList<String>> dataSet, ArrayList<String> testSet){
		Map<String, ArrayList<ArrayList<String>>> doc = this.classify(dataSet);  //用本class中的分类函数构造映射
		
		Object classes[] = doc.keySet().toArray(); //把map中所有的key取出来(即所有类别) ,借鉴学习了object的使用(待深入了解)
		double Max_Value=0.0; //最大的概率
		int Max_Class=-1;     //用于记录最大类的编号
		for(int i=0;i<doc.size();i++)  //对每一个类分别计算,本程序只有两个类
		{
			String c = classes[i].toString();  //将类提取出
			ArrayList<ArrayList<String>> y = doc.get(c);  //提取该类对应的数据列表
			double prob = (double)y.size()/dataSet.size();  //计算比例
			
			System.out.println(c+" : "+prob);  //输出该类的样本占总样本个数的比例!
			
			for(int j=0;j<testSet.size();j++)  //对每个属性计算先验概率
			{
				double P_yj_c = CalPro_yj_c(y,testSet.get(j),j);
				//输出中间结果以便测试System.out.println("now in bys_Main!!"+P_yj_c);
				prob = prob*P_yj_c;
			}
			
			System.out.printf("P(%s | testcase) * P(testcase) = %f\n",c,prob);  //输出分子的概率大小
			if(prob>Max_Value)  //更新分子最大概率
			{
				Max_Value=prob;
				Max_Class=i;
			}
		}
		return classes[Max_Class].toString();
	}
}

3、Main.java

package IsStudent_bys;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Scanner;

public class Main {

	//主函数,读取数据库,并读入待判定数据,输出结果
	public static void main(String[] args) {
		FetchData Fdata = new FetchData();   //java对函数的调用要先声明相应的对象再调用
		Bayes bys = new Bayes();
        ArrayList<ArrayList<String>> dataSet = null; //训练数据列表
        ArrayList<String> testSet = null; //测试数据
        try{
        	System.out.println("从数据库读入训练数据:");
        	dataSet = Fdata.fetch_traindata();   //读取训练数据集合
        	System.out.println("请输入测试数据:"); 
        	Scanner cin = new Scanner(new BufferedInputStream(System.in));  //从标准输入输出中读取测试数据
    		while(cin.hasNext())  //支持多条测试数据读取
    		{
    			String str = cin.nextLine();   //先读入一行
    			testSet = Fdata.read_testdata(str);//将这一行进行字符串分隔解析后返回可变长数组类型
    			//System.out.println(testSet);  //输出中间结果
    			String ans = bys.bys_Main(dataSet, testSet);  //调用贝叶斯分类器
            	if(ans.equals("yes")) System.out.println("Yes!!! He's likely to be a student!");  //输出结果
            	else System.out.println("No!!! It is NOT likely to be a student!!!");
    		}
        	cin.close();
        }catch (IOException e) {  //处理异常
            e.printStackTrace();
        } 
	}

}














  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值