本文是halide编程指南的连载,已同步至公众号
第13章 元组
// 本课程介绍如何编写求多个值的函数.
// 在linux系统, 按如下编译运行:
// g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_13 -std=c++11
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_13
// 在 os x:
// g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_13 -std=c++11
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_13
// 如果你有halide的源码,也可以在源码的最顶层目录,这样:
// make tutorial_lesson_13_tuples
#include "Halide.h"
#include <algorithm>
#include <stdio.h>
using namespace Halide;
int main(int argc, char **argv) {
//到目前为止,Funcs(如下面的函数)已经为其域中的每个点计算了一个标量值
Func single_valued;
Var x, y;
single_valued(x, y) = x + y;
// 编写返回值集合的Func的一种方法是给返回值添加索引。这就是我们通常处理颜色的方式。例如,下面的Func表示由c索引的每个x,y坐标的三个值的集合.
Func color_image;
Var c;
color_image(x, y, c) = select(c == 0, 245, // Red value
c == 1, 42, // Green value
132); // Blue value
// 由于这种模式经常出现,Halide使用“mux”函数提供了一个syntactic sugar来编写上面的代码,如下所示.
// color_image(x, y, c) = mux(c, {245, 42, 132});
// 这种方法通常是方便的,因为它对这个Func的操作变得容易,并且对集合中的每个项都一视同仁:
Func brighter;
brighter(x, y, c) = color_image(x, y, c) + 10;
// 然而,这种方法也不方便有三个原因.
//
// 1) Func是在一个无限域上定义的,因此该Func的用户可以访问例如color_image(x,y,-17),这不是一个有意义的值,可能表示有bug.
//
// 2) 它需要一个select,如果没有绑定和展开,它会影响性能:
// brighter.bound(c, 0, 3).unroll(c);
//
// 3) 使用此方法,集合中的所有值必须具有相同的类型。虽然上述两个问题只是不方便,但这是一个硬性的限制,无法用这种方式表达某些东西.
// 也可以将值的集合表示为func的集合:
Func func_array[3];
func_array[0](x, y) = x + y;
func_array[1](x, y) = sin(x);
func_array[2](x, y) = cos(y);
// 这种方法避免了上述三个问题,但引入了一个新的烦恼。因为这些函数是独立的,所以很难对它们进行调度,以便它们都在x,y上的一个循环中一起计算.
// 第三种方法是将Func定义为对元组求值,而不是对表达式求值。元组是表达式的固定大小集合。元组中的每个表达式可能有不同的类型。以下函数的计算结果为整数值(x+y)和浮点值(sin(x*y)).
Func multi_valued;
multi_valued(x, y) = Tuple(x + y, sin(x * y));
// 实现一个元组值Func返回一个缓冲区集合。我们称之为实现。它相当于缓冲区对象的std::vector:
{
Realization r = multi_valued.realize(80, 60);
assert(r.size() == 2);
Buffer<int> im0 = r[0];
Buffer<float> im1 = r[1];
assert(im0(30, 40) == 30 + 40);
assert(im1(30, 40) == sinf(30 * 40));
}
// 所有元组元素在同一循环嵌套中的同一域上一起计算,但存储在不同的分配中。上面的C++代码是:
{
int multi_valued_0[80 * 60];
float multi_valued_1[80 * 60];
for (int y = 0; y < 80; y++) {
for (int x = 0; x < 60; x++) {
multi_valued_0[x + 60 * y] = x + y;
multi_valued_1[x + 60 * y] = sinf(x * y);
}
}
}
// 在提前编译时,元组值Func计算为多个不同的输出halide_buffer_t结构体。它们依次出现在函数签名的末尾:
// int multi_valued(...input buffers and params...,
// halide_buffer_t *output_1, halide_buffer_t *output_2);
// 您可以通过向元组构造函数传递多个表达式来构造元组,就像我们上面所做的那样。也许更优雅,您还可以利用C++ 11初始化列表,只需在括号中包含ExpRs即可:
Func multi_valued_2;
multi_valued_2(x, y) = {x + y, sin(x * y)};
// 对多值函数的调用不能视为表达式。以下是语法错误:
// Func consumer;
// consumer(x, y) = multi_valued_2(x, y) + 10;
// 相反,您必须用方括号索引一个元组来检索各个表达式:
Expr integer_part = multi_valued_2(x, y)[0];
Expr floating_part = multi_valued_2(x, y)[1];
Func consumer;
consumer(x, y) = {integer_part + 10, floating_part + 10.0f};
// 元组约化.
{
// 元组在归约中特别有用,因为它们允许归约在其域中运行时保持复杂状态。最简单的例子是argmax.
// 首先,我们创建一个缓冲区来接管argmax.
Func input_func;
input_func(x) = sin(x);
Buffer<float> input = input_func.realize(100);
// 然后我们定义一个二值元组来跟踪最大值的索引和值本身.
Func arg_max;
// 纯定义.
arg_max() = {0, input(0)};
// 更新.
RDom r(1, 99);
Expr old_index = arg_max()[0];
Expr old_max = arg_max()[1];
Expr new_index = select(old_max < input(r), r, old_index);
Expr new_max = max(input(r), old_max);
arg_max() = {new_index, new_max};
// 等效C:
int arg_max_0 = 0;
float arg_max_1 = input(0);
for (int r = 1; r < 100; r++) {
int old_index = arg_max_0;
float old_max = arg_max_1;
int new_index = old_max < input(r) ? r : old_index;
float new_max = std::max(input(r), old_max);
// 在元组更新定义中,所有的加载和计算都是在任何存储之前完成的,因此所有元组元素都是相对于对同一Func的递归调用进行原子更新的.
arg_max_0 = new_index;
arg_max_1 = new_max;
}
// 让我们验证halide和C++找到相同的最大值和索引.
{
Realization r = arg_max.realize();
Buffer<int> r0 = r[0];
Buffer<float> r1 = r[1];
assert(arg_max_0 == r0(0));
assert(arg_max_1 == r1(0));
}
// halide提供argmax和argmin作为内置函数,类似于总和、乘积、最大值和最小值。它们返回一个元组,元组由对应于该值的归约域中的点和值本身组成。对于一个tie,它们返回找到的第一个值。我们将在下一节中使用其中一个.
}
// 用户定义的元组.
{
// 元组也是表示复合对象(如复数)的方便方法。定义一个可以与元组进行转换的对象是用用户定义的类型扩展Halide的类型系统的一种方法.
struct Complex {
Expr real, imag;
// 从元组构建
Complex(Tuple t)
: real(t[0]), imag(t[1]) {
}
// 从一对 Exprs构建
Complex(Expr r, Expr i)
: real(r), imag(i) {
}
// 通过将Func当作元组来构造对它的调用
Complex(FuncRef t)
: Complex(Tuple(t)) {
}
// 转换为元组
operator Tuple() const {
return {real, imag};
}
// 复合加法
Complex operator+(const Complex &other) const {
return {real + other.real, imag + other.imag};
}
// 复数乘法
Complex operator*(const Complex &other) const {
return {real * other.real - imag * other.imag,
real * other.imag + imag * other.real};
}
// 复振幅,效率平方
Expr magnitude_squared() const {
return real * real + imag * imag;
}
// 其他复杂的操作符会在这里。对于这个例子,上面的内容就足够了.
};
// 让我们使用复杂结构来计算一个Mandelbrot集.
Func mandelbrot;
// 函数中x,y坐标对应的初始复值.
Complex initial(x / 15.0f - 2.5f, y / 6.0f - 2.0f);
// 纯定义.
Var t;
mandelbrot(x, y, t) = Complex(0.0f, 0.0f);
// 我们将使用更新定义来执行12个步骤.
RDom r(1, 12);
Complex current = mandelbrot(x, y, r - 1);
// 下面一行使用我们上面定义的复数乘法和加法.
mandelbrot(x, y, r) = current * current + initial;
// 我们将使用另一个元组来计算迭代次数,其中值首先转义半径为4的圆。这可以表示为布尔表达式的argmin—我们希望第一次给定布尔表达式的索引为false(我们认为false小于true)。argmax将返回表达式第一次为真时的索引.
Expr escape_condition = Complex(mandelbrot(x, y, r)).magnitude_squared() < 16.0f;
Tuple first_escape = argmin(escape_condition);
// 我们只需要索引,不需要值,但是argmin返回这两个值,所以我们将使用方括号索引argmin元组表达式,以获得表示索引的表达式.
Func escape;
escape(x, y) = first_escape[0];
// 实现流水线并以ascii格式打印结果.
Buffer<int> result = escape.realize(61, 25);
const char *code = " .:-~*={}&%#@";
for (int y = 0; y < result.height(); y++) {
for (int x = 0; x < result.width(); x++) {
printf("%c", code[result(x, y)]);
}
printf("\n");
}
}
printf("Success!\n");
return 0;
}
第14章 类型系统
// 这一课更精确地描述halide的类型系统.
// linux,
// g++ lesson_14*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_14 -std=c++11
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_14
// os x:// g++ lesson_14*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_14 -std=c++11
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_14
// 在源码树,可以运行
// make tutorial_lesson_14_types
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
// 此函数用于在本课程结束时演示通用代码.
Expr average(Expr a, Expr b);
int main(int argc, char **argv) {
// 所有表达式都有一个标量类型,所有函数的计算结果都是一个或多个标量类型。Halide中的标量类型是各种位宽度的无符号整数、同一组位宽度的有符号整数、单精度和双精度的浮点数以及不透明句柄(相当于void*)。以下数组包含所有合法类型.
Type valid_halide_types[] = {
UInt(8), UInt(16), UInt(32), UInt(64),
Int(8), Int(16), Int(32), Int(64),
Float(32), Float(64), Handle()};
// 构造和检查类型.
{
// 可以通过编程方式检查halide类型的属性。当编写具有ExPR参数的C++函数时,这是有用的,并且您希望检查它们的类型。:
assert(UInt(8).bits() == 8);
assert(Int(8).is_int());
// 也可以通过编程方式将类型构造为其他类型的函数.
Type t = UInt(8);
t = t.with_bits(t.bits() * 2);
assert(t == UInt(16));
// 或者从C++标量类型构造类型
assert(type_of<float>() == Float(32));
// 类型结构也能够表示向量类型,但这是为Halide的内部使用而保留的。应该使用Func::vectorize对代码进行矢量化,而不是试图直接构造矢量表达式。如果以编程方式操作低halide代码,可能会遇到向量类型,但这是一个高级主题 (查看Func::add_custom_lowering_pass).
// 您可以查询任何卤化物表达式的类型。Expr代表Var是Int(32)类型:
Var x;
assert(Expr(x).type() == Int(32));
// halide中大多数transcendental 函数将输入类型转换为Float(32) 并返回Float(32):
assert(sin(x).type() == Float(32));
// 可以通过如下操作将 Expr从一个类型转换为另一个类型:
assert(cast(UInt(8), x).type() == UInt(8));
// 这也是以C++形式的模板形式出现的.
assert(cast<uint8_t>(x).type() == UInt(8));
// 您还可以查询任何已定义的Func以获取其生成的类型.
Func f1;
f1(x) = cast<uint8_t>(x);
assert(f1.output_types()[0] == UInt(8));
Func f2;
f2(x) = {x, sin(x)};
assert(f2.output_types()[0] == Int(32) &&
f2.output_types()[1] == Float(32));
}
// 类型提升规则.
{
// 当您组合不同类型的表达式(例如使用“+”、“*”等)时,Halide使用类型提升规则系统。这些与C的规则不同。为了演示这些,我们将对每种类型进行一些表达式.
Var x;
Expr u8 = cast<uint8_t>(x);
Expr u16 = cast<uint16_t>(x);
Expr u32 = cast<uint32_t>(x);
Expr u64 = cast<uint64_t>(x);
Expr s8 = cast<int8_t>(x);
Expr s16 = cast<int16_t>(x);
Expr s32 = cast<int32_t>(x);
Expr s64 = cast<int64_t>(x);
Expr f32 = cast<float>(x);
Expr f64 = cast<double>(x);
// 规则如下所示,并按以下顺序应用.
// 1) 对Handle()类型的表达式强制转换或使用算术运算符是错误的.
// 2) 如果类型相同,则不会发生类型转换.
for (Type t : valid_halide_types) {
// 跳过句柄类型.
if (t.is_handle()) continue;
Expr e = cast(t, x);
assert((e + e).type() == e.type());
}
// 3) 如果一个类型是float而另一个不是,那么non-float参数将提升为float(可能导致大整数的精度损失).
assert((u8 + f32).type() == Float(32));
assert((f32 + s64).type() == Float(32));
assert((u16 + f64).type() == Float(64));
assert((f64 + s32).type() == Float(64));
// 4) 如果这两种类型都是float,则较窄的参数将提升为较宽的位宽度.
assert((f64 + f32).type() == Float(64));
// 上面的规则处理所有浮点情况。以下三条规则处理整数情况.
// 5) 如果其中一个参数是C++ int,而另一个是halide::ExPR,则int被强制转换为表达式的类型。.
assert((u32 + 3).type() == UInt(32));
assert((3 + s16).type() == Int(16));
// 如果此规则会导致整数溢出,则Halide将触发错误,例如,取消对以下行的注释将导致此程序以错误终止.
// Expr bad = u8 + 257;
// 6) 如果两种类型都是无符号整数,或者两种类型都是有符号整数,则较窄的参数将提升为较宽的类型.
assert((u32 + u8).type() == UInt(32));
assert((s16 + s64).type() == Int(64));
// 7) 如果一种类型是有符号的,而另一种是无符号的,则两个参数都将提升为一个有符号整数,其宽度为两个位宽度中的较大值.
assert((u8 + s32).type() == Int(32));
assert((u32 + s8).type() == Int(32));
// 注意,在比特宽度相同的情况下,这可能会悄悄地溢出无符号类型.
assert((u32 + s32).type() == Int(32));
// 以这种方式将无符号表达式转换为更宽的有符号类型时,首先将其扩展为更宽的无符号类型(零扩展),然后重新解释为有符号整数。即,将UInt(8)值255转换为Int(32)产生255,而不是-1.
int32_t result32 = evaluate<int>(cast<int32_t>(cast<uint8_t>(255)));
assert(result32 == 255);
// 当使用强制转换运算符将有符号类型显式转换为更宽的无符号类型(类型提升规则不会自动执行此操作)时,首先将其转换为更宽的有符号类型(符号扩展),然后重新解释为无符号整数。即,将Int(8)值-1转换为UInt(16)产生65535,而不是255.
uint16_t result16 = evaluate<uint16_t>(cast<uint16_t>(cast<int8_t>(-1)));
assert(result16 == 65535);
}
// Handle()类型.
{
// 句柄用于表示不透明指针。将type_of应用于任何指针类型将返回Handle()
assert(type_of<void *>() == Handle());
assert(type_of<const char *const **>() == Handle());
// 不管编译目标是什么,句柄始终存储为64位.
assert(Handle().bits() == 64);
// Handle类型的Expr的主要用途是将它通过halide传递给其他外部代码.
}
// 通用代码.
{
// Type在Halide中的主要显式用法是编写由Type参数化的Halide代码。在C++中,你可以用模板来完成这个操作。在halide中,不需要 —— 可以在C++运行时动态地检查和修改类型。下面定义的函数用来平均任意相等数值类型的两个表达式.
Var x;
assert(average(cast<float>(x), 3.0f).type() == Float(32));
assert(average(x, 3).type() == Int(32));
assert(average(cast<uint8_t>(x), cast<uint8_t>(3)).type() == UInt(8));
}
printf("Success!\n");
return 0;}
Expr average(Expr a, Expr b) {
// 类型必须匹配.
assert(a.type() == b.type());
// 对于浮点类型:
if (a.type().is_float()) {
// 由于上面的规则3,“2”将升级为浮点类型.
return (a + b) / 2;
}
// 对于整数类型,我们必须在更宽的类型中计算中间值以避免溢出.
Type narrow = a.type();
Type wider = narrow.with_bits(narrow.bits() * 2);
a = cast(wider, a);
b = cast(wider, b);
return cast(narrow, (a + b) / 2);
}