class LeNet(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size, strides):
super(LeNet, self).__init__()
self.conv = nn.Sequential(
# 3x32x32 -> 6x28x28 k = 5, s = 1
nn.Conv2d(in_channels=input_channels[0], out_channels=output_channels[0],
kernel_size=kernel_size[0], stride=strides[0]),
nn.ReLU(),
# 6x28x28 -> 6x14x14 k = 2, s = 2
nn.MaxPool2d(kernel_size=kernel_size[1], stride=strides[1]),
# 6x14x14 -> 16x10x10 k = 5, s = 1
nn.Conv2d(in_channels=input_channels[1], out_channels=output_channels[1],
kernel_size=kernel_size[2], stride=strides[2]),
nn.ReLU(),
# 16x10x10 -> 16x5x5 k = 2, s = 2
nn.MaxPool2d(kernel_size=kernel_size[3], stride=strides[3]),
nn.ReLU()
)
self.fc = nn.Sequential(
# 16x5x5 -> 120x1x1
nn.Linear(in_features=input_channels[2], out_features=output_channels[2], bias=True),
nn.ReLU(),
# 120x1x1 -> 84x1x1
nn.Linear(in_features=input_channels[3], out_features=output_channels[3], bias=True),
nn.Softmax(dim=1),
# 84x1x1 -> 10x1x1
nn.Linear(in_features=input_channels[4], out_features=output_channels[4], bias=True),
nn.Softmax(dim=1)
)
def forward(self, x):
x = self.conv(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x
原因是全连接层的每一层都使用了SoftMax激活函数,将中间的全连接层的激活函数用ReLU函数替代(如上图),只有最后一层的全连接层使用SoftMax激活函数。