torch.summary是pytorch的一个包,可以打印模型的每一层组成,参数量,总的参数量。
使用summary函数,首先安装torchsummary包:
pip install torchsummary
然后导入包:
from torchsummary import summary
接着在运行时进行计算:
summary(model, input_size, batch_size, device)
遇到的一些bug:
1:出现这种错误是因为输入维度高,即不需要指定batchsize,默认为-1.
2:出现这种情况是因为输入格式错误,当有多个输入要用 [ ],将输入括起来。在官方文档中有案例。
注意:当多个输入没有报错,他运行的的结果也是错误的。
他的inputsize和total非常大,需要对torchsummary源码进行修改,
原始未修改的:
修改后的:
修改过程参考:修改流程
首先:加上如下代码找到源码位置
import torchsummary
print(torchsummary.__file__)
然后:根据目录一步一步寻找,在_init_.py同级目录下的torchsummary.py文件中
/home/xh/.local/lib/python3.7/site-packages/torchsummary/__init__.py
最后:将第一百行:
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
替换为:
total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))
修改完后注意保存。
3:当出现这种情况是因为没有指定device,加上device='cpu'就是正确的。