重温SSM(二):HiPPO的一些遗留问题
By 苏剑林 | 2024-06-05 | 26399位读者 |书接上文,在上一篇文章《重温SSM(一):线性系统和HiPPO矩阵》中,我们详细讨论了HiPPO逼近框架其HiPPO矩阵的推导,其原理是通过正交函数基来动态地逼近一个实时更新的函数,其投影系数的动力学正好是一个线性系统,而如果以正交多项式为基,那么线性系统的核心矩阵我们可以解析地求解出来,该矩阵就称为HiPPO矩阵。
当然,上一篇文章侧重于HiPPO矩阵的推导,并没有对它的性质做进一步分析,此外诸如“如何离散化以应用于实际数据”、“除了多项式基外其他基是否也可以解析求解”等问题也没有详细讨论到。接下来我们将补充探讨相关问题。
离散格式 #
假设读者已经阅读并理解上一篇文章的内容,那么这里我们就不再进行过多的铺垫。在上一篇文章中,我们推导出了两类线性ODE系统,分别是:
HiPPO-LegT:x′(t)=Ax(t)+Bu(t)HiPPO-LegS:x′(t)=Atx(t)+Btu(t)
其中A,B是与时间t无关的常数矩阵,HiPPO矩阵主要指矩阵A。在这一节中,我们讨论这两个ODE的离散化。
输入转换 #
在实际场景中,输入的数据点是离散的序列u0,u1,u2,⋯,uk,⋯,比如流式输入的音频信号、文本向量等,我们希望用如上的ODE系统来实时记忆这些离散点。为此,我们先定义
u(t)=uk,如果t∈[kϵ,(k+1)ϵ)
其中ϵ就是离散化的步长。该定义也就是说在区间[kϵ,(k+1)ϵ)内,u(t)是一个常数函数,其值等于uk。很明显这样定义出来的u(t)无损原本uk序列的信息,因此记忆u(t)就相当于记忆uk序列。
从uk变换到u(t),可以使得输入信号重新变回连续区间上的函数,方便后面进行积分等运算,此外在离散化的区间内保持为常数,也能够简化离散化后的格式。
LegT版本 #
我们先以LegT型ODE(1)为例,将它两端积分
x(t+ϵ)−x(t)=A∫t+ϵtx(s)ds+B∫t+ϵtu(s)ds
其中t=kϵ。根据u(t)的定义,它在[t,t+ϵ)区间内恒为uk,于是u(s)的积分可以直接算出来:
x(t+ϵ)−x(t)=A∫t+ϵtx(s)ds+ϵBuk
接下来的结果,就取决于我们如何近似x(s)的积分了。假如我们认为在[t,t+ϵ)区间内x(s)近似恒等于x(t),那么就得到前向欧拉格式
x(t+ϵ)−x(t)=ϵAx(t)+ϵBuk⇒x(t+ϵ)=(I+ϵA)x(t)+ϵBuk
我们认为在[t,t+ϵ)区间内x(s)近似恒等于x(t+ϵ),那么就得到后向欧拉格式
x(t+ϵ)−x(t)=ϵAx(t+ϵ)+ϵBuk⇒x(t+ϵ)=(I−ϵA)−1(x(t)+ϵBuk)
前后向欧拉都具有相同的理论精度,但后向通常会有更好的数值稳定性。如果要更准确一些,那么认为在[t,t+ϵ)区间内x(s)近似恒等于12[x(t)+x(t+ϵ)],那么得到双线性形式:
x(t+ϵ)−x(t)=12ϵA[x(t)+x(t+ϵ)]+ϵBuk⇓x(t+ϵ)=(I−ϵA/2)−1[(I+ϵA/2)x(t)+ϵBuk]
这也等价于先用前向欧拉走半步,再用后向欧拉走半步。更一般地,我们还可以认为在[t,t+ϵ)区间内x(s)近似恒等于αx(t)+(1−α)x(t+ϵ),其中α∈[0,1],这就不进一步展开了。事实上,我们也可以完全不做近似,因为结合式(1)以及在区间[t,t+ϵ)中u(s)是常数uk,我们完全可以用“常数变易法”来精确求解出来,结果是
x(t+ϵ)=eϵAx(t)+A−1(eϵA−I)Buk
这里的矩阵指数按照级数来定义,可以参考《恒等式 det(exp(A)) = exp(Tr(A)) 赏析》。
LegS版本 #
现在轮到LegS型ODE了,它的思路跟LegT型基本一致,结果也大同小异。首先将式(2)两端积分得到
x(t+ϵ)−x(t)=A∫t+ϵtx(s)sds+B∫t+ϵtu(s)sds
根据u(t)定义,第二项积分的u(s)在[t,t+ϵ)恒为uk,所以它相当于1/s的积分,可以直接积分出来得lnt+ϵt,当然直接换为一阶近似ϵt也无妨,因为本身uk到u(t)的变换有很大自由度,这点误差无所谓。至于第一项积分,我们直接采用精度更高的中点近似,得到
x(t+ϵ)−x(t)=12ϵA(x(t)t+x(t+ϵ)t+ϵ)+ϵtBuk⇓x(t+ϵ)=(I−ϵA2(t+ϵ))−1[(I+ϵA2t)x(t)+ϵtBuk]
事实上,式(2)也可以精确求解,只需要留意到它等价于
Ax(t)+Bu(t)=tx′(t)=ddlntx(t)
这意味着只需要做变量代换τ=lnt,那么LegS型ODE就可以转化为LegT型ODE:
ddτx(eτ)=Ax(eτ)+Bu(eτ)
利用式(9)得到(由于变量代换,时间间隔由ϵ变成ln(t+ϵ)−lnt)
x(t+ϵ)=e(ln(t+ϵ)−lnt)Ax(t)+A−1(e(ln(t+ϵ)−lnt)A−I)Buk
然而,上式虽然是精确解,但不如同为精确解的式(9)好用,因为式(9)的指数矩阵部分是eϵA,跟时间t无关,所以一次性计算完就可以了。但上式中t在矩阵指数里边,意味着在迭代过程中需要反复计算矩阵指数,对计算并不友好,所以LegS型ODE我们一般只会用式(11)来离散化。
优良性质 #
接下来,LegS是我们的重点关注对象。重点关注LegS的原因并不难猜,因为从推导的假设来看,它是目前求解出来的唯一一个能够记忆整个历史的ODE系统,这对于很多场景如多轮对话来说至关重要。此外,它还有其他的一些比较良好且实用的性质。
尺度等变 #
比如,LegS的离散化格式(11)是步长无关的,我们只需要将t=kϵ代入里边,并记x(kϵ)=xk,就可以发现
xk+1=(I−A2(k+1))−1[(I+A2k)xk+1kBuk]
步长ϵ被自动地消去了,从而自然地减少了一个需要调的超参数,这对于炼丹人士显然是一个好消息。注意步长无关是LegS型ODE的一个固有性质,它跟具体的离散化方式并无直接关系,比如精确解(14)同样是步长无关的:
xk+1=e(ln(k+1)−lnk)Axk+A−1(e(ln(k+1)−lnk)A−I)Buk
其背后的原因,在于LegS型ODE满足“时间尺度等变性(Timescale equivariance)”——如果我们设t=λτ代入LegS型ODE,将得到
Ax(ατ)+Bu(ατ)=(ατ)×dd(ατ)x(ατ)=τddτx(ατ)
这意味着,当我们将u(t)换成u(αt)时,LegS的ODE形式并没有变化,而对应的解则是x(t)换成了x(αt)。这个性质的直接后果就是:当我们选择更大的步长时,递归格式不需要发生变化,因为结果xk的步长也会自动放大,这就是LegS型ODE离散化与步长无关的本质原因。
长尾衰减 #
LegS型ODE的另一个优良性质是,它关于历史信号的记忆是多项式衰减(Polynomial decay)的,这比常规RNN的指数衰减更缓慢,从而理论上能记忆更长的历史,更不容易梯度消失。为了理解这一点,我们可以从精确解(16)出发,从式(16)可以看到,每递归一步,历史信息的衰减效应可以用矩阵指数e(ln(k+1)−lnk)A来描述,那么从第m步递归到第n步,总的衰减效应是
n−1∏k=me(ln(k+1)−lnk)A=e(lnn−lnm)A
回顾HiPPO-LegS中A的形式:
An,k=−{√(2n+1)(2k+1),k<nn+1,k=n0,k>n
从定义可以看出,A是一个下三角阵,其对角线元素为−1,−2,−3,⋯。我们知道,三角阵的对角线元素正好是它的特征值(参考Triangular matrix),由此可以看到一个d×d大小的A矩阵,有d个不同的特征值−1,−2,⋯,−d,这说明A矩阵是可对角化的,即存在可逆矩阵P,使得A=P−1ΛP,其中Λ=diag(−1,−2,⋯,−d),于是我们有
e(lnn−lnm)A=e(lnn−lnm)P−1ΛP=P−1e(lnn−lnm)ΛP=P−1diag(e−(lnn−lnm),e−2(lnn−lnm),⋯,e−d(lnn−lnm))P=P−1diag(mn,m2n2,⋯,mdnd)P
可见,最终的衰减函数是1/n的1,2,⋯,d次函数的线性组合,所以LegS型ODE关于历史记忆至多是多项式衰减的,比指数衰减更加长尾,因此理论上有更好的记忆力。
计算高效 #
最后,我们指出HiPPO-LegS的A矩阵是计算高效(Computational efficiency)的。具体来说,直接按照矩阵乘法的朴素实现的话,一个d×d的矩阵乘以d×1的列向量,需要做d2次乘法,但LegS的A矩阵与向量相乘则可以降低到O(d)次,更进一步地,我们还可以证明离散化后的(11)也可以在O(d)完成。
为了理解这一点,我们首先将HiPPO-LegS的A矩阵等价地改写成
An,k={nδn,k−√2n+1√2k+1,k≤n0,k>n
对于向量v=[v0,v1,⋯,vd−1],我们有
(Av)n=n∑k=0An,kvk=n∑k=0(nδn,k−√2n+1√2k+1)vk=nvn−√2n+1n∑k=0√2k+1vk
这包含三种运算,第一项的nvn是向量[0,1,2,⋯,d−1]与v做逐位相乘运算,第二项的√2k+1vk则是向量[1,√3,√5,⋯,√2d−1]与v做逐位相乘,然后n∑k=0就是cumsum运算,最后乘以√2n+1就是再逐位相乘向量[1,√3,√5,⋯,√2d−1],每一步都可以在O(d)内完成,因此总的复杂度是O(d)的。
我们再来看(11),它包含两步“矩阵-向量”乘法运算,一是(I+λA)v,λ是任意实数,刚才我们已经证明了Av是计算高效的,自然(I+λA)v也是;二是(I−λA)−1v,接下来我们将证明它也是计算高效的。这只需要留意到求z=(I−λA)−1v等价于解方程v=(I−λA)z,利用上面给出的Av表达式,我们可以得到
vn=zn−λ(nzn−√2n+1n∑k=0√2k+1zk)
记Sn=n∑k=0√2k+1zk,那么zn=Sn−Sn−1√2n+1,代入上式得
vn=Sn−Sn−1√2n+1−λ(nSn−Sn−1√2n+1−√2n+1Sn)
整理得
Sn=1−λn1+λn+λSn−1+√2n+11+λn+λvn
这是一个标量的递归式,可以完全串行地计算,也可以利用Prefix Sum的相关算法并行计算(参考这里),计算复杂度为O(d)或者O(dlogd),总之相比O(d2)都会更加高效。
傅立叶基 #
最后,我们以傅立叶基的一个推导收尾。在上一篇文章中,我们以傅立叶级数来引出了线性系统,但只推导了邻近窗口形式的结果,而后面的勒让德多项式基我们则推导了邻近窗口和完整区间两个版本(即LegT和LegS)。那么傅立叶基究竟能不能推导一个跟LegS相当的版本呢?其中会面临什么困难呢?下面我们对此进行探讨。
同样地,相关铺垫我们不再重复,按照上一节的记号,傅立叶基的系数为
cn(T)=∫10u(t≤T(s))e−2iπnsds
跟LegS一样,为了记忆整个[0,T]区间的信号,我们需要一个[0,1]↦[0,T]的映射,为此选取最简单的t≤T(s)=sT,代入后两边求导得到
ddTcn(T)=∫10u′(sT)se−2iπnsds
分部积分得到
ddTcn(T)=1T∫10se−2iπnsdu(sT)=1Tu(sT)se−2iπns|s=1s=0−1T∫10u(sT)d(se−2iπns)=1Tu(T)−1T∫10u(sT)e−2iπnsds+2iπnT∫10u(sT)se−2iπnsds=1Tu(T)−1Tcn(T)+2iπnT∫10u(sT)se−2iπnsds
上一篇文章我们提到,HiPPO选取勒让德多项式为基的重要原因之一是(s+1)p′n(t)可以分解为p0(t),p1(t),⋯,pn(t)的线性组合,而傅里叶基的se−2iπns则不能做到这一点。但事实上,如果允许误差的话,这个断论是不成立的,因为我们同样可以将s分解为傅里叶级数:
s=12+i2π∑k≠01ke2iπks
这里的求和有无限项,如果要截断为有限项的话,就会产生误差,但我们可以先不纠结这一点,直接往上代入得到
2iπnT∫10u(sT)se−2iπnsds=2iπnT∫10u(sT)(12+i2π∑k≠01ke2iπks)e−2iπnsds=iπnT∫10u(sT)e−2iπnsds−1T∑k≠0nk∫10u(sT)e−2iπ(n−k)sds=iπnTcn(T)−1T∑k≠0nkcn−k(T)=iπnTcn(T)−1T∑k≠nnn−kck(T)
这样一来
ddTcn(T)=1Tu(T)+iπn−1Tcn(T)−1T∑k≠nnn−kck(T)
所以可以写出
x′(t)=Atx(t)+Btu(t)An,k={−nn−k,k≠niπn−1,k=nBn=1
实际使用的时候,我们只需要截断|n|,|k|≤N,就可以得到一个(2N+1)×(2N+1)的矩阵。截断带来的误差其实是无所谓的,因为我们在推导HiPPO-LegT的时候同样引入了有限级数近似,那会我们同样也没考虑误差,或者反过来讲,对于特定的任务,我们会选择适当的规模(即N的大小),而这个“适当”的含义之一,就是截断带来误差对于该任务是可以忽略的。
对大多数人来说,傅立叶基的这个推导可能还更容易理解一些,因为勒让德多项式对很多读者来说都比较陌生,尤其是LegT、LegS推导过程中用到的几个恒等式,而对于傅立叶级数大多数读者应该或多或少都有所了解。不过,从结果上来看,傅立叶基的这个结果可能不如LegS实用,一来它引入了复数,这增加了实现的复杂度,二来它推导出的A矩阵不像LegS那样是个相对较淡的下三角阵,因此理论分析起来也更为复杂。所以,大家权当它是一道深化对HiPPO的理解的练习题就好。
文章小结 #
在这篇文章中,我们补充探讨了上一篇文章介绍的HiPPO的一些遗留问题,其中包括如何对ODE进行离散化、LegS型ODE的一些优良性质,以及利用傅立叶基记忆整个历史区间的结果推导(即LegS的傅立叶版本),以求获得对HiPPO的更全面理解。
转载到请包括本文地址:https://spaces.ac.cn/archives/10137
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jun. 05, 2024). 《重温SSM(二):HiPPO的一些遗留问题 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/10137
@online{kexuefm-10137,
title={重温SSM(二):HiPPO的一些遗留问题},
author={苏剑林},
year={2024},
month={Jun},
url={\url{https://spaces.ac.cn/archives/10137}},
}
June 6th, 2024
作者君有个小笔误:方程(8)第二个等式的右边x(t)的系数矩阵是加法:
x(t+ϵ)−x(t)=12ϵA[x(t)+x(t+ϵ)]+ϵBfk⇓x(t+ϵ)=(I−ϵA/2)−1[(I+ϵA/2)x(t)+ϵBfk]
感谢指出,已经更正~
July 24th, 2024
感谢苏佬分享! 请教一下,根据式 (20),如何理解 "最终的衰减函数是1/n
的1,2,⋯,d次函数的线性组合" 呢? 我猜测是 x(t) 的第一个元素以 m/n 衰减,第二个元素以 (m/n)^2 衰减,越靠后的元素衰减越快,可以这样理解吗?
不一定啊,就是P−1diag(mn,m2n2,⋯,mdnd)P的每一个元素,实际上是m/n的不超过d次的多项式,x(t)的每个元素,其衰减方式也是m/n的不超过d次的多项式,具体多少次算了才知道。
我也不太理解这个部分,可以麻烦苏老师更加详细的解释一下嘛
将P看成常数矩阵,那么P−1diag(mn,m2n2,⋯,mdnd)P就是mn,m2n2,⋯,mdnd的线性组合啊(矩阵乘法就是线性运算),这还能怎么详细?
September 11th, 2024
苏神,Timescale equivariance这个词应该翻译成不变性还是等变性?我们的模型在中文语境下其实分不清这两个词,是普遍中文翻译的问题吗。
我理解一个是f(x)=f(T(x))(不变性),一个是T(f(x))=f(T(x))(等变性);
这里u(t)的步长选择不同t=αt改变了x(t)的步长选择应该是不变性还是等变性?
谢谢,参考你的建议改过来了。不过就我的观点看来,其实“等变”还是“不变”都无所谓,因为你单说“等变”还是“不变”,大家都不理解是什么意思,只有后面附以相应的数学说明,才能把它的真正含义解析清楚。