python的jax包的常用操作
本文参考官方文档
1.jax.random包
PRNGKey
>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key
DeviceArray([0, 0], dtype=uint32)
根据传入参数,生成两个无符号32位整数(不用管具体的细节,理解后面的使用即可),我们通常其称为key
>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)
key可以用于任何jax的随机数生成
>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)
如果key不变,结果不变
如果你需要新的随机数,你可以使用jax.random.split()
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
DeviceArray(0.10536897, dtype=float32)
jax.random.gumbel
#指定形状和数据类型的gumbel采样,key为刚才的key,返回的是个数组类型
jax.random.gumbel(key, shape=(), dtype=<class 'numpy.float64'>)
2.jax.experimental.stax包
jax.experimental.stax是一个小型但灵活的神经网络规范库,可以快速的生成指定的网络层
from jax.experimental import stax
layers=[]#存放生成的网络层
#生成out_dim个神经元的全连接层
stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)
#以rate概率的dropout层
stax.Dropout(rate, mode='train')
#生成一般的的卷积层,out_chan为卷积核数量,filter_shape为卷积核shape,strides为步长,padding表示填充操作,
stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)
#生成Relu层
stax.Relu()
#我们可以用layers来存放结果
layer.append(300,(1,5),padding="same",strides=(1,1))
layer.append(stax.Dense(8))
layer.append(stax.Relu)
stax.serial
#返回将网络层组合起来的结果,是个(init_fun, apply_fun) pair,表示给定层序列的连续组成。layers是刚才定义的list,里面放刚才定义的层即可
stax.serial(*layers)
#我们甚至可以用新的list来接受stax.serial(*layers)返回值进行套娃
layers=[stax.serial(*layers)]
layers.append(stax.Dense(16))
layers.append(stax.Relu)
3.jax.numpy包
这个包方法和numpy类似
#创建一个数组
jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0)
#对x进行深度为k的one_hot编码
jax.numpy.eye(k)[x]
#返回将对数组进行排序的索引,a为数组,axis为指定维度
jax.numpy.argsort(a, axis=- 1, kind='quicksort', order=None)
#计算指定维度的均值
jax.numpy.mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=None)
#计算指定数组维度的和
jax.numpy.sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)
#重复一个数组的元素
jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)