JAX的设计旨在实现函数式编程的概念,这样可以更容易地理解程序的行为,并避免由可变状态导致的常见错误。JAX数组的不可变性也是基于这个理念。因为在函数式编程中,函数不会修改它们的输入,而是创建新的对象作为输出。因此,在JAX中,所有的数组变换都是实现为纯函数,接受一个输入数组并返回一个新的数组。这样可以确保原始数组保持不变,从而提高代码的可读性和可维护性。
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
# ---------------------------------------------------------------------------
# TypeError Traceback (most recent call last)
# <ipython-input-7-6b90817377fe> in <module>()
# 1 # JAX: immutable arrays
# 2 x = jnp.arange(10)
# ----> 3 x[0] = 10
# TypeError: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
JAX提供了一种索引更新语法,用于更新单个元素,并返回更新后的副本:
y = x.at[0].set(10)
print(x)
print(y)
# [0 1 2 3 4 5 6 7 8 9]
# [10 1 2 3 4 5 6 7 8 9]
与此相反,NumPy数组默认是可变的,这意味着您可以直接修改数组的内容,而不创建新的数组。虽然这样可能很方便,但它也会使程序的行为更难以理解,并导致错误。例如,在多线程环境中使用可变的NumPy数组可能导致竞态条件等问题。
# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
# [10 1 2 3 4 5 6 7 8 9]
但是,有时候必须就地修改数组以提高性能,特别是在处理大型数据集时。为此,JAX提供了一些方法来必要时就地修改数组,例如使用jax.ops.index_update
函数或jax.lax
模块。这些操作在JAX代码中不常见,通常仅在需要优化性能时使用。它们仍然是纯函数,因为它们返回一个新的数组而不是修改原始数组。
总之,JAX的设计遵循函数式编程的理念,并通过不可变性来确保原始数组保持不变,以提高代码的可读性和可维护性。虽然JAX提供了一些方法来就地修改数组,但这些操作仍然是纯函数,通常仅在需要优化性能时使用。