朴素贝叶斯文本分类java实现

转载 2015年05月21日 14:41:58

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.data.util.IoUtil;

public class NativeBayes {
    /**
     * 默认频率
     */
    private double defaultFreq = 0.1;

    /**
     * 训练数据的比例
     */
    private Double trainingPercent = 0.8;

    private Map<String, List<String>> files_all = new HashMap<String, List<String>>();

    private Map<String, List<String>> files_train = new HashMap<String, List<String>>();

    private Map<String, List<String>> files_test = new HashMap<String, List<String>>();

    public NativeBayes() {

    }

    /**
     * 每个分类的频率
     */
    private Map<String, Integer> classFreq = new HashMap<String, Integer>();

    private Map<String, Double> ClassProb = new HashMap<String, Double>();

    /**
     * 特征总数
     */
    private Set<String> WordDict = new HashSet<String>();

    private Map<String, Map<String, Integer>> classFeaFreq = new HashMap<String, Map<String, Integer>>();

    private Map<String, Map<String, Double>> ClassFeaProb = new HashMap<String, Map<String, Double>>();

    private Map<String, Double> ClassDefaultProb = new HashMap<String, Double>();

    /**
     * 计算准确率
     * @param reallist 真实类别
     * @param pridlist 预测类别
     */
    public void Evaluate(List<String> reallist, List<String> pridlist){
        double correctNum = 0.0;
        for (int i = 0; i < reallist.size(); i++) {
            if(reallist.get(i) == pridlist.get(i)){
                correctNum += 1;
            }
        }
        double accuracy = correctNum / reallist.size();
        System.out.println("准确率为:" + accuracy);
    }

    /**
     * 计算精确率和召回率
     * @param reallist
     * @param pridlist
     * @param classname
     */
    public void CalPreRec(List<String> reallist, List<String> pridlist, String classname){
        double correctNum = 0.0;
        double allNum = 0.0;//测试数据中,某个分类的文章总数
        double preNum = 0.0;//测试数据中,预测为该分类的文章总数

        for (int i = 0; i < reallist.size(); i++) {
            if(reallist.get(i) == classname){
                allNum += 1;
                if(reallist.get(i) == pridlist.get(i)){
                    correctNum += 1;
                }
            }
            if(pridlist.get(i) == classname){
                preNum += 1;
            }
        }
        System.out.println(classname + " 精确率(跟预测分类比较):" + correctNum / preNum + " 召回率(跟真实分类比较):" + correctNum / allNum);
    }

    /**
     * 用模型进行预测
     */
    public void PredictTestData() {
        List<String> reallist=new ArrayList<String>();
        List<String> pridlist=new ArrayList<String>();

        for (Entry<String, List<String>> entry : files_test.entrySet()) {
            String realclassname = entry.getKey();
            List<String> files = entry.getValue();


            for (String file : files) {
                reallist.add(realclassname);


                List<String> classnamelist=new ArrayList<String>();
                List<Double> scorelist=new ArrayList<Double>();
                for (Entry<String, Double> entry_1 : ClassProb.entrySet()) {
                    String classname = entry_1.getKey();
                    //先验概率
                    Double score = Math.log(entry_1.getValue());

                    String[] words = IoUtil.readFromFile(new File(file)).split(" ");
                    for (String word : words) {
                        if(!WordDict.contains(word)){
                            continue;
                        }

                        if(ClassFeaProb.get(classname).containsKey(word)){
                            score += Math.log(ClassFeaProb.get(classname).get(word));
                        }else{
                            score += Math.log(ClassDefaultProb.get(classname));
                        }
                    }

                    classnamelist.add(classname);
                    scorelist.add(score);
                }

                Double maxProb = Collections.max(scorelist);
                int idx = scorelist.indexOf(maxProb);
                pridlist.add(classnamelist.get(idx));
            }
        }

        Evaluate(reallist, pridlist);

        for (String cname : files_test.keySet()) {
            CalPreRec(reallist, pridlist, cname);
        }

    }

