BERT Pytorch版本 源码解析(一)

BERT Pytorch版本 源码解析(一)

 

一、BERT安装方式

pip install pytorch-pretrained-bert

二、BertPreTrainModel: 

  • 一个用于获取预训练好权重的抽象类,一个用于下载和载入预训练模型的简单接口

1、初始化函数(def __init__(self, config, *inputs, **kwargs)):

def __init__(self, config, *inputs, **kwargs):
    super(BertPreTrainedModel, self).__init__()
    if not isinstance(config, BertConfig):
        raise ValueError(
            "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
            "To create a model from a Google pretrained model use "
            "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                self.__class__.__name__, self.__class__.__name__
            ))
    self.config = config

初始化函数主要是用于传入BertConfig的一个对象,这样可以获得Bert模型所需的模型参数,例如hidden_size等

2、最重要的from_pretrained函数: def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)

pretrained_model_name_or_path: either:
    - a str with the name of a pre-trained model to load selected in the list of:
        . `bert-base-uncased`
        . `bert-large-uncased`
        . `bert-base-cased`
        . `bert-large-cased`
        . `bert-base-multilingual-uncased`
        . `bert-base-multilingual-cased`
        . `bert-base-chinese`
    - a path or url to a pretrained model archive containing:
        . `bert_config.json` a configuration file for the model
        . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
    - a path or url to a pretrained model archive containing:
        . `bert_config.json` a configuration file for the model
        . `model.chkpt` a TensorFlow checkpoint

看一下pretrained_model_name_or_path 这个参数,这个参数可以是两种,一种是你需要下载的预训练的BERT模型类别名称,另一种是你已经下好的BERT预训练模型的路径。

这就是为什么有一些博客上加载预训练模型是直接from_pretrain('bert-base-uncased'),而有一些上面写的是bert模型的路径,这里个人建议是把一些预训练模型下好,然后放到一个固定的文件夹下面可以避免重复下载,每次用的时候直接调用就好了。

 

if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
    archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
    archive_file = pretrained_model_name_or_path

这部分代码就是解析你传入的pretrained_model_name_or_path参数是一个模型名称还是一个模型路径,首先是进行判断是否是模型名称,不是的话默认为下载好的模型路径。

PRETRAINED_MODEL_ARCHIVE_MAP = {
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}

这个就是一个Map,用于将模型名称转换成相对应的URL,所以想预下载的同志们直接在这里copy 一下URL就可以下载了,并不需要去找一下百度云哦,毕竟百度云下载的东西也未必是真的有用(小编就被坑了,很难受 QAQ)。

try:
    resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
    logger.error(
        "Model name '{}' was not found in model name list ({}). "
        "We assumed '{}' was a path or url but couldn't find any file "
        "associated to this path or url.".format(
            pretrained_model_name_or_path,
            ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
            archive_file))
    return None



'''
cached_path函数的内部实现
'''
def cached_path(url_or_filename, cache_dir=None):
    """
    Given something that might be a URL (or might be a local path),
    determine which. If it's a URL, download the file and cache it, and
    return the path to the cached file. If it's already a local path,
    make sure the file exists and then return the path.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
    if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
        url_or_filename = str(url_or_filename)
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    parsed = urlparse(url_or_filename)

    if parsed.scheme in ('http', 'https', 's3'):
        # URL, so get it from the cache (downloading if necessary)
        return get_from_cache(url_or_filename, cache_dir)
    elif os.path.exists(url_or_filename):
        # File, and it exists.
        return url_or_filename
    elif parsed.scheme == '':
        # File, but it doesn't exist.
        raise EnvironmentError("file {} not found".format(url_or_filename))
    else:
        # Something unknown
        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))

这部分代码就是对于你传入的文件路径或者是转换成的URL做一个处理,如果是URL的话就进行下载的操作,如果是一个本地文件的话就进行文件路径检查以及返回文件路径。

 

三、BertModel

1、BertModel 大概是实战中最应该掌握的模块。初始化函数如下:

    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

可以看出初始化 BertModel 的时候是需要传一个config的,这个config就是BertConfig的一个对象,那么我们在项目中要运用一些参数预训练好的模型来进行建模时应该怎么操作呢?

self.bert = BertModel.from_pretrained(model_path)
self.hidden_size = self.bert.config.hidden_size

这个写法是预先下载好了 bert 的预训练模型的写法,将你自己下好的预训练模型的路径传进去就好了,如果没有下载过可以看一下 BertPreTrainModel 部分的解释,建议是将你的bert模型下载好保存到一个固定的文件夹中,以后要用到的时候直接调用就好了,BertModel可以加载几种的预训练模型(包括中文的bert)按需下载就好了。到现在我们就加载好了一个预训练好的bert模型。这种方式去加载的bert模型如果要查看相关的参数配置信息,只需要如上述第二行的方式即可获取。

注意:BERT的参数太大,不要在笔记本上面测试代码能不能跑,很容易就死机了。。。

BertModel的内部运行方式解析:

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output

首先是参数传入,一般来说对于 基本的BertModel 以及之后的另外一些模型都是传入 input_ids, token_type_ids, attention_mask三个参数,下面解释一下三个参数的含义。

input_ids: 如果是用BertTokenizer 进行分词的,那么会自动生成对应的 tokens to ids 的函数,将你的句子直接扔进这个函数就可以得到一个用词表index描述的句子。用 Batch 的方式去训练的记得将input_ids进行padding操作。

token_type_ids: BertModel 每一次最多允许两个句子输入模型中,所以你的 token_type_ids 只能是0或者1。

attention_mask: Encoder 做 attention的时候需要利用这个部分进行对padding的无用信息进行舍去,不进行attention操作。

  • 10
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 20
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值