主从模式实现矩阵乘法

参考资料

主要思路参考这篇文章

两小时入门MPI与并行计算(六):主从模式(实现矩阵乘法) - 知乎

因为原文只有Fortran代码,而且评论区c++代码只有矩阵和向量的乘法,在这里添加矩阵和矩阵相乘的c++示例

参考的c++代码为主从模式(实现矩阵乘法) - 简书

代码主体 

/* ************************************************************************
> File Name:     matmult.cpp
> Author:        Roy
> Created Time:  Monday, October 23, 2023 PM05:27:18
> Description:   
 ************************************************************************/
#include <mpi.h>
#include <iostream>
#include <iomanip>
#include <algorithm>

/***
 * A mpi program calculating A*b=ans, where A[Arowsize][Acolsize] is a matrix, b[Acolsize][bcolsize] is a matrix, ans[Arowsize][bcolsize] is a matrix
 * ***/

// Global variables
const int Acolsize = 10;          //Number of A colomns
const int Arowsize = 10;          //Number of A rows
const int bcolsize = 5;           //Number of b colomns
double A[Arowsize][Acolsize] = {0};//Matrix A
double b[Acolsize][bcolsize] = {0};       //Matrix b
double ans_buffer[bcolsize] = {0};        //Buffer for ans
double ans[Arowsize][bcolsize] = {0};     //Matrix ans
double buffer[Acolsize] = {0};            //Buffer for A row data
int numsent = 0;                          //Number of sent rows
int numprocess = 0;                      //Number of processes
const int m_id = 0;                      //Master process id
const int end_tag = 0;                   //End tag
MPI_Status status;                        //MPI status  

// Master process: Define data, send row data, and collect results
void master()
{
    // Define data
    for (int i = 0; i < Arowsize; i++)
    {
        for (int j = 0; j < Acolsize; j++)
            {
                A[i][j] = i + 1;
                for (int k = 0; k < bcolsize; k++)
                    b[j][k] = k + 1;
            }
    };
    // Broadcast b to all processors
    MPI_Bcast(
        b, // void *buffer
    Acolsize * bcolsize, // int count
    MPI_DOUBLE, // MPI_Datatype datatype
    m_id, // int root
    MPI_COMM_WORLD); // MPI_Comm comm

    // Send row data from main to slave processors, so in total numprocess-1 processors to send
    for (int i = 0; i < std::min(numprocess - 1, Arowsize); i++)
    {
        //Assign A row data to buffer
        for (int j = 0; j < Acolsize; j++)
        {
            buffer[j] = A[i][j];
        }
        //Send buffer to slave processors
        MPI_Send(
            buffer, // const void *buf
        Acolsize, // int count
        MPI_DOUBLE, // MPI_Datatype datatype
        i + 1, // int dest, send i-th row to i+1-th processor
        i + 1, // int tag, tag = id of processor = i+1, so i+1-th tag can labels the i-th row data
        MPI_COMM_WORLD);
        numsent++; //Number of sent rows
    };
    // Receive results from slave processors
    for (int i = 0; i < Arowsize; i++)
    {
        MPI_Recv(
        ans_buffer, // void *buf
        bcolsize, // int count
        MPI_DOUBLE, // MPI_Datatype datatype
        MPI_ANY_SOURCE, // int source, receive from any processor
        MPI_ANY_TAG, // int tag, receive any tag
        MPI_COMM_WORLD, // MPI_Comm comm
        &status); // MPI_Status *status
        //Assign ans_buffer to ans, rember that tag i+1 labels the i-th row data
        for (int j = 0; j < bcolsize; j++)
        {
            ans[status.MPI_TAG - 1][j] = ans_buffer[j];
        }
        int sender = status.MPI_SOURCE; // sender = status.MPI_SOURCE, the source of the message
        //If there are still rows to send, send the next row to the same processor
        if (numsent < Arowsize)
        {
            //Assign A row data to buffer
            for (int j = 0; j < Acolsize; j++)
            {
                buffer[j] = A[numsent][j];
            }
            //Send buffer to slave processors
            MPI_Send(
                buffer, // const void *buf
            Acolsize, // int count
            MPI_DOUBLE, // MPI_Datatype datatype
            sender, // int dest, send numsent-th row to sender-th processor
            numsent + 1, // int tag, tag = id of processor = numsent+1, so numsent+1-th tag can labels the numsent-th row data
            MPI_COMM_WORLD);
            numsent++; //Number of sent rows
        }
        else
        {
            //If there are no more rows to send, send end_tag to the processor
            MPI_Send(
                0, // const void *buf
            0, // int count
            MPI_DOUBLE, // MPI_Datatype datatype
            sender, // int dest, send end_tag to sender-th processor
            end_tag, // int tag, tag = end_tag
            MPI_COMM_WORLD);
        }
    };
}

