天池oppo-text-match比赛-苏剑林baseline代码解读

本文根据苏剑林的基于bert的baseline进行短文本匹配的讲解,其github地址是:https://github.com/bojone/oppo-text-match/blob/main/baseline.py

赛题地址:

https://tianchi.aliyun.com/competition/entrance/531851

数据探索

下载好相关数据之后,我们先看一下数据是什么样的:

path = '/content/drive/MyDrive/oppo-text-match/baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_train_20210228.tsv'
with open(path,'r',encoding='utf-8') as f:
  lines = f.readlines()
  for i,line in enumerate(lines):
    print(line.split('\t'))
    if i == 5:
      break

结果:

['1 2 3 4 5 6 7', '8 9 10 4 11', '0\n']
['12 13 14 15', '12 15 11 16', '0\n']
['17 18 12 19 20 21 22 23 24', '12 23 25 6 26 27 19', '1\n']
['28 29 30 31 11', '32 33 34 30 31', '1\n']
['29 35 36 29', '29 37 36 29', '1\n']
['38 23 39 9 40', '12 19 41 42 23 43 12 23 44 41 42 19', '0\n']

数据都是脱敏的,也就是字都用数字来表示了。
统计一下text1+text2的长度:

import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontProperties
import pandas as pd
from

train_path = './baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_train_20210228.tsv'
test_path = './baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_testA_20210228.tsv'

def cal_len_dis(path):
    with open(path,'r',encoding='utf-8') as f:
        lines = f.readlines()
        len_list = []
        for line in lines:
            line = line.strip().split('\t')
            len_list.append(len(line[0]+line[1]))
    return len_list

def get_len_detail(data):
    df = pd.DataFrame(data)
    res = df.describe()
    return res

# 设置matplotlib正常显示中文和负号
font = FontProperties(fname=r'/data02/gob/project/text-match/simhei.ttf')
matplotlib.rcParams['axes.unicode_minus']=False     # 正常显示负号

def draw_hist(data):
    plt.hist(data, bins=40, facecolor="blue", edgecolor="black", alpha=0.7)
    # 显示横轴标签
    plt.xlabel("长度",fontproperties=font)
    # 显示纵轴标签
    plt.ylabel("数量",fontproperties=font)
    # 显示图标题
    plt.title("句子长度统计",fontproperties=font)
    plt.savefig('len_hist.png')
    plt.show()


if __name__ == '__main__':
    len_list = cal_len_dis(train_path)
    res = get_len_detail(len_list)
    print(res)
    draw_hist(len_list)

在这里插入图片描述
相关统计量:
count 100000.0000
mean 46.8328
std 17.317
min 12
25% 35
50% 43
75% 55
max 279

baseline中值得注意的一些代码

from bert4keras.snippets import truncate_sequences
truncate_sequences(maxlen, -1, a, b)

这个函数用于截断超过最大长度的句子,如果len(a+b)>maxlen,则对句子a进行截断。

def random_mask(text_ids):
    """随机mask
    """
    input_ids, output_ids = [], []
    rands = np.random.random(len(text_ids))
    for r, i in zip(rands, text_ids):
        if r < 0.15 * 0.8:
            input_ids.append(4)
            output_ids.append(i)
        elif r < 0.15 * 0.9:
            input_ids.append(i)
            output_ids.append(i)
        elif r < 0.15:
            input_ids.append(np.random.choice(len(tokens)) + 7)
            output_ids.append(i)
        else:
            input_ids.append(i)
            output_ids.append(0)
    return input_ids, output_ids

这个函数用于随机将一些词mask掉。

def sample_convert(text1, text2, label, random=False):
    """转换为MLM格式
    """
    text1_ids = [tokens.get(t, 1) for t in text1]
    text2_ids = [tokens.get(t, 1) for t in text2]
    if random:
        if np.random.random() < 0.5:
            text1_ids, text2_ids = text2_ids, text1_ids
        text1_ids, out1_ids = random_mask(text1_ids)
        text2_ids, out2_ids = random_mask(text2_ids)
    else:
        out1_ids = [0] * len(text1_ids)
        out2_ids = [0] * len(text2_ids)
    token_ids = [2] + text1_ids + [3] + text2_ids + [3]
    segment_ids = [0] * len(token_ids)
    output_ids = [label + 5] + out1_ids + [0] + out2_ids + [0]
    return token_ids, segment_ids, output_ids

用于转换单个样本为bert的输入。

class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_output_ids = [], [], []
        for is_end, (text1, text2, label) in self.sample(random):
            token_ids, segment_ids, output_ids = sample_convert(
                text1, text2, label, random
            )
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_output_ids.append(output_ids)
            if len(batch_token_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_output_ids = sequence_padding(batch_output_ids)
                yield [batch_token_ids, batch_segment_ids], batch_output_ids
                batch_token_ids, batch_segment_ids, batch_output_ids = [], [], []

用于将多个样本制作成batch的格式。

# 加载预训练模型
model = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    with_mlm=True,
    keep_tokens=[0, 100, 101, 102, 103, 100, 100] + keep_tokens[:len(tokens)]
)

这个函数用于加载预训练的模型

def masked_crossentropy(y_true, y_pred):
    """mask掉非预测部分
    """
    y_true = K.reshape(y_true, K.shape(y_true)[:2])
    y_mask = K.cast(K.greater(y_true, 0.5), K.floatx())
    loss = K.sparse_categorical_crossentropy(y_true, y_pred)
    loss = K.sum(loss * y_mask) / K.sum(y_mask)
    return loss[None, None]

计算损失。

model.compile(loss=masked_crossentropy, optimizer=Adam(1e-5))
model.summary()

定义优化器和损失。

# 转换数据集
train_generator = data_generator(train_data, batch_size)
valid_generator = data_generator(valid_data, batch_size)
test_generator = data_generator(test_data, batch_size)


def evaluate(data):
    """线下评测函数
    """
    Y_true, Y_pred = [], []
    for x_true, y_true in data:
        y_pred = model.predict(x_true)[:, 0, 5:7]
        y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)
        y_true = y_true[:, 0] - 5
        Y_pred.extend(y_pred)
        Y_true.extend(y_true)
    return roc_auc_score(Y_true, Y_pred)


class Evaluator(keras.callbacks.Callback):
    """评估与保存
    """
    def __init__(self):
        self.best_val_score = 0.

    def on_epoch_end(self, epoch, logs=None):
        val_score = evaluate(valid_generator)
        if val_score > self.best_val_score:
            self.best_val_score = val_score
            model.save_weights('best_model.weights')
        print(
            u'val_score: %.5f, best_val_score: %.5f\n' %
            (val_score, self.best_val_score)
        )


def predict_to_file(out_file):
    """预测结果到文件
    """
    F = open(out_file, 'w')
    for x_true, _ in tqdm(test_generator):
        y_pred = model.predict(x_true)[:, 0, 5:7]
        y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)
        for p in y_pred:
            F.write('%f\n' % p)
    F.close()


if __name__ == '__main__':

    evaluator = Evaluator()

    model.fit(
        train_generator.forfit(),
        steps_per_epoch=len(train_generator),
        epochs=100,
        callbacks=[evaluator]
    )

else:

    model.load_weights('best_model.weights')

加载数据以及评估等,最后在主函数中调用。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

xiximayou

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值