1、用torchsummary打印pytorch模型参数信息,于是编写以下代码
from torchsummary import summary
summary(UNET, input_size=(3, 256, 256))
然后出现以下错误:
TypeError: apply() missing 1 required positional argument: 'fn'
2、修改之后
summary(UNET(nn.Module), input_size=(3, 256, 256))
出现以下错误:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
3、然后修改如下解决问题
①summary(UNET(nn.Module).cuda(), input_size=(3, 256, 256))
②device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNET(nn.Module).to(device)
summary(model, input_size=(3, 256, 256))