    /**
     * 模型训练
     */
    public void createModel() {
        double sum = 0.0;
        for (Entry<String, Integer> entry : classFreq.entrySet()) {
            sum+=entry.getValue();
        }
        for (Entry<String, Integer> entry : classFreq.entrySet()) {
            ClassProb.put(entry.getKey(), entry.getValue()/sum);
        }


        for (Entry<String, Map<String, Integer>> entry : classFeaFreq.entrySet()) {
            sum = 0.0;
            String classname = entry.getKey();
            for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
                sum += entry_1.getValue();
            }
            double newsum = sum + WordDict.size()*defaultFreq;

            Map<String, Double> feaProb = new HashMap<String, Double>();
            ClassFeaProb.put(classname, feaProb);

            for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
                String word = entry_1.getKey();
                feaProb.put(word, (entry_1.getValue() +defaultFreq) /newsum);
            }
            ClassDefaultProb.put(classname, defaultFreq/newsum);
        }
    }

    /**
     * 加载训练数据
     */
    public void loadTrainData(){
        for (Entry<String, List<String>> entry : files_train.entrySet()) {
            String classname = entry.getKey();
            List<String> docs = entry.getValue();

            classFreq.put(classname, docs.size());

            Map<String, Integer> feaFreq = new HashMap<String, Integer>();
            classFeaFreq.put(classname, feaFreq);

            for (String doc : docs) {
                String[] words = IoUtil.readFromFile(new File(doc)).split(" ");
                for (String word : words) {

                    WordDict.add(word);

                    if(feaFreq.containsKey(word)){
                        int num = feaFreq.get(word) + 1;
                        feaFreq.put(word, num);
                    }else{
                        feaFreq.put(word, 1);
                    }
                }
            }    


        }
        System.out.println(classFreq.size()+" 分类, " + WordDict.size()+" 特征词");
    }

    /**
     * 将数据分为训练数据和测试数据
     * 
     * @param dataDir
     */
    public void splitData(String dataDir) {
        // 用文件名区分类别
        Pattern pat = Pattern.compile("\\d+([a-z]+?)\\.");
        dataDir = "testdata/allfiles";
        File f = new File(dataDir);
        File[] files = f.listFiles();
        for (File file : files) {
            String fname = file.getName();
            Matcher m = pat.matcher(fname);
            if (m.find()) {
                String cname = m.group(1);
                if (files_all.containsKey(cname)) {
                    files_all.get(cname).add(file.toString());
                } else {
                    List<String> tmp = new ArrayList<String>();
                    tmp.add(file.toString());
                    files_all.put(cname, tmp);
                }
            } else {
                System.out.println("err: " + file);
            }
        }

        System.out.println("统计数据:");
        for (Entry<String, List<String>> entry : files_all.entrySet()) {
            String cname = entry.getKey();
            List<String> value = entry.getValue();
            // System.out.println(cname + " : " + value.size());

            List<String> train = new ArrayList<String>();
            List<String> test = new ArrayList<String>();

            for (String str : value) {
                if (Math.random() <= trainingPercent) {// 80%用来训练 , 20%测试
                    train.add(str);
                } else {
                    test.add(str);
                }
            }

            files_train.put(cname, train);
            files_test.put(cname, test);
        }

        System.out.println("所有文件数:");
        printStatistics(files_all);
        System.out.println("训练文件数:");
        printStatistics(files_train);
        System.out.println("测试文件数:");
        printStatistics(files_test);

    }

    /**
     * 打印统计信息
     * 
     * @param m
     */
    public void printStatistics(Map<String, List<String>> m) {
        for (Entry<String, List<String>> entry : m.entrySet()) {
            String cname = entry.getKey();
            List<String> value = entry.getValue();
            System.out.println(cname + " : " + value.size());
        }
        System.out.println("--------------------------------");
    }

    public static void main(String[] args) {
        NativeBayes bayes = new NativeBayes();
        bayes.splitData(null);
        bayes.loadTrainData();
        bayes.createModel();
        bayes.PredictTestData();

    }

}

所有文件数:
sports : 1018
auto : 1020
business : 1028
--------------------------------
训练文件数:
sports : 791
auto : 812
business : 808
--------------------------------
测试文件数:
sports : 227
auto : 208
business : 220
--------------------------------
分类, 39613 特征词
准确率为:0.9801526717557252
sports 精确率(跟预测分类比较):0.9956140350877193 召回率(跟真实分类比较):1.0
auto 精确率(跟预测分类比较):0.9579439252336449 召回率(跟真实分类比较):0.9855769230769231
business 精确率(跟预测分类比较):0.9859154929577465 召回率(跟真实分类比较):0.9545454545454546

统计数据:
所有文件数:
sports : 1018
auto : 1020
business : 1028
--------------------------------
训练文件数:
sports : 827
auto : 833
business : 825
--------------------------------
测试文件数:
sports : 191
auto : 187
business : 203
--------------------------------
分类, 39907 特征词
准确率为:0.9759036144578314
sports 精确率(跟预测分类比较):0.9894736842105263 召回率(跟真实分类比较):0.9842931937172775
auto 精确率(跟预测分类比较):0.9836956521739131 召回率(跟真实分类比较):0.9679144385026738
business 精确率(跟预测分类比较):0.9565217391304348 召回率(跟真实分类比较):0.9753694581280788

相关文章推荐

文本分类算法之--贝叶斯分类算法的实现Java版本

package com.vista; import java.io.IOException;       import jeasy.analysis.MMAnalyzer; /** * 中...

java实现文本分类中卡方特征选择

java在文本分类中卡方的特征选择,  在文本分类的特征选择阶段,一般使用“词汇t与类别c不相关”来做原假设,计算出的开方值越大,说明对原假设的偏离越大,我们越倾向于认为原 假设的反面...

Delphi7高级应用开发随书源码

  • 2003年04月30日 00:00
  • 676KB
  • 下载

数据挖掘-基于贝叶斯算法及KNN算法的newsgroup18828文本分类器的JAVA实现(上)

(update 2012.12.28 关于本项目下载及运行的常见问题 FAQ见 newsgroup18828文本分类器、文本聚类器、关联分析频繁模式挖掘算法的Java实现工程下载及运行FAQ ) ...

Delphi7高级应用开发随书源码

  • 2003年04月30日 00:00
  • 676KB
  • 下载

Delphi7高级应用开发随书源码

  • 2003年04月30日 00:00
  • 676KB
  • 下载

基于的朴素贝叶斯的文本分类(附完整代码(spark/java)

本文主要包括以下内容: 1)模型训练数据生成(demo) 2 ) 模型训练(spark+java),数据存储在hdfs上 3)预测数据生成(demo) 4)使用生成的模型进行文本分类。...

朴素贝叶斯文本分类算法java实现

在学习了朴素贝叶斯的概念后,下来我们来看看它的java实现。 有一个网友已经实现了其java的算法,具体详见: 数据挖掘-基于贝叶斯算法及KNN算法的newsgroup18828文本分类器...

Naive Bayes 朴素贝叶斯(文本)分类器Java实现

参照:《机器学习实战》 算法原理推导 伪代码 java实现代码 算法原理推导优缺点分析优点:在数据较少的情况下,仍然有效,可以处理多分类问题 缺点:对于输入数据的准备方式比较敏感 适用数据类型:标...

Delphi7高级应用开发随书源码

  • 2003年04月30日 00:00
  • 676KB
  • 下载
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:朴素贝叶斯文本分类java实现
举报原因:
原因补充:

(最多只允许输入30个字)