近年来,线性RNN由于其可并行训练以及常数推理成本等特性,吸引了一定研究人员的关注(例如笔者之前写的《Google新作试图“复活”RNN:RNN能否再次辉煌?》),这让RNN在Transformer遍地开花的潮流中仍有“一席之地”。然而,目前看来这“一席之地”只属于线性RNN,因为非线性RNN无法高效地并行训练,所以在架构之争中是“心有余而力不足”。

不过,一篇名为《Parallelizing Non-Linear Sequential Models over the Sequence Length》的论文有不同的看法,它提出了一种迭代算法,宣传可以实现非线性RNN的并行训练!真有如此神奇?接下来我们一探究竟。

求不动点 #

原论文对其方法做了非常一般的介绍,而且其侧重点是PDE和ODE,这里我们直接从RNN入手。考虑常见的简单非线性RNN:
xt=tanh(Axt1+ut)
由于tanh的存在,它只能串行计算。现在我们在两边都减去Axt1
xtAxt1=tanh(Axt1+ut)Axt1
当然,这改变不了它是非线性RNN的实质。然而我们可以发现,假如右端的xt1换成像ut那样的给定向量,那么这就是一个线性RNN了,根据《Google新作试图“复活”RNN:RNN能否再次辉煌?》的结果,它是可以并行计算的。此时,敏捷的读者可能已经猜到后面的步骤了——迭代求解!

首先,将上述RNN更改成
x(n)tAx(n)t1=tanh(Ax(n1)t1+ut)Ax(n1)t1
从给定x(0)t出发,反复迭代上式,理想情况下,它会收敛于一个不动点xt,这就是原来非线性RNN的计算结果。当然,理论上通过式(3)迭代的总计算量是比直接通过式(1)递归计算要大的,但由于每一步迭代都是可并行的线性RNN,并且如果收敛速度比较快时迭代步数不需要太多,那么总的耗时通常都会快于直接非线性RNN递归(尤其是序列长度很大时)。

简化形式 #

事实上,非线性RNN之所以慢,无法并行计算还是次要的,最关键是它包含了大量的非element-wise运算,比如式(1)tanh里边的矩阵运算Axt1;而线性RNN之所以快,除了它允许并行训练之外,更关键的是它能通过对角化来将矩阵乘法变换为element-wise的乘法——对于element-wise乘法来说,即便是串行计算也不会太慢。

当我们通过式(3)将非线性RNN转为线性RNN的迭代之后,同样享受线性RNN可对角化的“待遇”,从而提高计算速度。具体来说,在复数域中将A对角化为PΛP1,那么式(3)变为
x(n)tPΛP1x(n)t1=tanh(PΛP1x(n1)t1+ut)PΛP1x(n1)t1
两端都左乘P1
P1x(n)tΛP1x(n)t1=P1tanh(PΛP1x(n1)t1+ut)ΛP1x(n1)t1
yt=P1xt,那么上式可以简化为
y(n)tΛy(n)t1=P1tanh(PΛy(n1)t1+ut)Λy(n1)t1
由于RNN之后一般都还要接个投影层,所以xt=PytP原则上可以合并到外接的投影层里边,也就是说,上式理论上具备跟原来的(1)具备同等的表达能力,但由于Λ是对角阵,递归的计算量会明显降低。上式还出现了逆矩阵P1,不单计算量大,而且不利于优化,所以我们可以干脆将P1PΛ换成两个不相关的参数矩阵:
y(n)tΛy(n)t1=Ptanh(Qy(n1)t1+ut)Λy(n1)t1
只要初始化是PQ=Λ就行。

摄动思想 #

假定x(0)t=0,那么式(3)其实就是将原本的非线性RNN就分解为一系列线性RNN:
x(1)tAx(1)t1=tanh(ut)x(2)tAx(2)t1=tanh(Ax(1)t1+ut)Ax(1)t1x(n)tAx(n)t1=tanh(Ax(n1)t1+ut)Ax(n1)t1
而假设xt1,ut都是小量,那么对式(1)右端利用tanhxx得到:
xt=tanh(Axt1+ut)Axt1+utAxt1+tanh(ut)
这正好是(8)中的第一个方程,因此如果假设成立,那么x(1)t或许已经足够接近理想的xt,后面的每一步迭代都在快速逼近它。从这里我们可以看出,“两边同时减去Axt1”是关键之处,这使得(3)的第一步迭代就接近于原本非线性RNN的一阶线性近似,这可以提高收敛速度,也是数学物理中的经典操作,名曰“摄动”。

加快收敛 #

