Mojo 学习 —— 参数化:编译时元编程

Mojo 学习 —— 参数化:编译时元编程

介绍

所谓元编程,即编写可以生成或修改其他代码的代码。

Python 的装饰器、元类就是一种元编程。虽然非常灵活且高效,但由于它们是动态的,会带来更多的运行时开销。

而其他语言有静态或编译时元编程特性,比如 C 预处理器宏和 C++ 模板。这些功能存在一些局限性,且难以使用。

Mojo 为了提供功能强大、易于使用且运行成本为零的元编程。它添加了编译时参数(parameter),通过在编译时解析参数(类似于程序运行时解析的函数参数(argument)),使其在运行时是为常量,提高运行效率。

不同于其他语言,在 Mojoparameterparameter expression 指的是编译时值,而 argumentexpression 指的是运行时值。

为了区分这两个概念,在本章中我们都使用英文来表示。

参数化函数

要定义参数化函数,可以在函数 argument 前使用方括号来添加 parameter。其定义方式类似于 argument:在名称后面跟一个冒号和一个类型(必填)。例如

fn func[param: Dtype](args: Dtype):
    ...

来看一个具体的例子

fn repeat[count: Int](msg: String):
    @unroll
    for i in range(count):
        print(msg)

此处 @unroll 指令可在编译时展开循环。

调用参数化函数时,需要为 parameter 提供值,就像 argument 一样,例如

repeat[3]("Hello")

编译器会在编译时解析 parameter 的值,并为每个唯一的 parameter 创建一个具体版本的 repeat[]() 函数。

在解析了之后的 repeat[3]() 函数大致相当于:

fn repeat_3(msg: String):
    print(msg)
    print(msg)
    print(msg)

注意:

这并不代表编译器生成的实际代码。在解析 parameter 时,Mojo 代码已经转换为 MLIR 的中间表示形式。

如果编译器无法将所有 parameter 值解析为常量值,编译就会失败。

参数化结构体

您还可以为结构体添加 parameter。例如,通用数组类型可能包含如下代码:

struct GenericArray[T: AnyRegType]:
    var data: Pointer[T]
    var size: Int

    fn __init__(inout self, *elements: T):
        self.size = len(elements)
        self.data = Pointer[T].alloc(self.size)
        for i in range(self.size):
            self.data[i] = elements[i]

    fn __del__(owned self):
        self.data.free()

    fn __getitem__(self, i: Int) raises -> T:
        if (i < self.size):
            return self.data[i]
        else:
            raise Error("Out of bounds")

该结构体有一个名为 Tparameter,它是一个占位符,表示要存储在数组中的数据类型,有时也称为类型 parameterT 可以在结构体范围内使用。

T 的类型为 AnyRegType,这是一种元类型,代表任何可通过寄存器的类型。这意味着我们的 GenericArray 可以保存固定大小的数据类型,如整数和浮点数,但不能保存动态分配的数据,如字符串或向量。

注意:

使用大写的 T 作为 parameter 对编译器来说没有任何特殊意义(只是一个名称而已),但在许多语言中,使用简短的名称来表示类型参数是一种惯例,通常默认为 T

与参数化函数一样,使用参数化结构体时也需要传递参数值。在这种情况下,创建 GenericArray 实例时,需要指定要存储的类型,如 IntFloat64Mojo 种的类型是一个有效的编译时值)。

下面是一个使用 GenericArray 的示例:

var array = GenericArray[Int](1, 2, 3, 4)
for i in range(array.size):
    print(array[i], sep=" ", end="")
# 1 2 3 4

参数化结构体可以使用 Self 类型来表示结构体的具体实例。例如,你可以为 GenericArray 添加一个静态工厂方法,其签名如下:

struct GenericArray[T: AnyRegType]:
    ...

    @staticmethod
    fn splat(count: Int, value: T) -> Self:
        ...

在这里,Self 相当于编写 GenericArray[T]。也就是说,你可以像这样调用 splat 方法:

GenericArray[Float64].splat(8, 0)

该方法返回一个 GenericArray[Float64] 的实例。

参数重载

函数和方法的 parameter 也可以重载,重载解析逻辑根据以下规则(按优先顺序排列):

  1. 隐式转换最少
  2. 没有可变的 arguments
  3. 没有可变的 parameters
  4. parameter 数量最少
  5. 不是 @staticmethod

如果应用这些规则后,候选者不止一个,则重载解析失败。例如

