卷积神经网络LeNet-5的RTL实现(三)

卷积神经网络LeNet-5的RTL实现(三):卷积层

前文回顾

上一期文章利用FIFO和padding电路实现了卷积神经网络中的padding层。本期文章将详细介绍基于Shift RAM的卷积电路实现。

卷积简介

CNN的卷积层运算主要为一个窗口内的FMAP数据和卷积核对应数据乘累加的结果,如何获得卷积窗口的FMAP数据已经在前文《卷积神经网络LeNet-5的RTL实现(一)》中阐述。前面提到,利用Shift RAM获得的窗口会出现无效数据,在卷积电路中会对这些窗口进行判断和剔除。CNN网络结构中,一个卷积层通常有多个FMAP通道与卷积核,本期文章仅实现单通道、单卷积核的卷积运算,多通道多核的卷积运算将在后续文章中实现。

卷积电路

接口

卷积电路的接口如下所示:

module CONV2D
#(
	parameter DATA_WIDTH = 16,
	parameter FMAP_SIZE = 32,							//输入特征图尺寸
	parameter KERNEL_SIZE = 5,							//卷积核尺寸
	parameter STRIDE = 1								//卷积步长
)
(
	input clk, 
	input rst_n,
	input ena,
	input clear,
	input [DATA_WIDTH*KERNEL_SIZE-1:0]tap,				//Shift RAM输出的tap
	input [KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0]w,	//卷积核参数
	output [DATA_WIDTH-1:0]conv_out,					//卷积输出
	output valid,										//输出有效标志
	output done											//卷积结束标志
);

数据输入

卷积模块的窗口数据由Shift RAM的tap输入。Shift RAM在被填满后,每个时钟周期输出一个窗口列向量,卷积模块获得该列向量并移入窗口中,示意图如下:

在这里插入图片描述
每个时钟周期到来时,window_data寄存器组每行按元素左移,同时行尾接收tap的输入数据,将整个window_data填满后就产生了一个有效窗口数据。数据输入部分代码如下:

	//need two generate block as 'window_data' has a reg type, while 'tap' is wire type
	generate
		if(KERNEL_SIZE > 1) begin
			for(i = 0 ; i < KERNEL_SIZE; i = i + 1) begin
				for(j = 0; j < KERNEL_SIZE - 2; j = j + 1) begin
					always@(posedge clk)
						window_data[(i)*KERNEL_SIZE+j] <= window_data[i*KERNEL_SIZE+j+1];
				end
			end
	
			if(KERNEL_SIZE > 1) begin
				for(i = 0 ; i < KERNEL_SIZE; i = i + 1) begin
					always@(posedge clk)
						window_data[(i)*KERNEL_SIZE+KERNEL_SIZE-2] <= tap[(i+1)*DATA_WIDTH-1-:DATA_WIDTH];
				end
			end
		end
		else begin
			always@(posedge clk)
				window_data[0] <= tap[DATA_WIDTH-1:0];
		end
	endgenerate

由于window_data为reg类型,tap为wire类型,这里使用了两个generate-for块,分别实现内部元素左移和tap数据输入window_data行尾两个过程。此外,当特殊情况即卷积核尺寸为1时,tap即为输入数据,不需要数据窗口进行缓存,在if(KERNEL_SIZE > 1) 中体现。

卷积运算逻辑

