安装工具包
pip install torchstat
Module
from torchstat import stat
import torchvision.models as models
model = models.resnet18()
stat(model, (3, 224, 224))
Debug
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
原因分析
错误内容就在类型不匹配,根据报错内容看出输入类型为CPU 类型,而权重类型为GPU类型
解决方案
既然网络参数是GPU类型,那解决方法就是将输入类型转变为GPU类型,需要使用到cuda,没有cuda就解决不了。
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
inputs = inputs.to(device)
# 方法一:将input这个tensor转换成了CUDA 类型
inputs = inputs.cuda()
# 方法二:将input这个tensor转换成了CUDA 类型