使用@hk.without_apply_rng 装饰器时,被装饰的函数或方法内创建的 Haiku 模块将不会受到外部随机数生成器的影响,即不会应用到 Haiku 模块的参数初始化。这可以在某些情况下很有用,例如当你希望保持特定参数不变,而不受到外部的 RNG 影响。
使用 @hk.transform
装饰器时,被装饰的函数或方法将被包装成一个 Haiku 模块,并且可以通过调用 init
方法进行参数初始化。被装饰的函数通常接受输入参数并返回输出,而 Haiku 模块内部则定义了可训练的参数。这个装饰器提供了一种方便的方式来创建可训练的神经网络模型。
import haiku as hk
import jax
import jax.numpy as jnp
# 定义一个使用 RNG 的 Haiku 模块
def random_module(x):
return hk.Linear(10)(x)
# 定义一个简单的 Haiku 模块
# @hk.without_apply_rng 装饰器:声明一个函数或方法在其内部的 Haiku模块创建时不应用随机数生成器(RNG)
# @hk.transform 装饰器:将一个函数或方法转换为 Haiku 模块。
@hk.without_apply_rng
@hk.transform
def deterministic_function(x):
# 在该函数内创建的 Haiku 模块不受外部 RNG 影响
return random_module(x)
# 初始化模块的参数
# rng = jax.random.PRNGKey(4)
key_seq = hk.PRNGSequence(38)
print(key_seq) # <haiku._src.base.PRNGSequence object at 0x1210ab490>
# 每一次执行next(key_seq),生成不同的key_seq
print("next(key_seq)")
print(next(key_seq))
print(next(key_seq))
params = deterministic_function.init(next(key_seq), x)
# 调用模块
# 不需要 rng 参数
output_without_rng = deterministic_function.apply(params, x)
print("Output without RNG:", output_without_rng)
#random_module = hk.transform(random_module)
#params_1 = random_module.init(next(key_seq), x)
#output_with_rng = random_module.apply(params_1, next(key_seq), x)
#print("Output with RNG:", output_with_rng)
参考:
https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=haiku.transform#haiku.transform