前言
JAX是继TensorFlow之后,谷歌近年推出的一个较新的深度学习框架,具有计算速度快、支持大规模GPU集群等优点。JAX的生态近年来不断丰富,在JAX工具库方面已经包括有Flax、Equinox、Keras等不少项目。
Penzai(“盆栽”)[1][2]是近期由来自谷歌DeepMind的作者推出的一款新的JAX工具库。Penzai可以用来对已创建的JAX模型进行可视化展示、反向工程、消融分析和改进,也能用来创建新的JAX模型。
由Penzai创建的模型以及模型的可视化展示是什么样子的呢?接下来【算AI】小编就通过示例来初步介绍一下。
示例一:模型的创建及可视化展示
【算AI】小编首先使用Penzai创建了一个简单的MLP(多层感知器),然后使用Penzai在Jupyter中输出了该MLP模型的结构。该模型结构的输出如下。
点击输出结果中的右箭头,可以展开更多的内容,例如参数、所包含的层等,如下图所示。
更多的功能和示例
Penzai的官方文档[2]中介绍了Penzai的更多功能和示例,例如:
- 如何基于Penzai和LoRA修改已有的模型;
- 如何基于Penzai从零开始构建Gemma 7B模型;
- Penzai自带的神经网络库,类似于Flax、Haiku、Keras、Equinox等包含的神经网络库,用于搭建、编辑神经网络模型;
- 更多的交互式的、彩色的模型和数据可视化功能,等等。
以下再通过示例初步介绍一下Penzai对于多维数组的可视化展示。
示例二:N维数组的可视化展示
Penzai的可视化功能可以用来展示任意维度的数组。例如,下图是一个二维数组的可视化展示:
以下是一个四维数组的可视化展示:
在上图中,3大行和4大列分别用来表示第1和第2维度,每个5乘6的小方格以及其中的颜色用来表示第3和第4维度、以及该四维数组中每个元素的值。
上述的四维数组也可以这样来展示:
在上图中,3个横向间距较宽的分组表示第1维度,每4个横向间距较小的列表示第2维度,每个5乘6的小方格以及其中的颜色依然表示第3和第4维度、以及该四维数组中每个元素的值。
Penzai的安装
安装Penzai的过程比较简单。首先安装JAX,同时需要确保Python的版本至少是3.10,然后执行pip install penzai就可以了。
官方的Getting Started文档中的示例代码遗漏了一行import jax命令,运行前自己加上就可以。
其它信息
Penzai目前还是0.1.x版本,未来在功能、接口等方面可能会有变化。另外,尽管Penzai由谷歌员工创建,但Penzai并非谷歌的官方产品。
Penzai的授权协议采用的是Apache 2.0。
参考资料
[1] https://github.com/google-deepmind/penzai
[2] https://penzai.readthedocs.io/en/stable/