问题:最近在实现bert pytorch版本的过程中遇到一个不能使用多GPU的问题,然而github原始版本是可以使用的,修改过程中的一些改动导致使用多GPU时会报如下的错误: arguments are located on different GPUs
定位完问题的位置后一开始以为是cuda设置的问题,后来发现问题出在tranformer模块这里。旧代码在transformer的12个layer建立时采用了简单的list来存储然后用add_module的方法建立模型。但是这样的写法在多GPU的情况下好像是有问题的。以下是修改前后的代码对比:
修改前通过list和add_module方法建立
修改为nn.ModuleList方法建立
出错原因详解:
ModuleList和普通list不一样,它和torch的其他机制结合紧密,继承了nn.Module的网络模型class可以使用nn.ModuleList并识别其中的parameters。而在我们出错的代码中可以看见我们的子module是用普通的list存储的,这种写法的子module不能被主module所识别,所以其参数未加入到主module的参数中去,自然会报第一张图中的arguments are located on different GPUs。