Connon矩阵乘是通过循环移位,通过相邻节点上的数据进行交换,然后最终实现矩阵乘法。特点是每个节点占用空间比较少,传输比较少。本篇以MPI_Comm_split为循环移位基础。以下为具体做法。
以MPI_Comm_cart循环移位为基础的请参考:
https://blog.csdn.net/xll_bit/article/details/103114386
MPI_Type_vector和MPI_Scatterv的具体用法请参考:
https://blog.csdn.net/xll_bit/article/details/103112821
MPI_Comm_split的具体用法请参考:
https://blog.csdn.net/xll_bit/article/details/103091764
#include"mpi.h"
#include<stdio.h>
#include<math.h>
#include<malloc.h>
#define N 8
// 串行矩阵乘
void serialMatrixMulMatrix(int *dA,int *dB,int *dC,int row_block_data){
for(int i=0;i<row_block_data;i++)
for(int j=0;j<row_block_data;j++)
for(int k=0;k<row_block_data;k++)
dC[i*row_block_data+j] += dA[i*row_block_data+k] * dB[k*row_block_data+j];
}
//按行打印矩阵
void printMatrix(int *data,int len){
for(int i=0;i<len;i++){
for(int j=0;j<len;j++)
printf("%d ",data[i*len+j]);
printf("\n");
}
printf("\n");
}
int main( int argc, char *argv[] )
{
int rank, size;
int X_rank, X_size,Y_rank,Y_size;
int *dataA, *dA,*dataB,*dB,*dataC,*dC;
MPI_Comm X_comm,Y_comm,col_comm;
MPI_Group group_world,new_group;
MPI_Datatype matrix_block,tmp_block;
MPI_Status status[4];
MPI_Request request[4];
MPI_Init(&argc, &argv);
MPI_Comm_rank( MPI_COMM_WORLD, &rank );
MPI_Comm_size(MPI_COMM_WORLD, &size);
int row_block_proc = sqrt((double)(size));
int row_block_data = N/row_block_proc;
//定义新的数据类型,MPI_Type_vector和MPI_Scatterv的用法可参考:
//https://blog.csdn.net/xll_bit/article/details/103112821
MPI_Type_vector(row_block_data,row_block_data,N,MPI_INT,&tmp_block);
MPI_Type_create_resized(tmp_block,0,row_block_data*sizeof(int),&matrix_block);
MPI_Type_commit(&matrix_block);
//初始化矩阵
if(rank == 0){
dataA =(int*)malloc(sizeof(int)*N*N);
dataB =(int*)malloc(sizeof(int)*N*N);
dataC =(int*)malloc(sizeof(int)*N*N);
for(int i=0;i<N;i++)
for(int j=0;j<N;j++)
dataA[i*N+j] = i*8+j;
for(int i=0;i<N;i++)
for(int j=0;j<N;j++)
dataB[i*N+j] = -i+j*8;
for(int i=0;i<N;i++)
for(int j=0;j<N;j++)
dataC[i*N+j] = 0;
printf("Matrix A:\n");
printMatrix(dataA,N);
printf("Matrix B:\n");
printMatrix(dataB,N);
serialMatrixMulMatrix(dataA,dataB,dataC,N);
printf("serialMatrixMulMatrix results:\n");
printMatrix(dataC,N);
}
dA =(int*)malloc(sizeof(int)*row_block_data*row_block_data);
dB =(int*)malloc(sizeof(int)*row_block_data*row_block_data);
dC =(int*)malloc(sizeof(int)*row_block_data*row_block_data);
for(int i=0;i<row_block_data;i++)
for(int j=0;j<row_block_data;j++)
dC[i*row_block_data+j] = 0;
//MPI_Comm_split函数可将原通信空间划分为新的通信空间,且新通信空间中包含多个互不包含的组,
//当使用MPI_Bcast或者MPI_Send时是对各个组分别进行广播,或者各个组的节点进行发送。
//具体可参考:https://blog.csdn.net/xll_bit/article/details/103091764
MPI_Comm_split(MPI_COMM_WORLD, rank%row_block_proc, 0, &Y_comm);
MPI_Comm_split(MPI_COMM_WORLD, rank/row_block_proc, 0, &X_comm);
//得到新的rank和size
MPI_Comm_rank(X_comm, &X_rank);
MPI_Comm_size(X_comm, &X_size);
MPI_Comm_rank(Y_comm, &Y_rank);
MPI_Comm_size(Y_comm, &Y_size);
int *scounts = (int *)malloc(size*sizeof(int));
int *displs = (int *)malloc(size*sizeof(int));
int disp = 0;
for(int i=0;i<row_block_proc;i++){
disp = i*row_block_data*row_block_proc;
for(int j=0;j<row_block_proc;j++){
displs[i*row_block_proc+j] = disp + (j + i + X_size)%X_size;
scounts[i*row_block_proc+j] = 1;
}
}
MPI_Scatterv(dataA,scounts,displs,matrix_block,dA,row_block_data*row_block_data,MPI_INT,0,MPI_COMM_WORLD);
disp = 0;
for(int i=0;i<row_block_proc;i++){
for(int j=0;j<row_block_proc;j++){
disp = j*row_block_proc;
displs[i*row_block_proc+j] = j + ((j + i + Y_size)%Y_size) * row_block_data*row_block_proc;
}
}
MPI_Scatterv(dataB,scounts,displs,matrix_block,dB,row_block_data*row_block_data,MPI_INT,0,MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
// 组内循环通信
int sourceA = (X_rank + 1)%X_size;
int destA = (X_rank - 1 + X_size)%X_size;
int sourceB = (Y_rank + 1)%Y_size;
int destB = (Y_rank - 1 + Y_size)%Y_size;
for(int step = 0;step < row_block_proc-1;step++){
MPI_Isend(dA,row_block_data*row_block_data,MPI_INT,destA,99,X_comm,&request[0]);
MPI_Isend(dB,row_block_data*row_block_data,MPI_INT,destB,99,Y_comm,&request[1]);
serialMatrixMulMatrix(dA,dB,dC,row_block_data);
MPI_Wait(&request[0],&status[0]);
MPI_Wait(&request[1],&status[1]);
MPI_Irecv(dA,row_block_data*row_block_data,MPI_INT,sourceA,99,X_comm,&request[2]);
MPI_Irecv(dB,row_block_data*row_block_data,MPI_INT,sourceB,99,Y_comm,&request[3]);
MPI_Wait(&request[2],&status[2]);
MPI_Wait(&request[3],&status[3]);
}
serialMatrixMulMatrix(dA,dB,dC,row_block_data);
disp = 0;
for(int i=0;i<row_block_proc;i++){
for(int j=0;j<row_block_proc;j++){
displs[i*row_block_proc+j] = j + i * row_block_data*row_block_proc;
}
}
MPI_Gatherv(dC,row_block_data*row_block_data,MPI_INT,dataC,scounts,displs,matrix_block,0,MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
if(rank == 0){
printf("parallelMatrixMulMatrix results:\n");
printMatrix(dataC,N);
free(dataA);
free(dataB);
free(dataC);
}
free(dA);
free(dB);
free(dC);
MPI_Type_free(&tmp_block);
MPI_Type_free(&matrix_block);
MPI_Comm_free(&X_comm);
MPI_Comm_free(&Y_comm);
MPI_Finalize();
return 0;
}
最终结果如下:
Matrix A:
0 1 2 3 4 5 6 7
8 9 10 11 12 13 14 15
16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31
32 33 34 35 36 37 38 39
40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55
56 57 58 59 60 61 62 63
Matrix B:
0 8 16 24 32 40 48 56
-1 7 15 23 31 39 47 55
-2 6 14 22 30 38 46 54
-3 5 13 21 29 37 45 53
-4 4 12 20 28 36 44 52
-5 3 11 19 27 35 43 51
-6 2 10 18 26 34 42 50
-7 1 9 17 25 33 41 49
serialMatrixMulMatrix results:
-140 84 308 532 756 980 1204 1428
-364 372 1108 1844 2580 3316 4052 4788
-588 660 1908 3156 4404 5652 6900 8148
-812 948 2708 4468 6228 7988 9748 11508
-1036 1236 3508 5780 8052 10324 12596 14868
-1260 1524 4308 7092 9876 12660 15444 18228
-1484 1812 5108 8404 11700 14996 18292 21588
-1708 2100 5908 9716 13524 17332 21140 24948
parallelMatrixMulMatrix results:
-140 84 308 532 756 980 1204 1428
-364 372 1108 1844 2580 3316 4052 4788
-588 660 1908 3156 4404 5652 6900 8148
-812 948 2708 4468 6228 7988 9748 11508
-1036 1236 3508 5780 8052 10324 12596 14868
-1260 1524 4308 7092 9876 12660 15444 18228
-1484 1812 5108 8404 11700 14996 18292 21588
-1708 2100 5908 9716 13524 17332 21140 24948