针对脉动阵列 systolic 做了一些小改进。
以下是systolic的矩阵计算原理,主要是控制每个单元计算所需的每行的数据流在同一时刻流过,利用加法器和流水乘法器实现矩阵乘法。此外根据数据流入的两个方向,可以计算出systolic阵列的数据流实际的速度为红色箭头所示。可以计算出完成时间的关系
在此基础上做了一点小扩充:
模块设计:
在此基础上思考输入与输出的格式对齐方案:我使用了一种简单的设计,在计算完成时刻对整体的数据向下或右移出,这样不用添加额外的寄存器或读取转换电路,只需要将向右侧流动的数据位宽改为结果位宽即可,这样可以实现阵列的复用。
可根据需求另外设计:
1.根据左右乘矩阵,修改generate中的数据线横向纵向的连接方式即可
2.根据外部数据输出格式:修改输入输出的矩阵行列元素顺序:顺序输入--顺序输出、逆序输入--逆序输出
以下是输入时刻和输出时刻的数据流,本设计使用自定义无符号流水线乘法器
代码如下:
最小单元设计
module pulse_arrays_pe #(
parameter WIDTH_left = 8,
parameter WIDTH_up = 8,
parameter WIDTH_out = 8
)(
input wire clk,
input wire rst,
input wire [1:0]mode, // mode=0 tpu_computer mode=1,2 shift out_data
input wire [WIDTH_out-1:0] left,
input wire [WIDTH_up-1:0] up,
output reg [WIDTH_out-1:0] right,
output reg [WIDTH_up-1:0] down,
output reg [WIDTH_out-1:0] out_data
);
//reg [WIDTH_out-1:0] out_data;
wire [WIDTH_out-1:0] temp_data;
always@(posedge clk)begin
if(!rst)begin
right <= 0;
down <= 0;
out_data <= 0;
end
else begin
//computer
if(mode==0)begin
right<= left;//{{(WIDTH_out-WIDTH_left){1'd0}},left}
down <= up;
out_data <= temp_data+out_data;
end
//shift out_data
else if(mode==2)begin
right <= out_data;
out_data <= 0;
end
else
right <= left;
end
end
//multiply module instantiation
FIX_unsigned_MUL #(.WIDTH_multiplicand(WIDTH_left), .WIDTH_multiplier(WIDTH_up))
uut
(
.clk (clk),
.rst (rst),
.valid (1'd1),
.multiplicand (left[WIDTH_left-1:0]),
.multiplier (up),
.ready (),
.product (temp_data)
);
endmodule
通用阵列:
module pulse_arrays #(
parameter WIDTH_left = 8,
parameter WIDTH_up = 8,
parameter WIDTH_out = 8,
parameter Mritx_M = 3,//output row
parameter Mritx_N = 3,//
parameter Mritx_L = 3,//output col
parameter Mritx_LOG2_size = 10 //counter's width
)(
input clk,
input rst,
input wire valid_left,
input wire valid_up,
input wire [Mritx_M*WIDTH_left-1:0] left,
input wire [Mritx_L*WIDTH_up-1:0] up,
output reg ready, //ready for input
output wire [WIDTH_out*Mritx_M-1:0] product
);
localparam idl = 4'd0;
localparam state_in = 4'd1;
localparam state_out= 4'd2;
// state register
reg [3:0] state;
wire [WIDTH_out-1:0] left_temp [Mritx_M*Mritx_L-1:0]; //x_unite_wire
wire [WIDTH_up-1:0] up_temp [Mritx_M*Mritx_L-1:0]; //y_unite_wire
//enable signal
// reg [Mritx_M-1:0] left_shift_en;
// reg [Mritx_L-1:0] up_shift_en;
reg star=0,export=0,finish=0;//flag
reg [Mritx_LOG2_size-1:0] cnt_flow1;
reg [Mritx_LOG2_size-1:0] cnt_flow2;
reg [2*Mritx_L-1:0] mode_control=0;// 行同步控制 初始赋值0 大量变1易串扰不稳定
//output
wire [WIDTH_out-1:0] out_data[Mritx_M*Mritx_L-1:0];
assign left_temp[0] = valid_left?{{(WIDTH_out-WIDTH_left){1'd0}},left[WIDTH_left-1:0]}:0;
assign up_temp[0] = valid_up?up[WIDTH_up-1:0]:0;
always @(posedge clk ) begin
if(!rst)begin
state <= idl;
star <= 0;
export<= 0;
ready <= 0;
finish<= 0;
cnt_flow1 <= 0;
cnt_flow2 <= 0;
mode_control <= 0;
end
else begin
case (state)
idl :begin
ready <= 1;
star <= 0;
export <= 0;
finish <= 0;
cnt_flow1 <= 0;
cnt_flow2 <= 0;
mode_control <= 0;
if(valid_left&valid_up)
state <= state_in;
else
state <= idl;
end
state_in :begin
ready <= 0;
star <= 1;
export <= 0;
finish <= 0;
cnt_flow1 <= cnt_flow1 + 1;
cnt_flow2 <= 0;
if(cnt_flow1==Mritx_M+Mritx_N+Mritx_L-2+WIDTH_up-1)begin
state <= state_out;
mode_control <= {Mritx_L{2'd2}};
end
else
state <= state_in;
end
state_out:begin
ready <= 0;
star <= 0;
cnt_flow1 <= 0;
cnt_flow2 <= cnt_flow2 + 1;
if(cnt_flow2==0)
mode_control <= {Mritx_L{2'd1}}<<2;
else
mode_control <= mode_control<<2;
if(cnt_flow2==Mritx_L) begin
state <= idl;
finish<= 1;
export<= 0;
end
else begin
state <= state_out;
finish<= 0;
export<= 1;
end
end
default: begin
ready <= 0;
star <= 0;
export <= 0;
finish <= 0;
cnt_flow1 <= 0;
cnt_flow2 <= 0;
mode_control <= 0;
state <= idl;
end
endcase
end
end
generate
genvar i,j;
for(i=0;i<=Mritx_M-1;i=i+1)begin
for(j=0;j<=Mritx_L-1;j=j+1)begin:pulse_arrays_pex
if(i==0&&j==0)begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[0]),
.up (up_temp[0]),
.right (left_temp[i*Mritx_L+j+1]),
.down (up_temp[(i+1)*Mritx_L+j]),
.out_data (out_data[i*Mritx_L+j])
);
end
else if(i==Mritx_M-1&&j==Mritx_L-1)begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[i*Mritx_L+j]),
.up (up_temp[i*Mritx_L+j]),
.right (product[WIDTH_out*(i+1)-1:WIDTH_out*i]),
.down (),
.out_data (out_data[i*Mritx_L+j])
);
end
else if(i==Mritx_M-1)begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[i*Mritx_L+j]),
.up (up_temp[i*Mritx_L+j]),
.right (left_temp[i*Mritx_L+j+1]),
.down (),
.out_data (out_data[i*Mritx_L+j])
);
end
else if(j==Mritx_L-1)begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[i*Mritx_L+j]),
.up (up_temp[i*Mritx_L+j]),
.right (product[WIDTH_out*(i+1)-1:WIDTH_out*i]),
.down (up_temp[(i+1)*Mritx_L+j]),
.out_data (out_data[i*Mritx_L+j])
);
end
else begin
pulse_arrays_pe #(
.WIDTH_left (WIDTH_left),
.WIDTH_up (WIDTH_up),
.WIDTH_out (WIDTH_out)
)pulse_arrays_pex(
.clk (clk),
.rst (rst),
.mode (mode_control[(j+1)*2-1:j*2]),
.left (left_temp[i*Mritx_L+j]),
.up (up_temp[i*Mritx_L+j]),
.right (left_temp[i*Mritx_L+j+1]),
.down (up_temp[(i+1)*Mritx_L+j]),
.out_data (out_data[i*Mritx_L+j])
);
end
end
end
endgenerate
generate
genvar m,n;
for(m=1;m<Mritx_M;m=m+1)begin:shift_register_left
shift_register #(
.WIDTH_in(WIDTH_left),
.WIDTH_out(WIDTH_out),
.DEEP(m),
.PTR_SIZE(Mritx_LOG2_size)
)shift_register_left(
.clk(clk),
.rst(rst),
.shift_en(valid_left),
.shift_in(left[(m+1)*WIDTH_left-1:(m)*WIDTH_left]),
.shift_out(left_temp[m*Mritx_L])
);
end
for(n=1;n<Mritx_L;n=n+1)begin:shift_register_up
shift_register #(
.WIDTH_in(WIDTH_up),
.WIDTH_out(WIDTH_up),
.DEEP(n),
.PTR_SIZE(Mritx_LOG2_size)
)shift_register_up(
.clk(clk),
.rst(rst),
.shift_en(valid_up),
.shift_in(up[(n+1)*WIDTH_up-1:n*WIDTH_up]),
.shift_out(up_temp[n])
);
end
endgenerate
endmodule
要想在这个架构上实现真正的流水线矩阵乘法器,需要按照第一张图中的单元计算完成梯度图来将数据读出,但我目前并没有找到在不消耗大量硬件资源的情况下可通用的方法,后续找到的话会继续更新。
仿真:
本项目开源所有代码已上传至gihub仓库:Debug-xmh/Systiolic-Matrix-multiplier: Universal matrix multiplier, an improvement on the design of Systiolic Matrix multiplier in verilog. (github.com)