python的jax包的常用操作

本文介绍了Python库jax的基本操作,包括使用jax.random生成随机数和PRNGKey,利用jax.experimental.stax构建神经网络层,如全连接层、dropout层、卷积层和激活函数,并展示了如何组合这些层。此外,还提到了jax.numpy包,它提供了与numpy类似的数组操作,如创建数组、one_hot编码、排序、平均值计算和求和等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

python的jax包的常用操作

本文参考官方文档

1.jax.random包

PRNGKey

>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key
DeviceArray([0, 0], dtype=uint32)

根据传入参数,生成两个无符号32位整数(不用管具体的细节,理解后面的使用即可),我们通常其称为key

>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)

key可以用于任何jax的随机数生成

>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)

如果key不变,结果不变
如果你需要新的随机数,你可以使用jax.random.split()

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
DeviceArray(0.10536897, dtype=float32)

jax.random.gumbel

Gumbel 分布及应用浅析

#指定形状和数据类型的gumbel采样,key为刚才的key,返回的是个数组类型
jax.random.gumbel(key, shape=(), dtype=<class 'numpy.float64'>)

2.jax.experimental.stax包

jax.experimental.stax是一个小型但灵活的神经网络规范库,可以快速的生成指定的网络层

from jax.experimental import stax

layers=[]#存放生成的网络层
#生成out_dim个神经元的全连接层
stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)

#以rate概率的dropout层
stax.Dropout(rate, mode='train')

#生成一般的的卷积层,out_chan为卷积核数量,filter_shape为卷积核shape,strides为步长,padding表示填充操作,
stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

#生成Relu层
stax.Relu()

#我们可以用layers来存放结果
layer.append(300,(1,5),padding="same",strides=(1,1))
layer.append(stax.Dense(8))
layer.append(stax.Relu)

stax.serial

#返回将网络层组合起来的结果,是个(init_fun, apply_fun) pair,表示给定层序列的连续组成。layers是刚才定义的list,里面放刚才定义的层即可
stax.serial(*layers)

#我们甚至可以用新的list来接受stax.serial(*layers)返回值进行套娃
layers=[stax.serial(*layers)]
layers.append(stax.Dense(16))
layers.append(stax.Relu)

3.jax.numpy包

这个包方法和numpy类似

#创建一个数组
jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0)

#对x进行深度为k的one_hot编码
jax.numpy.eye(k)[x]

#返回将对数组进行排序的索引,a为数组,axis为指定维度
jax.numpy.argsort(a, axis=- 1, kind='quicksort', order=None)

#计算指定维度的均值
jax.numpy.mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=None)

#计算指定数组维度的和
jax.numpy.sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)

#重复一个数组的元素
jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小怪兽会微笑

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值