在学习神经网络的时候,总是会出现mat1 和 mat2 不能相乘的问题,其实就是神经网络的某一层的输入和上一层的输出不对应,那如何在定义好网络结构之后进行检查呢?
以AlexNet网络为例,代码如下:
# 两种方法实现AlexNet
class AlexNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 96, 11, 4, 1)
self.maxpool1 = nn.MaxPool2d(3, 2)
self.conv2 = nn.Conv2d(96, 256, 5, padding=2)
self.conv3 = nn.Conv2d(256, 384, 3, padding=1)
self.conv4 = nn.Conv2d(384, 384, 3, padding=1)
self.conv5 = nn.Conv2d(384, 256, 3, padding=1)
self.flatten = nn.Flatten(1)
self.linear1 = nn.Linear(6400, 4096)
self.linear2 = nn.Linear(4096, 4096)
self.linear3 = nn.Linear(4096, 10)
self.relu = nn.ReLU()
self.dropout = nn.Dropout()
def forward(self, input):
h1 = self.relu(self.conv1(input))
h1 = self.maxpool1(h1)
h2 = self.relu(self.conv2(h1))
h2 = self.maxpool1(h2)
h3 = self.relu(self.conv3(h2))
h4 = self.relu(self.conv4(h3))
h5 = self.relu(self.conv5(h4))
h5 = self.maxpool1(h5)
h5 = self.flatten(h5)
h6 = self.relu(self.linear1(h5))
h7 = self.relu(self.linear2(h6))
h7 = self.dropout(h7)
h8 = self.linear3(h7)
return h8
net = nn.Sequential(
nn.Conv2d(1, 96, 11, 4, 1), nn.ReLU(),
nn.MaxPool2d(3, 2),
nn.Conv2d(96, 256, 5, padding=2), nn.ReLU(),
nn.MaxPool2d(3, 2),
nn.Conv2d(256, 384, 3, padding=1), nn.ReLU(),
nn.Conv2d(384, 384, 3, padding=1), nn.ReLU(),
nn.Conv2d(384, 256, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(3, 2),
nn.Flatten(1),
nn.Linear(6400, 4096), nn.ReLU(),
nn.Linear(4096, 4096), nn.ReLU(),
nn.Linear(4096, 10), nn.ReLU(),
nn.Dropout())
方法一
一种方法是在《动手学深度学习》中提到的,如果使用的是nn.Sequential()对网络进行定义,可以写一个循环,逐层计算和打印,代码如下
X = torch.rand((1, 1, 224, 224))
for layer in net:
X=layer(X)
print(layer.__class__.__name__,'output shape:\t',X.shape)
输入是这样的:
Conv2d output shape: torch.Size([1, 96, 54, 54])
ReLU output shape: torch.Size([1, 96, 54, 54])
MaxPool2d output shape: torch.Size([1, 96, 26, 26])
Conv2d output shape: torch.Size([1, 256, 26, 26])
ReLU output shape: torch.Size([1, 256, 26, 26])
MaxPool2d output shape: torch.Size([1, 256, 12, 12])
Conv2d output shape: torch.Size([1, 384, 12, 12])
ReLU output shape: torch.Size([1, 384, 12, 12])
Conv2d output shape: torch.Size([1, 384, 12, 12])
ReLU output shape: torch.Size([1, 384, 12, 12])
Conv2d output shape: torch.Size([1, 256, 12, 12])
ReLU output shape: torch.Size([1, 256, 12, 12])
MaxPool2d output shape: torch.Size([1, 256, 5, 5])
Flatten output shape: torch.Size([1, 6400])
Linear output shape: torch.Size([1, 4096])
ReLU output shape: torch.Size([1, 4096])
Linear output shape: torch.Size([1, 4096])
ReLU output shape: torch.Size([1, 4096])
Linear output shape: torch.Size([1, 10])
ReLU output shape: torch.Size([1, 10])
Dropout output shape: torch.Size([1, 10])
但是如果把这个方法,用到使用class定义的网络中,就会报错
所以就有了方法二
方法二
使用torchinfo这个包,可以对网络进行打印,方法如下:
安装torchinfo,大概20k左右的大小
pip install torchinfo
导入包中的summary函数
from torchinfo import summary
传入需要打印的网络和输入尺寸的形状(可以是一个批量)
net = AlexNet()
print(summary(net, (1, 1, 224, 224)))
结果如下:不仅输出了每层的形状,还包括参数量的情况
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
AlexNet [1, 10] --
├─Conv2d: 1-1 [1, 96, 54, 54] 11,712
├─ReLU: 1-2 [1, 96, 54, 54] --
├─MaxPool2d: 1-3 [1, 96, 26, 26] --
├─Conv2d: 1-4 [1, 256, 26, 26] 614,656
├─ReLU: 1-5 [1, 256, 26, 26] --
├─MaxPool2d: 1-6 [1, 256, 12, 12] --
├─Conv2d: 1-7 [1, 384, 12, 12] 885,120
├─ReLU: 1-8 [1, 384, 12, 12] --
├─Conv2d: 1-9 [1, 384, 12, 12] 1,327,488
├─ReLU: 1-10 [1, 384, 12, 12] --
├─Conv2d: 1-11 [1, 256, 12, 12] 884,992
├─ReLU: 1-12 [1, 256, 12, 12] --
├─MaxPool2d: 1-13 [1, 256, 5, 5] --
├─Flatten: 1-14 [1, 6400] --
├─Linear: 1-15 [1, 4096] 26,218,496
├─ReLU: 1-16 [1, 4096] --
├─Linear: 1-17 [1, 4096] 16,781,312
├─ReLU: 1-18 [1, 4096] --
├─Dropout: 1-19 [1, 4096] --
├─Linear: 1-20 [1, 10] 40,970
==========================================================================================
...
Forward/backward pass size (MB): 4.87
Params size (MB): 187.06
Estimated Total Size (MB): 192.13
综上,通过两个方法都可以实现对神经网络中每层的输出情况的打印,当然,使用hook也可以更加灵活的实现,但是相对来说不如上面两种方法简单