原代码:
net = ModelMain(config, is_training=is_training)
net.train(is_training)
net = nn.DataParallel(net)
net = net.cuda()
Error(s) in loading state_dict for ModelMain
原因:
cuda版的torch,把代码去掉了就会报错:
net = nn.DataParallel(net)
net = net.cuda()
原代码:
net = ModelMain(config, is_training=is_training)
net.train(is_training)
net = nn.DataParallel(net)
net = net.cuda()
Error(s) in loading state_dict for ModelMain
原因:
cuda版的torch,把代码去掉了就会报错:
net = nn.DataParallel(net)
net = net.cuda()