使用Pytorch的torch.nn.DataParallel 进行多GPU训练时遇到的一个bug
问题描述:
我定义的网络模型除了原始的参数之外,还自定义了一组参数,这组参数也要参与训练,但是我在使用torch.nn.DataParallel 进行多GPU训练时,出现bug如下:
Traceback (most recent call last):
File "train_search.py", line 741, in <module>
architecture.main()
File "train_search.py", line 358, in main
train_acc,loss, error_loss, resource_loss, trainable_filter_number,model_performance = self.train(self.all_epochs, logging)
File "train_search.py", line 527, in train
logits, model_property, _ = self.model(input)
File "/home/dhb/jupyter notebook/distiller/env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/dhb/jupyter notebook/distiller/env/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
outputs = self.parallel_