ReFormer论文解读(THE EFFICIENT TRANSFORMER)
文章目录
Reformer要解决的问题
- attention的内存、计算复杂度是文本长度L的平方复杂度即O(L* L)(self-attention每个位置都要看整句的其他每个位置), 这在超长文本时(比如文章)是不可接受的。传统transformer一般是按512长度分块,这样损失了块与块之间的互信息。
- 原生transformer训练是需要的内存是层数的倍数(因为反向传播是需要存储每层的结果来求误差的梯度)。
- feed-forward层的维度一般远大于模型的维度(一般是两倍),这样feed-forward层反向传播空间复杂度要求大很多。
Reformer怎么解决以上三个问题
- 基于局部近似hash的近似attention把复杂度从O(L * L)降低到O(L * log(L))
- 可逆网络解决训练的内存是层数倍数的问题(可逆网络由于后面层可以推出前面层,所以,只用保存最后一层即可)。
- 可逆网络分块计算,复杂度可以从feed-forwad维度降到模型维度
Reformer时间、空间复杂度汇总
- 以下图非常重要,我先直接把每个模块空间、时间复杂度拿出来单独看。之后所有工作,都是围绕这个时间、空间复杂度展开。
- 第一列中,Transformer是原生的, Reversible Transformer就是论文引入的可逆Transformer(接下来详细说), Chunked Reversible Transformer就是可逆网络分块处理(接下来详细说),LSH Transformer就是文中引入的局部近似Hash(接下来详细说), Reformer就是上述三个汇总。
- 参数解释如下,b是batch size, l是输入文本长度, dff是feed forwd层的维度, dmodel是模型的维度, nh是multi-head的数量,nl是层数, c是分块的数量,nr是模型hash的次数。先给出这个图,来个直观的感受。接下来具体说每个模块怎么实现的。
我们接下来详解以上三个改进
一. hash近似Attention
-
先看看经典transformer架构,Attention原生公式如下
- 当Self attention是Q K都是输入文本自身,所以Q每个位置会看K的每个位置。所以复杂度是O(L * L)。由于有这个softmax存在,所以,实际上,只有Q*K很大(由于是内积,QK近似就会比较大)的点才会启到作用,比较小的点,就会比较接近于0,不起作用。这个就暗示这个矩阵是稀疏的,也是近似attention优化的依据。很多工作,都是建立在稀疏矩阵的优化上的。
- 上述Attention公式,如果用稀疏矩阵表示的话,就是如下图,Pi表示i需要关注的集合,m(j, Pi)表示一个mask:如果j在Pi里就是减去一个无穷大数,否则减去0。z(i,Pi): 表示归一化函数,代替了softmax函数。外层是一个指数函数exp,为了把乘除法都转换成加减法。
- 当Self attention是Q K都是输入文本自身,所以Q每个位置会看K的每个位置。所以复杂度是O(L * L)。由于有这个softmax存在,所以,实际上,只有Q*K很大(由于是内积,QK近似就会比较大)的点才会启到作用,比较小的点,就会比较接近于0,不起作用。这个就暗示这个矩阵是稀疏的,也是近似attention优化的依据。很多工作,都是建立在稀疏矩阵的优化上的。
-
先引用一个局部近似的Hash函数。函数如下图所示,x,y通过映射到球上的两个点,随便旋转球,如图中上面,如果x,y 离的比较远,容易分到不同的坐标轴分区里,如果x,y近似就是离的比较近,就会常常落到同一个坐标轴分区内里面。通过这样的的一个局部hash方法,就可以把近似的Q分到相同的桶内。
-
有了这个hash有什么用呢?如下图,展示了近似attention的过程。
-
图中左边展示了LSH attention的详细过程。
- Hash:首先,用hash分桶,图中用颜色展示了不同的桶。复杂度是O(L)
- 按桶排序排序O(L* LOG(L)),相同桶放到一起,形成第二行
- 分块:由于分桶是随机的,这就有可能所以的Q都到了一个桶,桶内的Q每个位置都看其他位置,那复杂度还是O(L*L),为了避免这种情况,需要分块.
- 计算attention: 如下图中,相同块内的互相attention(如前4个小方块),不同块如果第一个位置属于前面的块,那么就需要和前面块的都关注(如第五个小方块)。公式表达出来就如图,。Pi表示第i位置需要关注的位置j的集合, Si表示i位置排序后的集合,Sj表示j位置排序后的集合, Si Sj都是按照块的大小(m)分割的,那么Pi的j的下标集合可能在上一块也可能就在当前块中。
- 复杂度分析:这步的总体计算复杂度O(c * L), c为块的长度。c是定制,所以是线性复杂度。但是由于论文中c=128,所以,c * L 复杂度是远大于排序的L * log(L), 是主要的耗时。这也就是Reformer时间复杂度nr * l * c的由来。在超长文本中相比于L * L是有优势的 。如果L 小于1248,Reformer更慢了。
-
图中右边展示了一个qk例子, 表示qk的块具体是如何分的。
- 右图a: 黑点就是假设Q和K点乘比较大的点。从图中可以看出是一个稀疏的矩阵。
- 右图b:展示的是qk不同的情况,横坐标就是hash分桶排序后Q, 纵坐标就是K hash分桶排序后的k,颜色不同代表Hash后的相同的桶内也就是q1 q2 q4 k1近似, q3 q6 k2 k6近似,q5 k3 k4 k5近似。桶内互相关注,所以有蓝红黄三个分块。
- 图c和图b近似,图c就是self-attention的情况,qk相同,所以都是正方形的。
- 图d表示的就是分块的过程,为了怕点都分到同一个桶中,强制按照2格来分成了三个块。
-
既然是Hash,当然就有可能,hash分错的情况,相当于是漏网之鱼。论文提出多轮hash就可以解决此问题。用公式表达出来就是, Pi表示i需要算attention的下标j的集合,h(qi)表示qi的hash值。经过Nrounds轮并集就近似算全了。作者实验发现8轮hash,就能和原始的attention的结果相似,如下图16轮hash和8轮的接近。
-
速度评测如下图可以看出,文本总长度一定,随着batch里面的文本长度增长,LSH时间是平的,而原生的是明显增大。这个图中可以看出文本小于1024长度的时候,reformer的时间复杂度效果并不明显。这是因为reformer是nr * c * L, 原生是L* L, nr = 8, c = 128是,nr * c * L = L * L 的。
-
以上的hash优化性能的原理,是建立在Q 和K比较相似的基础之上,如果QK完全不相同,分不到一个桶内,那不就变成了没有attention了。因此,论文中提出Q K共享参数,那么QK就会非常接近,V的参数是单独的。实验发现在enwik8, imagenet64中效果和原生transformer差不多。
-
二. 可逆网络
- 由于训练时,误差反向传播时,需要保存每一层的输入输出,所以内存需要nl的倍数,一般层数比较大,gpu内存就会容纳不下。论文引入可逆网络。只用保存最后一层,当反向传播是,直接根据后面一层反推出前面一层的输出即可。可逆网络公式如下
- 如上公式,我们可以看出,如果知道输入X(直接分块成X1, X2),可以直接求出Y1 (attention的输出), Y2(feed forward的输出), 反过来,如果知道Y1, Y2也能直接反推出X1 , X2, 因为X2 = Y2 - FeedForwad(Y1),知道了X2有可以直接算出X1, X1= Y1 - Attention(X2)。所以训练的时候只有保存最后一层的输出Y,就可以轻易的求出输入X啦,也就是前一层的输出。这样就把空间复杂度从多层变成了单层。如下图中,空间复杂度中nl去掉了。
- 实验中,可以发现可逆网络和一般的transformer多个step之后非常接近。
三. 可逆网络分块
- 由于feed_forward层的维度远大于模型的维度。dff一般是4K。但是feed forward层又是和位置无关的,所以,可以分成c个块,每个块内单独计算。
- 如下图所示,这样复杂度就从b * L * dff降成了 b * L * dmodel。