脑洞大开:非线性RNN居然也可以并行计算?
By 苏剑林 | 2023-09-26 | 65109位读者 |近年来,线性RNN由于其可并行训练以及常数推理成本等特性,吸引了一定研究人员的关注(例如笔者之前写的《Google新作试图“复活”RNN:RNN能否再次辉煌?》),这让RNN在Transformer遍地开花的潮流中仍有“一席之地”。然而,目前看来这“一席之地”只属于线性RNN,因为非线性RNN无法高效地并行训练,所以在架构之争中是“心有余而力不足”。
不过,一篇名为《Parallelizing Non-Linear Sequential Models over the Sequence Length》的论文有不同的看法,它提出了一种迭代算法,宣传可以实现非线性RNN的并行训练!真有如此神奇?接下来我们一探究竟。
求不动点 #
原论文对其方法做了非常一般的介绍,而且其侧重点是PDE和ODE,这里我们直接从RNN入手。考虑常见的简单非线性RNN:
xt=tanh(Axt−1+ut)
由于\tanh的存在,它只能串行计算。现在我们在两边都减去Ax_{t-1}:
\begin{equation}x_t - Ax_{t-1} = \tanh(Ax_{t-1} + u_t) - Ax_{t-1}\end{equation}
当然,这改变不了它是非线性RNN的实质。然而我们可以发现,假如右端的x_{t-1}换成像u_t那样的给定向量,那么这就是一个线性RNN了,根据《Google新作试图“复活”RNN:RNN能否再次辉煌?》的结果,它是可以并行计算的。此时,敏捷的读者可能已经猜到后面的步骤了——迭代求解!
首先,将上述RNN更改成
\begin{equation}x_t^{(n)} - Ax_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - Ax_{t-1}^{(n-1)}\label{eq:rnn-iter}\end{equation}
从给定x_t^{(0)}出发,反复迭代上式,理想情况下,它会收敛于一个不动点x_t^*,这就是原来非线性RNN的计算结果。当然,理论上通过式\eqref{eq:rnn-iter}迭代的总计算量是比直接通过式\eqref{eq:rnn}递归计算要大的,但由于每一步迭代都是可并行的线性RNN,并且如果收敛速度比较快时迭代步数不需要太多,那么总的耗时通常都会快于直接非线性RNN递归(尤其是序列长度很大时)。
简化形式 #
事实上,非线性RNN之所以慢,无法并行计算还是次要的,最关键是它包含了大量的非element-wise运算,比如式\eqref{eq:rnn}的\tanh里边的矩阵运算Ax_{t-1};而线性RNN之所以快,除了它允许并行训练之外,更关键的是它能通过对角化来将矩阵乘法变换为element-wise的乘法——对于element-wise乘法来说,即便是串行计算也不会太慢。
当我们通过式\eqref{eq:rnn-iter}将非线性RNN转为线性RNN的迭代之后,同样享受线性RNN可对角化的“待遇”,从而提高计算速度。具体来说,在复数域中将A对角化为P\Lambda P^{-1},那么式\eqref{eq:rnn-iter}变为
\begin{equation}x_t^{(n)} - P\Lambda P^{-1} x_{t-1}^{(n)} = \tanh(P\Lambda P^{-1} x_{t-1}^{(n-1)} + u_t) - P\Lambda P^{-1} x_{t-1}^{(n-1)}\end{equation}
两端都左乘P^{-1}:
\begin{equation}P^{-1} x_t^{(n)} - \Lambda P^{-1} x_{t-1}^{(n)} = P^{-1}\tanh(P\Lambda P^{-1} x_{t-1}^{(n-1)} + u_t) - \Lambda P^{-1} x_{t-1}^{(n-1)}\end{equation}
令y_t = P^{-1} x_t,那么上式可以简化为
\begin{equation}y_t^{(n)} - \Lambda y_{t-1}^{(n)} = P^{-1}\tanh(P\Lambda y_{t-1}^{(n-1)} + u_t) - \Lambda y_{t-1}^{(n-1)}\end{equation}
由于RNN之后一般都还要接个投影层,所以x_t = P y_t的P原则上可以合并到外接的投影层里边,也就是说,上式理论上具备跟原来的\eqref{eq:rnn}具备同等的表达能力,但由于\Lambda是对角阵,递归的计算量会明显降低。上式还出现了逆矩阵P^{-1},不单计算量大,而且不利于优化,所以我们可以干脆将P^{-1}和P\Lambda换成两个不相关的参数矩阵:
\begin{equation}y_t^{(n)} - \Lambda y_{t-1}^{(n)} = P\tanh(Q y_{t-1}^{(n-1)} + u_t) - \Lambda y_{t-1}^{(n-1)}\end{equation}
只要初始化是PQ=\Lambda就行。
摄动思想 #
假定x_t^{(0)}=0,那么式\eqref{eq:rnn-iter}其实就是将原本的非线性RNN就分解为一系列线性RNN:
\begin{equation}\begin{array}{c}
x_t^{(1)} - Ax_{t-1}^{(1)} = \tanh(u_t)\\
x_t^{(2)} - Ax_{t-1}^{(2)} = \tanh(Ax_{t-1}^{(1)} + u_t) - Ax_{t-1}^{(1)} \\
\vdots \\
x_t^{(n)} - Ax_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - Ax_{t-1}^{(n-1)} \\
\vdots \\
\end{array}\label{eq:rnns}\end{equation}
而假设x_{t-1},u_t都是小量,那么对式\eqref{eq:rnn}右端利用\tanh x \approx x得到:
\begin{equation}x_t = \tanh(Ax_{t-1} + u_t) \approx Ax_{t-1} + u_t \approx Ax_{t-1} + \tanh(u_t)\label{eq:rnn-approx}\end{equation}
这正好是\eqref{eq:rnns}中的第一个方程,因此如果假设成立,那么x_t^{(1)}或许已经足够接近理想的x_t^*,后面的每一步迭代都在快速逼近它。从这里我们可以看出,“两边同时减去Ax_{t-1}”是关键之处,这使得\eqref{eq:rnn-iter}的第一步迭代就接近于原本非线性RNN的一阶线性近似,这可以提高收敛速度,也是数学物理中的经典操作,名曰“摄动”。
加快收敛 #
根据摄动法的思想,提高收敛速度的关键就是提高近似展开的精度,比如较为简单的改进是只假设x_{t-1}是小量,那么根据一阶泰勒展开有(将u_t视为列向量,这里的\circ是Hadamard积分)
\begin{equation}x_t = \tanh(Ax_{t-1} + u_t) \approx \tanh(u_t) + (\text{sech}^2 u_t\circ A)x_{t-1}\end{equation}
于是改进的结果就是式\eqref{eq:rnn-iter}变为
\begin{equation}x_t^{(n)} - A_t x_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - A_t x_{t-1}^{(n-1)}\label{eq:iter-plus1}\end{equation}
其中A_t = \text{sech}^2 u_t\circ A。更精细的改进是在每一步迭代时,都在前一步迭代结果的基础上进行展开:
\begin{equation}\begin{aligned}
x_t =&\, \tanh(Ax_{t-1} + u_t) \\
\approx&\, \tanh(Ax_{t-1}^{(n-1)} + u_t) + (\text{sech}^2 (Ax_{t-1}^{(n-1)} + u_t)\circ A)(x_{t-1} - x_{t-1}^{(n-1)})
\end{aligned}\end{equation}
于是式\eqref{eq:rnn-iter}变为
\begin{equation}x_t^{(n)} - A_t^{(n)} x_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - A_t^{(n)} x_{t-1}^{(n-1)}\label{eq:iter-plus2}\end{equation}
其中A_t^{(n)}=\text{sech}^2 (Ax_{t-1}^{(n-1)} + u_t)\circ A。最后的这个迭代格式,实际上就是求方程数值解的“牛顿法”,它具有二次收敛速度。
何必收敛 #
理论上来说,\eqref{eq:iter-plus1}、\eqref{eq:iter-plus2}两个改进确实能提高收敛速度,然而它们使得每一步线性递归的矩阵A变得跟t甚至n有关了,这其实会大大增加并行的复杂度,也不能利用“简化形式”一节的对角化技巧来加速。另一方面,如果保持\eqref{eq:rnn-iter}这样的迭代格式,虽然有诸多效率上的好处,但收敛方面确实无法得到很好的保障。
难道这两者的矛盾就无法调和了吗?事实上,按照笔者的观点,最直接的做法是“别去管它”——借助非线性RNN导出了\eqref{eq:rnn-iter}后,就忘记原本的非线性RNN,将式\eqref{eq:rnn-iter}作为基本模型。也就是说,何必忧虑式\eqref{eq:rnn-iter}会不会收敛到原来的非线性RNN?直接将它作为新的出发点不好吗?梯度下降学到怎样的结果就是怎样的结果,如果梯度下降学到的结果是不收敛到原来的非线性RNN,那么就意味着不收敛到原来的RNN是更适合的。
抛开这一层思维束缚后,其实很多问题会变得豁然开朗起来。首先,即便是式\eqref{eq:iter-plus2}在理论上拥有非常好的收敛速度,但也是有条件的,而且在深度学习的背景下,要保证这些条件会显得很奢侈。换言之,即便是式\eqref{eq:iter-plus2}的收敛性也没有绝对保证,所以何必“五十步笑百步”去苛责式\eqref{eq:rnn-iter}?其次,将式\eqref{eq:rnn-iter}视为新的出发点后,我们可以将它单纯地理解为线性RNN的一种新用法,或者说解决线性RNN缺陷(比如线性RNN不是图灵完备的)的一个思路,这样操作性更强。
总的来说,不去管它的收敛性,似乎更能打破思维僵局,探索更一般的结果。
一般情形 #
前面的“长篇大论”,都只围绕着简单的非线性RNN也就是式\eqref{eq:rnn}进行讨论,对于更常用的LSTM、GRU,结果又如何呢?
以GRU为例,它原本的形式为
\begin{equation}\begin{aligned} z_{t} & = \sigma \left( W_{z} x_{t} + U_{z} h_{t - 1} + b_{z} \right) \\
r_{t} & = \sigma \left( W_{r} x_{t} + U_{r} h_{t - 1} + b_{r} \right) \\
\hat{h}_t & = \tanh \left( W_{h} x_{t} + U_{h} (r_t \circ h_{t - 1}) + b_{c} \right)\\
h_{t} & = \left(1 - z_{t}\right) \circ h_{t - 1} + z_{t} \circ \hat{h}_t \end{aligned}\end{equation}
初始阶段,所有门控都可以近似视为\frac{1}{2},那么模仿式\eqref{eq:rnn-approx}有
\begin{equation}\begin{aligned}
h_{t} &\, = \left(1 - z_{t}\right) \circ h_{t - 1} + z_{t} \circ \hat{h}_t \\
&\, \approx \frac{1}{2} h_{t - 1} + \frac{1}{2} \hat{h}_t \\
&\, \approx \frac{1}{2} h_{t - 1} + \frac{1}{2} \left(\tanh ( W_{h} x_{t} + b_{c} ) + \frac{1}{2}U_{h} h_{t - 1}\right) \\
&\, = \frac{1}{2} \left(I + \frac{1}{2}U_{h}\right)h_{t - 1} + \frac{1}{2} \tanh ( W_{h} x_{t} + b_{c} ) \\
\end{aligned}\end{equation}
所以可以选取A=\frac{1}{2} \left(I + \frac{1}{2}U_{h}\right),将GRU改写为迭代
\begin{equation}\begin{aligned} z_{t}^{(n)} & = \sigma \left( W_{z} x_{t} + U_{z} h_{t - 1}^{(n-1)} + b_{z} \right) \\
r_{t}^{(n)} & = \sigma \left( W_{r} x_{t} + U_{r} h_{t - 1}^{(n-1)} + b_{r} \right) \\
\hat{h}_t^{(n)} & = \tanh \left( W_{h} x_{t} + U_{h} (r_t^{(n)} \circ h_{t - 1}^{(n-1)}) + b_{c} \right)\\
h_{t}^{(n)} & = Ah_{t-1}^{(n)} - Ah_{t-1}^{(n - 1)} + \left(1 - z_{t}^{(n)}\right) \circ h_{t - 1}^{(n-1)} + z_{t}^{(n)} \circ \hat{h}_t^{(n)} \end{aligned}\end{equation}
总的来说,这种将非线性RNN变为线性RNN迭代的转换,从实践的角度来看,就是以非线性RNN为引,导出一种多层线性RNN的参数共享和组合方法,迭代了几次,那么就有几层线性RNN的计算量。这样自然而言就引发了一个思考:除非可以证明GRU、LSTM等非线性RNN有绝对的优势,否则直接叠加几层“线性RNN+MLP”不好吗?
文章小结 #
本文简单探讨了非线性RNN的并行计算问题——通过数学物理中的“摄动”思想,我们可以将非线性RNN转化为线性RNN的迭代,从而利用线性RNN的可并行性来实现非线性RNN的并行。
转载到请包括本文地址:https://spaces.ac.cn/archives/9783
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Sep. 26, 2023). 《脑洞大开:非线性RNN居然也可以并行计算? 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9783
@online{kexuefm-9783,
title={脑洞大开:非线性RNN居然也可以并行计算?},
author={苏剑林},
year={2023},
month={Sep},
url={\url{https://spaces.ac.cn/archives/9783}},
}
September 26th, 2023
Google 最近有提出 TSMixer (https://openreview.net/forum?id=wbpxTuXgm0)。
扫了一眼,似乎不是特别有意思。
September 26th, 2023
最近這幾年有一個主張是透過 dynamical system 把 RNN 和 ODE/PDE/SDE 聯繫在一起。
您好,具体哪些 dynamical system? 方便列一些吗
这个联系比较自然吧。就本文提到的论文而言,它主要就是做ODE/PDE的,看上去做RNN只是顺便。
September 26th, 2023
[...]Read More [...]
September 29th, 2023
爲甚麼 公式 (7) 的中只需要初始化保證 \mathbf{P}\mathbf{Q} = \mathbf{\Lambda} 就可以了? Neural Networks 在訓練的過程中,weight tensor 不是會一直改變的嗎?\mathbf{P}\mathbf{Q} = \mathbf{\Lambda} 在訓練過程中可能會無法保證也沒關係?
如果训练过程不保证\boldsymbol{P}\boldsymbol{Q} = \boldsymbol{\Lambda},可能是因为优化器觉得\boldsymbol{P}\boldsymbol{Q} \neq \boldsymbol{\Lambda}更好,所以并没有什么关系。
October 10th, 2023
hello,从感受野的角度来看,迭代的过程,x_t要看到x_0,不是整个不动点迭代要跑t个step吗,这和原来的rnn比,为什么能快呢?
可以用 prefix sum 來並行運算 RNN,這樣就比原先衹能迭代運算 RNN 要快很多。
线性RNN是可以并行的(参考 https://kexue.fm/archives/9554 ),理想情况下复杂度是\mathcal{O}(\log t),所以总的复杂度再乘以迭代步数,因此时间上有可能少于直接计算非线性RNN。
October 18th, 2023
这个(16)是怎么推导出来的,有大神能解释一下细节吗
所有的中间变量都用上一步的迭代结果计算,然后用两端减去一阶近似的方式去构造最终的迭代格式,保证^{(n-1)}替换为^{(n)}时,退化为原本的非线性RNN。
November 12th, 2023
令 \mathbf{P} 爲 learnable orthogonal matrix 即可,比如可以用 expRNN 裡提出的方法。
这样计算复杂度又更大了?
還有一個辦法,在公式 (3) 中,用 \mathbf{A} = \mathbf{F}\mathbf{\Lambda}\mathbf{F}^{\ast} 替代 \mathbf{A} = \mathbf{P}\mathbf{\Lambda}\mathbf{P}^{-1},其中 \mathbf{F} 是 DFT 矩陣。從 signal processing 的角度來看,\mathbf{P}^{-1} 和 \mathbf{P} 的作用是對 \mathbf{x}_{t-1} 做 discrete orthogonal transform 和 inverse transform。用 DFT 矩陣替代之後,可以用 FFT 來加速運算。
这样的话感觉干脆换用基于FFT的模型得了(捂脸)~
說不定 FFT based Linear RNN 確實可以
還真有這樣的模型(https://openreview.net/forum?id=IxmWsm4xrua)
这个就是广义的FFT呀,并不是反过来用FFT来加速RNN
不會。我最近在試類似的方法,時間複雜度和空間複雜度僅比O(NlogN)大一點點。
December 7th, 2023
有个不太懂的是,这种技术可以用于RNN的并行推理吗?还是只能用于并行训练
理论上可以,但实际上有点难,因为无法提前预知结果长度。
September 2nd, 2024
想问一下以(8)式为例,从step=n-1到step=n,右式的x_{t-1}^{n-1}如何得到呢?前一个式子得到的是x_{t}^{n-1}-Ax_{t-1}^{n-1},如果想要单独求出x_{t-1}^{n-1}是不是还需要类似于递推或者开一个大数组从x_{1}^{n-1}开始往后一个一个计算+存储?
如果是这样的话,对序列的初始值x_1,(9)式有:x_1^{0}=tanh(u_1)(因为x_0 \triangleq 0),那就意味着x_1是固定且无损的,不会随着迭代而更新,导致后面所有的x_{t}^{n}都无法更新了?
第一步,x_t^{(1)} - Ax_{t-1}^{(1)} = \tanh(u_t),这就是个线性RNN,可以并行求出x_1^{(1)},x_2^{(1)},x_3^{(1)},\cdots;
第二步,x_t^{(2)} - Ax_{t-1}^{(2)} = \tanh(Ax_{t-1}^{(1)} + u_t) - Ax_{t-1}^{(1)},其中所有的x_1^{(1)},x_2^{(1)},x_3^{(1)},\cdots已经在上一步求出,因此也可以并行求出x_1^{(2)},x_2^{(2)},x_3^{(2)},\cdots;
依此类推~
嗯嗯,这个逻辑乍一看很显然,但是感觉细想还是不太对。n迭代过程中x^{(n)}能够产生更新的原因还是因为真值在随着n的迭代的过程中在跟着t进行传递吧,这样理论上n=t才可以收敛?
第一步x_1^{(1)}=tanh(u_1)已经是真值x_1了,而且这个真值每一步迭代都是这样的形式,不会改变(x_1是首项,即x_0 \triangleq 0,因此x_1的递推式和后面的不同);第二步这个真值根据x_2^{(2)}-Ax_1^{(2)}=tanh(Ax_1^{(1)}+u_2)-Ax_1^{(1)},由于x_1^{(1)}=x_1^{(2)},所以x_2^{(2)}=tanh(Ax_1^{(1)}+u_2)=tanh(Ax_1+u_2),也即正确的x_1在迭代第二步n=2的时候传递到了x_2处,使得x_2也获得了真值。
之后类似,整体的迭代的收敛性其实是因为真值在随着n根据t来逐步往后传的,如果真值传不到(n \ll t_{length}),误差会很大吧?
所以你的意思是想要表达这样做的误差可能很大,以及你为什么觉得误差大的一些看法?
如果这样的话,那这个问题确实无解了,要严格保证它收敛于原始非线性RNN还是要静心调试的,我更倾向于将它当成一种复杂线性RNN的构建方式,事后就当它线性RNN来用,保持训练和预测的一致性。