矩阵运算是程序运行过程中非常耗时的部分,因而也出现了许多针矩阵计算的优化方法,尤其是矩阵乘法,对于
M
i
t
M_{it}
Mit和
N
t
k
N_{tk}
Ntk,其相乘的复杂度是
O
(
i
×
k
)
O(i \times k)
O(i×k)。熟悉矩阵运算的人都应知道,结果矩阵
P
P
P中的任意元素
p
i
k
p_{ik}
pik和
M
、
N
M、N
M、N中的第
i
i
i行、
k
k
k列均有关。
p
i
k
=
(
M
×
N
)
i
k
=
∑
j
m
i
j
n
j
k
p_{ik}=(M \times N)_{ik}=\sum_j m_{ij}n_{jk}
pik=(M×N)ik=j∑mijnjk
考虑到我们前面对于wordcount的解释,我们首先应当明白:在战术上,我们应当将其视作一般的java程序,只要设计好其中的一个程序,后面将其并行化就可以了。因为说白了,MapReduce就是将一个人做的事情分给好多其他人一起做。
首先,我们考虑,该如何设计Map和Reduce过程。
从矩阵相乘的过程出发,我们看到为了求出 p i k p_{ik} pik,我们需要先将 M M M中的第 i i i行的 t t t个元素和 N N N中的第 k k k列中的 j j j个元素对应相乘,将相乘得到的结果相加,就是元素 p i k p_{ik} pik。
首先,在Map阶段,我们要做的工作就是做好划分。将可以组成最终结果 p i k p_{ik} pik的所有元素都交给一个Reduce结点处理。这样,我们就可以让其来进行乘法和加和并得到最终的结果。那也就是说, ( i , k ) (i,k) (i,k)应该是我们partitioner过程的关键字,也就是map写入结果的关键字。
那么,对一个作为输入的元素,怎么知道他的关键字呢?
也简单,如果是 M M M,那么就看他的行号,列号并不重要。因为这个元素会对结果矩阵中所有行号为 i i i的元素发挥作用。
如果是 N N N就看他的列号。
但是,就算这样也并不足以支撑我们完成程序编写。到了reduce过程我们还要做元素相乘相加。那么还需要进行对每个元素做辨识的东西。
若一个元素的key值是 s,t。那么,我们会知道,这个元素最终会对结果矩阵中元素 p s t p_{st} pst发挥作用。但是,我们还缺少该元素的值,该元素的哪个矩阵中,以及他的行号或者列号。为此,我们将其value值设为:M,2,20这种形式。什么意思呢?
看value值,首先,第一值表示该元素是M矩阵中的。那么也就意味着,key值中的行号是该元素的行号,那么value中的第二个元素就表示该值得列号,也就意味着,该元素为M矩阵中第s行、第2列的元素20。此时,信息充分了起来。
如此,便可进行Map函数的编程。
首先,我们确定输入。丢掉输入文件不考虑的程序,都是刷流氓。这里,我们的输入是两个矩阵 M M M和 N N N。文件名就以 M M M和 N N N命名。文件中,前面是该元素的行号和列号,用逗号隔开,后面是元素值,用空格隔开,示例如下:
1,1 3
1,2 4
map的输出关键字是该元素在结果矩阵中为哪个元素做出了贡献,值为该值在哪个矩阵中,后面是他的列号以及元素值,示例如下:
1,1 M,3,3
public static class MatrixMapper extends Mapper <Object, Text, Text, Text> {
//定义私有变量,作为后续的写入内容
private Text map_key = new Text();
private Text map_value = new Text();
int columnN;
int rowM;
public void setup (Context context) throws IOException {
Configuration conf = context.getConfiguration();
columnN = conf.getInt("columnN", 4);//N的列数
rowM = conf.getInt("rowM", 4);
}
public void map (Object key, Text value, Context context) throws IOException, InterruptedException {
//首先要得到文件名,因为文件名中包含矩阵的大小
FileSplit fileSplit = (FileSplit) context.getInputSplit();
String fileName = fileSplit.getPath().getName(); //得到文件名
if (fileName.contains("M")) {
String[] tuple = value.toString().split(",");
int i = Integer.parseInt(tuple[0]);
String[] tuples = tuple[1].split(" ");
int j = Integer.parseInt(tuples[0]);
int Mij = Integer.parseInt(tuples[1]);
for (int k = 1; k <columnN+1; k++) {
map_key.set(i + "," + k);
map_value.set("M" + "," + j + "," + Mij); //一个Mij要与N中的很多元素相乘
context.write(map_key, map_value);
}
}
else if (fileName.contains("N")) {
String[] tuple = value.toString().split(",");
int j = Integer.parseInt(tuple[0]);
String[] tuples = tuple[1].split(" ");
int k = Integer.parseInt(tuples[0]);
int Njk = Integer.parseInt(tuples[1]);
for (int t = 1; t <rowM+1; t++) {
map_key.set(t + "," + k);
map_value.set("N" + "," + j + "," + Njk); //一个Mij要与N中的很多元素相乘
context.write(map_key, map_value);
}
}
}
}
在Reduce端,会收到同一key值的多个value。对于同一key的,查看其value,让value中第二个值相同的M和N相乘,再将结果相加,那就是最终结果了。
public static class MatrixReducer extends Reducer<Text, Text, Text, Text> {
int columnM;
private int sum = 0;
public void setup (Context context) throws IOException {
Configuration conf = context.getConfiguration();
columnM = conf.getInt("columnM", 4);//M的列
}
public void reduce (Text key, Iterable<Text> value, Context context) throws IOException, InterruptedException {
int[] M = new int[columnM+1];
int[] N = new int[columnM+1];
//value为: M,1,34
for (Text val: value) {
String[] tuples = val.toString().split(",");
if (tuples[0].equals("M")) {
M[Integer.parseInt(tuples[1])] = Integer.parseInt(tuples[2]);
}
else {
N[Integer.parseInt(tuples[1])] = Integer.parseInt(tuples[2]);
}
}
for (int j=1; j<columnM+1; j++) {
sum += M[j] * N[j];
}
context.write(key, new Text(Integer.toString(sum)));
sum = 0;
}
}
需要注意的是,这里利用了MapReduce中的一个高级用法:setup函数。他是初始化函数,在mapreduce的一次运行过程中只会运行一次。
这里,利用他们得到矩阵的行列数。
最后,给出全部代码:
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import java.io.IOException;
public class Matrix {
public static class MatrixMapper extends Mapper <Object, Text, Text, Text> {
//定义私有变量,作为后续的写入内容
private Text map_key = new Text();
private Text map_value = new Text();
int columnN;
int rowM;
public void setup (Context context) throws IOException {
Configuration conf = context.getConfiguration();
columnN = conf.getInt("columnN", 4);//N的列数
rowM = conf.getInt("rowM", 4);
}
public void map (Object key, Text value, Context context) throws IOException, InterruptedException {
//首先要得到文件名,因为文件名中包含矩阵的大小
FileSplit fileSplit = (FileSplit) context.getInputSplit();
String fileName = fileSplit.getPath().getName(); //得到文件名
if (fileName.contains("M")) {
String[] tuple = value.toString().split(",");
int i = Integer.parseInt(tuple[0]);
String[] tuples = tuple[1].split(" ");
int j = Integer.parseInt(tuples[0]);
int Mij = Integer.parseInt(tuples[1]);
for (int k = 1; k <columnN+1; k++) {
map_key.set(i + "," + k);
map_value.set("M" + "," + j + "," + Mij); //一个Mij要与N中的很多元素相乘
context.write(map_key, map_value);
}
}
else if (fileName.contains("N")) {
String[] tuple = value.toString().split(",");
int j = Integer.parseInt(tuple[0]);
String[] tuples = tuple[1].split(" ");
int k = Integer.parseInt(tuples[0]);
int Njk = Integer.parseInt(tuples[1]);
for (int t = 1; t <rowM+1; t++) {
map_key.set(t + "," + k);
map_value.set("N" + "," + j + "," + Njk); //一个Mij要与N中的很多元素相乘
context.write(map_key, map_value);
}
}
}
}
public static class MatrixReducer extends Reducer<Text, Text, Text, Text> {
int columnM;
private int sum = 0;
public void setup (Context context) throws IOException {
Configuration conf = context.getConfiguration();
columnM = conf.getInt("columnM", 4);//M的列
}
public void reduce (Text key, Iterable<Text> value, Context context) throws IOException, InterruptedException {
int[] M = new int[columnM+1];
int[] N = new int[columnM+1];
//value为: M,1,34
for (Text val: value) {
String[] tuples = val.toString().split(",");
if (tuples[0].equals("M")) {
M[Integer.parseInt(tuples[1])] = Integer.parseInt(tuples[2]);
}
else {
N[Integer.parseInt(tuples[1])] = Integer.parseInt(tuples[2]);
}
}
for (int j=1; j<columnM+1; j++) {
sum += M[j] * N[j];
}
context.write(key, new Text(Integer.toString(sum)));
sum = 0;
}
}
public static void main (String[] args) throws Exception {
Configuration conf = new Configuration();
Job job = Job.getInstance(conf, "Matrix Computation");
conf.setInt("columnN", 4);
conf.setInt("rowM", 4);
conf.setInt("columnM", 4);
job.setJarByClass(Matrix.class);
job.setMapperClass(MatrixMapper.class);
job.setReducerClass(MatrixReducer.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Text.class);
FileInputFormat.addInputPath(job, new Path("input"));
FileOutputFormat.setOutputPath(job, new Path("output"));
System.exit(job.waitForCompletion(true) ? 0:1);
}
}
ps:欢迎关注公众号,掌握实时更新: