神经网络由对数据进行操作的层/模块组成。torch.nn提供了构建自己的神经网络所需的所有构建模块。 PyTorch中的每个模块都子类化nn.Module。 神经网络本身是由其他模块(layers)组成的模块。 这种嵌套结构允许轻松地构建和管理复杂的体系结构。
接下来的示例基于FashionMNIST数据集构建一个分类神经网络:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
判断训练的设备
在GPU可用的情况下使用GPU,torch.cuda.is_available()
>>> import torch
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> print(f"Using {device} device")
Using cuda device
定义神经网络的类
神经网络都继承自nn.Module,并且通过__init__函数来初始化神经网络中的层。此外,每一个nn.Module的子类都需要实现输入数据的forward方法,代表着神经网络的前向传播过程。
神经网络中的层
torch.nn.Flatten
在pytorch > 1.1.0之上支持。
Flatten层将每个2D 28x28图像转换为784个像素值的连续数组(保持小批量维度(dim=0))。
>>> input_image = torch.rand(3, 28, 28)
>>> flatten = nn.Flatten()
>>> flat_image = flatten(input_image)
>>> print(flat_image.size())
torch.Size([3, 784])
torch.nn.linear
CLASS torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
实现线性转换的层 y=xAT+b, 参数如下:
- in_features: 每个输入样本的size,可以看作是上一层的神经元个数,也是输入的特征数
- out_features: 每个输出样本的size,可以看作是下一层的神经元个数,也是输出的特征数
- bias:网络层是否有偏置,默认存在,且维度为[out_features ],若bias=False,则该网络层不会学习偏置
学习到的参数:
Linear.weight 学习到的线性变化参数w
Linear.bias有偏置情况下学习到的参数b
例子:
>>> import torch
>>> from torch import nn
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> input
tensor([[-0.9986, -0.3266, 1.5516, ..., -0.5078, 1.1729, -0.1926],
[-1.1963, 1.3566, 0.2359, ..., -0.9832, 0.1610, -0.3316],
[-1.3699, -0.1855, 0.4309, ..., -0.9447, 2.2891, -1.2312],
...,
[-0.5096, -0.5692, -1.0941, ..., -0.2755, 0.1662, -1.2074],
[ 1.6223, -1.5476, 1.6454, ..., 0.0478, -0.3027, -0.9315],
[-0.0033, -0.3463, 0.3735, ..., -1.4873, 0.2510, 1.4466]])
>>> output = m(input)
>>>> output.shape
torch.Size([128, 30])
输入的数据为128个样本,每个样本是一个20个数的向量。神经网络的线性层输入为20 ,输出为30。则output也是128个样本,每个样本是30个数的向量。这是一个增维的过程,即特征数增多。
>>> m.weight.shape
torch.Size([30, 20])
>>> m.bias.shape
torch.Size([30])
torch.nn.ReLU
非线性激活创建了模型输入和输出之间的复杂映射。 它们在线性变换后应用于引入非线性,帮助神经网络学习各种各样的现象。
print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")
torch.nn.Sequential
nn.Sequential 是modules的顺序容器,模块将按照在构造函数中传递的顺序被添加到容器中。 另外,模块的OrderedDict也可以被传入。 Sequential的forward()方法接受任何输入并将其转发到它包含的第一个模块。 然后,它将输出按顺序“链”到每个后续模块的输入,最后返回上一个模块的输出。