通用矩阵乘法GEMM的实现 --- 非阻塞通信版本

代码目录如下:

  • InitMatrix.c
#include "InitMatrix.h"

const float PI = 3.1415926;
const float E  = 2.7182818;

void InitMatrixA(float *matrix, int m, int p, int id)
{
	for (int i = 0; i < m; i++)
		for (int j = 0; j < p; j++)
			matrix[i * p + j] = (id * m + i) * PI + j * E;
}

void InitMatrixB(float *matrix, int p, int n, int id)
{
	for (int i = 0; i < p; i++)
		for (int j = 0; j < n; j++)
			matrix[i * n + j] = 2.0 * i * E - (id * n + j) * PI;
}
  • InitMatrix.h
void InitMatrixA(float *matrix, int m, int p, int id);
void InitMatrixB(float *matrix, int p, int n, int id);
  • main.c
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <mpi.h>
#include "InitMatrix.h"
#include "MatrixMultiply.h"
#include "PrintMatrix.h"

#define M 512
#define P 512
#define N 512

#define NP 8

#define ROW M / NP
#define COL N / NP

void matrixc2c(float *matrixc, float *Matrixc, int r, int c, int id, int itern)
{
	for (int i = 0; i < r; i++)
		for (int j = 0; j < c; j++)
			Matrixc[i * N + (id * COL + itern * COL + j) % N] = matrixc[i * c + j];
}

void matrixbufb2b(float *bufb, float *locb, int r, int c)
{
	for (int i = 0; i < r; i++)
		for (int j = 0; j < c; j++)
			locb[i * c + j] = bufb[i * c + j];
}

int main(int argc, char * argv [])
{
	int rank, size;
    MPI_Status status[2];
	MPI_Request request[2];

    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

	float *matrixa = (float *)malloc(ROW * P   * sizeof(float));
	float *matrixb = (float *)malloc(P   * COL * sizeof(float));
	float *matrixc = (float *)malloc(ROW * COL * sizeof(float));
	float *Matrixc = (float *)malloc(ROW * N   * sizeof(float));
	float *matrixb_buf = (float *)malloc(P * COL * sizeof(float));

	float *MatrixC;
	if (rank == 0)
	{
		MatrixC = (float *)malloc(M * N * sizeof(float));
	}

	InitMatrixA(matrixa, ROW, P, rank);
	InitMatrixB(matrixb, P, COL, rank);

	double start = MPI_Wtime();

	int sed_proc_id = (rank == 0) ? (NP - 1) : (rank - 1);
    int rcv_proc_id = (rank == (NP - 1)) ? 0 : (rank + 1);

	for (int i = 0; i < NP; i++)
	{
		if (i < NP - 1)
		{
			MPI_Isend(matrixb, P * COL, MPI_FLOAT, sed_proc_id, 0, MPI_COMM_WORLD, &request[0]);
			MPI_Irecv(matrixb_buf, P * COL, MPI_FLOAT, rcv_proc_id, 0, MPI_COMM_WORLD, &request[1]);
		}
		
		memset(matrixc, 0, ROW * COL * sizeof(float));
		MatrixMultiply(matrixa, matrixb, matrixc, ROW, P, COL);
		matrixc2c(matrixc, Matrixc, ROW, COL, rank, i);
		
		if (i < NP - 1)
		{
			MPI_Waitall(2, &request[0], &status[0]); // 隐式释放
			matrixbufb2b(matrixb_buf, matrixb, P, COL);
		}
	}

	MPI_Gather(Matrixc, ROW * N, MPI_FLOAT, MatrixC, ROW * N, MPI_FLOAT, 0, MPI_COMM_WORLD);

	if (rank == 0)
	{
		PrintMatrix(MatrixC, M, N);
		double end = MPI_Wtime();
		printf("Total used time is %lf s\n", end - start);	
		
		free(MatrixC);
	}
	
	free(matrixa);
	free(matrixb);
	free(matrixc);
	free(Matrixc);
	free(matrixb_buf);

	MPI_Finalize();
	return 0;	
}
  • Makefile
CXX = mpicc
TARGET = gemm
SRC = $(wildcard *.c)
OBJ = $(patsubst %.c, %.o, $(SRC))

$(TARGET): $(OBJ)
	$(CXX) -o $@ $^

%.o: %.c
	$(CXX) -c $< -o $@

PHONY: clean

clean:
	rm -f *.o $(TARGET)
  • MatrixMultiply.c
#include "MatrixMultiply.h"

void MatrixMultiply(float *matrixA, float *matrixB, float *matrixC, int m, int p, int n)
{
    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < p; k++)
                matrixC[i * n + j] += matrixA[i * p + k] * matrixB[k * n + j];
}
  • MatrixMultiply.h
void MatrixMultiply(float *matrixA, float *matrixB, float *matrixC, int m, int p, int n);
  • PrintMatrix.c
#include "PrintMatrix.h"

void PrintMatrix(float *matrix, int m, int n)
{
    FILE *pf = fopen("ResultC.dat", "w");
    fprintf(pf, "%.d\n", m * n);

    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            fprintf(pf, "%.10f\n", matrix[i * n + j]);

    fclose(pf);
}
  • PrintMatrix.h
#include <stdio.h>

void PrintMatrix(float * matrix, int m, int n);
  • yhrun.sh
#!/bin/bash
yhrun -p thcp1 -N 1 -n 8 ./gemm

运行结果如下:

维度 \ NP1248163264
5122.277681 2.3323821.0405340.6614650.5080840.3643350.347066
102466.52524920.3560659.4625093.6347632.300170 1.7433951.461965

注:使用不同 NP 时,需修改 yhrun.sh 文件中的进程数。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值