Soft-Masked-Bert网络细节解读

大家好,我是隔壁小王。

Soft-Masked-Bert是复旦大学和字节跳动联合发布的在bert基础上针对文本纠正的网络模型,这里对其细节进行一个梳理。

考虑到我另外一个网络中有讲过bert的细节,因此这里姑且把bert作为一个黑盒,详细介绍下smbert相比与bert改动的部分。

首先上图:

别看它这个图挺唬人,其实改动非常简单,该网络主要加入的是一个错别字的检测网络部分也就是图中的Detection Nerwork。

假设输入的句长是128,embedding后的维度是768,batchsize就定为16好了,那么bert的embedding部分不会变,依旧是token_embedding+position_embedding+segment_embedding,得到的维度是(16,128,768),接下来将这些输入接入到一个双向的GRU里,输出是(16,128,1536)。此时接一个全链接(1536,768)再变回(16,128,768)。此时得到的结果就是图中的pi,也就是说该tensor表示的是当前我这个字是是否是错别字的可能性,当然此处只是检测,该改成啥它还不管。

接下来会有一个e-mask这样的embedding与上述embedding想加,实际上就是第103个字符“[MASK]”的embedding值,维度是(128, 768),接下来按照这个方法计算就可以了:

这里要注意的是pi在计算前会经过一个sigmoid,换句话说,当这个自被认为是错别字时,pi就接近1,否则则接近0。

最终得出的 ei' 就是一个与bert输入同维度的embeddings:(16,128,768)。接下来的事就都是跟bert一模一样的事了。

最后,还有个残差计算,bert的12层transformer block的结果(16,128,768)要与最开始的输入embedding: ei(16,128,768)进行想加,结果也是(16,128,768),然后接全链接和softmax就可以了。

本人的复现源码如下:

https://github.com/whgaara/pytorch-soft-masked-bert​github.com

最后说说测试效果:

因为bert的预训练模型本身就很强大,其实很多基于bert改动的网络在预训练的基础上进行finetune后的结果都不会太差。再考虑到训练的速度,本人没用使用任何预训练模型,只是随机找了一些古诗进行了训练,默认16个epoch,测试集就是将这些古诗随机位置替换一个随机的字,用训练好的模型进行纠正,咱们看看结果如何:

最后说说测试效果:

因为bert的预训练模型本身就很强大,其实很多基于bert改动的网络在预训练的基础上进行finetune后的结果都不会太差。再考虑到训练的速度,本人没用使用任何预训练模型,只是随机找了一些古诗进行了训练,默认16个epoch,测试集就是将这些古诗随机位置替换一个随机的字,用训练好的模型进行纠正,咱们看看结果如何:

Bert epoch0:

EP_0 mask loss:0.15134941041469574

EP:0 Model Saved on:../../checkpoint/finetune/mlm_trained_128.model.ep0

top1纠正正确率:0.81

top5纠正正确率:0.91

Bert epoch8:

EP_train:8: 100%|| 170/170 [00:53<00:00, 3.20it/s]

EP_8 mask loss:0.010850409045815468

EP:8 Model Saved on:../../checkpoint/finetune/mlm_trained_128.model.ep8

top1纠正正确率:0.94

top5纠正正确率:0.98

Bert epoch15:

EP_train:15: 100%|| 170/170 [00:53<00:00, 3.21it/s]

EP_15 mask loss:0.002929957117885351

EP:15 Model Saved on:../../checkpoint/finetune/mlm_trained_128.model.ep15

top1纠正正确率:0.97

top5纠正正确率:0.99

soft-masked-bert epoch0:

EP_0 mask loss:0.11019379645586014

EP:0 Model Saved on:../checkpoint/finetune/mlm_trained_128.model.ep0

top1纠正正确率:0.88

top5纠正正确率:0.91

soft-masked-bert epoch8:

EP_train:8: 100%|| 170/170 [01:01<00:00, 2.74it/s]

EP_8 mask loss:0.011160945519804955

EP:8 Model Saved on:../checkpoint/finetune/mlm_trained_128.model.ep8

top1纠正正确率:0.93

top5纠正正确率:0.98

soft-masked-bert epoch15:

EP_train:15: 100%|| 170/170 [01:01<00:00, 2.75it/s]

EP_15 mask loss:0.0014254981651902199

EP:15 Model Saved on:../checkpoint/finetune/mlm_trained_128.model.ep15

top1纠正正确率:0.94

top5纠正正确率:0.98

总结下:smbert收敛快一点,但是结果没有bert好,速度也会慢一点。令我疑惑的是smbert最后加的残差网络,ei是带有错字信息的输入内容,好不容易纠错完的结果最后再加一个带有错误信息的原始输入,不是很懂。上面的smbert就是我将残差去掉以后的结果,如果将残差加上,正确率还要降1个点,我想这正论证了我的想法。当然,git上的代码没有任何部分缺失的。

  • 10
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值