torch多GPU加速

在具体使用pytorch框架进行训练的时候,发现实验室的服务器是多GPU服务器,因此需要在训练过程中,将网络参数都放入多GPU中进行训练。

   正文开始:

   涉及的代码为torch.nn.DataParallel,而且官方推荐使用nn.DataParallel而不是使用multiprocessing。官方代码文档如下:nn.DataParallel   教程文档如下:tutorial

torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
该函数实现了在module级别上的数据并行使用,注意batch size要大于GPU的数量。

参数 : module:需要多GPU训练的网络模型

device_ids: GPU的编号(默认全部GPU)

output_device:(默认是device_ids[0])

dim:tensors被分散的维度,默认是0

在代码文档中使用方法为:

net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
out = net(input)
具体实际操作中,加入其他代码会更好一点。

重点:
自己的代码及解释如下:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
定义device,其中需要注意的是“cuda:0”代表起始的device_id为0,如果直接是“cuda”,同样默认是从0开始。可以根据实际需要修改起始位置,如“cuda:1”。

model = Model()
if torch.cuda.device_count() > 1:
  model = nn.DataParallel(model,device_ids=[0,1,2])
 
model.to(device)
这里注意,如果是单GPU,直接model.to(device)就可以在单个GPU上训练,但如果是多个GPU就需要用到nn.DataParallel函数,然后在进行一次to(device)。

需要注意:device_ids的起始编号要与之前定义的device中的“cuda:0”相一致,不然会报错。

如果不定义device_ids,如model = nn.DataParallel(model),默认使用全部GPU。定义了device_ids就可以使用指定的GPU,但一定要注意与一开始定义device对应。

   通过以上代码,就可以实现网络的多GPU训练。

以下是训练过程中遇到其他的bug:

1、在实际训练过程中,如果要用到网络中的子模块,需要注意:

在单GPU中,可以使用以下代码

model = Net()
out = model.fc(input)
但是在DataParallel中,需要修改为如下:

model = Net()
model = nn.DataParallel(model)
out = model.module.fc(input)
其实,将并行后的网络打印出来就会发现需要加上“module”,千万注意是module,而不是model。这样就可以调用并行网络中定义的网络层。

2、在服务器上运行matplotlib库时,因为没有图形界面而报错,具体报错情况没有截图,但是如果使用远程服务器训练,而且没有图形界面。在使用matplotlib绘图时,建议添加以下代码:

import matplotlib.pyplot as plt
plt.switch_backend('agg')
而且训练时不要使用plt.show(),只是保存图片就行,之后在下载就可以。

其他遇到的bug,会不定时更新~
 ———————————————— 
版权声明:本文为CSDN博主「daydayjump」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/daydayjump/article/details/81158777

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值