论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

https://arxiv.org/pdf/2004.01655.pdf

对齐交叉熵的非自回归机器翻译(未开源)

 

非自回归机器翻译模型对整个目标序列进行并行预测,极大地提高了解码速度。然而由于模型中缺少自回归因子,难以对词序建模,且交叉熵损失会严重影响词序的微小变化,因此作者提出使用 aligned cross entropy 对齐交叉熵 AXE 作为非自回归模型的损失函数,AXE 使用一个可微的动态规划分配损失,寻找目标和预测 token 之间可能的最佳单调对齐。AXE-based CMLMs 可显著提高主要 WMT 的基准性能,为最新的 sota 非自回归模型。

非自回归机器翻译模型并行预测每个单词以显著提高解码速度,但这是以牺牲性能为代价的,因为当模型不能根据以前的预测建立时,对词序进行建模比较困难。一系列半自回归模型表明,权衡速度与精度可以优化与有限形式的自回归。作者提出一种新的非自回归机器翻译的训练损失,减轻了词序错误的惩罚,并且在不修改模型或解码算法的情况下显著提高了性能。

 

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

现有的模型(自回归和非自回归模型)通常使用交叉熵损失进行训练,交叉熵是一个严格的损失函数,预测不在位置的单词都会受到惩罚,即使是编辑距离很小的输出序列(图1)。自回归模型会避免这种惩罚,因为单词是根据句子 perfix 生成的,而非自回归模型无法获知这个信息。非自回归模型应该更多地关注根本错误 root errors(如缺失单词),同时允许对级联错误(正确的单词放在错误的地方)给予更多的信任。

对齐交叉熵 AXE 是一种新的目标函数,基于 token 标签序列和 token 分布预测序列之间的对齐来计算交叉熵损失,采用动态规划的方法寻找单调对齐使交叉熵损失最小化。通过忽略绝对位置、关注相对顺序和词法匹配,为非自回归模型提供更准确的训练。使用矩阵运算实现 AXE,作为损失函数训练条件遮蔽语言模型(CMLM; Ghazvininejad et al., 2019)用于机器翻译。与交叉熵相比,AXE 只略微增加了训练时间,且不需要改变并行 argmax 解码。实验表明,AXE 显著提高了 CMLMs 的性能,在 WMT'14 EN-DE 中提高5个 BLEU score,同时具有相同的解码速度。AXE CMLMs 显著优于最先进的非自回归模型,如 FlowSeq(Ma et al., 2019)、CRF-based semi-autoregressive model with bigram LM decoding(Sun et al., 2019)。AXE 可以使模型在预测时更有信心,减少了多模态性 multimodality。

Aligned Cross Entropy

设 Y 为含 n 个 tokens 的目标序列 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation,P 为含 m 个 tokens 的预测序列 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation,目标是在 Y 和 P 之间找到一个单调对齐以最小化交叉熵损失,从而将惩罚集中在词法错误(预测错误的 token)上,而不是位置错误(在错误的地方预测正确的 token)上。

对齐函数 α 将目标位置映射到预测位置,即 α : {1, …, n} → {1, ..., m},假设这个对齐是单调的,i ≤ j iff α(i) ≤ α(j) (iff: if and only if 当且仅当,充要条件),定义 conditional AXE loss 为:

        论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

 

损失函数的第一项是 Y 和 P 之间对齐部分的交叉熵,第二项是对未对齐部分预测的惩罚。ε 是词汇表中特殊概率分布的 “blank” token,但不会出现在最终的输出字符串中。最后的 AXE loss 是最小化所有可能的单调对齐的条件损失:

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

​​​

Dynamic Programming

求 Y、P 中任意一个前序 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 的最优对齐,最大得分是整个序列对齐 i = n, j = m。首先定义 (n + 1) * (m + 1) 的矩阵 A,论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 表示将 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 对齐到 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 的最小损失,初始化 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 为0,然后使用每个单元 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 的三种运算的局部最小值来填充矩阵:Align, Skip Prediction, Skip Target,表1描述了每种运算作及其更新公式。一旦矩阵被填满,论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 将包含最佳对齐的交叉熵损失。算法1为 AXE 动态规划的一个简单实现。

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

由第二个公式可知最优对齐可以是多对一的,多个目标位置可以映射到单个预测,计算方式是对齐第一个映射的 token 并跳过其余的目标 tokens。为了防止跳过太多目标 token,使用一个超参数 δ 惩罚 skip target 项,公式二中的 δ 为 1。消融实验表明,更高的 δ 值的性能更好。

算法 1 的时间复杂度为 O(n*m),矩阵 A 的更新可以在 GPU 上并行化。遍历每个反对角 anti-diagonal 而不是每个单元格,并行计算反对角线上的所有值。即首先计算 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation,然后计算 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 等等。反对角的数量是 n + m + 1,可得到 O(n + m) 的时间复杂度。

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

 

