手撸代码
keras有个model.summary() 查看网络结构参数
pytorch没有
一般用print(model)查看网络结构
然后自己手写一个方法统计权重参数量
def CustomCal(net):
res = 0
for i in net:
weight = i.weight.shape
bias = i.bias.shape
tmp=1
for j in weight:
tmp *= j
res+=tmp
res+=bias[0]
return res
利用pytorch-summary
有大佬针对这个问题专门写了一个包去实现类似的功能
pytorch-summary
直接pip install torchsummary 安装
然后导入使用就行了
from torchsummary import summary
summary(your_model, input_size=(channels, H, W))
---------------------------</