模型训练--

1、模型参数加载:如和确定 “model.load_state_dict(torch.load("weight_path"),strict=False)” 加载到了参数?(strict 表示是否需要参数结构和model完全一样)

load_state_dict 会返回两个参数 一个是miss,一个是unexcepted,第一个参数表示:model中没有得到参数加载的部分(例如 model有100个模块,参数中只有20个模块的参数,那么miss就是剩下80个没有得到参数加载的模块),第二个参数表示哪些参数是model不能加载的(参数中有个模块叫做 A,这个模块model中没有,那么model就不能加载这个模块的参数,这个A就是unexpected)

所以如果想要确定 model.load_state_dict 是否真的给模型赋予了有效的参数 可以检查 miss中的模块和 model中的模块是不是一样,一样就说明没有加载参数,不一样就说明加载到了参数。具体来说就是打印model中的所有模块 print( model.state_dict() )。 打印miss ,print( miss ),然后对比输出是否一样,如果一样说明 model 未加载到任何参数,如果不一样,说明加载到了部分参数。

2、如何保留模型中部分模块的参数:

mm_state_dict = OrderedDict()
state_dict = unet.state_dict()

for key in state_dict:
    if "motion_module" in key:
        mm_state_dict[key] = state_dict[key]

torch.save(mm_state_dict, “xxx.pth”)

3、参数预训练模型和目标模型都有但是参数对不上(例如形状不同)

state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device))

src_state_dict = state_dict['net']

target_state_dict = model.state_dict()

skip_keys = []

for k in src_state_dict.keys():

       if k not in target_state_dict:

              continue

        if src_state_dict[k].size() != target_state_dict[k].size():

               skip_keys.append(k)

 for k in skip_keys:

        del src_state_dict[k]

missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=strict)

3部分参考:神经网络load_state_dict()进阶使用_木木海2000的博客-CSDN博客

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值