1、模型参数加载:如和确定 “model.load_state_dict(torch.load("weight_path"),strict=False)” 加载到了参数?(strict 表示是否需要参数结构和model完全一样)
load_state_dict 会返回两个参数 一个是miss,一个是unexcepted,第一个参数表示:model中没有得到参数加载的部分(例如 model有100个模块,参数中只有20个模块的参数,那么miss就是剩下80个没有得到参数加载的模块),第二个参数表示哪些参数是model不能加载的(参数中有个模块叫做 A,这个模块model中没有,那么model就不能加载这个模块的参数,这个A就是unexpected)
所以如果想要确定 model.load_state_dict 是否真的给模型赋予了有效的参数 可以检查 miss中的模块和 model中的模块是不是一样,一样就说明没有加载参数,不一样就说明加载到了参数。具体来说就是打印model中的所有模块 print( model.state_dict() )。 打印miss ,print( miss ),然后对比输出是否一样,如果一样说明 model 未加载到任何参数,如果不一样,说明加载到了部分参数。
2、如何保留模型中部分模块的参数:
mm_state_dict = OrderedDict() state_dict = unet.state_dict() for key in state_dict: if "motion_module" in key: mm_state_dict[key] = state_dict[key] torch.save(mm_state_dict, “xxx.pth”)
3、参数预训练模型和目标模型都有但是参数对不上(例如形状不同)
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device))
src_state_dict = state_dict['net']
target_state_dict = model.state_dict()
skip_keys = []
for k in src_state_dict.keys():
if k not in target_state_dict:
continue
if src_state_dict[k].size() != target_state_dict[k].size():
skip_keys.append(k)
for k in skip_keys:
del src_state_dict[k]
missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=strict)