在pytorch中使用用npz文件保存的预训练模型

开始

最近在看transformer相关的文章,在AN IMAGE IS WORTH 16X16 WORDS:
TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE这篇文章中,作者给出了在超大数据集上训练的模型。但是采用的是npz文件保存。这里将介绍npz文件和如何将其应用到pytorch中进行加载参数和训练。

1、什么是npz文件?

npz文件是python中numpy的文件存储。
使用方法如下:

from numpy import load
data = load('D:/CsdnProject/TransUNet-main/ViT-B_16.npz')
lst = data.files
for item in lst:
    print(item)
    print(data[item])

采用load的方式进行加载。然后打印其内容(其中加载路径为这次的预训练模型的路径)

2、如何将预训练模型加载到pytorch中进行使用?

2.1预训练模型npz文件中的内容

首先,npz文件中存放的内容有两个内容,一个是模型中每个模块的名称,第二个是这个模块保存的参数。
如1所示代码

from numpy import load
from os.path import join as join
data = load('D:/CsdnProject/TransUNet-main/ViT-B_16.npz')
lst = data.files
for item in lst:
    print(item)
    print(data[item])

将其打印出来的结果是(未显示参数)

C:\Users\sg\Anaconda3\envs\pytorch\python.exe D:/CsdnProject/TransUNet-main/TestNpz.py
Transformer/encoder_norm/bias
Transformer/encoder_norm/scale
Transformer/encoderblock_0/LayerNorm_0/bias
Transformer/encoderblock_0/LayerNorm_0/scale
Transformer/encoderblock_0/LayerNorm_2/bias
Transformer/encoderblock_0/LayerNorm_2/scale
Transformer/encoderblock_0/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_0/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_0/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_0/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_0/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_0/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_0/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_0/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_0/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_0/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_1/LayerNorm_0/bias
Transformer/encoderblock_1/LayerNorm_0/scale
Transformer/encoderblock_1/LayerNorm_2/bias
Transformer/encoderblock_1/LayerNorm_2/scale
Transformer/encoderblock_1/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_1/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_1/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_1/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_1/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_1/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_1/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_1/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_1/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_1/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_1/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_1/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_10/LayerNorm_0/bias
Transformer/encoderblock_10/LayerNorm_0/scale
Transformer/encoderblock_10/LayerNorm_2/bias
Transformer/encoderblock_10/LayerNorm_2/scale
Transformer/encoderblock_10/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_10/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_10/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_10/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_10/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_10/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_10/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_10/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_10/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_10/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_10/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_10/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_11/LayerNorm_0/bias
Transformer/encoderblock_11/LayerNorm_0/scale
Transformer/encoderblock_11/LayerNorm_2/bias
Transformer/encoderblock_11/LayerNorm_2/scale
Transformer/encoderblock_11/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_11/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_11/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_11/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_11/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_11/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_11/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_11/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_11/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_11/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_11/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_11/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_2/LayerNorm_0/bias
Transformer/encoderblock_2/LayerNorm_0/scale
Transformer/encoderblock_2/LayerNorm_2/bias
Transformer/encoderblock_2/LayerNorm_2/scale
Transformer/encoderblock_2/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_2/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_2/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_2/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_2/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_2/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_2/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_2/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_2/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_2/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_2/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_2/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_3/LayerNorm_0/bias
Transformer/encoderblock_3/LayerNorm_0/scale
Transformer/encoderblock_3/LayerNorm_2/bias
Transformer/encoderblock_3/LayerNorm_2/scale
Transformer/encoderblock_3/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_3/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_3/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_3/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_3/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_3/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_3/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_3/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_3/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_3/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_3/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_3/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_4/LayerNorm_0/bias
Transformer/encoderblock_4/LayerNorm_0/scale
Transformer/encoderblock_4/LayerNorm_2/bias
Transformer/encoderblock_4/LayerNorm_2/scale
Transformer/encoderblock_4/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_4/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_4/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_4/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_4/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_4/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_4/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_4/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_4/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_4/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_4/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_4/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_5/LayerNorm_0/bias
Transformer/encoderblock_5/LayerNorm_0/scale
Transformer/encoderblock_5/LayerNorm_2/bias
Transformer/encoderblock_5/LayerNorm_2/scale
Transformer/encoderblock_5/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_5/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_5/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_5/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_5/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_5/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_5/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_5/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_5/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_5/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_5/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_5/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_6/LayerNorm_0/bias
Transformer/encoderblock_6/LayerNorm_0/scale
Transformer/encoderblock_6/LayerNorm_2/bias
Transformer/encoderblock_6/LayerNorm_2/scale
Transformer/encoderblock_6/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_6/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_6/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_6/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_6/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_6/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_6/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_6/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_6/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_6/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_6/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_6/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_7/LayerNorm_0/bias
Transformer/encoderblock_7/LayerNorm_0/scale
Transformer/encoderblock_7/LayerNorm_2/bias
Transformer/encoderblock_7/LayerNorm_2/scale
Transformer/encoderblock_7/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_7/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_7/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_7/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_7/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_7/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_7/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_7/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_7/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_7/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_7/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_7/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_8/LayerNorm_0/bias
Transformer/encoderblock_8/LayerNorm_0/scale
Transformer/encoderblock_8/LayerNorm_2/bias
Transformer/encoderblock_8/LayerNorm_2/scale
Transformer/encoderblock_8/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_8/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_8/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_8/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_8/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_8/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_8/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_8/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_8/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_8/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_8/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_8/MultiHeadDotProductAttention_1/value/kernel
Transformer/encoderblock_9/LayerNorm_0/bias
Transformer/encoderblock_9/LayerNorm_0/scale
Transformer/encoderblock_9/LayerNorm_2/bias
Transformer/encoderblock_9/LayerNorm_2/scale
Transformer/encoderblock_9/MlpBlock_3/Dense_0/bias
Transformer/encoderblock_9/MlpBlock_3/Dense_0/kernel
Transformer/encoderblock_9/MlpBlock_3/Dense_1/bias
Transformer/encoderblock_9/MlpBlock_3/Dense_1/kernel
Transformer/encoderblock_9/MultiHeadDotProductAttention_1/key/bias
Transformer/encoderblock_9/MultiHeadDotProductAttention_1/key/kernel
Transformer/encoderblock_9/MultiHeadDotProductAttention_1/out/bias
Transformer/encoderblock_9/MultiHeadDotProductAttention_1/out/kernel
Transformer/encoderblock_9/MultiHeadDotProductAttention_1/query/bias
Transformer/encoderblock_9/MultiHeadDotProductAttention_1/query/kernel
Transformer/encoderblock_9/MultiHeadDotProductAttention_1/value/bias
Transformer/encoderblock_9/MultiHeadDotProductAttention_1/value/kernel
Transformer/posembed_input/pos_embedding
cls
embedding/bias
embedding/kernel
head/bias
head/kernel
pre_logits/bias
pre_logits/kernel

