import torch
def try_gpu(i=0):
"""如果存在返回gpu(i), 否则返回cpu()"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{i}')
return torch.device('cpu')
def try_all_gpus():
"""返回该设备上的所有GPU数"""
devices = [
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
return devices if devices else [torch.device('cpu')]
a = try_gpu()
b = try_gpu(1)
c = try_all_gpus()
print(a, b, c)
10-09
2209
03-03
4129
11-07
6590