![v2-6784169c7bc831055bfe33adc636164a_1440w.jpg?source=172ae18b](http://img-03.proxy.5ce.com/view/image?&type=2&guid=641cd0d6-be2f-eb11-8da9-e4434bdf6706&url=https://pic3.zhimg.com/v2-6784169c7bc831055bfe33adc636164a_1440w.jpg?source=172ae18b)
上一篇:将 PyTorch 版的 BERT 模型转换成 Tensorflow 版的 BERT 模型(3)
本篇的重点放在 BertModel
向上几个类相关的函数,如下图所示:
![v2-c2e3e030771fc1748d3828417293c5af_b.jpg](http://img-01.proxy.5ce.com/view/image?&type=2&guid=641cd0d6-be2f-eb11-8da9-e4434bdf6706&url=https://pic4.zhimg.com/v2-c2e3e030771fc1748d3828417293c5af_b.jpg)
BertModel
继承自 BertPreTrainedModel
,而在 BertPreTrainedModel
关联了 BertConfig
和 load_tf_weights_in_bert()
方法
BertPreTrainedModel
的作用是:该抽象类处理权重初始化和一个用于下载和加载预训练模型的简单接口,代码如下:
class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = BertConfig
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
BertConfig
的作用是:用来存储BertModel
配置,代码如下:
class BertConfig(PretrainedConfig):
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
**kwargs):
super(BertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (
sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
load_tf_weights_in_bert()
函数作用:在PyTorch
模型中加载tf
checkpoints
。代码如下:
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_d+', m_name):
l = re.split(r'_(d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
elif l[0] == 'squad':
pointer = getattr(pointer, 'classifier')
else:
try:
pointer = getattr(pointer, l[0])
except AttributeError:
logger.info("Skipping {}".format("/".join(name)))
continue
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
BertModel
向上相关的类和方法就是上述内容,下一篇解析 BertModel
向下相关的三个类,分别是BertEmbeddings
、BertEncoder
、BertPooler
以及相关的 4 个函数。