pytorch checkpoint_将 PyTorch 版的 BERT 模型转换成 Tensorflow 版的 BERT 模型(2)

v2-8d23d398fd9a04f618454111d763b6c5_1440w.jpg?source=172ae18b

第一篇:将 PyTorch 版的 BERT 模型转换成 Tensorflow 版的 BERT 模型(1)

上一篇,我们分析了 convert_pytorch_checkpoint_to_tf.py 文件中 main()参数解析,本篇,我们从模型加载入手。

model = BertModel.from_pretrained(
    pretrained_model_name_or_path=args.model_name,
    state_dict=torch.load(args.pytorch_model_path), 
    cache_dir=args.cache_dir)

BertModel 调用的 from_pretrained() 方法是其父类 PreTrainedModel 中的方法。该方法的作用是:从预训练模型的配置文件实例化 PyTorch 版的预训练模型

【说明】:以下仅给出和 PyTorch->tf 相关的函数。

PreTrainedModel 类有 4 个全局变量,一个初始化函数、一个 from_pretrained() 函数

class PreTrainedModel(nn.Module):
    # 4个全局变量
    config_class = None
    pretrained_model_archive_map = {}
    load_tf_weights = lambda model, config, path: None
    base_model_prefix = ""

    # 初始化函数
    def __init__(self, config, *inputs, **kwargs):
        super(PreTrainedModel, self).__init__()
        # 判断 config 是否是继承自 PreTrainedConfig 类
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(self.__class__.__name__, self.__class__.__name__))
            # Save config in model
            self.config = config

【小贴士】

python的 lambda 表达式(匿名函数):https://www.runoob.com/python3/python3-function.html

lambda和map() 函数:https://mp.weixin.qq.com/s/GDC3GeTPXspInK_1DPyuVA

isinstance() 函数:https://www.runoob.com/python/python-func-isinstance.html

raise 语句抛出一个特定异常:https://www.runoob.com/python3/python3-errors-execptions.html

python 常用的内建属性:https://blog.csdn.net/qq_26442553/article/details/82464682

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
    config = kwargs.pop('config', None)
    state_dict = kwargs.pop('state_dict', None)
    cache_dir = kwargs.pop('cache_dir', None)
 
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值