pytorch多卡单卡模型加载
\quad
模型的保存和加载参照 pytorch模型保存及加载详解
\quad
多卡保存的时候,在model的state_dict()参数多了一个"module."的前缀,其他的参数保存的时候单卡多卡保存并没有区别。因此在模型相互加载之前把这个处理好这个前缀就可以了。
一、 ⇒ l o a d t o \xRightarrow{load \ to\ } load to 单卡
def strip_prefix(self, state_dict, prefix='module.'):
if not all(key.startswith(prefix) for key in state_dict.keys()):
return state_dict
stripped_state_dict = {}
for key in list(state_dict.keys()):
stripped_state_dict[key.replace(prefix, '')] = state_dict.pop(key)
return stripped_state_dict
二、 ⇒ l o a d t o \xRightarrow{load \ to\ } load to 多卡
def add_prefix(self, state_dict, prefix='module.'):
if all(key.startswith(prefix) for key in state_dict.keys()):
return state_dict
stripped_state_dict = {}
for key in list(state_dict.keys()):
key2 = prefix + key
stripped_state_dict[key2] = state_dict.pop(key)
return stripped_state_dict
三、使用
checkpoint = torch.load(pretrain)
if multi_gpu is not None:
model.load_state_dict(self.add_prefix(checkpoint['state_dict']))
else:
model.load_state_dict(self.strip_prefix(checkpoint['state_dict']))
optimizer.load_state_dict(checkpoint['optimizer'])