调用库上方添加:
os.environ['CUDA_VISIBLE_DEVICES']= '0,1'
设备设定由:
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
改为:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
模型加载设定由:
model = create_model(num_classes=args.num_classes).to(device)
改为:
model = create_model(num_classes=args.num_classes)
外加:
model = nn.DataParallel(model,device_ids=None).to('cuda')
原代码:
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
model = create_model(num_classes=args.num_classes).to(device)
现代码:
os.environ['CUDA_VISIBLE_DEVICES']= '0,1'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = create_model(num_classes=args.num_classes)
model = nn.DataParallel(model,device_ids=None).to('cuda')