@register_passable("trivial")
struct MyInt:
    """A type that is implicitly convertible to `Int`."""
    var value: Int

    @always_inline("nodebug")
    fn __init__(_a: Int) -> Self:
        return Self {value: _a}

fn foo[x: MyInt, a: Int]():
    print("foo[x: MyInt, a: Int]()")

fn foo[x: MyInt, y: MyInt]():
    print("foo[x: MyInt, y: MyInt]()")

fn bar[a: Int](b: Int):
    print("bar[a: Int](b: Int)")

fn bar[a: Int](*b: Int):
    print("bar[a: Int](*b: Int)")

fn bar[*a: Int](b: Int):
    print("bar[*a: Int](b: Int)")

fn parameter_overloads[a: Int, b: Int, x: MyInt]():
    foo[x, a]()
    bar[a](b)
    bar[a, a, a](b)

struct MyStruct:
    fn __init__(inout self):
        pass

    fn foo(inout self):
        print("calling instance menthod")

    @staticmethod
    fn foo():
        print("calling static method")

fn test_static_overload():
    var a = MyStruct()
    a.foo()

fn main():
    parameter_overloads[1, 2, MyInt(3)]()
    test_static_overload()
foo[x: MyInt, a: Int]()
bar[a: Int](b: Int)
bar[*a: Int](b: Int)
calling instance menthod

使用参数化类型和函数

通过给方括号中的 parameter 传递值,可以实例化参数类型和函数。

例如,对于 SIMD 类型

struct SIMD[type: DType, size: Int]: ...

type 指定了数据类型,size 指定 SIMD 向量的长度(必须是 2 的幂次):

# Make a vector of 4 floats.
var small_vec = SIMD[DType.float32, 4](1.0, 2.0, 3.0, 4.0)
# Make a big vector containing 1.0 in float16 format.
var big_vec = SIMD[DType.float16, 32].splat(1.0)
# Do some math and convert the elements to float32.
var bigger_vec = (big_vec+big_vec).cast[DType.float32]()
# You can write types out explicitly if you want of course.
var bigger_vec2 : SIMD[DType.float32, 32] = bigger_vec

print('small_vec type:', small_vec.element_type, 'length:', len(small_vec))
print('bigger_vec2 type:', bigger_vec2.element_type, 'length:', len(bigger_vec2))
small_vec type: float32 length: 4
bigger_vec2 type: float32 length: 32

需要注意的是,cast 方法还需要一个参数来指定你想要转换的类型(即需要一个目标类型)。

我们还可以使用 parameter 参数化来定义针对 SIMD 但与类型和宽度无关的算法:

from math import sqrt

fn rsqrt[dt: DType, width: Int](x: SIMD[dt, width]) -> SIMD[dt, width]:
    return 1 / sqrt(x)

var v = SIMD[DType.float16, 4](42)
print(rsqrt(v))

x 参数实际上是基于函数 rsqrtparameter 创建的 SIMD 类型。运行时程序可以使用编译时参数的值,但编译时参数表达式不能使用运行时值。

Mojo 编译器会根据传入的参数 x 值推断其 parameter,就像明确写入 rsqrt[dt, width](x) 一样。

还有就是,rsqrt 选择将第二个 parameter 命名为 width,而 SIMD 类型将其命名为 size,也没有问题。

可选和关键字参数

parameter 类似于 argument,也支持可选或关键字形式。例如

fn speak[a: Int = 3, msg: StringLiteral = "woof"]():
    print(msg, a)

fn use_defaults() raises:
    speak()             # prints 'woof 3'
    speak[5]()          # prints 'woof 5'
    speak[7, "meow"]()  # prints 'meow 7'
    speak[msg="baaa"]() # prints 'baaa 3'

speak 函数的两个 parametr 都有默认值,可以两个都不指定、指定一个或以关键字的形式指定 parameter 的值

当调用函数时,Mojo 可以推断出参数值。也就是说,它可以从 argument 的类型反推出 parameter 的类型。例如:

@value
struct Bar[v: Int]:
    pass

fn foo[a: Int = 3, msg: StringLiteral = "woof"](bar: Bar[a]):
    print(msg, a)

fn use_inferred():
    foo(Bar[9]())  # prints 'woof 9'

虽然 foo 函数的 parametera, 具有默认值,但 Mojo 使用了从 argumentbar 中推断出的 a 的值(实际上 a 的默认值永远不会被使用,它是由 bar 来决定的)。

我们还可以在 struct 中使用可选和关键字 parameter。例如

struct KwParamStruct[greeting: String = "Hello", name: String = "🔥mojo🔥"]:
    fn __init__(inout self):
        print(greeting, name)

fn use_kw_params():
    var a = KwParamStruct[]()                 # prints 'Hello 🔥mojo🔥'
    var b = KwParamStruct[name="World"]()     # prints 'Hello World'
    var c = KwParamStruct[greeting="Hola"]()  # prints 'Hola 🔥mojo🔥'

注意

Mojo 支持仅位置和仅关键字 parameter,其规则与 argument 相同。

可变参数

Mojo 还支持可变 parameter,与可变 argument 类似:

struct MyTensor[*dimensions: Int]:
    pass

目前,可变 parameter 有一些可变 argument 没有的限制:

  • 可变 parameter 必须是同质的,即所有值必须是同一类型。
  • parameter 的类型必须是寄存器可传递类型。
  • parameter 的值不会自动转换为 VariadicList,因此需要手动创建。例如
fn sum_params[*values: Int]() -> Int:
    alias list = VariadicList(values)
    var sum = 0
    for v in list:
        sum += v
    return sum

目前还不支持可变关键字 parameter(例如 **kwparams)。

参数表达式

参数表达式(parameter)可以是任何 Mojo 代码表达式,可以把 parameter 当做是函数的 argument 一样使用(与运行时代码具有相同的语法和类型),支持函数调用以及算术操作。

例如,您可能想定义一个辅助函数来连接两个 SIMD 向量:

fn concat[ty: DType, len1: Int, len2: Int](
        lhs: SIMD[ty, len1], rhs: SIMD[ty, len2]) -> SIMD[ty, len1+len2]:

    var result = SIMD[ty, len1 + len2]()
    for i in range(len1):
        result[i] = SIMD[ty, 1](lhs[i])
    for j in range(len2):
        result[len1 + j] = SIMD[ty, 1](rhs[j])
    return result

var a = SIMD[DType.float32, 2](1, 2)
var x = concat[DType.float32, 2, 2](a, a)

print('result type:', x.element_type, 'length:', len(x))
result type: float32 length: 4

请注意,结果长度是输入向量长度的总和,可以用简单的 + 运算来表示。

编译时编程

虽然简单的表达式很有用,但有时你想编写更复杂的编译时逻辑,例如循环和递归等。

下面是一个 “tree reduction” 算法示例,递归求和:

fn slice[ty: DType, new_size: Int, size: Int](
        x: SIMD[ty, size], offset: Int) -> SIMD[ty, new_size]:
    var result = SIMD[ty, new_size]()
    for i in range(new_size):
        result[i] = SIMD[ty, 1](x[i + offset])
    return result

fn reduce_add[ty: DType, size: Int](x: SIMD[ty, size]) -> Int:
    @parameter
    if size == 1:
        return int(x[0])
    elif size == 2:
        return int(x[0]) + int(x[1])

    alias half_size = size // 2
    var lhs = slice[ty, half_size, size](x, 0)
    var rhs = slice[ty, half_size, size](x, half_size)
    return reduce_add[ty, half_size](lhs + rhs)

var x = SIMD[DType.index, 4](1, 2, 3, 4)
print(x)
print("Elements sum:", reduce_add(x))

它利用 @parameter 装饰器创建了一个参数 if 条件,确保只有 if 语句的有效分支被编译到程序中。

类型就是参数表达式

不仅类型中可以使用参数表达式,类型注释也可以是任意表达式。Mojo 中的类型有一种特殊的元类型,允许定义类型参数算法和函数。

例如,我们可以创建一个支持任意寄存器可传递元素类型的简化数组(通过 AnyRegType 参数):

struct Array[T: AnyRegType]:
    var data: Pointer[T]
    var size: Int

    fn __init__(inout self, size: Int, value: T):
        self.size = size
        self.data = Pointer[T].alloc(self.size)
        for i in range(self.size):
            self.data[i] = value

    fn __getitem__(self, i: Int) -> T:
        return self.data[i]

    fn __del__(owned self):
        self.data.free()

var v = Array[Float32](4, 3.14)
print(v[0], v[1], v[2], v[3])

请注意,T 即被用作 value argument 的类型,也被用于 __getitem__ 方法的返回类型。

还有许多其他情况可以从更高级的参数使用中获益。例如,你可以并行执行闭包 N

