解决方法:
load_state_dict(torch.load('net.pth')在前,增加
model = nn.DataParallel(model)
就可以了。
比如
net = NET()
net.cuda()
net = nn.DataParallel(net)
net.load_state_dict(torch.load('net.pth')
解决方法:
load_state_dict(torch.load('net.pth')在前,增加
model = nn.DataParallel(model)
就可以了。
比如
net = NET()
net.cuda()
net = nn.DataParallel(net)
net.load_state_dict(torch.load('net.pth')