Process finished with exit code 0

2.2如何使用

所以我们要将这些内容一一赋值给我们定义的模型,我们定义的模型一定要和这些预训练的模型是一致的。
这里比如说我使用的是VIT-B-16,这意味着我的block有12个(具体参见论文),打印出我网络的模型后,其实两者是相同的。

在VisionTransformer类中,定义一个load_from函数,来给我的模型进行赋值

     def load_from(self, weights):
        with torch.no_grad():

            res_weight = weights
            #给嵌入块赋参数
            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
            # 给编码器赋参数
            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
            # 给位置嵌入赋参数
            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])

            posemb_new = self.transformer.embeddings.position_embeddings
            print("1执行到这里了")
            if posemb.size() == posemb_new.size():
                #这里是true 下面全都不执行 直接到下个循环 循环那里有问题
                self.transformer.embeddings.position_embeddings.copy_(posemb)
                print("2执行到这里了")
            elif posemb.size()[1]-1 == posemb_new.size()[1]:
                posemb = posemb[:, 1:]
                self.transformer.embeddings.position_embeddings.copy_(posemb)
                print("3执行到这里了")
            else:
                logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                print("4执行到这里了")
                ntok_new = posemb_new.size(1)
                if self.classifier == "seg":
                    _, posemb_grid = posemb[:, :1], posemb[0, 1:]
                    print("5执行到这里了")
                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2np
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = posemb_grid
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
                print("6执行到这里了")
            print("7执行到这里了")

            # Encoder whole
            #分别给我的编码器中每一个层赋权重和偏置
            for bname, block in self.transformer.encoder.named_children():
                print(bname)
                print(block)
                for uname, unit in block.named_children():
                    print(uname)
                    print(unit)
                    unit.load_from(weights, n_block=uname)

            print("执行到这里了")

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
                gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
                gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(res_weight, n_block=bname, n_unit=uname)

再for循环中,block调用一个load_from函数,这个函数来自于block类。是这个类中的一个函数

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}/"
        with torch.no_grad():
            query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()

            query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
            key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
            value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
            out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
            mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
            mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
            mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))

3、坑

在执行过程中,给我报了一个值错误,意思就是没有找到值

KeyError: 'Transformer/encoderblock_0/MultiHeadDotProductAttention_1/query\\kernel is not a file in the archive'

在我多方面的查找,纠错后,发现错误原因是我在block类中,load_from函数下面的路径传的不对。

query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()

如上这句代码,连接了ROOT,ATTENTION_Q和kernel这三个路径。而连接过程中“/”的错误导致找不到这句话,所以会报错。

  • 26
    点赞
  • 100
    收藏
    觉得还不错? 一键收藏
  • 45
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 45
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值