通用大规模可复用的 systolic 矩阵乘法器

        针对脉动阵列 systolic 做了一些小改进。
以下是systolic的矩阵计算原理,主要是控制每个单元计算所需的每行的数据流在同一时刻流过,利用加法器和流水乘法器实现矩阵乘法。此外根据数据流入的两个方向,可以计算出systolic阵列的数据流实际的速度为红色箭头所示。可以计算出完成时间的关系

systolic 矩阵乘法设计

在此基础上做了一点小扩充:

通用 systolic  乘法器

模块设计:

通用乘法器模块设计

在此基础上思考输入与输出的格式对齐方案:我使用了一种简单的设计,在计算完成时刻对整体的数据向下或右移出,这样不用添加额外的寄存器或读取转换电路,只需要将向右侧流动的数据位宽改为结果位宽即可,这样可以实现阵列的复用。

可根据需求另外设计:
1.根据左右乘矩阵,修改generate中的数据线横向纵向的连接方式即可

2.根据外部数据输出格式:修改输入输出的矩阵行列元素顺序:顺序输入--顺序输出、逆序输入--逆序输出

以下是输入时刻和输出时刻的数据流,本设计使用自定义无符号流水线乘法器

输入
输出
最小单元设计

代码如下:
最小单元设计

module pulse_arrays_pe  #(
    parameter   WIDTH_left     =   8,
    parameter   WIDTH_up       =   8,
    parameter   WIDTH_out      =   8
)(
    input   wire    clk,
    input   wire    rst,
    input   wire    [1:0]mode,           // mode=0 tpu_computer mode=1,2 shift out_data

    input   wire    [WIDTH_out-1:0]  left,
    input   wire    [WIDTH_up-1:0]  up,

    output  reg     [WIDTH_out-1:0]  right,
    output  reg     [WIDTH_up-1:0]  down,
    output  reg     [WIDTH_out-1:0]  out_data
);
//reg     [WIDTH_out-1:0]  out_data;
wire    [WIDTH_out-1:0]  temp_data;

always@(posedge  clk)begin
    if(!rst)begin
        right   <=  0;
        down   <=  0;
        out_data <=  0;
    end
    else begin
        //computer
        if(mode==0)begin
            right<= left;//{{(WIDTH_out-WIDTH_left){1'd0}},left}
            down     <=  up;
            out_data <=  temp_data+out_data;
        end
        //shift out_data
        else if(mode==2)begin                  
            right    <=  out_data;
            out_data <=  0;
        end

        else
            right    <=  left;

    end
end

    //multiply  module instantiation
    FIX_unsigned_MUL  #(.WIDTH_multiplicand(WIDTH_left), .WIDTH_multiplier(WIDTH_up))
    uut
    (
        .clk               (clk),
        .rst               (rst),
        .valid             (1'd1),
        .multiplicand      (left[WIDTH_left-1:0]),
        .multiplier        (up),
        .ready             (),
        .product           (temp_data)
    );

endmodule

通用阵列:

module pulse_arrays  #(
    parameter   WIDTH_left     =   8,
    parameter   WIDTH_up       =   8,
    parameter   WIDTH_out      =   8,
    parameter   Mritx_M   =   3,//output row
    parameter   Mritx_N   =   3,//
    parameter   Mritx_L   =   3,//output col
    parameter   Mritx_LOG2_size  =   10 //counter's width
)(
    input   clk,
    input   rst,

    input   wire    valid_left,
    input   wire    valid_up,
    input   wire  [Mritx_M*WIDTH_left-1:0]   left,
    input   wire  [Mritx_L*WIDTH_up-1:0]     up,

    output  reg     ready,                          //ready for input
    output  wire  [WIDTH_out*Mritx_M-1:0]    product 
);
    localparam  idl      = 4'd0;
    localparam  state_in = 4'd1;
    localparam  state_out= 4'd2;
    // state register
    reg [3:0] state;


    wire    [WIDTH_out-1:0]          left_temp [Mritx_M*Mritx_L-1:0];   //x_unite_wire
    wire    [WIDTH_up-1:0]           up_temp   [Mritx_M*Mritx_L-1:0];   //y_unite_wire

