BERT、ALBERT模型加载——From pytorch_model.bin

11 篇文章 4 订阅
11 篇文章 3 订阅
我在载入BERT/ALBERT的预训练模型时,总会好奇于它的模型参数到底是怎么一步步被填到模型框架里的,此外我也想更明晰地看到模型参数是否被正确地填入,以防预训练的模型参数没被正确载入。因此对BERT模型的载入代码进行了单步调试,在此简述这部分代码各自的作用。

注:根目录是albert-pytorch项目根目录,来自github该repo

模型文件加载的文件跳转路径:

/run_classifier.py(387) AlbertForSequenceClassification.from_pretrained()->
/model/modeling_utils.py(191) from_pretrained() ->
/model/modeling_utils.py(363) load() ->
/model/modeling_utils.py(347) load() ->  # 这是个递归函数,在一次次递归中"prefix"参数在变化,控制着模型参数的载入;
/model/modeling_utils.py(347) module._load_from_state_dict() -> 
{$TORCH_HOME}/nn/modules/module.py(703) _load_from_state_dict()

重点就在这函数_load_from_state_dict()里面。line742~line769的for-loop。若是成功的模型参数加载,则line762:param.copy_(input_param)就会被执行(这段for-loop代码示例如下)

742	for name, param in local_state.items():
743	    key = prefix + name
744	    if key in state_dict:
745	        input_param = state_dict[key]
746	
747	        # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
748	        if len(param.shape) == 0 and len(input_param.shape) == 1:
749	            input_param = input_param[0]
750	
751	        if input_param.shape != param.shape:
752	            # local shape should match the one in checkpoint
753	            error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
754	                              'the shape in current model is {}.'
755	                              .format(key, input_param.shape, param.shape))
756	            continue
757	
758	        if isinstance(input_param, Parameter):
759	            # backwards compatibility for serialized parameters
760	            input_param = input_param.data
761	        try:
762	            param.copy_(input_param)
763	        except Exception:
764	            error_msgs.append('While copying the parameter named "{}", '
765	                              'whose dimensions in the model are {} and '
766	                              'whose dimensions in the checkpoint are {}.'
767	                              .format(key, param.size(), input_param.size()))
768      elif strict:
769          missing_keys.append(key)

这里param是一个torch.Tensor,让我们读一下torch.Tensor.copy_()函数文档

Copies the elements from src into self tensor and returns self. The
src tensor must be broadcastable with the self tensor. It may be of a
different data type or reside on a different device.

很简单,意思就是,src=input_param会被复制到self当中(当前self就是当前param所在的nn.Module),同时input_param会作为返回值。
以参数bert.embeddings.word_embeddings.weight为例子,此时该param所对应的“self”是Embedding(21128, 128, padding_idx=0),所在层就是Embedding-layer。我们从class torch.nn.Embedding可以看出(下面附source code),该层含有num_embeddings、embedding_dim等属性。它们分别就是21128, 128(前者是vocab词表大小,后者是albert的Embedding size)

class Embedding(Module):
  def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
             max_norm=None, norm_type=2., scale_grad_by_freq=False,
             sparse=False, _weight=None):
    super(Embedding, self).__init__()
    self.num_embeddings = num_embeddings
    self.embedding_dim = embedding_dim
    if padding_idx is not None:
        if padding_idx > 0:
            assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
        elif padding_idx < 0:
            assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
            padding_idx = self.num_embeddings + padding_idx
    self.padding_idx = padding_idx
    self.max_norm = max_norm
    self.norm_type = norm_type
    self.scale_grad_by_freq = scale_grad_by_freq
    if _weight is None:
        self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
        self.reset_parameters()
    else:
        assert list(_weight.shape) == [num_embeddings, embedding_dim], \
            'Shape of weight does not match num_embeddings and embedding_dim'
        self.weight = Parameter(_weight)
    self.sparse = sparse

注意line 18~20。当还未执行param.copy_(input_param)时,此时Embedding层的参数还是未初始化的,其中line20的self.reset_parameters()内部的操作就是,将前面指定维度生成的Tensor,填以服从N(0, 1)的正态分布的随机数。我们此时先把这个self.weight打出来看看:

ipdb> self.weight        
Parameter containing:
tensor([[-0.0072, -0.0040,  0.0490,  ..., -0.0219,  0.0050, -0.0293],
   [ 0.0405,  0.0166, -0.0039,  ..., -0.0099, -0.0004, -0.0137],
   [ 0.0111, -0.0048,  0.0283,  ...,  0.0047, -0.0072,  0.0130],
   ...,
   [ 0.0209, -0.0084, -0.0283,  ...,  0.0367,  0.0080, -0.0220],
   [ 0.0584,  0.0286,  0.0028,  ..., -0.0016,  0.0436,  0.0071],
   [ 0.0238, -0.0204,  0.0172,  ..., -0.0435, -0.0267,  0.0099]],
  requires_grad=True)

执行过param.copy_(input_param)后,看看self.weight是否被修改成bert.embeddings.word_embeddings.weight的内容了:

ipdb> self.weight        
Parameter containing:
tensor([[ 0.0722,  0.0224,  0.1045,  ...,  0.0800,  0.0776, -0.0483],
   [ 0.0779,  0.0606,  0.0891,  ...,  0.0628,  0.0831, -0.0924],
   [ 0.0891,  0.0782,  0.0731,  ...,  0.0609,  0.1201, -0.0561],
   ...,
   [ 0.0159,  0.0438,  0.1095,  ...,  0.0802,  0.0773, -0.0790],
   [ 0.0664,  0.0513,  0.1075,  ...,  0.0682,  0.0776, -0.0842],
   [ 0.0135,  0.0239,  0.1113,  ...,  0.0646,  0.0756, -0.0632]],
  requires_grad=True)

这恰好就是bert.embeddings.word_embeddings.weight对应的值(input_param):

ipdb> (input_param == self.weight).numpy().all()   
True

这说明了line762的param.copy_(input_param)就是在将input_param更新到self.weight上去。

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_illusion_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值