# -*- coding: utf-8 -*-
"""
Created on Sun Apr 28 10:20:04 2019
@author: wumingshi
"""
#import contextlib
import json
import os
from enum import Enum
from termcolor import colored
import sys
import modeling
import logging
import tensorflow as tf
import argparse
from tokenization import FullTokenizer, validate_case_matches_checkpoint
os.environ['CUDA_VISIBLE_DEVICES']='0'
def set_logger(context, verbose=False):
if os.name == 'nt': # for Windows
return NTLogger(context, verbose)
logger = logging.getLogger(context)
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
formatter = logging.Formatter(
'%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt=
'%m-%d %H:%M:%S')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
console_handler.setFormatter(formatter)
logger.handlers = []
logger.addHandler(console_handler)
return logger
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_mask,
segment_ids,
label_id,
is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer):
"""Convert a set of `InputExample`s to a list of `InputFeatures`."""
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(ex_index, example, label_list,
max_seq_length, tokenizer)
features.append(feature)
return features
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son
基于bert的文本表征向量embedding模型由ckpt转成pb
最新推荐文章于 2022-11-06 09:52:17 发布
该博客介绍了如何将基于BERT的文本表征模型从ckpt格式转换为PB格式,重点在于模型的预训练和finetuning。在预训练后,通过模型的'CLS'标记获取文本向量,或者在序列任务中利用模型获取每个token的向量。
摘要由CSDN通过智能技术生成