学习目标:
尽管pytorch可用来实现神经网络的传播,但如果要完成深度网络的搭建和训练,仍然比较麻烦。故pytorch提供了更高模块化的接口torch.nn。
- nn.Module类
- 损失函数
- 优化器nn.optim
nn.Module:
nn.Module是pytorch提供的神经网络类,在该类中实现了网络各层定义及前向和反向传播。使用中,只要继承该类,并初始化中定义模型结构和参数,在函数forward()中编写网络前向过程即可。以建立一个两个全连接层的的感知机为例:
import torch #导入torch模块
from torch import nn.Module #导入神经网络类模块
class Linear(nn.Module): #建立全连接子类,继承nn.Module
def __init__(self, in_dim, out_dim): #构造函数,输入、输出层参数
super().__init__() #super调用父类nn.Module的构造函数
'''
pytorch中的Parameter函数可以对某个张量进行参数化。
它可以将不可训练的张量转化为可训练的参数类型。
同时将转化后的张量绑定到模型可训练参数的列表中,当更新模型的参数时一并将其更新。
'''
self.w = nn.Parameter(torch.randn(in_dim, out_dim))
self.b = nn.Parameter(torch.randn(out_dim)) #使用Parameter函数构造要学习的参数
def forward(self, x): #实现 y = wx + b 前向传播
x = x.matmul(self, w) #使用Tensor.matmul实现矩阵相乘
y = x + self.b.expand_as(x) #使用Tensor.expand_as()保证矩阵形状一致
return y
class Perception(nn.Module): #构建感知机类,同样继承nn.Module,并调用Linear的子Module
def __init__(self, in_dim, hid_dim, out_dim):
super().__init__()
self.layer1 = Linear(in_dim, hid_dim) #定义第一层(输入节点——中间节点)全连接
self.layer2 = Linear(hid_dim, out_dim) #定义第二层(中间节点——输出节点)全连接
def forwad(self, x):
x = self.layer1(x) #x送入第一层全连接
y = torch.sigmoid(x) #使用激活函数得到y
y = self.layer2(y) #y送入第二层全连接
y = torch.sigmod(y) #使用激活函数
return y
编写完网络模型模块后,可以使用Perception调用该模块;
nn.Parameter函数:
- 在类的__init__()中需要定义网络学习参数,在此使用nn.Parameter()函数定义全连接中的w和b,使得w和b张量转化为可训练的参数类型,默认可求导
- pytorch中的Parameter函数可以对某个张量进行参数化。
它可以将不可训练的张量转化为可训练的参数类型。
同时将转化后的张量绑定到模型可训练参数的列表中,当更新模型的参数时一并将其更新。
forward()函数:
- forward()函数用来进行网络的前向传播,需要传入相应的tensor。
nn.functional库:
- 在pytorch中,还有一个nn.functional库。同样提供了很多网络层与函数功能。
- nn.functional定义的网络层不能自动学习参数,需要用nn.Parameter封装,主要针对激活层、BN层。
nn.Sequential()模块:
- 对于简单的上一层输出作为下一层输入的前馈网络,可以使用nn.Sequential()模块快速搭建模型。
from torch import nn
class Perception(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_dim, hid_dim),
nn.Sigmoid(),
nn.Linear(hid_dim, out_dim),
nn.Sigmoid()
)
def forward(self, x):
y = self.layer(x)
return y
- 实际上,nn.Linear(in_dim, out_dim, bias)函数是构造好的,输入参数为神经元个数和输出神经元个数,bias偏置默认为True。