卷积运算主要由乘法器和加法器实现。为了缩短关键路径,本工程将卷积运算分为相乘部分和产生卷积结果产生三级流水线,分别实现数据窗口与对应卷积核的相乘、行卷积结果累加、列卷积结果累加三个步骤,其中两个累加级采用二叉加法树实现.。三级流水线的输出分别存放在product[i][j]partial_sum[i]sum三个中间结果中,其代码如下:

	//stage 1: generate multiplication product
	generate 
		if(KERNEL_SIZE > 1) begin
			for(i = 0 ; i < KERNEL_SIZE; i = i + 1) begin
				for(j = 0 ; j < KERNEL_SIZE-1; j = j + 1) begin
					always@(posedge clk)
						product[i*KERNEL_SIZE+KERNEL_SIZE-1] <= w[((i*KERNEL_SIZE+KERNEL_SIZE-1)+1)*DATA_WIDTH-1-:DATA_WIDTH] * tap[(i+1)*DATA_WIDTH-1-:DATA_WIDTH];
				end
			end
	
			for(i = 0 ; i < KERNEL_SIZE; i = i + 1) begin
				for(j = 0 ; j < KERNEL_SIZE-1; j = j + 1) begin
					always@(posedge clk)
						product[i*KERNEL_SIZE+j] <= w[((i*KERNEL_SIZE+j)+1)*DATA_WIDTH-1-:DATA_WIDTH] * window_data[i*KERNEL_SIZE+j];
				end
			end
		end
		
		else begin
			always @(posedge clk)
				product[0] <= w[0] * tap[DATA_WIDTH-1:0];
		end
	endgenerate
	
	//stage 2: generate partial sum 
	generate
		if(KERNEL_SIZE > 1) begin
			for(i = 0; i < KERNEL_SIZE; i = i + 1) begin
				always @(posedge clk or negedge rst_n)
					if(!rst_n)
						partial_sum[i] <= 0;
					else
						partial_sum[i] <= partial_adder_out[i*KERNEL_SIZE+KERNEL_SIZE-2];
			end
		end
		else begin
			always @(posedge clk or negedge rst_n)
				if(!rst_n)
					partial_sum[0] <= 0;
				else
					partial_sum[0] <= product[0];
		end
	endgenerate
	
	//stage 3: generate conv sum 
	generate 
		if(KERNEL_SIZE > 1) begin
			always@(posedge clk or negedge rst_n)
				if(!rst_n)
					sum <= 0;
				else
					sum <= sum_adder_out[KERNEL_SIZE-2];
		end
		else begin
			always@(posedge clk or negedge rst_n)
				if(!rst_n)
					sum <= 0;
				else
					sum <= partial_sum[0];
		end
	endgenerate

同样地,当卷积核尺寸为1时,卷积核与特征图对应像素的乘积即为卷积输出,不需要中间级的运算,但为了时序的统一,卷积核尺寸为1的情况还是保留了三级流水线。
两个加法级中的partial_adder_outsum_adder_out是加法树的输出,加法树例化如下:

	//partial sum adder tree
	//generated when KERNEL_SIZE > 1
	generate
		if(KERNEL_SIZE > 1) begin
			for(i = 0; i < KERNEL_SIZE; i = i + 1) begin
				ADDER
				#(
					.DATA_WIDTH(DATA_WIDTH)
				)
				u_partial_adder_head
				(
					.ina(product[i*KERNEL_SIZE]),
					.inb(product[i*KERNEL_SIZE+1]),
					.out(partial_adder_out[i*KERNEL_SIZE])
				);
			end
		
			for(i = 0; i < KERNEL_SIZE; i = i + 1) begin
				for(j = 2; j < KERNEL_SIZE; j = j + 1) begin
					ADDER
					#(
						.DATA_WIDTH(DATA_WIDTH)
					)
					u_partial_adder_remain
					(
						.ina(partial_adder_out[i*KERNEL_SIZE+j-2]),
						.inb(product[i*KERNEL_SIZE+j]),
						.out(partial_adder_out[i*KERNEL_SIZE+j-1])
					);
				end
			end
		end
	endgenerate
		
	//conv sum adder tree
	//generated when KERNEL_SIZE > 1
	generate
		if(KERNEL_SIZE > 1) begin
			ADDER
			#(
				.DATA_WIDTH(DATA_WIDTH)
			)
			u_sum_adder_head
			(
				.ina(partial_sum[0]),
				.inb(partial_sum[1]),
				.out(sum_adder_out[0])
			);
	
			for(i = 2; i < KERNEL_SIZE; i = i + 1) begin
				ADDER
				#(
					.DATA_WIDTH(DATA_WIDTH)
				)
				u_sum_adder_remain
				(
					.ina(sum_adder_out[i-2]),
					.inb(partial_sum[i]),
					.out(sum_adder_out[i-1])
				);
			end
		end
	endgenerate

第一级加法树实现了乘积结果行向量各元素的累加,第二级加法树实现了第一级输出的列向量各元素的累加。每一级的加法树都分为adder_head和adder_remain两个部分,这是由于加法器的输入不同,具体细节可以参考第一期文章。

理论上卷积模块到这里就可以结束了,但由于需要剔除Shift RAM输出的无效数据,还要增加很多判断逻辑。下面的内容均关于无效数据的判断和剔除。

