近期在公司实习遇到一个问题,训练时,采用的是分布式的GPU训练的模型,上线需要cpu版本的,因此测试时,模型载入出错,需要转成CPU版。转换方法如下:
model = torch.load(model_path)
d = collections.OrderedDict()
for key, value in model.state_dict().items():
tmp = key[7:]
d[tmp] = value
model.load_state_dict(d)
分布式GPU训练的模型的key值会多了module.,将它去掉,重新载入模型即可。