一、继承Module类来构造模型
Module类是nn
模块里提供的一个模型构造类(专门用来构造模型的),是所有神经网络模块的爸爸,我们可以继承它来定义我们想要的模型。
import torch
from torch import nn
class MLP(nn.Module):
def __init__(self, **kwargs):
# 调用MLP父类Module的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
# 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params
super(MLP, self).__init__(**kwargs)
self.hidden = nn.Linear(784, 256) # 隐藏层
self.act = nn.ReLU()
self.output = nn.Linear(256, 10) # 输出层
# 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
上面这个MLP就是继承了Module类,并且重写了Module类的__init__方法和forward方法。对于一个神经网络模型来说,确定这个模型有哪些层以及这些层之间的数据如何流动,那这个模型就确定了。所以MLP的__init__方法声明MLP模型有哪些类,MLP的forward方法规定数据如何在层之间流动。
部分初学者可能会想,那我还继承Module类干嘛,直接下面这样不行吗