基于朴素贝叶斯分类器的文本分类

  1. 题目要求

1、用MapReduce算法实现贝叶斯分类器的训练过程,并输出训练模型;

2、用输出的模型对测试集文档进行分类测试。测试过程可基于单机Java程序,也可以是MapReduce程序。输出每个测试文档的分类结果;

3、利用测试文档的真实类别,计算分类模型的PrecisionRecallF1值。

2.实验环境

实验平台:VMware Workstation10

虚拟机系统:Suse11

集群环境:主机名master  ip:192.168.226.129

从机名slave1  ip:192.168.226.130

  1. 贝叶斯分类器理论介绍

贝叶斯分类器的分类原理是通过某对象的先验概率,利用贝叶斯公式计算出其后验概率,即该对象属于某一类的概率,选择具有最大后验概率的类作为该对象所属的类。

应用贝叶斯分类器进行分类主要分成两阶段。第一阶段是贝叶斯统计分类器的学习阶段,即根据训练数据集训练得出训练模型;第二阶段是贝叶斯分类器的推理阶段,即根据训练模型计算属于各个分类的概率,进行分类。

贝叶斯公式如下:

其中AB分别为两个不同的事件,P(A)A先验概率P(A|B)是已知B发生后A条件概率,也由于得自B的取值而被称作A后验概率。而上式就是用事件B的先验概率来求它的后验概率。

  1. 贝叶斯分类器训练的MapReduce算法设计

3.1贝叶斯文本分类流程图

3.2贝叶斯文本分类详细步骤

       整个文档归类过程可以分为以下步骤:

    1. 测试数据打包。将训练数据集中的大量小文本打包为SequencedFileMapReduce程序)。
    2. 文档统计以及单词统计。将1中输出的SequencedFile作为输入,分别进行文档个数统计DocCount(MapReduce程序)和各个分类下单词个数统计WordCount(MapReduce程序)
    3. 测试数据打包。将测试数据集中的大量小文本打包为SequcedFileMapReduce程序)。
    4. 文档归类。将2中输出的文档统计和单词统计的结果,分别计算文档的先验概率和单词在各个分类下的条件概率,然后将3中输出的SequencedFile作为输入,计算测试文本属于各个分类的概率,并选择其中最大概率的分类作为该文档的所属分类。

3.3具体算法设计

一个Country有多个news.txt, 一个news 有多个word

我们所设计的算法最后是要得到随机抽取一个txt文档,它最有可能属于哪个国家类别,也就是我们要得到它属于哪个国家的概率最大,把它转化为数学公式也就是:

         3-1

为了便于比较,我们将上式取对数得到:

       3-2

其中Num(Wi)表示该txt文档中单词Wi的个数;P(C|Wi) 表示拿出一个单词Wi,其属于国家C的后验概率。根据贝叶斯公式有:P(C|W) = P(W|C)*P(C)/P(W),其中:

P(W|C):国家Cnews中单词W出现的概率,根据式3-2,不能使该概率为0,所以我们约定每个国家都至少包含每个单词一次,也就是在统计单词数量时,都自动的加1,就有:

     3-3

P(C):国家C出现的概率(正比于其所含txt文件数);

P(W):单词W在整个测试集中出现的概率。

根据上面的贝叶斯公式我们设计的MapReduce算法如下:

  1. 按比例选取测试文档,其比例大致为国家包含文档数的相对比例;
  2. Map操作:一一遍历文档,得到<<C, Wi> , 1>
  3. Reduce操作:

合并<<C, W> , 1> 得到国家C中含有单词Wi的个数<<C, Wi> , ni>+1,记为N(C,Wi)

得到国家C中含有的单词总数,记为N(C)

得到测试集中单词W的总数,记为N(W)

再由得到测试集的单词总数,记为N

则可求得P(W|C) = N(C,W)/N(C)P(C) = N(C)/NP(W) = N(W)/N

 

3.4MapReduceData Flow示意图

  1. 源代码清单

