使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络

本文通过实例展示了如何在JAX框架中,借助PyTorch的数据处理工具,构建并训练一个简单的MLP。介绍了JAX的自动微分和JIT特性,以及如何利用vmap和grad进行模型优化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络

本文例程部分主要参考官方文档。

JAX简介

JAX 的前身是 Autograd ,也就是说 JAX 是 Autograd 升级版本,JAX 可以对 Python 和 NumPy 程序进行自动微分。可以通过 Python的大量特征子集进行区分,包括循环、分支、递归和闭包语句进行自动求导,也可以求三阶导数(三阶导数是由原函数导数的导数的导数。 所谓三阶导数,即原函数导数的导数的导数,将原函数进行三次求导)。通过 grad ,JAX 支持反向模式和正向模式的求导,而且这两种模式可以任意组合成任何顺序,具有一定灵活性。

另一个特点是基于 XLA 的 JIT 即时编译,大大提高速度。

需要注意的是,JAX 仅提供计算时的优化,相当于是一个支持自动微分和 JIT 编译的 NumPy。也就是说,数据处理 Dataloader 等其他框架都会提供的 utils 功能这里是没有的。所幸 JAX 可以比较好的支持 PyTorch、 TensorFlow 等主流框架的数据读取。本文就将基于 PyTorch 的数据读取工具和 JAX 框架来训练一个简单的神经网络

以下是国内优秀的机器学习框架 OneFlow 同名公司的创始人袁进辉老师在知乎上的一个评价:

如果说tensorflow 是主打lazy, 偏functional 的思想,但实现的臃肿面目可憎;pytorch 则主打eager, 偏imperative 编程,但内核简单,可视为支持gpu的numpy, 加上一个autograd。JAX 像是这俩框架的混合体,取了tensorflow的functional和PyTorch的精简,即支持gpu的 numpy, 具有autograd功能,非常追求函数式编程的思想,强调无状态,immutable,加上JIT修饰符后就是lazy,可以使用xla对计算流程进行静态分析和优化。当然JAX不带jit也可像pytorch那种命令式编程和eager执行。

JAX有可能和PyTorch竞争。

安装

安装可以通过源码编译,也可以直接 pip。源码编译详见[官方文档: Building from source][2],对于官方没有提供预编译包的 cuda-cudnn 版本组合,只能通过自己源码构建。pip的方式比较简单,在 github 仓库的 README 文档中就有介绍。要注意,不同于 PyTorch 等框架,JAX 不会再 pip 安装中绑定 CUDA 或 cuDNN 进行安装,若未安装,需要自己先手动安装。仅使用 CPU 的版本也有支持。

笔者是 CUDA11.1,CUDNN 8.2,安装如下:

pip install --upgrade pip
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

前面已经提到过,本文会借用 PyTorch 的数据处理工具,因此 torch 和 torchvision 也是必不可少的(已经安装的可跳过):

pip install torch torchvision

构建简单的神经网络训练

框架安装完毕,我们正式开始。接下来我们使用 JAX 在 MNIST 上指定和训练一个简单的 MLP 进行计算,用 PyTorch 的数据加载 API 来加载图像和标签。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

超参数

# 本函数用来随机初始化网络权重
def random_layer_params(m, n, key, scale=1e-2):
	w_key, b_key = random.split(key)
	return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n, ))

