具体原因:
windows下不支持函数 torch.cuda.set_device(args.gpu)
,在linux下支持。
因此需要替换这行代码。
如下:
# torch.cuda.set_device(args.gpu)
# model = model.cuda(args.gpu)
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
具体原因:
windows下不支持函数 torch.cuda.set_device(args.gpu)
,在linux下支持。
因此需要替换这行代码。
如下:
# torch.cuda.set_device(args.gpu)
# model = model.cuda(args.gpu)
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()