fn parallelize[func: fn (Int) -> None](num_work_items: Int):
    for i in range(num_work_items):
        func(i)

另一个重要的例子是可变泛型,在这种情况下,算法或数据结构可能需要通过异构类型列表(如元组)来定义。

现在,Mojo 还不完全支持这种情况,需要使用 MLIR 来编写。未来,将可以使用纯 Mojo 中实现。

别名(alias)

var 定义的是运行时值,而有时候我们需要一种定义编译时临时值的方法。Mojo 提供了 alias 来进行声明。

例如,DType 结构使用 alias 实现了一个简单的枚举(实际的 DType 实现细节略有不同):

struct DType:
    var value : UI8
    alias invalid = DType(0)
    alias bool = DType(1)
    alias int8 = DType(2)
    alias uint8 = DType(3)
    alias int16 = DType(4)
    alias int16 = DType(5)
    ...
    alias float32 = DType(15)

这样,就可以自然地将 DType.float32 用作参数表达式(也可用作运行时值)。

因为类型是编译时表达式,所以可以定义类型的别名:

alias Float16 = SIMD[DType.float16, 1]
alias UInt8 = SIMD[DType.uint8, 1]

var x : Float16   # Float16 works like a "typedef"

var 变量一样,别名也服从作用域规则,因此可以在函数中定义局部别名。

类型绑定

为所有 parameter 指定参数类型被称为完全绑定类型。也就是说,它的所有参数都与值绑定。

前面的例子中,你只能实例化一个完全绑定的类型。不过,在某些情况下可以不绑定或部分绑定。

例如,可以使用别名部分绑定类型,用来创建一个需要较少参数的新类型:

alias Bytes = SIMD[DType.uint8, _]
var b = Bytes[8]()

在这里,BytesSIMD 字节向量的类型别名。参数列表中的下划线 _ 表示第二个参数(宽度)是非绑定的。可以在后面使用 Bytes 时再指定宽度参数。

例如,给定以下类型

struct MyType[s: String, i: Int, i2: Int, b: Bool = True]:
    pass

它可以以下列形式出现在代码中:

  • 完全绑定,指定了所有 parameter
MyType["Hello", 3, 4, True]
  • 部分绑定,已指定部分但非全部 parameter
MyType["Hola", _, _, True]
  • 不绑定,未指定任何 parameter
MyType[_, _, _, _]

您还可以使用表达式 *_ 在参数列表末尾取消绑定任意数量的位置参数。例如

# These two types are equivalent
MyType["Hello", *_]
MyType["Hello", _, _, _]

当使用 _*_ 表达式显式为一个 parameter 解绑时,必须为该 parameter 指定一个值才能使用该类型。原始类型声明中的默认值将被忽略。

部分绑定和未绑定参数类型可在某些情况下使用,接着看下面的例子

省略参数

Mojo 还支持不绑定 parameter 的另一种格式,即在表达式中省略参数:

# Partially bound
MyType["Hi there"]
# Unbound
MyType

这种格式与上述显式解绑语法不同,省略参数的默认值会立即绑定。例如,以下表达式是等价的:

MyType["Hi there"]
# equivalent to
MyType["Hi there", _, _, True] # Uses the default value for `b`

目前支持这种格式是为了向后兼容。预计在未来淘汰这种格式,而使用显式解除绑定语法。

自动参数化

Mojo 支持函数的 “自动” 参数化,如果函数的 argument 类型是部分绑定或不绑定类型,未绑定的 parameter 会自动添加为函数的 parameter。举个例子:

fn print_params(vec: SIMD[*_]):
    print(vec.type)
    print(vec.size)

var v = SIMD[DType.float64, 4](1.0, 2.0, 3.0, 4.0)
print_params(v)

在上例中,print_params 函数已自动参数化。vec 是一个 SIMD[*_] 类型的 argument。这是一个未绑定参数的类型,也就是说,它没有为该类型指定任何参数值。

Mojovec 的不绑定的 parameter 视为函数的隐式 parameter。相当于下面的代码:

fn print_params[t: DType, s: Int](vec: SIMD[t, s]):
    print(vec.type)
    print(vec.size)

调用 print_params 时,必须向它传递 SIMD 类型的具体实例,即一个指定了所有 parameter 的实例,如 SIMD[DType.float64, 4]Mojo 编译器会根据输入 argument 推导出 parameter 的值。

