标题在网上看了很多相关代码,不知道为什么特别难跑通,这段时间刚好自己做了个数据集测试一下。
三方库
用的库是kashgari,一个专业做文本分类和序列标注的库,其中内置了BertEmbedding,GPT2Embedding,WordEmbedding等特征提取类,可以利用与训练模型进行迁移学习。
以及内置了CNN_LSTM_Model
BiLSTM_Model
BiLSTM_CRF_Model
BiGRU_Model
BiGRU_CRF_Model
深度学习模型供选择。
对于我这种新手来说十分友好。
准备工作
- Python 3.6 环境
- BERT-Base, Chinese 中文模型
- tensorflow 1.14
- kashgari 1.1.5
读取数据
这里的语料需要处理成如下图所示格式,最小粒度字符级的格式。用BIO标注。这里不多赘述。
train_path = './train.txt'
test_path = './test.txt'
def read_file(file_path: str,
text_index: int = 0,
label_index: int = 1):
"""
根据文件路径读取训练数据、测试数据以及验证数据的text和label
"""
data_x, data_y = [], []
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.read().splitlines()
x, y = [], []
for line in lines:
rows = line.split(' ')
if len(rows) == 1:
data_x.append(x)
data_y.append(y)
x = []
y = []
else:
x.append(rows[0])
y.append(rows[1])
return data_x, data_y
train_x, train_y = read_file(train_path)
valid_x, valid_y = read_file(test_path)
创建BERT embedding
import kashgari
from kashgari.embeddings import BERTEmbedding
bert_embed = BERTEmbedding('../publish',
task=kashgari.LABELING,
sequence_length=128)
在这里加载本地的预训练模型。
训练
from kashgari.tasks.labeling import BiLSTM_CRF_Model
# 还可以选择 `CNN_LSTM_Model`, `BiLSTM_Model`, `BiGRU_Model` 或 `BiGRU_CRF_Model`
model = BiLSTM_CRF_Model(bert_embed)
model.fit(train_x,
train_y,
x_validate=valid_x,
y_validate=valid_y,
epochs=5,
batch_size=512)
笔记本cpu跑的较慢,暂时看不到验证集的准确率。一个epochs要跑很久。就不展示结果了。
参考文章
https://github.com/BrikerMan/Kashgari