方法1.使用summary
test.py正确代码如下:
import torch
from torchsummary import summary
from nets.yolo4 import YoloBody
if __name__ == "__main__":
# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YoloBody(3,80,backbone="mobilenetv2",phi=0).to(device)
summary(model, input_size=(3, 416, 416))
输出网络结构和参数量,以及可训练参数量
方法2.使用torchstat
遇到的问题:RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
问题分析:输入在cpu,模型指定在gpu,应该一致
解决方法:都放到cpu上
test.py正确代码如下:
import torch
from torchsummary impor