本文翻译自:https://pytorch.org/tutorials/beginner/nn_tutorial.html 这是第一部分:如何用纯Python构造一个神经网络。
Pytorch提供了几个设计得非常棒的模块和类,比如 torch.nn,torch.optim,Dataset 以及 DataLoader,来帮助你设计和训练神经网络。为了充分利用他们来解决你的问题,你需要明白他们具体是做什么的。为了帮助大家理解这些内容,我们首先基于MNIST数据集,不用以上提到的模块和类来训练一个基础的神经网络,只用到基本的PyTorch tensor 函数。然后我们会逐渐地使用来自torch.nn,torch.optim,Dataset 以及 DataLoader的功能。展示每个模块具体的功能,他的运作过程。这样来使得代码逐渐简洁和灵活。
MNIST数据集
我们使用经典的MNIST数据集。这个数据集由手写数字0~9的黑白照片组成。
我们使用pathlib包来解决文件路径问题(这是python3的一个标准包),我们使用requests函数来下载数据。我们只导入我们需要用到的模块,所以你可以清晰地看到每一步具体用了什么。
from pathlib import Path
import requests
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists():
content = requests.get(URL + FILENAME).content
(PATH / FILENAME).open("wb").write(content)
数据现在是numpy array格式,用pickle存储起来了。这是Python特有的串行数据格式。
import pickle
import gzip
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
((x_tr