第一篇:将 PyTorch 版的 BERT 模型转换成 Tensorflow 版的 BERT 模型(1)
上一篇,我们分析了 convert_pytorch_checkpoint_to_tf.py
文件中 main()
的参数解析,本篇,我们从模型加载入手。
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir)
BertModel
调用的 from_pretrained()
方法是其父类 PreTrainedModel
中的方法。该方法的作用是:从预训练模型的配置文件实例化 PyTorch 版的预训练模型。
【说明】:以下仅给出和 PyTorch->tf 相关的函数。
PreTrainedModel
类有 4 个全局变量,一个初始化函数、一个 from_pretrained()
函数
class PreTrainedModel(nn.Module):
# 4个全局变量
config_class = None
pretrained_model_archive_map = {}
load_tf_weights = lambda model, config, path: None
base_model_prefix = ""
# 初始化函数
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
# 判断 config 是否是继承自 PreTrainedConfig 类
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(self.__class__.__name__, self.__class__.__name__))
# Save config in model
self.config = config
【小贴士】
python的 lambda 表达式(匿名函数):https://www.runoob.com/python3/python3-function.html
lambda和map() 函数:https://mp.weixin.qq.com/s/GDC3GeTPXspInK_1DPyuVA
isinstance() 函数:https://www.runoob.com/python/python-func-isinstance.html
raise 语句抛出一个特定异常:https://www.runoob.com/python3/python3-errors-execptions.html
python 常用的内建属性:https://blog.csdn.net/qq_26442553/article/details/82464682
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop('config', None)
state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None)