简单实现朴素贝叶斯分类

package pusu;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Scanner;

import org.apdplat.word.WordSegmenter;
import org.apdplat.word.segmentation.Word;

 public class boss {
    static ArrayList<String> pe_term=new ArrayList<>();//记录从文件中出来的词项
    static ArrayList<String> hs_term=new ArrayList<>();
    static ArrayList<String> term = new ArrayList<String>();//库词项
    static ArrayList<forterm> aimToFile=new ArrayList<forterm>();
    static ArrayList<selecte_term> slt_term=new ArrayList<selecte_term>();
    public static void main(String[] args) throws IOException {
        int a=pe_readfile();
        int b=hs_readfile();
        System.out.println("文档集有: "+(a+b)+" 个");
         all_term(term);
//         System.out.println("all_term");
         removeSameElement(term);
//         System.out.println("removeSameElement");
         removeStopElement(term);
//         System.out.println("removeStopElement");
         statistic();
//         System.out.println("statistic");
         getPe_peprior_probability();
//         System.out.println("getPe_peprior_probability");
         precision(term,aimToFile);
         save();
         System.out.println("save");
         //selecte();
    }

    //******************************将体育所有文档中词项提出来    
    public static int  pe_readfile() throws IOException{
        
        File[] file=new File("tiyu").listFiles();
        for(int i=0;i<file.length;i++)
        {
            String temp=null;
            BufferedReader pe_file=new BufferedReader(new FileReader(file[i].getAbsolutePath())); 
            while((temp = pe_file.readLine()) != null) {
                pe_term.add(temp);
            }
            pe_file.close();
        }
        return (file.length);
    }
//******************************将历史所有文档中词项提出来    
    public static int  hs_readfile() throws IOException{
        File[] file=new File("lishi").listFiles();
        for(int i=0;i<file.length;i++)
        {
            String temp=null;
            BufferedReader hs_file=new BufferedReader(new FileReader(file[i].getAbsolutePath())); 
            while((temp = hs_file.readLine()) != null) {
                hs_term.add(temp);
            }
            hs_file.close();
        }
        return (file.length);
    }
//将所有词项合并在一个词项库中
    public static void all_term(ArrayList<String> term) {
        for(int i = 0; i < pe_term.size(); i++) {
            term.add(pe_term.get(i));
        }
        for(int i = 0; i < hs_term.size(); i++) {
            term.add(hs_term.get(i));
        }
    }
//*************************************将词项中重复的删掉
     public static ArrayList<String> removeSameElement(ArrayList<String> term){
         ArrayList<String> temp = new ArrayList<>();
         Iterator it = term.iterator();
         while(it.hasNext()) {
         Object o = (Object)it.next();
         if(!temp.contains(o))
         temp.add((String)o);
         }
         return temp;
     }
//**************************************去掉停用词
     public static ArrayList<String> removeStopElement(ArrayList<String> term) throws IOException{
         ArrayList<String> SW=new ArrayList<>();
         File file_WD=new File("停用词表.txt");
         Scanner file3=new Scanner(file_WD);
         while(file3.hasNextLine()){
                SW.add(file3.nextLine());
            }
         ArrayList<String> temp = new ArrayList<>();
         for(int i=0;i<term.size();i++) {
             if(term.get(i).compareTo("0")!=0)
             {
                 Object o = (Object)term.get(i);
                 temp.add((String)o);
             }
         }
         return temp;
     }
//********************************计算先验概率
    public static double[] getPe_peprior_probability() throws IOException {
        int pe_file_num=pe_readfile();
        int hs_file_num=hs_readfile();
        double[] peprior_probability= {(double)pe_file_num/(hs_file_num+pe_file_num),(double)hs_file_num/(hs_file_num+pe_file_num)};
        System.out.println("历史的先验概率  :"+peprior_probability[0]);
        System.out.println("体育的先验概率  :"+peprior_probability[1]);
        return peprior_probability;
    }
 
//********************************统计计算
    public static void statistic() {
        int[] pe_count=new int[term.size()];
        int[] hs_count=new int[term.size()];
        double[] pe_frequence=new double[term.size()];
        double[] hs_frequence=new double[term.size()];
        for(int i=0;i<term.size();i++){   //初始化
            pe_count[i]=1;
            hs_count[i]=1;
        }
        //计算词出现的次数
        for(int i=0;i<term.size();i++)
            for(int j=0;j<pe_term.size();j++)
                if(pe_term.get(j).compareTo(term.get(i))==0) 
                    pe_count[i]++;
        for(int i=0;i<term.size();i++)
            for(int j=0;j<hs_term.size();j++)
                if(hs_term.get(j).compareTo(term.get(i))==0) 
                    hs_count[i]++;
        //计算频率
        int sum1=0,sum2=0;
        for(int i=0;i<term.size();i++){
            sum1+=pe_count[i];
            sum2+=hs_count[i];
        }
        for(int i=0;i<term.size();i++){
            pe_frequence[i]=(double)pe_count[i]/sum1;
            hs_frequence[i]=(double)hs_count[i]/sum2;
        }
        for(int i=0;i<term.size();i++){
            forterm o=new forterm();
            o.setE_term(term.get(i));
            o.setPe_num(pe_count[i]);
            o.setHs_num(hs_count[i]);
            o.setPe_frequence(pe_frequence[i]);
            o.setHs_frequence(hs_frequence[i]);
            //all_count++;
            aimToFile.add(o);
        }
    }
//***************************************计算正确率、召回率
    public static void precision(ArrayList<String> term,ArrayList<forterm> aimToFile) throws IOException {
        ArrayList<ArrayList<String>> PE=new ArrayList<ArrayList<String>>();
        ArrayList<ArrayList<String>> HS=new ArrayList<ArrayList<String>>();
        double[] pro=new double[2];
        pro=getPe_peprior_probability();
        
        File[] pe_file=new File("测试体育").listFiles();
        for(int i=0;i<pe_file.length;i++)
        {
            ArrayList<String> pe_term = new ArrayList<String>();
            File txt_file=new File("测试体育" + "/" + pe_file[i].getName());
            if (!txt_file.exists()) {
                System.out.println("wrong");
            }
                Scanner file2=new Scanner(pe_file[i]); 
                while(file2.hasNextLine()) {
                    pe_term.add(file2.nextLine());
                }
                PE.add(pe_term);
                file2.close();
        }
        
        File[] hs_file=new File("测试历史").listFiles();
        for(int i=0;i<hs_file.length;i++)
        {
            ArrayList<String>hs_term = new ArrayList<String>();
            File txt_file=new File("测试历史" + "/" + hs_file[i].getName());
                Scanner file1=new Scanner(hs_file[i]); 
                while(file1.hasNextLine()) {
                    hs_term.add(file1.nextLine());
                    
                }
                HS.add(hs_term);
                file1.close();
        }
        int computer_pe=0;
        int computer_hs=0;
        int real_pe=0;
        int real_hs=0;
        int all_pefile=PE.size();
        int all_hsfile=HS.size();
        for(int i=0;i<PE.size();i++) {
            double pe_frequence=0;
            double hs_frequence=0;
            for(int j=0;j<PE.get(i).size();j++) {
                for(int k=0;k<term.size();k++) {
                    if(PE.get(i).get(j).compareTo(aimToFile.get(k).getE_term())==0){
                        pe_frequence+=Math.log(aimToFile.get(k).getPe_frequence());
                        hs_frequence+=Math.log(aimToFile.get(k).getHs_frequence());
                    }
                }
            }
//            每一个文档即刻判断
//            pe_frequence=pe_frequence*();
//            hs_frequence=hs_frequence*(o.getHs_peprior_probability());
            if(pe_frequence>hs_frequence) {
                computer_pe++;//计算机判断出来体育多少个
                real_pe++;
            }
            if(pe_frequence<hs_frequence) {//计算机判断出来体育多少个
                computer_hs++;
            }
            
        }
        for(int i=0;i<HS.size();i++) {
            double pe_frequence=0;
            double hs_frequence=0;
            for(int j=0;j<HS.get(i).size();j++) {
                for(int k=0;k<term.size();k++) {
                    if(HS.get(i).get(j).compareTo(aimToFile.get(k).getE_term())==0){
                        pe_frequence+=Math.log(aimToFile.get(k).getPe_frequence());
                        hs_frequence+=Math.log(aimToFile.get(k).getHs_frequence());
                    }
                }
            }
            //每一个文档即刻判断
//            pe_frequence=pe_frequence*(o.getPe_peprior_probability());
//            hs_frequence=hs_frequence*(o.getHs_peprior_probability());
            if(pe_frequence>hs_frequence) {
                computer_pe++;//计算机判断出来体育多少个
            }
            if(pe_frequence<hs_frequence) {//计算机判断出来体育多少个
                computer_hs++;
                real_hs++;
            }
        }
        System.out.println("体育正确率   :"+(double)real_pe/computer_pe);
        System.out.println("历史正确率   : "+(double)real_hs/computer_hs);
        System.out.println("体育召回率   :"+(double)real_pe/all_pefile);
        System.out.println("历史召回率  : "+(double)real_hs/all_hsfile);
    }
//**************************************将词项存到文档中去
    public static void save() throws IOException {
            File save=new File("allterm(1).txt");
        PrintWriter file=new PrintWriter(save);
        for(int i=0;i<aimToFile.size();i++){
            file.printf("%s   %d   %d   %.6f   %.6f",aimToFile.get(i).getE_term(),aimToFile.get(i).getPe_num(),aimToFile.get(i).getHs_num(),aimToFile.get(i).getPe_frequence(),aimToFile.get(i).getHs_frequence());
            file.println();
            if(i==aimToFile.size()-1) {
                file.printf("%s   %d   %d   %.6f   %.6f",aimToFile.get(i).getE_term(),aimToFile.get(i).getPe_num(),aimToFile.get(i).getHs_num(),aimToFile.get(i).getPe_frequence(),aimToFile.get(i).getHs_frequence());
            }
        }
        file.close();
    }
    
    public static void start_term(ArrayList<String> term) {
        for(int i=0;i<term.size();i++) {
            selecte_term o=new selecte_term(term.get(i));//存入,并且初始化
            slt_term.add(o);
        }
    }
    public static void selecte() throws IOException {
        start_term(term);//调用此函数
        File[] file1=new File("tiyu").listFiles();
        File[] file2=new File("lishi").listFiles();
        for(int i=0;i<slt_term.size();i++) {//每一个词遍历整个学习文档集
            for(int j=0;j<file1.length;j++)
            {
                String temp=null;
                BufferedReader pe_file=new BufferedReader(new FileReader(file1[j].getAbsolutePath())); 
                ArrayList<String> everyterm=new ArrayList<String>();
                while((temp = pe_file.readLine()) != null) {
                    everyterm.add(temp);
                }
                pe_file.close();
                
                for(int k=0;k<everyterm.size();k++) {
                             //如果这个词在这个01的文档中,并且这个文档属于体育类
                            if(slt_term.get(i).get_name().compareTo(everyterm.get(k))==0) {
                                slt_term.get(i).add_N11();
                            }
                            //这个词不在01这个文档中,但是这个文档属于类别体育
                            else if(slt_term.get(i).get_name().compareTo(everyterm.get(k))!=0){
                                slt_term.get(i).add_N01();
                            }    
                }
                            
            }
            for(int j=0;j<file2.length;j++)
            {
                String temp=null;
                BufferedReader hs_file=new BufferedReader(new FileReader(file2[j].getAbsolutePath())); 
                ArrayList<String> everyterm=new ArrayList<String>();
                while((temp = hs_file.readLine()) != null) {
                    everyterm.add(temp);
                }
                hs_file.close();
                
                for(int k=0;k<everyterm.size();k++) {
                             //如果这个词在这个01的文档中,并且这个文档不属于体育类
                            if(slt_term.get(i).get_name().compareTo(everyterm.get(k))==0) {
                                slt_term.get(i).add_N10();
                            }
                            //这个词不在01这个文档中,但是这个文档不属于类别体育
                            else if(slt_term.get(i).get_name().compareTo(everyterm.get(k))!=0){
                                slt_term.get(i).add_N00();
                            }    
                }
                            
            }
                
                
        }
        //套用公式:
        double[] c=new double[slt_term.size()];//每个词的MI值
        for(int i=0;i<slt_term.size();i++) {
            
            File save=new File("selecte_term.txt");
            PrintWriter file=new PrintWriter(save);
            int N,N1,N0,N00,N01,N10,N11;
            N00=slt_term.get(i).get_N00();
            N01=slt_term.get(i).get_N01();
            N10=slt_term.get(i).get_N10();
            N11=slt_term.get(i).get_N11();
            N=N00+N01+N10+N11;
            N1=N10+N11;
            N0=N00+N01;
            c[i]=(N11/N)*Log2((N*N11)/(N1*N1))+(N01/N)*Log2((N*N01)/(N0*N1))+(N10/N)*Log2((N*N10)/(N1*N0))+(N00/N)*Log2((N*N00)/(N0*N0));
            slt_term.get(i).setC(c[i]);
            if(i!=(slt_term.size()-1)) {
            file.printf("%s   %.6f",slt_term.get(i).get_name(),slt_term.get(i).getC());
            file.println();
        }
        file.printf("%s   %.6f",slt_term.get(i).get_name(),slt_term.get(i).getC());
        }
        
    }
    public static double Log2(double a) {
        return Math.log(a)/Math.log(2.0);
    }

}
 