//enable  signal

    // reg     [Mritx_M-1:0]           left_shift_en;
    // reg     [Mritx_L-1:0]           up_shift_en;
    reg     star=0,export=0,finish=0;//flag
    reg     [Mritx_LOG2_size-1:0]      cnt_flow1;
    reg     [Mritx_LOG2_size-1:0]      cnt_flow2;

    reg     [2*Mritx_L-1:0]      mode_control=0;// 行同步控制 初始赋值0 大量变1易串扰不稳定
//output 
    wire    [WIDTH_out-1:0] out_data[Mritx_M*Mritx_L-1:0];  


assign left_temp[0] = valid_left?{{(WIDTH_out-WIDTH_left){1'd0}},left[WIDTH_left-1:0]}:0;
assign up_temp[0]   = valid_up?up[WIDTH_up-1:0]:0;

always @(posedge clk ) begin
    if(!rst)begin
        state <= idl;
        star  <= 0;
        export<= 0;
        ready <= 0;
        finish<= 0;
        cnt_flow1 <= 0;
        cnt_flow2 <= 0;
        mode_control   <= 0;
    end

    else begin
        case (state)
        idl :begin
            ready     <= 1;
            star      <= 0;
            export    <= 0;
            finish    <= 0;
            cnt_flow1 <= 0;
            cnt_flow2 <= 0;
            mode_control <= 0;
            if(valid_left&valid_up)
                state <= state_in;
            else
                state <= idl;
        end 
        state_in :begin
            ready     <= 0;
            star      <= 1;
            export    <= 0;
            finish    <= 0;
            cnt_flow1 <= cnt_flow1 + 1;
            cnt_flow2 <= 0;
            if(cnt_flow1==Mritx_M+Mritx_N+Mritx_L-2+WIDTH_up-1)begin
                state <= state_out;
                mode_control <= {Mritx_L{2'd2}};
            end
            else
                state <= state_in;
        end 
        state_out:begin
            ready     <= 0;
            star      <= 0;
            cnt_flow1 <= 0;
            cnt_flow2 <= cnt_flow2 + 1;
            if(cnt_flow2==0)
                mode_control <= {Mritx_L{2'd1}}<<2;
            else
                mode_control <= mode_control<<2;
            if(cnt_flow2==Mritx_L) begin
                state <= idl;
                finish<= 1;
                export<= 0;
            end
            else begin
                state <= state_out;
                finish<= 0;
                export<= 1;
            end
        end 

        default: begin
            ready     <= 0;
            star      <= 0;
            export    <= 0;
            finish    <= 0;
            cnt_flow1 <= 0;
            cnt_flow2 <= 0;
            mode_control <= 0;
            state <= idl;
        end
        endcase
    end
end

generate 
    genvar  i,j;
    for(i=0;i<=Mritx_M-1;i=i+1)begin
        for(j=0;j<=Mritx_L-1;j=j+1)begin:pulse_arrays_pex
            if(i==0&&j==0)begin
                pulse_arrays_pe #(
                    .WIDTH_left (WIDTH_left),
                    .WIDTH_up   (WIDTH_up),
                    .WIDTH_out  (WIDTH_out)
                )pulse_arrays_pex(
                    .clk        (clk),
                    .rst        (rst),
                    .mode       (mode_control[(j+1)*2-1:j*2]),
                    .left       (left_temp[0]),
                    .up         (up_temp[0]),
                    .right      (left_temp[i*Mritx_L+j+1]),
                    .down       (up_temp[(i+1)*Mritx_L+j]),
                    .out_data   (out_data[i*Mritx_L+j])
                );
            end
            
            else if(i==Mritx_M-1&&j==Mritx_L-1)begin
                pulse_arrays_pe #(
                    .WIDTH_left (WIDTH_left),
                    .WIDTH_up   (WIDTH_up),
                    .WIDTH_out  (WIDTH_out)
                )pulse_arrays_pex(
                    .clk        (clk),
                    .rst        (rst),
                    .mode       (mode_control[(j+1)*2-1:j*2]),
                    .left       (left_temp[i*Mritx_L+j]),
                    .up         (up_temp[i*Mritx_L+j]),
                    .right      (product[WIDTH_out*(i+1)-1:WIDTH_out*i]),
                    .down       (),
                    .out_data   (out_data[i*Mritx_L+j])
                );
            end
            else if(i==Mritx_M-1)begin
                pulse_arrays_pe #(
                    .WIDTH_left (WIDTH_left),
                    .WIDTH_up   (WIDTH_up),
                    .WIDTH_out  (WIDTH_out)
                )pulse_arrays_pex(
                    .clk        (clk),
                    .rst        (rst),
                    .mode       (mode_control[(j+1)*2-1:j*2]),
                    .left       (left_temp[i*Mritx_L+j]),
                    .up         (up_temp[i*Mritx_L+j]),
                    .right      (left_temp[i*Mritx_L+j+1]),
                    .down       (),
                    .out_data   (out_data[i*Mritx_L+j])
                );
            end
            else if(j==Mritx_L-1)begin
                pulse_arrays_pe #(
                    .WIDTH_left (WIDTH_left),
                    .WIDTH_up   (WIDTH_up),
                    .WIDTH_out  (WIDTH_out)
                )pulse_arrays_pex(
                    .clk        (clk),
                    .rst        (rst),
                    .mode       (mode_control[(j+1)*2-1:j*2]),
                    .left       (left_temp[i*Mritx_L+j]),
                    .up         (up_temp[i*Mritx_L+j]),
                    .right      (product[WIDTH_out*(i+1)-1:WIDTH_out*i]),
                    .down       (up_temp[(i+1)*Mritx_L+j]),
                    .out_data   (out_data[i*Mritx_L+j])
                );
            end
            else begin
                pulse_arrays_pe #(
                    .WIDTH_left (WIDTH_left),
                    .WIDTH_up   (WIDTH_up),
                    .WIDTH_out  (WIDTH_out)
                )pulse_arrays_pex(
                    .clk        (clk),
                    .rst        (rst),
                    .mode       (mode_control[(j+1)*2-1:j*2]),
                    .left       (left_temp[i*Mritx_L+j]),
                    .up         (up_temp[i*Mritx_L+j]),
                    .right      (left_temp[i*Mritx_L+j+1]),
                    .down       (up_temp[(i+1)*Mritx_L+j]),
                    .out_data   (out_data[i*Mritx_L+j])
                );
            end
        end
    end
endgenerate

generate
    genvar m,n;
        for(m=1;m<Mritx_M;m=m+1)begin:shift_register_left
            shift_register #(
                .WIDTH_in(WIDTH_left),
                .WIDTH_out(WIDTH_out),
                .DEEP(m),
                .PTR_SIZE(Mritx_LOG2_size)
            )shift_register_left(
                .clk(clk),
                .rst(rst),
                .shift_en(valid_left),
                .shift_in(left[(m+1)*WIDTH_left-1:(m)*WIDTH_left]),
                .shift_out(left_temp[m*Mritx_L])
            );
        end

        for(n=1;n<Mritx_L;n=n+1)begin:shift_register_up
            shift_register #(
                .WIDTH_in(WIDTH_up),
                .WIDTH_out(WIDTH_up),
                .DEEP(n),
                .PTR_SIZE(Mritx_LOG2_size)
            )shift_register_up(
                .clk(clk),
                .rst(rst),
                .shift_en(valid_up),
                .shift_in(up[(n+1)*WIDTH_up-1:n*WIDTH_up]),
                .shift_out(up_temp[n])
            );
        end
endgenerate

endmodule

要想在这个架构上实现真正的流水线矩阵乘法器,需要按照第一张图中的单元计算完成梯度图来将数据读出,但我目前并没有找到在不消耗大量硬件资源的情况下可通用的方法,后续找到的话会继续更新。

仿真:

本项目开源所有代码已上传至gihub仓库:Debug-xmh/Systiolic-Matrix-multiplier: Universal matrix multiplier, an improvement on the design of Systiolic Matrix multiplier in verilog. (github.com)

  • 5
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

OTITA

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值