在矩阵乘法中,第一个矩阵的列数和第二个矩阵的行数必须是相同的。如果需要进行并行计算,一种简单的策略是可以将A矩阵进行水平分割,得到子矩阵A1和A2,B矩阵进行垂直分割,得到子矩阵B1和B2。此时,我们只要分别计算这些子矩阵的乘积,将结果进行拼接,就能得到原始矩阵A和B的乘积。
我们使用ForkJoin框架来实现这个并行矩阵相乘的想法。为了方便矩阵计算,我们使用jMatrces开源软件,作为矩阵计算的工具。其中,使用的主要API如下:
Matrix:代表一个矩阵
MatrixOperator.multiply(Matrix, Matrix):矩阵相乘
Matrix.row():获得矩阵的行数
Matrix.getSubMatrix():获得矩阵的子矩阵
MatrixOperator.horizontalConcatenation(Matrix, Matrix):将两个矩阵进行水平连接
MatrixOperator.verticalConcatenation(Matrix, Matrix):将两个矩阵进行垂直连接
并行算法代码如下:
public class MatrixMulTask extends RecursiveTask<Matrix> {
public static final int granularity = 3;
Matrix m1;
Matrix m2;
String pos;
public MatrixMulTask(Matrix m1,Matrix m2,String pos) {
this.m1 = m1;
this.m2 = m2;
this.pos = pos;
}
@Override
protected Matrix compute() {
if(m1.rows() <= MatrixMulTask.granularity || m2.cols()<=MatrixMulTask.granularity) {
Matrix mRe = MatrixOperator.multiply(m1, m2);
return mRe;
} else {
int rows;
rows = m1.rows();
Matrix m11 = m1.getSubMatrix(1, 1, rows/2, m1.cols());
Matrix m12 = m1.getSubMatrix(rows/2+1, 1, m1.rows(), m1.cols());
Matrix m21 = m2.getSubMatrix(1, 1, m2.rows(), m12.cols()/2);
Matrix m22 = m2.getSubMatrix(1, m2.cols()/2+1, m2.rows(), m2.cols());
ArrayList<MatrixMulTask> subTasks = new ArrayList<MatrixMulTask>();
MatrixMulTask tmp = null;
tmp = new MatrixMulTask(m11, m21, "m1");
subTasks.add(tmp);
tmp = new MatrixMulTask(m11, m22, "m2");
subTasks.add(tmp);
tmp = new MatrixMulTask(m12, m21, "m3");
subTasks.add(tmp);
tmp = new MatrixMulTask(m12, m22, "m4");
subTasks.add(tmp);
for(MatrixMulTask t : subTasks) {
t.fork();
}
Map<String, Matrix> matrixMap = new HashMap<String,Matrix>();
for(MatrixMulTask t :subTasks) {
matrixMap.put(t.pos, t.join());
}
Matrix tmp1 = MatrixOperator.horizontalConcatenation(matrixMap.get("m1"), matrixMap.get("m2"));
Matrix tmp2 = MatrixOperator.horizontalConcatenation(matrixMap.get("m3"), matrixMap.get("m4"));
Matrix reM = MatrixOperator.verticalConcatenation(tmp1, tmp2);
return reM;
}
}
public static void main(String[] args) throws InterruptedException, ExecutionException {
ForkJoinPool forkJoinPool = new ForkJoinPool();
Matrix m1 = MatrixFactory.getRandomMatrix(10, 10, null);
Matrix m2 = MatrixFactory.getRandomMatrix(10, 10, null);
MatrixMulTask task = new MatrixMulTask(m1, m2, null);
ForkJoinTask<Matrix> result = forkJoinPool.submit(task);
Matrix pr = result.get();
System.out.println(pr);
}
}
MatrixMULTask中的成员变量m1和m2表示要相乘的两个矩阵,pos表示这个乘积结果在父矩阵相乘结果中所处的位置,有m1,m2,m3,和m4等四种。先对矩阵进行分割,分割后得到m11、m12、m21和m22等四个任务,并将它们进行子任务的创建。然后计算这些子任务,最后将m1,m2,m3,和m4拼接成新的矩阵作为最终结果。