jax安装

本文介绍了如何在Windows和Linux上安装JAX及其依赖,并通过对比numpy的运行时间展示了JAX在矩阵乘法运算上的显著加速效果。在Linux环境下,JAX的numpy实现比原生numpy速度快约30倍。此外,还提到了JAX对GPU和TPU的支持以及在没有GPU环境下的CPU安装选项。文章最后提到了JAX与PyTorch的部分操作转换,以及Windows中遇到的问题和解决方法。
摘要由CSDN通过智能技术生成

Windows 安装

pip install jaxlib===0.3.5 -f https://whls.blob.core.windows.net/unstable/index.html

pip install jaxlib[cuda111] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

安装完后引用报错。

linux版本:

https://github.com/google/jax

jax.numpy是CPU、GPU和TPU上的numpy,具有出色的自动差异化功能,可用于高性能机器学习研究

首先我们得安装jax

pip install jax jaxlib

以下参考:机器学习加速利器jax,让numpy加速30倍_喝粥也会胖的唐僧的博客-CSDN博客_jax numpy

先试一下原生的numpy 

import numpy as np

import time

x = np.random.random([5000, 5000]).astype(np.float32)

try:

    st=time.time()

    y=np.matmul(x, x)

except Exception:

    print("erro")

print(time.time()-st)

print(y)

运行结果:  

  1. [root@node opt]# python np.py

  2. 4.424036026000977

  3. [[1236.3004 1240.3048 1211.4501 ... 1225.7804 1237.1368 1235.1566]

  4. [1235.5778 1246.7327 1208.7142 ... 1238.117 1232.439 1226.5779]

  5. [1235.0111 1244.4628 1211.5264 ... 1238.5541 1246.9045 1244.6909]

  6. ...

  7. [1229.7677 1238.8345 1210.4467 ... 1219.8604 1234.0862 1220.1482]

  8. [1231.9464 1251.9636 1212.1384 ... 1235.8513 1236.8677 1240.5355]

  9. [1254.0636 1265.74 1241.6528 ... 1245.015 1259.153 1247.0613]]

再来试一下jax带的numpy


import jax.numpy as np

from jax import random

import time

x = random.uniform(random.PRNGKey(0), [5000, 5000])

st=time.time()

try:

    y=np.matmul(x, x)

except Exception:

    print("erro")

print(time.time()-st)

print(y)
  1. [root@node opt]# python jax_np.py

  2. /opt/AN/lib/python3.7/site-packages/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU.

  3. warnings.warn('No GPU/TPU found, falling back to CPU.')

  4. 0.013895750045776367

  5. [[1261.0647 1244.5797 1237.2269 ... 1264.7208 1246.0367 1260.5391]

  6. [1256.1 1239.737 1237.5562 ... 1257.1333 1243.5856 1243.5979]

  7. [1261.2687 1239.5006 1250.6697 ... 1259.8387 1250.6825 1248.5712]

  8. ...

  9. [1265.9805 1230.9077 1244.4961 ... 1264.2374 1241.5995 1244.9274]

  10. [1262.9971 1253.961 1256.2424 ... 1266.3489 1255.1581 1274.1865]

  11. [1273.3524 1252.4921 1261.0496 ... 1273.2394 1272.829 1267.7483]]

我们可以看到,没有jax的numpy运行了差不多4.4秒,而带了jax的numpy直接才0.014,速度基本上提升了30倍。也太快了

同样地,jax下面还有jax.scipy等代替原生的scipy

To install a CPU-only version, which might be useful for doing local development on a laptop, you can run

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

nvcc --version

pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

pip install jax

pip install jaxlib

Windows不支持,替换方法:

错误:scale = jax.lax.rsqrt(np.maximum(var * fan_in, eps)) * gain

改之后:scale = torch.rsqrt(np.maximum(var * fan_in, eps)) * gain

    'celu': lambda x: torch_F.celu(x) * 1.270926833152771,
    'elu': lambda x: torch_F.elu(x) * 1.2716004848480225,
    'gelu': lambda x: torch_F.gelu(x) * 1.7015043497085571,
    'glu': lambda x: torch_F.glu(x) * 1.8484294414520264,
    'leaky_relu': lambda x: torch_F.leaky_relu(x) * 1.70590341091156,
    'log_sigmoid': lambda x: torch_F.logsigmoid(x) * 1.9193484783172607,
    'log_softmax': lambda x: torch_F.log_softmax(x) * 1.0002083778381348,
    'relu': lambda x: torch_F.relu(x) * 1.7139588594436646,
    'relu6': lambda x: torch_F.relu6(x) * 1.7131484746932983,
    'selu': lambda x: torch_F.selu(x) * 1.0008515119552612,
    'sigmoid': lambda x: torch_F.sigmoid(x) * 4.803835391998291,
    'silu': lambda x: torch_F.silu(x) * 1.7881293296813965,
    'soft_sign': lambda x: torch_F.softsign(x) * 2.338853120803833,
    'softplus': lambda x: torch_F.softplus(x) * 1.9203323125839233,
    'tanh': lambda x: np.tanh(x) * 1.5939117670059204,

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值