计数信号

卷积电路利用三个计数信号cntcnt_colcnt_stride_colcnt_stride_row分别记录输入的总像素、当前窗口所在列、当前窗口所在步长间隔,用于数据有效性判断:

	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt <= 0;
		else if(clear)
			cnt <= 0;
		else if(!ena)
			cnt <= cnt;
		else
			cnt <= cnt + 1;
	
	//valid data window detector
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt_col <= 0;
		else if(clear)
			cnt_col <= 0;
		else if(!ena)
			cnt_col <= cnt_col;
		else if(shift) begin
			if(cnt_col == FMAP_SIZE - 1)
				cnt_col <= 0;
			else
				cnt_col <= cnt_col + 1;
		end
		
	//valid shift period detector
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			shift <= 0;
		else if(clear)
			shift <= 0;
		else if(!ena)
			shift <= 0;
		else if(cnt == FMAP_SIZE * KERNEL_SIZE + (KERNEL_SIZE - 1) - 1)
			shift <= 1'b1;
		else if(cnt == FMAP_SIZE * KERNEL_SIZE + FMAP_SIZE * (FMAP_SIZE - KERNEL_SIZE + 1) - 1)
			shift <= 1'b0;
    
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt_stride_col <= 0;
		else if(clear)
			cnt_stride_col <= 0;
		else if(ena) begin
			if(shift) begin
				if(cnt_stride_col < STRIDE - 1)
					cnt_stride_col <= cnt_stride_col + 1;
				else
					cnt_stride_col = 0;
			end
		end
			
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt_stride_row <= 0;
		else if(clear)
			cnt_stride_row <= 0;
		else if(ena) begin
			if(cnt_col == FMAP_SIZE - 1) begin
				if(cnt_stride_row < STRIDE - 1)
					cnt_stride_row <= cnt_stride_row + 1;
				else
					cnt_stride_row = 0;
			end
		end

全局计数cnt在模块被使能的每个周期都加一,而cnt_colcnt_strideshift信号有效时才发生变化。shift信号标志着卷积窗口是否处于有效滑动周期,例如当Shift RAM还未被填满时,窗口处于无效滑动周期,shift无效,cnt_colcnt_stride_colcnt_stride_row不发生变化。

卷积数据有效性

卷积数据有效需要同时满足窗口数据有效卷积结果有效两个条件。

窗口数据有效性

窗口数据的判断由多个判断条件共同构成:

  1. 窗口是否处于有效移动周期
  2. 窗口是否“错行”
  3. 窗口是否在步长间隔内
1. 窗口是否处于有效滑动周期

窗口是否处于有效滑动周期由shift信号进行标志。窗口的有效滑动周期为:

  • Shift RAM被填满并输出第一个窗口之后
  • 窗口移动到特征图最后一个位置之前

其中,填满Shift RAM需要的周期为TAP_LENGTH × TAP_NUM,输出第一个窗口还需要额外TAP_NUM - 1个周期(第一个tap已经在Shift RAM被填满时输出了),因此窗口的有效滑动起始周期为TAP_LENGTH × TAP_NUM + TAP_NUM - 1。在这里,TAP_LENGTH 为特征图的尺寸FMAP_SIZE,TAP_NUM为卷积核尺寸KERNEL_SIZE,经过代换后得到窗口的有效滑动起始周期为FMAP_SIZE × KERNEL_SIZE + (KERNEL_SIZE - 1)。依次类推。可得最后一个有效滑动周期为FMAP_SIZE × TAP_NUM + FMAP_SIZE × (FMAP_SIZE - KERNEL_SIZE + 1),即FMAP的最后一个卷积窗口。产生shift信号的代码如下:

	//valid shifting period detector
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			shift <= 0;
		else if(clear)
			shift <= 0;
		else if(!ena)
			shift <= 0;
		else if(cnt == FMAP_SIZE * KERNEL_SIZE + (KERNEL_SIZE - 1) - 1)
			shift <= 1'b1;
		else if(cnt == FMAP_SIZE * KERNEL_SIZE + FMAP_SIZE * (FMAP_SIZE - KERNEL_SIZE + 1) - 1)
			shift <= 1'b0;
		
2. 窗口是否“错行”

“错行“是指窗口的某一部分进入了FMAP的下一行中,如图所示:

