本文是BERT
实战的第二篇,使用BERT
进行命名实体识别(序列标注类任务)。
1. 准备
1.1 环境
python 3.7
;pytorch 1.3
;transformers 2.3
(安装教程);
1.2 数据
- 数据链接(链接:https://pan.baidu.com/s/1spwmV3_07U0HA9mlde2wMg
提取码:reic);
2. 实战
2.1 训练代码
lr = 5e-5
max_length = 256
batch_size = 8
epoches = 20
cuda = True
# cuda = False
max_grad_norm = 1
warmup_steps = 3000
train_steps = 60000
train_dataset_file_path = './data/names/train.json'
eval_dataset_file_path = './data/names/text.json'
tokenizer = BertTokenizer('./bert_model/vocab.txt')
with open('./data/names/label.json', mode='r', encoding='utf8') as f:
id2label, label2id = json.load(f)
# 得到attention mask
def get_atten_mask(tokens_ids, pad_index=0):
return list(map(lambda x: 1 if x != pad_index else 0, tokens_ids))
class NerDataSet(Dataset):
def __init__(self, file_path):
token_ids = []
token_attn_mask = []
token_seg_type = []
labels = []
with open(file_path, mode='r', encoding='utf8') as f:
data_set = json.load(f)
data_set = data_set[:5]
for data in data_set:
text = data['text']
tmp_token_ids = tokenizer.encode(text, max_length=max_length, pad_to_max_length=True)
if len(text) < max_length - 2:
tmp_labels = [label2id['O']] + [label2id[item] for item in data['labels']] + [label2id['O']] * (
max_length - len(data['labels']) - 1)
else:
tmp_labels = [label2id['O']] + [label2id[item] for item in data['labels']][:max_length - 2] + [
label2id['O']]
tmp_token_attn_mask = get_atten_mask(tmp_token_ids)
tmp_seg_type = tokenizer.create_token_type_ids_from_sequences(tmp_token_ids[1:-1])
token_ids.append(tmp_token_ids)
token_attn_mask.append(tmp_token_attn_mask)
token_seg_type.append(tmp_seg_type)
labels.append(tmp_labels)
self.TOKEN_IDS = torch.from_numpy(np.array(token_ids)).long()
self.TOKEN_ATTN_MASK = torch.from_numpy(np.array(token_attn_mask)).long()
self.TOKEN_SEG_TYPE = torch.from_numpy(np.array(token_seg_type)).long()
self.LABELS = torch.from_numpy(np.array(labels)).long()
def __len__(self):
return self.LABELS.shape[0]
def __getitem__(self, item):
return self.TOKEN_IDS[item], self.TOKEN_SEG_TYPE[item], \
self.TOKEN_ATTN_MASK[item], self.LABELS[item]
def train(train_dataset, model: BertForTokenClassification, scheduler, optimizer: AdamW, batch_size=batch_size,
device=None):
train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
model.train()
tr_loss = 0.0
tr_acc = 0
global_step = 0
if cuda:
torch.cuda.empty_cache()
for step, batch in tqdm(enumerate(train_loader)):
# print(step)
inputs = {
'input_ids': batch[0].to(device),
'token_type_ids': batch[1].to(device),
'attention_mask': batch[2].to(device),
'labels': batch[3].to(device)
}
outputs = model(**inputs)
loss = outputs[0]
# print(loss)
logits = outputs[1].view(-1, len(label2id))
tr_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
scheduler.step()
optimizer.step()
model.zero_grad()
# 计算准确率
_, pred = logits.max(1)
number_corr = (pred == batch[-1].to(device).view(-1)).long().sum().item()
tr_acc += number_corr
global_step += 1
return tr_loss / global_step, tr_acc / (len(train_dataset) * max_length)
class NER(tuple):
def __init__(self, ner):
self.ner = ner
def __hash__(self):
return self.ner.__hash__()
def __eq__(self, other):
return self.ner == other
def get_entities(text_list, label_list):
# text = ''.join(text_list)
result_ent = []
buf_ent = []
ner_clas = ''
for i, item in enumerate(label_list):
item = str(item)
item = item.strip()
if item == 'O':
if len(buf_ent) > 0:
result_ent.append((''.join(buf_ent), ner_clas))
buf_ent = []
continue
pre_item, ner_item = item.split('-')
if pre_item == 'B':
if len(buf_ent) > 0:
result_ent.append((''.join(buf_ent), ner_clas))
buf_ent = []
buf_ent.append(text_list[i])
ner_clas = ner_item
else:
if ner_item == ner_clas:
buf_ent.append(text_list[i])
else:
logger.warn('ner error')
return result_ent
def predict_func(text, model, device=None):
text = text.strip()
token_ids = tokenizer.encode(text, max_length=max_length, pad_to_max_length=True)
token_attn_mask = get_atten_mask(token_ids)
seq_type_ids = tokenizer.create_token_type_ids_from_sequences(token_ids[1:-1])
token_ids = torch.from_numpy(np.array(token_ids)).unsqueeze(0).long()
token_attn_mask = torch.from_numpy(np.array(token_attn_mask)).unsqueeze(0).long()
seq_type_ids = torch.from_numpy(np.array(seq_type_ids)).unsqueeze(0).long()
inputs = {
'input_ids': token_ids.to(device),
'token_type_ids': seq_type_ids.to(device),
'attention_mask': token_attn_mask.to(device),
}
output = model(**inputs)[0]
output = output.squeeze()
output = output[1:len(text) + 1, :]
_, output = output.max(1)
label_list = list(output.cpu().numpy())
return get_entities(list(text), [id2label[str(item)] for item in label_list])
def evalate(model: BertForTokenClassification, device=None):
with open('./data/names/text.json', mode='r', encoding='utf8') as f:
test_data = json.load(f)
X, Y, Z = 1e-10, 1e-10, 1e-10
f1, precision, recall = 0.0, 0.0, 0.0
result_list = []
pbar = tqdm()
for data in tqdm(test_data):
predict_entities = predict_func(data['text'], model, device)
predict_entities = [NER((item[0], item[1])) for item in predict_entities]
entities = [NER((item[0], item[1])) for item in data['entities']]
R = set(predict_entities)
T = set(entities)
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
pbar.update()
pbar.set_description('f1: %.5f, precision: %.5f, recall: %.5f' %
(f1, precision, recall))
s = {
'text': data['text'],
'ent_list': list(T),
'ent_list_pred': list(R),
'new': list(R - T),
'lack': list(T - R),
}
result_list.append(s)
with open('./predict.json', mode='w', encoding='utf8') as f:
json.dump(result_list, f, indent=4, ensure_ascii=False)
pbar.close()
with open('./predict.json', mode='w', encoding='utf8') as f:
json.dump(result_list, f, indent=4, ensure_ascii=False)
return f1, precision, recall
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
if __name__ == '__main__':
config = BertConfig.from_pretrained('./bert_model/bert_config.json')
device = torch.device('cuda' if cuda else 'cpu')
model = BertForTokenClassification.from_pretrained('./bert_model/pytorch_model.bin', config=config).to(device)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': 0.0},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, train_steps)
logger.info('create train dataset')
train_dataset = NerDataSet(train_dataset_file_path)
# logger.info('create eval dataset')
# eval_dataset = NerDataSet(eval_dataset_file_path)
eval_best_f1 = 0.0
for e in range(1, epoches):
start_time = time.time()
train_loss, train_acc = train(train_dataset, model, scheduler, optimizer, batch_size, device)
# eval_acc = evalate(eval_dataset, model, batch_size, device)
eval_result = evalate(model, device)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
logger.info('Epoch: {:02} | Time: {}m {}s'.format(e, epoch_mins, epoch_secs))
logger.info(
'Train Loss: {:.6f} | Eval f1: {:.6f} | Eval Pre: {:.6f} | Eval Rec: {:.6f}'.format(train_loss,
eval_result[0],
eval_result[1],
eval_result[2]))
if eval_result[0] > eval_best_f1:
eval_best_f1 = eval_result[0]
torch.save(model.state_dict(), './models/model_{}'.format(e))
3. 效果
- 在验证集最终效果:
f1:0.9247
,Precision:0.925
,Recall:0.924
;