pytorch BERT文本分类保姆级教学
本文主要依赖的工具为huggingface的transformers,更详细的解释可以查阅文档。
定义模型
模型定义主要是tokenizer、config和model的定义,直接简单粗暴点可以使用huggingface的automodel,这里cache_dir为模型下载的路径,在config中可以定义后面模型要用到的参数,比如我后面model用的是BertForSequenceClassification,需要一个参数来定义模型预测的标签数,所以我在config中加了num_labels=3.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', cache_dir='./cache_down')
config = BertConfig.from_pretrained('bert-base-chinese', cache_dir='./cache_down', num_labels=3)
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', cache_dir='./cache_down', config=config)
如果预训练参数已经下载了,那么可以用下面的方法,这里要将下载的三个文件命名为config.json,pytorch_model.bin,vocab.txt不然from_pretrained会找不到文件。
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
config = BertConfig.from_pretrained(args.model_name_or_path, num_labels=3)
model = BertForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
这个from_pretrained功能很强大,也可以载入tensorflow的预训练参数.ckpt文件,只要修改个参数:
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path, do_lower_case=args.do_lower_case)
config = BertConfig.from_json_file('bert_config.json')
model = BertForSequenceClassification.from_pretrained('bert_model.ckpt', from_tf=True,config=config)
读入数据
这里tensorflow的做法是自己修改一个读入数据类DataProcessor,在这个类里边修改读入数据,然后直接跑run_classifier.py就好了(大致上是这样,没用tensorflow跑过bert,大佬请飘过)。transformers的样例也用一个读入类。我把他们拆分开来便于理解。
- 用一个类来存储每条example
class InputExample(object):
def __init__(self, guid, text_a, text_b=None, label=None):
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
- 写一个读数据的方法,将json/excel/csv文件读入
def read_examples(input_file, is_training, sep=','):
df = pd.read_excel(input_file)
examples = []
for val in df[['idx', 'text_a', 'text_b', 'label']].values:
examples.append(InputExample(guid=val[0], text_a=val[1], text_b=val[2], label=val[3]))
return examples
- 写一个将读入数据转换成BERT所需特征的方法(input_ids:每个字的id/input_mask:mask有些句子短了,就用0mask掉多出来的部分/segment_ids:NSP任务中区分两句不同句子)
def convert_examples_to_features(examples, tokenizer, max_seq_length, split_num, is_training):
features = []
for example_index, example in enumerate(examples):
context_tokens = tokenizer.tokenize(example.text_a)
ending_tokens = tokenizer.tokenize(example.text_b)
skip_len = len(context_tokens) / split_num
choices_features = []
index_begin = 0
index_end = split_num - 1
if example_index < 1 and is_training:
logger.info("** RAW EXAMPLE **")
logger.info("content: {}".format(context_tokens))
for i in range(split_num):
if i != index_end:
context_tokens_choice = context_tokens[int(i * skip_len):int((i + 1) * skip_len)]
elif i == index_end:
context_tokens_choice = context_tokens[-int(i * skip_len):]
_truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3, i == index_end)
tokens = ["[CLS]"] + ending_tokens + ["[SEP]"] + context_tokens_choice + ["[SEP]"]
segment_ids = [0] * (len(ending_tokens) + 2) + [1] * (len(context_tokens_choice) + 1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
padding_length = max_seq_length - len(input_ids)
input_ids += ([0] * padding_length)
input_mask += (