Efficient Transformer,泛指一切致力于降低Transformer的二次复杂度的工作,开始特指针对Attention的改进,后来更一般的思路,如傅里叶变换、线性RNN等,也被归入这个范畴。不得不说,为了降低Transformer的二次复杂度,各路大牛可谓是“八仙过海,各显神通”,各种神奇的思路“百花齐放”,笔者也从中学习到了不少理论知识。然而,尽管Efficient Transformer在理论上是精彩的,但实际上该领域一直都是不愠不火的状态,并没有实际表现十分出色的模型,在LLM火爆的今天,甚至已经逐渐淡出了大家的视野,也淡出了笔者的兴趣范围。

不过,最近有一篇论文《Transformer-VQ: Linear-Time Transformers via Vector Quantization》,却让笔者为之拍案叫绝。作者非常高明地洞察到,只需要对标准Attention的Key做一下VQ(Vector Quantize),复杂度就会自动降低为线性!这种线性化思路保留了标准Attention的形式,是标准Attention到线性Attention的一个完美过渡,同时最大程度上保留了标准Attention的能力。

高效难题 #

说起来,本站也算是比较早关注Efficient Transformer相关工作了,最早可以追溯到2019年解读Sparse Transformer的一篇博客《为节约而生:从标准Attention到稀疏Attention》。此后,陆续写的关于Efficient Transformer的其他博文还有

《线性Attention的探索:Attention必须有个Softmax吗?》

《Performer:用随机投影将Attention的复杂度线性化》

《Nyströmformer:基于矩阵分解的线性化Attention方案》

《Transformer升级之路:3、从Performer到线性Attention》

《线性Transformer应该不是你要等的那个模型》

《FLASH:可能是近来最有意思的高效Transformer设计》

《Google新作试图“复活”RNN:RNN能否再次辉煌?》

然而,正如本文开头所说,尽管Efficient Transformer已有不少工作,也曾被大家寄予厚望,但实际上该领域一直都没什么能“出圈”的作品,这其中的原因可能是:

1、不少Efficient Transformer的提速以牺牲效果为代价;

2、很多Efficient Transformer的复杂度降低仅仅是理论上的,实际使用提升不明显;

3、有些Efficient Transformer难以用来训练Causal LM,所以在LLM流行的今天就没有了用武之地;

4、Flash Attention的出现表明即便是标准的Transformer仍有很大的提速空间。

VQ一下 #

那么,Transformer-VQ为何又具备的“出圈”潜力?

简单来说,Transformer-VQ就是对Attention的Key向量序列进行了“聚类”,并用所属类的类别中心近似原向量,然后Attention的复杂度就变成线性了。也就是说,Transformer-VQ仅仅改变了Key的形似,其余部分(理论上)完全不变,所以这是一种对Attention改动非常小的线性化方案,也能非常清楚体现出线性化后损失的精度在哪里(即用类别中心近似原向量的差距)。

铺垫得有点多了,现在我们正式介绍Transformer-VQ。首先,我们假设Q,KRn×dk,VRn×dv,标准Attention就是
softmax(QK)V


简单起见,这里省略了scale factor。Transformer-VQ改为
softmax(QˆK)V,ˆK=VQ(K,C)

其中CRc×dk是训练参数,也是VQ的编码表(Codebook)。对了,这里的“VQ”就是指VQ-VAE中的VQ,不了解的读者可以移步参考《VQ-VAE的简明介绍:量子化自编码器》《简单得令人尴尬的FSQ:“四舍五入”超越了VQ-VAE》,这里不重复介绍了。总之,经过VQ之后,最直接的表现就是K的每个向量都变成了C中与之最相近的那个,这意味着ˆK的每个向量都是C的向量之一,用数学的语言就是说KRn×dk变成了ˆKCn

Encoder #

当然,直接按照式(2)去实现Transformer-VQ的话,复杂度还是二次的,但由于ˆK的每个向量都是C的向量之一,所以我们可以先算exp(QC),然后从中“挑出exp(QˆK)对应的结果,而由于C的大小是固定的,所以关键运算QC的复杂度是线性的,这就是Transformer-VQ能线性化的原理(我们不妨称为“挑出”技巧)。

