jax
是一个用于高性能数值计算的库,它是为了加速在 GPU、TPU 等加速硬件上执行的科学计算任务而设计的。jnp.tile
是 jax
库中的一个函数,功能类似于 NumPy 中的 numpy.tile
,用于在给定的方向上重复数组的元素。
import jax.numpy as jnp
arr = jnp.array([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
# arr 第一个维度重复二次,第二个维度重复三次
tiled_arr = jnp.tile(arr, [2,3])
print(tiled_arr)
import numpy as np
arr1 = np.array([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
# arr1 第一个维度重复二次,第二个维度重复三次
result = np.tile(arr1, [2,3])
print(result)