本文用system verilog实现了分块矩阵乘法中计算输出矩阵的某一块,并且进行了pingpang操作,以掩盖数据传输时间。
这是顶层模块的代码:
`timescale 1ns / 1ps
//
// Company:
// Engineer:
//
// Create Date: 2020/11/16 22:53:40
// Design Name:
// Module Name: compute_one_block
// Project Name:
// Target Devices:
// Tool Versions:
// Description:
//
// Dependencies:
//
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
//
//
module compute_one_block(
input logic clk,
input logic rst,
input logic start,
input logic [15:0]dina,
input logic [15:0]dinb,
input logic [7:0]block_row,
input logic [7:0]block_col,
output logic [7:0]addra,
output logic [7:0]addrb,
output logic [15:0]result[0:Tn-1][0:Tn-1],
output logic done
);
parameter Tn=4;
parameter N=16;
logic [15:0] buff_a1[0:Tn-1][0:Tn-1];
logic [15:0] buff_a2[0:Tn-1][0:Tn-1];
logic [15:0] buff_b1[0:Tn-1][0:Tn-1];
logic [15:0] buff_b2[0:Tn-1][0:Tn-1];
logic [15:0] buff_o1[0:Tn-1][0:Tn-1];
logic [15:0] buff_o2[0:Tn-1][0:Tn-1];
logic pingpang;
logic pingpang_start;
logic pingpang_done;
logic start_load1;
logic start_load2;
logic start_compute1;
logic start_compute2;
logic load1_done;
logic load2_done;
logic compute1_done;
logic compute2_done;
logic load1_done_ff;
logic load2_done_ff;
logic compute1_done_ff;
logic compute2_done_ff;
logic [7:0]addra1;
logic [7:0]addra2;
logic [7:0]addrb1;
logic [7:0]addrb2;
logic [7:0]block_k;
logic [7:0]pre_block_k; //load block and compute pre_block_k
logic first_load;
logic final_compute;
logic busy;
//result
always_ff@(posedge clk,posedge rst)
if(rst)
begin
for(int i=0;i<Tn;i++)
for(int j=0;j<Tn;j++)
result[i][j]<=16'd0;
end
else if(start)
begin
for(int i=0;i<Tn;i++)
for(int j=0;j<Tn;j++)
result[i][j]<=16'd0;
end
else if(busy)
if(compute1_done)
for(int i=0;i<Tn;i++)
for(int j=0;j<Tn;j++)
result[i][j]<=result[i][j]+buff_o1[i][j];
else if(compute2_done)
for(int i=0;i<Tn;i++)
for(int j=0;j<Tn;j++)
result[i][j]<=result[i][j]+buff_o2[i][j];
//first_load,final_compute
assign first_load=(busy&&block_k==0)?1'b1:1'b0;
assign final_compute=(busy&&pre_block_k==N-Tn)?1'b1:1'b0;
assign init=(busy&&pre_block_k==0)?1'b1:1'b0;
//busy
always_ff@(posedge clk,posedge rst)
if(rst)
busy<=1'b0;
else if(start)
busy<=1'b1;
else if(pingpang_done&&pre_block_k==N-Tn)
busy<=1'b0;
//pingpang_start
always_ff@(posedge clk,posedge rst)
if(rst)
pingpang_start<=1'b0;
else if(start)
pingpang_start<=1'b1;
else if(pingpang_done&&~pingpang_start&&busy&&~done)
pingpang_start<=1'b1;
else
pingpang_start<=1'b0;
//pingpang
always_ff@(posedge clk,posedge rst)
if(rst)
pingpang<=1'b0;
else if(start)
pingpang<=1'b0;
else if(pingpang_done)
pingpang<=~pingpang;
//load1_done_ff
always_ff@(posedge clk,posedge rst)
if(rst)
load1_done_ff<=1'b0;
else if(start||pingpang_done)
load1_done_ff<=1'b0;
else if(load1_done)
load1_done_ff<=1'b1;
//load2_done_ff
always_ff@(posedge clk,posedge rst)
if(rst)
load2_done_ff<=1'b0;
else if(start||pingpang_done)
load2_done_ff<=1'b0;
else if(load2_done)
load2_done_ff<=1'b1;
//compute1_done_ff
always_ff@(posedge clk,posedge rst)
if(rst)
compute1_done_ff<=1'b0;
else if(start||pingpang_done)
compute1_done_ff<=1'b0;
else if(compute1_done)
compute1_done_ff<=1'b1;
//compute2_done_ff
always_ff@(posedge clk,posedge rst)
if(rst)
compute2_done_ff<=1'b0;
else if(start||pingpang_done)
compute2_done_ff<=1'b0;
else if(compute2_done)
compute2_done_ff<=1'b1;
//pingpang_done
always_ff@(posedge clk,posedge rst)
if(rst)
pingpang_done<=1'b0;
else if(pingpang==1'b0) //load buffer1 and compute buffer2
if(~pingpang_done)
if(first_load&&load1_done_ff)
pingpang_done<=1'b1;
else if(final_compute&&compute2_done_ff)
pingpang_done<=1'b1;
else if(load1_done_ff&&compute2_done_ff)
pingpang_done<=1'b1;
else
pingpang_done<=1'b0;
else
pingpang_done<=1'b0;
else //load2 and compute1
if(~pingpang_done)
if(first_load&&load2_done_ff)
pingpang_done<=1'b1;
else if(final_compute&&compute1_done_ff)
pingpang_done<=1'b1;
else if(load2_done_ff&&compute1_done_ff)
pingpang_done<=1'b1;
else
pingpang_done<=1'b0;
else
pingpang_done<=1'b0;
//1,2的start_load和start_compute信号
assign start_load1=(~pingpang&&pingpang_start&&~final_compute)?1'b1:1'b0;
assign start_load2=(pingpang&&pingpang_start&&~final_compute)?1'b1:1'b0;
assign start_compute1=(pingpang&&pingpang_start&&~first_load)?1'b1:1'b0;
assign start_compute2=(~pingpang&&pingpang_start&&~first_load)?1'b1:1'b0;
//根据pingpang选择地址线来源
assign addra=(pingpang==1'b1)?addra2:addra1;
assign addrb=(pingpang==1'b1)?addrb2:addrb1;
//block_k
always_ff@(posedge clk,posedge rst)
if(rst)
block_k<=8'd0;
else if(start)
block_k<=8'd0;
else if(pingpang_done)
block_k<=(block_k==N-Tn)?block_k:block_k+Tn;
//pre_block_k
always_ff@(posedge clk,posedge rst)
if(rst)
pre_block_k<=8'd0;
else if(start)
pre_block_k<=8'd0;
else if(pingpang_done)
pre_block_k<=block_k;
//done
assign done=(pingpang_done&&pre_block_k==N-Tn)?1'b1:1'b0;
//模块例化
load_two_block load1
(
.clk(clk),
.rst(rst),
.start(start_load1),
.block_row(block_row),
.block_col(block_col),
.block_k(block_k), //load A[block_row:block_row+Tn,block_k:block_k+Tn]
.dina(dina), //load B[block_k:bloc_k+Tn,block_col:block_col+Tn]
.dinb(dinb),
.addra(addra1),
.addrb(addrb1),
.block_mat_a(buff_a1),
.block_mat_b(buff_b1),
.done(load1_done)
);
load_two_block load2
(
.clk(clk),
.rst(rst),
.start(start_load2),
.block_row(block_row),
.block_col(block_col),
.block_k(block_k), //load A[block_row:block_row+Tn,block_k:block_k+Tn]
.dina(dina), //load B[block_k:bloc_k+Tn,block_col:block_col+Tn]
.dinb(dinb),
.addra(addra2),
.addrb(addrb2),
.block_mat_a(buff_a2),
.block_mat_b(buff_b2),
.done(load2_done)
);
block_mm compute1
(
.clk(clk),
.rst(rst),
.start(start_compute1), //start拉高一个周期表示开始
.A(buff_a1),
.B(buff_b1),
.O(buff_o1),
.done(compute1_done) //done拉高一个周期表示完成
);
block_mm compute2
(
.clk(clk),
.rst(rst),
.start(start_compute2), //start拉高一个周期表示开始
.A(buff_a2),
.B(buff_b2),
.O(buff_o2),
.done(compute2_done) //done拉高一个周期表示完成
);
endmodule
block_mm模块,计算A中某一块和B中某一块的乘法。
`timescale 1ns / 1ps
//
// Company:
// Engineer:
//
// Create Date: 2020/11/13 16:04:32
// Design Name:
// Module Name: block_mm
// Project Name:
// Target Devices:
// Tool Versions:
// Description:
//
// Dependencies:
//
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
//
//
module block_mm
#(parameter Tn=4)
(
input logic clk,
input logic rst,
input logic start, //start拉高一个周期表示开始
input logic [15:0] A[0:Tn-1][0:Tn-1],
input logic [15:0] B[0:Tn-1][0:Tn-1],
output logic [15:0] O[0:Tn-1][0:Tn-1],
output logic done //done拉高一个周期表示完成
);
int row;
int col;
int k;
logic busy;
//busy
always_ff@(posedge clk,posedge rst)
if(rst)
busy<=1'b0;
else if(start)
busy<=1'b1;
else if(row==Tn-1&&col==Tn-1&&k==Tn-1)
busy<=1'b0;
//k
always_ff@(posedge clk,posedge rst)
if(rst)
k<=0;
else if(start)
k<=0;
else if(busy)
if(k==Tn-1)
k<=0;
else
k<=k+1;
//col
always_ff@(posedge clk,posedge rst)
if(rst)
col<=0;
else if(start)
col<=0;
else if(k==Tn-1)
if(col==Tn-1)
col<=0;
else
col<=col+1;
//row
always_ff@(posedge clk,posedge rst)
if(rst)
row<=0;
else if(start)
row<=0;
else if(col==Tn-1&&k==Tn-1)
row<=row+1;
//done
always_ff@(posedge clk,posedge rst)
if(rst)
done<=1'b0;
else if(row==Tn-1&&col==Tn-1&&k==Tn-1&&done==1'b0)
done<=1'b1;
else
done<=1'b0;
//calculate matrix
always_ff@(posedge clk,posedge rst)
if(rst)
;
else if(busy)
if(k==0)
O[row][col]<=A[row][k]*B[k][col];
else
O[row][col]<=O[row][col]+A[row][k]*B[k][col];
endmodule
load_two_block模块,分别加载A和B中的某一块。
`timescale 1ns / 1ps
//
// Company:
// Engineer:
//
// Create Date: 2020/11/14 10:30:18
// Design Name:
// Module Name: load_two_block
// Project Name:
// Target Devices:
// Tool Versions:
// Description:
//
// Dependencies:
//
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
//
//
module load_two_block
(
input logic clk,
input logic rst,
input logic start,
input logic [7:0]block_row,
input logic [7:0]block_col,
input logic [7:0]block_k, //load A[block_row:block_row+Tn,block_k:block_k+Tn]
input logic [15:0]dina, //load B[block_k:bloc_k+Tn,block_col:block_col+Tn]
input logic [15:0]dinb,
output logic [7:0]addra,
output logic [7:0]addrb,
output logic [15:0]block_mat_a[0:Tn-1][0:Tn-1],
output logic [15:0]block_mat_b[0:Tn-1][0:Tn-1],
output logic done
);
parameter Tn=4;
logic done_a;
logic done_b;
assign done=done_a&&done_b;
load_block block_a(
.start(start),
.clk(clk),
.rst(rst),
.din(dina),
.addr(addra),
.block_row(block_row),
.block_col(block_k), //读取M[block_row:block_row+Tn,block_col:block_col+Tn]
.block_mat(block_mat_a),
.done(done_a)
);
load_block block_b(
.start(start),
.clk(clk),
.rst(rst),
.din(dinb),
.addr(addrb),
.block_row(block_k),
.block_col(block_col), //读取M[block_row:block_row+Tn,block_col:block_col+Tn]
.block_mat(block_mat_b),
.done(done_b)
);
endmodule
load_block模块,load_two_block的子模块,加载一个分块矩阵。
`timescale 1ns / 1ps
//
// Company:
// Engineer:
//
// Create Date: 2020/11/13 18:10:01
// Design Name:
// Module Name: load_block
// Project Name:
// Target Devices:
// Tool Versions:
// Description:
//
// Dependencies:
//
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
//
//
module load_block(
input logic start,
input logic clk,
input logic rst,
input logic [15:0] din,
output logic [7:0] addr,
input logic [7:0]block_row,
input logic [7:0]block_col, //读取M[block_row:block_row+Tn,block_col:block_col+Tn]
output logic [15:0]block_mat[0:Tn-1][0:Tn-1],
output logic done
);
parameter Tn = 4;
parameter N = 16 ;
logic [7:0]row;
logic [7:0]col;
logic [7:0]row_ff1;
logic [7:0]row_ff2;
logic [7:0]col_ff1;
logic [7:0]col_ff2;
logic busy;
logic busy_ff1;
logic busy_ff2;
logic done_ff0;
logic done_ff1;
logic done_ff2;
assign done=done_ff2;
//done_ff0
always_ff@(posedge clk,posedge rst)
if(rst)
done_ff0<=1'b0;
else if(row==block_row+Tn-1&&col==block_col+Tn-1&&~done_ff0)
done_ff0<=1'b1;
else
done_ff0<=1'b0;
//done_ff1,ff2
always_ff@(posedge clk,posedge rst)
if(rst)
begin
done_ff1<=1'b0;
done_ff2<=1'b0;
end
else
begin
done_ff1<=done_ff0;
done_ff2<=done_ff1;
end
//busy
always_ff@(posedge clk,posedge rst)
if(rst)
busy<=1'b0;
else if(start)
busy<=1'b1;
else if(row==block_row+Tn-1&&col==block_col+Tn-1)
busy<=1'b0;
//busy_ff1,busy_ff2
always_ff@(posedge clk,posedge rst)
if(rst)
begin
busy_ff1<=1'b0;
busy_ff2<=1'b0;
end
else
begin
busy_ff1<=busy;
busy_ff2<=busy_ff1;
end
//row
always_ff@(posedge clk,posedge rst)
if(rst)
row<=8'd0;
else if(start)
row<=block_row;
else if(col==block_col+Tn-1)
row<=row+1;
//col
always_ff@(posedge clk,posedge rst)
if(rst)
col<=8'd0;
else if(start)
col<=block_col;
else if(busy)
if(col==block_col+Tn-1)
col<=block_col;
else
col<=col+1;
always_ff@(posedge clk,posedge rst)
if(rst)
begin
row_ff1<=8'd0;
row_ff2<=8'd0;
col_ff1<=8'd0;
col_ff2<=8'd0;
end
else
begin
row_ff1<=row;
row_ff2<=row_ff1;
col_ff1<=col;
col_ff2<=col_ff1;
end
//addr
assign addr=(row*N+col);
//din
always_ff@(posedge clk,posedge rst)
if(rst)
;
else if(busy_ff2)
block_mat[row_ff2-block_row][col_ff2-block_col]<=din;
endmodule
testbench:
`timescale 1ns / 1ps
//
// Company:
// Engineer:
//
// Create Date: 2020/11/17 08:43:39
// Design Name:
// Module Name: compute_one_block_test
// Project Name:
// Target Devices:
// Tool Versions:
// Description:
//
// Dependencies:
//
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
//
//
module compute_one_block_test;
parameter N = 16;
parameter Tn = 4;
logic clk;
logic rst;
logic start;
logic [7:0]block_row;
logic [7:0]block_col;
logic [15:0]dina;
logic [15:0]dinb;
logic [7:0]addra;
logic [7:0]addrb;
logic [15:0]result[0:Tn-1][0:Tn-1];
logic done;
logic wea;
logic web;
logic [7:0]address_a;
logic [7:0]address_b;
logic [7:0]write_addra;
logic [7:0]write_addrb;
logic [15:0]write_data_a;
logic [15:0]write_data_b;
logic [15:0]read_data_a;
logic [15:0]read_data_b;
logic init_done;
logic init_done_ff;
//
//clk
initial
begin
clk=0;
forever
#5 clk=~clk;
end
//rst
initial
begin
rst=1;
#10
rst=0;
end
//初始化A,B
always_ff@(posedge clk,posedge rst)
if(rst)
begin
wea<=1'b0;
web<=1'b0;
write_data_a<=16'd0;
write_data_b<=16'd0;
write_addra<=8'd0;
write_addrb<=8'd0;
end
else if(write_addra<N*N-1)
begin
wea<=1'b1;
web<=1'b1;
write_data_a<=write_data_a+1;
write_data_b<=write_data_b+1;
write_addra<=write_addra+1;
write_addrb<=write_addrb+1;
end
else
begin
wea<=1'b0;
web<=1'b0;
end
//init_done
always_ff@(posedge clk,posedge rst)
if(rst)
begin
init_done<=1'b0;
end
else if(write_addra==N*N-1)
init_done<=1'b1;
//init_done_ff
always_ff@(posedge clk,posedge rst)
if(rst)
init_done_ff<=1'b0;
else
init_done_ff<=init_done;
//start
always_ff@(posedge clk,posedge rst)
if(rst)
begin
start<=1'b0;
block_row<=8'd0;
block_col<=8'd0;
end
else if(init_done&&~init_done_ff&&~start)
begin
start<=1'b1;
block_row<=8'd0;
block_col<=8'd0;
end
else
start<=1'b0;
//
assign dina=read_data_a;
assign dinb=read_data_b;
assign address_a=(init_done==1'b1)?addra:write_addra;
assign address_b=(init_done==1'b1)?addrb:write_addrb;
//模块例化
compute_one_block U(
.clk(clk),
.rst(rst),
.start(start),
.dina(dina),
.dinb(dinb),
.block_row(block_row),
.block_col(block_col),
.addra(addra),
.addrb(addrb),
.result(result),
.done(done)
);
Matrix A (
.clka(clk), // input wire clka
.ena(1'b1), // input wire ena
.wea(wea), // input wire [0 : 0] wea
.addra(address_a), // input wire [7 : 0] addra
.dina(write_data_a), // input wire [15 : 0] dina
.douta(read_data_a) // output wire [15 : 0] douta
);
Matrix B (
.clka(clk), // input wire clka
.ena(1'b1), // input wire ena
.wea(web), // input wire [0 : 0] wea
.addra(address_b), // input wire [7 : 0] addra
.dina(write_data_b), // input wire [15 : 0] dina
.douta(read_data_b) // output wire [15 : 0] douta
);
endmodule
仿真波形及结果
上图中的result数组即计算结果,和C++的计算结果相同(这里采用的是16位无符号整数,因此,c++中的int型最终结果要对
2
16
2^{16}
216取模)