图 2 为 AXE 的一个示例,预测总体上是正确的,但对齐偏差会产生常规的交叉熵损失,严重影响前三个预测,即使 P2 和 P3 与 Y1 和 Y2 对齐时是正确的。另一方面,AXE 发现了目标和预测之间的对齐,这使得它可以将惩罚集中在冗余预测和缺失 token 的根本错误上。

 

训练非自回归模型

(1)CMLMs

以源序列 X 和部分可观测的目标序列 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 作为输入,预测被 masked 的目标序列 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation 的概率,模型是一个 encoder-decoder transformer,在训练时 Y 的一个随机子集被 mask,在推断时所有的目标 tokens 都被遮蔽(Y = 论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation),而 Y 的长度(遮蔽 tokens 的数量)未知,为了估计 Y 的长度,引入一个辅助任务基于源序列 X 预测目标序列的长度。

(2)AXE 用于 CMLMs

模型可以产生 blank tokens (ε) 以缩短预测序列的长度,为了解释在推断时可能跳过的 token,解码前将预测长度乘以一个超参数 λ(使用验证集调参)。

(3)调整训练目标以适于 AXE

训练时不需要遮蔽整个序列,可用部分观察序列,因此用三种变体进行实验:

a) Unobserved Input, Predict all

所有 tokens 都被遮蔽,预测所有。虽然 AXE 允许遮蔽 tokens 的数量 m 与目标序列的长度 n 不同,但初步实验表明,设置 m = n 可以产生更好的模型。

b) Partially-Observed Input, Predict all

和 CMLM 训练相同,目标序列的一个随机子集在作为输入传入模型之前被掩蔽,然后在整个序列上应用 AXE。训练时总是设置 m = n,以避免目标序列在遮蔽之后进一步改变。

c) Partially-Observed Input, Predict Masks

AXE 在计算交叉熵时跳过观察到的 tokens,将训练集中在实际任务上。通过将每个观察到的 Yi 设为 Pi(Yi) = 1,即如果第 i 个 token 被观察到,且与同一位置 Pi 的预测对齐,则没有惩罚。消融实验表明,这提供了一个适度但持续的性能提升,因此作者使用这种设置来训练模型。

 

实验

评估 AXE 训练的 CMLMs 在6个标准机器翻译基准上的表现,证明 AXE 在交叉熵训练的 CMLMs 和最近提出的非自回归模型的基础上显著提高了性能。En-Zh 使用 SacreBLEU,其他使用 BLEU。

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

使用 AXE 可以显著提高所有 CMLMs 的基准性能,此外将原始数据与知识蒸馏训练进行比较,结果证明了知识蒸馏对于非自回归方法的重要性。

消融实验

表 5 表明使用部分观测的输入效果更好;表 6 表明高 δ可以显著提高性能;表 7 表明使用最佳的长度乘数可提高性能。

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

分析

首先分析不同序列长度下交叉熵与 AXE 训练的 CMLMs 的性能,使用 compare-mt (Neubig et al., 2019) 将 WMT'14 EN-DE 和 DE-EN 的测试集根据目标序列长度分桶,计算每个桶的 BLEU。表 8 表明,随着序列长度的增加,交叉熵训练的模型的性能急剧下降,而 AXE 训练的模型的性能相对稳定。一种解释是,序列越长就越有可能观察到预测与目标之间的偏差,AXE 的重新排列缓解了这种情况。

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation

计算每个生成的 token 在序列所有位置的概率,并根据与生成位置的相对距离对这些概率进行平均。图 3 为短(< 10 tokens)和长(> 30 tokens)目标序列的平均概率。两种模型对短序列的预测都相当有信心 (a),概率在生成的位置有一个高峰,且随着距离的增加迅速下降。对于较长的句子 (b),交叉熵的峰值较低,预测位置 ±1 的近邻平均概率约为 0.14,几乎是峰值的三分之一;AXE 的概率峰值明显更大,与中心相比,相邻位置的概率可以忽略不计。一种解释是,交叉熵鼓励给近邻一些概率,以便在预测与目标不对齐的情况下“hedge their bets”。由于 AXE 在计算实际损失之前找到了最佳的对齐方式,所以就没有必要将概率分散到近邻了。

AXE 减少了多模态问题 (Gu et al., 2018),由于许多非自回归模型的预测之间的 coordination 极小,一个模型可能同时考虑许多可能的翻译,在这种情况下,模型可能合并两个或多个不同的翻译,并生成不一致的输出,这种输出通常以 token 重复为特征。表 9 显示了交叉熵和 AXE CMLMs 的重复率,用 AXE 代替交叉熵大大减少了多模态性,将重复减少了12倍。

论文阅读——Aligned Cross Entropy for Non-Autoregressive Machine Translation