haiku装饰器@hk.without_apply_rng、@hk.transform介绍

文章介绍了如何使用Haiku库中的@hk.without_apply_rng装饰器来防止被装饰函数内的Haiku模块受外部随机数生成器影响,以及@hk.transform装饰器将函数转换为可训练模型的过程。通过实例展示了如何在保持参数稳定的情况下创建和初始化Haiku模块。
摘要由CSDN通过智能技术生成

使用@hk.without_apply_rng 装饰器时,被装饰的函数或方法内创建的 Haiku 模块将不会受到外部随机数生成器的影响,即不会应用到 Haiku 模块的参数初始化。这可以在某些情况下很有用,例如当你希望保持特定参数不变,而不受到外部的 RNG 影响。

使用 @hk.transform 装饰器时,被装饰的函数或方法将被包装成一个 Haiku 模块,并且可以通过调用 init 方法进行参数初始化。被装饰的函数通常接受输入参数并返回输出,而 Haiku 模块内部则定义了可训练的参数。这个装饰器提供了一种方便的方式来创建可训练的神经网络模型。
 

import haiku as hk
import jax
import jax.numpy as jnp


# 定义一个使用 RNG 的 Haiku 模块
def random_module(x):
    return hk.Linear(10)(x)

# 定义一个简单的 Haiku 模块
# @hk.without_apply_rng 装饰器:声明一个函数或方法在其内部的 Haiku模块创建时不应用随机数生成器(RNG)
# @hk.transform 装饰器:将一个函数或方法转换为 Haiku 模块。
@hk.without_apply_rng
@hk.transform
def deterministic_function(x):
    # 在该函数内创建的 Haiku 模块不受外部 RNG 影响
    return random_module(x)

# 初始化模块的参数
# rng = jax.random.PRNGKey(4)

key_seq = hk.PRNGSequence(38)
print(key_seq) # <haiku._src.base.PRNGSequence object at 0x1210ab490>
# 每一次执行next(key_seq),生成不同的key_seq
print("next(key_seq)")
print(next(key_seq))
print(next(key_seq))
params = deterministic_function.init(next(key_seq), x)

# 调用模块
# 不需要 rng 参数
output_without_rng = deterministic_function.apply(params, x)
print("Output without RNG:", output_without_rng)

#random_module  = hk.transform(random_module)
#params_1 = random_module.init(next(key_seq), x)
#output_with_rng = random_module.apply(params_1, next(key_seq), x)
#print("Output with RNG:", output_with_rng)

参考:

https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=haiku.without_apply_rng#haiku.without_apply_rng

https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=haiku.transform#haiku.transform

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值