问题1:网上下载下来的代码,复现过程中出现AssertionError: Invalid device id
解决:源代码中调用了多个GPU进行加速运算。
if torch.cuda.is_available() and ngpu > 1:
model = nn.DataParallel(model, device_ids=list(range(ngpu)))
令device_ids =[0]
if torch.cuda.is_available():
model = nn.DataParallel(model, device_ids=[0])
未完待续!