问题描述
python 运行tensor的index_select()函数时,报错信息如下
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
解决办法
(1)定位错误出现代码行,如下
inputs = inputs.index_select(3, torch.arange(inputs.size(3) - 1, -1, -1)
(2)查看参数是否都在GPU/CPU上。
print(inputs.is_cuda,torch.arange(inputs.size(3) - 1, -1, -1).is_cuda)
(3)发现inputs在GPU,torch.arange(inputs.size(3) - 1, -1, -1)在cpu。
用以下代码替换错误代码即可。
inputs = inputs.index_select(3, torch.arange(inputs.size(3) - 1, -1, -1).to('cuda'))