脑洞大开:非线性RNN居然也可以并行计算?
By 苏剑林 | 2023-09-26 | 64748位读者 |近年来,线性RNN由于其可并行训练以及常数推理成本等特性,吸引了一定研究人员的关注(例如笔者之前写的《Google新作试图“复活”RNN:RNN能否再次辉煌?》),这让RNN在Transformer遍地开花的潮流中仍有“一席之地”。然而,目前看来这“一席之地”只属于线性RNN,因为非线性RNN无法高效地并行训练,所以在架构之争中是“心有余而力不足”。
不过,一篇名为《Parallelizing Non-Linear Sequential Models over the Sequence Length》的论文有不同的看法,它提出了一种迭代算法,宣传可以实现非线性RNN的并行训练!真有如此神奇?接下来我们一探究竟。
求不动点 #
原论文对其方法做了非常一般的介绍,而且其侧重点是PDE和ODE,这里我们直接从RNN入手。考虑常见的简单非线性RNN:
xt=tanh(Axt−1+ut)
由于tanh的存在,它只能串行计算。现在我们在两边都减去Axt−1:
xt−Axt−1=tanh(Axt−1+ut)−Axt−1
当然,这改变不了它是非线性RNN的实质。然而我们可以发现,假如右端的xt−1换成像ut那样的给定向量,那么这就是一个线性RNN了,根据《Google新作试图“复活”RNN:RNN能否再次辉煌?》的结果,它是可以并行计算的。此时,敏捷的读者可能已经猜到后面的步骤了——迭代求解!
首先,将上述RNN更改成
x(n)t−Ax(n)t−1=tanh(Ax(n−1)t−1+ut)−Ax(n−1)t−1
从给定x(0)t出发,反复迭代上式,理想情况下,它会收敛于一个不动点x∗t,这就是原来非线性RNN的计算结果。当然,理论上通过式(3)迭代的总计算量是比直接通过式(1)递归计算要大的,但由于每一步迭代都是可并行的线性RNN,并且如果收敛速度比较快时迭代步数不需要太多,那么总的耗时通常都会快于直接非线性RNN递归(尤其是序列长度很大时)。
简化形式 #
事实上,非线性RNN之所以慢,无法并行计算还是次要的,最关键是它包含了大量的非element-wise运算,比如式(1)的tanh里边的矩阵运算Axt−1;而线性RNN之所以快,除了它允许并行训练之外,更关键的是它能通过对角化来将矩阵乘法变换为element-wise的乘法——对于element-wise乘法来说,即便是串行计算也不会太慢。
当我们通过式(3)将非线性RNN转为线性RNN的迭代之后,同样享受线性RNN可对角化的“待遇”,从而提高计算速度。具体来说,在复数域中将A对角化为PΛP−1,那么式(3)变为
x(n)t−PΛP−1x(n)t−1=tanh(PΛP−1x(n−1)t−1+ut)−PΛP−1x(n−1)t−1
两端都左乘P−1:
P−1x(n)t−ΛP−1x(n)t−1=P−1tanh(PΛP−1x(n−1)t−1+ut)−ΛP−1x(n−1)t−1
令yt=P−1xt,那么上式可以简化为
y(n)t−Λy(n)t−1=P−1tanh(PΛy(n−1)t−1+ut)−Λy(n−1)t−1
由于RNN之后一般都还要接个投影层,所以xt=Pyt的P原则上可以合并到外接的投影层里边,也就是说,上式理论上具备跟原来的(1)具备同等的表达能力,但由于Λ是对角阵,递归的计算量会明显降低。上式还出现了逆矩阵P−1,不单计算量大,而且不利于优化,所以我们可以干脆将P−1和PΛ换成两个不相关的参数矩阵:
y(n)t−Λy(n)t−1=Ptanh(Qy(n−1)t−1+ut)−Λy(n−1)t−1
只要初始化是PQ=Λ就行。
摄动思想 #
假定x(0)t=0,那么式(3)其实就是将原本的非线性RNN就分解为一系列线性RNN:
x(1)t−Ax(1)t−1=tanh(ut)x(2)t−Ax(2)t−1=tanh(Ax(1)t−1+ut)−Ax(1)t−1⋮x(n)t−Ax(n)t−1=tanh(Ax(n−1)t−1+ut)−Ax(n−1)t−1⋮
而假设xt−1,ut都是小量,那么对式(1)右端利用tanhx≈x得到:
xt=tanh(Axt−1+ut)≈Axt−1+ut≈Axt−1+tanh(ut)
这正好是(8)中的第一个方程,因此如果假设成立,那么x(1)t或许已经足够接近理想的x∗t,后面的每一步迭代都在快速逼近它。从这里我们可以看出,“两边同时减去Axt−1”是关键之处,这使得(3)的第一步迭代就接近于原本非线性RNN的一阶线性近似,这可以提高收敛速度,也是数学物理中的经典操作,名曰“摄动”。
加快收敛 #
根据摄动法的思想,提高收敛速度的关键就是提高近似展开的精度,比如较为简单的改进是只假设xt−1是小量,那么根据一阶泰勒展开有(将ut视为列向量,这里的∘是Hadamard积分)
xt=tanh(Axt−1+ut)≈tanh(ut)+(sech2ut∘A)xt−1
于是改进的结果就是式(3)变为
x(n)t−Atx(n)t−1=tanh(Ax(n−1)t−1+ut)−Atx(n−1)t−1
其中At=sech2ut∘A。更精细的改进是在每一步迭代时,都在前一步迭代结果的基础上进行展开:
xt=tanh(Axt−1+ut)≈tanh(Ax(n−1)t−1+ut)+(sech2(Ax(n−1)t−1+ut)∘A)(xt−1−x(n−1)t−1)
于是式(3)变为
x(n)t−A(n)tx(n)t−1=tanh(Ax(n−1)t−1+ut)−A(n)tx(n−1)t−1
其中A(n)t=sech2(Ax(n−1)t−1+ut)∘A。最后的这个迭代格式,实际上就是求方程数值解的“牛顿法”,它具有二次收敛速度。
何必收敛 #
理论上来说,(11)、(13)两个改进确实能提高收敛速度,然而它们使得每一步线性递归的矩阵A变得跟t甚至n有关了,这其实会大大增加并行的复杂度,也不能利用“简化形式”一节的对角化技巧来加速。另一方面,如果保持(3)这样的迭代格式,虽然有诸多效率上的好处,但收敛方面确实无法得到很好的保障。
难道这两者的矛盾就无法调和了吗?事实上,按照笔者的观点,最直接的做法是“别去管它”——借助非线性RNN导出了(3)后,就忘记原本的非线性RNN,将式(3)作为基本模型。也就是说,何必忧虑式(3)会不会收敛到原来的非线性RNN?直接将它作为新的出发点不好吗?梯度下降学到怎样的结果就是怎样的结果,如果梯度下降学到的结果是不收敛到原来的非线性RNN,那么就意味着不收敛到原来的RNN是更适合的。
抛开这一层思维束缚后,其实很多问题会变得豁然开朗起来。首先,即便是式(13)在理论上拥有非常好的收敛速度,但也是有条件的,而且在深度学习的背景下,要保证这些条件会显得很奢侈。换言之,即便是式(13)的收敛性也没有绝对保证,所以何必“五十步笑百步”去苛责式(3)?其次,将式(3)视为新的出发点后,我们可以将它单纯地理解为线性RNN的一种新用法,或者说解决线性RNN缺陷(比如线性RNN不是图灵完备的)的一个思路,这样操作性更强。
总的来说,不去管它的收敛性,似乎更能打破思维僵局,探索更一般的结果。
一般情形 #
前面的“长篇大论”,都只围绕着简单的非线性RNN也就是式(1)进行讨论,对于更常用的LSTM、GRU,结果又如何呢?
以GRU为例,它原本的形式为
zt=σ(Wzxt+Uzht−1+bz)rt=σ(Wrxt+Urht−1+br)ˆht=tanh(Whxt+Uh(rt∘ht−1)+bc)ht=(1−zt)∘ht−1+zt∘ˆht
初始阶段,所有门控都可以近似视为12,那么模仿式(9)有
ht=(1−zt)∘ht−1+zt∘ˆht≈12ht−1+12ˆht≈12ht−1+12(tanh(Whxt+bc)+12Uhht−1)=12(I+12Uh)ht−1+12tanh(Whxt+bc)
所以可以选取A=12(I+12Uh),将GRU改写为迭代
z(n)t=σ(Wzxt+Uzh(n−1)t−1+bz)r(n)t=σ(Wrxt+Urh(n−1)t−1+br)ˆh(n)t=tanh(Whxt+Uh(r(n)t∘h(n−1)t−1)+bc)h(n)t=Ah(n)t−1−Ah(n−1)t−1+(1−z(n)t)∘h(n−1)t−1+z(n)t∘ˆh(n)t
总的来说,这种将非线性RNN变为线性RNN迭代的转换,从实践的角度来看,就是以非线性RNN为引,导出一种多层线性RNN的参数共享和组合方法,迭代了几次,那么就有几层线性RNN的计算量。这样自然而言就引发了一个思考:除非可以证明GRU、LSTM等非线性RNN有绝对的优势,否则直接叠加几层“线性RNN+MLP”不好吗?
文章小结 #
本文简单探讨了非线性RNN的并行计算问题——通过数学物理中的“摄动”思想,我们可以将非线性RNN转化为线性RNN的迭代,从而利用线性RNN的可并行性来实现非线性RNN的并行。
转载到请包括本文地址:https://spaces.ac.cn/archives/9783
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Sep. 26, 2023). 《脑洞大开:非线性RNN居然也可以并行计算? 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9783
@online{kexuefm-9783,
title={脑洞大开:非线性RNN居然也可以并行计算?},
author={苏剑林},
year={2023},
month={Sep},
url={\url{https://spaces.ac.cn/archives/9783}},
}
September 26th, 2023
Google 最近有提出 TSMixer (https://openreview.net/forum?id=wbpxTuXgm0)。
扫了一眼,似乎不是特别有意思。
September 26th, 2023
最近這幾年有一個主張是透過 dynamical system 把 RNN 和 ODE/PDE/SDE 聯繫在一起。
您好,具体哪些 dynamical system? 方便列一些吗
这个联系比较自然吧。就本文提到的论文而言,它主要就是做ODE/PDE的,看上去做RNN只是顺便。
September 26th, 2023
[...]Read More [...]
September 29th, 2023
爲甚麼 公式 (7) 的中只需要初始化保證 PQ=Λ 就可以了? Neural Networks 在訓練的過程中,weight tensor 不是會一直改變的嗎?PQ=Λ 在訓練過程中可能會無法保證也沒關係?
如果训练过程不保证\boldsymbol{P}\boldsymbol{Q} = \boldsymbol{\Lambda},可能是因为优化器觉得\boldsymbol{P}\boldsymbol{Q} \neq \boldsymbol{\Lambda}更好,所以并没有什么关系。
October 10th, 2023
hello,从感受野的角度来看,迭代的过程,x_t要看到x_0,不是整个不动点迭代要跑t个step吗,这和原来的rnn比,为什么能快呢?
可以用 prefix sum 來並行運算 RNN,這樣就比原先衹能迭代運算 RNN 要快很多。
线性RNN是可以并行的(参考 https://kexue.fm/archives/9554 ),理想情况下复杂度是\mathcal{O}(\log t),所以总的复杂度再乘以迭代步数,因此时间上有可能少于直接计算非线性RNN。
October 18th, 2023
这个(16)是怎么推导出来的,有大神能解释一下细节吗
所有的中间变量都用上一步的迭代结果计算,然后用两端减去一阶近似的方式去构造最终的迭代格式,保证^{(n-1)}替换为^{(n)}时,退化为原本的非线性RNN。
November 12th, 2023
令 \mathbf{P} 爲 learnable orthogonal matrix 即可,比如可以用 expRNN 裡提出的方法。
这样计算复杂度又更大了?
還有一個辦法,在公式 (3) 中,用 \mathbf{A} = \mathbf{F}\mathbf{\Lambda}\mathbf{F}^{\ast} 替代 \mathbf{A} = \mathbf{P}\mathbf{\Lambda}\mathbf{P}^{-1},其中 \mathbf{F} 是 DFT 矩陣。從 signal processing 的角度來看,\mathbf{P}^{-1} 和 \mathbf{P} 的作用是對 \mathbf{x}_{t-1} 做 discrete orthogonal transform 和 inverse transform。用 DFT 矩陣替代之後,可以用 FFT 來加速運算。
这样的话感觉干脆换用基于FFT的模型得了(捂脸)~
說不定 FFT based Linear RNN 確實可以
還真有這樣的模型(https://openreview.net/forum?id=IxmWsm4xrua)
这个就是广义的FFT呀,并不是反过来用FFT来加速RNN
不會。我最近在試類似的方法,時間複雜度和空間複雜度僅比O(NlogN)大一點點。
December 7th, 2023
有个不太懂的是,这种技术可以用于RNN的并行推理吗?还是只能用于并行训练
理论上可以,但实际上有点难,因为无法提前预知结果长度。
September 2nd, 2024
想问一下以(8)式为例,从step=n-1到step=n,右式的x_{t-1}^{n-1}如何得到呢?前一个式子得到的是x_{t}^{n-1}-Ax_{t-1}^{n-1},如果想要单独求出x_{t-1}^{n-1}是不是还需要类似于递推或者开一个大数组从x_{1}^{n-1}开始往后一个一个计算+存储?
如果是这样的话,对序列的初始值x_1,(9)式有:x_1^{0}=tanh(u_1)(因为x_0 \triangleq 0),那就意味着x_1是固定且无损的,不会随着迭代而更新,导致后面所有的x_{t}^{n}都无法更新了?
第一步,x_t^{(1)} - Ax_{t-1}^{(1)} = \tanh(u_t),这就是个线性RNN,可以并行求出x_1^{(1)},x_2^{(1)},x_3^{(1)},\cdots;
第二步,x_t^{(2)} - Ax_{t-1}^{(2)} = \tanh(Ax_{t-1}^{(1)} + u_t) - Ax_{t-1}^{(1)},其中所有的x_1^{(1)},x_2^{(1)},x_3^{(1)},\cdots已经在上一步求出,因此也可以并行求出x_1^{(2)},x_2^{(2)},x_3^{(2)},\cdots;
依此类推~
嗯嗯,这个逻辑乍一看很显然,但是感觉细想还是不太对。n迭代过程中x^{(n)}能够产生更新的原因还是因为真值在随着n的迭代的过程中在跟着t进行传递吧,这样理论上n=t才可以收敛?
第一步x_1^{(1)}=tanh(u_1)已经是真值x_1了,而且这个真值每一步迭代都是这样的形式,不会改变(x_1是首项,即x_0 \triangleq 0,因此x_1的递推式和后面的不同);第二步这个真值根据x_2^{(2)}-Ax_1^{(2)}=tanh(Ax_1^{(1)}+u_2)-Ax_1^{(1)},由于x_1^{(1)}=x_1^{(2)},所以x_2^{(2)}=tanh(Ax_1^{(1)}+u_2)=tanh(Ax_1+u_2),也即正确的x_1在迭代第二步n=2的时候传递到了x_2处,使得x_2也获得了真值。
之后类似,整体的迭代的收敛性其实是因为真值在随着n根据t来逐步往后传的,如果真值传不到(n \ll t_{length}),误差会很大吧?
所以你的意思是想要表达这样做的误差可能很大,以及你为什么觉得误差大的一些看法?
如果这样的话,那这个问题确实无解了,要严格保证它收敛于原始非线性RNN还是要静心调试的,我更倾向于将它当成一种复杂线性RNN的构建方式,事后就当它线性RNN来用,保持训练和预测的一致性。