本文带来各类奇怪的IT百科知识。

文章目录

Reformer要解决的问题

Reformer怎么解决以上三个问题

Reformer时间、空间复杂度汇总

我们接下来详解以上三个改进

一. hash近似Attention

二. 可逆网络

三. 可逆网络分块

四. 总结一下:Reformer确实在内存、性能优化方面明显改进:提出或借鉴的方法比较有意思。

如有不对地方:欢迎指出:谢谢。

参考链接:

论文:https://arxiv.org/abs/2001.04451

Reformer要解决的问题

attention的内存、计算复杂度是文本长度L的平方复杂度即O(L* L):self-attention每个位置都要看整句的其他每个位置:, 这在超长文本时:比如文章:是不可接受的。传统transformer一般是按512长度分块:这样损失了块与块之间的互信息。

原生transformer训练是需要的内存是层数的倍数:因为反向传播是需要存储每层的结果来求误差的梯度:。

feed-forward层的维度一般远大于模型的维度:一般是两倍::这样feed-forward层反向传播空间复杂度要求大很多。

Reformer怎么解决以上三个问题

基于局部近似hash的近似attention把复杂度从O(L * L)降低到O(L * log(L))

可逆网络解决训练的内存是层数倍数的问题:可逆网络由于后面层可以推出前面层:所以:只用保存最后一层即可:。

可逆网络分块计算:复杂度可以从feed-forwad维度降到模型维度

Reformer时间、空间复杂度汇总

以下图非常重要:我先直接把每个模块空间、时间复杂度拿出来单独看。之后所有工作:都是围绕这个时间、空间复杂度展开。

第一列中:Transformer是原生的: Reversible Transformer就是论文引入的可逆Transformer:接下来详细说:, Chunked Reversible Transformer就是可逆网络分块处理:接下来详细说::LSH Transformer就是文中引入的局部近似Hash:接下来详细说:, Reformer就是上述三个汇总。

参数解释如下:b是batch size, l是输入文本长度: dff是feed forwd层的维度: dmodel是模型的维度: nh是multi-head的数量:nl是层数: c是分块的数量:nr是模型hash的次数。先给出这个图:来个直观的感受。接下来具体说每个模块怎么实现的。

我们接下来详解以上三个改进

一. hash近似Attention

先看看经典transformer架构:Attention原生公式如下

当Self attention是Q K都是输入文本自身:所以Q每个位置会看K的每个位置。所以复杂度是O(L * L)。由于有这个softmax存在:所以:实际上:只有Q*K很大:由于是内积:QK近似就会比较大:的点才会启到作用:比较小的点:就会比较接近于0:不起作用。这个就暗示这个矩阵是稀疏的:也是近似attention优化的依据。很多工作:都是建立在稀疏矩阵的优化上的。

上述Attention公式:如果用稀疏矩阵表示的话:就是如下图:Pi表示i需要关注的集合:m(j, Pi)表示一个mask:如果j在Pi里就是减去一个无穷大数:否则减去0。z(i,Pi): 表示归一化函数:代替了softmax函数。外层是一个指数函数exp:为了把乘除法都转换成加减法。

先引用一个局部近似的Hash函数。函数如下图所示:x,y通过映射到球上的两个点:随便旋转球:如图中上面:如果x,y 离的比较远:容易分到不同的坐标轴分区里:如果x,y近似就是离的比较近:就会常常落到同一个坐标轴分区内里面。通过这样的的一个局部hash方法:就可以把近似的Q分到相同的桶内。

有了这个hash有什么用呢:如下图:展示了近似attention的过程。

图中左边展示了LSH attention的详细过程。

Hash:首先:用hash分桶:图中用颜色展示了不同的桶。复杂度是O:L:

按桶排序排序O:L* LOG:L:::相同桶放到一起:形成第二行

分块:由于分桶是随机的:这就有可能所以的Q都到了一个桶:桶内的Q每个位置都看其他位置:那复杂度还是O:L*L):为了避免这种情况:需要分块.