在这里插入图片描述
从图中可以看到,当窗口第一列的位置大于等于FMAP_SIZE - (KERNEL_SIZE - 1) 时,会发生错行。

3. 窗口是否在步长间隔内

有的卷积层步长不为1,在步长间隔内的窗口也是无效窗口,通过cnt_stride_colcnt_stride_row进行判断。

综合上述三个有效判断条件,窗口有效性的判断最终为:

	assign stride_valid = (cnt_stride_col == 0 & cnt_stride_row == 0) ? 1 : 0;
	assign shift_valid = (shift && (cnt_col < FMAP_SIZE - (KERNEL_SIZE - 1)));
	assign win_valid = (shift_valid && stride_valid);
卷积结果有效性

卷积结果有效是指窗口数据生成后,需要经过若干的时钟周期的运算才会输出正确卷积结果。由于本工程的卷积运算分为三级流水线,在第一个窗口数据生成后,需要经过3个时钟周期才会输出第一个卷积结果(一个乘法周期,两个加法周期),此后每个周期都能输出一个卷积结果。因此,卷积结果的第一个有效周期为FMAP_SIZE × KERNEL_SIZE + (KERNEL_SIZE - 1) + 3,卷积结果的最后一个有效周期为FMAP_SIZE × TAP_NUM + FMAP_SIZE × (FMAP_SIZE - KERNEL_SIZE + 1) + 3。卷积结果有效性的判断代码如下:

	//valid sum period detector
	always @(posedge clk or negedge rst_n)
	    if(!rst_n)
	    	sum_valid <= 0;
		else if(clear)
	    	sum_valid <= 0;
		else if(!ena)
	    	sum_valid <= 0;
	    else if(cnt == FMAP_SIZE * KERNEL_SIZE + (KERNEL_SIZE - 1) + (1 + 2) - 1)
	    	sum_valid <= 1;
		else if(cnt == FMAP_SIZE * KERNEL_SIZE + FMAP_SIZE * (FMAP_SIZE - KERNEL_SIZE + 1) + (1 + 2) - 1)
	    	sum_valid <= 0;

当窗口数据和卷积结果均有效时,卷积模块的输出才有效。需要注意,由于卷积运算结果滞后于窗口数据2个周期,需要插入2个寄存器进行同步:

	 //'conv sum valid' sig comes 2 clocks later than window valid sig,
	 //insert 2 regs
	 always @(posedge clk or negedge rst_n)
	 	if(!rst_n) begin
	 		win_valid_s1 <= 0;
	 		win_valid_s2 <= 0;
	 		win_valid_s3 <= 0;
	 	end
	 	else if(clear) begin
	 		win_valid_s1 <= 0;
	 		win_valid_s2 <= 0;
	 		win_valid_s3 <= 0;
	 	end
	 	else if(ena) begin
	 		win_valid_s1 <= win_valid;
	 		win_valid_s2 <= win_valid_s1;
	 		win_valid_s3 <= win_valid_s2;
	 	end
	 	else;

卷积数据有效标志的判断逻辑如下:

assign valid = win_valid_s3 & sum_valid;

电路完整代码

/*CONV2D.v*/
module CONV2D
#(
	parameter DATA_WIDTH = 16,
	parameter FMAP_SIZE = 32,
	parameter KERNEL_SIZE = 5,
	parameter STRIDE = 1
)
(
	input clk, 
	input rst_n,
	input ena,
	input clear,
	input [DATA_WIDTH*KERNEL_SIZE-1:0]tap,
	input [KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0]w,
	output [DATA_WIDTH-1:0]conv_out,
	output valid,
	output done
);
	
	function integer clogb2 (input integer bit_depth);
	begin
        for(clogb2 = 0; bit_depth > 0; clogb2 = clogb2 + 1)
            bit_depth = bit_depth >> 1;
	end
	endfunction
	
	localparam CNT_BIT_NUM = clogb2((FMAP_SIZE * (FMAP_SIZE + 1)));
	localparam CNT_LINE_BIT_NUM = clogb2(FMAP_SIZE);
	
	reg [DATA_WIDTH-1:0] window_data[KERNEL_SIZE*KERNEL_SIZE-1:0];
	reg [DATA_WIDTH-1:0] product[KERNEL_SIZE*KERNEL_SIZE-1:0];
	reg [DATA_WIDTH-1:0] partial_sum[KERNEL_SIZE-1:0];
	reg [DATA_WIDTH-1:0] sum;
	wire[DATA_WIDTH-1:0] partial_adder_out[KERNEL_SIZE*KERNEL_SIZE-1:0];
	wire[DATA_WIDTH-1:0] sum_adder_out[KERNEL_SIZE-2:0];
	
	reg[CNT_BIT_NUM-1:0] cnt;
	reg[CNT_LINE_BIT_NUM-1:0] cnt_col;
	reg[CNT_LINE_BIT_NUM-1:0] cnt_row;
	reg[CNT_LINE_BIT_NUM-1:0] cnt_stride_col;
	reg[CNT_LINE_BIT_NUM-1:0] cnt_stride_row;
	
	reg shift;
	reg sum_valid;
	wire stride_valid;
	wire shift_valid;
	wire win_valid;
	reg win_valid_s1;
	reg win_valid_s2;
	reg win_valid_s3;
	
	genvar i, j;
	
