我在载入BERT/ALBERT的预训练模型时,总会好奇于它的模型参数到底是怎么一步步被填到模型框架里的,此外我也想更明晰地看到模型参数是否被正确地填入,以防预训练的模型参数没被正确载入。因此对BERT模型的载入代码进行了单步调试,在此简述这部分代码各自的作用。
注:根目录是albert-pytorch项目根目录,来自github该repo
模型文件加载的文件跳转路径:
/run_classifier.py(387) AlbertForSequenceClassification.from_pretrained()->
/model/modeling_utils.py(191) from_pretrained() ->
/model/modeling_utils.py(363) load() ->
/model/modeling_utils.py(347) load() -> # 这是个递归函数,在一次次递归中"prefix"参数在变化,控制着模型参数的载入;
/model/modeling_utils.py(347) module._load_from_state_dict() ->
{$TORCH_HOME}/nn/modules/module.py(703) _load_from_state_dict()
重点就在这函数_load_from_state_dict()
里面。line742~line769的for-loop。若是成功的模型参数加载,则line762:param.copy_(input_param)
就会被执行(这段for-loop代码示例如下)
742 for name, param in local_state.items():
743 key = prefix + name
744 if key in state_dict:
745 input_param = state_dict[key]
746
747 # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
748 if len(param.shape) == 0 and len(input_param.shape) == 1:
749 input_param = input_param[0]
750
751 if input_param.shape != param.shape:
752 # local shape should match the one in checkpoint
753 error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
754 'the shape in current model is {}.'
755 .format(key, input_param.shape, param.shape))
756 continue
757
758 if isinstance(input_param, Parameter):
759 # backwards compatibility for serialized parameters
760 input_param = input_param.data
761 try:
762 param.copy_(input_param)
763 except Exception:
764 error_msgs.append('While copying the parameter named "{}", '
765 'whose dimensions in the model are {} and '
766 'whose dimensions in the checkpoint are {}.'
767 .format(key, param.size(), input_param.size()))
768 elif strict:
769 missing_keys.append(key)
这里param是一个torch.Tensor
,让我们读一下torch.Tensor.copy_()
的函数文档
Copies the elements from
src
into self tensor and returnsself
. The
src
tensor must be broadcastable with theself
tensor. It may be of a
different data type or reside on a different device.
很简单,意思就是,src
=input_param
会被复制到self
当中(当前self
就是当前param所在的nn.Module),同时input_param
会作为返回值。
以参数bert.embeddings.word_embeddings.weight为例子,此时该param所对应的“self
”是Embedding(21128, 128, padding_idx=0),所在层就是Embedding-layer。我们从class torch.nn.Embedding可以看出(下面附source code),该层含有num_embeddings、embedding_dim等属性。它们分别就是21128, 128(前者是vocab词表大小,后者是albert的Embedding size)
class Embedding(Module):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2., scale_grad_by_freq=False,
sparse=False, _weight=None):
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
if _weight is None:
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'
self.weight = Parameter(_weight)
self.sparse = sparse
注意line 18~20。当还未执行param.copy_(input_param)
时,此时Embedding层的参数还是未初始化的,其中line20的self.reset_parameters()
内部的操作就是,将前面指定维度生成的Tensor,填以服从N(0, 1)的正态分布的随机数。我们此时先把这个self.weight打出来看看:
ipdb> self.weight
Parameter containing:
tensor([[-0.0072, -0.0040, 0.0490, ..., -0.0219, 0.0050, -0.0293],
[ 0.0405, 0.0166, -0.0039, ..., -0.0099, -0.0004, -0.0137],
[ 0.0111, -0.0048, 0.0283, ..., 0.0047, -0.0072, 0.0130],
...,
[ 0.0209, -0.0084, -0.0283, ..., 0.0367, 0.0080, -0.0220],
[ 0.0584, 0.0286, 0.0028, ..., -0.0016, 0.0436, 0.0071],
[ 0.0238, -0.0204, 0.0172, ..., -0.0435, -0.0267, 0.0099]],
requires_grad=True)
执行过param.copy_(input_param)
后,看看self.weight
是否被修改成bert.embeddings.word_embeddings.weight的内容了:
ipdb> self.weight
Parameter containing:
tensor([[ 0.0722, 0.0224, 0.1045, ..., 0.0800, 0.0776, -0.0483],
[ 0.0779, 0.0606, 0.0891, ..., 0.0628, 0.0831, -0.0924],
[ 0.0891, 0.0782, 0.0731, ..., 0.0609, 0.1201, -0.0561],
...,
[ 0.0159, 0.0438, 0.1095, ..., 0.0802, 0.0773, -0.0790],
[ 0.0664, 0.0513, 0.1075, ..., 0.0682, 0.0776, -0.0842],
[ 0.0135, 0.0239, 0.1113, ..., 0.0646, 0.0756, -0.0632]],
requires_grad=True)
这恰好就是bert.embeddings.word_embeddings.weight对应的值(input_param
):
ipdb> (input_param == self.weight).numpy().all()
True
这说明了line762的param.copy_(input_param)
就是在将input_param
更新到self.weight
上去。