【小白学习笔记】测试mobilenetv3网络如下错误:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
在测试打印网络参数情况时,用如下片段。发现是由于电脑有GPU,数据需要送入cuda()进行运算;如果电脑用CPU则正常。
if __name__ == '__main__':
#net = MobileNetV3_Small().train().cuda() #GPU训练
net = MobileNetV3_Small() #CPU训练
summary(net, (3, 224, 224))