对于手动参数化函数,可以通过名称访问输入参数(ts)。而对于自动参数化函数,则可以将 parameter 作为 argument 的属性来访问(例如 vec.type)。

这种访问类型输入 parameter 的功能并不是自动参数化函数所特有的,可以在任何地方使用它。

您可以将参数化类型的输入 parameter 作为类型本身的属性来访问:

fn on_type():
    print(SIMD[DType.float32, 2].size) # prints 2

或作为该类型实例的属性:

fn on_instance():
    var x = SIMD[DType.int32, 2](4, 8)
    print(x.type) # prints int32

您甚至可以在函数签名中使用这种语法,根据 argumentparameter 定义函数的 argument 和返回类型。

例如,如果您想让函数接受两个类型和大小相同的 SIMD 向量,可以这样:

fn interleave(v1: SIMD, v2: __type_of(v1)) -> SIMD[v1.type, v1.size*2]:
    var result = SIMD[v1.type, v1.size*2]()
    for i in range(v1.size):
        result[i*2] = SIMD[v1.type, 1](v1[i])
        result[i*2+1] = SIMD[v1.type, 1](v2[i])
    return result

var a = SIMD[DType.int16, 4](1, 2, 3, 4)
var b = SIMD[DType.int16, 4](0, 0, 0, 0)
var c = interleave(a, b)
print(c)

如果只想匹配 argument 的类型,可以使用神奇的 __type_of。在这种情况下,它比写等价的 SIMD[v1.type, v1.size] 更方便、更紧凑。

部分绑定类型自动参数化

Mojo 还支持部分绑定类型的自动参数化。例如,假设我们有一个带有三个 parameterFudge 结构:

@value
struct Fudge[sugar: Int, cream: Int, chocolate: Int = 7](Stringable):
    fn __str__(self) -> String:
        var values = StaticIntTuple[3](sugar, cream, chocolate)
        return str("Fudge") + values

我们可以定义一个绑定一个 parameter(部分绑定)的函数:

fn eat(f: Fudge[5, *_]):
    print("Ate " + str(f))

eat 函数接收一个 Fudge 结构,其中第一个 parametersugar)的值绑定为 5。第二和第三个参数(creamchocolate)未绑定。

未绑定的 creamchocolate 成为 eat 函数的隐式输入 parameter。相当于

fn eat[cr: Int, ch: Int](f: Fudge[5, cr, ch]):
    print("Ate " + str(f))

我们可以通过传入一个绑定了creamchocolate 的实例来调用函数:

eat(Fudge[5, 5, 7]())
eat(Fudge[5, 8, 9]())

但是如果 sugar 的值不等于 5,那么编译会失败,因为它与参数类型不匹配:

eat(Fudge[12, 5, 7]()) 
# ERROR: invalid call to 'eat': argument #0 cannot be converted from 'Fudge[12, 5, 7]' to 'Fudge[5, 5, 7]'

您还可以更自由地指定未绑定的 parameter。例如,只绑定 cream

fn devour(f: Fudge[_, 6, _]):
    print(str("Devoured ") + str(f))

相当于下面的代码

fn devour[su: Int, ch: Int](f: Fudge[su, 6, ch]):
        print(str("Devoured ") + str(f))

您也可以通过关键字指定 parameter,或混合使用位置和关键字 parameter

fn devour(f: Fudge[_, chocolate=_, cream=6]):
    print(str("Devoured ") + str(f))

所有三个版本的 devour 函数都可以使用如下方式调用:

devour(Fudge[3, 6, 9]())
devour(Fudge[4, 6, 8]())

省略参数

您也可以通过省略 parameter 来指定不绑定或部分绑定类型:例如:

fn nibble(f: Fudge[5]):
    print("Ate " + str(f))

在这里,Fudge[5] 的工作原理与 Fudge[5, *_] 相同,只是在处理带默认值的参数时不同。

Fudge[5] 并不会丢弃 chocolate 的默认值,而是立即绑定默认值,使其等价于 Fudge[5, _, 7]

如果部使用 chocolate 的默认值将会报错

nibble(Fudge[5, 5, 9]())
# ERROR: invalid call to 'eat': argument #0 cannot be converted from 'Fudge[5, 5, 9]' to 'Fudge[5, 5, 7]'

对与省略不绑定参数的方式最终将会被弃用,推荐使用 _*_ 来显式地解绑参数。

  • 33
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

名本无名

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值