// Slave process: Receive row data, calculate results, and send results back
void slave()
{
    // Receive b from master
    MPI_Bcast(
        b, // void *buffer
    Acolsize * bcolsize, // int count
    MPI_DOUBLE, // MPI_Datatype datatype
    m_id, // int root
    MPI_COMM_WORLD); // MPI_Comm comm
    while (1)
    {
        // Receive row data from master
        MPI_Recv(
            buffer, // void *buf
        Acolsize, // int count
        MPI_DOUBLE, // MPI_Datatype datatype
        m_id, // int source, receive from master
        MPI_ANY_TAG, // int tag, receive any tag
        MPI_COMM_WORLD, // MPI_Comm comm
        &status); // MPI_Status *status
        // Calculate result and send back, until receive end_tag
        if (status.MPI_TAG != end_tag)
        {
            //Calculate result
            for (int i = 0; i < bcolsize; i++)
            {
                ans_buffer[i] = 0;
                for (int j = 0; j < Acolsize; j++)
                {
                    ans_buffer[i] += buffer[j] * b[j][i];
                }
            }
            //Send back result
            MPI_Send(
                ans_buffer, // const void *buf
            bcolsize, // int count
            MPI_DOUBLE, // MPI_Datatype datatype
            m_id, // int dest, send result to master
            status.MPI_TAG, // int tag, tag = status.MPI_TAG = (i-1)-th row data
            MPI_COMM_WORLD);
        }
        else
        {
            break;
        }
    }
}

int main(int argc, char **argv)
{
    // Initialize MPI
    MPI_Init(&argc, &argv);
    // Get number of processes
    MPI_Comm_size(MPI_COMM_WORLD, &numprocess);
    // Get process id
    int myid;
    MPI_Comm_rank(MPI_COMM_WORLD, &myid);
    // Master process
    if (myid == m_id)
    {
        master();
    }
    // Slave process
    else
    {
        slave();
    }
    // Print result
    if (myid == m_id)
    {
        // print matrix A
        std::cout << "A = " << std::endl;
        for (int i = 0; i < Arowsize; i++)
        {
            for (int j = 0; j < Acolsize; j++)
            {
                std::cout << std::setw(10) << A[i][j];
            }
            std::cout << std::endl;
        }
        // print matrix b
        std::cout << "b = " << std::endl;
        for (int i = 0; i < Acolsize; i++)
        {
            for (int j = 0; j < bcolsize; j++)
            {
                std::cout << std::setw(10) << b[i][j];
            }
            std::cout << std::endl;
        }
        // print matrix ans
        std::cout << "ans = " << std::endl;
        for (int i = 0; i < Arowsize; i++)
        {
            for (int j = 0; j < bcolsize; j++)
            {
                std::cout << std::setw(10) << ans[i][j];
            }
            std::cout << std::endl;
        }
    }
    // Finalize MPI
    MPI_Finalize();
    return 0;
}

 运行结果

运行下列命令

mpirun -n 4 ./matmult

 输出结果:

A = 
         1         1         1         1         1         1         1         1         1         1
         2         2         2         2         2         2         2         2         2         2
         3         3         3         3         3         3         3         3         3         3
         4         4         4         4         4         4         4         4         4         4
         5         5         5         5         5         5         5         5         5         5
         6         6         6         6         6         6         6         6         6         6
         7         7         7         7         7         7         7         7         7         7
         8         8         8         8         8         8         8         8         8         8
         9         9         9         9         9         9         9         9         9         9
        10        10        10        10        10        10        10        10        10        10
b = 
         1         2         3         4         5
         1         2         3         4         5
         1         2         3         4         5
         1         2         3         4         5
         1         2         3         4         5
         1         2         3         4         5
         1         2         3         4         5
         1         2         3         4         5
         1         2         3         4         5
         1         2         3         4         5
ans = 
        10        20        30        40        50
        20        40        60        80       100
        30        60        90       120       150
        40        80       120       160       200
        50       100       150       200       250
        60       120       180       240       300
        70       140       210       280       350
        80       160       240       320       400
        90       180       270       360       450
       100       200       300       400       500

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值