jax和jaxlib是一起的,所以我们可以通过jax或者jaxlib去判断GPU是否用。
jax判断:
import jax
print(jax.devices())
jaxlib判断:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
jax和jaxlib是一起的,所以我们可以通过jax或者jaxlib去判断GPU是否用。
jax判断:
import jax
print(jax.devices())
jaxlib判断:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)