注:本算法的实现仅仅适用于小规模数据集的实验与测试,不适合用于工程应用
算法假定训练数据各属性列的值均是离散类型的。若是非离散类型的数据,需要首先进行数据的预处理,将非离散型的数据离散化。
算法中使用到了DecimalCaculate类,该类是java中BigDecimal类的扩展,用于高精度浮点数的运算。该类的实现同本人转载的一篇博文:对BigDecimal常用方法的归类中的Arith类相同。
算法实现的代码如下
- package Bayes;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.Map;
- import util.DecimalCalculate;
- /**
- * 贝叶斯主体类
- * @author Rowen
- * @qq 443773264
- * @mail luowen3405@163.com
- * @blog blog.csdn.net/luowen3405
- * @data 2011.03.15
- */
- public class Bayes {
- /**
- * 将原训练元组按类别划分
- * @param datas 训练元组
- * @return Map<类别,属于该类别的训练元组>
- */
- Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){
- Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>();
- ArrayList<String> t = null;
- String c = "";
- for (int i = 0; i < datas.size(); i++) {
- t = datas.get(i);
- c = t.get(t.size() - 1);
- if (map.containsKey(c)) {
- map.get(c).add(t);
- } else {
- ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>();
- nt.add(t);
- map.put(c, nt);
- }
- }
- return map;
- }
- /**
- * 在训练数据的基础上预测测试元组的类别
- * @param datas 训练元组
- * @param testT 测试元组
- * @return 测试元组的类别
- */
- public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) {
- Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas);
- Object classes[] = doc.keySet().toArray();
- double maxP = 0.00;
- int maxPIndex = -1;
- for (int i = 0; i < doc.size(); i++) {
- String c = classes[i].toString();
- ArrayList<ArrayList<String>> d = doc.get(c);
- double pOfC = DecimalCalculate.div(d.size(), datas.size(), 3);
- for (int j = 0; j < testT.size(); j++) {
- double pv = this.pOfV(d, testT.get(j), j);
- pOfC = DecimalCalculate.mul(pOfC, pv);
- }
- if(pOfC > maxP){
- maxP = pOfC;
- maxPIndex = i;
- }
- }
- return classes[maxPIndex].toString();
- }
- /**
- * 计算指定属性列上指定值出现的概率
- * @param d 属于某一类的训练元组
- * @param value 列值
- * @param index 属性列索引
- * @return 概率
- */
- private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) {
- double p = 0.00;
- int count = 0;
- int total = d.size();
- ArrayList<String> t = null;
- for (int i = 0; i < total; i++) {
- if(d.get(i).get(index).equals(value)){
- count++;
- }
- }
- p = DecimalCalculate.div(count, total, 3);
- return p;
- }
- }
算法测试类:
- package Bayes;
- import java.io.BufferedReader;
- import java.io.IOException;
- import java.io.InputStreamReader;
- import java.util.ArrayList;
- import java.util.StringTokenizer;
- /**
- * 贝叶斯算法测试类
- * @author Rowen
- * @qq 443773264
- * @mail luowen3405@163.com
- * @blog blog.csdn.net/luowen3405
- * @data 2011.03.15
- */
- public class TestBayes {
- /**
- * 读取测试元组
- * @return 一条测试元组
- * @throws IOException
- */
- public ArrayList<String> readTestData() throws IOException{
- ArrayList<String> candAttr = new ArrayList<String>();
- BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
- String str = "";
- while (!(str = reader.readLine()).equals("")) {
- StringTokenizer tokenizer = new StringTokenizer(str);
- while (tokenizer.hasMoreTokens()) {
- candAttr.add(tokenizer.nextToken());
- }
- }
- return candAttr;
- }
- /**
- * 读取训练元组
- * @return 训练元组集合
- * @throws IOException
- */
- public ArrayList<ArrayList<String>> readData() throws IOException {
- ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();
- BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
- String str = "";
- while (!(str = reader.readLine()).equals("")) {
- StringTokenizer tokenizer = new StringTokenizer(str);
- ArrayList<String> s = new ArrayList<String>();
- while (tokenizer.hasMoreTokens()) {
- s.add(tokenizer.nextToken());
- }
- datas.add(s);
- }
- return datas;
- }
- public static void main(String[] args) {
- TestBayes tb = new TestBayes();
- ArrayList<ArrayList<String>> datas = null;
- ArrayList<String> testT = null;
- Bayes bayes = new Bayes();
- try {
- System.out.println("请输入训练数据");
- datas = tb.readData();
- while (true) {
- System.out.println("请输入测试元组");
- testT = tb.readTestData();
- String c = bayes.predictClass(datas, testT);
- System.out.println("The class is: " + c);
- }
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- }
训练数据:
- youth high no fair no
- youth high no excellent no
- middle_aged high no fair yes
- senior medium no fair yes
- senior low yes fair yes
- senior low yes excellent no
- middle_aged low yes excellent yes
- youth medium no fair no
- youth low yes fair yes
- senior medium yes fair yes
- youth medium yes excellent yes
- middle_aged medium no excellent yes
- middle_aged high yes fair yes
- senior medium no excellent no
对原训练数据进行测试,测试如果如下:
- 请输入测试元组
- youth high no fair
- The class is: no
- 请输入测试元组
- youth high no excellent
- The class is: no
- 请输入测试元组
- middle_aged high no fair
- The class is: yes
- 请输入测试元组
- senior medium no fair
- The class is: yes
- 请输入测试元组
- senior low yes fair
- The class is: yes
- 请输入测试元组
- senior low yes excellent
- The class is: yes
- 请输入测试元组
- middle_aged low yes excellent
- The class is: yes
- 请输入测试元组
- youth medium no fair
- The class is: no
- 请输入测试元组
- youth low yes fair
- The class is: yes
- 请输入测试元组
- senior medium yes fair
- The class is: yes
- 请输入测试元组
- youth medium yes excellent
- The class is: yes
- 请输入测试元组
- middle_aged medium no excellent
- The class is: yes
- 请输入测试元组
- middle_aged high yes fair
- The class is: yes
- 请输入测试元组
- senior medium no excellent
- The class is: no
测试结果显示14个测试实例中有13个分类是正确的,正确率为93%,说明算法能够给出一个准确的预测与分类,但是算法还需改进以提高正确率。
改进的可选方法之一:
为避免单个属性值对分类结果的权重过大,例如当某属性值在某一类中出现0次时,该属性值就决定了测试实例已经不可能属于该类了,这就可能会造成误差,因此在计算概率时可能进行如下改进:
将原先的P(Xk|Ci)=|Xk| / |Ci| 改为 P(Xk|Ci)=(|Xk|+mp) / (|Ci|+m),其中m可设定为训练元组的个数,p为等可能假设的先验概率。