#include<stdio.h>
#include<time.h>
#include<stdlib.h>
#include "mpi.h"
#define N 3
int main(int argc,char **argv)
{
MPI_Init(&argc,&argv);
int rank,size;
MPI_Comm_rank(MPI_COMM_WORLD,&rank);
MPI_Comm_size(MPI_COMM_WORLD,&size);
int m=atoi(argv[1]);
int n=atoi(argv[2]);
int block=(m+size-1)/size*n;
if(rank==size-1 && block>m*n-(size-1)*block) block=m*n-(size-1)*block;
int *A=NULL;
int *TA=(int*)malloc(block*sizeof(int));
int **B=(int**)malloc(m*sizeof(int*));
int *counts=NULL;
int *ofss=NULL;
int i,j;
for(i=0;i<m;++i)
B[i]=(int*)malloc(n*sizeof(int));
if(rank==0)
{
counts=(int*)malloc(size*sizeof(int));
ofss=(int*)malloc(size*sizeof(int));
for(i=0;i<size-1;++i)
{
counts[i]=block;
ofss[i]=i*block;
}
counts[i]=block<=m*n-i*block?block:m*n-i*block;
ofss[i]=i*block;
//for(i=0;i<size;++i)
// printf("%d\t%d\n",counts[i],ofss[i]);
// printf("constructing %d * %d matrix\n",m,n);
A=(int*)malloc(m*n*sizeof(int));
srand((int)time(0));
for(i=0;i<m;++i)
{
for(j=0;j<n;++j)
{
A[i*n+j]=(int)rand()%N;
B[i][j]=(int)rand()%N;
}
}
/* printf("A:\n");
for(i=0;i<m;++i)
{
for(j=0;j<n;++j)
printf("%d\t",A[i*n+j]);
printf("\n");
}
printf("B:\n");
for(i=0;i<n;++i)
{
for(j=0;j<m;++j)
printf("%d\t",B[j][i]);
printf("\n");
}*/
}
for(i=0;i<m;++i)
MPI_Bcast(B[i],n,MPI_INT,0,MPI_COMM_WORLD);
MPI_Scatterv(A,counts,ofss,MPI_INT,TA,block,MPI_INT,0,MPI_COMM_WORLD);
/*printf("B in %d:\n",rank);
for(i=0;i<m;++i)
{
for(j=0;j<n;++j)
printf("%d\t",B[i][j]);
printf("\n");
}
printf("A in %d\n",rank);
for(i=0;i<block;++i)
printf("%d\t",TA[i]);
printf("\n");*/
int col=block/n;
int *RETA=(int*)malloc(col*m*sizeof(int));
int *RES=NULL;
if(rank==0)
RES=(int*)malloc(m*m*sizeof(int));
int k;
for(k=0;k<m;++k)
{
for(i=0;i<col;++i)
{
int tmp=0;
for(j=0;j<n;++j)
tmp+=TA[i*n+j]*B[k][j];
RETA[i*m+k]=tmp;
}
}
if(rank==0)
{
ofss[0]=0;
counts[0]=col*m;
for(i=1;i<size;++i)
{
counts[i]=counts[i]/n*m;
ofss[i]=ofss[i-1]+counts[i-1];
}
}
MPI_Gatherv(RETA,col*m,MPI_INT,RES,counts,ofss,MPI_INT,0,MPI_COMM_WORLD);
if(rank==0)
{
printf("A:\n");
for(i=0;i<m;++i)
{
printf("[\t");
for(j=0;j<n;++j)
printf("%d\t",A[i*n+j]);
printf("]\n");
}
printf("B:\n");
for(i=0;i<n;++i)
{
printf("[\t");
for(j=0;j<m;++j)
printf("%d\t",B[j][i]);
printf("]\n");
}
printf("=:\n");
for(i=0;i<m;++i)
{
printf("[\t");
for(j=0;j<m;++j)
printf("%d\t",RES[i*m+j]);
printf("]\n");
}
free(A);
free(counts);
free(ofss);
free(RES);
}
free(TA);
free(RETA);
for(i=0;i<m;++i)
free(B[i]);
free(B);
MPI_Finalize();
return 0;
}
MPI——矩阵乘法
最新推荐文章于 2024-06-04 19:49:49 发布