计算attention: 如下图中:相同块内的互相attention:如前4个小方块:,不同块如果第一个位置属于前面的块:那么就需要和前面块的都关注:如第五个小方块:。公式表达出来就如图:。Pi表示第i位置需要关注的位置j的集合, Si表示i位置排序后的集合:Sj表示j位置排序后的集合: Si Sj都是按照块的大小:m:分割的:那么Pi的j的下标集合可能在上一块也可能就在当前块中。

复杂度分析:这步的总体计算复杂度O(c * L), c为块的长度。c是定制:所以是线性复杂度。但是由于论文中c:128:所以:c * L 复杂度是远大于排序的L * log(L), 是主要的耗时。这也就是Reformer时间复杂度nr * l * c的由来。在超长文本中相比于L * L是有优势的 。如果L 小于1248:Reformer更慢了。

图中右边展示了一个qk例子, 表示qk的块具体是如何分的。

右图a: 黑点就是假设Q和K点乘比较大的点。从图中可以看出是一个稀疏的矩阵。

右图b:展示的是qk不同的情况:横坐标就是hash分桶排序后Q, 纵坐标就是K hash分桶排序后的k,颜色不同代表Hash后的相同的桶内也就是q1 q2 q4 k1近似: q3 q6 k2 k6近似:q5 k3 k4 k5近似。桶内互相关注:所以有蓝红黄三个分块。

图c和图b近似:图c就是self-attention的情况:qk相同:所以都是正方形的。

图d表示的就是分块的过程:为了怕点都分到同一个桶中:强制按照2格来分成了三个块。

既然是Hash:当然就有可能:hash分错的情况:相当于是漏网之鱼。论文提出多轮hash就可以解决此问题。用公式表达出来就是: Pi表示i需要算attention的下标j的集合:h(qi)表示qi的hash值。经过Nrounds轮并集就近似算全了。作者实验发现8轮hash:就能和原始的attention的结果相似:如下图16轮hash和8轮的接近。

速度评测如下图可以看出:文本总长度一定:随着batch里面的文本长度增长:LSH时间是平的:而原生的是明显增大。这个图中可以看出文本小于1024长度的时候:reformer的时间复杂度效果并不明显。这是因为reformer是nr * c * L, 原生是L* L: nr : 8, c : 128是:nr * c * L : L * L 的。

以上的hash优化性能的原理:是建立在Q 和K比较相似的基础之上,如果QK完全不相同:分不到一个桶内:那不就变成了没有attention了。因此:论文中提出Q K共享参数:那么QK就会非常接近:V的参数是单独的。实验发现在enwik8, imagenet64中效果和原生transformer差不多。

二. 可逆网络

由于训练时:误差反向传播时:需要保存每一层的输入输出:所以内存需要nl的倍数:一般层数比较大:gpu内存就会容纳不下。论文引入可逆网络。只用保存最后一层:当反向传播是:直接根据后面一层反推出前面一层的输出即可。可逆网络公式如下

如上公式:我们可以看出:如果知道输入X(直接分块成X1, X2),可以直接求出Y1 :attention的输出:, Y2:feed forward的输出:, 反过来:如果知道Y1, Y2也能直接反推出X1 , X2, 因为X2 : Y2 – FeedForwad(Y1):知道了X2有可以直接算出X1: X1: Y1 – Attention(X2)。所以训练的时候只有保存最后一层的输出Y:就可以轻易的求出输入X啦:也就是前一层的输出。这样就把空间复杂度从多层变成了单层。如下图中:空间复杂度中nl去掉了。

实验中:可以发现可逆网络和一般的transformer多个step之后非常接近。

三. 可逆网络分块

由于feed_forward层的维度远大于模型的维度。dff一般是4K。但是feed forward层又是和位置无关的:所以:可以分成c个块:每个块内单独计算。

如下图所示:这样复杂度就从b * L * dff降成了 b * L * dmodel。

四. 总结一下:Reformer确实在内存、性能优化方面明显改进:提出或借鉴的方法比较有意思。

如有不对地方:欢迎指出:谢谢。

参考链接:

论文:https://arxiv.org/abs/2001.04451

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