本实验中的主要代码如下所示

4.1 SmallFilesToSequenceFileConverter.java      小文件集合打包工具类MapReduce程序

4.2 WholeFileInputFormat.java      支持类:递归读取指定目录下的所有文件

4.3 WholeFileRecordReader.java 支持类:读取单个文件的全部内容

4.4 DocCount.java     文档统计MapReduce程序

4.5 WordCount.java   单词统计MapReduce程序

4.6 DocClassification.java   测试文档分类MapReduce程序

详细代码如下:

4.1 SmallFilesToSequenceFileConverter.java 其中MapReduce关键代码如下:

publicclass SmallFilesToSequenceFileConverter extends Configured implements Tool {

 

    staticclass SequenceFileMapper extends Mapper<NullWritable, BytesWritable, Text, BytesWritable> {

 

       private String fileNameKey// 被打包的小文件名作为key,表示为Text对象

       private String classNameKey// 当前文档所在的分类名

 

       @Override// 重新实现setup方法,进行map任务的初始化设置

       protectedvoid setup(Context contextthrows IOException, InterruptedException {

           InputSplit split = context.getInputSplit(); // context获取split

           Path path = ((FileSplit) split).getPath(); // split获取文件路径

           fileNameKey = path.getName(); // 将文件路径实例化为key对象

           classNameKey = path.getParent().getName();

       }

 

       @Override// 实现map方法

       protectedvoid map(NullWritable key, BytesWritable value, Context context)

              throws IOException, InterruptedException {

           // 注意sequencefilekeyvalue key:分类,文档名  value:文档的内容)

           context.write(new Text(classNameKey + "/" + fileNameKey), value);

       }

    }

}

4.2 WholeFileInputFormat.java 其中关键代码如下:

publicclass WholeFileInputFormat extends FileInputFormat<NullWritable, BytesWritable> {

    /**

     * <p>方法描述:递归遍历输入目录下的所有文件</p>

     * <p>备注:该写FileInputFormat,使支持多层目录的输入</p>

     *  @authormeify DateTime 2015113下午2:37:49

     *  @param fs

     *  @param path

     */

    void search(FileSystem fs, Path path) {

       try {

           if (fs.isFile(path)) {

              fileStatus.add(fs.getFileStatus(path));

           } elseif (fs.isDirectory(path)) {

              FileStatus[] fileStatus = fs.listStatus(path);

              for (inti = 0; i < fileStatus.lengthi++) {

                  FileStatus fileStatu = fileStatus[i];

                  search(fsfileStatu.getPath());

              }

           }

       } catch (IOException e) {

           e.printStackTrace();

       }

    }

    @Override

    public RecordReader<NullWritable, BytesWritable> createRecordReader(InputSplit split, TaskAttemptContext context)

           throws IOException, InterruptedException {

       WholeFileRecordReader reader = new WholeFileRecordReader();

       reader.initialize(splitcontext);

       returnreader;

    }

    @Override

    protected List<FileStatus> listStatus(JobContext jobthrows IOException {

      

       FileSystem fs = FileSystem.get(job.getConfiguration());

       // 输入根目录

       String rootDir = job.getConfiguration().get("mapred.input.dir""");

       // 递归获取输入目录下的所有文件

       search(fsnew Path(rootDir));

       returnthis.fileStatus;

    }

}

4.3 WholeFileRecordReader.java 其中关键代码如下:

publicclass WholeFileRecordReader extends RecordReader<NullWritable, BytesWritable>{

 

    private FileSplit fileSplit//保存输入的分片,它将被转换成一条( key value)记录

    private Configuration conf//配置对象

    private BytesWritable value = new BytesWritable(); //value对象,内容为空

    privatebooleanprocessed = false//布尔变量记录记录是否被处理过

    @Override

    publicboolean nextKeyValue() throws IOException, InterruptedException {

       if (!processed) { //如果记录没有被处理过

           //fileSplit对象获取split的字节数,创建byte数组contents

           byte[] contents = newbyte[(intfileSplit.getLength()];

           Path file = fileSplit.getPath(); //fileSplit对象获取输入文件路径

           FileSystem fs = file.getFileSystem(conf); //获取文件系统对象

           FSDataInputStream in = null//定义文件输入流对象

           try {

              in = fs.open(file); //打开文件,返回文件输入流对象

//从输入流读取所有字节到contents

              IOUtils.readFully(incontents, 0, contents.length);          value.set(contents, 0, contents.length); //contens内容设置到value对象中

           } finally {

              IOUtils.closeStream(in); //关闭输入流

           }

          

           processed = true//将是否处理标志设为true,下次调用该方法会返回false

           returntrue;

       }

           returnfalse//如果记录处理过,返回false,表示split处理完毕

    }

}

4.4 DocCount.java  其中MapReduce关键代码如下:

publicclass DocCount extends Configured implements Tool{

 

    publicstaticclass Map extends Mapper<Text, BytesWritable, Text, IntWritable> {

       @Override

       publicvoid map(Text key, BytesWritable value, Context context) {

           try {

              String currentKey = key.toString();

              String[] arr = currentKey.split("/");

              String className = arr[0];

              String fileName = arr[1];

              System.out.println(className + "," + fileName);

              context.write(new Text(className), new IntWritable(1));

           } catch (IOException e) {

              e.printStackTrace();

           } catch (InterruptedException e) {

              e.printStackTrace();

           }

       }

    }

   

    publicstaticclass Reduce extends Reducer<Text, IntWritable, Text, IntWritable> {

       private IntWritable result = new IntWritable();

       publicvoid reduce(Text key, Iterable<IntWritable> values, Context contextthrows IOException, InterruptedException {

           intsum = 0;

           for (IntWritable val : values) {

              sum ++;

           }

           result.set(sum);

           context.write(keyresult);  // 输出结果key: 分类 ,  value: 文档个数

       }

    }

}

4.5 WordCount.java  其中MapReduce关键代码如下:

publicclass WordCount extends Configured implements Tool{

 

    publicstaticclass Map extends Mapper<Text, BytesWritable, Text, IntWritable> {

 

       @Override

       publicvoid map(Text key, BytesWritable value, Context context) {

           try {

             

              String[] arr = key.toString().split("/");

              String className = arr[0];

              String fileName = arr[1];

              value.setCapacity(value.getSize()); // 剔除多余空间

              // 文本内容

               String content = new String(value.getBytes(), 0, value.getLength());

              StringTokenizer itr = new StringTokenizer(content);

              while (itr.hasMoreTokens()) {

                  String word = itr.nextToken();

                  if(StringUtil.isValidWord(word))

                  {

                     System.out.println(className + "/" + word);

                     context.write(new Text(className + "/" + word), new IntWritable(1));

                  }

              }

           } catch (IOException e) {

              e.printStackTrace();

           } catch (InterruptedException e) {

              e.printStackTrace();

           }

       }

    }

   

    publicstaticclass Reduce extends Reducer<Text, IntWritable, Text, IntWritable> {

       private IntWritable result = new IntWritable();

       publicvoid reduce(Text key, Iterable<IntWritable> values, Context contextthrows IOException, InterruptedException {

           intsum = 1; // 注意这里单词的个数从1开始计数

           for (IntWritable val : values) {

              sum ++;

           }

           result.set(sum);

           context.write(keyresult);  // 输出结果key: 分类单词 ,  value: 频次

       }

    }

}

4.6 DocClassification.java 其中MapReduce关键代码如下:

publicclass DocClassification extends Configured implements Tool {

    // 所有分类集合

    privatestatic List<String> classList = new ArrayList<String>();

    // 所有分类的先验概率(其中的概率取对数log

    privatestatic HashMap<String, Double> classProMap = new HashMap<String, Double>();

    // 所有单词在各个分类中的出现的频次

    privatestatic HashMap<String, Integer> classWordNumMap = new HashMap<String, Integer>();

    // 分类下的所有单词出现的总频次

    privatestatic HashMap<String, Integer> classWordSumMap = new HashMap<String, Integer>();

    privatestatic Configuration conf = new Configuration();

    static {

       // 初始化分类先验概率词典

       initClassProMap("hdfs://192.168.226.129:9000/user/hadoop/doc");

       // 初始化单词在各个分类中的条件概率词典

       initClassWordProMap("hdfs://192.168.226.129:9000/user/hadoop/word");

    }

   

    publicstaticclass Map extends Mapper<Text, BytesWritable, Text, Text> {

       @Override

       publicvoid map(Text key, BytesWritable value, Context context) {

           String fileName = key.toString();

           value.setCapacity(value.getSize()); // 剔除多余空间

           String content = new String(value.getBytes(), 0, value.getLength());

           try {

              for (String className : classList) {

                  doubleresult = Math.log(classProMap.get(className));

                  StringTokenizer itr = new StringTokenizer(content);

                  while (itr.hasMoreTokens()) {

                     String word = itr.nextToken();

                     if (StringUtil.isValidWord(word)) {

                         intwordSum = 1;

                         if(classWordNumMap.get(className + "/" + word) != null){

                            wordSum = classWordNumMap.get(className + "/" + word);

                         }

                         intclassWordSum = classWordSumMap.get(className);

                         doublepro_class_word = Math.log(((double)wordSum)/classWordSum);

                         result += pro_class_word;

                     }

                  }

                  // 输出的形式 key:文件名 value:分类名/概率

                  context.write(new Text(fileName), new Text(className + "/" + String.valueOf(result)));

              }

           } catch (IOException e) {

              e.printStackTrace();

           } catch (InterruptedException e) {

              e.printStackTrace();

           }

       }

    }

 

    publicstaticclass Reduce extends Reducer<Text, Text, Text, Text> {

 

       publicvoid reduce(Text key, Iterable<Text> values, Context contextthrows IOException, InterruptedException {

           String fileName = key.toString().split("/")[1];

    doublemaxPro = Math.log(Double.MIN_VALUE);

           String maxClassName = "unknown";

           for (Text value : values) {

              String[] arr = value.toString().split("/");

              String className = arr[0];

              doublepro = Double.valueOf(arr[1]);

              if (pro > maxPro) {

                  maxPro = pro;

                  maxClassName = className;

              }

           }

           System.out.println("fileName:" + fileName + ",belong class:" + maxClassName);

           // 输出 key:文件名 value:所属分类名以及概率

           context.write(new Text(fileName), new Text(maxClassName + ",pro=" + maxPro));

       }

    }

}

四、数据集说明

训练集:CHINA  文档数255

INDIA   文档数326

TAIWAN  文档数43.

测试集:CHINA   文档个数15

INDIA    文档个数20

TAIWAN  文档个数15

  1. 程序运行说明

5.1训练数据集打包程序

Map任务个数624(所有小文件的个数)   Reduce任务个数1

截图如下

 

5.2训练文档统计程序

Map任务个数1(输入为1SequencedFile   Reduce任务个数1

5.3训练单词统计程序

Map任务个数1(输入为1SequencedFile)   Reduce任务个数1

5.4测试数据集打包程序

Map任务个数50(测试数据集小文件个数为50)   Reduce任务个数1

5.5测试文档归类程序

Map任务个数1(输入为1SequencedFile   Reduce任务个数1

 

  1. 实验结果分析

测试集文档归类结果截图如下:

 针对CHINA TAIWAN INDIA三个分类下的测试文档进行测试结果如下表所示:

类别(国家)

正确率

召回率

F1

CHINA

18.4%

46.667%

26.38%

INDIA

42.1%

80%

55.67%

TAIWAN

39.47%

100%

56.60%

  • 8
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值