源代码
from torchsummary import summary
from vggnet import VGGNet
myNet = VGGNet()
summary(myNet, (3, 28, 28))
# print(myNet)
运行报错信息
重点
说明我们应该把模型放到cuda上面去运行
修改
主要修改位置
myNet = VGGNet() =>myNet = VGGNet().cuda()
整体
from torchsummary import summary
from vggnet import VGGNet
myNet = VGGNet().cuda()
summary(myNet, (3, 28, 28))
# print(myNet)
运行成功,解决!
特别注意点!!!
summary(myNet, (3, 28, 28))这里的(3, 28, 28)里面的3是深度,一定要和定义的一样
比如:
这里我们输入的深度是1
所以应该是summary(myNet, (1, 28, 28))