前面我们用两篇文章《重温SSM(一):线性系统和HiPPO矩阵》《重温SSM(二):HiPPO的一些遗留问题》介绍了HiPPO的思想和推导——通过正交函数基对持续更新的函数进行实时逼近,其拟合系数的动力学正好可以表示为一个线性ODE系统,并且对于特定的基底以及逼近方式,我们可以将线性系统的关键矩阵精确地算出来。此外,我们还讨论了HiPPO的离散化和相关性质等问题,这些内容奠定了后续的SSM工作的理论基础。

接下来,我们将介绍HiPPO的后续应用篇《Efficiently Modeling Long Sequences with Structured State Spaces》(简称S4),它利用HiPPO的推导结果作为序列建模的基本工具,并从新的视角探讨了高效的计算和训练方式,最后在不少长序列建模任务上验证了它的有效性,可谓SSM乃至RNN复兴的代表作之一。

基本框架 #

S4使用的序列建模框架,是如下的线性ODE系统:
x(t)=Ax(t)+Bu(t)y(t)=Cx(t)+Du(t)


这里u,y,DR;xRd;ARd×d;B,CRd×1是转置共轭运算,如果是实矩阵的话,那就是单纯的转置。由于完整的模型通常还会带有残差结构,最后一项Du(t)可以整合到残差里边,所以我们可以直接假设D=0来稍微简化一下形式,但不会降低模型的能力。

该系统具备相似不变性,如果˜AA的相似矩阵,即A=P1˜AP,那么代入整理得
Px(t)=˜APx(t)+PBu(t)y(t)=((P1)C)Px(t)


Px(t)视为一个整体替换原来的x(t),那么新系统的变化是(A,B,C)(˜A,PB,(P1)C),但输出完全不改变。这意味着如果存在A的某个相似矩阵˜A使得计算更加简单,那么可以完全转到˜A中分析而不改变结果,这就是后面一系列分析的核心思路。

