最近在写论文需要对比网络大小,平时可能用不到,做个笔记。
1、利用工具,torchsummary
输入安装命令:pip install torchsummary -i https://pypi.tuna.tsinghua.edu.cn/simple
加入清华镜像源快一点
import torch
from torchsummary import summary
from nets.yolo import YoloBody
if __name__ == "__main__":
# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
m = YoloBody([[6, 7, 8], [3, 4, 5], [0, 1, 2]], 80, backbone = "ghostnet").to(device)
summary(m, input_size=(3, 416, 416))
然后使用summary(model,input_size)
,运行即可查看网络大小。
但是该工具用在fasterrcnn却运行不了,自己也不知道为什么?于是自己书写代码查看网络总参数。
2、自己书写
其实就两句话 for循环读取参数,sum求和。把求和之后的网络总数打印出来即可。
from nets.yolo import YoloBody
total = sum([param.nelement() for param in YoloBody([[6, 7, 8], [3, 4, 5], [0, 1, 2]], 80, backbone = "ghostnet").parameters()])
print("Total params: %d" % total)
打印之后和用torchsummary的一样。