MapReduce计算PMI

1 问题说明

现有很多篇文档,分别使用pair和stripes两种数据结构,计算语料库中两两单词的互信息PMI,PMI的计算方法为:
P M I ( x , y ) = l o g p ( x , y ) p ( x ) p ( y ) PMI(x,y)=log\frac{ p(x,y)}{p(x)p(y)} PMI(x,y)=logp(x)p(y)p(x,y)
其中, x , y x,y x,y为不同的单词;
p ( x , y ) p(x,y) p(x,y) x x x y y y同时出现在一个文档的频率(同时出现的文档数/总文档数);
p ( x ) p(x) p(x) x x x出现在文档中的频率( x x x出现的文档数/总文档数);
p ( y ) p(y) p(y) y y y出现在文档中的频率( y y y出现的文档数/总文档数)。
注:本问题的核心是统计每个单词出现的文档数和单词对出现的文档数。

2 采用pair结构计算PMI

2.1 采用pair结构的思路

  • 采用两阶段的MapReduce
    第一阶段MapReduce统计每个词(word)出现的文档数,第二阶段统计每个词对(word_pair)出现的文档数并计算PMI。
  • 第一阶段MapReduce
    Mapper的输入是一篇文档,对其进行分词,去重后得到词集合(word_set),然后输出(word, 1)对。Reducer只需要对每个word的次数累加即可。
  • 第二阶段MapReduce
    Mapper的输入是一篇文档,对其进行分词,去重后排序得到有序的词集合(sorted_word_set),采用一个二重循环的遍历产生词对((first_word, second_word), 1),其中first_word的字母序小于second_word。Reducer只需要对每个word_pair出现的次数累加即可,并利用该结构计算PMI。

2.2 代码

注:需要使用TextPair数据结构,用于中间传输单词对,参见TextPair数据结构

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.*;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.*;


public class PairPMI {
    /**
     * 分为两阶段的MapReduce
     * 第一阶段统计每个词出现的文档数
     * 第二阶段计算PMI
     */

    // 第一阶段的mapreduce,负责统计每个词出现的文档数

    // 第一阶段Mapper
    public static class WordMapper extends Mapper<LongWritable, Text, Text, IntWritable> {

        @Override
        protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
            Set<String> word_set = new HashSet<String>();
            String word = new String();
            Text KEY = new Text();
            IntWritable ONE = new IntWritable(1);

            String clean_doc = value.toString().replaceAll("[^a-z A-Z]", " ");
            StringTokenizer doc_tokenizer = new StringTokenizer(clean_doc);

