BERT Pytorch版本 源码解析(一)
一、BERT安装方式
pip install pytorch-pretrained-bert
二、BertPreTrainModel:
一个用于获取预训练好权重的抽象类,一个用于下载和载入预训练模型的简单接口
1、初始化函数(def __init__(self, config, *inputs, **kwargs)):
def __init__(self, config, *inputs, **kwargs):
super(BertPreTrainedModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
初始化函数主要是用于传入BertConfig的一个对象,这样可以获得Bert模型所需的模型参数,例如hidden_size等
2、最重要的from_pretrained函数: def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model