Vitis HLS 学习笔记--Stream Chain Matrix Multiplication

目录

1. 简介

2. 示例解析

2.1 示例功能说明

2.2 函数说明 

2.2.1 mmult 函数

2.2.2 mm2s 函数

2.2.3 s2mm 函数

2.2.4 总示意图

3. 总结


1. 简介

这是一个包含使用数据流的级联矩阵乘法的内核。该内核启用了 ap_ctrl_chain,以展示如何重叠多个内核调用队列以提供更高的性能。 ap_ctrl_chain 允许内核在完成当前内核操作之前开始处理下一个内核操作。

2. 示例解析

2.1 示例功能说明

这个例子展示了链式内核和简单内核之间的性能差异。

#pragma HLS INTERFACE s_axilite port = return bundle = control
#pragma HLS INTERFACE ap_ctrl_chain port = return bundle = control

示例包含两个内核:

1. krnl_simple_mmult:使用 ap_ctrl_hs 协议的内核。

#pragma HLS INTERFACE ap_ctrl_hs port = return

2. krnl_chain_mmult:使用 ap_ctrl_chain 协议的内核;

#pragma HLS INTERFACE ap_ctrl_chain port = return

ap_ctrl_chain 实现一组块级控制端口,用于启动设计操作、继续操作,并指示设计何时空闲、完成以及准备好新的输入数据。 ap_ctrl_chain 接口模式与 ap_ctrl_hs 类似,但提供额外的输入信号 ap_continue 来施加背压。 Xilinx 建议在将 Vivado HLS 块链接在一起时使用 ap_ctrl_chain 块级 I/O 协议。

#include "ap_axi_sdata.h"
#include "ap_int.h"
#include "hls_stream.h"
#include "krnl_mmult.hpp"
extern "C" {
void krnl_simple_mmult(int* a, int* b, int* c, int* d, int* output, int dim) {
#pragma HLS INTERFACE m_axi port = a offset = slave bundle = gmem0
#pragma HLS INTERFACE m_axi port = b offset = slave bundle = gmem1
#pragma HLS INTERFACE m_axi port = c offset = slave bundle = gmem2
#pragma HLS INTERFACE m_axi port = d offset = slave bundle = gmem3
#pragma HLS INTERFACE m_axi port = output offset = slave bundle = gmem3
#pragma HLS INTERFACE s_axilite port = a
#pragma HLS INTERFACE s_axilite port = b
#pragma HLS INTERFACE s_axilite port = c
#pragma HLS INTERFACE s_axilite port = d
#pragma HLS INTERFACE s_axilite port = output
#pragma HLS INTERFACE s_axilite port = dim
#pragma HLS INTERFACE s_axilite port = return
#pragma HLS INTERFACE ap_ctrl_hs port = return

#pragma HLS STABLE variable = a
#pragma HLS STABLE variable = b
#pragma HLS STABLE variable = c
#pragma HLS STABLE variable = d
#pragma HLS STABLE variable = output

    hls::stream<pkt> strm_a, strm_b, strm_c, strm_d;
    hls::stream<int> strm_ctrl_trans1, strm_ctrl_trans2, strm_ctrl_trans3, strm_ctrl_trans4, strm_ctrl_trans5;

    int tmp = dim;
    strm_ctrl_trans1.write(tmp);

#pragma HLS STREAM variable = strm_ctrl_trans1 depth = 2
#pragma HLS STREAM variable = strm_ctrl_trans2 depth = 2
#pragma HLS STREAM variable = strm_ctrl_trans3 depth = 2
#pragma HLS STREAM variable = strm_ctrl_trans4 depth = 2
#pragma HLS STREAM variable = strm_ctrl_trans5 depth = 2

#pragma HLS STREAM variable = strm_a depth = 64
#pragma HLS STREAM variable = strm_b depth = 64
#pragma HLS STREAM variable = strm_c depth = 64
#pragma HLS STREAM variable = strm_d depth = 64

#pragma HLS DATAFLOW

    mm2s(a, strm_a, strm_ctrl_trans1, strm_ctrl_trans2);
    mmult(strm_a, b, strm_ctrl_trans2, strm_b, strm_ctrl_trans3);
    mmult(strm_b, c, strm_ctrl_trans3, strm_c, strm_ctrl_trans4);
    mmult(strm_c, d, strm_ctrl_trans4, strm_d, strm_ctrl_trans5);
    s2mm(strm_d, output, strm_ctrl_trans5);
}
}

2.2 函数说明 

  • mm2s 函数:这个函数将矩阵的行读入并写入一个流,同时发送控制信号以指示矩阵的大小。
  • mmult 函数:这个函数实现了矩阵乘法。它首先读取控制信号来确定矩阵的大小。然后,它创建了四个临时数组来存储矩阵的行、列、乘积和最终结果。接下来,它使用三个循环来计算每个结果元素的值,通过将相应行和列的元素相乘和相加得到。
  • s2mm 函数:这个函数将矩阵的结果读入并写入一个流,同时发送控制信号以指示矩阵的大小。

下面分别对三个函数展开说明:

2.2.1 mmult 函数

// Template to avoid signature conflict in sw_em
template <int DUMMY = 0>
void mmult(hls::stream<pkt>& strm_a,
           int* b,
           hls::stream<int>& strm_ctrl_trans2,
           hls::stream<pkt>& strm_out,
           hls::stream<int>& strm_ctrl_trans3) {
    int dim = strm_ctrl_trans2.read();
    strm_ctrl_trans3.write(dim);
    int size = dim * dim;

    int buf_a[MAT_DIM][MAT_DIM];
    int buf_b[MAT_DIM][MAT_DIM];
    int buf_out[MAT_DIM][MAT_DIM];
    int temp_sum[MAT_DIM];
    int i, j, itr;

// Auto-pipeline is going to apply pipeline to these loops
read_strm_in1:
    for (itr = 0, i = 0, j = 0; itr < size; itr++, j++) {
        if (j == dim) {
            j = 0;
            i++;
        }
        pkt temp = strm_a.read();
        buf_a[i][j] = temp;
    }

read_mm_in2:
    for (itr = 0, i = 0, j = 0; itr < size; itr++, j++) {
        if (j == dim) {
            j = 0;
            i++;
        }
        buf_b[i][j] = b[i * dim + j];
    }

mmult_strm_1:
    for (int row = 0; row < dim; row++) {
    mmult_strm_2:
        for (int col = 0; col < dim; col++) {
            int result = 0;
        mmult_strm_3:
            for (int l = 0; l < dim; l++) {
// As the outer loop is not a perfect loop
#pragma HLS loop_flatten off
                result += buf_a[row][l] * buf_b[l][col];
            }
            buf_out[row][col] = result;
        }
    }

write_strm_out:
    for (itr = 0, i = 0, j = 0; itr < size; itr++, j++) {
        if (j == dim) {
            j = 0;
            i++;
        }
        pkt temp;
        temp = buf_out[i][j];

        strm_out.write(temp);
    }
}

2.2.2 mm2s 函数

template <int DUMMY = 0>
void mm2s(       int* a,
    hls::stream<pkt>& strm_a,
    hls::stream<int>& strm_ctrl_trans1,
    hls::stream<int>& strm_ctrl_trans2)
{
    int dim = strm_ctrl_trans1.read();
    strm_ctrl_trans2.write(dim);
    int size = dim * dim;

// Auto-pipeline is going to apply pipeline to this loop
mm2s:
    for (int i = 0; i < size; i++) {
        pkt p1;
        p1 = a[i];

        strm_a.write(p1);
    }
}

2.2.3 s2mm 函数

template <int DUMMY = 0>
void s2mm(hls::stream<pkt>& strm_in,
                       int* output,
          hls::stream<int>& strm_ctrl_trans5)
{
    int dim = strm_ctrl_trans5.read();
    int size = dim * dim;

write_output:
    for (int i = 0; i < size; i++) {
        pkt temp = strm_in.read();
        output[i] = temp;
    }
}

 

2.2.4 总示意图

3. 总结

这个内核的核心功能是通过使用数据流和级联矩阵乘法来提高性能。它利用了 ap_ctrl_chain 协议,这允许内核在当前操作完成后立即开始下一个操作,从而实现了更高的并行性和效率。内核包含两个版本:简单的 krnl_simple_mmult 和链式的 krnl_chain_mmult。简单版本使用 ap_ctrl_hs 协议,而链式版本使用 ap_ctrl_chain 协议,后者提供了额外的控制信号来管理内核操作的流程。

示例代码展示了如何使用这些内核进行矩阵乘法。mm2s 函数读取矩阵的行并发送控制信号,mmult 实现了矩阵乘法操作,而 s2mm 函数将结果写入流。这些函数通过数据流和控制信号的协同工作,实现了高效的内核调用和数据处理。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值