import pynvml
pynvml.nvmlInit() # 初始化
deviceCount = pynvml.nvmlDeviceGetCount() # 获得计算机中GPU的个数
largest_free_mem = 0
largest_free_idx = 0
for i in range(deviceCount):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
if info.free > largest_free_mem:
largest_free_mem = info.free
largest_free_idx = i
pynvml.nvmlShutdown()
largest_free_mem = largest_free_mem / 1024. / 1024. # Convert to MB
idx_to_gpu_id = {}
for i in range(deviceCount):
idx_to_gpu_id[i] = '{}'.format(i) # 获得所有gpu的ID
gpu_id = idx_to_gpu_id[largest_free_idx]
print(gpu_id, largest_free_mem)