Mojo 学习 —— 参数化:编译时元编程
文章目录
介绍
所谓元编程,即编写可以生成或修改其他代码的代码。
Python
的装饰器、元类就是一种元编程。虽然非常灵活且高效,但由于它们是动态的,会带来更多的运行时开销。
而其他语言有静态或编译时元编程特性,比如 C
预处理器宏和 C++
模板。这些功能存在一些局限性,且难以使用。
Mojo
为了提供功能强大、易于使用且运行成本为零的元编程。它添加了编译时参数(parameter
),通过在编译时解析参数(类似于程序运行时解析的函数参数(argument
)),使其在运行时是为常量,提高运行效率。
不同于其他语言,在 Mojo
种 parameter
和 parameter expression
指的是编译时值,而 argument
和 expression
指的是运行时值。
为了区分这两个概念,在本章中我们都使用英文来表示。
参数化函数
要定义参数化函数,可以在函数 argument
前使用方括号来添加 parameter
。其定义方式类似于 argument
:在名称后面跟一个冒号和一个类型(必填)。例如
fn func[param: Dtype](args: Dtype):
...
来看一个具体的例子
fn repeat[count: Int](msg: String):
@unroll
for i in range(count):
print(msg)
此处
@unroll
指令可在编译时展开循环。
调用参数化函数时,需要为 parameter
提供值,就像 argument
一样,例如
repeat[3]("Hello")
编译器会在编译时解析 parameter
的值,并为每个唯一的 parameter
创建一个具体版本的 repeat[]()
函数。
在解析了之后的 repeat[3]()
函数大致相当于:
fn repeat_3(msg: String):
print(msg)
print(msg)
print(msg)
注意:
这并不代表编译器生成的实际代码。在解析
parameter
时,Mojo
代码已经转换为MLIR
的中间表示形式。
如果编译器无法将所有 parameter
值解析为常量值,编译就会失败。
参数化结构体
您还可以为结构体添加 parameter
。例如,通用数组类型可能包含如下代码:
struct GenericArray[T: AnyRegType]:
var data: Pointer[T]
var size: Int
fn __init__(inout self, *elements: T):
self.size = len(elements)
self.data = Pointer[T].alloc(self.size)
for i in range(self.size):
self.data[i] = elements[i]
fn __del__(owned self):
self.data.free()
fn __getitem__(self, i: Int) raises -> T:
if (i < self.size):
return self.data[i]
else:
raise Error("Out of bounds")
该结构体有一个名为 T
的 parameter
,它是一个占位符,表示要存储在数组中的数据类型,有时也称为类型 parameter
。T
可以在结构体范围内使用。
T
的类型为 AnyRegType
,这是一种元类型,代表任何可通过寄存器的类型。这意味着我们的 GenericArray
可以保存固定大小的数据类型,如整数和浮点数,但不能保存动态分配的数据,如字符串或向量。
注意:
使用大写的
T
作为parameter
对编译器来说没有任何特殊意义(只是一个名称而已),但在许多语言中,使用简短的名称来表示类型参数是一种惯例,通常默认为T
。
与参数化函数一样,使用参数化结构体时也需要传递参数值。在这种情况下,创建 GenericArray
实例时,需要指定要存储的类型,如 Int
或 Float64
(Mojo
种的类型是一个有效的编译时值)。
下面是一个使用 GenericArray
的示例:
var array = GenericArray[Int](1, 2, 3, 4)
for i in range(array.size):
print(array[i], sep=" ", end="")
# 1 2 3 4
参数化结构体可以使用 Self
类型来表示结构体的具体实例。例如,你可以为 GenericArray
添加一个静态工厂方法,其签名如下:
struct GenericArray[T: AnyRegType]:
...
@staticmethod
fn splat(count: Int, value: T) -> Self:
...
在这里,Self
相当于编写 GenericArray[T]
。也就是说,你可以像这样调用 splat
方法:
GenericArray[Float64].splat(8, 0)
该方法返回一个 GenericArray[Float64]
的实例。
参数重载
函数和方法的 parameter
也可以重载,重载解析逻辑根据以下规则(按优先顺序排列):
- 隐式转换最少
- 没有可变的
arguments
- 没有可变的
parameters
parameter
数量最少- 不是
@staticmethod
如果应用这些规则后,候选者不止一个,则重载解析失败。例如
@register_passable("trivial")
struct MyInt:
"""A type that is implicitly convertible to `Int`."""
var value: Int
@always_inline("nodebug")
fn __init__(_a: Int) -> Self:
return Self {value: _a}
fn foo[x: MyInt, a: Int]():
print("foo[x: MyInt, a: Int]()")
fn foo[x: MyInt, y: MyInt]():
print("foo[x: MyInt, y: MyInt]()")
fn bar[a: Int](b: Int):
print("bar[a: Int](b: Int)")
fn bar[a: Int](*b: Int):
print("bar[a: Int](*b: Int)")
fn bar[*a: Int](b: Int):
print("bar[*a: Int](b: Int)")
fn parameter_overloads[a: Int, b: Int, x: MyInt]():
foo[x, a]()
bar[a](b)
bar[a, a, a](b)
struct MyStruct:
fn __init__(inout self):
pass
fn foo(inout self):
print("calling instance menthod")
@staticmethod
fn foo():
print("calling static method")
fn test_static_overload():
var a = MyStruct()
a.foo()
fn main():
parameter_overloads[1, 2, MyInt(3)]()
test_static_overload()
foo[x: MyInt, a: Int]()
bar[a: Int](b: Int)
bar[*a: Int](b: Int)
calling instance menthod
使用参数化类型和函数
通过给方括号中的 parameter
传递值,可以实例化参数类型和函数。
例如,对于 SIMD
类型
struct SIMD[type: DType, size: Int]: ...
type
指定了数据类型,size
指定 SIMD
向量的长度(必须是 2
的幂次):
# Make a vector of 4 floats.
var small_vec = SIMD[DType.float32, 4](1.0, 2.0, 3.0, 4.0)
# Make a big vector containing 1.0 in float16 format.
var big_vec = SIMD[DType.float16, 32].splat(1.0)
# Do some math and convert the elements to float32.
var bigger_vec = (big_vec+big_vec).cast[DType.float32]()
# You can write types out explicitly if you want of course.
var bigger_vec2 : SIMD[DType.float32, 32] = bigger_vec
print('small_vec type:', small_vec.element_type, 'length:', len(small_vec))
print('bigger_vec2 type:', bigger_vec2.element_type, 'length:', len(bigger_vec2))
small_vec type: float32 length: 4
bigger_vec2 type: float32 length: 32
需要注意的是,cast
方法还需要一个参数来指定你想要转换的类型(即需要一个目标类型)。
我们还可以使用 parameter
参数化来定义针对 SIMD
但与类型和宽度无关的算法:
from math import sqrt
fn rsqrt[dt: DType, width: Int](x: SIMD[dt, width]) -> SIMD[dt, width]:
return 1 / sqrt(x)
var v = SIMD[DType.float16, 4](42)
print(rsqrt(v))
x
参数实际上是基于函数 rsqrt
的 parameter
创建的 SIMD
类型。运行时程序可以使用编译时参数的值,但编译时参数表达式不能使用运行时值。
Mojo
编译器会根据传入的参数 x
值推断其 parameter
,就像明确写入 rsqrt[dt, width](x)
一样。
还有就是,rsqrt
选择将第二个 parameter
命名为 width
,而 SIMD
类型将其命名为 size
,也没有问题。
可选和关键字参数
parameter
类似于 argument
,也支持可选或关键字形式。例如
fn speak[a: Int = 3, msg: StringLiteral = "woof"]():
print(msg, a)
fn use_defaults() raises:
speak() # prints 'woof 3'
speak[5]() # prints 'woof 5'
speak[7, "meow"]() # prints 'meow 7'
speak[msg="baaa"]() # prints 'baaa 3'
speak
函数的两个 parametr
都有默认值,可以两个都不指定、指定一个或以关键字的形式指定 parameter
的值
当调用函数时,Mojo
可以推断出参数值。也就是说,它可以从 argument
的类型反推出 parameter
的类型。例如:
@value
struct Bar[v: Int]:
pass
fn foo[a: Int = 3, msg: StringLiteral = "woof"](bar: Bar[a]):
print(msg, a)
fn use_inferred():
foo(Bar[9]()) # prints 'woof 9'
虽然 foo
函数的 parameter
:a
, 具有默认值,但 Mojo
使用了从 argument
:bar
中推断出的 a
的值(实际上 a 的默认值永远不会被使用,它是由 bar
来决定的)。
我们还可以在 struct
中使用可选和关键字 parameter
。例如
struct KwParamStruct[greeting: String = "Hello", name: String = "🔥mojo🔥"]:
fn __init__(inout self):
print(greeting, name)
fn use_kw_params():
var a = KwParamStruct[]() # prints 'Hello 🔥mojo🔥'
var b = KwParamStruct[name="World"]() # prints 'Hello World'
var c = KwParamStruct[greeting="Hola"]() # prints 'Hola 🔥mojo🔥'
注意:
Mojo
支持仅位置和仅关键字parameter
,其规则与argument
相同。
可变参数
Mojo
还支持可变 parameter
,与可变 argument
类似:
struct MyTensor[*dimensions: Int]:
pass
目前,可变 parameter
有一些可变 argument
没有的限制:
- 可变
parameter
必须是同质的,即所有值必须是同一类型。 parameter
的类型必须是寄存器可传递类型。parameter
的值不会自动转换为VariadicList
,因此需要手动创建。例如
fn sum_params[*values: Int]() -> Int:
alias list = VariadicList(values)
var sum = 0
for v in list:
sum += v
return sum
目前还不支持可变关键字 parameter
(例如 **kwparams
)。
参数表达式
参数表达式(parameter
)可以是任何 Mojo
代码表达式,可以把 parameter
当做是函数的 argument
一样使用(与运行时代码具有相同的语法和类型),支持函数调用以及算术操作。
例如,您可能想定义一个辅助函数来连接两个 SIMD
向量:
fn concat[ty: DType, len1: Int, len2: Int](
lhs: SIMD[ty, len1], rhs: SIMD[ty, len2]) -> SIMD[ty, len1+len2]:
var result = SIMD[ty, len1 + len2]()
for i in range(len1):
result[i] = SIMD[ty, 1](lhs[i])
for j in range(len2):
result[len1 + j] = SIMD[ty, 1](rhs[j])
return result
var a = SIMD[DType.float32, 2](1, 2)
var x = concat[DType.float32, 2, 2](a, a)
print('result type:', x.element_type, 'length:', len(x))
result type: float32 length: 4
请注意,结果长度是输入向量长度的总和,可以用简单的 +
运算来表示。
编译时编程
虽然简单的表达式很有用,但有时你想编写更复杂的编译时逻辑,例如循环和递归等。
下面是一个 “tree reduction” 算法示例,递归求和:
fn slice[ty: DType, new_size: Int, size: Int](
x: SIMD[ty, size], offset: Int) -> SIMD[ty, new_size]:
var result = SIMD[ty, new_size]()
for i in range(new_size):
result[i] = SIMD[ty, 1](x[i + offset])
return result
fn reduce_add[ty: DType, size: Int](x: SIMD[ty, size]) -> Int:
@parameter
if size == 1:
return int(x[0])
elif size == 2:
return int(x[0]) + int(x[1])
alias half_size = size // 2
var lhs = slice[ty, half_size, size](x, 0)
var rhs = slice[ty, half_size, size](x, half_size)
return reduce_add[ty, half_size](lhs + rhs)
var x = SIMD[DType.index, 4](1, 2, 3, 4)
print(x)
print("Elements sum:", reduce_add(x))
它利用 @parameter
装饰器创建了一个参数 if
条件,确保只有 if
语句的有效分支被编译到程序中。
类型就是参数表达式
不仅类型中可以使用参数表达式,类型注释也可以是任意表达式。Mojo
中的类型有一种特殊的元类型,允许定义类型参数算法和函数。
例如,我们可以创建一个支持任意寄存器可传递元素类型的简化数组(通过 AnyRegType
参数):
struct Array[T: AnyRegType]:
var data: Pointer[T]
var size: Int
fn __init__(inout self, size: Int, value: T):
self.size = size
self.data = Pointer[T].alloc(self.size)
for i in range(self.size):
self.data[i] = value
fn __getitem__(self, i: Int) -> T:
return self.data[i]
fn __del__(owned self):
self.data.free()
var v = Array[Float32](4, 3.14)
print(v[0], v[1], v[2], v[3])
请注意,T
即被用作 value
argument
的类型,也被用于 __getitem__
方法的返回类型。
还有许多其他情况可以从更高级的参数使用中获益。例如,你可以并行执行闭包 N
次
fn parallelize[func: fn (Int) -> None](num_work_items: Int):
for i in range(num_work_items):
func(i)
另一个重要的例子是可变泛型,在这种情况下,算法或数据结构可能需要通过异构类型列表(如元组)来定义。
现在,Mojo
还不完全支持这种情况,需要使用 MLIR
来编写。未来,将可以使用纯 Mojo
中实现。
别名(alias)
var
定义的是运行时值,而有时候我们需要一种定义编译时临时值的方法。Mojo
提供了 alias
来进行声明。
例如,DType
结构使用 alias
实现了一个简单的枚举(实际的 DType
实现细节略有不同):
struct DType:
var value : UI8
alias invalid = DType(0)
alias bool = DType(1)
alias int8 = DType(2)
alias uint8 = DType(3)
alias int16 = DType(4)
alias int16 = DType(5)
...
alias float32 = DType(15)
这样,就可以自然地将 DType.float32
用作参数表达式(也可用作运行时值)。
因为类型是编译时表达式,所以可以定义类型的别名:
alias Float16 = SIMD[DType.float16, 1]
alias UInt8 = SIMD[DType.uint8, 1]
var x : Float16 # Float16 works like a "typedef"
与 var
变量一样,别名也服从作用域规则,因此可以在函数中定义局部别名。
类型绑定
为所有 parameter
指定参数类型被称为完全绑定类型。也就是说,它的所有参数都与值绑定。
前面的例子中,你只能实例化一个完全绑定的类型。不过,在某些情况下可以不绑定或部分绑定。
例如,可以使用别名部分绑定类型,用来创建一个需要较少参数的新类型:
alias Bytes = SIMD[DType.uint8, _]
var b = Bytes[8]()
在这里,Bytes
是 SIMD
字节向量的类型别名。参数列表中的下划线 _
表示第二个参数(宽度)是非绑定的。可以在后面使用 Bytes
时再指定宽度参数。
例如,给定以下类型
struct MyType[s: String, i: Int, i2: Int, b: Bool = True]:
pass
它可以以下列形式出现在代码中:
- 完全绑定,指定了所有
parameter
:
MyType["Hello", 3, 4, True]
- 部分绑定,已指定部分但非全部
parameter
:
MyType["Hola", _, _, True]
- 不绑定,未指定任何
parameter
:
MyType[_, _, _, _]
您还可以使用表达式 *_
在参数列表末尾取消绑定任意数量的位置参数。例如
# These two types are equivalent
MyType["Hello", *_]
MyType["Hello", _, _, _]
当使用 _
或 *_
表达式显式为一个 parameter
解绑时,必须为该 parameter
指定一个值才能使用该类型。原始类型声明中的默认值将被忽略。
部分绑定和未绑定参数类型可在某些情况下使用,接着看下面的例子
省略参数
Mojo
还支持不绑定 parameter
的另一种格式,即在表达式中省略参数:
# Partially bound
MyType["Hi there"]
# Unbound
MyType
这种格式与上述显式解绑语法不同,省略参数的默认值会立即绑定。例如,以下表达式是等价的:
MyType["Hi there"]
# equivalent to
MyType["Hi there", _, _, True] # Uses the default value for `b`
目前支持这种格式是为了向后兼容。预计在未来淘汰这种格式,而使用显式解除绑定语法。
自动参数化
Mojo
支持函数的 “自动” 参数化,如果函数的 argument
类型是部分绑定或不绑定类型,未绑定的 parameter
会自动添加为函数的 parameter
。举个例子:
fn print_params(vec: SIMD[*_]):
print(vec.type)
print(vec.size)
var v = SIMD[DType.float64, 4](1.0, 2.0, 3.0, 4.0)
print_params(v)
在上例中,print_params
函数已自动参数化。vec
是一个 SIMD[*_]
类型的 argument
。这是一个未绑定参数的类型,也就是说,它没有为该类型指定任何参数值。
Mojo
将 vec
的不绑定的 parameter
视为函数的隐式 parameter
。相当于下面的代码:
fn print_params[t: DType, s: Int](vec: SIMD[t, s]):
print(vec.type)
print(vec.size)
调用 print_params
时,必须向它传递 SIMD
类型的具体实例,即一个指定了所有 parameter
的实例,如 SIMD[DType.float64, 4]
。Mojo
编译器会根据输入 argument
推导出 parameter
的值。
对于手动参数化函数,可以通过名称访问输入参数(t
和 s
)。而对于自动参数化函数,则可以将 parameter
作为 argument
的属性来访问(例如 vec.type
)。
这种访问类型输入 parameter
的功能并不是自动参数化函数所特有的,可以在任何地方使用它。
您可以将参数化类型的输入 parameter
作为类型本身的属性来访问:
fn on_type():
print(SIMD[DType.float32, 2].size) # prints 2
或作为该类型实例的属性:
fn on_instance():
var x = SIMD[DType.int32, 2](4, 8)
print(x.type) # prints int32
您甚至可以在函数签名中使用这种语法,根据 argument
的 parameter
定义函数的 argument
和返回类型。
例如,如果您想让函数接受两个类型和大小相同的 SIMD
向量,可以这样:
fn interleave(v1: SIMD, v2: __type_of(v1)) -> SIMD[v1.type, v1.size*2]:
var result = SIMD[v1.type, v1.size*2]()
for i in range(v1.size):
result[i*2] = SIMD[v1.type, 1](v1[i])
result[i*2+1] = SIMD[v1.type, 1](v2[i])
return result
var a = SIMD[DType.int16, 4](1, 2, 3, 4)
var b = SIMD[DType.int16, 4](0, 0, 0, 0)
var c = interleave(a, b)
print(c)
如果只想匹配 argument
的类型,可以使用神奇的 __type_of
。在这种情况下,它比写等价的 SIMD[v1.type, v1.size]
更方便、更紧凑。
部分绑定类型自动参数化
Mojo
还支持部分绑定类型的自动参数化。例如,假设我们有一个带有三个 parameter
的 Fudge
结构:
@value
struct Fudge[sugar: Int, cream: Int, chocolate: Int = 7](Stringable):
fn __str__(self) -> String:
var values = StaticIntTuple[3](sugar, cream, chocolate)
return str("Fudge") + values
我们可以定义一个绑定一个 parameter
(部分绑定)的函数:
fn eat(f: Fudge[5, *_]):
print("Ate " + str(f))
eat
函数接收一个 Fudge
结构,其中第一个 parameter
(sugar
)的值绑定为 5
。第二和第三个参数(cream
和 chocolate
)未绑定。
未绑定的 cream
和 chocolate
成为 eat
函数的隐式输入 parameter
。相当于
fn eat[cr: Int, ch: Int](f: Fudge[5, cr, ch]):
print("Ate " + str(f))
我们可以通过传入一个绑定了cream
和 chocolate
的实例来调用函数:
eat(Fudge[5, 5, 7]())
eat(Fudge[5, 8, 9]())
但是如果 sugar
的值不等于 5,那么编译会失败,因为它与参数类型不匹配:
eat(Fudge[12, 5, 7]())
# ERROR: invalid call to 'eat': argument #0 cannot be converted from 'Fudge[12, 5, 7]' to 'Fudge[5, 5, 7]'
您还可以更自由地指定未绑定的 parameter
。例如,只绑定 cream
。
fn devour(f: Fudge[_, 6, _]):
print(str("Devoured ") + str(f))
相当于下面的代码
fn devour[su: Int, ch: Int](f: Fudge[su, 6, ch]):
print(str("Devoured ") + str(f))
您也可以通过关键字指定 parameter
,或混合使用位置和关键字 parameter
fn devour(f: Fudge[_, chocolate=_, cream=6]):
print(str("Devoured ") + str(f))
所有三个版本的 devour
函数都可以使用如下方式调用:
devour(Fudge[3, 6, 9]())
devour(Fudge[4, 6, 8]())
省略参数
您也可以通过省略 parameter
来指定不绑定或部分绑定类型:例如:
fn nibble(f: Fudge[5]):
print("Ate " + str(f))
在这里,Fudge[5]
的工作原理与 Fudge[5, *_]
相同,只是在处理带默认值的参数时不同。
Fudge[5]
并不会丢弃 chocolate
的默认值,而是立即绑定默认值,使其等价于 Fudge[5, _, 7]
。
如果部使用 chocolate
的默认值将会报错
nibble(Fudge[5, 5, 9]())
# ERROR: invalid call to 'eat': argument #0 cannot be converted from 'Fudge[5, 5, 9]' to 'Fudge[5, 5, 7]'
对与省略不绑定参数的方式最终将会被弃用,推荐使用
_
和*_
来显式地解绑参数。