hadoop-朴素贝叶斯算法的简单实现

文章转自:https://blog.csdn.net/Angelababy_huan/article/details/53046151

  贝叶斯分类器的分类原理是通过某对象的先验概率,利用贝叶斯公式计算出其后验概率,即该对象属于某一类的概率,选择具有最大后验概率的类作为该对象所属的类。    

    以下为一个简单的例子:

    数据:天气情况和每天是否踢足球的记录表

日期踢足球天气温度湿度风速
1号否(0)晴天(0)热(0)高(0)低(0)
2号否(0)晴天(0)热(0)高(0)高(1)
3号是(1)多云(1)热(0)高(0)低(0)
4号是(1)下雨(2)舒适(1)高(0)低(0)
5号是(1)下雨(2)凉爽(2)正常(1)低(0)
6号否(0)下雨(2)凉爽(2)正常(1)高(1)
7号是(1)多云(1)凉爽(2)正常(1)高(1)
8号否(0)晴天(0)舒适(1)高(0)低(0)
9号是(1)晴天(0)凉爽(2)正常(1)低(0)
10号是(1)下雨(2)舒适(1)正常(1)低(0)
11号是(1)晴天(0)舒适(1)正常(1)高(1)
12号是(1)多云(1)舒适(1)高(0)高(1)
13号是(1)多云(1)热(0)正常(1)低(0)
14号否(0)下雨(2)舒适(1)高(0)高(1)
15号晴天(0)凉爽(2)高(0)高(1)

    需要预测15号,在这种天气情况下是否踢球。

    假设15号去踢球,踢球的概率计算过程如下:

    P(踢球的概率) = 9/14

    P(晴天|踢) = 踢球天数中晴天踢球的次数/踢球次数 = 2/9

    P(凉爽|踢) = 踢球天数中凉爽踢球的次数/踢球次数 = 3/9

    P(湿度高|踢) = 踢球天数中湿度高踢球的次数/踢球次数 = 3/9

    P(风速高|踢) = 踢球天数中风速高踢球的次数/踢球次数 = 3/9

    则15号踢球的概率P = 9/14 * 2/9 * 3/9 * 3/9 * 3/9 = 0.00529

    按照上述步骤还可计算出15号不去踢球的概率P = 5/14 * 3/5 * 1/5 * 4/5 * 3/5 = 0.02057

    可以看出,15号不去踢球的概率大于去踢球的概率,则可预测说,15号不去踢球。

    理解朴素贝叶斯的流程之后,开始设计MR程序。在Mapper中,对训练数据进行拆分,也就是将这条训练数据拆分为类别和训练数据,将训练数据以自定义值类型来保存,然后传递给Reducer。

                

Mapper:

import java.io.IOException;   
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;  
import org.apache.hadoop.mapreduce.Mapper;  
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BayesMapper extends Mapper<Object, Text, IntWritable, MyWritable> {  
    Logger log = LoggerFactory.getLogger(BayesMapper.class);  
    private IntWritable myKey = new IntWritable();  
    private MyWritable myValue = new MyWritable();
    @Override  
    protected void map(Object key, Text value, Context context)  
            throws IOException, InterruptedException {  
        log.info("***"+value.toString());  
        int[] values = getIntData(value);  
        int label = values[0];  //存放类别  
        int[] result = new int[values.length-1]; //存放数据  
        for(int i =1;i<values.length;i++){  
            result[i-1] = values[i];
        }  
        myKey.set(label);  
        myValue.setValue(result);  
        context.write(myKey, myValue);  
    }  
    private int[] getIntData(Text value) {  
        String[] values = value.toString().split(",");  
        int[] data = new int[values.length];
        for(int i=0; i < values.length;i++){
        	if(!values[i].equals(""))
        		if(values[i].matches("^[0-9]+$"))
        			data[i] = Integer.parseInt(values[i]);  
        }  
        return data;  
    }  
}  

MyWritable:

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;

public class MyWritable implements Writable{  
    private int[] value;  
    public MyWritable() {  
        
    }  
    public MyWritable(int[] value){  
        this.setValue(value);  
    } 
    public void write(DataOutput out) throws IOException {  
        out.writeInt(value.length);  
        for(int i=0; i<value.length;i++){  
            out.writeInt(value[i]);  
        }  
    }   
    public void readFields(DataInput in) throws IOException {  
        int vLength = in.readInt();  
        value = new int[vLength];  
        for(int i=0; i<vLength;i++){  
            value[i] = in.readInt();  
        }  
    }  
    public int[] getValue() {  
        return value;  
    }  
    public void setValue(int[] value) {  
        this.value = value;  
    }  
}  

 

Reducer:

 

