矩阵乘法(多线程版)
题外话
家中断网,月末没有多少流量,跟随网课学习redis的任务被阻塞了,数据算法最近又不想看,看看JAVA多线程吧
正题
矩阵乘法是针对矩阵的基本运算之一,也是并发和并行编程中经典问题
在串行的矩阵乘法运算中,核心代码就是如下内容
for(int i=0;i<rows;i++){
for(int j=0;j<col2;j++){
result[i][j]=0;
for(int k=0;k<col;k++){
result[i][j]+=matrix1[i][k]*matrix2[k][j];
}
}
}
由此代码可以看出,如果将该计算过程转化为多线程的话,可以从矩阵的每个元素的计算过程和矩阵每行计算过程开始入手进行分析
在这篇博客中,通过java代码随即生成两个2000行和2000列的矩阵,进行乘法运算
首先,随即生成2000行和2000列的矩阵程序如下
public class MatrixGenerator {
public static double[][] generate (int row,int col){
double[][] ret=new double[row][col];
Random random=new Random();
for(int i=0;i<row;i++){
for(int j=0;j<col;j++){
ret[i][j]=random.nextDouble()*10;
}
}
return ret;
}
}
从矩阵的每个元素计算,每个元素创建一个线程
两个2000行和2000列的矩阵进行计算,得到的矩阵元素有4 000 000,为了防止系统超载,将20个线程作为一组进行启动,新建类IndividualMultiplierTask类,该类实现每个元素计算的Thread,设计如下
public class IndividualMultiplier implements Runnable{
private final double[][] result;
private final double[][] matrix1;
private final double[][] matrix2;
private final int row;
private final int col;
public IndividualMultiplier(double[][] result,double[][] matrix1,double[][]matrix2,int i,int j){
this.result=result;
this.matrix1=matrix1;
this.matrix2=matrix2;
this.row=i;
this.col=j;
}
@Override
public void run() {
result[row][col]=0;
for(int k=0;k<matrix1[row].length;k++){
result[row][col]+=matrix1[row][k]*matrix2[k][col];
}
}
}
正如代码中所看到的那样,run方法中将矩阵中每个元素计算的过程写入到run方法中,现在再编写一个类创建所有必要的线程执行计算结果矩阵,调用IndividualMultiplier来实现计算矩阵的每一个元素
public class ParallelIndividualMultiplier {
public static void multiply(double[][] matrix1,double[][] matrix2, double[][] result) {
List<Thread> threads = new ArrayList<>();
int row1 = matrix1.length;
int row2 = matrix2.length;
for(int i=0;i<row1;i++){
for(int j=0;j<row1;j++){
IndividualMultiplier task = new IndividualMultiplier(result,matrix1,matrix2,i,j);
Thread thread = new Thread(task);
thread.start();
threads.add(thread);
if(threads.size()%10 ==0) { //每次只计算10个,这个自己设定,只要算力跟得上
waitForThreads(threads);
}
}
}
}
private static void waitForThreads(List<Thread> threads) { //等待资源
for(Thread thread: threads){
try{
thread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
threads.clear();
}
}
现在,还差一个驱动类
public class ParallelIndvidualMain {
public static void main(String[] args){
double matrix1[][]= MatrixGenerator.generate(2000,2000);
double matrix2[][]=MatrixGenerator.generate(2000,2000);
double resulet[][]=new double[matrix1.length][matrix2[0].length];
Date start=new Date(); //统计计算开始时间
ParallelIndividualMultiplier.multiply(matrix1,matrix2,resulet);
Date end=new Date();
System.out.println(end.getTime()-start.getTime()); //统计执行时间
}
}
这便是根据每个元素创建一个线程的并发矩阵计算代码,下面看看矩阵每行创建一个可执行线程的代码
从矩阵的每行计算,每行创建一个线程
正如前面所展示的那样,在这里还是以10个线程为一组进行计算,首先通过设计RowMultiplierTask类并继承Runnable接口来实现每个Thread,RowMultiplierTask类设计如下
public class RowMultiplierTask implements Runnable {
private final double[][] result;
private final double[][] matrix1;
private final double[][] matrix2;
private final int row;
public RowMultiplierTask(double[][] result,double[][] matrix1,double[][]matrix2,int row){
this.result=result;
this.matrix1=matrix1;
this.matrix2=matrix2;
this.row=row;
}
@Override
public void run() {
for( int j=0;j<matrix2[row].length; j++){
result[row][j]=0;
for(int k=0;k<matrix1[row].length;k++){
result[row][j]+=matrix1[row][k]*matrix2[k][j];
}
}
}
}
从上面的run方法可以看到,这次是将行计算的过程放入了run方法中,从而实现按每行创建一个线程
同理,设计一个类来创建计算结果矩阵所需要的所有执行线程,方法如下
public class ParallelRowMultiplier {
public static void multiply(double[][] matrix1, double[][] matrix2, double[][] result) {
List<Thread> threads = new ArrayList<>();
int rows1 = matrix1.length;
for(int i=0;i<rows1;i++){
RowMultiplierTask task = new RowMultiplierTask(result,matrix1,matrix2,i);
Thread thread = new Thread(task);
thread.start();
threads.add(thread);
if(threads.size()%10==0) {
waitForThreads(threads);
}
}
}
private static void waitForThreads(List<Thread> threads) {
for(Thread thread: threads){
try{
thread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
threads.clear();
}
}
再同理,驱动程序如下
public class ParallelRowMain {
public static void main(String[] args){
double matrix1[][]= MatrixGenerator.generate(2000,2000);
double matrix2[][]=MatrixGenerator.generate(2000,2000);
double resulet[][]=new double[matrix1.length][matrix2[0].length];
Date start=new Date();
ParallelRowMultiplier.multiply(matrix1,matrix2,resulet);
Date end=new Date();
System.out.println(end.getTime()-start.getTime());
}
}
上述两种代码的线程数量都是固定的,那是否可以将线程的数量交由处理器决定呢?当然可以
矩阵乘法线程数量由处理器决定
在块代码中,获取可用线程的数量可以通过Runtime类的availableProcessors()方法来进行计算,在这里在矩阵每行创建一个线程的基础上进行修改,首先设计实现创建的线程类GroupMultiplierTask
public class GroupMultiplierTask implements Runnable{
private final double[][] result;
private final double[][] matrix1;
private final double[][] matrix2;
private final int start;
private final int end;
public GroupMultiplierTask(double[][] result, double[][] matrix1, double[][] matrix2, int start, int end) {
this.result = result;
this.matrix1 = matrix1;
this.matrix2 = matrix2;
this.start = start;
this.end = end;
}
@Override
public void run() {
for(int i=start;i<end;i++){
for(int j=0;j<matrix2[0].length;j++){
result[i][j] = 0;
for(int k=0;k<matrix1[i].length;k++) {
result[i][j]=result[i][k]*result[k][j];
}
}
}
}
}
同理,设计一个类来创建计算结果矩阵所需要的所有执行线程,方法如下
public class ParallelGroupMultiplier {
public static void multiply(double[][] matrix1, double[][] matrix2, double[][] result) {
List<Thread> threads = new ArrayList<>();
int rows1= matrix1.length;
int numThreads=Runtime.getRuntime().availableProcessors();
int startIndex, endIndex,step;
step = rows1/numThreads; //计算分组,每个线程计算一个分组
startIndex = 0;
endIndex = step;
for(int i=0;i<numThreads;i++) {
GroupMultiplierTask task = new GroupMultiplierTask(result,matrix1,matrix2,startIndex,endIndex);
Thread thread = new Thread(task);
thread.start();
threads.add(thread);
startIndex = endIndex;
endIndex= i==numThreads-2?rows1:endIndex+step;
}
for(Thread thread: threads) {
try{
thread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
再同理,驱动程序如下
public class ParalleGroupMain {
public static void main(String[] args){
double matrix1[][]= MatrixGenerator.generate(2000,2000);
double matrix2[][]=MatrixGenerator.generate(2000,2000);
double resulet[][]=new double[matrix1.length][matrix2[0].length];
Date start=new Date();
ParallelGroupMultiplier.multiply(matrix1,matrix2,resulet);
Date end=new Date();
System.out.println(end.getTime()-start.getTime());
}
}