Penzai:来自DeepMind的JAX模型构建和可视化工具库

前言

JAX是继TensorFlow之后,谷歌近年推出的一个较新的深度学习框架,具有计算速度快、支持大规模GPU集群等优点。JAX的生态近年来不断丰富,在JAX工具库方面已经包括有Flax、Equinox、Keras等不少项目。

Penzai(“盆栽”)[1][2]是近期由来自谷歌DeepMind的作者推出的一款新的JAX工具库。Penzai可以用来对已创建的JAX模型进行可视化展示、反向工程、消融分析和改进,也能用来创建新的JAX模型。

由Penzai创建的模型以及模型的可视化展示是什么样子的呢?接下来【算AI】小编就通过示例来初步介绍一下。

示例一:模型的创建及可视化展示

【算AI】小编首先使用Penzai创建了一个简单的MLP(多层感知器),然后使用Penzai在Jupyter中输出了该MLP模型的结构。该模型结构的输出如下。

点击输出结果中的右箭头,可以展开更多的内容,例如参数、所包含的层等,如下图所示。

更多的功能和示例

Penzai的官方文档[2]中介绍了Penzai的更多功能和示例,例如:

  • 如何基于Penzai和LoRA修改已有的模型;
  • 如何基于Penzai从零开始构建Gemma 7B模型;
  • Penzai自带的神经网络库,类似于Flax、Haiku、Keras、Equinox等包含的神经网络库,用于搭建、编辑神经网络模型;
  • 更多的交互式的、彩色的模型和数据可视化功能,等等。

以下再通过示例初步介绍一下Penzai对于多维数组的可视化展示。

示例二:N维数组的可视化展示

Penzai的可视化功能可以用来展示任意维度的数组。例如,下图是一个二维数组的可视化展示:

以下是一个四维数组的可视化展示:

在上图中,3大行和4大列分别用来表示第1和第2维度,每个5乘6的小方格以及其中的颜色用来表示第3和第4维度、以及该四维数组中每个元素的值。

上述的四维数组也可以这样来展示:

在上图中,3个横向间距较宽的分组表示第1维度,每4个横向间距较小的列表示第2维度,每个5乘6的小方格以及其中的颜色依然表示第3和第4维度、以及该四维数组中每个元素的值。

Penzai的安装

安装Penzai的过程比较简单。首先安装JAX,同时需要确保Python的版本至少是3.10,然后执行pip install penzai就可以了。

官方的Getting Started文档中的示例代码遗漏了一行import jax命令,运行前自己加上就可以。

其它信息

Penzai目前还是0.1.x版本,未来在功能、接口等方面可能会有变化。另外,尽管Penzai由谷歌员工创建,但Penzai并非谷歌的官方产品。

Penzai的授权协议采用的是Apache 2.0。

参考资料

[1] https://github.com/google-deepmind/penzai
[2] https://penzai.readthedocs.io/en/stable/

  • 25
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

算AI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值