jax.numpy是CPU、GPU和TPU上的numpy,具有出色的自动差异化功能,可用于高性能机器学习研究。
我今天就来试一试到底多快。我在同一台bu带gpu的机器上进行试验
首先我们得安装jax
pip install jax jaxlib
先试一下原生的numpy
import numpy as np
import time
x = np.random.random([5000, 5000]).astype(np.float32)
try:
st=time.time()
y=np.matmul(x, x)
except Exception:
print("erro")
print(time.time()-st)
print(y)
运行结果:
[root@node opt]# python np.py
4.424036026000977
[[1236.3004 1240.3048 1211.4501 ... 1225.7804 1237.1368 1235.1566]
[1235.5778 1246.7327 1208.7142 ... 1238.117 1232.439 1226.5779]
[1235.0111 1244.4628 1211.5264 ... 1238.5541 1246.9045 1244.6909]
...
[1229.7677 1238.8345 1210.4467 ... 1219.8604 1234.0862 1220.1482]
[1231.9464 1251.9636 1212.1384 ... 1235.8513 1236.8677 1240.5355]
[1254.0636 1265.74 1241.6528 ... 1245.015 1259.153 1247.0613]]
再来试一下jax带的numpy
import jax.numpy as np
from jax import random
import time
x = random.uniform(random.PRNGKey(0), [5000, 5000])
st=time.time()
try:
y=np.matmul(x, x)
except Exception:
print("erro")
print(time.time()-st)
print(y)
结果:
[root@node opt]# python jax_np.py
/opt/AN/lib/python3.7/site-packages/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
0.013895750045776367
[[1261.0647 1244.5797 1237.2269 ... 1264.7208 1246.0367 1260.5391]
[1256.1 1239.737 1237.5562 ... 1257.1333 1243.5856 1243.5979]
[1261.2687 1239.5006 1250.6697 ... 1259.8387 1250.6825 1248.5712]
...
[1265.9805 1230.9077 1244.4961 ... 1264.2374 1241.5995 1244.9274]
[1262.9971 1253.961 1256.2424 ... 1266.3489 1255.1581 1274.1865]
[1273.3524 1252.4921 1261.0496 ... 1273.2394 1272.829 1267.7483]]
我们可以看到,没有jax的numpy运行了差不多4.4秒,而带了jax的numpy直接才0.014,速度基本上提升了30倍。也太快了
同样地,jax下面还有jax.scipy等代替原生的scipy