文章转自:https://blog.csdn.net/Angelababy_huan/article/details/53046151
贝叶斯分类器的分类原理是通过某对象的先验概率,利用贝叶斯公式计算出其后验概率,即该对象属于某一类的概率,选择具有最大后验概率的类作为该对象所属的类。
以下为一个简单的例子:
数据:天气情况和每天是否踢足球的记录表
日期 | 踢足球 | 天气 | 温度 | 湿度 | 风速 |
1号 | 否(0) | 晴天(0) | 热(0) | 高(0) | 低(0) |
2号 | 否(0) | 晴天(0) | 热(0) | 高(0) | 高(1) |
3号 | 是(1) | 多云(1) | 热(0) | 高(0) | 低(0) |
4号 | 是(1) | 下雨(2) | 舒适(1) | 高(0) | 低(0) |
5号 | 是(1) | 下雨(2) | 凉爽(2) | 正常(1) | 低(0) |
6号 | 否(0) | 下雨(2) | 凉爽(2) | 正常(1) | 高(1) |
7号 | 是(1) | 多云(1) | 凉爽(2) | 正常(1) | 高(1) |
8号 | 否(0) | 晴天(0) | 舒适(1) | 高(0) | 低(0) |
9号 | 是(1) | 晴天(0) | 凉爽(2) | 正常(1) | 低(0) |
10号 | 是(1) | 下雨(2) | 舒适(1) | 正常(1) | 低(0) |
11号 | 是(1) | 晴天(0) | 舒适(1) | 正常(1) | 高(1) |
12号 | 是(1) | 多云(1) | 舒适(1) | 高(0) | 高(1) |
13号 | 是(1) | 多云(1) | 热(0) | 正常(1) | 低(0) |
14号 | 否(0) | 下雨(2) | 舒适(1) | 高(0) | 高(1) |
15号 | ? | 晴天(0) | 凉爽(2) | 高(0) | 高(1) |
需要预测15号,在这种天气情况下是否踢球。
假设15号去踢球,踢球的概率计算过程如下:
P(踢球的概率) = 9/14
P(晴天|踢) = 踢球天数中晴天踢球的次数/踢球次数 = 2/9
P(凉爽|踢) = 踢球天数中凉爽踢球的次数/踢球次数 = 3/9
P(湿度高|踢) = 踢球天数中湿度高踢球的次数/踢球次数 = 3/9
P(风速高|踢) = 踢球天数中风速高踢球的次数/踢球次数 = 3/9
则15号踢球的概率P = 9/14 * 2/9 * 3/9 * 3/9 * 3/9 = 0.00529
按照上述步骤还可计算出15号不去踢球的概率P = 5/14 * 3/5 * 1/5 * 4/5 * 3/5 = 0.02057
可以看出,15号不去踢球的概率大于去踢球的概率,则可预测说,15号不去踢球。
理解朴素贝叶斯的流程之后,开始设计MR程序。在Mapper中,对训练数据进行拆分,也就是将这条训练数据拆分为类别和训练数据,将训练数据以自定义值类型来保存,然后传递给Reducer。
Mapper:
import java.io.IOException;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BayesMapper extends Mapper<Object, Text, IntWritable, MyWritable> {
Logger log = LoggerFactory.getLogger(BayesMapper.class);
private IntWritable myKey = new IntWritable();
private MyWritable myValue = new MyWritable();
@Override
protected void map(Object key, Text value, Context context)
throws IOException, InterruptedException {
log.info("***"+value.toString());
int[] values = getIntData(value);
int label = values[0]; //存放类别
int[] result = new int[values.length-1]; //存放数据
for(int i =1;i<values.length;i++){
result[i-1] = values[i];
}
myKey.set(label);
myValue.setValue(result);
context.write(myKey, myValue);
}
private int[] getIntData(Text value) {
String[] values = value.toString().split(",");
int[] data = new int[values.length];
for(int i=0; i < values.length;i++){
if(!values[i].equals(""))
if(values[i].matches("^[0-9]+$"))
data[i] = Integer.parseInt(values[i]);
}
return data;
}
}
MyWritable:
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
public class MyWritable implements Writable{
private int[] value;
public MyWritable() {
}
public MyWritable(int[] value){
this.setValue(value);
}
public void write(DataOutput out) throws IOException {
out.writeInt(value.length);
for(int i=0; i<value.length;i++){
out.writeInt(value[i]);
}
}
public void readFields(DataInput in) throws IOException {
int vLength = in.readInt();
value = new int[vLength];
for(int i=0; i<vLength;i++){
value[i] = in.readInt();
}
}
public int[] getValue() {
return value;
}
public void setValue(int[] value) {
this.value = value;
}
}
Reducer:
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Reducer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BayesReducer extends Reducer<IntWritable, MyWritable, IntWritable, IntWritable>{
Logger log = LoggerFactory.getLogger(BayesReducer.class);
private String testFilePath;
// 测试数据
private ArrayList<int[]> testData = new ArrayList<>();
// 保存相同k的所有数据
private ArrayList<CountAll> allData = new ArrayList<>();
@Override
protected void setup(Context context)
throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
testFilePath = conf.get("home/5.txt");
Path path = new Path("home/5.txt");
FileSystem fs = path.getFileSystem(conf);
readTestData(fs,path);
}
@Override
protected void reduce(IntWritable key, Iterable<MyWritable> values,
Context context)
throws IOException, InterruptedException {
Double[] myTest = new Double[testData.get(0).length-1];
for(int i=0;i<myTest.length;i++){
myTest[i] = 1.0;
}
Long sum = 2L;
// 计算每个类别中,每个属性值为1的个数
for (MyWritable myWritable : values) {
int[] myvalue = myWritable.getValue();
for(int i=0; i < myvalue.length;i++){
myTest[i] += myvalue[i];
}
sum += 1;
}
for(int i=0;i<myTest.length;i++){
myTest[i] = myTest[i]/sum;
}
allData.add(new CountAll(sum,myTest,key.get()));
}
private IntWritable myKey = new IntWritable();
private IntWritable myValue = new IntWritable();
protected void cleanup(Context context)
throws IOException, InterruptedException {
// 保存每个类别的在训练数据中出现的概率
// k,v 0,0.4
// k,v 1,0.6
HashMap<Integer, Double> labelG = new HashMap<>();
Long allSum = getSum(allData); //计算训练数据的长度
for(int i=0; i<allData.size();i++){
labelG.put(allData.get(i).getK(),
Double.parseDouble(allData.get(i).getSum().toString())/allSum);
}
//test的长度 要比训练数据中的长度大1
int sum = 0;
int yes = 0;
for(int[] test: testData){
int value = getClasify(test, labelG);
if(test[0] == value){
yes += 1;
}
sum +=1;
myKey.set(test[0]);
myValue.set(value);
context.write(myKey, myValue);
}
System.out.println("正确率为:"+(double)yes/sum);
}
/***
* 求得所有训练数据的条数
* @param allData2
* @return
*/
private Long getSum(ArrayList<CountAll> allData2) {
Long allSum = 0L;
for (CountAll countAll : allData2) {
log.info("类别:"+countAll.getK()+"数据:"+myString(countAll.getValue())+"总数:"+countAll.getSum());
allSum += countAll.getSum();
}
return allSum;
}
/***
* 得到分类的结果
* @param test
* @param labelG
* @return
*/
private int getClasify(int[] test,HashMap<Integer, Double> labelG ) {
double[] result = new double[allData.size()]; //以类别的长度作为数组的长度
for(int i = 0; i<allData.size();i++){
double count = 0.0;
CountAll ca = allData.get(i);
Double[] pdata = ca.getValue();
for(int j=1;j<test.length;j++){
if(test[j] == 1){
// 在该类别中,相同位置上的元素的值出现1的概率
count += Math.log(pdata[j-1]);
}else{
count += Math.log(1- pdata[j-1]);
}
log.info("count: "+count);
}
count += Math.log(labelG.get(ca.getK()));
result[i] = count;
}
if(result[0] > result[1]){
return 0;
}else{
return 1;
}
}
/***
* 读取测试数据
* @param fs
* @param path
* @throws NumberFormatException
* @throws IOException
*/
private void readTestData(FileSystem fs, Path path) throws NumberFormatException, IOException {
FSDataInputStream data = fs.open(path);
BufferedReader bf = new BufferedReader(new InputStreamReader(data));
String line = "";
while ((line = bf.readLine()) != null) {
String[] str = line.split(",");
int[] myData = new int[str.length];
for(int i=0;i<str.length;i++){
if(str[i]!=""||!str[i].equals(""))
if(str[i].matches("^[0-9]+$"))
myData[i] = Integer.parseInt(str[i]);
}
testData.add(myData);
}
bf.close();
data.close();
}
public static String myString(Double[] arr){
String num = "";
for(int i=0;i<arr.length;i++){
if(i==arr.length-1){
num += String.valueOf(arr[i]);
}else{
num += String.valueOf(arr[i])+',';
}
}
return num;
}
}
CountAll:
public class CountAll {
private Long sum;
private Double[] value;
private int k;
public CountAll(){}
public CountAll(Long sum, Double[] value,int k){
this.sum = sum;
this.value = value;
this.k = k;
}
public Double[] getValue() {
return value;
}
public void setValue(Double[] value) {
this.value = value;
}
public Long getSum() {
return sum;
}
public void setSum(Long sum) {
this.sum = sum;
}
public int getK() {
return k;
}
public void setK(int k) {
this.k = k;
}
}
MainJob:
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
public class MainJob {
public static void main(String[] args) throws Exception {
Configuration conf = new Configuration();
String[] otherArgs = new GenericOptionsParser(conf, args)
.getRemainingArgs();
if (otherArgs.length != 2) {
System.err.println("Usage: numbersum <in> <out>");
System.exit(2);
}
long startTime = System.currentTimeMillis();// 计算时间
Job job = new Job(conf);
job.setJarByClass(MainJob.class);
job.setMapperClass(BayesMapper.class);
job.setReducerClass(BayesReducer.class);
job.setMapOutputKeyClass(IntWritable.class);
job.setMapOutputValueClass(MyWritable.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(MyWritable.class);
FileInputFormat.addInputPath(job, new Path(otherArgs[0]));
FileOutputFormat.setOutputPath(job, new Path(otherArgs[1]));
job.waitForCompletion(true);
long endTime = System.currentTimeMillis();
System.out.println("time=" + (endTime - startTime));
System.exit(0);
}
}
测试数据:
1,0,0,0,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,0
1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,1
1,1,0,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,1
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,1,1,0,0,0,0,0,0
1,0,0,0,1,0,0,0,0,1,0,0,0,1,1,0,1,0,0,0,1,0,1
1,1,0,1,1,0,0,0,1,0,1,0,1,1,0,0,0,0,0,0,0,1,1
1,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,1
1,0,0,1,0,0,0,1,1,0,0,0,0,1,0,1,0,0,0,0,0,1,1
1,0,1,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0
1,1,1,0,0,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,1,0,1
1,1,1,0,0,1,1,1,0,1,1,1,1,0,1,0,0,1,0,1,1,0,0
1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,1
1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,1,0,0,1,1
1,1,0,1,1,0,0,1,1,1,0,1,1,1,1,1,1,0,1,1,0,1,1
1,0,1,1,0,0,1,1,1,0,0,0,1,1,0,0,1,1,1,0,1,1,1
1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,1,0,0,0,0,1,0
1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
1,1,0,1,0,1,0,1,1,0,1,0,1,1,0,0,0,1,0,0,1,1,0
1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0
1,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,0,0,0,0,1,1
1,1,0,0,0,1,1,0,1,0,0,1,0,0,0,0,0,0,0,1,1,0,0
1,1,1,0,0,1,1,1,0,0,1,1,1,0,0,0,0,0,0,1,0,0,0
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0
1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0
验证数据:
1,1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,0,0,1,1,0,0
1,1,0,0,1,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0
1,0,0,0,1,0,1,0,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1
1,0,1,1,1,0,0,1,0,1,0,0,1,1,1,0,1,0,0,0,0,1,0
1,0,0,1,0,0,0,0,1,0,0,1,0,1,1,0,1,0,0,0,0,0,1
1,0,0,1,1,0,1,0,0,1,0,1,0,1,0,0,1,0,0,0,0,1,1
1,1,0,0,1,0,0,1,1,1,1,0,1,1,1,0,1,0,0,0,1,0,1
1,1,0,0,1,0,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0
1,1,0,0,1,1,1,0,0,1,1,1,0,0,1,0,1,1,0,1,0,0,0
1,1,0,0,0,1,0,0,0,1,1,0,0,1,1,1,0,0,0,1,0,0,0
1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0
1,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1
1,1,0,0,0,1,0,0,0,1,1,0,0,0,1,0,0,0,1,1,0,0,0
1,1,0,0,1,1,0,0,0,1,1,0,0,0,0,0,1,0,0,1,1,0,0
1,1,0,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,1,0,0,1,0
1,1,1,0,0,1,1,1,1,0,1,1,1,1,0,0,0,1,0,0,0,1,1
1,1,0,0,0,0,1,1,0,0,1,1,1,0,0,0,0,1,0,0,0,0,1
1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0
1,1,1,1,0,1,0,1,1,0,1,0,1,1,0,0,1,0,0,0,1,1,0
1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,1,0,0
1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,1,0,1,1,1
1,0,0,1,1,1,0,0,1,1,1,0,0,1,1,1,1,0,1,0,1,1,0
1,1,1,0,1,1,1,1,0,0,0,1,1,0,0,0,1,1,0,0,1,0,0
1,1,1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0
1,1,1,0,0,1,1,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0
1,1,0,1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,1,0
1,1,1,1,1,0,1,1,1,0,1,0,0,1,1,1,1,0,0,1,1,0,0
运行结果: