halide编程技术指南(连载八)

本文深入介绍了Halide编程语言中的元组和类型系统。元组允许函数返回多个值,同时提供了处理多值集合的灵活性。文章通过示例展示了如何创建和使用元组,以及元组在函数调度和归约中的应用。此外,还详细讨论了Halide的类型提升规则,解释了不同类型之间的运算行为。最后,提到了用户定义的类型和类型转换在Halide中的使用。
摘要由CSDN通过智能技术生成

本文是halide编程指南的连载,已同步至公众号

第13章 元组

// 本课程介绍如何编写求多个值的函数.
// 在linux系统, 按如下编译运行:
// g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_13 -std=c++11
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_13
// 在 os x:
// g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_13 -std=c++11
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_13
// 如果你有halide的源码,也可以在源码的最顶层目录,这样:
//    make tutorial_lesson_13_tuples
#include "Halide.h"
#include <algorithm>
#include <stdio.h>
using namespace Halide;
int main(int argc, char **argv) {

//到目前为止,Funcs(如下面的函数)已经为其域中的每个点计算了一个标量值    
Func single_valued;
    Var x, y;
    single_valued(x, y) = x + y;

    // 编写返回值集合的Func的一种方法是给返回值添加索引。这就是我们通常处理颜色的方式。例如,下面的Func表示由c索引的每个x,y坐标的三个值的集合.
    Func color_image;
    Var c;
    color_image(x, y, c) = select(c == 0, 245,  // Red value
                                  c == 1, 42,   // Green value
                                  132);         // Blue value

    // 由于这种模式经常出现,Halide使用“mux”函数提供了一个syntactic sugar来编写上面的代码,如下所示.
    // color_image(x, y, c) = mux(c, {245, 42, 132});

    // 这种方法通常是方便的,因为它对这个Func的操作变得容易,并且对集合中的每个项都一视同仁:
    Func brighter;
    brighter(x, y, c) = color_image(x, y, c) + 10;

    // 然而,这种方法也不方便有三个原因.
    //
    // 1) Func是在一个无限域上定义的,因此该Func的用户可以访问例如color_image(x,y,-17),这不是一个有意义的值,可能表示有bug.
    //
    // 2) 它需要一个select,如果没有绑定和展开,它会影响性能:
    // brighter.bound(c, 0, 3).unroll(c);
    //
    // 3) 使用此方法,集合中的所有值必须具有相同的类型。虽然上述两个问题只是不方便,但这是一个硬性的限制,无法用这种方式表达某些东西.

    // 也可以将值的集合表示为func的集合:
    Func func_array[3];
    func_array[0](x, y) = x + y;
    func_array[1](x, y) = sin(x);
    func_array[2](x, y) = cos(y);

    // 这种方法避免了上述三个问题,但引入了一个新的烦恼。因为这些函数是独立的,所以很难对它们进行调度,以便它们都在x,y上的一个循环中一起计算.

    // 第三种方法是将Func定义为对元组求值,而不是对表达式求值。元组是表达式的固定大小集合。元组中的每个表达式可能有不同的类型。以下函数的计算结果为整数值(x+y)和浮点值(sin(x*y)).
    Func multi_valued;
    multi_valued(x, y) = Tuple(x + y, sin(x * y));

    // 实现一个元组值Func返回一个缓冲区集合。我们称之为实现。它相当于缓冲区对象的std::vector:
    {
        Realization r = multi_valued.realize(80, 60);
        assert(r.size() == 2);
        Buffer<int> im0 = r[0];
        Buffer<float> im1 = r[1];
        assert(im0(30, 40) == 30 + 40);
        assert(im1(30, 40) == sinf(30 * 40));
    }

    // 所有元组元素在同一循环嵌套中的同一域上一起计算,但存储在不同的分配中。上面的C++代码是:
    {
        int multi_valued_0[80 * 60];
        float multi_valued_1[80 * 60];
        for (int y = 0; y < 80; y++) {
            for (int x = 0; x < 60; x++) {
                multi_valued_0[x + 60 * y] = x + y;
                multi_valued_1[x + 60 * y] = sinf(x * y);
            }
        }
    }

    // 在提前编译时,元组值Func计算为多个不同的输出halide_buffer_t结构体。它们依次出现在函数签名的末尾:
    // int multi_valued(...input buffers and params...,
    //                  halide_buffer_t *output_1, halide_buffer_t *output_2);

    // 您可以通过向元组构造函数传递多个表达式来构造元组,就像我们上面所做的那样。也许更优雅,您还可以利用C++ 11初始化列表,只需在括号中包含ExpRs即可:
    Func multi_valued_2;
    multi_valued_2(x, y) = {x + y, sin(x * y)};

    // 对多值函数的调用不能视为表达式。以下是语法错误:
    // Func consumer;
    // consumer(x, y) = multi_valued_2(x, y) + 10;

    // 相反,您必须用方括号索引一个元组来检索各个表达式:
    Expr integer_part = multi_valued_2(x, y)[0];
    Expr floating_part = multi_valued_2(x, y)[1];
    Func consumer;
    consumer(x, y) = {integer_part + 10, floating_part + 10.0f};

    // 元组约化.
    {
        // 元组在归约中特别有用,因为它们允许归约在其域中运行时保持复杂状态。最简单的例子是argmax.

        // 首先,我们创建一个缓冲区来接管argmax.
        Func input_func;
        input_func(x) = sin(x);
        Buffer<float> input = input_func.realize(100);

        // 然后我们定义一个二值元组来跟踪最大值的索引和值本身.
        Func arg_max;

        // 纯定义.
        arg_max() = {0, input(0)};

        // 更新.
        RDom r(1, 99);
        Expr old_index = arg_max()[0];
        Expr old_max = arg_max()[1];
        Expr new_index = select(old_max < input(r), r, old_index);
        Expr new_max = max(input(r), old_max);
        arg_max() = {new_index, new_max};

        // 等效C:
        int arg_max_0 = 0;
        float arg_max_1 = input(0);
        for (int r = 1; r < 100; r++) {
            int old_index = arg_max_0;
            float old_max = arg_max_1;
            int new_index = old_max < input(r) ? r : old_index;
            float new_max = std::max(input(r), old_max);
            // 在元组更新定义中,所有的加载和计算都是在任何存储之前完成的,因此所有元组元素都是相对于对同一Func的递归调用进行原子更新的.
            arg_max_0 = new_index;
            arg_max_1 = new_max;
        }

        // 让我们验证halide和C++找到相同的最大值和索引.
        {
            Realization r = arg_max.realize();
            Buffer<int> r0 = r[0];
            Buffer<float> r1 = r[1];
            assert(arg_max_0 == r0(0));
            assert(arg_max_1 == r1(0));
        }

        // halide提供argmax和argmin作为内置函数,类似于总和、乘积、最大值和最小值。它们返回一个元组,元组由对应于该值的归约域中的点和值本身组成。对于一个tie,它们返回找到的第一个值。我们将在下一节中使用其中一个.
    }

    // 用户定义的元组.
    {
        // 元组也是表示复合对象(如复数)的方便方法。定义一个可以与元组进行转换的对象是用用户定义的类型扩展Halide的类型系统的一种方法.
        struct Complex {
            Expr real, imag;

            // 从元组构建
            Complex(Tuple t)
                : real(t[0]), imag(t[1]) {
            }

            // 从一对 Exprs构建
            Complex(Expr r, Expr i)
                : real(r), imag(i) {
            }

            // 通过将Func当作元组来构造对它的调用
            Complex(FuncRef t)
                : Complex(Tuple(t)) {
            }

            // 转换为元组
            operator Tuple() const {
                return {real, imag};
            }

            // 复合加法
            Complex operator+(const Complex &other) const {
                return {real + other.real, imag + other.imag};
            }

            // 复数乘法
            Complex operator*(const Complex &other) const {
                return {real * other.real - imag * other.imag,
                        real * other.imag + imag * other.real};
            }

            // 复振幅,效率平方
            Expr magnitude_squared() const {
                return real * real + imag * imag;
            }

            // 其他复杂的操作符会在这里。对于这个例子,上面的内容就足够了.
        };

        // 让我们使用复杂结构来计算一个Mandelbrot集.
        Func mandelbrot;

        // 函数中x,y坐标对应的初始复值.
        Complex initial(x / 15.0f - 2.5f, y / 6.0f - 2.0f);

        // 纯定义.
        Var t;
        mandelbrot(x, y, t) = Complex(0.0f, 0.0f);

        // 我们将使用更新定义来执行12个步骤.
        RDom r(1, 12);
        Complex current = mandelbrot(x, y, r - 1);

        // 下面一行使用我们上面定义的复数乘法和加法.
        mandelbrot(x, y, r) = current * current + initial;

        // 我们将使用另一个元组来计算迭代次数,其中值首先转义半径为4的圆。这可以表示为布尔表达式的argmin—我们希望第一次给定布尔表达式的索引为false(我们认为false小于true)。argmax将返回表达式第一次为真时的索引.

        Expr escape_condition = Complex(mandelbrot(x, y, r)).magnitude_squared() < 16.0f;
        Tuple first_escape = argmin(escape_condition);

        // 我们只需要索引,不需要值,但是argmin返回这两个值,所以我们将使用方括号索引argmin元组表达式,以获得表示索引的表达式.
        Func escape;
        escape(x, y) = first_escape[0];

        // 实现流水线并以ascii格式打印结果.
        Buffer<int> result = escape.realize(61, 25);
        const char *code = " .:-~*={}&%#@";
        for (int y = 0; y < result.height(); y++) {
            for (int x = 0; x < result.width(); x++) {
                printf("%c", code[result(x, y)]);
            }
            printf("\n");
        }
    }

    printf("Success!\n");

return 0;
}

