a.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]='0,1,2,3'
net = net.cuda()
input = input.cuda()
output = output.cuda()
b.
device = torch.device('cude:0' if torch.cuda.is_available() else 'cpu')
print("device: {}".format(device))
net = net.to(device)
input = input.to(device)
output = output.to(device)