pytorch checkpoint_将 PyTorch 版的 BERT 模型转换成 Tensorflow 版的 BERT 模型(1)

v2-4877bd333aad75a21f69907251b140ba_1440w.jpg?source=172ae18b

说明1:该专栏没有规划有多少个系列文章,而是根据每次文章内容难易程度文章较佳阅读时长决定最终文章篇幅。

说明2:该专栏是对 huggingface 中多个模型的源码解析,源项目地址:

huggingface/transformers​github.com
v2-aebd429ea055e3fc710c7f283998a07a_ipico.jpg

Transformers 文档:

Transformers - transformers 2.1.0 documentation​huggingface.co

一、源码阅读的三个原则和一个能力

  • 三个原则
    • IPO原则:完成一件事情必须要经历输入(Input)—处理(Process)—输出(Output)三个步骤,在IPO原则基础上,再对将要完成的目标添加需要的辅助文件。
    • AD原则:根据上述 IPO 原则,找到输入,跳过处理,返回输出,画出概要(Abstract);完成概要设计后,跳到详细(Detail)代码中。若在拆解代码过程时,碰到其他 IPO 过程,依样画葫芦即可。
    • BO原则:在 IPO 原则和 AD 原则的基础上,读懂代码,思考代码逻辑,获得不同代码模块的功能。以 Business-Oriented (业务导向),在业务目标下,将各个独立的、不同的功能组装拼接用以支撑业务。
  • 一个能力
    • 想象能力--不是天马行空般地胡乱猜想,是在原作者设计的基础上,深入思考:
      • 作者为何要如此设计?
      • 实现这个功能的基本设计流程思路是怎样的?
      • 按照功能需要,我们自己设计一个流程,往往我们仅会考虑正常流程,而忽略掉异常流程。完成设计后,对比源码,看看作者在可能出现的问题上又做了哪些针对性的预防措施?

二、各个类之间的关系以及解析策略

v2-09abac2b215e63fe2d79a847cc0927ce_b.jpg

转换过程中主要涉及以上十六个类,其中 nn.Module 是 PyTorch 提供的,暂且不谈。剩下的十五个类,我们以 BertModel 为中心,向上和向下分析相结合来完成解析:

  • 中心BertModel
  • 向上
    • 两个通用工具类:PretrainedConfigPreTrainedModel
    • BertModel 的直接父类:BertPreTrainedModel,以及它需要关联的 BertConfigBertLayerNorm
  • 向下
    • BertModel 关联的三个类:BertEmbeddingsBertEncoderBertPooler
    • BertEncoder 关联的一个类:BertLayer
    • BertLayer 关联的三个类:BertAttentionBertIntermediateBertOutput
    • BertAttention 关联的两个类:BertSelfAttentionBertSelfOutput

三、关键函数之间调用关系的时序图

仅给出关键函数之间的调用关系,从而使得逻辑更加清晰,中间存在的其他函数之间的调用关系,可以从源码中获取。

v2-4a08651502ca68c04f27a24ad0f19b81_b.jpg

四、convert_bert_pytorch_checkpoint_to_original_tf.py 文件解析

我们从 convert_bert_pytorch_checkpoint_to_original_tf.py 文件中的 main() 函数开始,结合 IPO 和 AD 原则来阅读源码。
main() 函数有参数解析、模型加载、模型转换三个功能(AD原则中 A(Abstract)的体现)。

1. 参数解析

    • 参数:模型名称、pytorch版模型的缓存地址、pytorch 版模型的存放地址、tf 版模型的输出地址
parser 
【思考1】:在没有源码的情况下,实现 pytorch->tf 模型的转换,那么我们首先会有哪些什么想法呢?
1. 根据 IPO 原则,会先确定输入、输出,暂时忽略处理部分。
输入:pytorch 版模型
输出:tensorflow 版模型
2. 根据 AD 原则,对输入、输出环节进行细化。
输入:
离线下载模型,并存放到具体路径下(pytorch_model_path),需要调用模型的时候,直接加载。
在线下载模型。如果是第一次下载,并在给定的目录中没有找到需要调用的模型,那么,在线下载,下载完成之后并存放到指定路径下(cache_dir)。
输出:
一个包含 tf 模型的具体路径(tf_cache_dir),同时输出应该包含4个文件(tf版模型的特殊之处)。
【源码1】:在完成上面的思考之后,我们瞅瞅看实际源码中的 IPO 过程是怎样的?
目标:将 pytorch 版的模型转成 tf 版的模型
输入:pytorch 版的模型名称【model_name】、pytorch 版模型的存放路径【pytorch_model_path】、pytorch 版模型的缓存路径【cache_dir】
处理:将 pytorch 版的模型 转换成 tf 版的模型
输出:tf 版的模型输出路径【tf_cache_dir】
【校正1】:遗漏了 model_name 

【小贴士】

  • argparse 核心内容(英文):
https://docs.python.org/3.6/howto/argparse.html​docs.python.org
  • argparse 核心内容(中文):
李小伟:参数解析 argparse 核心内容​zhuanlan.zhihu.com
v2-8c902637370b7e29a16dd5bac538b302_120x160.jpg

2. 模型加载

    • 参数:预训练模型名称或路径、pytorch版模型、pytorch版模型的输入地址
model = BertModel.from_pretrained(
    pretrained_model_name_or_path=args.model_name, 
    state_dict=torch.load(args.pytorch_model_path), 
    cache_dir=args.cache_dir)

3. 将 pytorch 版的模型 转换成 tf 版的模型

    • 参数:模型、checkpoint的输出地址、预训练模型名称
convert_pytorch_checkpoint_to_tf
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值