并行计算之并行前缀法(Parallel Prefix)

前缀和问题(Prefix sum)

首先介绍一下前缀和问题:

输入:一个二元可结合的运算符 ⊗ \otimes (比如说加减乘除),以及 n n n 个元素 x 0 , x 1 , x 2 , ⋯   , x n − 1 x_0, x_1, x_2,\cdots,x_{n-1} x0,x1,x2,,xn1

输出: n n n 个元素 s 0 , s 1 , ⋯   , s n − 1 s_0,s_1, \cdots,s_{n-1} s0,s1,,sn1,对任意 i ∈ [ 0 , n − 1 ] i\in[0,n-1] i[0,n1],都有 s i = x 0 ⊗ x 1 ⊗ ⋯ x i s_i=x_0 \otimes x_1 \otimes \cdots x_i si=x0x1xi

比如说,给定加法操作,以及输入 16 , 23 , 7 , 31 , 9 16,23,7,31,9 16,23,7,31,9

输入16237319
1616+2316+23+716+23+7+3116+23+7+31+9
输出1639467786

该问题使用串行算法实现很简单,时间复杂度为 O ( n ) O(n) O(n)

由于存在串行依赖,即 s i s_i si 依赖于 s i − 1 s_{i-1} si1,所以有必要考虑如何对这类问题进行并行化

并行前缀算法

首先来考虑这么一个问题:假设输入 x i x_i xi 9 , 8 , 3 , 2 , 7 , 1 , 6 , 4 9,8,3,2,7,1,6,4 9,8,3,2,7,1,6,4,操作为加法操作,一共8个进程,每个进程恰好拥有一个输入 x i x_i xi,那么可以按照下面这张图来计算输出 s i s_i si

在这里插入图片描述

看懂这张图的计算过程后,下面这个并行前缀和的算法就可以看懂了,其中id是当前进程的id,p是进程的数目

在这里插入图片描述

上面的例子中进程的数目恰好等于输入元素的个数,但实际情况却有可能是

  • n > p n>p n>p n n n 是输入元素的个数, p p p 是进程的数目
  • n n n 不是 p p p 的整数倍
    • 每个进程拥有的输入元素个数为 ⌈ n p ⌉ \lceil \frac{n}{p} \rceil pn ⌊ n p ⌋ \lfloor \frac{n}{p} \rfloor pn
  • p p p 不是2的整数次幂
    • d = ⌈ log ⁡ 2 p ⌉ d=\lceil \log_2p \rceil d=log2p
    • 在任何通信阶段,如果当前进程要通信的进程的id大于 p p p (假设进程从0依次编号到 p − 1 p-1 p1),则不执行任何操作

因此,更加通用的算法可以分为以下三步 (假设每个进程拥有的输入元素个数都为 n / p n/p n/p,如果n不是p的整数倍,其实步骤也差不多)

  1. 每个进程计算它自己拥有的这 n / p n/p n/p 个元素的前缀和
  2. 每个进程都调用上面的 PARALLEL_PREFIX_SUM(id, Xid, p) 函数,其中id是进程的id号,p是进程的数目,Xid是每个进程自己这 n / p n/p n/p 个元素的前缀和中的最后那个前缀和
  3. 最后每个进程再把函数PARALLEL_PREFIX_SUM的返回值加到已经计算的前缀和上即可

下面看个例子来体会这个过程

在这里插入图片描述

这个算法的时间复杂度为:

  • 步骤1:计算时间为 O ( n p ) O(\frac{n}{p}) O(pn),进程通信时间为 0
  • 步骤2:计算时间为 O ( log ⁡ p ) O(\log p) O(logp),进程通信时间为 O ( log ⁡ p ) O(\log p) O(logp)
  • 步骤3:计算时间为 O ( n p ) O(\frac{n}{p}) O(pn),进程通信时间为 0

最终计算时间为 O ( n p + log ⁡ p ) O(\frac{n}{p}+\log p) O(pn+logp),通信时间为 O ( log ⁡ p ) O(\log p) O(logp)

实际应用

多项式的计算

输入:一个实数 x 0 x_0 x0,以及 n n n 个整数系数 { a 0 , a 1 , ⋯   , a n − 1 } \{a_0,a_1,\cdots,a_{n-1}\} {a0,a1,,an1}

