记一下多GPU训练LSTM
使用多GPU需要在to.(device)前面加
model = nn.Dataparallel(model)
有错误
lstm只有一个输入x,原代码中多了一个self.hidden_vect_1,去掉后如下
output, hidden = self.lstm(x)
能跑但是有warning
需要在lstm前面加一行
self.lstm.flatten_parameters()
这样就没问题了
总结
在to(device)前面加nn.Dataparallel
model = nn.Datap
原创
2020-09-16 10:24:35 ·
1132 阅读 ·
0 评论