            while (doc_tokenizer.hasMoreTokens()) {
                word = doc_tokenizer.nextToken();
                if (word_set.add(word)) { // 如果set里面没有当前word
                    KEY.set(word);
                    context.write(KEY, ONE);
                }
            }
        }
    }

    // 第一阶段Reducer
    public static class CountReducer extends Reducer<Text, IntWritable, Text, IntWritable> {

        @Override
        protected void reduce(Text key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException {
            int sum = 0;
            for (IntWritable val : values) {
                sum += val.get();
            }
            context.write(key, new IntWritable(sum));
        }
    }

    // 第二阶段MapReduce

    // 第二阶段Mapper,生成((x,y),one)
    public static class PairPMIMapper extends Mapper<LongWritable, Text, TextPair, IntWritable> {

        @Override
        protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
            StringTokenizer doc_tokenizer = new StringTokenizer(value.toString().replaceAll("[^a-z A-Z]", " "));

            Set<String> word_set_sorted = new TreeSet<String>();
            String word = new String();
            TextPair KEY = new TextPair();
            Text first = new Text();
            Text second = new Text();
            IntWritable ONE = new IntWritable(1);

            // 用set记录都有哪些词,TreeSet是保证有子母序
            while (doc_tokenizer.hasMoreTokens()) {
                word = doc_tokenizer.nextToken();
                word_set_sorted.add(word);
            }

            String[] word_list = new String[word_set_sorted.size()];
            word_set_sorted.toArray(word_list);

            // 产生TextPair,前者小,后者大
            for (int i = 0; i < word_list.length; i++) {
                for (int j = i + 1; j < word_list.length; j++) {
                    first.set(word_list[i]);
                    second.set(word_list[j]);
                    KEY.set(first, second);
                    context.write(KEY, ONE);
                }
            }
        }
    }

    // 设置combiner
    private static class PairPMICombiner extends Reducer<TextPair, IntWritable, TextPair, IntWritable> {
        @Override
        protected void reduce(TextPair key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException {
            int sum = 0;
            for (IntWritable val : values) {
                sum += val.get();
            }
            context.write(key, new IntWritable(sum));
        }
    }

    // 第二阶段reducer
    public static class PairPMIReducer extends Reducer<TextPair, IntWritable, TextPair, DoubleWritable> {
        private static Map<String, Integer> word_total_map = new HashMap<String, Integer>();
        private static double docs_num = 316.0;

        @Override
        // 提前读取每个word的文档数
        protected void setup(Context context) throws IOException, InterruptedException {
            Path middle_result_path = new Path("hdfs://master:9000/homework/HW2/output/PairOutput/middle_result/part-r-00000");
            Configuration middle_conf = new Configuration();
            try {
                // 获取文件系统的实例
                FileSystem fs = FileSystem.get(URI.create(middle_result_path.toString()), middle_conf);

                if (!fs.exists(middle_result_path)) {
                    throw new IOException(middle_result_path.toString() + "文件不存在!");
                }
                //通过FileSystem的open方法打开一个指定的文件
                FSDataInputStream in = fs.open(middle_result_path);
                InputStreamReader inStream = new InputStreamReader(in);
                BufferedReader reader = new BufferedReader(inStream);

                // 逐行读取
                System.out.println("开始读取数据...");
                String line = reader.readLine();
                String[] line_terms;
                while (line != null) {
                    line_terms = line.split("\t");
                    word_total_map.put(line_terms[0], Integer.valueOf(line_terms[1]));
                    line = reader.readLine();
                }
                reader.close();
                System.out.println("读取完毕!");
            } catch (Exception e) {
                System.out.println(e.getMessage());
            }
        }

        @Override
        protected void reduce(TextPair key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException {
            int sum = 0;
            double PMI;
            for (IntWritable val : values) {
                sum += val.get();
            }
            double first_total = word_total_map.get(key.getFirst().toString());
            double second_total = word_total_map.get(key.getSecond().toString());
            PMI = Math.log(sum * docs_num / (first_total * second_total));
            context.write(key, new DoubleWritable(PMI));
        }
    }


    public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException, URISyntaxException {
        // 设置路径
        String inputPath = "hdfs://master:9000/homework/HW2/input/HarryPotter_new"; // 输入路径
        String middlePath = "hdfs://master:9000/homework/HW2/output/PairOutput/middle_result"; // 中间结果
        String outputPath = "hdfs://master:9000/homework/HW2/output/PairOutput/PMIresult"; // 最终结果

        FileSystem fs = FileSystem.get(new URI("hdfs://master:9000"), new Configuration(), "geosot");

        // 设置第一次MapReduce参数
        Configuration conf1 = new Configuration();
        Job job1 = Job.getInstance(conf1, "WordCount");

        job1.setJarByClass(PairPMI.class);
        job1.setMapperClass(WordMapper.class);
        job1.setReducerClass(CountReducer.class);
        job1.setOutputKeyClass(Text.class);
        job1.setOutputValueClass(IntWritable.class);

        FileInputFormat.setInputPaths(job1, new Path(inputPath));
        FileOutputFormat.setOutputPath(job1, new Path(middlePath));

        // 启动
        System.out.println("第一阶段mapreduce开始...");
        long startTime = System.currentTimeMillis();
        fs.delete(new Path(middlePath), true); // 删除之前的运行结果
        job1.waitForCompletion(true);
        double runtime1 = (System.currentTimeMillis() - startTime) / 1000.0;
        System.out.println("第一阶段mapreduce结束,耗时:" + runtime1 + "秒");

        // 设置第二次MapReduce
        Configuration conf2 = new Configuration();
        Job job2 = Job.getInstance(conf2, "computer PMI");

        job2.setJarByClass(PairPMI.class);
        job2.setMapperClass(PairPMIMapper.class);
        job2.setCombinerClass(PairPMICombiner.class);
        job2.setReducerClass(PairPMIReducer.class);

        job2.setOutputKeyClass(TextPair.class);
        job2.setOutputValueClass(DoubleWritable.class); // setOutputKeyClass和setOutputValueClass默认是同时设置map和reduce的输出类型的
        job2.setMapOutputValueClass(IntWritable.class); // 设置map输出的value,这个得放在后面才不会被覆盖

        FileInputFormat.setInputPaths(job2, new Path(inputPath));
        FileOutputFormat.setOutputPath(job2, new Path(outputPath));


        // 运行
        System.out.println("第二阶段mapreduce开始...");
        startTime = System.currentTimeMillis();
        fs.delete(new Path(outputPath), true); // 删除之前的运行结果
        job2.waitForCompletion(true);
        double runtime2 = (System.currentTimeMillis() - startTime) / 1000.0;
        System.out.println("第二阶段mapreduce结束,耗时:" + runtime2 + " 秒");

        // 任务结束
        System.out.println("PairPMI任务结束:");
        System.out.println("总耗时:" + (runtime1 + runtime2) + "秒");
        System.out.println("第一阶段mapreduce耗时:" + runtime1 + " 秒");
        System.out.println("第二阶段mapreduce耗时:" + runtime2 + " 秒");
    }
}

3 采用Stripes结构计算PMI

3.1 采用Stripes结构的思路

  • 采用两阶段的MapReduce
    第一阶段MapReduce与之前一样统计每个词(word)出现的文档数,第二阶段统计每个词对(word_pair)出现的文档数并计算PMI。
  • 第一阶段MapReduce
    与pair数据结构的第一阶段MapReduce一致,不多加赘述。
  • 第二阶段MapReduce
  1. Mapper的输入是一篇文档,对其进行分词,去重后排序得到有序的词集合(sorted_word_set),循环产生(word, stripe)对,其中stripe采用的MapWritable类型实现,stripe中出现的词的字母序都在word之后。举几个mapper输出的例子:
    ( a ,   { b : 1 , c : 3 , d : 2 } ) (a,\ \{b:1, c:3, d:2\}) (a, {b:1,c:3,d:2}) ( a ,   { b : 3 , d : 4 } ) (a,\ \{b:3, d:4\}) (a, {b:3,d:4}) ( b ,   { c : 2 , d : 3 } ) (b,\ \{c:2, d:3\}) (b, {c:2,d:3}) ( c ,   { d : 6 } ) (c,\ \{d:6\}) (c, {d:6})其中 a , b , c , d a,b,c,d a,b,c,d是单词,数字是出现的次数。
  2. Reducer需要对同一个word的stripe进行整合,整合方式为把向对应的词相加,比如 ( a ,   { b : 1 , c : 3 , d : 2 } ) (a,\ \{b:1, c:3, d:2\}) (a, {b:1,c:3,d:2}) ( a ,   { b : 3 , d : 4 } ) (a,\ \{b:3, d:4\}) (a, {b:3,d:4})相加得到 ( a ,   { b : 4 , c : 3 , d : 6 } ) (a,\ \{b:4, c:3, d:6\}) (a, {b:4,c:3,d:6})。根据每个词和每个词对出现的文档数就可以计算PMI。

3.2 代码

采坑记录:我采用SortedMapWritable数据结构来传递stripe发现mapper输出的数据与reducer拿到的数据不一致,mapper输出的数据中value里面都是key之后的单词,但是我在reducer里面发现value里面却有key之前的单词,也不知道为什么会这样,改成MapWritable就好了。

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.yarn.webapp.hamlet.Hamlet;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.*;


public class StripePMI {
    // 两个MapReduce过程,第一个MapReduce与pair方法一致,第二个MapReduce采用stripe结构,也就是MapWritable

    // 第一个MapReduce直接用pair的就行

    // 第二个MapReduce
    public static class StripePMIMapper extends Mapper<LongWritable, Text, Text, MapWritable> {
        @Override
        protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
            Set<String> sorted_word_set = new TreeSet<String>(); // 排序的word集合
            String doc_clean = value.toString().replaceAll("[^a-z A-Z]", " ");
            StringTokenizer doc_tokenizers = new StringTokenizer(doc_clean);
            while (doc_tokenizers.hasMoreTokens()) {
                sorted_word_set.add(doc_tokenizers.nextToken());
            }

            // set转为list
            String[] word_list = new String[sorted_word_set.size()];
            sorted_word_set.toArray(word_list);

            Text KEY = new Text();
            MapWritable VALUE = new MapWritable();
            IntWritable ONE = new IntWritable(1);

            // 依次产生(key, value)对
            for (int i = 0; i < word_list.length; i++) {
                KEY.set(word_list[i]);
                VALUE = new MapWritable(); // value为map类型
                for (int j = i + 1; j < word_list.length; j++) {
                    VALUE.put(new Text(word_list[j]), ONE);
                }
                context.write(KEY, VALUE);
                VALUE.clear();
            }
        }
    }

    public static class StripePMICombiner extends Reducer<Text, MapWritable, Text, MapWritable> {
        static IntWritable ZERO = new IntWritable(0);

        @Override
        protected void reduce(Text key, Iterable<MapWritable> values, Context context) throws IOException, InterruptedException {
            MapWritable stripe_map = new MapWritable(); // 用来存储当前key的stripe结果
            Text temp_text;
            IntWritable temp_count;
            IntWritable total_count;
            for (MapWritable val : values) {
//                System.out.println("new" +key + "   " + val.size());
                for (Map.Entry<Writable, Writable> entry : val.entrySet()) {
                    temp_text = (Text) entry.getKey();
                    temp_count = (IntWritable) entry.getValue();
                    total_count = (IntWritable) stripe_map.getOrDefault(temp_text, ZERO);
                    stripe_map.put(temp_text, new IntWritable(total_count.get() + temp_count.get()));
                }
            }
            context.write(key, stripe_map);
        }
    }

    public static class StripePMIReducer extends Reducer<Text, MapWritable, TextPair, DoubleWritable> {
        private static Map<String, Integer> word_total_map = new HashMap<String, Integer>();
        private static double docs_num = 316.0;
        private static IntWritable ZERO = new IntWritable(0);

        @Override
//         提前读取每个word的文档数
        protected void setup(Context context) throws IOException, InterruptedException {
            Path middle_result_path = new Path("hdfs://master:9000/homework/HW2/output/StripeOutput/middle_result/part-r-00000");
            Configuration middle_conf = new Configuration();
            try {
                // 获取文件系统的实例
                FileSystem fs = FileSystem.get(URI.create(middle_result_path.toString()), middle_conf);

                if (!fs.exists(middle_result_path)) {
                    throw new IOException(middle_result_path.toString() + "文件不存在!");
                }
                //通过FileSystem的open方法打开一个指定的文件
                FSDataInputStream in = fs.open(middle_result_path);
                InputStreamReader inStream = new InputStreamReader(in);
                BufferedReader reader = new BufferedReader(inStream);

                // 逐行读取
                System.out.println("开始读取数据...");
                String line = reader.readLine();
                String[] line_terms;
                while (line != null) {
                    line_terms = line.split("\t");
                    word_total_map.put(line_terms[0], Integer.valueOf(line_terms[1]));
                    line = reader.readLine();
                }
                reader.close();
                System.out.println("读取完毕!总共" + word_total_map.size() + "个word!");
            } catch (Exception e) {
                System.out.println(e.getMessage());
            }
        }

        @Override
        protected void reduce(Text key, Iterable<MapWritable> values, Context context) throws IOException, InterruptedException {

            // 累加stripe
            MapWritable stripe_map = new MapWritable(); // 用来存储当前key的stripe结果
            Text temp_text;
            IntWritable temp_count;
            IntWritable total_count;
            for (MapWritable val : values) {
                for (Map.Entry<Writable, Writable> entry : val.entrySet()) {
                    temp_text = (Text) entry.getKey();
                    temp_count = (IntWritable) entry.getValue();
                    total_count = (IntWritable) stripe_map.getOrDefault(temp_text, ZERO);
                    stripe_map.put(temp_text, new IntWritable(total_count.get() + temp_count.get()));
                }
            }

            // 计算PMI
            double PMI;
            for (Map.Entry<Writable, Writable> entry : stripe_map.entrySet()) {
                temp_text = (Text) entry.getKey();
                temp_count = (IntWritable) entry.getValue();
                PMI = Math.log(docs_num * temp_count.get() / word_total_map.get(key.toString()) / word_total_map.get(temp_text.toString()));
                context.write(new TextPair(key, temp_text), new DoubleWritable(PMI));
            }
        }
    }


    public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException, URISyntaxException {
        // 设置路径
        String inputPath = "hdfs://master:9000/homework/HW2/input/HarryPotter_new"; // 输入路径
//        String inputPath = "hdfs://master:9000/homework/HW2/input/HarryPotter_small_new"; // 输入路径
        String middlePath = "hdfs://master:9000/homework/HW2/output/StripeOutput/middle_result"; // 中间结果
        String outputPath = "hdfs://master:9000/homework/HW2/output/StripeOutput/PMIresult"; // 最终结果

        FileSystem fs = FileSystem.get(new URI("hdfs://master:9000"), new Configuration(), "geosot");


        // 设置第一次MapReduce参数
        Configuration conf1 = new Configuration();
        Job job1 = Job.getInstance(conf1, "WordCount");

        job1.setJarByClass(PairPMI.class);
        job1.setMapperClass(PairPMI.WordMapper.class);
        job1.setReducerClass(PairPMI.CountReducer.class);

        job1.setOutputKeyClass(Text.class);
        job1.setOutputValueClass(IntWritable.class);
        FileInputFormat.setInputPaths(job1, new Path(inputPath));
        FileOutputFormat.setOutputPath(job1, new Path(middlePath));

        // 启动第一次MapReduce
        System.out.println("第一阶段mapreduce开始...");
        long startTime = System.currentTimeMillis();
        fs.delete(new Path(middlePath), true); // 删除之前的运行结果
        job1.waitForCompletion(true);
        double runtime1 = (System.currentTimeMillis() - startTime) / 1000.0;
        System.out.println("第一阶段mapreduce结束,耗时:" + runtime1 + "秒");

        // 设置第二次MapReduce参数
        Configuration conf2 = new Configuration();
        Job job2 = Job.getInstance(conf2);

        job2.setJarByClass(StripePMI.class);
        job2.setMapperClass(StripePMI.StripePMIMapper.class); 
        job2.setCombinerClass(StripePMI.StripePMICombiner.class);
        job2.setReducerClass(StripePMI.StripePMIReducer.class);

        job2.setOutputKeyClass(TextPair.class);
        job2.setOutputValueClass(DoubleWritable.class);
        job2.setMapOutputKeyClass(Text.class);
        job2.setMapOutputValueClass(MapWritable.class);
        FileInputFormat.setInputPaths(job2, new Path(inputPath));
        FileOutputFormat.setOutputPath(job2, new Path(outputPath));

        // 运行
        System.out.println("第二阶段mapreduce开始...");
        startTime = System.currentTimeMillis();
        fs.delete(new Path(outputPath), true); // 删除之前的运行结果
        job2.waitForCompletion(true);
        double runtime2 = (System.currentTimeMillis() - startTime) / 1000.0;
        System.out.println("第二阶段mapreduce结束,耗时:" + runtime2 + " 秒");

        // 任务结束
        System.out.println("StripePMI任务结束:");
        System.out.println("总耗时:" + (runtime1 + runtime2) + "秒");
        System.out.println("第一阶段mapreduce耗时:" + runtime1 + " 秒");
        System.out.println("第二阶段mapreduce耗时:" + runtime2 + " 秒");
    }
}

4 总结

本次实验使用和不使用combiner时,完成pair和stripes两种数据结构的程序所需的具体时间见下表。

有无combinerPairStripes
有combiner500.9秒336.94秒
无combiner436.0秒230.7秒

从上表可以看出,无论是pair还是stripes,这两种数据结构的程序中如果使用combiners会导致程序运行效率降低,采用Stripes数据结构的运行效率优于使用pair数据结构。

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值