与Bert结构不完全相同的模型从.ckpt 转换为.bin.报错:AttributeError: ‘BertForPreTraining‘ object has noattribute ‘bias“

如果模型与bert结构一致,或是transformers中的其他模型,都可以用transformer官方库提供的转换方式进行转换。
1)vim convert.py
2) 使用命令行

python convert.py --tf_checkpoint_path /Users/sunrui/Desktop/cbert/bert_model.ckpt --bert_config_file /Users/sunrui/Desktop/cbert/bert_config.json --pytorch_dump_path /Users/sunrui/Desktop/cbert/pytorch_model.bin
import argparse

import torch

from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging


logging.set_verbosity_info()


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print(f"Building PyTorch model from configuration: {config}")
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

但此时会报错:
AttributeError: ‘BertForPreTraining’ object has no attribute ‘bias’
原因是tensorflow保存的ckpt的key与pytorch transformers的key 不相符
这时,需要将tf ckpt中的key全部打印出来,与BertModel key进行比较。

可以采用以下命令打印tf ckpt中的key:

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('./cbert', "bert_model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
  print("tensor_name: ", key)

我的打印出来是这样
bert/bert/encoder/layer_9/attention/self/query/kernel/adam_m
应该将bert/bert变为bert/,多了一个bert
于是采用一个rename代码,将ckpt中的key进行rename
rename代码如下:
并使用命令行,将bert/bert 变为bert

python rename.py --checkpoint_dir /Users/sunrui/Desktop/cbert/bert_model.ckpt  --replace_from bert/bert  --replace_to bert

import getopt
import sys

import tensorflow as tf

usage_str = ('python tensorflow_rename_variables.py '
             '--checkpoint_dir=path/to/dir/ --replace_from=substr '
             '--replace_to=substr --add_prefix=abc --dry_run')
find_usage_str = ('python tensorflow_rename_variables.py '
                  '--checkpoint_dir=path/to/dir/ --find_str=[\'!\']substr')
comp_usage_str = ('python tensorflow_rename_variables.py '
                  '--checkpoint_dir=path/to/dir/ '
                  '--checkpoint_dir2=path/to/dir/')


def print_usage_str():
    print('Please specify a checkpoint_dir. Usage:')
    print('%s\nor\n%s\nor\n%s' % (usage_str, find_usage_str, comp_usage_str))
    print('Note: checkpoint_dir should be a *DIR*, not a file')


def compare(checkpoint_dir, checkpoint_dir2):
    import difflib
    with tf.Session():
        list1 = [el1 for (el1, el2) in
                 tf.contrib.framework.list_variables(checkpoint_dir)]
        list2 = [el1 for (el1, el2) in
                 tf.contrib.framework.list_variables(checkpoint_dir2)]
        for k1 in list1:
            if k1 in list2:
                continue
            else:
                print('{} close matches: {}'.format(
                    k1, difflib.get_close_matches(k1, list2)))


def find(checkpoint_dir, find_str):
    with tf.Session():
        negate = find_str.startswith('!')
        if negate:
            find_str = find_str[1:]
        for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
            if negate and find_str not in var_name:
                print('%s missing from %s.' % (find_str, var_name))
            if not negate and find_str in var_name:
                print('Found %s in %s.' % (find_str, var_name))


def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
    with tf.Session() as sess:
        for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
            # Load the variable
            var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)

            # Set the new name
            if None not in [replace_from, replace_to]:
                new_name = var_name
                if replace_from in var_name:
                    new_name = new_name.replace(replace_from, replace_to)
                    if add_prefix:
                        new_name = add_prefix + new_name
                    if dry_run:
                        print('%s would be renamed to %s.' % (var_name,
                                                              new_name))
                    else:
                        print('Renaming %s to %s.' % (var_name, new_name))
                # Create the variable, potentially renaming it
                var = tf.Variable(var, name=new_name)

        if not dry_run:
            # Save the variables
            saver = tf.train.Saver()
            sess.run(tf.global_variables_initializer())
            saver.save(sess, '/Users/sunrui/Desktop/cbert/bert_model.ckpt')


def main(argv):
    checkpoint_dir = None
    checkpoint_dir2 = None
    replace_from = None
    replace_to = None
    add_prefix = None
    dry_run = False
    find_str = None

    try:
        opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=',
                                               'replace_from=', 'replace_to=',
                                               'add_prefix=', 'dry_run',
                                               'find_str=',
                                               'checkpoint_dir2='])
    except getopt.GetoptError as e:
        print(e)
        print_usage_str()
        sys.exit(2)
    for opt, arg in opts:
        if opt in ('-h', '--help'):
            print(usage_str)
            sys.exit()
        elif opt == '--checkpoint_dir':
            checkpoint_dir = arg
        elif opt == '--checkpoint_dir2':
            checkpoint_dir2 = arg
        elif opt == '--replace_from':
            replace_from = arg
        elif opt == '--replace_to':
            replace_to = arg
        elif opt == '--add_prefix':
            add_prefix = arg
        elif opt == '--dry_run':
            dry_run = True
        elif opt == '--find_str':
            find_str = arg

    if not checkpoint_dir:
        print_usage_str()
        sys.exit(2)

    if checkpoint_dir2:
        compare(checkpoint_dir, checkpoint_dir2)
    elif find_str:
        find(checkpoint_dir, find_str)
    else:
        rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)


if __name__ == '__main__':
    main(sys.argv[1:])

此外,我打印出来的ckpt中还有这样几项:/loss/output_weights/adam_v
这才BertModel中是不需要的
所以我在anaconda3/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py
中做了这样的修改:
在load_tf_weights_in_bert中,

    for name, array in zip(names, arrays):
        name = name.split("/")
        print(name)
        if name[0] == "loss":
            continue
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值