pytorch使用DataParallel加速(包括RNN疑难杂症处理)

DataParallel的基本使用方法很简单,只需设置device_ids即可,如下所示:

device_ids = [0, 1, 2, 3]
model = torch.nn.DataParallel(model, device_ids=device_ids)

device_ids为你要使用的GPU号。如果你未使用DataParallel之前用的便是单GPU进行训练,那么对于数据不需要额外的操作,否则,你需要将模型的输入数据转移到cuda上,如:

# 此处device与device_ids无关,你可以设置device = torch.device("cuda:0")
input = input.to(device)

如果顺利的话,简单的两步就可以实现加速了。

然而,由墨菲定律可得:凡是可能出错的事就一定会出错。常见问题如下。

问题1:如果model里定义了一个函数,如初始化函数init_hidden等,并已实现DataParallel,在train函数里该如何调用?

class Model(nn.Module):
    def __init__(self, ):
        pass

    def forward(self, ):
        pass

    def init_hidden(self, ):
        pass


model = Model()
model = nn.DataParallel(model, device_ids=devic
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值