Soft-Masked-Bert网络细节解读
大家好,我是隔壁小王。
Soft-Masked-Bert是复旦大学和字节跳动联合发布的在bert基础上针对文本纠正的网络模型,这里对其细节进行一个梳理。
考虑到我另外一个网络中有讲过bert的细节,因此这里姑且把bert作为一个黑盒,详细介绍下smbert相比与bert改动的部分。
首先上图:
别看它这个图挺唬人,其实改动非常简单,该网络主要加入的是一个错别字的检测网络部分也就是图中的Detection Nerwork。
假设输入的句长是128,embedding后的维度是768,batchsize就定为16好了,那么bert的embedding部分不会变,依旧是token_embedding+position_embedding+segment_embedding,得到的维度是(16,128,768),接下来将这些输入接入到一个双向的GRU里,输出是(16,128,1536)。此时接一个全链接(1536,768)再变回(16,128,768)。此时得到的结果就是图中的pi,也就是说该tensor表示的是当前我这个字是是否是错别字的可能性,当然此处只是检测,该改成啥它还不管。
接下来会有一个e-mask这样的embedding与上述embedding想加,实际上就是第103个字符“[MASK]”的embedding值,维度是(128, 768),接下来按照这个方法计算就可以了:
这里要注意的是pi在计算前会经过一个sigmoid,换句话说,当这个自被认为是错别字时,pi就接近1,否则则接近0。
最终得出的 ei' 就是一个与bert输入同维度的embeddings:(16,128,768)。接下来的事就都是跟bert一模一样的事了。
最后,还有个残差计算,bert的12层transformer block的结果(16,128,768)要与最开始的输入embedding: ei(16,128,768)进行想加,结果也是(16,128,768),然后接全链接和softmax就可以了。
本人的复现源码如下:
https://github.com/whgaara/pytorch-soft-masked-bertgithub.com
最后说说测试效果:
因为bert的预训练模型本身就很强大,其实很多基于bert改动的网络在预训练的基础上进行finetune后的结果都不会太差。再考虑到训练的速度,本人没用使用任何预训练模型,只是随机找了一些古诗进行了训练,默认16个epoch,测试集就是将这些古诗随机位置替换一个随机的字,用训练好的模型进行纠正,咱们看看结果如何:
最后说说测试效果:
因为bert的预训练模型本身就很强大,其实很多基于bert改动的网络在预训练的基础上进行finetune后的结果都不会太差。再考虑到训练的速度,本人没用使用任何预训练模型,只是随机找了一些古诗进行了训练,默认16个epoch,测试集就是将这些古诗随机位置替换一个随机的字,用训练好的模型进行纠正,咱们看看结果如何:
Bert epoch0:
EP_0 mask loss:0.15134941041469574
EP:0 Model Saved on:../../checkpoint/finetune/mlm_trained_128.model.ep0
top1纠正正确率:0.81
top5纠正正确率:0.91
Bert epoch8:
EP_train:8: 100%|| 170/170 [00:53<00:00, 3.20it/s]
EP_8 mask loss:0.010850409045815468
EP:8 Model Saved on:../../checkpoint/finetune/mlm_trained_128.model.ep8
top1纠正正确率:0.94
top5纠正正确率:0.98
Bert epoch15:
EP_train:15: 100%|| 170/170 [00:53<00:00, 3.21it/s]
EP_15 mask loss:0.002929957117885351
EP:15 Model Saved on:../../checkpoint/finetune/mlm_trained_128.model.ep15
top1纠正正确率:0.97
top5纠正正确率:0.99
soft-masked-bert epoch0:
EP_0 mask loss:0.11019379645586014
EP:0 Model Saved on:../checkpoint/finetune/mlm_trained_128.model.ep0
top1纠正正确率:0.88
top5纠正正确率:0.91
soft-masked-bert epoch8:
EP_train:8: 100%|| 170/170 [01:01<00:00, 2.74it/s]
EP_8 mask loss:0.011160945519804955
EP:8 Model Saved on:../checkpoint/finetune/mlm_trained_128.model.ep8
top1纠正正确率:0.93
top5纠正正确率:0.98
soft-masked-bert epoch15:
EP_train:15: 100%|| 170/170 [01:01<00:00, 2.75it/s]
EP_15 mask loss:0.0014254981651902199
EP:15 Model Saved on:../checkpoint/finetune/mlm_trained_128.model.ep15
top1纠正正确率:0.94
top5纠正正确率:0.98
总结下:smbert收敛快一点,但是结果没有bert好,速度也会慢一点。令我疑惑的是smbert最后加的残差网络,ei是带有错字信息的输入内容,好不容易纠错完的结果最后再加一个带有错误信息的原始输入,不是很懂。上面的smbert就是我将残差去掉以后的结果,如果将残差加上,正确率还要降1个点,我想这正论证了我的想法。当然,git上的代码没有任何部分缺失的。