经办公室小伙伴介绍,jax包代替numpy可大大提升运算速度
例子
## pip install jax
## pip install jaxlib
from jax import random, jit
import jax.numpy as jnp
x=np.random.randn(2000000)
@jit ###调用jnp前必有
def abc(x):
a=jnp.log10(x)
b=10**a
c=jnp.sum(b)
return b