MR矩阵乘法

矩阵乘法是比较常见的操作,其表示为M(i,j)和N(j,k)的两个矩阵,进行相乘,其中M的列数必须等于N的行数,并且结果矩阵O(i,j)的每个坐标的值等于M中第i行的每一个数乘以N中第k列的每一个数的行。矩阵乘法是满足结合率,而不满足交换率的。以下是一个比较简单的MR实现:
public class MatrixMultiply {
    /** mapper和reducer需要的三个必要变量,由conf.get()方法得到 **/
    public static int rowM = 0;
    public static int columnM = 0;
    public static int columnN = 0;
    private static Logger logger=LoggerFactory.getLogger(MatrixMultiply.class);

    public static class MatrixMapper extends Mapper<Object, Text, Text, Text> {
        private Text map_key = new Text();
        private Text map_value = new Text();

        /**
         * 执行map()函数前先由conf.get()得到main函数中提供的必要变量, 这也是MapReduce中共享变量的一种方式
         */
        public void setup(Context context) throws IOException {
            logger.info("map setup here");
            Configuration conf = context.getConfiguration();
            columnN = Integer.parseInt(conf.get("columnN"));
            rowM = Integer.parseInt(conf.get("rowM"));
        }
// Context类为Map类下面的一个需类,其实现了MapContext接口
        public void map(Object key, Text value, Context context)
                throws IOException, InterruptedException {
            /** 得到输入文件名,从而区分输入矩阵M和N **/
            //Map input records=8,表明分成了8个input record

            logger.info("map here");

            FileSplit fileSplit = (FileSplit) context.getInputSplit();
            String fileName = fileSplit.getPath().getName();
            //通过counter来统计这8个input record的filename,其中有2个M_1_2,6个N_2_3
            Counter countFile = context.getCounter("Split File", fileName);
            countFile.increment(1l);
            Counter countKey = context.getCounter("Map Input oriKey", key.toString());
            countKey.increment(1l);
            Counter countValue = context.getCounter("Map Input oriValue", value.toString());
            countValue.increment(1l);
            /*这8个split的value值
            1,1 1=1
            1,1 2=1
            1,2 1=1
            1,2 2=1
            1,3 3=1
            2,1 0=1
            2,2 2=1
            2,3 4=1
             */
             if (fileName.contains("M")) {
                String[] tuple = value.toString().split(",");

                int i = Integer.parseInt(tuple[0]);
                String[] tuples = tuple[1].split("\t");
                int j = Integer.parseInt(tuples[0]);
                int Mij = Integer.parseInt(tuples[1]);
                //经过解析后每行内容为i,j,Mij

                for (int k = 1; k < columnN + 1; k++) {
                    //使用输出矩阵坐标构造map key
                    map_key.set(i + "," + k);
                    //map value是M_某列_值
                    map_value.set("M" + "," + j + "," + Mij);

                    context.write(map_key, map_value);

                }
            }
            else if (fileName.contains("N")) {

                String[] tuple = value.toString().split(",");
                int j = Integer.parseInt(tuple[0]);
                String[] tuples = tuple[1].split("\t");
                int k = Integer.parseInt(tuples[0]);
                int Njk = Integer.parseInt(tuples[1]);

                for (int i = 1; i < rowM + 1; i++) {
                    map_key.set(i + "," + k);
                    //map value是N_某行_值
                    map_value.set("N" + "," + j + "," + Njk);

                    context.write(map_key, map_value);

                    Counter countNKey = context.getCounter("map_key", map_key.toString());
                    countNKey.increment(1l);
                    Counter countNValue = context.getCounter("map_value", map_value.toString());
                    countNValue.increment(1l);
                }


            }
            //输出的map value为
            /*
             M,1,1=3
             M,2,2=3
             N,1,1=1
             N,1,2=1
             N,1,3=1
             N,2,0=1
             N,2,2=1
             N,2,4=1
             */


        }
    }
    public static class MatrixReducer extends Reducer<Text, Text, Text, Text> {
        private int sum = 0;

        public void setup(Context context) throws IOException {
            logger.info("reduce setup here");
            Configuration conf = context.getConfiguration();
            columnM = Integer.parseInt(conf.get("columnM"));
        }

