PyTorch是一个开源的Python机器学习库,被广泛应用于深度学习领域。本文将介绍如何使用PyTorch实现一个简单的神经网络,并在MNIST数据集上进行训练和测试。
- 环境准备
首先需要安装PyTorch和相关的依赖库。可以通过以下命令安装PyTorch:
pip install torch torchvision
- 数据集准备
我们将使用MNIST手写数字数据集,这是一个非常经典的数据集,包含60,000个训练样本和10,000个测试样本。可以使用以下代码下载数据集:关注v❤公众H:Ai技术星球 回复(123)必领pytorch深度学习资料
- 数据预处理
在使用数据集之前,需要对其进行预处理。在本例中,我们将把每个数字图像缩放到28x28大小,并将像素值归一化到0到1之间。可以使用以下代码完成预处理:
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset.transform = trans