[原创]利用Java多线程计算矩阵乘法
前言
前段时间在一本操作系统书籍上,看到了可以利用多线程来计算矩阵乘法的思想。例如下图中,A矩阵和B矩阵相乘得到C矩阵,那么A矩阵的每一行和B矩阵的每一列的相乘和加和,都可以交给一个线程来计算,最终得到cij这个元素。A矩阵维度是m*s,B矩阵是s*n,那么这个计算需要m*n个线程的参与,它是否一定比串行计算快呢?本文使用Java多线程一探究竟。
串行计算
矩阵乘法的串行计算方法是不难想到的。三层循环遍历计算即可。从外到内分别遍历A矩阵的行、B矩阵的列、A矩阵的列(即为B矩阵的行)即可。在计算的开始和结束时刻分别获取系统当前时刻,最后可得计算时间。以下是简单的代码片段。
// 串行验证
startTime = System.currentTimeMillis();
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < B[0].length; j++) {
for (int k = 0; k < A[0].length; k++)
serial_result[i][j] += A[i][k] * B[k][j];
}
}
endTime = System.currentTimeMillis();
System.out.println("串行计算开始时刻:" + (startTime));
System.out.println("串行计算结束时刻:" + (endTime));
System.out.println("串行计算运行时间:" + (endTime - startTime));
并行计算
并行计算需要考虑的问题就复杂一些。假设结果保存在C矩阵里,期间有m*n个线程参与计算,那么首先要保证C矩阵的计算正确性----即C矩阵是全部的子线程计算完成后得到的结果,而不是子线程还没结束,main线程已经继续执行并打印出了错误的输出结果。如果一个算法计算速度再快,结果是错误的,那就毫无意义了。
这里我采用CountDownLatch作为计数工具。
int threadNum = A.length*B[0].length;
CountDownLatch countDownLatch = new CountDownLatch(threadNum);
这样就声明了一个初始值为m*n个线程的countDownLatch,每当有一个线程完成其计算任务后,可调用countDownLatch实例的countDown()方法,令总线程数减1。这个操作可放在子线程的run()方法中实现。
使用for循环启动线程之后,在main()函数中调用countDownLatch的await()方法,这个操作的作用是,只要计数器的值不为0,其他已先计算完成的子线程就会等待直到计数器值变为0 。计数器变为0后,main线程就不会被阻塞。所以,这时得到的结果就是必然正确的。
按照上述思路,我顺利完成了代码的编写。但令人意外的是,对于我测试的所有维度的矩阵,其并行计算时间均慢于串行。对于维度为300*300的A、B矩阵,运行结果如下。随着矩阵维度的增大,两者的时间差距甚至越来越大:
我确信代码逻辑没有问题,那么一个最有可能的猜测是,创建这m*n个线程和上下文切换耗费太多时间,故对代码又做了改进,令一个子线程由负责C矩阵中一个元素的计算改为负责多行元素的计算,这样就大大减少了线程数量。使用10个线程,对于同样300*300的A、B矩阵计算时间如下:
验证了我的猜测。
另外,当声明的线程数量小于for循环中启动的线程总数时,会导致await()方法提前失效,main线程和子线程交替执行,那么有可能会导致结果错误,这也是需要注意的一点。
最终整体代码如下:
import java.util.concurrent.CountDownLatch;
public class CalculateTask extends Thread {
private int[][] A;
private int[][] B;
private int index;
private int gap;
private int[][] result;
private CountDownLatch countDownLatch;
public CalculateTask(int[][] A, int[][] B, int index, int gap, int[][] result, CountDownLatch countDownLatch) {
this.A = A;
this.B = B;
this.index = index;
this.gap = gap;
this.result = result;
this.countDownLatch = countDownLatch;
}
// 计算特定范围内的结果
public void run() {
// TODO Auto-generated method stub
for (int i = index * gap; i < (index + 1) * gap; i++)
for (int j = 0; j < B[0].length; j++) {
for (int k = 0; k < B.length; k++)
result[i][j] += A[i][k] * B[k][j];
}
// 线程数减1
countDownLatch.countDown();
}
public static void main(String[] args) throws InterruptedException {
// 声明和初始化
long startTime;
long endTime;
int row_A = 300;
int col_A = 300;
int col_B = 300;
int[][] A = new int[row_A][col_A];
int[][] B = new int[col_A][col_B];
//存放并行计算结果
int[][] parallel_result = new int[A.length][B[0].length];
//存放串行计算结果
int[][] serial_result = new int[A.length][B[0].length];
//子线程数量
int threadNum = 10;
//子线程的分片计算间隔
int gap = A.length / threadNum;
CountDownLatch countDownLatch = new CountDownLatch(threadNum);
// 为A和B矩阵随机赋值
for (int i = 0; i < row_A; i++)
for (int j = 0; j < col_A; j++) {
A[i][j] = (int) (Math.random() * 100);
}
for (int i = 0; i < col_A; i++)
for (int j = 0; j < col_B; j++) {
B[i][j] = (int) (Math.random() * 100);
}
// 并行计算
startTime = System.currentTimeMillis();
for (int i = 0; i < threadNum; i++) {
CalculateTask ct = new CalculateTask(A, B, i, gap, parallel_result, countDownLatch);
ct.start();
}
countDownLatch.await();
endTime = System.currentTimeMillis();
System.out.println("并行计算开始时刻:" + (startTime));
System.out.println("并行计算结束时刻:" + (endTime));
System.out.println("并行计算运行时间:" + (endTime - startTime));
// 串行计算
startTime = System.currentTimeMillis();
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < B[0].length; j++) {
for (int k = 0; k < A[0].length; k++)
serial_result[i][j] += A[i][k] * B[k][j];
}
}
endTime = System.currentTimeMillis();
System.out.println("串行计算开始时刻:" + (startTime));
System.out.println("串行计算结束时刻:" + (endTime));
System.out.println("串行计算运行时间:" + (endTime - startTime));
}
}