/*******************************************generate signals************************************/
	assign stride_valid = (cnt_stride_col == 0 & cnt_stride_row == 0) ? 1 : 0;
	assign shift_valid = (shift && (cnt_col < FMAP_SIZE - (KERNEL_SIZE - 1)));
	assign win_valid = (shift_valid && stride_valid);
	assign valid = win_valid_s3 & sum_valid;
	assign done = (ena && (cnt == FMAP_SIZE * KERNEL_SIZE + FMAP_SIZE * (FMAP_SIZE - KERNEL_SIZE + 1) + (2 + 1) - 1)) ? 1 : 0;
	assign conv_out = (valid) ? sum : 0;
	
/******************************************data window shift-in**********************************/
	//need two generate block as 'window_data' has a reg type, while 'tap' is wire type
	generate
		if(KERNEL_SIZE > 1) begin
			for(i = 0 ; i < KERNEL_SIZE; i = i + 1) begin
				for(j = 0; j < KERNEL_SIZE - 2; j = j + 1) begin
					always@(posedge clk)
						window_data[(i)*KERNEL_SIZE+j] <= window_data[i*KERNEL_SIZE+j+1];
				end
			end
	
			if(KERNEL_SIZE > 1) begin
				for(i = 0 ; i < KERNEL_SIZE; i = i + 1) begin
					always@(posedge clk)
						window_data[(i)*KERNEL_SIZE+KERNEL_SIZE-2] <= tap[(i+1)*DATA_WIDTH-1-:DATA_WIDTH];
				end
			end
		end
		else begin
			always@(posedge clk)
				window_data[0] <= tap[DATA_WIDTH-1:0];
		end
	endgenerate
	
	
/******************************************adder tree**********************************/
	//partial sum adder tree
	//generated when KERNEL_SIZE > 1
	generate
		if(KERNEL_SIZE > 1) begin
			for(i = 0; i < KERNEL_SIZE; i = i + 1) begin
				ADDER
				#(
					.DATA_WIDTH(DATA_WIDTH)
				)
				u_partial_adder_head
				(
					.ina(product[i*KERNEL_SIZE]),
					.inb(product[i*KERNEL_SIZE+1]),
					.out(partial_adder_out[i*KERNEL_SIZE])
				);
			end
		
			for(i = 0; i < KERNEL_SIZE; i = i + 1) begin
				for(j = 2; j < KERNEL_SIZE; j = j + 1) begin
					ADDER
					#(
						.DATA_WIDTH(DATA_WIDTH)
					)
					u_partial_adder_remain
					(
						.ina(partial_adder_out[i*KERNEL_SIZE+j-2]),
						.inb(product[i*KERNEL_SIZE+j]),
						.out(partial_adder_out[i*KERNEL_SIZE+j-1])
					);
				end
			end
		end
	endgenerate
		
	//conv sum adder tree
	//generated when KERNEL_SIZE > 1
	generate
		if(KERNEL_SIZE > 1) begin
			ADDER
			#(
				.DATA_WIDTH(DATA_WIDTH)
			)
			u_sum_adder_head
			(
				.ina(partial_sum[0]),
				.inb(partial_sum[1]),
				.out(sum_adder_out[0])
			);
	
			for(i = 2; i < KERNEL_SIZE; i = i + 1) begin
				ADDER
				#(
					.DATA_WIDTH(DATA_WIDTH)
				)
				u_sum_adder_remain
				(
					.ina(sum_adder_out[i-2]),
					.inb(partial_sum[i]),
					.out(sum_adder_out[i-1])
				);
			end
		end
	endgenerate
	
