torch.nn.Module是所有神经网络模块的基本类,所有神经网络都需要继承实现它。
这是官网给出的一个示例,首先继承父类nn.Module并初始化,然后重载前向传播foward方法。
下面操作一个简单的例子:
class Model(nn.Module):
def __int__(self):
super().__init__()
def forward(self, x):
return x + 1
model = Model()
output = model(1)
print(output)
结果比较简单:
forward函数实现了__call__方法,因此直接传入参数便调用了forward方法