刚刚,Keras 3.0正式发布!经过5个月的公开Beta测试,深度学习框架Keras 3.0终于面向所有开发者推出。全新的Keras3对Keras代码库进行了完全重写,可以在JAX、TensorFlow和PyTorch上运行,能够解锁全新大模型训练和部署的新功能。
重磅消息:我们刚刚发布了 Keras 3.0!
在 JAX、TensorFlow 和 PyTorch 上运行 Keras使用 XLA 编译更快地训练通过新的 Keras 分发 API 解锁任意数量的设备和主机的训练运行
它现在在 PyPI 上上线
Keras 3.0新特性
1、始终为模型获得最佳性能
它能够动态选择为模型提供最佳性能的后端,而无需对代码进行任何更改,这意味着开发者可以以最高效率进行训练和服务。
2、为模型解锁生态系统可选性
可以将Keras 3模型与PyTorch生态系统包,全系列TensorFlow部署和生产工具(如TF-Serving,TF.js和TFLite)以及JAX大规模TPU训练基础架构一起使用。使用 Keras 3 API 编写一个 model.py ,即可访问 ML 世界提供的一切。
3、利用JAX的大规模模型并行性和数据并行性
通过它,可以在任意模型尺度和聚类尺度上轻松实现模型并行、数据并行以及两者的组合。由于它能将模型定义、训练逻辑和分片配置相互分离,因此使分发工作流易于开发和维护。
4、使用来自任何来源的数据管道
Keras 3 / fit() / evaluate() predict() 例程与 tf.data.Dataset 对象、PyTorch DataLoader 对象、NumPy 数组、Pandas 数据帧兼容——无论你使用什么后端。
可以在 PyTorch DataLoader 上训练 Keras 3 + TensorFlow 模型,也可以在tf.data.Dataset上训练Keras 3 + PyTorch模型。预训练模型现在,开发者即可开始使用Keras 3的各种预训练模型。所有40个Keras应用程序模型( keras.applications 命名空间)在所有后端都可用。
KerasCV和KerasNLP中的大量预训练模型也适用于所有后端。其中包括:BERT、 OPT、Whisper、 T5、 Stable Diffusion、 YOLOv8跨框架开发Keras 3能够让开发者创建在任何框架中都相同的组件(如任意自定义层或预训练模型),它允许访问适用于所有后端的 keras.ops 命名空间。Keras 3包含NumPy API的完整实现,——不是「类似 NumPy」,而是真正意义上的 NumPy API,具有相同的函数和参数。比如 ops.matmul、ops.sum、ops.stack、ops.einsum 等函数。
Keras 3还包含NumPy中没有的,一组特定于神经网络的函数,例如 ops.softmax, ops.binary_crossentropy, ops.conv等。另外,只要开发者使用的运算,全部来自于keras.ops ,那么自定义的层、损失函数、优化器就可以跨越JAX、PyTorch和TensorFlow,使用相同的代码。开发者只需要维护一个组件实现,就可以在所有框架中使用它。