/*****************************************3-stage pipeline*********************************/
	//stage 1: generate multiplication product
	generate 
		if(KERNEL_SIZE > 1) begin
			for(i = 0 ; i < KERNEL_SIZE; i = i + 1) begin
				for(j = 0 ; j < KERNEL_SIZE-1; j = j + 1) begin
					always@(posedge clk)
						product[i*KERNEL_SIZE+KERNEL_SIZE-1] <= w[((i*KERNEL_SIZE+KERNEL_SIZE-1)+1)*DATA_WIDTH-1-:DATA_WIDTH] * tap[(i+1)*DATA_WIDTH-1-:DATA_WIDTH];
				end
			end
	
			for(i = 0 ; i < KERNEL_SIZE; i = i + 1) begin
				for(j = 0 ; j < KERNEL_SIZE-1; j = j + 1) begin
					always@(posedge clk)
						product[i*KERNEL_SIZE+j] <= w[((i*KERNEL_SIZE+j)+1)*DATA_WIDTH-1-:DATA_WIDTH] * window_data[i*KERNEL_SIZE+j];
				end
			end
		end
		
		else begin
			always @(posedge clk)
				product[0] <= w[0] * tap[DATA_WIDTH-1:0];
		end
	endgenerate
	
	//stage 2: generate partial sum 
	generate
		if(KERNEL_SIZE > 1) begin
			for(i = 0; i < KERNEL_SIZE; i = i + 1) begin
				always @(posedge clk or negedge rst_n)
					if(!rst_n)
						partial_sum[i] <= 0;
					else
						partial_sum[i] <= partial_adder_out[i*KERNEL_SIZE+KERNEL_SIZE-2];
			end
		end
		else begin
			always @(posedge clk or negedge rst_n)
				if(!rst_n)
					partial_sum[0] <= 0;
				else
					partial_sum[0] <= product[0];
		end
	endgenerate
	
	//stage 3: generate conv sum 
	generate 
		if(KERNEL_SIZE > 1) begin
			always@(posedge clk or negedge rst_n)
				if(!rst_n)
					sum <= 0;
				else
					sum <= sum_adder_out[KERNEL_SIZE-2];
		end
		else begin
			always@(posedge clk or negedge rst_n)
				if(!rst_n)
					sum <= 0;
				else
					sum <= partial_sum[0];
		end
	endgenerate
			
			
/**********************************************generate control logic*****************************/	
	//global counter
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt <= 0;
		else if(clear)
			cnt <= 0;
		else if(!ena)
			cnt <= cnt;
		else
			cnt <= cnt + 1;
	
	//valid data window detector
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt_col <= 0;
		else if(clear)
			cnt_col <= 0;
		else if(!ena)
			cnt_col <= cnt_col;
		else if(shift) begin
			if(cnt_col == FMAP_SIZE - 1)
				cnt_col <= 0;
			else
				cnt_col <= cnt_col + 1;
		end
	
	//valid shift period detector
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			shift <= 0;
		else if(clear)
			shift <= 0;
		else if(!ena)
			shift <= 0;
		else if(cnt == FMAP_SIZE * KERNEL_SIZE + (KERNEL_SIZE - 1) - 1)
			shift <= 1'b1;
		else if(cnt == FMAP_SIZE * KERNEL_SIZE + FMAP_SIZE * (FMAP_SIZE - KERNEL_SIZE + 1) - 1)
			shift <= 1'b0;
    
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt_stride_col <= 0;
		else if(clear)
			cnt_stride_col <= 0;
		else if(ena) begin
			if(shift) begin
				if(cnt_stride_col < STRIDE - 1)
					cnt_stride_col <= cnt_stride_col + 1;
				else
					cnt_stride_col = 0;
			end
		end
			
	always @(posedge clk or negedge rst_n)
		if(!rst_n)
			cnt_stride_row <= 0;
		else if(clear)
			cnt_stride_row <= 0;
		else if(ena) begin
			if(cnt_col == FMAP_SIZE - 1) begin
				if(cnt_stride_row < STRIDE - 1)
					cnt_stride_row <= cnt_stride_row + 1;
				else
					cnt_stride_row = 0;
			end
		end
	
	//valid sum period detector
	always @(posedge clk or negedge rst_n)
	    if(!rst_n)
	    	sum_valid <= 0;
		else if(clear)
	    	sum_valid <= 0;
		else if(!ena)
	    	sum_valid <= 0;
	    else if(cnt == FMAP_SIZE * KERNEL_SIZE + (KERNEL_SIZE - 1) + (1 + 2) - 1)
	    	sum_valid <= 1;
		else if(cnt == FMAP_SIZE * KERNEL_SIZE + FMAP_SIZE * (FMAP_SIZE - KERNEL_SIZE + 1) + (1 + 2) - 1)
	    	sum_valid <= 0;

	 //'conv sum valid' sig comes 2 clocks later than ‘window valid’ sig,
	 //insert 2 regs
	 always @(posedge clk or negedge rst_n)
	 	if(!rst_n) begin
	 		win_valid_s1 <= 0;
	 		win_valid_s2 <= 0;
	 		win_valid_s3 <= 0;
	 	end
	 	else if(clear) begin
	 		win_valid_s1 <= 0;
	 		win_valid_s2 <= 0;
	 		win_valid_s3 <= 0;
	 	end
	 	else if(ena) begin
	 		win_valid_s1 <= win_valid;
	 		win_valid_s2 <= win_valid_s1;
	 		win_valid_s3 <= win_valid_s2;
	 	end
	 	
