为什么 JAX 数组是不可变的,而 NumPy 数组不是呢?

2 篇文章 0 订阅

JAX的设计旨在实现函数式编程的概念,这样可以更容易地理解程序的行为,并避免由可变状态导致的常见错误。JAX数组的不可变性也是基于这个理念。因为在函数式编程中,函数不会修改它们的输入,而是创建新的对象作为输出。因此,在JAX中,所有的数组变换都是实现为纯函数,接受一个输入数组并返回一个新的数组。这样可以确保原始数组保持不变,从而提高代码的可读性和可维护性。

# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10

# ---------------------------------------------------------------------------
# TypeError                                 Traceback (most recent call last)
# <ipython-input-7-6b90817377fe> in <module>()
#       1 # JAX: immutable arrays
#       2 x = jnp.arange(10)
# ----> 3 x[0] = 10

# TypeError: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

JAX提供了一种索引更新语法,用于更新单个元素,并返回更新后的副本:

y = x.at[0].set(10)
print(x)
print(y)

# [0 1 2 3 4 5 6 7 8 9]
# [10  1  2  3  4  5  6  7  8  9]

与此相反,NumPy数组默认是可变的,这意味着您可以直接修改数组的内容,而不创建新的数组。虽然这样可能很方便,但它也会使程序的行为更难以理解,并导致错误。例如,在多线程环境中使用可变的NumPy数组可能导致竞态条件等问题。

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)

# [10  1  2  3  4  5  6  7  8  9]

但是,有时候必须就地修改数组以提高性能,特别是在处理大型数据集时。为此,JAX提供了一些方法来必要时就地修改数组,例如使用jax.ops.index_update函数或jax.lax模块。这些操作在JAX代码中不常见,通常仅在需要优化性能时使用。它们仍然是纯函数,因为它们返回一个新的数组而不是修改原始数组。
总之,JAX的设计遵循函数式编程的理念,并通过不可变性来确保原始数组保持不变,以提高代码的可读性和可维护性。虽然JAX提供了一些方法来就地修改数组,但这些操作仍然是纯函数,通常仅在需要优化性能时使用。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值