【MPI】Connon矩阵乘(一)

17 篇文章 0 订阅
5 篇文章 1 订阅

Connon矩阵乘是通过循环移位,通过相邻节点上的数据进行交换,然后最终实现矩阵乘法。特点是每个节点占用空间比较少,传输比较少。本篇以MPI_Comm_split为循环移位基础。以下为具体做法。

以MPI_Comm_cart循环移位为基础的请参考:

https://blog.csdn.net/xll_bit/article/details/103114386

MPI_Type_vector和MPI_Scatterv的具体用法请参考:

https://blog.csdn.net/xll_bit/article/details/103112821

MPI_Comm_split的具体用法请参考:

https://blog.csdn.net/xll_bit/article/details/103091764

#include"mpi.h"
#include<stdio.h>
#include<math.h>
#include<malloc.h> 
#define N 8
// 串行矩阵乘
void serialMatrixMulMatrix(int *dA,int *dB,int *dC,int row_block_data){
         for(int i=0;i<row_block_data;i++)
    	     for(int j=0;j<row_block_data;j++)
    		     for(int k=0;k<row_block_data;k++)
    			     dC[i*row_block_data+j] += dA[i*row_block_data+k] * dB[k*row_block_data+j];
}
//按行打印矩阵
void printMatrix(int *data,int len){
	for(int i=0;i<len;i++){
		for(int j=0;j<len;j++)
			printf("%d	",data[i*len+j]);
		printf("\n");
	}
	printf("\n");
}

int main( int argc, char *argv[] )

{

     int rank, size;
     int X_rank, X_size,Y_rank,Y_size;
     int *dataA, *dA,*dataB,*dB,*dataC,*dC; 
     MPI_Comm X_comm,Y_comm,col_comm;
     MPI_Group group_world,new_group;
     MPI_Datatype matrix_block,tmp_block;
     MPI_Status status[4];
     MPI_Request request[4]; 
     MPI_Init(&argc, &argv);
     MPI_Comm_rank( MPI_COMM_WORLD, &rank );
     MPI_Comm_size(MPI_COMM_WORLD, &size); 
     int row_block_proc = sqrt((double)(size));
     int row_block_data = N/row_block_proc;
     //定义新的数据类型,MPI_Type_vector和MPI_Scatterv的用法可参考:
     //https://blog.csdn.net/xll_bit/article/details/103112821
     MPI_Type_vector(row_block_data,row_block_data,N,MPI_INT,&tmp_block);
     MPI_Type_create_resized(tmp_block,0,row_block_data*sizeof(int),&matrix_block);
     MPI_Type_commit(&matrix_block);
     //初始化矩阵
     if(rank == 0){
	     dataA =(int*)malloc(sizeof(int)*N*N); 
	     dataB =(int*)malloc(sizeof(int)*N*N); 
	     dataC =(int*)malloc(sizeof(int)*N*N); 
	     for(int i=0;i<N;i++)
		     for(int j=0;j<N;j++)
			     dataA[i*N+j] = i*8+j;
	     for(int i=0;i<N;i++)
		     for(int j=0;j<N;j++)
			     dataB[i*N+j] = -i+j*8;
	     for(int i=0;i<N;i++)
		     for(int j=0;j<N;j++)
			     dataC[i*N+j] = 0;
	     printf("Matrix A:\n");
             printMatrix(dataA,N);
	     printf("Matrix B:\n");
     	     printMatrix(dataB,N);
             serialMatrixMulMatrix(dataA,dataB,dataC,N);
	     printf("serialMatrixMulMatrix results:\n");
     	     printMatrix(dataC,N);
     }
     dA =(int*)malloc(sizeof(int)*row_block_data*row_block_data); 
     dB =(int*)malloc(sizeof(int)*row_block_data*row_block_data); 
     dC =(int*)malloc(sizeof(int)*row_block_data*row_block_data); 
     for(int i=0;i<row_block_data;i++)
	     for(int j=0;j<row_block_data;j++)
		    dC[i*row_block_data+j] = 0;
     //MPI_Comm_split函数可将原通信空间划分为新的通信空间,且新通信空间中包含多个互不包含的组,
     //当使用MPI_Bcast或者MPI_Send时是对各个组分别进行广播,或者各个组的节点进行发送。
     //具体可参考:https://blog.csdn.net/xll_bit/article/details/103091764
     MPI_Comm_split(MPI_COMM_WORLD, rank%row_block_proc, 0, &Y_comm);
     MPI_Comm_split(MPI_COMM_WORLD, rank/row_block_proc, 0, &X_comm);

     //得到新的rank和size

     MPI_Comm_rank(X_comm, &X_rank);

     MPI_Comm_size(X_comm, &X_size);     
     MPI_Comm_rank(Y_comm, &Y_rank);

     MPI_Comm_size(Y_comm, &Y_size);     
     int *scounts = (int *)malloc(size*sizeof(int));
     int *displs = (int *)malloc(size*sizeof(int));
     int disp = 0;
     for(int i=0;i<row_block_proc;i++){
	     disp = i*row_block_data*row_block_proc;
	     for(int j=0;j<row_block_proc;j++){
	     	displs[i*row_block_proc+j] = disp + (j + i + X_size)%X_size;
	     	scounts[i*row_block_proc+j] = 1;
	     }
     }
     MPI_Scatterv(dataA,scounts,displs,matrix_block,dA,row_block_data*row_block_data,MPI_INT,0,MPI_COMM_WORLD);
     disp = 0;
     for(int i=0;i<row_block_proc;i++){
	     for(int j=0;j<row_block_proc;j++){
	     	disp = j*row_block_proc;
	     	displs[i*row_block_proc+j] = j + ((j + i + Y_size)%Y_size) * row_block_data*row_block_proc;
	     }
     }
     MPI_Scatterv(dataB,scounts,displs,matrix_block,dB,row_block_data*row_block_data,MPI_INT,0,MPI_COMM_WORLD);
     MPI_Barrier(MPI_COMM_WORLD);
     // 组内循环通信
     int sourceA = (X_rank + 1)%X_size;
     int destA = (X_rank - 1 + X_size)%X_size;
     int sourceB = (Y_rank + 1)%Y_size;
     int destB = (Y_rank - 1 + Y_size)%Y_size;
     for(int step = 0;step < row_block_proc-1;step++){
         MPI_Isend(dA,row_block_data*row_block_data,MPI_INT,destA,99,X_comm,&request[0]);
         MPI_Isend(dB,row_block_data*row_block_data,MPI_INT,destB,99,Y_comm,&request[1]);
	 serialMatrixMulMatrix(dA,dB,dC,row_block_data);
	 MPI_Wait(&request[0],&status[0]);
	 MPI_Wait(&request[1],&status[1]);
         MPI_Irecv(dA,row_block_data*row_block_data,MPI_INT,sourceA,99,X_comm,&request[2]);
         MPI_Irecv(dB,row_block_data*row_block_data,MPI_INT,sourceB,99,Y_comm,&request[3]);
	 MPI_Wait(&request[2],&status[2]);
	 MPI_Wait(&request[3],&status[3]);
     }
     serialMatrixMulMatrix(dA,dB,dC,row_block_data);
     disp = 0;
     for(int i=0;i<row_block_proc;i++){
	     for(int j=0;j<row_block_proc;j++){
	     	displs[i*row_block_proc+j] = j + i * row_block_data*row_block_proc;
	     }
     }
     MPI_Gatherv(dC,row_block_data*row_block_data,MPI_INT,dataC,scounts,displs,matrix_block,0,MPI_COMM_WORLD); 
     MPI_Barrier(MPI_COMM_WORLD);

     if(rank == 0){
	     printf("parallelMatrixMulMatrix results:\n");
     	     printMatrix(dataC,N);
	     free(dataA);
	     free(dataB);
	     free(dataC);
     }
     free(dA);
     free(dB);
     free(dC);

     MPI_Type_free(&tmp_block);
     MPI_Type_free(&matrix_block);
     MPI_Comm_free(&X_comm);
     MPI_Comm_free(&Y_comm);
     MPI_Finalize(); 

    return 0;

}