输出: P ( x 0 ) = a 0 + a 1 x 0 + a 2 x 0 2 + ⋯ + a n − 1 x 0 n − 1 P(x_0)=a_0+a_1x_0+a_2x_0^2+\cdots+a_{n-1}x_0^{n-1} P(x0)=a0+a1x0+a2x02++an1x0n1

可以使用并行前缀法解决这个问题:

  • 假设这 n n n 个整数系数分布在 p p p 个进程上,不妨认为进程 P i P_i Pi 拥有 a i n p a_{i\frac{n}{p}} aipn a ( i + 1 ) n p − 1 a_{(i+1)\frac{n}{p}-1} a(i+1)pn1

  • 进程 P i P_i Pi 负责计算局部和
    s u m ( i ) = ∑ j = 0 n p − 1 a i n P + j + x 0 i n p + j sum(i) = \sum_{j=0}^{\frac{n}{p}-1}a_{i\frac{n}{P}+j}+x_0^{i\frac{n}{p}+j} sum(i)=j=0pn1aiPn+j+x0ipn+j

  • 进程 P i P_i Pi 需要的 x 0 i n p x_0^{i\frac{n}{p}} x0ipn 可以通过并行前缀法来求得,即,每个进程先是计算 x 0 n p x_0^{\frac{n}{p}} x0pn,然后每个进程 P i P_i Pi 通过并行前缀法来求得自己需要的 x 0 i n p x_0^{i\frac{n}{p}} x0ipn

线性递归

输入:实数 x 0 , x 1 x_0,x_1 x0,x1,以及整数系数 a , b a, b a,b

输出:序列 { x 2 , x 3 , ⋯   , x n } \{x_2,x_3,\cdots,x_n\} {x2,x3,,xn} 使得 x i = a x i − 1 + b x i − 2 x_i=ax_{i-1}+bx_{i-2} xi=axi1+bxi2

上式可以重写为
[ x i x i − 1 ] = [ x i − 1 x i − 2 ] [ a 1 b 0 ] \begin{bmatrix} x_i & x_{i-1} \end{bmatrix} =\begin{bmatrix} x_{i-1} & x_{i-2} \end{bmatrix} \begin{bmatrix} a & 1\\ b & 0 \end{bmatrix} [xixi1]=[xi1xi2][ab10]
因此有
[ x i x i − 1 ] = [ x 1 x 0 ] [ a 1 b 0 ] i − 1 \begin{bmatrix} x_i & x_{i-1} \end{bmatrix} =\begin{bmatrix} x_{1} & x_{0} \end{bmatrix} \begin{bmatrix} a & 1\\ b & 0 \end{bmatrix}^{i-1} [xixi1]=[x1x0][ab10]i1
可以使用并行前缀法计算 [ a 1 b 0 ] i \begin{bmatrix}a & 1 \\b & 0 \end{bmatrix} ^ {i} [ab10]i

基于线性同余生成器的伪随机数序列发生器

输入:整数 A 和 B,以及大素数 P

输出:输出伪随机数序列 { x 1 , ⋯   , x n } \{x_1,\cdots,x_n\} {x1,,xn},其中 x i + 1 = ( A x i + B )   m o d   P x_{i+1}=(Ax_i+B)\space mod\space P xi+1=(Axi+B) mod P,不妨设 x 0 = 0 x_0=0 x0=0

类似地,我们有
[ x i 1 ] = [ x 0 1 ] [ A 0 B 1 ] i = [ 0 1 ] [ A 0 B 1 ] i \begin{bmatrix} x_i & 1 \end{bmatrix} =\begin{bmatrix} x_{0} & 1 \end{bmatrix} \begin{bmatrix} A & 0\\ B & 1 \end{bmatrix}^{i}=\begin{bmatrix} 0 & 1 \end{bmatrix} \begin{bmatrix} A & 0\\ B & 1 \end{bmatrix}^{i} [xi1]=[x01][AB01]i=[01][AB01]i
可以使用并行前缀法计算 [ A 0 B 1 ] i \begin{bmatrix}A & 0 \\B & 1 \end{bmatrix} ^ {i} [AB01]i

下面是实现的完整源码

#include <stdio.h>
#include <stdlib.h>
#include <mpi.h>
#include <sys/time.h>

// #define PRINT_SERIES

/*
    return the difference between end and start in microseconds
*/
int timeDiff(struct timeval start, struct timeval end) {
	return (end.tv_sec-start.tv_sec)*1000000 + (end.tv_usec-start.tv_usec);
}

