在Linux中使用显卡训练网络时,一般会通过device id来确定使用的显卡。我们从GitHub上获取的源码中的device id和我们本地的device id肯定不一致,所以训练时一定要注意device id修改。
以下示例:
源码:
model = nn.DataParallel(
model.cuda(), device_ids=[0,1]
源码中使用了id为0和1 的显卡进行训练。
本地训练报错:
AssertionError: Invalid device id
本地显卡指示:
CUDA Device count: 1
本地只有一个显卡,代码中带入了2个id,这时候肯定会报错。修改代码如下:
model = nn.DataParallel(
model.cuda(), device_ids=[0]
注意:
在使用多显卡进行训练时,一定要注意显卡id设置。如遇问题可以参考: