os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" #可用的卡号
device_ids = [1,2,3] #显卡id
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") #从显卡1开始载入模型
textnet = TextNet(code_length=512)
textnet =nn.DataParallel(textnet,device_ids=device_ids) #三块卡一起跑
textnet.to(device)
解决不在一张卡的问题
device = fusion_vector.device #查看当前在哪张卡
seed=seed.to(device)