【Swin-Unet】官方代码预训练权重加载函数load_from()详解

最近在用Swin-Unet改实验,正好看到官方的issue里也有人提问这个问题,就顺便学习了一下。如果解释的有问题欢迎大家指正!!!

总结:实际上由于SwinUnet是一个encoder-decoder对称的结构,因此加载权重时,作者并没有像通常那样仅仅加载encoder部分而不加载decoder部分,而是同时将encoder的权重对称地加载到了decoder上(除了swin_unet.layers_up.1/2/3.upsample)
Swin-Unet

详细注释看下面:

    def load_from(self, args):
        pretrained_path = args.pretrain_ckpt   #config.MODEL.PRETRAIN_CKPT
        if pretrained_path is not None:
            print("pretrained_path:{}".format(pretrained_path))
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            pretrained_dict = torch.load(pretrained_path, map_location=device)
            # print("pretrained_dict.keys()==",pretrained_dict.keys())  #dict_keys(['model'])
            #用dict.keys()输出字典元素所有的键

            if "model" not in pretrained_dict:  #正常情况下都有,不执行这里
                print("---start load pretrained modle by splitting---")
                pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()}
                for k in list(pretrained_dict.keys()):
                    if "output" in k:
                        # print("delete key:{}".format(k))
                        del pretrained_dict[k]
                msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False)
                # print(msg)
                return
            
            pretrained_dict = pretrained_dict['model']
            print("---start load pretrained modle of SwinTransformer encoder---")
            # print("pretrained_dict['model']==",pretrained_dict.keys()) 
            # 此时pretrained_dict包含swin_transformer的全部权重keys-------------------------------
            # 对于tiny其中的[2,2,6,2]encoder,有layers.0.blocks.0.到layers.2.blocks.5.到layers.3.blocks.2.
            model_dict = self.swin_unet.state_dict()
            # 此时model_dict包含swin_unet的全部权重-------------------------------
            # 接下来应该是一一对应赋值
            # print("model_dict==", model_dict.keys())

            full_dict = copy.deepcopy(pretrained_dict)
            #copy.deepcopy() 深拷贝=寻常意义的复制。
            # 将被复制对象完全再复制一遍作为独立的新个体单独存在。
            # 改变原有被复制对象不会对已经复制出来的新对象产生影响。
            # 此时full_dict包含swin_transformer的全部权重keys-------------------------------
            for k, v in pretrained_dict.items():
                if "layers." in k:
                    # print("k==",k,k[7:8])
                    # k == layers.0.blocks.0.norm1.weight,包含预训练权重的全部
                    # k[7:8] == layers之后的number
                    current_layer_num = 3-int(k[7:8])
                    # print("current_layer_num==",current_layer_num)
                    # current_layer_num现有layer=原有layer总数4-现有layer
                    current_k = "layers_up." + str(current_layer_num) + k[8:]
                    # 将(tiny)encoder[2,2,6,2]对称地映射出一个decoder权重
                    # print("current_k==",current_k)
                    full_dict.update({current_k:v})
                    # 将映射出的decoder,和原本的encoder的所有权重都存在full_dict里
            # print("full_dict(all)==", full_dict.keys())       

            for k in list(full_dict.keys()):
            # 遍历full_dict,如果和model_dict(swin_unet)的key相同,就判断尺寸是否匹配,如果不匹配就删除
                if k in model_dict:
                    if full_dict[k].shape != model_dict[k].shape:
                        # print("delete:'{}'; pretrain_shape:'{}'; swinunet_shape:'{}'".format(k,v.shape,model_dict[k].shape))
                        del full_dict[k]
            # print("full_dict(after del)==", full_dict.keys())       
            msg = self.swin_unet.load_state_dict(full_dict, strict=False) 
            # 正式执行加载权重。strict=False 表示忽略不匹配的网络层参数
            # 除了.layers_up.1/2/3.upsample没加载预训练权重,都加载了
            # print(msg)
        else:
            print("none pretrain_ckpt to load")
  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值