只会一些python就能理解的pytorch入门
PyTorch提供设计精美的模块和类 torch.nn、 torch.optim , Dataset 和DataLoader 来帮助您创建和训练神经网络。
为了充分利用他们的力量并针对您的问题定制他们,您需要真正了解他们在做什么。
为了加深这一理解,我们将首先在MNIST数据集上训练基本神经网络,而不使用这些模型中的任何功能;最初我们将仅使用最基本的PyTorch张量功能。
然后,我们将一次递增地从torch.nn、torch.optim、DataSet或DataLoader添加一个特性,显示每个部分的确切功能,以及它是如何使代码更简洁或更灵活的。
本教程假设您已经安装了PyTorch,并且熟悉张量运算的基础知识。
(如果您熟悉Numpy数组操作,您会发现这里使用的PyTorch张量操作几乎相同)。
安装MNIST 数据集
我们将使用经典的MNIST数据集,它由手绘数字(介于0和9之间)的黑白图像组成。
我们将使用pathlib处理路径(Python3标准库的一部分),并使用请求下载数据集。我们只会在使用模块时导入它们,这样您就可以准确地看到每个点都在使用什么。
[懂pathlib可以点击跳过](# 正式开始)
为了方便,我先介绍一下pathlib是什么,不会很深入,只讲基本操作。你可能很熟悉os库,不过学会pathlib是一件好事,你可以简单的认为是os的升级版,好处可以自行了解。
pathlib的概念(不想看可以跳过):
pathlib由6个类组成,分别是"具体路径":PosixPath(非windows文件系统的类),Path,WindowsPath(windows文件系统的类),以及“纯路径”:PurePosixPath,PurePath,PureWindowsPath
具体路径是从纯路径继承而来的,具体路径里的父类都是path,windows和posix都继承了path。
也就是说具体路径里的类,Path继承了PurePath,WindowPath继承了Path和PureWindowsPath,PosixPath和WindowsPath类似。
纯路径只有计算操作,没有I/O操作,具体路径带上了I/O操作。
如果你不知道该使用什么类,可以总是信赖Path类,他会在你的在运行代码的平台上自动实例化为一个合适的具体路径类。
基本操作:
以下使用Ipython环境,即在cmd中输入python,或者使用JupyterNotebook等可以直接输入Ipython代码的开发环境。我可能会比较关注WindowsPath
创建path类对象
#输入:
from pathlib import Path #导入Path类
p = Path('.') #实例化Path类对象,之后他会根据你的运行代码平台自动实例化成两个子类之一
#运行后的输出:
WindowsPath('.') #我的开发环境就是Windows
cwd()
返回一个新的代表当前工作目录的路径对象:
#输入:
Path.cwd()
#运行后的输出:
WindowsPath('F:/anaconda/lib/site-packages/keras/layers') #这个路径就是你的运行代码文件的路径,我就是创建在这个奇怪的路径下面
home()
返回一个表示当前用户家目录的新路径对象:
#输入:
Path.home()
#运行后的输出:
WindowsPath('C:/Users/95436')
stat()
返回一个 os.stat_result
对象,其中包含有关此路径的信息
#输入:
from pathlib import Path
p = Path('.') #要注意这里必须是实例化的,不然他没办法返回信息
p.stat()
#运行后的输出:
os.stat_result(st_mode=16895, st_ino=12384898975271472, st_dev=910572341, st_nlink=1, st_uid=0, st_gid=0, st_size=32768, st_atime=1596443383, st_mtime=1596443377, st_ctime=1567087196) #这一段输出的含义可以点击os.stat_result那个超链接观看,就不赘述了
exists()
此路径是否指向一个已存在的文件或目录?返回一个bool
#输入:
Path('.').exists()
#运行后的输出
True
操作符’/’,这个可以直接用来创建子路径
#比如 a="123" b=a/"2" 这样的语法 str/str在 python里是不允许的,但是path类支持
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH
WindowsPath('data/mnist')
mkdir(mode=0o777, parents=False, exist_ok=False)
新建给定路径的目录。如果给出了 mode ,它将与当前进程的 umask
值合并来决定文件模式和访问标志。如果路径已经存在,则抛出 FileExistsError
。
如果 parents 为 true,任何找不到的父目录都会伴随着此路径被创建;它们会以默认权限被创建,而不考虑 mode 设置(模仿 POSIX 的 mkdir -p
命令)。
如果 parents 为 false(默认),则找不到的父级目录会导致 FileNotFoundError
被抛出。
如果 exist_ok 为 false(默认),则在目标已存在的情况下抛出 FileExistsError
。
如果 exist_ok 为 true, 则 FileExistsError
异常将被忽略(和 POSIX mkdir -p
命令行为相同),但是只有在最后一个路径组件不是现存的非目录文件时才生效。
正式开始
from pathlib import Path
import requests
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist" #这里实例化了一个Path对象(“data/mnist”)
PATH.mkdir(parents=True, exist_ok=True) #在根目录下创建这一个路径 即为“./data/mnist”,而且允许父目录不存在(就是相对于最后一个/的内容来说,其他都是父目录,不存在一并创建),并且如果已经存在了这个路径,也无所谓。
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists(): #如果这个路径不存在的话,request是个模拟http操作的库,有六种主要方法,可以看看北京理工大学的爬虫教学,里面讲的很仔细。
content = requests.get(URL + FILENAME).content #content是获取内容的方法
(PATH / FILENAME).open("wb").write(content) #接着通过路径打开文件,值得一提的是其实pathlib有内置的写入方法,不过也可以像这里用open再write来写
做完上面的操作之后,在你的文件管理器中应该已经能看到如下的mnist数据文件
这个数据集是numpy array格式的,而且是用一个python特地为序列化数据存储的格式包pickle所存储的。
pickle模块只能在python中使用,python中几乎所有的数据类型(列表,字典,集合,类等)都可以用pickle来序列化
import pickle
import gzip
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1") #从已打开的 file object 文件 中读取封存后的对象,重建其中特定对象的层次结构并返回。
每个图像的大小为28x28,并且存储为长度为784(=28x28)的平坦行(有点不太好描述,大概就是下面这样,就是图像本来应该是二维的,但是数组是一维的)。让我们看一下其中一个;我们需要首先将其重塑为2D。
#这是x_train[0]
array([0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01171875, 0.0703125 , 0.0703125 ,
0.0703125 , 0.4921875 , 0.53125 , 0.68359375, 0.1015625 ,
0.6484375 , 0.99609375, 0.96484375, 0.49609375, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.1171875 , 0.140625 , 0.3671875 , 0.6015625 ,
0.6640625 , 0.98828125, 0.98828125, 0.98828125, 0.98828125,
0.98828125, 0.87890625, 0.671875 , 0.98828125, 0.9453125 ,
0.76171875, 0.25 , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.19140625, 0.9296875 ,
0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.98828125,
0.98828125, 0.98828125, 0.98828125, 0.98046875, 0.36328125,
0.3203125 , 0.3203125 , 0.21875 , 0.15234375, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.0703125 , 0.85546875, 0.98828125, 0.98828125,
0.98828125, 0.98828125, 0.98828125, 0.7734375 , 0.7109375 ,
0.96484375, 0.94140625, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.3125 , 0.609375 , 0.41796875, 0.98828125, 0.98828125,
0.80078125, 0.04296875, 0. , 0.16796875, 0.6015625 ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.0546875 ,
0.00390625, 0.6015625 , 0.98828125, 0.3515625 , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.54296875,
0.98828125, 0.7421875 , 0.0078125 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.04296875, 0.7421875 , 0.98828125,
0.2734375 , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.13671875, 0.94140625, 0.87890625, 0.625 ,
0.421875 , 0.00390625, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.31640625, 0.9375 , 0.98828125, 0.98828125, 0.46484375,
0.09765625, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.17578125,
0.7265625 , 0.98828125, 0.98828125, 0.5859375 , 0.10546875,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.0625 , 0.36328125,
0.984375 , 0.98828125, 0.73046875, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.97265625, 0.98828125,
0.97265625, 0.25 , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.1796875 , 0.5078125 ,
0.71484375, 0.98828125, 0.98828125, 0.80859375, 0.0078125 ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.15234375,
0.578125 , 0.89453125, 0.98828125, 0.98828125, 0.98828125,
0.9765625 , 0.7109375 , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.09375 , 0.4453125 , 0.86328125, 0.98828125, 0.98828125,
0.98828125, 0.98828125, 0.78515625, 0.3046875 , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.08984375, 0.2578125 , 0.83203125, 0.98828125,
0.98828125, 0.98828125, 0.98828125, 0.7734375 , 0.31640625,
0.0078125 , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.0703125 , 0.66796875, 0.85546875,
0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.76171875,
0.3125 , 0.03515625, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.21484375, 0.671875 ,
0.8828125 , 0.98828125, 0.98828125, 0.98828125, 0.98828125,
0.953125 , 0.51953125, 0.04296875, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.53125 , 0.98828125, 0.98828125, 0.98828125,
0.828125 , 0.52734375, 0.515625 , 0.0625 , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. ], dtype=float32)
#y_train[0]
5
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray") #reshape可以告诉系统这是个28x28的矩阵,cmap是颜色映射,cmap的属性可以从这里挑选,这不是重点,不过可以看看https://matplotlib.org/tutorials/colors/colormaps.html
print(x_train.shape)#输出维度
(50000, 784)#这就是训练集的维度,有50000个图片,然后每个图片是784的长度
pytorch用的是torch.tensor,而不是numpy的array,所以我们得转换一下我们的数据格式
import torch
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
) #这个写法值得学习一下,map是一个(函数,迭代类型数据...)这样的函数,torch.tensor应该是构造函数,然后后面传入了这几个可迭代数据
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
既然讲到这里,我们就把torch.tensor一并介绍一下吧
torch.``tensor
(data, dtype=None, device=None, requires_grad=False, pin_memory=False) → Tensor
这个方法可以用 data
参数构建一个张量(Tensor)
可以是 list, tuple, NumPy ndarray
, scalar, and other types.
从头开始神经网络
让我们首先使用PyTorch张量操作创建一个模型。我们假设您已经熟悉了神经网络的基础知识。(如果您不是,您可以在Course.fast.ai(https://course.fast.ai).上学习它们。
PyTorch提供了创建随机或零填充张量的方法,我们将使用这些张量为简单的线性模型创建权重和偏差。这些只是正则张量,有一个非常特殊的补充:我们告诉PyTorch它们需要一个梯度。这会导致PyTorch记录对张量所做的所有操作,以便它可以在反向传播期间自动计算梯度!
对于权重,我们在初始化之后设置了REQUIRES_GRAD,因为我们不希望该步骤包含在梯度中。(Note that a trailling _
in PyTorch signifies that the operation is performed in-place.)【pytorch的签名后面如果跟着下划线,证明这是个in-place操作】
这个in-place操作有点复杂,如果想了解可以看看pytorch是怎么想的,大部分情况是不太建议使用的,如果我写出来会有点冗余,感兴趣可以点击这个学习一下https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd,大致上说就是他只在很少的情况下可以减轻内存负担,但有可能会引发错误,这一篇提到自动计算梯度的,pytorch是建议大家get familiar with it的
import math
weights = torch.randn(size=(784,10)) / math.sqrt(784) #randn是一个返回N~(0,1)的方法,这里为什么要除以一个根号784不太清楚,这样的话weights就是一个服从N~(0,1/784)的分布
weights.requires_grad_()
#这个requires_grad是一个属性,而带了下划线的是一个就地操作方法
bias = torch.zeros(10, requires_grad=True)#用0塞满这个tensor
多亏了PyTorch自动计算梯度的能力,我们可以使用任何标准Python函数(或可调用对象)作为模型!
所以让我们只编写一个简单的矩阵乘法和广播加法来创建一个简单的线性模型。
我们还需要一个激活函数,因此我们将编写log_softmax并使用它。
请记住:尽管PyTorch提供了大量预先编写的损失函数、激活函数等,但是您可以使用普通的python轻松编写自己的函数。
PyTorch甚至会自动为您的函数创建快速的GPU或矢量化的CPU代码。
log_softmax就是下图这样的操作,至于为什么多了个unsqueeze函数,是因为如果不做这一步,他会把一维给整合在一起,这样会有一个问题,就是x-的后面这一坨会比x少一个维度,比如x是两维的,后面这个ln出来的结果就是一维的,这样可能会引发很多问题,比如运算结果不对,甚至无法运行
x
i
−
l
n
(
∑
i
=
0
−
1
维
元
素
个
数
−
1
e
x
i
)
xi-ln(\sum^{-1维元素个数-1}_{i=0} e^{x_{i}})
xi−ln(i=0∑−1维元素个数−1exi)
def log_softmax(x):
return x - x.exp().sum(-1).log().unsqueeze(-1)#这个地方又是一个奇怪的语法糖,这里的几个函数都是pytorch的,你可以把exp(x)和x.exp()等价起来。然后因为x是一个二维的tensor,所以这个奇怪的sum的意思是指定在-1维度求和(行,列)-1维度就是按列求和 维度正整数是(0,1,2-----)
def model(xb):
return log_softmax(xb @ weights + bias) #这里就是wx+b @可以用来简单表示矩阵乘法,好像是python3.5+的语法糖
sub
(other, ***, alpha=1) → TensorSubtracts a scalar or tensor from
self
tensor. If bothalpha
and
other
are specified, each element ofother
is scaled byalpha
before being used.When
other
is a tensor, the shape ofother
must be
broadcastable
with the shape of the underlying tensor.这个的意思就是说,如果你是个标量,那无所谓,如果你是用tensor和tensor做减法,那减数必须对于被减数的shape来说是可传播的,你可以点击可broadcastable看看他的意思
简而言之,如果PyTorch操作支持广播,则其张量参数可以自动扩展为相同大小(无需你自己复制数据)。
如果以下规则成立,则两个张量是互相“可广播的”:
1、每个张量至少有一个维度。
2、在迭代维度大小时,从尾部维度开始,维度大小必须是相等、其中一个为1或其中一个不存在三种情况。