bert-tokenization代码学习

# coding=utf-8

# Copyright 2018 The Google AI Language Team Authors.

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

"""Tokenization classes."""

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import collections

import re

import unicodedata

import six

import tensorflow as tf


def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
    """Checks whether the casing config is consistent with the checkpoint name."""

    # The casing has to be passed in by the user and there is no explicit check

    # as to whether it matches the checkpoint. The casing information probably

    # should have been stored in the bert_config.json file, but it's not, so

    # we have to heuristically detect it to validate.



    if not init_checkpoint:
        return

    m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)

    if m is None:
        return

    model_name = m.group(1)

    lower_models = [

        "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",

        "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"

    ]

    cased_models = [

        "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",

        "multi_cased_L-12_H-768_A-12"

    ]

    is_bad_config = False

    if model_name in lower_models and not do_lower_case:
        is_bad_config = True

        actual_flag = "False"

        case_name = "lowercased"

        opposite_flag = "True"

    if model_name in cased_models and do_lower_case:
        is_bad_config = True

        actual_flag = "True"

        case_name = "cased"

        opposite_flag = "False"

    if is_bad_config:
        raise ValueError(

            "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "

            "However, `%s` seems to be a %s model, so you "

            "should pass in `--do_lower_case=%s` so that the fine-tuning matches "

            "how the model was pre-training. If this error is wrong, please "

            "just comment out this check." % (actual_flag, init_checkpoint,

                                              model_name, case_name, opposite_flag))


def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not alrea
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值