目录
1. 简介
这是一个包含使用数据流的级联矩阵乘法的内核。该内核启用了 ap_ctrl_chain,以展示如何重叠多个内核调用队列以提供更高的性能。 ap_ctrl_chain 允许内核在完成当前内核操作之前开始处理下一个内核操作。
2. 示例解析
2.1 示例功能说明
这个例子展示了链式内核和简单内核之间的性能差异。
#pragma HLS INTERFACE s_axilite port = return bundle = control
#pragma HLS INTERFACE ap_ctrl_chain port = return bundle = control
示例包含两个内核:
1. krnl_simple_mmult:使用 ap_ctrl_hs 协议的内核。
#pragma HLS INTERFACE ap_ctrl_hs port = return
2. krnl_chain_mmult:使用 ap_ctrl_chain 协议的内核;
#pragma HLS INTERFACE ap_ctrl_chain port = return
ap_ctrl_chain 实现一组块级控制端口,用于启动设计操作、继续操作,并指示设计何时空闲、完成以及准备好新的输入数据。 ap_ctrl_chain 接口模式与 ap_ctrl_hs 类似,但提供额外的输入信号 ap_continue 来施加背压。 Xilinx 建议在将 Vivado HLS 块链接在一起时使用 ap_ctrl_chain 块级 I/O 协议。
#include "ap_axi_sdata.h"
#include "ap_int.h"
#include "hls_stream.h"
#include "krnl_mmult.hpp"
extern "C" {
void krnl_simple_mmult(int* a, int* b, int* c, int* d, int* output, int dim) {
#pragma HLS INTERFACE m_axi port = a offset = slave bundle = gmem0
#pragma HLS INTERFACE m_axi port = b offset = slave bundle = gmem1
#pragma HLS INTERFACE m_axi port = c offset = slave bundle = gmem2
#pragma HLS INTERFACE m_axi port = d offset = slave bundle = gmem3
#pragma HLS INTERFACE m_axi port = output offset = slave bundle = gmem3
#pragma HLS INTERFACE s_axilite port = a
#pragma HLS INTERFACE s_axilite port = b
#pragma HLS INTERFACE s_axilite port = c
#pragma HLS INTERFACE s_axilite port = d
#pragma HLS INTERFACE s_axilite port = output
#pragma HLS INTERFACE s_axilite port = dim
#pragma HLS INTERFACE s_axilite port = return
#pragma HLS INTERFACE ap_ctrl_hs port = return
#pragma HLS STABLE variable = a
#pragma HLS STABLE variable = b
#pragma HLS STABLE variable = c
#pragma HLS STABLE variable = d
#pragma HLS STABLE variable = output
hls::stream<pkt> strm_a, strm_b, strm_c, strm_d;
hls::stream<int> strm_ctrl_trans1, strm_ctrl_trans2, strm_ctrl_trans3, strm_ctrl_trans4, strm_ctrl_trans5;
int tmp = dim;
strm_ctrl_trans1.write(tmp);
#pragma HLS STREAM variable = strm_ctrl_trans1 depth = 2
#pragma HLS STREAM variable = strm_ctrl_trans2 depth = 2
#pragma HLS STREAM variable = strm_ctrl_trans3 depth = 2
#pragma HLS STREAM variable = strm_ctrl_trans4 depth = 2
#pragma HLS STREAM variable = strm_ctrl_trans5 depth = 2
#pragma HLS STREAM variable = strm_a depth = 64
#pragma HLS STREAM variable = strm_b depth = 64
#pragma HLS STREAM variable = strm_c depth = 64
#pragma HLS STREAM variable = strm_d depth = 64
#pragma HLS DATAFLOW
mm2s(a, strm_a, strm_ctrl_trans1, strm_ctrl_trans2);
mmult(strm_a, b, strm_ctrl_trans2, strm_b, strm_ctrl_trans3);
mmult(strm_b, c, strm_ctrl_trans3, strm_c, strm_ctrl_trans4);
mmult(strm_c, d, strm_ctrl_trans4, strm_d, strm_ctrl_trans5);
s2mm(strm_d, output, strm_ctrl_trans5);
}
}
2.2 函数说明
- mm2s 函数:这个函数将矩阵的行读入并写入一个流,同时发送控制信号以指示矩阵的大小。
- mmult 函数:这个函数实现了矩阵乘法。它首先读取控制信号来确定矩阵的大小。然后,它创建了四个临时数组来存储矩阵的行、列、乘积和最终结果。接下来,它使用三个循环来计算每个结果元素的值,通过将相应行和列的元素相乘和相加得到。
- s2mm 函数:这个函数将矩阵的结果读入并写入一个流,同时发送控制信号以指示矩阵的大小。
下面分别对三个函数展开说明:
2.2.1 mmult 函数
// Template to avoid signature conflict in sw_em
template <int DUMMY = 0>
void mmult(hls::stream<pkt>& strm_a,
int* b,
hls::stream<int>& strm_ctrl_trans2,
hls::stream<pkt>& strm_out,
hls::stream<int>& strm_ctrl_trans3) {
int dim = strm_ctrl_trans2.read();
strm_ctrl_trans3.write(dim);
int size = dim * dim;
int buf_a[MAT_DIM][MAT_DIM];
int buf_b[MAT_DIM][MAT_DIM];
int buf_out[MAT_DIM][MAT_DIM];
int temp_sum[MAT_DIM];
int i, j, itr;
// Auto-pipeline is going to apply pipeline to these loops
read_strm_in1:
for (itr = 0, i = 0, j = 0; itr < size; itr++, j++) {
if (j == dim) {
j = 0;
i++;
}
pkt temp = strm_a.read();
buf_a[i][j] = temp;
}
read_mm_in2:
for (itr = 0, i = 0, j = 0; itr < size; itr++, j++) {
if (j == dim) {
j = 0;
i++;
}
buf_b[i][j] = b[i * dim + j];
}
mmult_strm_1:
for (int row = 0; row < dim; row++) {
mmult_strm_2:
for (int col = 0; col < dim; col++) {
int result = 0;
mmult_strm_3:
for (int l = 0; l < dim; l++) {
// As the outer loop is not a perfect loop
#pragma HLS loop_flatten off
result += buf_a[row][l] * buf_b[l][col];
}
buf_out[row][col] = result;
}
}
write_strm_out:
for (itr = 0, i = 0, j = 0; itr < size; itr++, j++) {
if (j == dim) {
j = 0;
i++;
}
pkt temp;
temp = buf_out[i][j];
strm_out.write(temp);
}
}
2.2.2 mm2s 函数
template <int DUMMY = 0>
void mm2s( int* a,
hls::stream<pkt>& strm_a,
hls::stream<int>& strm_ctrl_trans1,
hls::stream<int>& strm_ctrl_trans2)
{
int dim = strm_ctrl_trans1.read();
strm_ctrl_trans2.write(dim);
int size = dim * dim;
// Auto-pipeline is going to apply pipeline to this loop
mm2s:
for (int i = 0; i < size; i++) {
pkt p1;
p1 = a[i];
strm_a.write(p1);
}
}
2.2.3 s2mm 函数
template <int DUMMY = 0>
void s2mm(hls::stream<pkt>& strm_in,
int* output,
hls::stream<int>& strm_ctrl_trans5)
{
int dim = strm_ctrl_trans5.read();
int size = dim * dim;
write_output:
for (int i = 0; i < size; i++) {
pkt temp = strm_in.read();
output[i] = temp;
}
}
2.2.4 总示意图
3. 总结
这个内核的核心功能是通过使用数据流和级联矩阵乘法来提高性能。它利用了 ap_ctrl_chain 协议,这允许内核在当前操作完成后立即开始下一个操作,从而实现了更高的并行性和效率。内核包含两个版本:简单的 krnl_simple_mmult 和链式的 krnl_chain_mmult。简单版本使用 ap_ctrl_hs 协议,而链式版本使用 ap_ctrl_chain 协议,后者提供了额外的控制信号来管理内核操作的流程。
示例代码展示了如何使用这些内核进行矩阵乘法。mm2s 函数读取矩阵的行并发送控制信号,mmult 实现了矩阵乘法操作,而 s2mm 函数将结果写入流。这些函数通过数据流和控制信号的协同工作,实现了高效的内核调用和数据处理。