可以用在多GPU的服务器上,明确使用的是哪块GPU,可以将信息打印出来,返回的是device的列表。
具体用到的pytorch函数是torch.cuda.device_count(),返回可得到的GPU数量。
接下来是具体函数代码,分为两部分,一个主函数get_proper_device明确使用GPU还是CPU,另外一个函数get_proper_cuda_device是明确具体使用哪块GPU。
# 获取 GPU信息
def get_proper_cuda_device(device, verbose=True):
if not isinstance(device, list):
device = [device]
count = torch.cuda.device_count()
if verbose:
print("[Builder]: Found {} gpu".format(count))
for i in range(len(device)):
d = device[i]
d_id = None
if isinstance(d, str):
# 正则表达式,查看是否存在“cuda:0”这种形式。
if re.search("cuda:[\d]+", d):
d_id = int(d[5:])
elif isinstance(d, int):
d_id = d
if d_id is None:
raise ValueError("[Builder]: Wrong cuda id {}".format(d))
if