特别地,S4将矩阵A选取为HiPPO-LegS矩阵,即
An,k={(2n+1)(2k+1),k<nn+1,k=n0,k>n


这个选择的特别之处在于,我们此前推导LegS所满足的ODE是x(t)=Atx(t)+Btu(t)的形式,而LegT的ODE才是x(t)=Ax(t)+Bu(t)的形式,所以现在就是说LegT的ODE搭配了LegS的A矩阵,因此首先要问的问题是:这样的组合会带来什么影响呢?比如它对历史的记忆是否跟LegS一样依然是完整的、平权的?

指数衰减 #

答案是否定的——S4所选取的ODE系统,关于历史的记忆是指数衰减的,我们可以从两个角度理解这一点。

第一个角度是从《重温SSM(二):HiPPO的一些遗留问题》讨论过的变换出发,将LegS型ODE可以等价地写成:
Ax(t)+Bu(t)=tx(t)=ddlntx(t)


所以设τ=lnt就可以将LegS型ODE变成时间变量为τ的LegT型ODE,也就是S4所用的ODE。我们知道,LegS会平等对待每一处历史,但这前提是输入为u(t)=u(eτ),但S4的ODE相当于输入直接改为u(τ),此时对τ做均匀离散化的话,结果就是每一处的权重不相等——假设t[0,T],用概率密度的写法就是dt/T=ρ(τ)dτ,即ρ(τ)=eτ/T,即权重是τ的指数函数,越新的历史权重越大。

第二个角度则需要多一点线性代数知识。同样在《重温SSM(二):HiPPO的一些遗留问题》我们说过HiPPO-LegS的矩阵A理论上是可以对角化的,并且其特征值为[1,2,3,],于是存在可逆矩阵P使得A=P1ΛP,其中Λ=diag(1,2,,d),根据相似不变性,原系统等价于新系统
x(t)=Λx(t)+PBu(t)y(t)=CP1x(t)


离散化后(以前向欧拉为例):
x(t+ϵ)=(I+ϵΛ)Px(t)+ϵPBu(t)

这里的I+ϵΛ是每个分量都小于1的对角线矩阵,也就意味着每迭代一步,就将历史信息乘以一个小于1的数,多步叠加后,就呈现出指数衰减的效应。

离散格式 #

虽然指数衰减看上去没有LegS平等对待每一处历史那么优雅,但实际上没有免费的午餐,对于固定大小的记忆状态x(t),在记忆区间越来越大时,LegS平等对待每一处历史的做法反而会导致每一处历史都比较模糊,对于符合“近大远小”的场景反而得不偿失。此外,S4型ODE右端没有显式地出现时间t,这也有助于提供训练效率。

对S4型ODE的记忆性质心中有数之后,我们就可以着手下一步操作了。为了处理实际中的离散序列,我们首先要进行离散化,在上一篇文章中,我们给出了两种精度较高的离散格式,一种是双线性形式
xk+1=(IϵA/2)1[(I+ϵA/2)xk+ϵBuk]


它具有二阶的精度,S4采用的就是这个离散化格式,也是本文接下来所探讨的格式。另一种是基于精确求解常输入的ODE,得到
xk+1=eϵAxk+A1(eϵAI)Buk

作者后面的作品包括Mamba都是用这个格式,此时一般都要假设A为对角矩阵,因为对于LegS的矩阵A,矩阵指数算起来并不友好。

现在我们记:
ˉA=(IϵA/2)1(I+ϵA/2),ˉB=ϵ(IϵA/2)1B,ˉC=C


那么就得到线性RNN:
xk+1=ˉAxk+ˉBukyk+1=ˉCxk+1

其中ϵ>0是离散化步长,是人为选择的超参数。

卷积运算 #

在上一篇文章中,我们还提到了HiPPO-LegS的矩阵A具备计算高效的特点,具体表现为AˉA跟向量x相乘,存在计算复杂度为O(d)而不是一般的O(d2)的高效算法,但这仅仅意味着式(10)递归计算时比一般的RNN高效,而如果想要进行高效训练的话,单纯递归是不够的,需要探究并行计算方法。

线性RNN的并行计算有两种思路:一种是在《Google新作试图“复活”RNN:RNN能否再次辉煌?》介绍过的视为Prefix Sum问题,直接用Upper/Lower、Odd/Even、Ladner-Fischer等Associative Scan算法进行计算,论文可参考《Prefix Sums and Their Applications》;另一种是转化为矩阵序列和向量序列的卷积运算,利用快速傅里叶变换(FFT)来加速,这是S4的思路。但不管哪一种,它们面临共同的瓶颈:幂矩阵ˉAk的计算。

具体来说,我们一般会设初始状态x0为0,那么就可以写出:
y1=ˉCˉBu0y2=ˉC(ˉAx0+ˉBu1)=ˉCˉAˉBu0+ˉCˉBu1y3=ˉC(ˉAx1+ˉBu2)=ˉCˉA2Bu0+ˉCˉABu1+ˉCˉBu2yL=ˉC(ˉAxL1+ˉBuL1)=L1k=0ˉCˉAkˉBuLk=ˉK<Lu<L


其中代表卷积运算,而
ˉKk=ˉCˉAkˉB,ˉK<L=(ˉK0,ˉK1,,ˉKL1),u<L=(u0,u1,,uL1)

注意根据目前的约定,ˉCˉAkˉBuk都是标量,所以有ˉK<L,u<LRL。我们知道,卷积运算可以通过(离散)傅立叶变换转换为频域的乘法运算,然后再逆变换回来,它的复杂度为O(LlogL)L是序列长度。虽然复杂度看上去比直接递归的O(L)要大,但是傅立叶变换是可以并行的,所以实际上计算速度要更快。

所以,现在问题是如何高效地计算卷积核ˉK<L,它需要计算幂矩阵ˉAk,按定义计算的话复杂度还是相当大的。当然,如果只是计算ˉAk那倒不是什么问题,因为A是一个常数矩阵,给定ϵˉA也是常数矩阵,不管它的幂多难算,都可以提前算好存起来。然而,ˉAk只是中间步骤,我们还要算ˉCˉAkˉB,而S4将ˉC,ˉB视为训练参数,所以ˉCˉAkˉB没法提前算好,就是提前算好ˉAk效率还是不大够。

生成函数 #

在进一步分析之前,我们先来插入一个生成函数的概念,这是后面的高效计算的基础步骤之一。同时,对于不大了解卷积运算和离散傅立叶变换的读者,这也可以作为一个科普步骤,从中我们可以大致了解到傅立叶变换加速卷积运算的基本原理。

对于给定序列a=(a0,a1,a2,),它的生成函数就是将每个分量当成幂级数的系数来构建幂级数:
G(z|a)=k=0akzk


如果有两个序列a=(a0,a1,a2,)b=(b0,b1,b2,),那么它们生成函数的乘积:
G(z|a)G(z|b)=(k=0akzk)(l=0blzl)=k=0l=0akblzk+l=l=0(lk=0akblk)zl

留意到了没有?G(z|a)G(z|b)的第l项系数(即zl1的系数),正好是a<l=(a0,,al1)b<l=(b0,,bl1)的卷积运算。如果我们有快速计算生成函数以及快速提取生成函数某一项系数的方法,那么就可以将卷积运算转换为生成函数,做普通乘法之后然后再提取相应的系数。

离散傅立叶变换(Discrete Fourier Transform,DFT)正是这样的一种构建生成函数的思路。首先注意到,如果我们只需要对a,b的不超过前L项做卷积运算,那么生成函数的求和上限不一定非得到正无穷,求和上限改为L1也是可以的。针对这种需求,DFT没有对所有z来计算生成函数,而是选取了特定的z=e2iπl/L,l=0,1,2,,L1进行计算:
ˆal=L1k=0ak(e2iπl/L)k=L1k=0ake2iπkl/L


提取系数的逆变换(Inverse DFT,IDFT)则是
ak=1LL1l=0ˆale2iπkl/L

DFT和IDFT我们都可以通过快速傅里叶变换(Fast Fourier Transform,FFT)进行高效计算,大部分数值计算框架都已内置了相应函数,所以DFT和IDFT的计算在效率上没有问题。但要注意,如果用DFT来计算卷积的话,需要稍微微调一下,因为e2iπl/L是周期函数,我们没法区分e2iπl/Le2iπ(l+L)/L,而当我们将两个L项求和的DFT相乘时,结果会出现lLe2iπkl/L项,它会跟e2iπk(lL)/L项混合,从而做IDFT时实则得到的是两项的系数相加,这样作为卷积结果来说是不正确的。

解决这个问题的方法是将e2iπl/LL改为2L(但求和还是L项求和),也就是增大它的周期,使得乘积结果都是单个周期内,即将DFT的定义改为
ˆal=L1k=0akeiπkl/L


不过现成的FFT函数基本上都不支持单独调整周期,而是默认周期就是数组长度,所以等价的处理方式是在(a0,a1,,aL1)后面拼接L个零再做常规的DFT,得到乘积后做IDFT,最后只取前L个结果。

从幂到逆 #

对于卷积核ˉK,我们有
G(z|ˉK)=k=0ˉCˉAkˉBzk=ˉC(IzˉA)1ˉB


可以发现,生成函数不仅可以加速卷积的计算,它还将原本的幂矩阵ˉAk的计算转化为逆矩阵(IzˉA)1的计算。

什么样的矩阵ˉA,它对应的(IzˉA)1比较容易计算呢?首先对角阵肯定没问题,如果ˉA是对角阵,那么IzˉA也是对角阵,对角阵的逆直接将对角线元素都取逆即可。其次,如果ˉA可以对角化为ˉΛ,即ˉA=P1ˉΛP,那么(IzˉA)1同样容易计算,因为
(IzˉA)1=(P1(IzˉΛ)P)1=P1(IzˉΛ)1P

ˉA能不能对角化呢?这取决于A能不能对角化。如果A=P1ΛP,根据相似不变性,我们可以完全转到A=Λ的新系统去计算,而根据定义新的ˉA为:
ˉA=(IϵA/2)1(I+ϵA/2)=(IϵΛ/2)1(I+ϵΛ/2)


显然是一个对角阵。

那么A可以对角化吗?答案是理论上可以,实际上不行。理论上可以,是因为从理论上来说,几乎所有矩阵在复数域内都可以对角化,并且在上一篇文章已经给出了LegS的A特征值为[1,2,3,],也就是连对角化后的对角矩阵我们都知道长什么样了。实际上不行,是指对数值计算来说很难,因为数值计算要考虑精度、内存、时间等,只要三者之一超出了限度或容忍度,那么理论可行的算法在实际中就不成立。

对于A矩阵,实际上不行的主要原因是对角化A所需要的矩阵P存在数值不稳定问题,说白了也是计算机精度有限导致的。对于这一点,原论文直接不加解释地给出了矩阵P的解析解,然后进行验证,这显然不利于读者理解。下面笔者从特征向量计算的角度,给出另一个理解思路。

特征向量 #

A的对角化等价于A的对角化,因为A的特征值全是负数,所以简单起见我们转而考虑A的对角化,它有d个不同的特征值λ=1,2,,d,对角化它所需的矩阵就是其特征向量的堆叠,所以求P本质上是求特征向量。而对于已知特征值的矩阵,求解特征向量的直接方法是求解方程Av=λv

上一篇文章中“计算高效”那一节,我们已经给出了Av的第n个分量的计算结果:
(Av)n=nvn2n+1nk=02k+1vk


所以Av=λv意味着
2n+1nk=02k+1vknvn=λvn

Sn=nk=02k+1vk,那么2n+1vn=SnSn1,稍加整理得
Sn1=λn1λ+nSn

注意Av=λv是一个不定方程,我们有一些灵活调整的自由度(即特征向量不是唯一的),由于n最大是d1,我们可以设Sd1=1,然后递归地往回推,直到λn1=0得到Sλ1=0,此后n<λ1都有Sn=0,而对于n>λ1,则有
Sn=(1)dn1(dλ)!(n+λ)!(d+λ1)!(nλ+1)!

由于我们是想要证明P的数值不稳定性,那么观察一个特征向量即可,我们取n=λ=d/3(如果d不是3的倍数,简单取个整即可,结论不变),那么
|Sd/3|=(2d3)!(2d3)!(4d31)!O(d24d/3)

最后的可以由String公式得到。由该结果我们可以看到,对于d/3这个特征值,从Sd1Sd/3存在一个指数级别的衰减过程(反之则爆炸),那么特征向量的分量vd1vd/3也存在类似的衰减,在浮点数的有限精度内,是很难精确处理这样的特征向量的。所以,直接对角化A的矩阵P存在数值上的不稳定性。

对角低秩 #

除了对角阵外,当ˉA可以低秩分解时,同样可以降低(IzˉA)1的计算难度。这是因为我们有如下的Woodbury恒等式:
(IUV)1=k=0(UV)k=I+U(k=0(VU)k)V=I+U(IVU)1V


这里U,VRd×r,推导过程利用了(UV)k=U(VU)k1V。如果dr,那么理论上(IVU)1的计算量就比(IUV)1少得多,因此可以加速计算。特别地,如果r=1,那么(IVU)1就是一个标量的倒数,计算起来最简单。

然而,我们知道A是一个下三角阵,且对角线元素没有一个是零,那么它就一定是满秩矩阵。再结合上一节的结论,也就是说A即不低秩,对角化又存在实践上的困难,所以这都不适用,还有什么办法呢?有!利用上面的Woodbury恒等式,我们可以推出它更一般的版本:
(MUV)1=(M(I(M1U)V))1=(I(M1U)V)1M1=(I+M1U(IVM1U)1V)M1=M1+M1U(IVM1U)1VM1


这个结果告诉我们,如果M的逆比较容易算,那么它加/减一个低秩矩阵的逆也容易算。那什么样的矩阵逆比较容易算呢?又回到上一节的答案——对角矩阵。所以,我们可以想办法将A或者ˉA往“对角+低秩”的形式上凑。

事实上,仔细观察就会发现,A矩阵本身就有“对角+低秩”的影子。在上一篇文章中,我们将A的定义等价地改写为:
An,k={nδn,k2n+12k+1,kn0,k>n


其中nδn,k实质就是对角矩阵diag(0,1,2,),而2n+12k+1则可以重写为低秩矩阵形式vv,其中v=[1,3,5,]Rd×1,换句话说,如果没有k>n,An,k=0的规定,那么A本身就是对角矩阵减去低秩矩阵的形式了。

点睛之笔 #

虽然有了下三角阵的约束后,这个规律就不再适用了,但我们可以充分利用原本就有的vv结构,来辅助构建新的可对角化矩阵。但不得不说,这个技巧相当机智,堪称点睛之笔,让人惊叹,再次为原作者点赞。具体来说,我们考虑A+12vv
(A+12vv)n,k={nδn,k122n+12k+1,kn122n+12k+1,k>n


这个新矩阵的对角线元素正好是12I,我们再加上12I,就得到
(A+12vv+12I)n,k={122n+12k+1,k<n0,k=n122n+12k+1,k>n

重点来了,可以看到这是一个反对称矩阵,所以它一定可以(在复数域中)对角化!于是我们就将A分解为了可对角化矩阵与低秩矩阵之和!可能有读者质疑,原本A就一定是可对角化矩阵,但还是有数值稳定性问题,难道这个反对称矩阵的对角化不用担心数值稳定性问题吗?重点的重点来了,反对称矩阵不单单一定可以对角化,它一定可以被正交矩阵(复数域叫做酉矩阵)对角化!酉矩阵一般数值稳定性都非常好,所以不用担心这个问题,这也就是为什么我们不直接对角化A,而绕一圈来构建反对称矩阵的原因。

现在我们得到,存在对角矩阵Λ和酉矩阵U,使得A+12vv+12I=UΛU,从而
A=UΛU12I12vv=U(Λ12I12(Uv)(Uv))U


抛开脚手架,我们发现最终的结论可以简化为“A同构于对角阵减去秩1矩阵”:存在酉矩阵U、对角矩阵Λ、列向量u,v,使得:
A=U(Λuv)U

注意“对角+低秩”的矩阵乘以向量是计算高效的,比如
(Λuv)x=Λxu(vx)

Λx相当于将Λ当成向量与x逐位相乘,而u(vx)则是vx先做内积,然后得到一个标量乘以向量u,这些都可以在O(d)内完成。

最后冲刺 #

有了A=U(Λuv)U,再次根据相似不变性,我们接下来的所有计算都可以转到A=Λuv中进行,所以下面均设A=Λuv。首先,对于ˉA
ˉA=(Iϵ(Λuv)/2)1(I+ϵ(Λuv)/2)


留意到Iϵ(Λuv)/2=ϵ2(D+uv),其中D=2ϵIΛ是对角阵,于是利用Woodbury恒等式得到:
(Iϵ(Λuv)/2)1=2ϵ(D+uv)1=2ϵ[D1D1u(I+vD1u)1vD1]

仔细观察,这同样是“对角+低秩”的形式,再乘以(I+ϵ(Λuv)/2)后就能完成ˉA的计算,最终结果是两个“对角+低秩”矩阵的相乘,意味着它同样具有计算高效的特点,这个结果可以在递归推理中用到。

最后是并行训练所需要的卷积核,我们已经将它转化为生成函数(18),现在我们就可以来完成它的计算了。首先通过类似“通分”的操作可以证明:
G(z|ˉK)=ˉC(IˉAz)1ˉB=ˉC(I(IϵA/2)1(I+ϵA/2)z)1ˉB=ˉC[(IϵA/2)1((IϵA/2)(I+ϵA/2)z)]1ˉB=ˉC[(IϵA/2)(I+ϵA/2)z]1(IϵA/2)ˉB=ˉC[(IϵA/2)(I+ϵA/2)z]1Bϵ=ˉC[(1z)I(1+z)ϵA/2]1Bϵ=21+zˉC[2ϵ1z1+zIA]1B


于是代入A=Λuv得到
G(z|ˉK)=21+zˉC[2ϵ1z1+zI(Λuv)]1B=21+zˉC(Rz+uv)1B

这里Rz=2ϵ1z1+zIΛ是个对角阵,于是再次利用Woodbury恒等式就可以完成计算:
G(z|ˉK)=21+zˉC[R1zR1zu(I+vR1zu)1vR1z]B

这是关于z的标量函数。不过要注意一个细节,傅立叶变换所需要的实际是“截断生成函数”:
GL(z|ˉK)=L1k=0ˉCˉAkˉBzk=ˉC(IzLˉAL)(IzˉA)1ˉB

也就相当于G(z|ˉK)ˉC要换成ˉC(IzLˉAL),这里L是提前选定的最大训练长度。接下来,我们只需要代入z=e2iπl/L,l=0,1,2,,L1进行计算,结果就是ˉK的DFT,然后IDFT就得到ˉK了,这个过程还可以转化为Cauchy核问题加速一下,但个人认为不是太核心,就不展开讨论了。最后的最后,还有一个技巧,就是对于z=e2iπl/LzL=1,此时只是相当于将ˉC要换成ˉC(IˉAL),而S4将ˉC当成训练参数,所以我们可以直接将ˉC(IˉAL)当成训练参数,事后再从中解出ˉC用于推理,这样训练时就可以避免计算ˉAL了。

这里看上去我们也可以代入z=eiπl/L直接计算卷积所用的ˉK的DFT,而不是迂回地先IDFT得到ˉK,然后拼接零再DFT,但问题是此时zL=(1)l是一个不定值,我们没法将ˉC(IzLˉAL)看成单个训练参数,这会导致在训练过程中需要计算ˉAL,计算量比较大(当然,如果训练过程中ˉA是完全固定的,那么可以提前算出来,视情况而定)。

草草收尾 #

经过一通艰难的“长篇大论”,我们总算把S4中比较关键的数学细节都捋了一遍,希望能够对有兴趣了解S4的读者有所帮助。可以看到,S4是对HiPPO的进一步补充和完成,它的关键一笔是提出了A等价于“对角+低秩”的矩阵形式,为剩余部分的分析奠定了基础。因为一开始A是分段定义的形式,而不是矩阵运算形式,这样的定义不利于应用现有的线性代数工具进行一般化分析。

由于HiPPO的推导是基于u(t)是一维函数进行的,所以到目前为止,S4的uk也都还是标量。那么S4怎么处理向量序列输入呢?非常暴力,它直接对每个分量独立地应用一遍前述线性RNN,每个RNN使用不同的ϵ,B,C参数,然后将结果拼接起来,这个做法直到作者最新的Mamaba依然还被应用。当然,也有简化的做法,直接在单个RNN中处理向量输入,只需要相应地将B,C改为矩阵就行,这就是S5(作者不是Albert Gu了),这种做法可以理解为单纯借用了S4的线性RNN形式以及HiPPO的矩阵A,而抛开了HiPPO的其他细枝末节,也取得了不错的效果。

让人啼笑皆非的是,S4提出了诸多精妙的数学技巧来简化和加速A的计算,结果从《Diagonal State Spaces are as Effective as Structured State Spaces》开始,原作者的后续工作包括Mamba基本上都抛弃了这部分内容,而是直接假设A为对角矩阵,这样RNN部分就跟《Google新作试图“复活”RNN:RNN能否再次辉煌?》介绍的LRU大同小异了。因此,从当前最新的SSM及线性RNN的角度看,S4及HiPPO系列工作某种意义上来说已经是“过时”了。很多讲解Mamba的文章从HiPPO、S4开始说起,从事后来说可谓是“大可不必”了。

当然,对于笔者来说,花那么长的篇幅去学习HiPPO和S4,并不是简单为了理解或使用最新的SSM和RNN模型,而是通过学习HiPPO背后的假设和推导,了解线性系统的记忆方式和瓶颈,为将来构建新模型、新方法积累更多的思路。此外,HiPPO和S4中诸多精妙的数学技巧也让人赏心悦目,并且也不失为提升数学能力的相当不错的练习题。

文章小结 #

本文介绍了HiPPO的后续之作S4,它的关键之处是提出了“对角矩阵+低秩矩阵”的分解,从而实现了HiPPO矩阵的高效并行计算,本文主要对其中比较困难的数学细节做了介绍和推导。

转载到请包括本文地址:https://spaces.ac.cn/archives/10162

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Jun. 20, 2024). 《重温SSM(三):HiPPO的高效计算(S4) 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/10162

@online{kexuefm-10162,
        title={重温SSM(三):HiPPO的高效计算(S4)},
        author={苏剑林},
        year={2024},
        month={Jun},
        url={\url{https://spaces.ac.cn/archives/10162}},
}