矩阵乘法是比较常见的操作,其表示为M(i,j)和N(j,k)的两个矩阵,进行相乘,其中M的列数必须等于N的行数,并且结果矩阵O(i,j)的每个坐标的值等于M中第i行的每一个数乘以N中第k列的每一个数的行。矩阵乘法是满足结合率,而不满足交换率的。以下是一个比较简单的MR实现:
public class MatrixMultiply {
/** mapper和reducer需要的三个必要变量,由conf.get()方法得到 **/
public static int rowM = 0;
public static int columnM = 0;
public static int columnN = 0;
private static Logger logger=LoggerFactory.getLogger(MatrixMultiply.class);
public static class MatrixMapper extends Mapper<Object, Text, Text, Text> {
private Text map_key = new Text();
private Text map_value = new Text();
/**
* 执行map()函数前先由conf.get()得到main函数中提供的必要变量, 这也是MapReduce中共享变量的一种方式
*/
public void setup(Context context) throws IOException {
logger.info("map setup here");
Configuration conf = context.getConfiguration();
columnN = Integer.parseInt(conf.get("columnN"));
rowM = Integer.parseInt(conf.get("rowM"));
}
// Context类为Map类下面的一个需类,其实现了MapContext接口
public void map(Object key, Text value, Context context)
throws IOException, InterruptedException {
/** 得到输入文件名,从而区分输入矩阵M和N **/
//Map input records=8,表明分成了8个input record
logger.info("map here");
FileSplit fileSplit = (FileSplit) context.getInputSplit();
String fileName = fileSplit.getPath().getName();
//通过counter来统计这8个input record的filename,其中有2个M_1_2,6个N_2_3
Counter countFile = context.getCounter("Split File", fileName);
countFile.increment(1l);
Counter countKey = context.getCounter("Map Input oriKey", key.toString());
countKey.increment(1l);
Counter countValue = context.getCounter("Map Input oriValue", value.toString());
countValue.increment(1l);
/*这8个split的value值
1,1 1=1
1,1 2=1
1,2 1=1
1,2 2=1
1,3 3=1
2,1 0=1
2,2 2=1
2,3 4=1
*/
if (fileName.contains("M")) {
String[] tuple = value.toString().split(",");
int i = Integer.parseInt(tuple[0]);
String[] tuples = tuple[1].split("\t");
int j = Integer.parseInt(tuples[0]);
int Mij = Integer.parseInt(tuples[1]);
//经过解析后每行内容为i,j,Mij
for (int k = 1; k < columnN + 1; k++) {
//使用输出矩阵坐标构造map key
map_key.set(i + "," + k);
//map value是M_某列_值
map_value.set("M" + "," + j + "," + Mij);
context.write(map_key, map_value);
}
}
else if (fileName.contains("N")) {
String[] tuple = value.toString().split(",");
int j = Integer.parseInt(tuple[0]);
String[] tuples = tuple[1].split("\t");
int k = Integer.parseInt(tuples[0]);
int Njk = Integer.parseInt(tuples[1]);
for (int i = 1; i < rowM + 1; i++) {
map_key.set(i + "," + k);
//map value是N_某行_值
map_value.set("N" + "," + j + "," + Njk);
context.write(map_key, map_value);
Counter countNKey = context.getCounter("map_key", map_key.toString());
countNKey.increment(1l);
Counter countNValue = context.getCounter("map_value", map_value.toString());
countNValue.increment(1l);
}
}
//输出的map value为
/*
M,1,1=3
M,2,2=3
N,1,1=1
N,1,2=1
N,1,3=1
N,2,0=1
N,2,2=1
N,2,4=1
*/
}
}
public static class MatrixReducer extends Reducer<Text, Text, Text, Text> {
private int sum = 0;
public void setup(Context context) throws IOException {
logger.info("reduce setup here");
Configuration conf = context.getConfiguration();
columnM = Integer.parseInt(conf.get("columnM"));
}
public void reduce(Text key, Iterable<Text> values, Context context)
throws IOException, InterruptedException {
logger.info("reduce here");
int[] M = new int[columnM + 1];
int[] N = new int[columnM + 1];
Counter countReduceKey = context.getCounter("Reduce Input oriKey", key.toString());
countReduceKey.increment(1l);
Counter countReduceValue = context.getCounter("Reduce Input oriValue", values.toString());
countReduceValue.increment(1l);
for (Text val : values) {
String[] tuple = val.toString().split(",");
//构造每个输出矩阵坐标的处理数组
if (tuple[0].equals("M")) {
M[Integer.parseInt(tuple[1])] = Integer.parseInt(tuple[2]);
} else
N[Integer.parseInt(tuple[1])] = Integer.parseInt(tuple[2]);
}
/** 根据j值,对M[j]和N[j]进行相乘累加得到乘积矩阵的数据 **/
for (int j = 1; j < columnM + 1; j++) {
sum += M[j] * N[j];
}
context.write(key, new Text(Integer.toString(sum)));
sum = 0;
}
}
/**
* main函数
* <p>
* Usage:
*
* <p>
* <code>MatrixMultiply inputPathM inputPathN outputPath</code>
*
* <p>
* 从输入文件名称中得到矩阵M的行数和列数,以及矩阵N的列数,作为重要参数传递给mapper和reducer
*
* @param args
* 输入文件目录地址M和N以及输出目录地址
*
* @throws Exception
*/
public static void main(String[] args) throws Exception {
if (args.length != 3) {
System.err
.println("Usage: MatrixMultiply <inputPathM> <inputPathN> <outputPath>");
System.exit(2);
} else {
// 解析M_$1_$2
String[] infoTupleM = args[0].split("_");
rowM = Integer.parseInt(infoTupleM[1]);
columnM = Integer.parseInt(infoTupleM[2]);
// columnM等于rowN,所以不用写了
// 解析N_$2_$3
String[] infoTupleN = args[1].split("_");
columnN = Integer.parseInt(infoTupleN[2]);
}
Configuration conf = new Configuration();
/** 设置三个全局共享变量 **/
conf.setInt("rowM", rowM);
conf.setInt("columnM", columnM);
conf.setInt("columnN", columnN);
conf.set("mapreduce.framework.name", "local");
Job job = new Job(conf, "MatrixMultiply");
job.setJarByClass(MatrixMultiply.class);
job.setMapperClass(MatrixMapper.class);
job.setReducerClass(MatrixReducer.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Text.class);
//使用FileInputFormat划分inputsplit
//其中得到的是(path,key,value),其中path是文件名,key是该行在其文件的起始位置,value是该行内容,按照换行符进行划分
/*可以看出0和6出现了2次,0是因为第一行开始位置,6是第二行开始位置
0=2
12=1
18=1
24=1
30=1
6=2
*/
FileInputFormat
.setInputPaths(job, new Path(args[0]), new Path(args[1]));
FileOutputFormat.setOutputPath(job, new Path(args[2]));
System.exit(job.waitForCompletion(true) ? 0 : 1);
}
}
其中,主要MR化的思路为,由于每个O(i,k)=M(i,1)*N(1,k)+M(i,2)*N(2,k)+…M(i,j)*N(j,k),所以map/reduce中的key就选为新矩阵的坐标,而因为要在reduce中获取每个(i,k)的数组,也就是M(i,1)、M(i,2)..M(i,j)和N(1,k)、N(2,k)..N(j,k),所以每个map所输出的value要包括其要计算的每个坐标及其坐标值,然后在reduce中再根据坐标来重新划分为计算项,并计算
其他需要注意的是:
1.可以在local模式下用logger进行跟踪
2.可以使用counter来跟踪变量信息(local和yarn)
3.可以用mapper或者reducer的setup方法初始化变量
4.使用FileInputFormat/FileSplit来将文件按行划分