本篇整理 Transformer 架构,及在 Transformer 基础上衍生出来的 BERT 模型,最后给出
相应的应用案例。
1.Transformer的架构
Transformer 网络架构架构由 Ashish Vaswani 等人在 Attention Is All You Need 一文中提出,并用于机器翻译任务,和以往网络架构有所区别的是,该网络架构中,编码器和解码器没有采用 RNN 或 CNN 等网络架构,而是采用完全依赖于注意力机制的架构。网络架构如下所示:
该网络架构中引入了多头注意力机制,该机制的网络架构如下所示:
这里有必要对多头注意力机制进行一定的解释。假设输入数据的 batch size 为B B B ,输入数据的最大长度为F F F ,输出数据的最大长度为T T T ,共有N N N 个注意力头,每个注意力头的输出维度为H H H ,则输入/输出数据中每个词的 Embedding 的维度为 E = N × H E=N×H E = N × H ,且注意力头中每个头对应的W Q , W K , W V \boldsymbol{W}^{Q}, \boldsymbol{W}^{K} ,\boldsymbol{W}^{V} W Q , W K , W V 矩阵均属于R E × H \mathbb{R}^{E \times H} R E × H 。考虑到编码器和解码器涉及三个注意过程,且输入有所不同,这里分别来看。
2.编码器自注意力
考虑输入数据为X ∈ R B × F × E \mathbf{X} \in \mathbb{R}^{B \times F \times E} X ∈ R B × F × E ,对输入数据应用如下线性变换:Q = X W Q , ( W Q ∈ R E × H ⇒ Q ∈ R B × F × H ) K = X W K , ( W K ∈ R E × H ⇒ K ∈ R B × F × H ) V = X W V , ( W V ∈ R E × H ⇒ V ∈ R B × F × H )
\begin{aligned}
&\mathbf{Q}=\mathbf{X} \mathbf{W}^{Q}, \quad\left(\mathbf{W}^{Q} \in \mathbb{R}^{E \times H} \Rightarrow \mathbf{Q} \in \mathbb{R}^{B \times F \times H}\right)\\
&\mathbf{K}=\mathbf{X} \mathbf{W}^{K}, \quad\left(\mathbf{W}^{K} \in \mathbb{R}^{E \times H} \Rightarrow \mathbf{K} \in \mathbb{R}^{B \times F \times H}\right)\\
&\mathbf{V}=\mathbf{X} \boldsymbol{W}^{V}, \quad\left(\boldsymbol{W}^{V} \in \mathbb{R}^{E \times H} \Rightarrow \mathbf{V} \in \mathbb{R}^{B \times F \times H}\right)
\end{aligned}
Q = X W Q , ( W Q ∈ R E × H ⇒ Q ∈ R B × F × H ) K = X W K , ( W K ∈ R E × H ⇒ K ∈ R B × F × H ) V = X W V , ( W V ∈ R E × H ⇒ V ∈ R B × F × H )
在上述变换基础上进行如下计算,得到输入中每个词和自身及其他词之间的关系权重S = softmax ( Q K ⊤ H )
\mathbf{S}=\operatorname{softmax}\left(\frac{\mathbf{Q K}^{\top}}{\sqrt{H}}\right)
S = s o f t m a x ( H Q K ⊤ )
上述变换 K T K^T K T 表示对张量的最内部矩阵进行转置,因此K ⊤ ∈ R B × H × F \mathbf{K}^{\top} \in \mathbb{R}^{B \times H \times F} K ⊤ ∈ R B × H × F ,Q K ⊤ \mathrm{QK}^{\top} Q K ⊤ 表示相同维度下张量 Q 和张量 K T K^T K T 最内部矩阵执行矩阵乘法运算 (即 numpy.matmul 运算),因此有S ∈ R B × F × F \mathbf{S} \in \mathbb{R}^{B \times F \times F} S ∈ R B × F × F ,该张量表示输入数据中每个词和自身及其他词的关系权重,每一行的得分之和为 1,即∀ i , j np.sum ( S [ i , j , : ] ) = 1
\forall i, j \quad \operatorname{np.sum}(\mathbf{S}[i, j,:])=1
∀ i , j n p . s u m ( S [ i , j , : ] ) = 1
基于该得分即可得到,每个词在当前上下文下的新的向量表示,公式如下:x h = S V ⇒ X h ∈ R B × F × H
\mathbf{x}^{h}=\mathbf{S V} \quad \Rightarrow \mathbf{X}^{h} \in \mathbb{R}^{B \times F \times H}
x h = S V ⇒ X h ∈ R B × F × H
考虑到 Transformer 采用了 N 个注意力头,因此最终产生了集合大小为 N 的注意力集合{ X h 1 , … , X h N } \left\{\mathbf{X}^{h_{1}}, \ldots, \mathbf{X}^{h_{N}}\right\} { X h 1 , … , X h N } ,将该集合中中的所有张量按照最后一个维度进行拼接,并采用矩阵W O ∈ R E × E \boldsymbol{W}^{O} \in \mathbb{R}^{E \times E} W O ∈ R E × E 进行变换,得到最终生成的自注意力输入数据,公式如下:X a = numpy.concatenate ( ( X h 1 , … , X h N ) , axis = − 1 ) W O
\mathbf{X}^{a}=\text { numpy.concatenate }\left(\left(\mathbf{X}^{\mathrm{h}_{1}}, \ldots, \mathbf{X}^{\mathrm{h}_{\mathrm{N}}}\right), \text { axis }=-1\right) \boldsymbol{W}^{O}
X a = numpy.concatenate ( ( X h 1 , … , X h N ) , axis = − 1 ) W O
因此有X a ∈ R B × F × E \mathbf{X}^{a} \in \mathbb{R}^{B \times F \times E} X a ∈ R B × F × E 。
考虑到多头注意力可以并行运算,为了充分发挥向量化计算并行效率,实际实现中往往采用如下表示方案:X par = reshape ( X , to shape = [ B × F , N × H ] ) Q p a r = X p a r W Q p a r ( W Q p a r ∈ R ( N × H ) × ( N × H ) ⇒ Q p a r ∈ R ( B × F ) × ( N × H ) ) K p a r = X p a r W K p a r ( W K p a r ∈ R ( N × H ) × ( N × H ) ⇒ K p a r ∈ R ( B × F ) × ( N × H ) ) V p a r = X p a r W V p a r ( W V p a r ∈ R ( N × H ) × ( N × H ) ⇒ V p a r ∈ R ( B × F ) × ( N × H ) )
\begin{aligned}
&\mathbf{X}^{\text {par }}=\text { reshape }(\mathbf{X}, \text { to shape }=[B \times F, N \times H])\\
&\begin{array}{ll}
{\mathbf{Q}^{p a r}=\mathbf{X}^{p a r} \boldsymbol{W}^{Q^{p a r}}} & {\left(\boldsymbol{W}^{Q^{p a r}} \in \mathbb{R}^{(N \times H) \times(N \times H)} \Rightarrow \mathbf{Q}^{p a r} \in \mathbb{R}^{(B \times F) \times(N \times H)}\right)} \\
{\mathbf{K}^{p a r}=\mathbf{X}^{p a r} \boldsymbol{W}^{K^{p a r}}} & {\left(\boldsymbol{W}^{K^{p a r}} \in \mathbb{R}^{(N \times H) \times(N \times H)} \Rightarrow \mathbf{K}^{p a r} \in \mathbb{R}^{(B \times F) \times(N \times H)}\right)} \\
{\text { V }^{p a r}=\mathbf{X}^{p a r} \boldsymbol{W}^{V^{p a r}}} & {\left(\boldsymbol{W}^{V^{p a r}} \in \mathbb{R}^{(N \times H) \times(N \times H)} \Rightarrow \mathbf{V}^{p a r} \in \mathbb{R}^{(B \times F) \times(N \times H)}\right)}
\end{array}
\end{aligned}
X par = reshape ( X , to shape = [ B × F , N × H ] ) Q p a r = X p a r W Q p a r K p a r = X p a r W K p a r V p a r = X p a r W V p a r ( W Q p a r ∈ R ( N × H ) × ( N × H ) ⇒ Q p a r ∈ R ( B × F ) × ( N × H ) ) ( W K p a r ∈ R ( N × H ) × ( N × H ) ⇒ K p a r ∈ R ( B × F ) × ( N × H ) ) ( W V p a r ∈ R ( N × H ) × ( N × H ) ⇒ V p a r ∈ R ( B × F ) × ( N × H ) )
在上述并行计算基础上通过如下计算得到词和自身及其他词的关系权值:v p a r = numpy.reshape (Vpar, ( B , F , N , H ) ) v’art = numpy. transpose (V p a r , axes = [ 0 , 2 , 1 , 3 ] ) ⇒ V p a r t ∈ R B × N × F × H X h = S p a r v p a r t ⇒ X h ∈ R B × N × F × H X h = numpy. transpose ( X h , axes = [ 0 , 2 , 1 , 3 ] ) ⇒ X h ∈ R B × F × N × H X h = numpy.reshape ( X h , F , N × H ) ) ⇒ X h ∈ R B × F × E
\begin{array}{l}
{\left.\mathbf{v}^{p a r}=\text { numpy.reshape (Vpar, }(B, F, N, H)\right)} \\
{ \text { v'art }\left.=\text { numpy. transpose (V }^{p a r}, \text { axes }=[0,2,1,3]\right) \Rightarrow \mathbf{V}^{p a r^{t}} \in \mathbb{R}^{B \times N \times F \times H}} \\
{\mathbf{X}^{h}=\mathbf{S}^{p a r} \mathbf{v}^{p a r^{t}} \Rightarrow \mathbf{X}^{h} \in \mathbb{R}^{B \times N \times F \times H}} \\
{\mathbf{X}^{h}=\text { numpy. transpose }\left(\mathbf{X}^{h}, \text { axes }=[0,2,1,3]\right) \Rightarrow \mathbf{X}^{h} \in \mathbb{R}^{B \times F \times N \times H}} \\
{\left.\mathbf{X}^{h}=\text { numpy.reshape }\left(\mathbf{X}^{h}, F, N \times H\right)\right) \Rightarrow \mathbf{X}^{h} \in \mathbb{R}^{B \times F \times E}}
\end{array}
v p a r = numpy.reshape (Vpar, ( B , F , N , H ) ) v’art = numpy. transpose (V p a r , axes = [ 0 , 2 , 1 , 3 ] ) ⇒ V p a r t ∈ R B × N × F × H X h = S p a r v p a r t ⇒ X h ∈ R B × N × F × H X h = numpy. transpose ( X h , axes = [ 0 , 2 , 1 , 3 ] ) ⇒ X h ∈ R B × F × N × H X h = numpy.reshape ( X h , F , N × H ) ) ⇒ X h ∈ R B × F × E
3.解码器自注意力
解码器的自注意力和编码器的自注意力基本完全一致,需要注意的是解码过程是one by one的生成过程,因此输出数据中的每个词在进行自注意力的过程时,仅可以看到当前输出位置的所有前驱词的信息,因此需要对输出数据中的词进行掩码操作,该操作即对应上面的左图上的掩码操作。该掩码操作相当于执行如下操作:A = Q K ⊤ + M S p a r = softmax ( A H ) ⇒ S p a r ∈ R B × N × T × T
\begin{aligned}
\mathbf{A} &=\mathbf{Q K}^{\top}+\mathbf{M} \\
\mathbf{S}^{p a r} &=\operatorname{softmax}\left(\frac{\mathbf{A}}{\sqrt{H}}\right) \Rightarrow \mathbf{S}^{p a r} \in \mathbb{R}^{B \times N \times T \times T}
\end{aligned}
A S p a r = Q K ⊤ + M = s o f t m a x ( H A ) ⇒ S p a r ∈ R B × N × T × T
其中M ∈ R 1 × 1 × T × T \mathbf{M} \in \mathbb{R}^{1 \times 1 \times T \times T} M ∈ R 1 × 1 × T × T 为掩码,其最内部矩阵为方阵,该方阵主对角线及以下元素均为 0,主对角线以上元素为− ∞ -\infty − ∞ 。譬如 T = 5 时,最内部方阵内容如下:M = [ 0 − ∞ − ∞ − ∞ − ∞ 0 0 − ∞ − ∞ − ∞ 0 0 0 − ∞ − ∞ 0 0 0 0 − ∞ 0 0 0 0 0 ]
\boldsymbol{M}=\left[\begin{array}{ccccc}
{0} & {-\infty} & {-\infty} & {-\infty} & {-\infty} \\
{0} & {0} & {-\infty} & {-\infty} & {-\infty} \\
{0} & {0} & {0} & {-\infty} & {-\infty} \\
{0} & {0} & {0} & {0} & {-\infty} \\
{0} & {0} & {0} & {0} & {0}
\end{array}\right]
M = ⎣ ⎢ ⎢ ⎢ ⎢ ⎡ 0 0 0 0 0 − ∞ 0 0 0 0 − ∞ − ∞ 0 0 0 − ∞ − ∞ − ∞ 0 0 − ∞ − ∞ − ∞ − ∞ 0 ⎦ ⎥ ⎥ ⎥ ⎥ ⎤
其余操作和编码器自注意力机制一致,唯一不同的是此时需要向上面那样将输入数据换成 Y ∈ R N × T × E \mathbf{Y} \in \mathbb{R}^{N \times T \times E} Y ∈ R N × T × E ,因此所有的 F F F 均需换成 T T T 。
4.编码解码注意力
编码解码注意力和自注意力类似,唯一不同的是计算 Q, K, V 使用的数据有所区别,计算
Q 时采用 Y,计算 K 和 V 时采用 X,因此有:
X p a r = \mathbf{X}^{\mathrm{par}}= X p a r = reshape ( X , to shape = [ B × F , N × H ] ) (\mathbf{X}, \text { to shape }=[B \times F, N \times H]) ( X , to shape = [ B × F , N × H ] ) Y par = \mathbf{Y}^{\text {par }}= Y par = reshape ( Y , to shape = [ B × T , N × H ] ) (\mathbf{Y}, \text { to shape }=[B \times T, N \times H]) ( Y , to shape = [ B × T , N × H ] )
Q ende-par = Y par W Q ende-par ( W Q ende-par ∈ R ( N × H ) × ( N × H ) ) \mathbf{Q}^{\text {ende-par }} =\mathbf{Y}^{\text {par }} \boldsymbol{W}^{Q \text { ende-par }}\left(\boldsymbol{W}^{Q \text { ende-par }} \in \mathbb{R}^{(N \times H) \times(N \times H)}\right) Q ende-par = Y par W Q ende-par ( W Q ende-par ∈ R ( N × H ) × ( N × H ) )
K ende-par = X par W K coldeper ( W K ende-par ∈ R ( N × H ) × ( N × H ) ) \mathbf{K}^{\text {ende-par }} = \mathbf{X}^{\text {par }} \boldsymbol{W}^{K^{\text {coldeper }}} \quad\left(\boldsymbol{W}^{K \text { ende-par }} \in \mathbb{R}^{(N \times H) \times(N \times H)}\right) K ende-par = X par W K coldeper ( W K ende-par ∈ R ( N × H ) × ( N × H ) ) V ende-par = X par W Vpar ( W Vende-par ∈ R ( N × H ) × ( N × H ) ) \mathbf{V}^{\text {ende-par }}= \mathbf{X}^{\text {par }} \boldsymbol{W}^{\text {Vpar }} \quad\left(\boldsymbol{W}^{\text {Vende-par }} \in \mathbb{R}^{(N \times H) \times(N \times H)}\right) V ende-par = X par W Vpar ( W Vende-par ∈ R ( N × H ) × ( N × H ) )
因此有:S ende-par = softmax ( Q ende-par t K ende-par t ⊤ H ) ⇒ S ende-par ∈ R B × N × T × F
\mathbf{S}^{\text {ende-par}}=\operatorname{softmax}\left(\frac{\mathbf{Q}^{\text {ende-par}^{t}} \mathbf{K}^{\text {ende-par}^{t \top}}}{\sqrt{H}}\right) \Rightarrow \mathbf{S}^{\text {ende-par}} \in \mathbb{R}^{B \times N \times T \times F}
S ende-par = s o f t m a x ( H Q ende-par t K ende-par t ⊤ ) ⇒ S ende-par ∈ R B × N × T × F
y ende-h = S ende-par V ende-par ⇒ Y ende-h ∈ R B × N × T × H
\mathbf{y}^{\text {ende-h}}=\mathbf{S}^{\text {ende-par}} \mathbf{V}^{\text {ende-par}} \Rightarrow \mathbf{Y}^{\text {ende-h}} \in \mathbb{R}^{B \times N \times T \times H}
y ende-h = S ende-par V ende-par ⇒ Y ende-h ∈ R B × N × T × H
其余计算过程和编码器自注意力机制类似。