endmodule

TestBench

`timescale 1ns / 1ns

module tb_conv2d();

	parameter 	DATA_WIDTH = 16,
				FMAP_SIZE = 28,
				KERNEL_SIZE = 5;
   	
   	reg clk, rst_n;
   	reg ena;
   	reg[DATA_WIDTH-1:0] shift_in;
   	wire[DATA_WIDTH-1:0] shift_out;
   	wire[KERNEL_SIZE*DATA_WIDTH-1:0] tap;
   	reg [KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] w;
   	wire[DATA_WIDTH-1:0] conv_out;
   	wire conv_done;
   	wire valid;
   	
   	integer i;

	initial begin
		clk = 0;
		rst_n = 0;
		ena = 0;
		shift_in = 0;
		for(i = 0; i < KERNEL_SIZE*KERNEL_SIZE; i = i + 1)
			w[((i+1)*DATA_WIDTH-1)-:DATA_WIDTH] = 1;
		#100;
		rst_n = 1;
		ena = 1;
		
		repeat(784) begin
			#100
			clk = ~clk;
			#100
			clk = ~clk;			
			shift_in = shift_in + 1;
		end
		
		forever begin
			#100
			clk = ~clk;
			#100
			clk = ~clk;			
		end	
	end
	
	SHIFT_RAM
	# (	
		.DATA_WIDTH(DATA_WIDTH),
		.TAP_NUM(KERNEL_SIZE),
	   	.TAP_LENGTH(FMAP_SIZE)
	)
	u_shift_ram(
		.clk(clk),
		.rst_n(rst_n),
		.clear(),
		.ena(ena),
		.shift_in(shift_in),
		.shift_out(shift_out),
		.taps(tap)
	);
	
	CONV2D
	#(
		.DATA_WIDTH(DATA_WIDTH),
		.FMAP_SIZE(FMAP_SIZE),
		.KERNEL_SIZE(KERNEL_SIZE),
		.STRIDE(1)
	)
	u_conv
	(
		.clk(clk), 
		.rst_n(rst_n),
		.ena(ena),
		.clear(),
		.tap(tap),
		.w(w),
		.conv_out(conv_out),
		.valid(valid),
		.done(conv_done)
	);
	
endmodule

FMAP数据为0-783,在repeat块中产生,输入Shift RAM。Shift RAM输出的tap连接到卷积模块的输入。卷积核数据w均为1。

仿真结果

在这里插入图片描述

在这里插入图片描述
由波形图可知,有效卷积结果从29.4us开始输出,到162.8us结束,同时在最后一个卷积结果时输出一个时钟的done高电平信号。

总结

本期文章详细介绍了基于Shift RAM的卷积电路的实现,其中判断和剔除无效数据占据了大量篇幅。下一期文章将介绍池化(pooling)电路的实现。

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值