# 初始化各个全连接层
def init_network_params(sizes, key):
	keys = random.split(key, len(sizes))
	return [random_layer_params(m, n, k) for m, n, k in zip(sizes[: -1], sizes[1: ], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

自动分批次预测

对于小批量,我们稍后将使用 JAX 的 vmap 函数来自动处理,而不会降低性能。我们现在先准备一个单张图像推理预测函数:

from jax.scipy.special import logsumexp

def relu(x):
	return jnp.maximum(0, x)

# 对单张图像进行推理的函数
def predict(params, image):
	activations = image
	for w, b in params[: -1]:
		outputs = jnp.dot(w, activations) + b
		activations = relu(outputs)
	
	final_w, final_b = params[-1]
	logits = jnp.dot(final_w, activations) + final_b
	return logits - logsumexp(logits)

这个函数应该只能用来处理单张图像推理预测,而不能批量处理,我们简单测试一下,对于单张:

random_flattened_images = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_images)
print(preds.shape)

输出:

(10,)

对于批次:

random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
	preds = predict(params, random_flattened_images)
except TypeError:
	print('Invalid shapes!')

输出:

Invalid shapes!

现在我们使用 vmap 来使它能够处理批量数据:

# 用 vmap 来实现一个批量版本
batched_predict = vmap(predict, in_axes=(None, 0))

# batched_predict 的调用与 predict 相同
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

输出:

(10, 10)

现在,我们已经做好了准备工作,接下来就是要定义一个神经网络并且进行训练了,我们已经构建了的自动批处理版本的 predict 函数,并且将在损失函数中也使用它。我们将使用 grad 来得到损失关于神经网络参数的导数。而且,这一切都可以用 jit 进行加速。

实用工具函数和损失函数

def one_hot(x, k, dtype=jnp.float32):
	"""构建一个 x 的 k 维 one-hot 编码."""
	return jnp.array(x[:, None] == jnp.arange(k), dtype)


def accuracy(params, images, targets):
	target_class = jnp.argmax(targets, axis=1)
	predicted_class =  jnp.argmax(batched_predict(params, images), axis=1)
	return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
	preds = batched_predict(params, images)
	return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
	grads = grad(loss)(params, x, y)
	return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

使用 PyTorch 进行数据读取

JAX 是一个专注于程序转换和支持加速的 NumPy,对于数据的读取,已经有很多优秀的工具了,这里我们就直接用 PyTorch 的 API。我们会做一个小的 shim 来使得它能够支持 NumPy 数组。

import numpy as np
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
	if isinstance(batch[0], np.ndarray):
		return np.stack(batch)
	elif isinstance(batch[0], (tuple, list)):
		transposed = zip(*batch)
		return [numpy_collate(samples) for samples in transposed]
	else:
		return np.array(batch)

class NumpyLoader(data.DataLoader):
	def __init__(self, dataset, batch_size=1,
					shuffle=False, sampler=None,
					batch_sampler=None, num_workers=0,
					pin_memory=False, drop_last=False,
					timeout=0, worker_init_fn=None):
		super(self.__class__, self).__init__(dataset,
					batch_size=batch_size,
					shuffle=shuffle,
					sampler=sampler,
					batch_sampler=batch_sampler,
					collate_fn=numpy_collate,
					num_workers=num_workers,
					pin_memory=pin_memory,
					drop_last=drop_last,
					timeout=timeout,
					worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
	def __call__(self, pic):
		return np.ravel(np.array(pic, dtype=jnp.float32))

接下来借助 PyTorch 的 datasets,定义我们自己的 dataset:

mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

此处应该输出一堆下载 MNIST 数据集的信息,就不贴了。

接下来分别拿到整个训练集和整个测试集,下面会用于测准确率:

train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

开始训练

import time
for epoch in range(num_epochs):
	start_time = time.time()
	for x, y in training_generator:
		y = one_hot(y, n_targets)
		params = update(params, x, y)
	epoch_time = time.time() - start_time

	train_acc = accuracy(params, train_images, train_labels)
	test_acc = accuracy(params, test_images, test_labels)
	print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
	print("Training set accuracy {}".format(train_acc))
	print("Test set accuracy {}".format(test_acc))

输出:

Epoch 0 in 3.29 sec
Training set accuracy 0.9156666994094849
Test set accuracy 0.9196999669075012
...
Epoch 7 in 1.78 sec
Training set accuracy 0.9736666679382324
Test set accuracy 0.9670999646186829

在本文的过程中,我们已经使用了整个 JAX API:grad 用于自动微分、jit 用于加速、vmap 用于自动矢量化。我们使用 NumPy 来进行我们所有的计算,并从 PyTorch 借用了出色的数据加载器,并在 GPU 上运行了整个过程。

Ref:

https://juejin.cn/post/6994695537316331556

https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

https://jax.readthedocs.io/en/latest/developer.html#building-from-source

### Safetensors 的多框架兼容性与 Transformers 模型加载 Safetensors 是一种高效的文件格式,用于存储深度学习模型的权重。它具有安全性、高性能以及良好的跨框架兼容性特点[^1]。 #### 1. **Safetensors 文件的安全性和高效性** Safetensors 提供了一种安全的方式来保存和读取模型权重,避免了传统 Pickle 方法可能带来的安全隐患。通过其零拷贝特性,可以显著提升模型加载的速度,这对于大规模模型尤为重要。 以下是使用 `safetensors` 加载部分张量的一个典型例子: ```python from safetensors import safe_open tensors = {} with safe_open("model.safetensors", framework="pt", device=0) as f: tensor_slice = f.get_slice("embedding") vocab_size, hidden_dim = tensor_slice.get_shape() tensor = tensor_slice[:vocab_size, :hidden_dim] ``` 上述代码展示了如何仅加载特定的部分张量(如嵌入层),从而节省内存资源并加速初始化过程。 --- #### 2. **Transformers 库中的集成** Hugging Face 的 Transformers 库广泛支持各种主流的大规模预训练模型,并提供了对 `safetensors` 格式的内置支持。这意味着可以直接利用 `.safetensors` 文件作为模型权重的来源,而无需额外转换为其他格式。 要加载基于 `safetensors` 存储的 Transformer 模型,可以通过如下方式完成: ```python from transformers import AutoModelForCausalLM, AutoTokenizer # 初始化 tokenizer 和 model tokenizer = AutoTokenizer.from_pretrained("path/to/model", use_safetensors=True) model = AutoModelForCausalLM.from_pretrained( "path/to/model", torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True ) # 测试推理功能 input_text = "Hello world!" inputs = tokenizer(input_text, return_tensors="pt").to(model.device) outputs = model.generate(**inputs) print(tokenizer.decode(outputs[0])) ``` 在此过程中,参数 `use_safetensors=True` 明确指定了优先尝试从 `.safetensors` 文件中加载权重。如果该路径不存在对应的 `.safetensors` 文件,则会回退到传统的 PyTorch 权重文件(`.bin`)。此设计增强了灵活性,同时也保留了向后兼容的能力[^3]。 --- #### 3. **多框架的支持能力** 为了适应不同的深度学习生态系统需求,Safetensors 被设计为能够轻松适配多个框架环境。例如,在 TensorFlow 或 JAX 中也可以采用类似的逻辑操作这些二进制数据结构。具体而言,开发者只需调整 `framework` 参数即可切换目标运行时平台: ```python import tensorflow as tf from safetensors.tensorflow import load_file weights = load_file("model.safetensors", tf_device="/GPU:0") for name, value in weights.items(): print(f"{name}: {value.shape}") ``` 这段代码片段演示了在 TensorFlow 下加载相同格式的数据流程序列化对象的方法。 此外,对于那些专注于轻量化部署场景的应用场合来说,诸如 vLLM 这样的优化框架同样推荐配合 `safetensors` 实现更优的整体表现[^2]。 --- ### 总结 综上所述,借助于 Safetensors 所具备的强大特性和易用接口,无论是单机实验还是分布式生产环境中都可以便捷地处理复杂的神经网络架构及其关联组件之间的交互关系。同时,由于其天然契合现代机器学习工作流程的设计理念,使得像 HuggingFace Transformers 等工具链得以无缝衔接其中,进一步促进了整个行业的标准化进程与发展步伐。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值