nn.Module是Pytorch提供的神经网络类,在类中实现了网络各层的定义及向前计算与反向传播机制。
实现步骤:
1.继承nn.Module
2.在初始化中定义模型结构与参数
3.在函数forward()中编写网络前向过程
from torch import nn
class Perception(nn.Module):
def __init__(self,in_dim,hid_dim,out_dim):
super(Perception,self).__init__()
self.layer1=Linear(in_dim,hid_dim)
self.layer2=Linear(hid_dim,out_dim)
def forward(self,x):
x=self.layer1(x)
y=torch.sigmoid(x)#激活函数
y=self.layer2(y)
y=torch.sigmoid(y)
return y
对比使用nn.Sequential():
当模型中只是简单的前馈网络(上一层的输出直接作为下一层的输入)时,可使用nn.nn.Sequential()快速搭建
from torch import nn
class Perception(nn.Module):
def __init__(self,in_dim,hid_dim,out_dim):
super(Perception,self).__init__()
self.layer=nn.Sequential(nn.Linear(in_dim,hid_dim),nn.Sigmoid(),nn.Linear(hid_dim,out_dim),nn.Sigmoid())
def forward(self,x):
y=self.layer(x)
return y
参考书籍:《深度学习之PyTorch物理检测实战》