os.environ[‘CUDA_VISIBLE_DEVICES’] = “1,2,3,4,5,6”
use_cuda = torch.cuda.is_available()
conf.device=torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
self.model = resnet50()
if use_cuda:
self.model = torch.nn.DataParallel(self.model).cuda() #默认就是123456 GPU
else: # 单GPU
self.model = self.model.to(conf.device)
self.head = self.head.to(conf.device)
然后数据变成cuda 类型
imgs = imgs.cuda()
labels = labels.cuda()