CNN中经常需要考虑到网络的参数量与计算量等问题,具体的计算方法为:
其中,K是卷积核的大小,Cin核Cout表示输入与输出的通道数,H与W表示特征图的大小。
此外,可以通过python中的stat模块计算与验证,代码:
import torch
import torch.nn as nn
from torchstat import stat
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Sequential()
self.conv.add_module("conv", nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, bias=False, padding=1))
self.conv.add_module("bn", nn.BatchNorm2d(64))
# self.conv.add_module("")
self.fc = nn.Linear(64*224*224,100, bias=True)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 64*224*224)
x = self.fc(x)
return x
model = Net()
stat(model, (3, 224, 224))
输出为:
注意:torchstst的版本需要为0.0.6,安装方法: pip install torchstat==0.0.6