JAX定位
-
JAX 不是一个深度学习框架或深度学习库,其设计初衷也不是成为一个深度学习框架或深度学习库。
-
JAX 的定位科学计算(Scientific Computing)和函数转换(Function Transformations)的交叉融合。深度学习只是 JAX 功能的一小部分。特色功能如下:
- 即时编译(Just-in-Time Compilation)
- 自动并行化(Automatic Parallelization)
- 自动向量化(Automatic Vectorization)
- 自动微分(Automatic Differentiation)
-
两大部分内容:
- 对标Numpy的科学计算库,可以在GPU和TPU上运行
- 深度学习需要用到的底层计算工具。
基本用法
jit
即时编译加速
-
纯函数:输入全部作为参数,结果全部作为输出。不使用一切外部变量。
-
jax transforms
将语句翻译成简单的数据计算流(tracing),追踪数值/变量的变换轨迹
import jax
import jax.numpy as jnp
global_list = []
def f(x):
global_list.append(x) # side-effect
print(x) # side-effect
return 2*x*x+3*x+3
jaxpr = jax.make_jaxpr(f)
jaxpr(3)
# make_jaxpr翻译(转换)后函数:
{ lambda ; a:i32[]. let
b:i32[] = mul a 2
c:i32[] = mul b a
d:i32[] = mul a 3
e:i32[] = add c d
f:i32[] = add e 3
in (f,) }
-
jit
将代码直接编译成机器码,静态编译。由谷歌开发的XLA(加速线性代数)编译器完成。 -
会将编译后的代码进行缓存。只有输入变量形状改变,静态参数改变,才会重新编译。
# 举一个调用全局变量的问题: g = 0 def f(x): return x + g jit_f = jax.jit(f) print ("First call: ", jit_f(3.))
输出:
First call: 3.0
如果此时将外部变量g改为10,再次运行程序:
g = 10. # Update the global jit_f(4.)
输出:
DeviceArray(4., dtype=float32, weak_type=True)
jit_f(jnp.array([4.])) # 输入参数的shape改变 output:DeviceArray([14.], dtype=float32)
-
不支持分支语句,对循环语句有限制条件,强烈建议使用特定的函数
-
cond
def cond(pred, true_fun, false_fun, operand): if pred: return true_fun(operand) else: return false_fun(operand)
-
while_loop
def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
-
fori_loop
body_fun有两个参数:i和中间变量
def fori_loop(start, stop, body_fun, init_val): val = init_val for i in range(start, stop): val = body_fun(i, val) return val
-
scan
- f: 双参数状态变化函数
- init: carry的初始值
- xs: 输入的变量
def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init # core: ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
-
vmap
自动向量化,批处理
-
输入一个函数,输出一个函数。
-
自动对参数进行切分。默认对全部数组参数按照最高维进行切分。切分好的切片按照原函数函数分组进行计算,无法分组或分组后无法计算会报错,结果会
stack
进行堆砌。from jax import numpy as jnp from jax import random, jit, grad, vmap x = jnp.array([[1, 2, 3], [0, 1, 2]]) # 2 * 3 y = jnp.array([[1, 2, 3], [0, 0, 0],[1, 2, 3], [0, 1, 2] ]) # 4 * 3 def fun(a1, a2): return (a1 * a2) vmap_fun = vmap(fun) print(vmap_fun(x, y))
ValueError: vmap got inconsistent sizes for array axes to be mapped: arg 0 has shape (2, 3) and axis 0 is to be mapped arg 1 has shape (4, 3) and axis 0 is to be mapped so arg 0 has an axis to be mapped of size 2 arg 1 has an axis to be mapped of size 4
-
in_axes
指定输入数组的切片轴。out_axes
指定输出函数的堆砌轴。# 定义函数: f = lambda x,w : jnp.dot(w,x) # 定义batch_x, w。 x_batch = jax.random.normal(jax.random.PRNGKey(55), (4, 5, 3)) w = jax.random.normal(jax.random.PRNGKey(42), (100, 5))
batch_a = jax.vmap(f, in_axes=(0,None), out_axes=0)(x_batch, w) print(batch_a,shape) ================================== out: (4, 100, 3)
pmap
- 单机多卡
- 和vmap类似。会把切片分散到不同的GPU里进行并行计算。会把用到的参数在每个GPU里复制一份。
- 使用并行结果的时候利用
API
去获得不同GPU上的结果。
grad
自动微分
-
·输入一个函数,输入一个函数
-
多参数函数,可指定对哪几个参数进行微分。默认只对第一个参数进行微分。
import jax def f(x, y): return 2*x*x + 3*y + 3 x = 10 y = 5 jax.grad(f)(x, y) jax.grad(f, argnums=(0,1,))(x, y)
输出:返回的是一个元组,可以被索引:
DeviceArray(40., dtype=float32, weak_type=True) ( DeviceArray(40., dtype=float32, weak_type=True), DeviceArray(3., dtype=float32, weak_type=True) )
Pytree
JAX中的pytree指的是使用python容器(比如list、dict、tuple、OrderedDict、None、namedtuple等)储存的树状结构的数据(e.g., lists of lists of dicts)。如果一些数据没有被python容器装起来,那么它就是子叶数据(比如数值、数组、类、字符串),pytree中可以嵌套pytree。
嵌套式的list/dict/tuple结构,常常用来做神经网络的参数。
jnumpy
-
API
和Numpy
的API
几乎完全一样。import numpy as np import jax.numpy as np
-
产生随机数的方式不一样。
np.random.seed(seed) np.random.uniform() # 0.54881350 np.random.unifrom() # 0.71518936
key = jax.random.PRNGKey(seed) # key:DeviceArray([0, 0], dtype=uint32) x = jax.random.uniform(key) # 0.41845703 x = jax.random.uniform(key) # 0.41845703 key, subkey = jax.random.split(key) x = jax.random.uniform(subkey) # 0.10546897
key 和 x 之间看似有一种映射关系。
在jax中使用随机数的精髓:永远不用重复使用你的key,善用jax.random.split()函数。
为了兼容jax的并行化、可重复以及可矢量化。
-
数组不可变
# NumPy: mutable arrays x = np.arange(10) x[0] = 10
# JAX: immutable arrays x = jnp.arange(10) x[0] = 10 # 报错 y = x.at[0].set(10) new_array = index_update(old_array, index[1, :], 1.) new_array = index_add(old_array, index[::2, 3:], 7)
允许就地改变变量使得程序分析和转换非常困难。
-
索引超出范围不报错,返回最后一个
jnp.arange(10)[11] -------------------------------------------- out: DeviceArray(9, dtype=int32)
JAX PyTorch GCN实现对比
1. 导包
import jax
import jax.numpy as np
from jax import lax, random # 随机数包
from jax.experimental import stax # 计算模型
from jax.experimental.stax import Relu, LogSoftmax # 激活函数
from jax.nn.initializers import glorot_normal, glorot_uniform, normal, uniform, zeros
import optax # 优化器
import jax.nn as nn # nn库
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import matplotlib.pyplot as plt
2. 定义图卷积
def GraphConvolution(out_dim, bias=False, sparse=False):
def matmul(A, B, shape):
if sparse:
return sp_matmul(A, B, shape)
else:
return np.matmul(A, B)
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W_init, b_init = glorot_uniform(), zeros
W = W_init(k1, (input_shape[-1], out_dim))
if bias:
b = b_init(k2, (out_dim,))
else:
b = None
return output_shape, (W, b)
def apply_fun(params, feature, adj):
W, b = params
support = np.dot(feature, W)
out = matmul(adj, support, support.shape[0])
if bias:
out += b
return out
return init_fun, apply_fun
class GraphConvolution(nn.Module):
def __init__(self, input_dim, output_dim, use_bias=True):
"""图卷积: L*X*\theta
Args:
----------
input_dim: int
节点输入特征的维度
output_dim: int
输出特征维度
use_bias : bool, optional
是否使用偏置
"""
super(GraphConvolution, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.use_bias = use_bias
self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
if self.use_bias:
self.bias = nn.Parameter(torch.Tensor(output_dim))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight)
if self.use_bias:
init.zeros_(self.bias)
def forward(self, adjacency, input_feature):
"""邻接矩阵是稀疏矩阵, 因此在计算时使用稀疏矩阵乘法
Args:
-------
adjacency: torch.sparse.FloatTensor
邻接矩阵
input_feature: torch.Tensor
输入特征
"""
support = torch.mm(input_feature, self.weight)
output = torch.sparse.mm(adjacency, support)
if self.use_bias:
output += self.bias
retur output
总结: jax没有PyTorch的model类。jax的函数尽可能写成纯函数,不保留内部数据,数据尽可能作为参数,外界传入。
3. 定义神经网络
def GCN(nhid: int, nclass: int, sparse: bool = False):
gc1_init, gc1_fun = GraphConvolution(nhid, sparse=sparse)
gc2_init, gc2_fun = GraphConvolution(nclass, sparse=sparse)
init_funs = [gc1_init, gc2_init]
def init_fun(rng, input_shape):
params = []
for init_fun in init_funs:
rng, layer_rng = random.split(rng)
input_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
return input_shape, params
def apply_fun(params, feature, adj, **kwargs):
rng = kwargs.pop('rng', None)
k1, k2 = random.split(rng, 2)
x = gc1_fun(params[0], feature, adj, rng=k1)
x = nn.relu(x)
x = gc2_fun(params[1], x, adj, rng=k2)
x = nn.log_softmax(x)
return x
return init_fun, apply_fun
class GcnNet(nn.Module):
"""
定义一个包含两层GraphConvolution的模型
"""
def __init__(self, input_dim=1433):
super(GcnNet, self).__init__()
self.gcn1 = GraphConvolution(input_dim, 16)
self.gcn2 = GraphConvolution(16, 7)
def forward(self, adjacency, feature):
h = F.relu(self.gcn1(adjacency, feature))
logits = self.gcn2(adjacency, h)
return logits
4.训练
-
模型初始化
init_fun, predict_fun = GCN(nhid=hidden, nclass=labels.shape[1],sparse=args.sparse) _, init_params = init_fun(init_key, input_shape)
# 模型初始化 device = "cuda" if torch.cuda.is_available() else "cpu" model = GcnNet().to(device)
-
定义损失函数
@jit def loss(params, batch): """ The idxes of the batch indicate which nodes are used to compute the loss. """ inputs, targets, adj, rng, idx = batch preds = predict_fun(params, inputs, adj, rng=rng) ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1)) l2_loss = 5e-4 * optimizers.l2_norm(params)**2 # tf doesn't use sqrt return ce_loss + l2_loss
# 损失函数使用交叉熵 criterion = nn.CrossEntropyLoss().to(device)
-
定义优化器
optimizer = optax.adam(start_learning_rate) opt_state = optimizer.init(init_params) # 优化器状态初始化
# 优化器使用Adam optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
-
正(反)向传播
for epoch in range(num_epochs): grads = jax.grad(compute_loss)(params, xs, ys) updates, opt_state = optimizer.update(grads, opt_state) # updates更新参数的方式 params = optax.apply_updates(params, updates)
for epoch in range(epochs): logits = model(tensor_adjacency, tensor_x) # 前向传播 loss = criterion(logits, train_y) # 计算损失值 optimizer.zero_grad() loss.backward() # 反向传播计算参数的梯度 optimizer.step() # 使用优化方法进行梯度更新
jax的前向传播过程其实定义在了损失函数里面。在更新的时候,会调用损失函数求梯度,也就前向传播了。
update
完成了前向传播,求损失值,求梯度更新参数的过程。
总结: 最大的区别就是jax是面向纯函数编程,PyTorch是面向对象编程。
Neural Network Libraries
- Flax - Centered on flexibility and clarity.
- Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
- Objax - Has an object oriented design similar to PyTorch.
- Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
- Trax - “Batteries included” deep learning library focused on providing solutions for common workloads.
- Jraph - Lightweight graph neural network library.
- Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
- HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
- Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
参考资料
2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美
Implementing Graph Neural Networks with JAX
《深入浅出图神经网络:GNN原理解析》(刘忠雨 李彦霖 周洋)