本文根据苏剑林的基于bert的baseline进行短文本匹配的讲解,其github地址是:https://github.com/bojone/oppo-text-match/blob/main/baseline.py
赛题地址:
https://tianchi.aliyun.com/competition/entrance/531851
数据探索
下载好相关数据之后,我们先看一下数据是什么样的:
path = '/content/drive/MyDrive/oppo-text-match/baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_train_20210228.tsv'
with open(path,'r',encoding='utf-8') as f:
lines = f.readlines()
for i,line in enumerate(lines):
print(line.split('\t'))
if i == 5:
break
结果:
['1 2 3 4 5 6 7', '8 9 10 4 11', '0\n']
['12 13 14 15', '12 15 11 16', '0\n']
['17 18 12 19 20 21 22 23 24', '12 23 25 6 26 27 19', '1\n']
['28 29 30 31 11', '32 33 34 30 31', '1\n']
['29 35 36 29', '29 37 36 29', '1\n']
['38 23 39 9 40', '12 19 41 42 23 43 12 23 44 41 42 19', '0\n']
数据都是脱敏的,也就是字都用数字来表示了。
统计一下text1+text2的长度:
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontProperties
import pandas as pd
from
train_path = './baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_train_20210228.tsv'
test_path = './baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_testA_20210228.tsv'
def cal_len_dis(path):
with open(path,'r',encoding='utf-8') as f:
lines = f.readlines()
len_list = []
for line in lines:
line = line.strip().split('\t')
len_list.append(len(line[0]+line[1]))
return len_list
def get_len_detail(data):
df = pd.DataFrame(data)
res = df.describe()
return res
# 设置matplotlib正常显示中文和负号
font = FontProperties(fname=r'/data02/gob/project/text-match/simhei.ttf')
matplotlib.rcParams['axes.unicode_minus']=False # 正常显示负号
def draw_hist(data):
plt.hist(data, bins=40, facecolor="blue", edgecolor="black", alpha=0.7)
# 显示横轴标签
plt.xlabel("长度",fontproperties=font)
# 显示纵轴标签
plt.ylabel("数量",fontproperties=font)
# 显示图标题
plt.title("句子长度统计",fontproperties=font)
plt.savefig('len_hist.png')
plt.show()
if __name__ == '__main__':
len_list = cal_len_dis(train_path)
res = get_len_detail(len_list)
print(res)
draw_hist(len_list)
相关统计量:
count 100000.0000
mean 46.8328
std 17.317
min 12
25% 35
50% 43
75% 55
max 279
baseline中值得注意的一些代码
from bert4keras.snippets import truncate_sequences
truncate_sequences(maxlen, -1, a, b)
这个函数用于截断超过最大长度的句子,如果len(a+b)>maxlen,则对句子a进行截断。
def random_mask(text_ids):
"""随机mask
"""
input_ids, output_ids = [], []
rands = np.random.random(len(text_ids))
for r, i in zip(rands, text_ids):
if r < 0.15 * 0.8:
input_ids.append(4)
output_ids.append(i)
elif r < 0.15 * 0.9:
input_ids.append(i)
output_ids.append(i)
elif r < 0.15:
input_ids.append(np.random.choice(len(tokens)) + 7)
output_ids.append(i)
else:
input_ids.append(i)
output_ids.append(0)
return input_ids, output_ids
这个函数用于随机将一些词mask掉。
def sample_convert(text1, text2, label, random=False):
"""转换为MLM格式
"""
text1_ids = [tokens.get(t, 1) for t in text1]
text2_ids = [tokens.get(t, 1) for t in text2]
if random:
if np.random.random() < 0.5:
text1_ids, text2_ids = text2_ids, text1_ids
text1_ids, out1_ids = random_mask(text1_ids)
text2_ids, out2_ids = random_mask(text2_ids)
else:
out1_ids = [0] * len(text1_ids)
out2_ids = [0] * len(text2_ids)
token_ids = [2] + text1_ids + [3] + text2_ids + [3]
segment_ids = [0] * len(token_ids)
output_ids = [label + 5] + out1_ids + [0] + out2_ids + [0]
return token_ids, segment_ids, output_ids
用于转换单个样本为bert的输入。
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids, batch_output_ids = [], [], []
for is_end, (text1, text2, label) in self.sample(random):
token_ids, segment_ids, output_ids = sample_convert(
text1, text2, label, random
)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_output_ids.append(output_ids)
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_output_ids = sequence_padding(batch_output_ids)
yield [batch_token_ids, batch_segment_ids], batch_output_ids
batch_token_ids, batch_segment_ids, batch_output_ids = [], [], []
用于将多个样本制作成batch的格式。
# 加载预训练模型
model = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
with_mlm=True,
keep_tokens=[0, 100, 101, 102, 103, 100, 100] + keep_tokens[:len(tokens)]
)
这个函数用于加载预训练的模型
def masked_crossentropy(y_true, y_pred):
"""mask掉非预测部分
"""
y_true = K.reshape(y_true, K.shape(y_true)[:2])
y_mask = K.cast(K.greater(y_true, 0.5), K.floatx())
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
loss = K.sum(loss * y_mask) / K.sum(y_mask)
return loss[None, None]
计算损失。
model.compile(loss=masked_crossentropy, optimizer=Adam(1e-5))
model.summary()
定义优化器和损失。
# 转换数据集
train_generator = data_generator(train_data, batch_size)
valid_generator = data_generator(valid_data, batch_size)
test_generator = data_generator(test_data, batch_size)
def evaluate(data):
"""线下评测函数
"""
Y_true, Y_pred = [], []
for x_true, y_true in data:
y_pred = model.predict(x_true)[:, 0, 5:7]
y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)
y_true = y_true[:, 0] - 5
Y_pred.extend(y_pred)
Y_true.extend(y_true)
return roc_auc_score(Y_true, Y_pred)
class Evaluator(keras.callbacks.Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_score = 0.
def on_epoch_end(self, epoch, logs=None):
val_score = evaluate(valid_generator)
if val_score > self.best_val_score:
self.best_val_score = val_score
model.save_weights('best_model.weights')
print(
u'val_score: %.5f, best_val_score: %.5f\n' %
(val_score, self.best_val_score)
)
def predict_to_file(out_file):
"""预测结果到文件
"""
F = open(out_file, 'w')
for x_true, _ in tqdm(test_generator):
y_pred = model.predict(x_true)[:, 0, 5:7]
y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)
for p in y_pred:
F.write('%f\n' % p)
F.close()
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(
train_generator.forfit(),
steps_per_epoch=len(train_generator),
epochs=100,
callbacks=[evaluator]
)
else:
model.load_weights('best_model.weights')
加载数据以及评估等,最后在主函数中调用。