[原创]利用Java多线程计算矩阵乘法

[原创]利用Java多线程计算矩阵乘法

前言

前段时间在一本操作系统书籍上,看到了可以利用多线程来计算矩阵乘法的思想。例如下图中,A矩阵和B矩阵相乘得到C矩阵,那么A矩阵的每一行和B矩阵的每一列的相乘和加和,都可以交给一个线程来计算,最终得到cij这个元素。A矩阵维度是m*s,B矩阵是s*n,那么这个计算需要m*n个线程的参与,它是否一定比串行计算快呢?本文使用Java多线程一探究竟。

串行计算

矩阵乘法的串行计算方法是不难想到的。三层循环遍历计算即可。从外到内分别遍历A矩阵的行、B矩阵的列、A矩阵的列(即为B矩阵的行)即可。在计算的开始和结束时刻分别获取系统当前时刻,最后可得计算时间。以下是简单的代码片段。

// 串行验证

		startTime = System.currentTimeMillis();
		for (int i = 0; i < A.length; i++) {
			for (int j = 0; j < B[0].length; j++) {
				for (int k = 0; k < A[0].length; k++)
					serial_result[i][j] += A[i][k] * B[k][j];
			}
		}

		endTime = System.currentTimeMillis();
		System.out.println("串行计算开始时刻:" + (startTime));
		System.out.println("串行计算结束时刻:" + (endTime));
		System.out.println("串行计算运行时间:" + (endTime - startTime));

​

并行计算

并行计算需要考虑的问题就复杂一些。假设结果保存在C矩阵里,期间有m*n个线程参与计算,那么首先要保证C矩阵的计算正确性----即C矩阵是全部的子线程计算完成后得到的结果,而不是子线程还没结束,main线程已经继续执行并打印出了错误的输出结果。如果一个算法计算速度再快,结果是错误的,那就毫无意义了。

这里我采用CountDownLatch作为计数工具。

int threadNum = A.length*B[0].length;
CountDownLatch countDownLatch = new CountDownLatch(threadNum);

这样就声明了一个初始值为m*n个线程的countDownLatch,每当有一个线程完成其计算任务后,可调用countDownLatch实例的countDown()方法,令总线程数减1。这个操作可放在子线程的run()方法中实现。

使用for循环启动线程之后,在main()函数中调用countDownLatch的await()方法,这个操作的作用是,只要计数器的值不为0,其他已先计算完成的子线程就会等待直到计数器值变为0 。计数器变为0后,main线程就不会被阻塞。所以,这时得到的结果就是必然正确的。

按照上述思路,我顺利完成了代码的编写。但令人意外的是,对于我测试的所有维度的矩阵,其并行计算时间均慢于串行。对于维度为300*300的A、B矩阵,运行结果如下。随着矩阵维度的增大,两者的时间差距甚至越来越大:

我确信代码逻辑没有问题,那么一个最有可能的猜测是,创建这m*n个线程和上下文切换耗费太多时间,故对代码又做了改进,令一个子线程由负责C矩阵中一个元素的计算改为负责多行元素的计算,这样就大大减少了线程数量。使用10个线程,对于同样300*300的A、B矩阵计算时间如下:

验证了我的猜测。

另外,当声明的线程数量小于for循环中启动的线程总数时,会导致await()方法提前失效,main线程和子线程交替执行,那么有可能会导致结果错误,这也是需要注意的一点。

最终整体代码如下:

import java.util.concurrent.CountDownLatch;
public class CalculateTask extends Thread {
	private int[][] A;
	private int[][] B;
	private int index;
	private int gap;
	private int[][] result;
	private CountDownLatch countDownLatch;

	public CalculateTask(int[][] A, int[][] B, int index, int gap, int[][] result, CountDownLatch countDownLatch) {
		this.A = A;
		this.B = B;
		this.index = index;
		this.gap = gap;
		this.result = result;
		this.countDownLatch = countDownLatch;
	}

	// 计算特定范围内的结果
	public void run() {
		// TODO Auto-generated method stub
		for (int i = index * gap; i < (index + 1) * gap; i++)
			for (int j = 0; j < B[0].length; j++) {
				for (int k = 0; k < B.length; k++)
					result[i][j] += A[i][k] * B[k][j];
			}
		// 线程数减1
		countDownLatch.countDown();
	}

	public static void main(String[] args) throws InterruptedException {
		// 声明和初始化
		long startTime;
		long endTime;
		int row_A = 300;
		int col_A = 300;
		int col_B = 300;
		int[][] A = new int[row_A][col_A];
		int[][] B = new int[col_A][col_B];
        //存放并行计算结果
		int[][] parallel_result = new int[A.length][B[0].length];
        //存放串行计算结果
		int[][] serial_result = new int[A.length][B[0].length];
        //子线程数量
		int threadNum = 10;
        //子线程的分片计算间隔
		int gap = A.length / threadNum;
		CountDownLatch countDownLatch = new CountDownLatch(threadNum);
		// 为A和B矩阵随机赋值
		for (int i = 0; i < row_A; i++)
			for (int j = 0; j < col_A; j++) {
				A[i][j] = (int) (Math.random() * 100);
			}
		for (int i = 0; i < col_A; i++)
			for (int j = 0; j < col_B; j++) {
				B[i][j] = (int) (Math.random() * 100);
			}
		// 并行计算
		startTime = System.currentTimeMillis();
		for (int i = 0; i < threadNum; i++) {
			CalculateTask ct = new CalculateTask(A, B, i, gap, parallel_result, countDownLatch);
			ct.start();
		}
		countDownLatch.await();
		endTime = System.currentTimeMillis();
		System.out.println("并行计算开始时刻:" + (startTime));
		System.out.println("并行计算结束时刻:" + (endTime));
		System.out.println("并行计算运行时间:" + (endTime - startTime));

		// 串行计算
		startTime = System.currentTimeMillis();
		for (int i = 0; i < A.length; i++) {
			for (int j = 0; j < B[0].length; j++) {
				for (int k = 0; k < A[0].length; k++)
					serial_result[i][j] += A[i][k] * B[k][j];
			}
		}
		endTime = System.currentTimeMillis();
		System.out.println("串行计算开始时刻:" + (startTime));
		System.out.println("串行计算结束时刻:" + (endTime));
		System.out.println("串行计算运行时间:" + (endTime - startTime));
	}
}

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值