指定 GPU
torch.cuda.set_device(3)
该语句指定为 GPU3(从0开始)后,之后的一系列操作都是在 GPU3 上执行,如
# "cuda" 是指定 GPU3
torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 是限制的 GPU3 的显存
desired_memory_fraction = 0.5 # 50% 显存
torch.cuda.set_per_process_memory_fraction(desired_memory_fraction)
也可以用
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tensor = torch.empty(..., device=device)
单独指定其他设备。
限制使用显存
# 指定之后所有操作在 GPU3 上执行
torch.cuda.set_device(3)
# 限制 GPU3 显存使用50%
desired_memory_fraction = 0.5 # 50% 显存
torch.cuda.set_per_process_memory_fraction(desired_memory_fraction)
# 获取当前GPU上的总显存容量
total_memory = torch.cuda.get_device_properties(3).total_memory
# 指定使用 GPU3
tmp_tensor = torch.empty(int(total_memory * 0.4999), dtype=torch.int8, device="cuda") # 此处 cuda 即指 GPU3
# 获取当前已分配的显存,计算可用显存
allocated_memory = torch.cuda.memory_allocated()
available_memory = total_memory - allocated_memory
# 打印结果
print(f"Total GPU Memory: {total_memory / (1024**3):.2f} GB")
print(f"Allocated GPU Memory: {allocated_memory / (1024**3):.2f} GB")
print(f"Available GPU Memory: {available_memory / (1024**3):.2f} GB")
此时占用了50%的显存,而将0.4999改为0.5会爆显存,可能是受浮点数精度影响。