通用矩阵乘法GEMM的实现 --- Cannon 版本

文章详细描述了一个使用MPI进行并行计算的示例,涉及矩阵初始化函数InitMatrixA和InitMatrixB,以及矩阵乘法的实现MatrixMultiply。代码展示了如何在分布式环境中划分任务和数据交换,以优化性能。
摘要由CSDN通过智能技术生成

代码目录如下:

  • InitMatrix.c
#include "InitMatrix.h"

const float PI = 3.1415926;
const float E  = 2.7182818;

void InitMatrixA(float *matrix, int m, int p, int id)
{
	for (int i = 0; i < m; i++)
		for (int j = 0; j < p; j++)
			matrix[i * p + j] = (id / NP * m + i) * PI + (id % NP * p + j) * E;
}

void InitMatrixB(float *matrix, int p, int n, int id)
{
	for (int i = 0; i < p; i++)
		for (int j = 0; j < n; j++)
			matrix[i * n + j] = 2.0 * (id / NP * p + i) * E - (id % NP * n + j) * PI;
}
  • InitMatrix.h
#define NP 4

void InitMatrixA(float *matrix, int m, int p, int id);
void InitMatrixB(float *matrix, int p, int n, int id);
  • main.c
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <mpi.h>
#include <math.h>
#include "InitMatrix.h"
#include "MatrixMultiply.h"
#include "PrintMatrix.h"

#define M 1024
#define P 1024
#define N 1024

#define NP 4

#define ROW M / NP
#define COL N / NP

#define UP    0
#define DOWN  1
#define LEFT  2
#define RIGHT 3

int main(int argc, char *argv[])
{
	int rank, size;
    MPI_Status status;

    MPI_Init(&argc, &argv);
    MPI_Comm_size(MPI_COMM_WORLD, &size);
	MPI_Comm_rank(MPI_COMM_WORLD, &rank);

	float *matrixa = (float *)malloc(ROW * COL * sizeof(float));
	float *matrixb = (float *)malloc(ROW * COL * sizeof(float));
	float *matrixc = (float *)malloc(ROW * COL * sizeof(float));

	float *MatrixC;
	if (rank == 0)
	{
		MatrixC = (float *)malloc(M * N * sizeof(float));
	}

    int nbrs[4];
    int dims[2] = {4, 4};
    int periods[2] = {1, 1};
    int reorder = 0;
	MPI_Comm Cart_Comm_World;

	MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, reorder, &Cart_Comm_World);
	MPI_Comm_rank(Cart_Comm_World, &rank);

	double start = MPI_Wtime();
	
	int shiftc = rank / NP;
	int shiftr = rank % NP;
	MPI_Cart_shift(Cart_Comm_World, 0, shiftr, &nbrs[UP], &nbrs[DOWN]);
	MPI_Cart_shift(Cart_Comm_World, 1, shiftc, &nbrs[LEFT], &nbrs[RIGHT]);

	InitMatrixA(matrixa, ROW, COL, nbrs[RIGHT]);
	InitMatrixB(matrixb, ROW, COL, nbrs[DOWN]);
 
	memset(matrixc, 0, ROW * COL * sizeof(float));
	MatrixMultiply(matrixa, matrixb, matrixc, ROW, ROW, COL);
	
	MPI_Cart_shift(Cart_Comm_World, 0, 1, &nbrs[UP], &nbrs[DOWN]);
	MPI_Cart_shift(Cart_Comm_World, 1, 1, &nbrs[LEFT], &nbrs[RIGHT]);
	for (int i = 0; i < NP - 1; i++)
	{
		MPI_Sendrecv_replace(matrixa, ROW * COL, MPI_FLOAT, nbrs[LEFT], 0, nbrs[RIGHT], 0, Cart_Comm_World, &status);
		MPI_Sendrecv_replace(matrixb, ROW * COL, MPI_FLOAT, nbrs[UP],   0, nbrs[DOWN],  0, Cart_Comm_World, &status);

		MatrixMultiply(matrixa, matrixb, matrixc, ROW, ROW, COL);
	}
	
	MPI_Gather(matrixc, ROW * COL, MPI_FLOAT, MatrixC, ROW * COL, MPI_FLOAT, 0, Cart_Comm_World);

	if (rank == 0)
	{
		PrintMatrix(MatrixC, M, N);
		double end = MPI_Wtime();
		printf("Total used time is %lf s\n", end - start);	
		
		free(MatrixC);
	}
	
	free(matrixa);
	free(matrixb);
	free(matrixc);
	MPI_Comm_free(&Cart_Comm_World);

	MPI_Finalize();
	return 0;	
}
  • Makefile
CXX = mpicc
TARGET = gemm
SRC = $(wildcard *.c)
OBJ = $(patsubst %.c, %.o, $(SRC))

$(TARGET): $(OBJ)
	$(CXX) -o $@ $^

%.o: %.c
	$(CXX) -c $< -o $@

PHONY: clean

clean:
	rm -f *.o $(TARGET)
  • MatrixMultiply.c
#include "MatrixMultiply.h"

void MatrixMultiply(float *matrixA, float *matrixB, float *matrixC, int m, int p, int n)
{
    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < p; k++)
                matrixC[i * n + j] += matrixA[i * p + k] * matrixB[k * n + j];
}
  • MatrixMultiply.h
void MatrixMultiply(float *matrixA, float *matrixB, float *matrixC, int m, int p, int n);
  • PrintMatrix.c
#include "PrintMatrix.h"

void PrintMatrix(float *matrix, int m, int n)
{
    FILE *pf = fopen("ResultC.dat", "w");
    fprintf(pf, "%.d\n", m * n);

    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            fprintf(pf, "%.10f\n", matrix[i * n + j]);

    fclose(pf);
}
  • PrintMatrix.h
#include <stdio.h>

void PrintMatrix(float * matrix, int m, int n);
  • yhrun.sh
#!/bin/bash
yhrun -p thcp1 -N 1 -n 16 ./gemm

运行结果如下:

维度 \ NP16
5120.663438
10242.368885

注:使用不同 NP 时,需修改 yhrun.sh 文件中的进程数。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值