void compute_Mi_mul_M(int *Mi, int A, int B, int P);
void serial_matrix(int *base, int A, int B, int P, int iter_times, int *random_numbers);

int main(int argc, char *argv[]){
    int my_rank, comm_sz, n, A, B, P, *local_series, i, local_count, *total_series;
    struct timeval start_time, end_time;
    int local_runtime, total_runtime;

	MPI_Init(&argc, &argv);
	MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
	MPI_Comm_size(MPI_COMM_WORLD, &comm_sz);

    // get input from the command
    n = atoi(argv[1]); 

    // {A,B,P} are user-defined parameters
    A = 2;
    B = 101;
    P = 93563;
    if (argc > 2)
        A = atoi(argv[2]);
    else if (argc > 3)
        B = atoi(argv[3]);
    else if (argc > 4)
        P = atoi(argv[4]);

    // Suppose n is divisible by comm_sz
    local_count = n / comm_sz;
    local_series = (int*)malloc(sizeof(int) * local_count);

    gettimeofday(&start_time, NULL);

    int base_M[4] = {1, 0, 0, 1}, send_M[4] = {1, 0, 0, 1}, recv_M[4];
    // calculate M^local_count
    for(i = 0; i < local_count; ++i){
        compute_Mi_mul_M(send_M, A, B, P);
    }
    MPI_Status status;
    // Use parallel prefix to compute M^i
    for(i = 1; i < comm_sz; i = i << 1){
        int partner;
        if (my_rank % (i << 1) < i){
            partner = my_rank + i;
            if (partner < comm_sz){
                MPI_Send(send_M, 4, MPI_INT, partner, 0, MPI_COMM_WORLD);
                MPI_Recv(recv_M, 4, MPI_INT, partner, 0, MPI_COMM_WORLD, &status);
                send_M[0] = ((long long)send_M[0] * recv_M[0]) % P;
                send_M[2] = ((long long)send_M[2] * recv_M[0] + recv_M[2]) % P;
            }
        }else{
            partner = my_rank - i;
            MPI_Recv(recv_M, 4, MPI_INT, partner, 0, MPI_COMM_WORLD, &status);
            MPI_Send(send_M, 4, MPI_INT, partner, 0, MPI_COMM_WORLD);
            send_M[0] = ((long long)send_M[0] * recv_M[0]) % P;
            send_M[2] = ((long long)send_M[2] * recv_M[0] + recv_M[2]) % P;
            base_M[0] = ((long long)base_M[0] * recv_M[0]) % P;
            base_M[2] = ((long long)base_M[2] * recv_M[0] + recv_M[2]) % P;
        }
    }
    // calculate local random numbers
    serial_matrix(base_M, A, B, P, local_count, local_series);

    gettimeofday(&end_time, NULL);

    // Gather all the local_series to rank 0
    if(my_rank == 0){
        total_series = (int*)malloc(sizeof(int) * n);
        MPI_Gather(local_series, local_count, MPI_INT, total_series, local_count, MPI_INT, 0, MPI_COMM_WORLD);
#ifdef PRINT_SERIES
        for(i = 0; i < n; ++i){
            printf("%d\n", total_series[i]);
        }
#endif
        free(total_series);
    }else{
        MPI_Gather(local_series, local_count, MPI_INT, total_series, local_count, MPI_INT, 0, MPI_COMM_WORLD);
    }

    local_runtime = timeDiff(start_time, end_time);
    MPI_Reduce(&local_runtime, &total_runtime, 1, MPI_INT, MPI_MAX, 0, MPI_COMM_WORLD);
    if (my_rank == 0)
        printf("Total runtime (in microseconds): %d\n", total_runtime);

    free(local_series);
    MPI_Finalize();
    return 0;
}

void compute_Mi_mul_M(int *Mi, int A, int B, int P){
    Mi[0] = (Mi[0] * A) % P;
    Mi[2] = (Mi[2] * A + B) % P;
}

/*
    NOTE: base is a 2*2 matrix
*/
void serial_matrix(int *base, int A, int B, int P, int iter_times, int *random_numbers){
    int i;
    if(iter_times <= 0){
        return;
    }
    compute_Mi_mul_M(base, A, B, P);
    random_numbers[0] = base[2];
    for(i = 1; i < iter_times; ++i){
        compute_Mi_mul_M(base, A, B, P);
        random_numbers[i] = base[2];
    }
}

致谢

本文参考了文件,该文件还包括了该算法的另外两个应用,感兴趣的可以下载查看

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值