Windows 安装 pip install jaxlib===0.3.5 -f https://whls.blob.core.windows.net/unstable/index.html pip install jaxlib[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver 安装完后引用报错。
linux版本:
jax.numpy是CPU、GPU和TPU上的numpy,具有出色的自动差异化功能,可用于高性能机器学习研究。
首先我们得安装jax
pip install jax jaxlib
以下参考:机器学习加速利器jax,让numpy加速30倍_喝粥也会胖的唐僧的博客-CSDN博客_jax numpy
先试一下原生的numpy
import numpy as np
import time
x = np.random.random([5000, 5000]).astype(np.float32)
try:
st=time.time()
y=np.matmul(x, x)
except Exception:
print("erro")
print(time.time()-st)
print(y)
运行结果:
-
[root@node opt]# python np.py
-
4.424036026000977
-
[[1236.3004 1240.3048 1211.4501 ... 1225.7804 1237.1368 1235.1566]
-
[1235.5778 1246.7327 1208.7142 ... 1238.117 1232.439 1226.5779]
-
[1235.0111 1244.4628 1211.5264 ... 1238.5541 1246.9045 1244.6909]
-
...
-
[1229.7677 1238.8345 1210.4467 ... 1219.8604 1234.0862 1220.1482]
-
[1231.9464 1251.9636 1212.1384 ... 1235.8513 1236.8677 1240.5355]
-
[1254.0636 1265.74 1241.6528 ... 1245.015 1259.153 1247.0613]]
再来试一下jax带的numpy
import jax.numpy as np
from jax import random
import time
x = random.uniform(random.PRNGKey(0), [5000, 5000])
st=time.time()
try:
y=np.matmul(x, x)
except Exception:
print("erro")
print(time.time()-st)
print(y)
-
[root@node opt]# python jax_np.py
-
/opt/AN/lib/python3.7/site-packages/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU.
-
warnings.warn('No GPU/TPU found, falling back to CPU.')
-
0.013895750045776367
-
[[1261.0647 1244.5797 1237.2269 ... 1264.7208 1246.0367 1260.5391]
-
[1256.1 1239.737 1237.5562 ... 1257.1333 1243.5856 1243.5979]
-
[1261.2687 1239.5006 1250.6697 ... 1259.8387 1250.6825 1248.5712]
-
...
-
[1265.9805 1230.9077 1244.4961 ... 1264.2374 1241.5995 1244.9274]
-
[1262.9971 1253.961 1256.2424 ... 1266.3489 1255.1581 1274.1865]
-
[1273.3524 1252.4921 1261.0496 ... 1273.2394 1272.829 1267.7483]]
我们可以看到,没有jax的numpy运行了差不多4.4秒,而带了jax的numpy直接才0.014,速度基本上提升了30倍。也太快了
同样地,jax下面还有jax.scipy等代替原生的scipy
To install a CPU-only version, which might be useful for doing local development on a laptop, you can run
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
nvcc --version
pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install jax
pip install jaxlib
Windows不支持,替换方法:
错误:scale = jax.lax.rsqrt(np.maximum(var * fan_in, eps)) * gain
改之后:scale = torch.rsqrt(np.maximum(var * fan_in, eps)) * gain
'celu': lambda x: torch_F.celu(x) * 1.270926833152771,
'elu': lambda x: torch_F.elu(x) * 1.2716004848480225,
'gelu': lambda x: torch_F.gelu(x) * 1.7015043497085571,
'glu': lambda x: torch_F.glu(x) * 1.8484294414520264,
'leaky_relu': lambda x: torch_F.leaky_relu(x) * 1.70590341091156,
'log_sigmoid': lambda x: torch_F.logsigmoid(x) * 1.9193484783172607,
'log_softmax': lambda x: torch_F.log_softmax(x) * 1.0002083778381348,
'relu': lambda x: torch_F.relu(x) * 1.7139588594436646,
'relu6': lambda x: torch_F.relu6(x) * 1.7131484746932983,
'selu': lambda x: torch_F.selu(x) * 1.0008515119552612,
'sigmoid': lambda x: torch_F.sigmoid(x) * 4.803835391998291,
'silu': lambda x: torch_F.silu(x) * 1.7881293296813965,
'soft_sign': lambda x: torch_F.softsign(x) * 2.338853120803833,
'softplus': lambda x: torch_F.softplus(x) * 1.9203323125839233,
'tanh': lambda x: np.tanh(x) * 1.5939117670059204,