jnp.full() jnp.full(shape, fill_value, dtype=None) shape: 数组的形状; fill_value:用于填充数组的值; dtype: 数组的数据类型,默认为float32. 例如: import jax.numpy as jnp shape = (3, 3) array = jnp.full(shape, 0.1, dtype=jnp.float32) print(array)