        public void reduce(Text key, Iterable<Text> values, Context context)
                throws IOException, InterruptedException {

            logger.info("reduce here");

            int[] M = new int[columnM + 1];
            int[] N = new int[columnM + 1];

            Counter countReduceKey = context.getCounter("Reduce Input oriKey", key.toString());
            countReduceKey.increment(1l);
            Counter countReduceValue = context.getCounter("Reduce Input oriValue", values.toString());
            countReduceValue.increment(1l);

            for (Text val : values) {
                String[] tuple = val.toString().split(",");
                //构造每个输出矩阵坐标的处理数组       
                if (tuple[0].equals("M")) {
                    M[Integer.parseInt(tuple[1])] = Integer.parseInt(tuple[2]);
                } else
                    N[Integer.parseInt(tuple[1])] = Integer.parseInt(tuple[2]);
            }
            /** 根据j值,对M[j]和N[j]进行相乘累加得到乘积矩阵的数据 **/
            for (int j = 1; j < columnM + 1; j++) {
                sum += M[j] * N[j];
            }
            context.write(key, new Text(Integer.toString(sum)));
            sum = 0;
        }
    }

    /**
     * main函数
     * <p>
     * Usage:
     * 
     * <p>
     * <code>MatrixMultiply  inputPathM inputPathN outputPath</code>
     * 
     * <p>
     * 从输入文件名称中得到矩阵M的行数和列数,以及矩阵N的列数,作为重要参数传递给mapper和reducer
     * 
     * @param args
     *            输入文件目录地址M和N以及输出目录地址
     * 
     * @throws Exception
     */
     public static void main(String[] args) throws Exception {

        if (args.length != 3) {
            System.err
                    .println("Usage: MatrixMultiply <inputPathM> <inputPathN> <outputPath>");
            System.exit(2);
        } else {
            // 解析M_$1_$2
            String[] infoTupleM = args[0].split("_");
            rowM = Integer.parseInt(infoTupleM[1]);
            columnM = Integer.parseInt(infoTupleM[2]);
            // columnM等于rowN,所以不用写了
            // 解析N_$2_$3
            String[] infoTupleN = args[1].split("_");
            columnN = Integer.parseInt(infoTupleN[2]);
        }

        Configuration conf = new Configuration();
        /** 设置三个全局共享变量 **/
        conf.setInt("rowM", rowM);
        conf.setInt("columnM", columnM);
        conf.setInt("columnN", columnN);

        conf.set("mapreduce.framework.name", "local");

        Job job = new Job(conf, "MatrixMultiply");
        job.setJarByClass(MatrixMultiply.class);
        job.setMapperClass(MatrixMapper.class);
        job.setReducerClass(MatrixReducer.class);
        job.setOutputKeyClass(Text.class);
        job.setOutputValueClass(Text.class);
        //使用FileInputFormat划分inputsplit
        //其中得到的是(path,key,value),其中path是文件名,key是该行在其文件的起始位置,value是该行内容,按照换行符进行划分
        /*可以看出0和6出现了2次,0是因为第一行开始位置,6是第二行开始位置
        0=2
        12=1
        18=1
        24=1
        30=1
        6=2 
         */
        FileInputFormat
                .setInputPaths(job, new Path(args[0]), new Path(args[1]));
        FileOutputFormat.setOutputPath(job, new Path(args[2]));
        System.exit(job.waitForCompletion(true) ? 0 : 1);
    }
}

其中,主要MR化的思路为,由于每个O(i,k)=M(i,1)*N(1,k)+M(i,2)*N(2,k)+…M(i,j)*N(j,k),所以map/reduce中的key就选为新矩阵的坐标,而因为要在reduce中获取每个(i,k)的数组,也就是M(i,1)、M(i,2)..M(i,j)和N(1,k)、N(2,k)..N(j,k),所以每个map所输出的value要包括其要计算的每个坐标及其坐标值,然后在reduce中再根据坐标来重新划分为计算项,并计算

其他需要注意的是:
1.可以在local模式下用logger进行跟踪
2.可以使用counter来跟踪变量信息(local和yarn)
3.可以用mapper或者reducer的setup方法初始化变量
4.使用FileInputFormat/FileSplit来将文件按行划分

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值