import java.io.BufferedReader;
import java.io.IOException;  
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;  
import org.apache.hadoop.conf.Configuration;  
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;  
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Reducer;  
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BayesReducer extends Reducer<IntWritable, MyWritable, IntWritable, IntWritable>{  
    Logger log = LoggerFactory.getLogger(BayesReducer.class);  
    private String testFilePath;  
    // 测试数据  
    private ArrayList<int[]> testData = new ArrayList<>();  
    // 保存相同k的所有数据  
    private ArrayList<CountAll> allData = new ArrayList<>();  
    @Override  
    protected void setup(Context context)  
            throws IOException, InterruptedException {  
        Configuration conf = context.getConfiguration();  
        testFilePath = conf.get("home/5.txt");  
        Path path = new Path("home/5.txt");  
        FileSystem fs = path.getFileSystem(conf);  
        readTestData(fs,path);  
    }  
    @Override  
    protected void reduce(IntWritable key, Iterable<MyWritable> values,  
            Context context)  
            throws IOException, InterruptedException {  
        Double[] myTest = new Double[testData.get(0).length-1];  
        for(int i=0;i<myTest.length;i++){  
            myTest[i] = 1.0;  
        }  
        Long sum = 2L;  
        // 计算每个类别中,每个属性值为1的个数  
        for (MyWritable myWritable : values) {  
            int[] myvalue = myWritable.getValue();  
            for(int i=0; i < myvalue.length;i++){  
                myTest[i] += myvalue[i];  
            }  
            sum += 1;  
        }  
        for(int i=0;i<myTest.length;i++){  
            myTest[i] = myTest[i]/sum;  
        }  
        allData.add(new CountAll(sum,myTest,key.get()));  
    }  
    private IntWritable myKey = new IntWritable();  
    private IntWritable myValue = new IntWritable();  
      
    protected void cleanup(Context context)  
            throws IOException, InterruptedException {  
        // 保存每个类别的在训练数据中出现的概率  
        // k,v  0,0.4  
        // k,v  1,0.6  
        HashMap<Integer, Double> labelG = new HashMap<>();  
        Long allSum = getSum(allData); //计算训练数据的长度  
        for(int i=0; i<allData.size();i++){  
            labelG.put(allData.get(i).getK(),   
                    Double.parseDouble(allData.get(i).getSum().toString())/allSum);  
        }  
        //test的长度 要比训练数据中的长度大1  
        int sum = 0;  
        int yes = 0;  
        for(int[] test: testData){  
            int value = getClasify(test, labelG);  
            if(test[0] == value){  
                yes += 1;  
            }  
            sum +=1;  
            myKey.set(test[0]);  
            myValue.set(value);  
            context.write(myKey, myValue);  
        }  
        System.out.println("正确率为:"+(double)yes/sum);  
    }  
    /*** 
     * 求得所有训练数据的条数 
     * @param allData2 
     * @return 
     */  
    private Long getSum(ArrayList<CountAll> allData2) {  
        Long allSum = 0L;  
        for (CountAll countAll : allData2) {  
            log.info("类别:"+countAll.getK()+"数据:"+myString(countAll.getValue())+"总数:"+countAll.getSum());  
            allSum += countAll.getSum();  
        }  
        return allSum;  
    }  
    /*** 
     * 得到分类的结果 
     * @param test 
     * @param labelG 
     * @return 
     */  
    private int getClasify(int[] test,HashMap<Integer, Double> labelG ) {  
        double[] result = new double[allData.size()]; //以类别的长度作为数组的长度  
        for(int i = 0; i<allData.size();i++){  
            double count = 0.0;  
            CountAll ca = allData.get(i);  
            Double[] pdata = ca.getValue();  
            for(int j=1;j<test.length;j++){  
                if(test[j] == 1){  
                    // 在该类别中,相同位置上的元素的值出现1的概率  
                    count += Math.log(pdata[j-1]);   
                }else{  
                    count += Math.log(1- pdata[j-1]);   
                }  
                log.info("count: "+count);  
            }  
            count += Math.log(labelG.get(ca.getK()));  
            result[i] = count;  
        }   
        if(result[0] > result[1]){  
            return 0;  
        }else{  
            return 1;  
        }  
    }  
    /*** 
     * 读取测试数据 
     * @param fs 
     * @param path 
     * @throws NumberFormatException 
     * @throws IOException 
     */  
    private void readTestData(FileSystem fs, Path path) throws NumberFormatException, IOException {  
        FSDataInputStream data = fs.open(path);  
        BufferedReader bf = new BufferedReader(new InputStreamReader(data));  
        String line = "";  
        while ((line = bf.readLine()) != null) {  
            String[] str = line.split(",");  
            int[] myData = new int[str.length];  
            for(int i=0;i<str.length;i++){
            	if(str[i]!=""||!str[i].equals(""))
            		if(str[i].matches("^[0-9]+$"))
                myData[i] = Integer.parseInt(str[i]);  
            }  
            testData.add(myData);  
        }  
        bf.close();  
        data.close();  
          
    }  
    public static String myString(Double[] arr){  
        String num = "";  
        for(int i=0;i<arr.length;i++){  
            if(i==arr.length-1){  
                num += String.valueOf(arr[i]);  
            }else{  
                num += String.valueOf(arr[i])+',';  
            }  
        }  
        return num;  
    }  
}  

CountAll:

public class CountAll {  
    private Long sum;  
    private Double[] value;  
    private int k;  
    public CountAll(){}  
    public CountAll(Long sum, Double[] value,int k){  
        this.sum = sum;  
        this.value = value;  
        this.k = k;  
    }  
    public Double[] getValue() {  
        return value;  
    }  
    public void setValue(Double[] value) {  
        this.value = value;  
    }  
    public Long getSum() {  
        return sum;  
    }  
    public void setSum(Long sum) {  
        this.sum = sum;  
    }  
    public int getK() {  
        return k;  
    }  
    public void setK(int k) {  
        this.k = k;  
    }  
}  

MainJob:

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
public class MainJob {
	public static void main(String[] args) throws Exception {  
        Configuration conf = new Configuration();  
        String[] otherArgs = new GenericOptionsParser(conf, args)  
                .getRemainingArgs();  
        if (otherArgs.length != 2) {  
            System.err.println("Usage: numbersum <in> <out>");  
            System.exit(2);  
        }  
        long startTime = System.currentTimeMillis();// 计算时间  
        Job job = new Job(conf);  
        job.setJarByClass(MainJob.class);  
        job.setMapperClass(BayesMapper.class);  
        job.setReducerClass(BayesReducer.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(MyWritable.class);
        job.setOutputKeyClass(IntWritable.class);  
        job.setOutputValueClass(MyWritable.class);
        FileInputFormat.addInputPath(job, new Path(otherArgs[0]));  
        FileOutputFormat.setOutputPath(job, new Path(otherArgs[1]));  
        job.waitForCompletion(true);  
        long endTime = System.currentTimeMillis();  
        System.out.println("time=" + (endTime - startTime));  
        System.exit(0);  
    }  

}

测试数据:

1,0,0,0,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,0  
1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,1  
1,1,0,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,1  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,1,1,0,0,0,0,0,0  
1,0,0,0,1,0,0,0,0,1,0,0,0,1,1,0,1,0,0,0,1,0,1  
1,1,0,1,1,0,0,0,1,0,1,0,1,1,0,0,0,0,0,0,0,1,1  
1,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,1  
1,0,0,1,0,0,0,1,1,0,0,0,0,1,0,1,0,0,0,0,0,1,1  
1,0,1,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0  
1,1,1,0,0,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,1,0,1  
1,1,1,0,0,1,1,1,0,1,1,1,1,0,1,0,0,1,0,1,1,0,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,1  
1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,1,0,0,1,1  
1,1,0,1,1,0,0,1,1,1,0,1,1,1,1,1,1,0,1,1,0,1,1  
1,0,1,1,0,0,1,1,1,0,0,0,1,1,0,0,1,1,1,0,1,1,1  
1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,0,1,0,0,0,0,1,0  
1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1  
1,1,0,1,0,1,0,1,1,0,1,0,1,1,0,0,0,1,0,0,1,1,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0  
1,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,0,0,0,0,1,1  
1,1,0,0,0,1,1,0,1,0,0,1,0,0,0,0,0,0,0,1,1,0,0  
1,1,1,0,0,1,1,1,0,0,1,1,1,0,0,0,0,0,0,1,0,0,0  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0  
1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0  

验证数据:

1,1,0,0,1,1,0,0,0,1,1,0,0,0,1,1,1,0,0,1,1,0,0  
1,1,0,0,1,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0  
1,0,0,0,1,0,1,0,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1  
1,0,1,1,1,0,0,1,0,1,0,0,1,1,1,0,1,0,0,0,0,1,0  
1,0,0,1,0,0,0,0,1,0,0,1,0,1,1,0,1,0,0,0,0,0,1  
1,0,0,1,1,0,1,0,0,1,0,1,0,1,0,0,1,0,0,0,0,1,1  
1,1,0,0,1,0,0,1,1,1,1,0,1,1,1,0,1,0,0,0,1,0,1  
1,1,0,0,1,0,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0  
1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0  
1,1,0,0,1,1,1,0,0,1,1,1,0,0,1,0,1,1,0,1,0,0,0  
1,1,0,0,0,1,0,0,0,1,1,0,0,1,1,1,0,0,0,1,0,0,0  
1,1,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0  
1,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,1  
1,1,0,0,0,1,0,0,0,1,1,0,0,0,1,0,0,0,1,1,0,0,0  
1,1,0,0,1,1,0,0,0,1,1,0,0,0,0,0,1,0,0,1,1,0,0  
1,1,0,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,1,0,0,1,0  
1,1,1,0,0,1,1,1,1,0,1,1,1,1,0,0,0,1,0,0,0,1,1  
1,1,0,0,0,0,1,1,0,0,1,1,1,0,0,0,0,1,0,0,0,0,1  
1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0  
1,1,1,1,0,1,0,1,1,0,1,0,1,1,0,0,1,0,0,0,1,1,0  
1,1,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,1,0,0  
1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,1,0,1,1,1  
1,0,0,1,1,1,0,0,1,1,1,0,0,1,1,1,1,0,1,0,1,1,0  
1,1,1,0,1,1,1,1,0,0,0,1,1,0,0,0,1,1,0,0,1,0,0  
1,1,1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0  
1,1,1,0,0,1,1,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0  
1,1,0,1,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,1,0  
1,1,1,1,1,0,1,1,1,0,1,0,0,1,1,1,1,0,0,1,1,0,0 

运行结果:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值