torch固定部分权重,训练其他权重方法

print("# --------------- restore ---------------#")
if (is_restore):
    print("restore from ckpt=", restore_ckpt)
    checkpoint = torch.load(restore_ckpt)  # ,encoding="latin1")
    print("=> resuming from checkpoint ")
    print(restore_ckpt)
    print("==> checkpoint [epoch={}, iteration={}]".format(checkpoint["epoch"], checkpoint["iteration"]))
    args.restore_epoch_iter = "Epoch=[%d]_Iter=[%d]" % (checkpoint["epoch"], checkpoint["iteration"])

    restore_dir = os.path.dirname(restore_ckpt)
    restore_main_dir = os.path.dirname(restore_dir)

    #ckpt = checkpoint["state_dict"]
    #------------------------- 注意这种方式会遗漏bn曾的running_mean,因为running_mean不属于named_parameters----------------------#
    # if (is_face and "isface" not in restore_main_dir):
    #     for name, param in model.named_parameters():
    #         if (name in ckpt.keys()):
    #             tmp = param.data.clone()
    #             param.data = ckpt[name]
    #             diff = torch.sum(torch.abs(tmp - param.data))
    #             print("%s " % name, diff)
    # else:
    #     model.load_state_dict(checkpoint["state_dict"])
    # ------------------------- 注意这种方式会遗漏bn曾的running_mean----------------------#
    if(is_face):
        state_dict=checkpoint["state_dict"]
        own_state = model.state_dict()
        print("------------------------------------ restore list --------------------------------")
        for name, param in state_dict.items():
            if("_isface" in name):
                continue
            if isinstance(param, torch.nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)
            print("restore : ",name)
    else:
        model.load_state_dict(checkpoint["state_dict"])

if(is_face):
    print("#---------------------setting isface------------------------#")
    #for idx,m in enumerate(model.modules()):
    for name, m in model.named_modules():
        # freeze bn https://blog.51cto.com/u_16213585/8592957
        # model.modules 强制遍历所有子层和子层的子层 https://blog.csdn.net/weixin_48076759/article/details/131048684
        print(name,m)
        if isinstance(m, nn.BatchNorm2d) and '_isface' not in name:
            #print("Bn eval")
            m.eval()

    for name,param in model.named_parameters():
        if("_isface" not in name):
            param.requires_grad=False
            
print("------------------------------------ train list -----------------------------------")
for name,param in model.named_parameters():
    print(name.ljust(50),str(param.shape).ljust(30), param.requires_grad)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值