class forterm{
    private String e_term;//不允许用户修改 只能读取文件之后修改
    private int pe_num;
    private int hs_num;
    private double pe_frequence;
    private double hs_frequence;
    public String getE_term() {
        return e_term;
    }
    public void setE_term(String e_term) {
        this.e_term = e_term;
    }
    public int getPe_num() {
        return pe_num;
    }
    public void setPe_num(int pe_num) {
        this.pe_num = pe_num;
    }
    public int getHs_num() {
        return hs_num;
    }
    public void setHs_num(int hs_num) {
        this.hs_num = hs_num;
    }
    public double getPe_frequence() {
        return pe_frequence;
    }
    public void setPe_frequence(double pe_frequence) {
        this.pe_frequence = pe_frequence;
    }
    public double getHs_frequence() {
        return hs_frequence;
    }
    public void setHs_frequence(double hs_frequence) {
        this.hs_frequence = hs_frequence;
    }
}
class selecte_term implements Comparable<selecte_term>{
    private String name;
    private int N00=1;
    private int N01=1;
    private int N10=1;
    private int N11=1;
    private double c;
    selecte_term(String name){
        this.name=name;
        this.N00=1;
        this.N01=1;
        this.N10=1;
        this.N11=1;
    }
    public void set_name(String name) {
        this.name=name;
    }
    public String get_name() {
        return name;
    }
    public void add_N00() {
        N00++;
    }
    public int get_N00() {
        return N00;
    }
    public void add_N01() {
        N01++;
    }
    public int get_N01() {
        return N01;
    }
    public void add_N10() {
        N10++;
    }
    public int get_N10() {
        return N10;
    }
    public void add_N11() {
        N11++;
    }
    public double getC() {
        return c;
    }
    public void setC(double c) {
        this.c = c;
    }
    public int get_N11() {
        return N11;
    }
    @Override
    public int compareTo(selecte_term o) {
        // TODO Auto-generated method stub
        return 0;
    }
    
}

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值