第14章 类型系统

// 这一课更精确地描述halide的类型系统.
// linux, 
// g++ lesson_14*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_14 -std=c++11
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_14
// os x:// g++ lesson_14*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_14 -std=c++11
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_14
// 在源码树,可以运行
//    make tutorial_lesson_14_types
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
// 此函数用于在本课程结束时演示通用代码.
Expr average(Expr a, Expr b);
int main(int argc, char **argv) {

    // 所有表达式都有一个标量类型,所有函数的计算结果都是一个或多个标量类型。Halide中的标量类型是各种位宽度的无符号整数、同一组位宽度的有符号整数、单精度和双精度的浮点数以及不透明句柄(相当于void*)。以下数组包含所有合法类型.

    Type valid_halide_types[] = {
        UInt(8), UInt(16), UInt(32), UInt(64),
        Int(8), Int(16), Int(32), Int(64),
        Float(32), Float(64), Handle()};

    // 构造和检查类型.
    {
        // 可以通过编程方式检查halide类型的属性。当编写具有ExPR参数的C++函数时,这是有用的,并且您希望检查它们的类型。:
        assert(UInt(8).bits() == 8);
        assert(Int(8).is_int());

        // 也可以通过编程方式将类型构造为其他类型的函数.
        Type t = UInt(8);
        t = t.with_bits(t.bits() * 2);
        assert(t == UInt(16));

        // 或者从C++标量类型构造类型
        assert(type_of<float>() == Float(32));

        // 类型结构也能够表示向量类型,但这是为Halide的内部使用而保留的。应该使用Func::vectorize对代码进行矢量化,而不是试图直接构造矢量表达式。如果以编程方式操作低halide代码,可能会遇到向量类型,但这是一个高级主题 (查看Func::add_custom_lowering_pass).

        // 您可以查询任何卤化物表达式的类型。Expr代表Var是Int(32)类型:
        Var x;
        assert(Expr(x).type() == Int(32));

        // halide中大多数transcendental 函数将输入类型转换为Float(32) 并返回Float(32):
        assert(sin(x).type() == Float(32));

        // 可以通过如下操作将 Expr从一个类型转换为另一个类型:
        assert(cast(UInt(8), x).type() == UInt(8));

        // 这也是以C++形式的模板形式出现的.
        assert(cast<uint8_t>(x).type() == UInt(8));

        // 您还可以查询任何已定义的Func以获取其生成的类型.
        Func f1;
        f1(x) = cast<uint8_t>(x);
        assert(f1.output_types()[0] == UInt(8));

        Func f2;
        f2(x) = {x, sin(x)};
        assert(f2.output_types()[0] == Int(32) &&
               f2.output_types()[1] == Float(32));
    }

    // 类型提升规则.
    {
        // 当您组合不同类型的表达式(例如使用“+”、“*”等)时,Halide使用类型提升规则系统。这些与C的规则不同。为了演示这些,我们将对每种类型进行一些表达式.
        Var x;
        Expr u8 = cast<uint8_t>(x);
        Expr u16 = cast<uint16_t>(x);
        Expr u32 = cast<uint32_t>(x);
        Expr u64 = cast<uint64_t>(x);
        Expr s8 = cast<int8_t>(x);
        Expr s16 = cast<int16_t>(x);
        Expr s32 = cast<int32_t>(x);
        Expr s64 = cast<int64_t>(x);
        Expr f32 = cast<float>(x);
        Expr f64 = cast<double>(x);

        // 规则如下所示,并按以下顺序应用.

        // 1) 对Handle()类型的表达式强制转换或使用算术运算符是错误的.

        // 2) 如果类型相同,则不会发生类型转换.
        for (Type t : valid_halide_types) {
            // 跳过句柄类型.
            if (t.is_handle()) continue;
            Expr e = cast(t, x);
            assert((e + e).type() == e.type());
        }

        // 3) 如果一个类型是float而另一个不是,那么non-float参数将提升为float(可能导致大整数的精度损失).
        assert((u8 + f32).type() == Float(32));
        assert((f32 + s64).type() == Float(32));
        assert((u16 + f64).type() == Float(64));
        assert((f64 + s32).type() == Float(64));

        // 4) 如果这两种类型都是float,则较窄的参数将提升为较宽的位宽度.
        assert((f64 + f32).type() == Float(64));

        // 上面的规则处理所有浮点情况。以下三条规则处理整数情况.

        // 5) 如果其中一个参数是C++ int,而另一个是halide::ExPR,则int被强制转换为表达式的类型。.
        assert((u32 + 3).type() == UInt(32));
        assert((3 + s16).type() == Int(16));

        // 如果此规则会导致整数溢出,则Halide将触发错误,例如,取消对以下行的注释将导致此程序以错误终止.
        // Expr bad = u8 + 257;

        // 6) 如果两种类型都是无符号整数,或者两种类型都是有符号整数,则较窄的参数将提升为较宽的类型.
        assert((u32 + u8).type() == UInt(32));
        assert((s16 + s64).type() == Int(64));

        // 7) 如果一种类型是有符号的,而另一种是无符号的,则两个参数都将提升为一个有符号整数,其宽度为两个位宽度中的较大值.
        assert((u8 + s32).type() == Int(32));
        assert((u32 + s8).type() == Int(32));

        // 注意,在比特宽度相同的情况下,这可能会悄悄地溢出无符号类型.
        assert((u32 + s32).type() == Int(32));

        // 以这种方式将无符号表达式转换为更宽的有符号类型时,首先将其扩展为更宽的无符号类型(零扩展),然后重新解释为有符号整数。即,将UInt(8)值255转换为Int(32)产生255,而不是-1.
        int32_t result32 = evaluate<int>(cast<int32_t>(cast<uint8_t>(255)));
        assert(result32 == 255);

        // 当使用强制转换运算符将有符号类型显式转换为更宽的无符号类型(类型提升规则不会自动执行此操作)时,首先将其转换为更宽的有符号类型(符号扩展),然后重新解释为无符号整数。即,将Int(8)值-1转换为UInt(16)产生65535,而不是255.
        uint16_t result16 = evaluate<uint16_t>(cast<uint16_t>(cast<int8_t>(-1)));
        assert(result16 == 65535);
    }

    // Handle()类型.
    {
        // 句柄用于表示不透明指针。将type_of应用于任何指针类型将返回Handle()
        assert(type_of<void *>() == Handle());
        assert(type_of<const char *const **>() == Handle());

        // 不管编译目标是什么,句柄始终存储为64位.
        assert(Handle().bits() == 64);

        // Handle类型的Expr的主要用途是将它通过halide传递给其他外部代码.
    }

    // 通用代码.
    {
        // Type在Halide中的主要显式用法是编写由Type参数化的Halide代码。在C++中,你可以用模板来完成这个操作。在halide中,不需要 —— 可以在C++运行时动态地检查和修改类型。下面定义的函数用来平均任意相等数值类型的两个表达式.
        Var x;
        assert(average(cast<float>(x), 3.0f).type() == Float(32));
        assert(average(x, 3).type() == Int(32));
        assert(average(cast<uint8_t>(x), cast<uint8_t>(3)).type() == UInt(8));
    }

    printf("Success!\n");

    return 0;}

Expr average(Expr a, Expr b) {
    // 类型必须匹配.
    assert(a.type() == b.type());

    // 对于浮点类型:
    if (a.type().is_float()) {
        // 由于上面的规则3,“2”将升级为浮点类型.
        return (a + b) / 2;
    }

    // 对于整数类型,我们必须在更宽的类型中计算中间值以避免溢出.
    Type narrow = a.type();
    Type wider = narrow.with_bits(narrow.bits() * 2);
    a = cast(wider, a);
    b = cast(wider, b);
    return cast(narrow, (a + b) / 2);
}

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值