最终结果如下:

Matrix A:
0	1	2	3	4	5	6	7	
8	9	10	11	12	13	14	15	
16	17	18	19	20	21	22	23	
24	25	26	27	28	29	30	31	
32	33	34	35	36	37	38	39	
40	41	42	43	44	45	46	47	
48	49	50	51	52	53	54	55	
56	57	58	59	60	61	62	63	

Matrix B:
0	8	16	24	32	40	48	56	
-1	7	15	23	31	39	47	55	
-2	6	14	22	30	38	46	54	
-3	5	13	21	29	37	45	53	
-4	4	12	20	28	36	44	52	
-5	3	11	19	27	35	43	51	
-6	2	10	18	26	34	42	50	
-7	1	9	17	25	33	41	49	

serialMatrixMulMatrix results:
-140	84	308	532	756	980	1204	1428	
-364	372	1108	1844	2580	3316	4052	4788	
-588	660	1908	3156	4404	5652	6900	8148	
-812	948	2708	4468	6228	7988	9748	11508	
-1036	1236	3508	5780	8052	10324	12596	14868	
-1260	1524	4308	7092	9876	12660	15444	18228	
-1484	1812	5108	8404	11700	14996	18292	21588	
-1708	2100	5908	9716	13524	17332	21140	24948	

parallelMatrixMulMatrix results:
-140	84	308	532	756	980	1204	1428	
-364	372	1108	1844	2580	3316	4052	4788	
-588	660	1908	3156	4404	5652	6900	8148	
-812	948	2708	4468	6228	7988	9748	11508	
-1036	1236	3508	5780	8052	10324	12596	14868	
-1260	1524	4308	7092	9876	12660	15444	18228	
-1484	1812	5108	8404	11700	14996	18292	21588	
-1708	2100	5908	9716	13524	17332	21140	24948

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值