在查看了https://github.com/chriskhanhtran/bert-extractive-summarization.git关于摘要提取的内容之后,说说对于bert_extract_summarization其中的理解。
主要内容是根据一篇长文本,对其中的关键内容进行抽取,得到了文本摘要的内容。因为作者只提供了训练好的模型,所以只是使用作者的模型带入去理解。
01 预处理:将一篇文本中的多个句子用特殊符号([SEP]、[CLS])合并。将文本转换为id。并标记mask值。因为文本中有多个句子用特殊符号连接,因为CLS在bert模型中往往被认为是代表整句话的意思。所以extract_summarization模型的主要任务就要对cls进行预测。所以将cls标注的位置都保存记录下来。并对cls标注的数据,进行mask的填充。记录cls标注数据的shape(batch,cls_sel_len)。input的shape(batch,seq_len)
02 模型实现:将预处理后的句子,送入bert模型。得到最后一层的输出输出output (batch,seq_len,d_model)。然后根据之前处理的cls,根据矩阵相乘,找到输出中都是cls标记位置的向量矩阵cls_out_put (batch,cls_sel_len,d_model)。
03 将得到的cls标记位置向量矩阵,做为输入。选用transformer的encoder阶段,计算位置向量、多头注意力、和ffn以及残差层和正则化层。得到输出output (batch,cls_sel_len,d_model)。
04 将得到的输出经过非线性从得到最后维度为1, output (batch,cls_sel_len,1)。然后降维,经过softmax得到,最后各自cls位置的概率。然后根据概率选出,对应排名前几位置对应的内容,若为输出的摘要内容。