作为铺垫,我们先考虑双向注意力的Encoder情形。由于
softmax(QK)V=exp(QK)Vexp(QK)1n×1


这里1n×1指的是n×1大小的全1矩阵,分母可以视为分子的一个特殊形式,所以我们只需要考虑分子exp(QK)V。由于ˆK的每个向量都是C中之一,所以我们可以构建一个one hot矩阵Δ{0,1}n×c,其中Δi{0,1}c是一个one hot向量,如果1所在的维度为j,那么ˆKi=Cj,于是ˆK=ΔC

于是对于Transformer-VQ来说有
exp(QˆK)V=exp(QCΔ)V=exp(QC)ΔV=exp(QC)(ΔV)


很明显,这里最关键的地方就是第二个等号!对于one hot矩阵Δ,右乘以它的转置可以从exp中分离出来,这就是原理中的“挑出”技巧的数学表述。分离出来之后,由于矩阵乘法结合律,Δ可以先跟V相乘,得到一个c×dv的矩阵,而exp(QC)是一个n×c的矩阵,乘以ΔV就得到一个n×dv的矩阵,总的理论复杂度是O(ncdk+ncdv+ncdv)=O(n)

最后,根据式(3),将exp(QˆK)V的结果代入去,就可以计算完整的Attention结果(可能还要加一些避免溢出的细节),整个过程可以在线性复杂度内完成。

Decoder #

现在我们来考虑单向注意力的Decoder,这是训练生成模型的关键,也是当前LLM的基础。有了Encoder的铺垫后,Decoder理解起来也就没那么困难了。假设Qi,ˆKjR1×dk,VjR1×dv是向量序列Q,ˆK,V的行向量之一,那么对于Decoder的分子有
Oi=jiexp(QiˆKj)Vj=jiexp(QiCΔj)Vj=jiexp(QiC)ΔjVj=exp(QiC)jiΔjVj


如果c×dv不大,那么最后的式子可以直接用cumsum算子完成,不过一般情况下,尤其是Multi-Heaad时,为了节省显存,通常是跟《线性Attention的探索:Attention必须有个Softmax吗?》中的“自回归生成”一节一样,转为RNN来递归计算,即设Ui=jiΔjVjRc×dv,那么
Oi=exp(QiC)Ui,Ui=Ui1+ΔiVi

在推理阶段这样step by step递归计算自然是没问题,但训练阶段step by step的话可能会比较慢,我们可以改为block by block来加速:不失一般性,设n=lml代表block_size,m代表block数目,block切片[il:(i+1)l]简写为[i],那么
O[i]=exp(Q[i]ˆK[i]+M)V[i]+j<iexp(Q[i]ˆK[j])V[j]=exp(Q[i]ˆK[i]+M)V[i]+j<iexp(Q[i]CΔ[j])V[j]=exp(Q[i]ˆK[i]+M)V[i]+exp(Q[i]C)j<iΔ[j]V[j]

其中M{,0}l×l是下三角的Attention Mask,即当ijMi,j=0,否则Mi,j=。于是记Ui=j<iΔ[j]V[j]后,我们有
O[i]=exp(Q[i]ˆK[i]+M)V[i]+exp(Q[i]C)Ui1,Ui=Ui1+Δ[i]V[i]

这样我们就将递归步数减少为m了,可以在保证线性效率的同时,更充分发挥硬件的并行能力。用同样的方式也可以计算分母,最后相除得到完整的Attention结果

局域增强 #

就这样完了?并不是,如果仅仅是这样的话,Transformer-VQ可能跟以往基于矩阵分解的Kernelized Attention如Performer并没有太多区别。当序列长度n远大于编码表大小c时,由抽屉原理我们知道部分编码向量必然会反复出现,甚至可以合理猜测所有编码向量应该会均匀分布在整个序列中。这样一来,邻近token的Attention就会跟远处某些token的Attention一样,也就是说模型无法区分远近,这本质上就是所有Kernelized Attention都存在的低秩问题。

已有的经验告诉我们,对于语言模型来说,相对于远处的token的来说邻近的token往往更为重要,所以一个好的语言模型架构应该具有区分远近的能力。为此,Transformer-VQ选择在QˆK之后,加上一个Sliding Window形状的Attention Bias(记为B),来对邻近token进行加权,如下图:

Window Attention Bias示意图

Window Attention Bias示意图

从最后一个图可以看出,如果将Window大小直接设为block大小l,即i<j或者ijlBi,j=0,那么在分block计算时,矩阵B顶多影响最邻近的两个block,再远的block依旧可以用“挑出”技巧来线性化。为了便于下面的推导,我们记B[i,j]=B[il:(i+1)l,jl:(j+1)l],那么
O[i]=exp(Q[i]ˆK[i]+B[i,i])V[i]+exp(Q[i]ˆK[i1]+B[i,i1])V[i1]+j<i1exp(Q[i]ˆK[j])V[j]=exp(Q[i]ˆK[i]+B[i,i])V[i]+exp(Q[i]ˆK[i1]+B[i,i1])V[i1]+j<i1exp(Q[i]CΔ[j])V[j]=exp(Q[i]ˆK[i]+B[i,i])V[i]+exp(Q[i]ˆK[i1]+B[i,i1])V[i1]+exp(Q[i]C)j<i1Δ[j]V[j]


所以很明显,有(约定V[1],U[1],U[2]都是全零矩阵)
O[i]=exp(Q[i]ˆK[i]+B[i,i])V[i]+exp(Q[i]ˆK[i1]+B[i,i1])V[i1]+exp(Q[i]C)Ui2Ui=Ui1+Δ[i]V[i]

笔者认为,B的引入是Transformer-VQ是跟其他Kernelized Attention拉开差距的关键,为了减少参数量且支持变长生成,我们约束B的非零部分为“Toeplitz矩阵”,即Bi,jij的函数,此时B就相当于加性相对位置编码。除了这种做法外,也可以考虑换为笔者之前提出的ReRoPE,它是旋转位置编码的窗口版,跟B具有同样的相对位置编码形状。

梯度回传 #

等等,我们好像忘记了点什么。了解VQ-VAE的读者都知道,“ˆK的每个向量都是C的向量之一”只是前向传播的表现,反向传播用的可是原始的K,这意味着即便不同位置的ˆKj等于同一个Ck,但它们的梯度却不相等,这叫做STE(Straight-Through Estimator)。由于STE的存在,“挑出”技巧理论上仅可用于推理阶段,训练阶段是无法线性化的。

没有其他办法了吗?确实如此,如果我们坚持要获得精确的梯度结果,那么并没有线性化效率的方案。然而,考虑到VQ的梯度本身就是近似的,所以Attention获取精确的梯度似乎也没多大必要。于是作者想了个折衷的方案:依然是按照式(10)进行递归计算,仅在前两项使用STE(Key序列可以获得梯度),而Ui1的梯度直接停掉(stop_gradient算子)。这样我们就保持了模型的线性性,同时也已经保留了最重要的梯度(邻近的两个block),算是一个比较合理的近似方案。从这一点来看,Transformer-VQ跟Transformer-XL很像,Transformer-XL在递归的同时也停掉了历史窗口的梯度,即历史窗口可以参与递归计算,不传递梯度。

解决了梯度回传问题之后,在自回归交叉熵损失的基础上,再上VQ带来的用来更新编码表的辅助loss,就得到完整的训练目标了。当然,对于编码表的更新,Transformer-VQ采用了直接滑动平均的方案,所以只补充了Key的辅助loss,这些细节读者在熟悉VQ-VAE之后,稍微看一下原论文就理解了。

实验结果 #

这一节我们来看一下原论文的实验结果。作者已经将代码开源如下:

值得指出的是,作者做VQ的基础架构并不是常规的MHA(Multi-Head Attention),而是笔者一直很推崇的GAU(Gated Attention Unit)+Softmax,Transformer-VQ更准确的命名应该是“GAU-VQ”,不了解GAU的读者可以参考《FLASH:可能是近来最有意思的高效Transformer设计》《听说Attention与Softmax更配哦~》。简单来说,GAU本身比MHA有着更高的效率,配合上VQ技巧后,就更加“如虎添翼”了。

实验方面,作者做了语言模型(ENWIK8、PG-19)和图像生成(IMAGENET64),所有的实验中的编码表大小都是c=512。模型最大参数量为1.3B,虽然比不上主流的大模型参数量,但其实对于科研来说不算小了。实验结果总体来说算得上优异:

PG-19的实验结果

PG-19的实验结果

IMAGENET64的实验结果

IMAGENET64的实验结果

最后,让人惊奇的是,Transformer-VQ的作者只有一个,并且身份是“Independent Researcher”。

发散思考 #

笔者发现,从Transformer-VQ出发,可以联系到非常多的研究主题,这也是为什么笔者如此欣赏它的原因之一。

首先,再次为作者惊人的洞察力点赞,“只需VQ一下Key,Transformer的复杂度就会变成线性”这个发现实在太美妙了,它实现了标准Attention到线性Attention的自然过渡,并且可以通过加Attention Bias的方式让它比很多的Kernelized Attention都有效。然后,通过VQ进行“聚类”的方式,也比LinformerNyströmformer等更为高明,因为它防止了未来信息的泄漏,可以自然地用来做Causal的语言模型。

我们知道,VQ本质上也是将序列转为离散id的运算,这跟Tokenizer的作用是非常相似的。从这个角度来看,Transformer-VQ跟MegaByte等模型一样,都是将Tokenizer内置在模型之中,并且相比MegaByte,VQ这一操作跟我们传统意义上的Tokenizer更为相似、直观。所以,Transformer-VQ实际上非常适合用来训练直接以Bytes输入的“No Tokenizer”模型,事实上,上述ENWIK8实验就是Bytes输入,Transformer-VQ效果明显优于MegaByte。

相比近来出的RetNet,Transformer-VQ没有显式的远程衰减,所以Long Context能力有可能会更好,同时由于Key经过了VQ,都是有限集合之一,所以不会出现没有学过的Key,因此长度外推能力大概率也会更好。虽然Transformer-VQ的基础架构GAU只是Single-Head的,但它在递归过程中模型记忆状态大小是ΔiViRc×dv,在默认的设置中,这比Multi-Head的RetNet还大(RetNet的记忆状态大小是nd2k,默认设置下dv=2ndk),因此,记忆容量理论上是足够的。

由于上一篇文章刚好写了《简单得令人尴尬的FSQ:“四舍五入”超越了VQ-VAE》,可能会有读者想知道可否用更简单的FSQ取代VQ?笔者认为比较难,原因其实在上一篇文章给出了:第一,c=512还属于VQ优于FSQ的编码数量范围,所以换FSQ大概率会掉效果;第二,由于每层Attention的Key都要被VQ,所以平均来说VQ的Encoder和Decoder都不强,这种情况VQ近似精度更高,FSQ更适合Decoder和Decoder都足够强的场景;第三,Transformer-VQ需要用的是Key被VQ之后的中心向量而不是id,而FSQ则直接得到id,反而不容易恢复为近似的中心向量。

除此之外,用VQ而不是FSQ,使得Transformer-VQ有希望从现有的预训练模型如LLAMA2中微调过来,而不单单是从零训练。因为VQ具有鲜明的几何意义,跟K-Means有诸多相通之处,我们可以从现有预训练模型出发,选取一些样本计算出Key,对Key进行K-Means得到中心向量作为编码表的初始化,然后在原模型基础上加上VQ进行微调。不过Transformer-VQ不大好适配RoPE,所以要如前面所说,RoPE的模型要换成ReRoPE再VQ比较好,此时就可以不用加Bias了。

总之,在笔者眼中,Transformer-VQ在众多Efficient Transformer工作中,是非常独特、出色而又潜力深厚的之一。

文章小结 #

本文介绍了一个名为Transformer-VQ的Efficient Transformer方案,它基于“只需VQ一下Key,Transformer的复杂度就会变成线性”的观察结果进行展开,个人认为是一种非常独特且亮眼的线性化思路,实验结果也很优异。它既可以理解为一种更高明的线性Attention/RNN模型,也可以理解为一个带有“可训练的Tokenizer”的Attention模型。

转载到请包括本文地址:https://spaces.ac.cn/archives/9844

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Nov. 09, 2023). 《VQ一下Key,Transformer的复杂度就变成线性了 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9844

@online{kexuefm-9844,
        title={VQ一下Key,Transformer的复杂度就变成线性了},
        author={苏剑林},
        year={2023},
        month={Nov},
        url={\url{https://spaces.ac.cn/archives/9844}},
}