在测试或者预测时,Transformer里decoder为什么还需要seq mask?

 

在测试或者预测时,Transformer里decoder为什么还需要seq mask?

这个sublayer里会用一个三角矩阵来做mask。在training的时候,这个mask是为了保证causality,即把将来的数据mask掉,这都比较好理解。但是在做testing的时候,为什么还要继续使用这个mask了?

如在http://nlp.seas.harvard.edu/2018/04/03/attention.html#batches-and-masking 里

在测试或者预测时,Transformer里decoder为什么还需要seq mask?

在testing 或 inferring的时候,我们是一个接一个地预测下一个字,假如我们已经产生了<s>, 字1,字2,现在需要预测字3,我们是把<s>, 字1,字2一块输入到decoder里去,理论上说,一个3X3的mask三角矩阵在此时已经不需要了。不知道大家是咋理解的?

============================================

我自己的理解是这样的。

 

Transformer在训练的时候是并行执行的,所以在decoder的第一个sublayer里需要seq mask,其目的就是为了在预测未来数据时把这些未来的数据屏蔽掉,防止数据泄露。如果我们非要去串行执行training,seq mask其实就不需要了。比如说我们用transformer做NMT,训练数据里有一个sample是I love China -->我爱中国。利用串行的思维来想,在训练过程中,我们会

1. 把I love China输入到encoder里去,利用top encoder最终输出的tensor (size: 1X3X512,假设我们采用的embedding长度为512,而且batch size = 1)作为decoder里每一层用到的k和v;

2. 将<s>作为decoder的输入,将decoder最终的输出和‘我’做cross entropy计算error。

3. 将<s>,我作为decoder的输入,将decoder最终:输出的最后一个prob. vector和‘爱’做cross entropy计算error。

4. 将<s>,我,爱 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和‘中’做cross entropy计算error。

5. 将<s>,我,爱,中 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和‘国’做cross entropy计算error。

6. 将<s>,我,爱,中,国 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和</s>做cross entropy计算error。

2-6里都可以不用seq mask。

 

而在transformer实际的training过程中,我们是并行地将2-6在一步中完成,即

7:将<s>,我,爱,中,国 作为decoder的输入,将decoder最终输出的5个prob. vector和我,爱,中,国,</s>分别做cross entropy计算error。

比如要想在7中计算第一个prob. vector的整个过程中,都不用到‘我’及其后面字的信息,就必需seq mask。对所有位置的输入,情况都是如此。

但是,仔细想想,7虽然包括了2-6,不过有一点区别。比如对3来说,我们是可以不用seq mask的,这时 <s>所对应的encoder output是会利用'我'里的信息的;而在并行时,seq mask是必需的,这时<s>所对应的encoder output是不会利用'我'里的信息的。

如此一来,我们可以看到,在transformer训练时,由于是并行计算,decoder的第i个输入只能用到i,i-1,..., 0这些位置上输入的信息;当训练完成后,在实际预测过程中,虽然理论上decoder的第i个输入可以用到所有位置上输入的信息,但是由于模型在训练过程中是按照前述方式训练的,所以继续使用seq mask会和训练方式匹配,得到更好的预测结果。

 

我感觉从理论上看,按照串行方式1-6来训练并且不用seq mask,我们可以把信息用得更足一些,似乎可能模型的效果会好一点,但是计算效率比transformer的并行训练差太多,最终综合来看应该还是并行的综合效果好。

不知道理解得对不对。