对于朴素贝叶斯算法相信做数据挖掘和推荐系统的小伙们都耳熟能详了,算法原理我就不啰嗦了。我主要想通过java代码实现朴素贝叶斯算法,思想:
1. 用javabean +Arraylist 对于训练数据存储
2. 对于样本数据训练
具体的代码如下:
- package NB;
- /**
- * 训练样本的属性 javaBean
- *
- */
- public class JavaBean {
- int age;
- String income;
- String student;
- String credit_rating;
- String buys_computer;
- public JavaBean(){
- }
- public JavaBean(int age,String income,String student,String credit_rating,String buys_computer){
- this.age=age;
- this.income=income;
- this.student=student;
- this.credit_rating=credit_rating;
- this.buys_computer=buys_computer;
- }
- public int getAge() {
- return age;
- }
- public void setAge(int age) {
- this.age = age;
- }
- public String getIncome() {
- return income;
- }
- public void setIncome(String income) {
- this.income = income;
- }
- public String getStudent() {
- return student;
- }
- public void setStudent(String student) {
- this.student = student;
- }
- public String getCredit_rating() {
- return credit_rating;
- }
- public void setCredit_rating(String credit_rating) {
- this.credit_rating = credit_rating;
- }
- public String getBuys_computer() {
- return buys_computer;
- }
- public void setBuys_computer(String buys_computer) {
- this.buys_computer = buys_computer;
- }
- @Override
- public String toString() {
- return "JavaBean [age=" + age + ", income=" + income + ", student="
- + student + ", credit_rating=" + credit_rating + ", buys_computer="
- + buys_computer + "]";
- }
- }
- package NB;
- import java.io.BufferedReader;
- import java.io.File;
- import java.io.FileReader;
- import java.util.ArrayList;
- public class TestNB {
- /**data_length
- * 算法的思想
- */
- public static ArrayList<JavaBean> list = new ArrayList<JavaBean>();;
- static int data_length=0;
- public static void main(String[] args) {
- // 1.读取数据,放入list容器中
- File file = new File("E://test.txt");
- txt2String(file);
- //数据测试样本
- testData(25,"Medium","Yes","Fair");
- }
- // 读取样本数据
- public static void txt2String(File file) {
- try {
- BufferedReader br = new BufferedReader(new FileReader(file));// 构造一个BufferedReader类来读取文件
- String s = null;
- while ((s = br.readLine()) != null) {// 使用readLine方法,一次读一行
- data_length++;
- splitt(s);
- }
- br.close();
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
- // 存入ArrayList中
- public static void splitt(String str){
- String strr = str.trim();
- String[] abc = strr.split("[\\p{Space}]+");
- int age=Integer.parseInt(abc[0]);
- JavaBean bean=new JavaBean(age, abc[1], abc[2], abc[3], abc[4]);
- list.add(bean);
- }
- // 训练样本,测试
- public static void testData(int age,String a,String b,String c){
- //训练样本
- int number_yes=0;
- int bumber_no=0;
- // age情况 个数
- int num_age_yes=0;
- int num_age_no=0;
- // income
- int num_income_yes=0;
- int num_income_no=0;
- // student
- int num_student_yes=0;
- int num_stdent_no=0;
- //credit
- int num_credit_yes=0;
- int num_credit_no=0;
- //遍历List 获得数据
- for(int i=0;i<list.size();i++){
- JavaBean bb=list.get(i);
- if(bb.getBuys_computer().equals("Yes")){ //Yes
- number_yes++;
- if(bb.getIncome().equals(a)){//income
- num_income_yes++;
- }
- if(bb.getStudent().equals(b)){//student
- num_student_yes++;
- }
- if(bb.getCredit_rating().equals(c)){//credit
- num_credit_yes++;
- }
- if(bb.getAge()==age){//age
- num_age_yes++;
- }
- }else {//No
- bumber_no++;
- if(bb.getIncome().equals(a)){//income
- num_income_no++;
- }
- if(bb.getStudent().equals(b)){//student
- num_stdent_no++;
- }
- if(bb.getCredit_rating().equals(c)){//credit
- num_credit_no++;
- }
- if(bb.getAge()==age){//age
- num_age_no++;
- }
- }
- }
- System.out.println("购买的历史个数:"+number_yes);
- System.out.println("不买的历史个数:"+bumber_no);
- System.out.println("购买+age:"+num_age_yes);
- System.out.println("不买+age:"+num_age_no);
- System.out.println("购买+income:"+num_income_yes);
- System.out.println("不买+income:"+num_income_no);
- System.out.println("购买+stundent:"+num_student_yes);
- System.out.println("不买+student:"+num_stdent_no);
- System.out.println("购买+credit:"+num_credit_yes);
- System.out.println("不买+credit:"+num_credit_no);
- 概率判断
- double buy_yes=number_yes*1.0/data_length; // 买的概率
- double buy_no=bumber_no*1.0/data_length; // 不买的概率
- System.out.println("训练数据中买的概率:"+buy_yes);
- System.out.println("训练数据中不买的概率:"+buy_no);
- /// 未知用户的判断
- double nb_buy_yes=(1.0*num_age_yes/number_yes)*(1.0*num_income_yes/number_yes)*(1.0*num_student_yes/number_yes)*(1.0*num_credit_yes/number_yes)*buy_yes;
- double nb_buy_no=(1.0*num_age_no/bumber_no)*(1.0*num_income_no/bumber_no)*(1.0*num_stdent_no/bumber_no)*(1.0*num_credit_no/bumber_no)*buy_no;
- System.out.println("新用户买的概率:"+nb_buy_yes);
- System.out.println("新用户不买的概率:"+nb_buy_no);
- if(nb_buy_yes>nb_buy_no){
- System.out.println("新用户买的概率大");
- }else {
- System.out.println("新用户不买的概率大");
- }
- }
- }
对于样本数据:
- 25 High No Fair No
- 25 High No Excellent No
- 33 High No Fair Yes
- 41 Medium No Fair Yes
- 41 Low Yes Fair Yes
- 41 Low Yes Excellent No
- 33 Low Yes Excellent Yes
- 25 Medium No Fair No
- 25 Low Yes Fair Yes
- 41 Medium Yes Fair Yes
- 25 Medium Yes Excellent Yes
- 33 Medium No Excellent Yes
- 33 High Yes Fair Yes
- 41 Medium No Excellent No
对于未知用户的数据得出的结果:
- 购买的历史个数:9
- 不买的历史个数:5
- 购买+age:2
- 不买+age:3
- 购买+income:4
- 不买+income:2
- 购买+stundent:6
- 不买+student:1
- 购买+credit:6
- 不买+credit:2
- 训练数据中买的概率:0.6428571428571429
- 训练数据中不买的概率:0.35714285714285715
- 新用户买的概率:0.028218694885361547
- 新用户不买的概率:0.006857142857142858
- 新用户买的概率大