贝叶斯定理
在一个论域中,某个事件A发生的概率用P(A)表示,事件的条件概率P(A|B)的定义为:在事件B已经发生的前提下事件A发生的概率。其计算公式为:
分类的问题是,给定已知的一组类 Y1 , Y2 , …, Yk 以及一个未分类样本X, 判断X应该属于Y1, Y2, …, Yk 中的哪一个类。如果利用贝叶斯定理,问题也可以转换成:若X是一个样本,那么X属于这k个类中的哪一个的几率最大.
朴素贝叶斯算法分析
假设每个数据样本用一个n维特征向量来描述n个属性的值,即:X={X1 , X2 , … ,Xn}。假定有m个类,分别用Y1 , Y2 , Y3 , … , Yn 表示。给定一个未分类的数据样本X,若朴素贝叶斯分类时未分类样本X落入分类Yi,则一定有P(Yi|X) >= P(Yj|X), 1 <= j <= m 。
根据朴素贝叶斯公式,由于未分类样本X出现的概率P(X)对于所有分类为定值,因此只需要计算P(Yi|X)的相对值大小,所以概率P(Yi|X)可转化为计算P(X|Yi)P(Yi)
如果样本X中的属性xj有相关性,计算P(X|Yi)将会非常复杂,因此,通常假设X的各属性xj是互相独立的,这样P(X|Yi)的计算可简化为求P(x1|Yi),P(x2|Yi),…,P(x3|Yi)之积;而每个P(xj|Yi)和P(Yi)都可以从训练数据中求得
因此对一个未分类的样本X ,可以先计算X属于每一个分类Yi的概率P(X|Yi)P(Yi),然后选择其中最大的Yi作为其分类.
步骤
根据以上分析,朴素贝叶斯算法分为2个阶段:
样本训练阶段计算每个分类Yi出现的频度 P(Yi),以及每个属性值xj出现在Yi中的频度 P(xj|Yi)
分类预测阶段对一个未分类的测试样本X,针对其包含的每个具体属性值xj,根据从训练数据集计算出的P(xjYi)进行求积得到样本X对于Yi的条件概率P(X|Yi),再乘以P(Yi)即可得到X在各个Yi中出现的频度P(X|Yi)P(Yi),取得最大频度的Yi即为X所属的分类.
整理训练集
package org.bigdata.mapreduce.native_bayes;
import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.bigdata.util.HadoopCfg;
public class NBayesTrainMapReduce {
private static final String Train = "NBayes.train";
// P(A|B)=P(AB)/P(B)
public static class NBayesTrainMapper extends
Mapper<LongWritable, Text, Text, IntWritable> {
@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
FileSplit fileSplit = (FileSplit) context.getInputSplit();
String fileName = fileSplit.getPath().getName();
if (Train.equals(fileName)) {
String terms[] = value.toString().split(" ");
// P(yi)
context.write(new Text(terms[0]), new IntWritable(1));
for (int i = 1, len = terms.length; i < len; i++) {
// P(xj|yi)
context.write(
new Text(terms[0] + ":" + i + ":" + terms[i]),
new IntWritable(1));
}
}
}
}
public static class NBayesTrainReducer extends
Reducer<Text, IntWritable, Text, IntWritable> {
@Override
protected void reduce(Text key, Iterable<IntWritable> values,
Context context) throws IOException, InterruptedException {
int sum = 0;
for (IntWritable value : values) {
sum += value.get();
}
context.write(key, new IntWritable(sum));
}
}
private static final String JOB_NAME = "NB";
public static void solve(String pathin, String pathout)
throws ClassNotFoundException, InterruptedException {
try {
Configuration cfg = HadoopCfg.getConfiguration();
Job job = Job.getInstance(cfg);
job.setJobName(JOB_NAME);
job.setJarByClass(NBayesTrainMapReduce.class);
// mapper
job.setMapperClass(NBayesTrainMapper.class);
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(IntWritable.class);
// reducer
job.setReducerClass(NBayesTrainReducer.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(IntWritable.class);
FileInputFormat.addInputPath(job, new Path(pathin));
FileOutputFormat.setOutputPath(job, new Path(pathout));
job.waitForCompletion(true);
} catch (IllegalStateException | IllegalArgumentException | IOException e) {
e.printStackTrace();
}
}
public static void main(String[] args) throws ClassNotFoundException,
InterruptedException {
solve("/nb", "/nb_train");
}
}
package org.bigdata.mapreduce.native_bayes;
import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.bigdata.util.HadoopCfg;
public class NBayesTrainMapReduce {
private static final String Train = "NBayes.train";
// P(A|B)=P(AB)/P(B)
public static class NBayesTrainMapper extends
Mapper<LongWritable, Text, Text, IntWritable> {
@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
FileSplit fileSplit = (FileSplit) context.getInputSplit();
String fileName = fileSplit.getPath().getName();
if (Train.equals(fileName)) {
String terms[] = value.toString().split(" ");
// P(yi)
context.write(new Text(terms[0]), new IntWritable(1));
for (int i = 1, len = terms.length; i < len; i++) {
// P(xj|yi)
context.write(
new Text(terms[0] + ":" + i + ":" + terms[i]),
new IntWritable(1));
}
}
}
}
public static class NBayesTrainReducer extends
Reducer<Text, IntWritable, Text, IntWritable> {
@Override
protected void reduce(Text key, Iterable<IntWritable> values,
Context context) throws IOException, InterruptedException {
int sum = 0;
for (IntWritable value : values) {
sum += value.get();
}
context.write(key, new IntWritable(sum));
}
}
private static final String JOB_NAME = "NB";
public static void solve(String pathin, String pathout)
throws ClassNotFoundException, InterruptedException {
try {
Configuration cfg = HadoopCfg.getConfiguration();
Job job = Job.getInstance(cfg);
job.setJobName(JOB_NAME);
job.setJarByClass(NBayesTrainMapReduce.class);
// mapper
job.setMapperClass(NBayesTrainMapper.class);
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(IntWritable.class);
// reducer
job.setReducerClass(NBayesTrainReducer.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(IntWritable.class);
FileInputFormat.addInputPath(job, new Path(pathin));
FileOutputFormat.setOutputPath(job, new Path(pathout));
job.waitForCompletion(true);
} catch (IllegalStateException | IllegalArgumentException | IOException e) {
e.printStackTrace();
}
}
public static void main(String[] args) throws ClassNotFoundException,
InterruptedException {
solve("/nb", "/nb_train");
}
}
输入
cl1 5 6 7
cl2 3 8 4
cl1 2 5 2
cl3 7 8 7
cl4 3 8 2
cl4 9 2 7
cl2 1 8 5
cl5 2 9 4
cl3 10 3 4
cl1 4 5 6
cl3 4 6 7
1 5 6 7
2 1 8 5
3 2 9 4
4 10 3 4
5 4 5 6
6 3 8 4
7 2 5 2
8 7 8 7
9 3 8 2
10 9 2 7
11 4 6 7