根据摄动法的思想,提高收敛速度的关键就是提高近似展开的精度,比如较为简单的改进是只假设xt1是小量,那么根据一阶泰勒展开有(将ut视为列向量,这里的是Hadamard积分)
xt=tanh(Axt1+ut)tanh(ut)+(sech2utA)xt1
于是改进的结果就是式(3)变为
x(n)tAtx(n)t1=tanh(Ax(n1)t1+ut)Atx(n1)t1
其中At=sech2utA。更精细的改进是在每一步迭代时,都在前一步迭代结果的基础上进行展开:
xt=tanh(Axt1+ut)tanh(Ax(n1)t1+ut)+(sech2(Ax(n1)t1+ut)A)(xt1x(n1)t1)
于是式(3)变为
x(n)tA(n)tx(n)t1=tanh(Ax(n1)t1+ut)A(n)tx(n1)t1
其中A(n)t=sech2(Ax(n1)t1+ut)A。最后的这个迭代格式,实际上就是求方程数值解的“牛顿法”,它具有二次收敛速度。

何必收敛 #

理论上来说,(11)(13)两个改进确实能提高收敛速度,然而它们使得每一步线性递归的矩阵A变得跟t甚至n有关了,这其实会大大增加并行的复杂度,也不能利用“简化形式”一节的对角化技巧来加速。另一方面,如果保持(3)这样的迭代格式,虽然有诸多效率上的好处,但收敛方面确实无法得到很好的保障。

难道这两者的矛盾就无法调和了吗?事实上,按照笔者的观点,最直接的做法是“别去管它”——借助非线性RNN导出了(3)后,就忘记原本的非线性RNN,将式(3)作为基本模型。也就是说,何必忧虑式(3)会不会收敛到原来的非线性RNN?直接将它作为新的出发点不好吗?梯度下降学到怎样的结果就是怎样的结果,如果梯度下降学到的结果是不收敛到原来的非线性RNN,那么就意味着不收敛到原来的RNN是更适合的。

抛开这一层思维束缚后,其实很多问题会变得豁然开朗起来。首先,即便是式(13)在理论上拥有非常好的收敛速度,但也是有条件的,而且在深度学习的背景下,要保证这些条件会显得很奢侈。换言之,即便是式(13)的收敛性也没有绝对保证,所以何必“五十步笑百步”去苛责式(3)?其次,将式(3)视为新的出发点后,我们可以将它单纯地理解为线性RNN的一种新用法,或者说解决线性RNN缺陷(比如线性RNN不是图灵完备的)的一个思路,这样操作性更强。

总的来说,不去管它的收敛性,似乎更能打破思维僵局,探索更一般的结果。

一般情形 #

前面的“长篇大论”,都只围绕着简单的非线性RNN也就是式(1)进行讨论,对于更常用的LSTM、GRU,结果又如何呢?

以GRU为例,它原本的形式为
zt=σ(Wzxt+Uzht1+bz)rt=σ(Wrxt+Urht1+br)ˆht=tanh(Whxt+Uh(rtht1)+bc)ht=(1zt)ht1+ztˆht
初始阶段,所有门控都可以近似视为12,那么模仿式(9)
ht=(1zt)ht1+ztˆht12ht1+12ˆht12ht1+12(tanh(Whxt+bc)+12Uhht1)=12(I+12Uh)ht1+12tanh(Whxt+bc)
所以可以选取A=12(I+12Uh),将GRU改写为迭代
z(n)t=σ(Wzxt+Uzh(n1)t1+bz)r(n)t=σ(Wrxt+Urh(n1)t1+br)ˆh(n)t=tanh(Whxt+Uh(r(n)th(n1)t1)+bc)h(n)t=Ah(n)t1Ah(n1)t1+(1z(n)t)h(n1)t1+z(n)tˆh(n)t

总的来说,这种将非线性RNN变为线性RNN迭代的转换,从实践的角度来看,就是以非线性RNN为引,导出一种多层线性RNN的参数共享和组合方法,迭代了几次,那么就有几层线性RNN的计算量。这样自然而言就引发了一个思考:除非可以证明GRU、LSTM等非线性RNN有绝对的优势,否则直接叠加几层“线性RNN+MLP”不好吗?

文章小结 #

本文简单探讨了非线性RNN的并行计算问题——通过数学物理中的“摄动”思想,我们可以将非线性RNN转化为线性RNN的迭代,从而利用线性RNN的可并行性来实现非线性RNN的并行。

转载到请包括本文地址:https://spaces.ac.cn/archives/9783

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Sep. 26, 2023). 《脑洞大开:非线性RNN居然也可以并行计算? 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9783

@online{kexuefm-9783,
        title={脑洞大开:非线性RNN居然也可以并行计算?},
        author={苏剑林},
        year={2023},
        month={Sep},
        url={\url{https://spaces.ac.cn/archives/9783}},
}