【PyTorch】argparse + os.environ 设置pytorch网络使用的显卡

指定使用的显卡编号

os.environ("CUDA_VISIBLE_DEVICES")='2,3,4'

设置环境变量CUDA_VISIBLE_DEVICES为’2,3,4’,这个时候对于系统来说只有编号2,3,4的显卡是可见的(从0开始)
通过torch.cuda.device_count()获取显卡数量的时候显示的是3,即只能看见这三张显卡

在使用pytorch时,如果需要在gpu上对某些数据进行操作,一般的流程是:

# 获取设备
device = torch.device('cuda:0') # 使用可见gpu中的第0个
# device = torch.device('cpu') 如果使用cpu设备

# 将数据送到设备中
data.to(device)

在上面获取设备的时候,第0个设备获取的其实是编号为2的显卡,同理 torch.device('cuda:2')获取的就是编号为4的显卡。

指定显卡对数据进行操作的方法大概就是这样,也有一些其他的方法,比如

  • 在命令行前加:CUDA_VISIBLE_DEVICES=2 python ...

这里不做赘述

指定显卡不生效

通过 os.environ 指定显卡的时候,经常遇见指定了3号卡训练,但最后还是在0号卡训练的,这个问题最可能的原因是指定显卡和使用torch的顺序问题。

可能的一种错误方式为:

import torch
print(torch,cuda.device_count())

import os
os.environ["CUDA_VISIBLE_DEVICES"]='2'

device = torch.device('cuda:0')
data = torch.tensor([1,2,3])
data.to(device)

在第一次使用torch的时候,系统就会去获取当前可用的设备了,这个时候还没有指定可见的显卡,代码就会加载所有可用的显卡,程序会输出设备数为8,且将data放在0号卡

如果在第一次使用torch前指定显卡,就可以达到想要的效果

import torch

import os
os.environ["CUDA_VISIBLE_DEVICES"]='2'

print(torch,cuda.device_count())

device = torch.device('cuda:0')
data = torch.tensor([1,2,3])
data.to(device)

这个时候程序输出设备数为1,即2号显卡,且data也会放在2号显卡上

所以如果你通过environ指定显卡无法生效,可以尝试看一下自己的代码,是不是在指定显卡之前已经使用了torch的某些函数或功能。

使用argparse在命令行指定显卡

正如上面提到的,可以直接在python命令之前加 CUDA_VISIBLE_DEVICES=1,2 来指定显卡
当然可以使用获取参数时最经常使用的argparse来获取指定的设备号

实现也很简单

import argparse
import os
import torch

parser = argparse.ArgumentParser(description='...')
parser.add_argument('-d', '--device', type=str)
args = parser.parse_args()
if args.device:
	os.environ["CUDA_VISIBLE_DEVICES"] = args.device
# code about torch: 将torch相关的代码写在指定显卡的代码后面

只需要保证torch的使用在指定显卡之后即可,如果你为了保险,将 import torch 也放在指定显卡的代码后面也是可以的。

在使用的时候直接在命令行指定设备即可:

python train.py --device 2
  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
当代码中出现报错"os.environ[local_rank]"时,可能有以下几个原因: 1. 错误的变量名: local_rank在代码中未定义或拼写错误。请确保代码中已经正确定义了local_rank变量。 2. 变量未设置: local_rank没有在os.environ设置os.environ是一个字典,用于存储环境变量。如果在代码中使用os.environ[local_rank],但是local_rank没有在os.environ设置,会导致报错。请确保在使用os.environ[local_rank]之前,已经正确地设置了local_rank环境变量。 3. 环境变量不存在: local_rank是一个不存在的环境变量。如果在代码中使用os.environ[local_rank],但是local_rank并没有在当前的环境变量中设置,会导致报错。请确保local_rank环境变量已经存在且被正确设置。 4. 引用方式错误:在代码中,local_rank应该被作为字符串来引用,即使用'local_rank'而不是local_rank。请确保在使用os.environ['local_rank']时,使用了正确的引用方式。 总结来说,当出现"os.environ[local_rank]"报错时,需要检查代码中是否正确定义了local_rank变量,是否在os.environ中正确设置了local_rank环境变量,以及是否以正确的方式引用了local_rank。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [Pytorch 分布式训练(DP/DDP)](https://blog.csdn.net/ytusdc/article/details/122091284)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [Python基于os.environ从windows获取环境变量](https://download.csdn.net/download/weixin_38698149/12851183)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KyrieLiu52

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值