模型的定义
首先,必须继承 nn.Module 类。
其次,在__init__(self)中设置好需要的“组件"(如 conv、 pooling、 Linear、 BatchNorm等)。
最后,在 forward(self, x)中用定义好的“组件”进行组装。
class Net(nn.Module):
def __init__(self): # 初始化,定义组件
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84,