三元组抽取任务,基于“半指针-半标注”结构
文章介绍:https://kexue.fm/archives/7161
数据集:http://ai.baidu.com/broad/download?dataset=sked
最优f1=0.82198
代码来源 bert4keras examples
苏神鼓励大家进行小改动后发文章出来哟。
Baidu Research Open-Access Datasetai.baidu.com这个数据集是一个中文的三元组抽取的数据集
{
"text": "《新駌鸯蝴蝶梦》是黄安的音乐作品,收录在《流金十载全记录》专辑中",
"spo_list": [
{
"subject": "新駌鸯蝴蝶梦",
"predicate": "所属专辑",
"object": "流金十载全记录",
"subject_type": "歌曲",
"object_type": "音乐专辑"
},
{
"subject": "新駌鸯蝴蝶梦",
"predicate": "歌手",
"object": "黄安",
"subject_type": "歌曲",
"object_type": "人物"
}
]
}
安装bert4keras
pip install git+https://www.github.com/bojone/bert4keras.git
训练代码如下
import json
import codecs
import numpy as np
import tensorflow as tf
from bert4keras.backend import keras, set_gelu, K
from bert4keras.layers import LayerNormalization
from bert4keras.tokenizer import Tokenizer
from bert4keras.bert import build_bert_model
from bert4keras.optimizers import Adam, ExponentialMovingAverage
from bert4keras.snippets import sequence_padding, DataGenerator
from keras.layers import *
from keras.models import Model
from tqdm import tqdm
maxlen = 128
batch_size = 64
config_path = 'wwm/bert_config.json'
checkpoint_path = 'wwm/bert_model.ckpt'
dict_path = 'wwm/vocab.txt'
def load_data(filename):
D = []
with codecs.open(filename, encoding='utf-8') as f:
for l in f:
l = json.loads(l)
D.append({
'text': l['text'],
'spo_list': [
(spo['subject'], spo['predicate'], spo['object'])
for spo in l['spo_list']
]
})
return D
# 加载数据集
train_data = load_data('kg_huge/train_data.json')
valid_data = load_data('kg_huge/dev_data.json')
predicate2id, id2predicate = {}, {}
with codecs.open('kg_huge/all_50_schemas') as f:
for l in f:
l = json.loads(l)
if l['predicate'] not in predicate2id:
id2predicate[len(predicate2id)] = l['pre