尽管Transformer类的模型已经攻占了NLP的多数领域,但诸如LSTM、GRU之类的RNN模型依然在某些场景下有它的独特价值,所以RNN依然是值得我们好好学习的模型。而对于RNN梯度的相关分析,则是一个从优化角度思考分析模型的优秀例子,值得大家仔细琢磨理解。君不见,诸如“LSTM为什么能解决梯度消失/爆炸”等问题依然是目前流行的面试题之一...

经典的LSTM

经典的LSTM

关于此类问题,已有不少网友做出过回答,然而笔者查找了一些文章(包括知乎上的部分回答、专栏以及经典的英文博客),发现没有找到比较好的答案:有些推导记号本身就混乱不堪,有些论述过程没有突出重点,整体而言感觉不够清晰自洽。为此,笔者也尝试给出自己的理解,供大家参考。

RNN及其梯度 #

RNN的统一定义为
ht=f(xt,ht1;θ)
其中ht是每一步的输出,它由当前输入xt和前一时刻输出ht1共同决定,而θ则是可训练参数。在做最基本的分析时,我们可以假设ht,xt,θ都是一维的,这可以让我们获得最直观的理解,并且其结果对高维情形仍有参考价值。之所以要考虑梯度,是因为我们目前主流的优化器还是梯度下降及其变种,因此要求我们定义的模型有一个比较合理的梯度。我们可以求得:
dhtdθ=htht1dht1dθ+htθ
可以看到,其实RNN的梯度也是一个RNN,当前时刻梯度dhtdθ是前一时刻梯度dht1dθ与当前运算梯度htθ的函数。同时,从上式我们就可以看出,其实梯度消失或者梯度爆炸现象几乎是必然存在的:当|htht1|<1时,意味着历史的梯度信息是衰减的,因此步数多了梯度必然消失(好比limn0.9n0);当|htht1|>1,因为这历史的梯度信息逐步增强,因此步数多了梯度必然爆炸(好比limn1.1n)。总不可能一直|htht1|=1吧?当然,也有可能有些时刻大于1,有些时刻小于1,最终稳定在1附近,但这样概率很小,需要很精巧地设计模型才行。

所以步数多了,梯度消失或爆炸几乎都是不可避免的,我们只能对于有限的步数去缓解这个问题。

消失还是爆炸? #

说到这里,我们还没说清楚一个问题:什么是RNN的梯度消失/爆炸?梯度爆炸好理解,就是梯度数值发散,甚至慢慢就NaN了;那梯度消失就是梯度变成零吗?并不是,我们刚刚说梯度消失是|htht1|一直小于1,历史梯度不断衰减,但不意味着总的梯度就为0了,具体来说,一直迭代下去,我们有
dhtdθ=htht1dht1dθ+htθ=htθ+htht1ht1θ+htht1ht1ht2ht2θ+
显然,其实只要htθ不为0,那么总梯度为0的概率其实是很小的;但是一直迭代下去的话,那么h1θ这一项前面的稀疏就是t1项的连乘htht1ht1ht2h2h1,如果它们的绝对值都小于1,那么结果就会趋于0,这样一来,dhtdθ几乎就没有包含最初的梯度h1θ的信息了,这才是RNN中梯度消失的含义:距离当前时间步越长,那么其反馈的梯度信号越不显著,最后可能完全没有起作用,这就意味着RNN对长距离语义的捕捉能力失效了。

说白了,你优化过程都跟长距离的反馈没关系,怎么能保证学习出来的模型能有效捕捉长距离呢?

几个数学公式 #

上面的文字都是一般性的分析,接下来我们具体RNN具体分析。不过在此之前,我们需要回顾几条数学公式,后面的推导中我们将多次运用到这几条公式:
tanhx=2σ(2x)1σ(x)=12(tanhx2+1)(tanhx)=1tanh2xσ(x)=σ(x)(1σ(x))
其中σ(x)=1/(1+ex)是sigmoid函数。这几条公式其实就是说了这么一件事:tanhxσ(x)基本上是等价的,它们的导数均可以用它们自身来表示。

简单RNN分析 #

首先登场的是比较原始的简单RNN(有时候我们确实直接称它为SimpleRNN),它的公式为:
ht=tanh(Wxt+Uht1+b)
其中W,U,b是待优化参数。看到这里很自然就能提出第一个疑问:为什么激活函数用tanh而不是更流行的relu?这是个好问题,我们很快就会回答它。

从上面的讨论中我们已经知道,梯度消失还是爆炸主要取决于|htht1|,所以我们计算
htht1=(1h2t)U
由于我们无法确定U的范围,因此|htht1|可能小于1也可能大于1,梯度消失/爆炸的风险是存在的。但有意思的是,如果|U|很大,那么相应地ht就会很接近1或-1,这样(1h2t)U反而会小,事实上可以严格证明:如果固定ht10,那么(1h2t)U作为U的函数是有界的,也就是说不管U等于什么,它都不超过一个固定的常数。

这样一来,我们就能回答为什么激活函数要用tanh了,因为激活函数用tanh后,对应的梯度htht1是有界的,虽然这个界未必是1,但一个有界的量不超过1的概率总高于无界的量,因此梯度爆炸的风险更低。相比之下,如果用relu激活的话,它在正半轴的导数恒为1,此时htht1=U是无界的,梯度爆炸风险更高。

所以,RNN用tanh而不是relu的主要目的就是缓解梯度爆炸风险。当然,这个缓解是相对的,用了tanh依然有爆炸的可能性。事实上,处理梯度爆炸的最根本方法是参数裁剪或梯度裁剪,换句话说我人为地把U给裁剪到[1,1]内,那不就可以保证梯度不爆了吗?当然,又有读者会问,既然裁剪可以解决问题,那么是不是可以用relu了?确实是这样子,配合良好的初始化方法和参数/梯度裁剪方案,relu版的RNN也可以训练好,但是我们还是愿意用tanh,这还是因为它对应的htht1有界,要裁剪也不用裁剪得太厉害,模型的拟合能力可能会更好。

LSTM的结果 #

当然,裁剪的方式虽然也能work,但终究是无奈之举,况且裁剪也只能解决梯度爆炸问题,解决不了梯度消失,如果能从模型设计上解决这个问题,那么自然是最好的。传说中的LSTM就是这样的一种设计,真相是否如此?我们马上来分析一下。

LSTM的更新公式比较复杂,它是:
ft=σ(Wfxt+Ufht1+bf)it=σ(Wixt+Uiht1+bi)ot=σ(Woxt+Uoht1+bo)ˆct=tanh(Wcxt+Ucht1+bc)ct=ftct1+itˆctht=ottanh(ct)
我们可以像上面一样计算htht1,但从ht=ottanh(ct)可以看出分析ct就等价于分析ht,而计算ctct1显得更加简单一些,因此我们往这个方向走。

同样地,我们先只关心1维的情形,这时候根据求导公式,我们有
ctct1=ft+ct1ftct1+ˆctitct1+itˆctct1
右端第一项ft,也就是我们所说的“遗忘门”,从下面的论述我们可以知道一般情况下其余三项都是次要项,因此ft是“主项”,由于ft在0~1之间,因此就意味着梯度爆炸的风险将会很小,至于会不会梯度消失,取决于ft是否接近于1。但非常碰巧的是,这里有个相当自洽的结论:如果我们的任务比较依赖于历史信息,那么ft就会接近于1,这时候历史的梯度信息也正好不容易消失;如果ft很接近于0,那么就说明我们的任务不依赖于历史信息,这时候就算梯度消失也无妨了。

所以,现在的关键就是看“其余三项都是次要项”这个结论能否成立。后面的三项都是“一项乘以另一项的偏导”的形式,而且求偏导的项都是σtanh激活,前面在回顾数学公式的时候说了σtanh基本上是等价的,因此后面三项是类似的,分析了其中一项就相当于分析了其余两项。以第二项为例,代入ht1=ot1tanh(ct1),可以算得
ct1ftct1=ft(1ft)ot1(1tanh2ct1)ct1Uf
注意到ft,1ft,ot1,都是在0~1之间,也可以证明|(1tanh2ct1)ct1|<0.45,因此它也在-1~1之间。所以ct1ftct1就相当于1个Uf乘上4个门,结果会变得更加小,所以只要初始化不是很糟糕,那么它都会被压缩得相当小,因此占不到主导作用。跟简单RNN的梯度(6)相比,它多出了3个门,所以这个变化说白点就是:1个门我压不垮你,多来几个门还不行么?

剩下两项的结论也是类似的:
ˆctitct1=it(1it)ot1(1tanh2ct1)ˆctUiitˆctct1=(1ˆc2t)ot1(1tanh2ct1)itUc

所以,后面三项的梯度带有更多的“门”,一般而言乘起来后会被压缩的更厉害,因此占主导的项还是ftft在0~1之间这个特性决定了它梯度爆炸的风险很小,同时ft表明了模型对历史信息的依赖性,也正好是历史梯度的保留程度,两者相互自洽,所以LSTM也能较好地缓解梯度消失问题。因此,LSTM同时较好地缓解了梯度消失/爆炸问题,现在我们训练LSTM时,多数情况下只需要直接调用Adam等自适应学习率优化器,不需要人为对梯度做什么调整了。

当然,这些结果都是“概论”,你非要构造一个会梯度消失/爆炸的LSTM来,那也是能构造出来的。此外,就算LSTM能缓解这两个问题,也是在一定步数内,如果你的序列很长,比如几千上万步,那么该消失的还会消失。毕竟单靠一个向量,也缓存不了那么多信息啊~

顺便看看GRU #

在文章结束之前,我们顺便对LSTM的强力竞争对手GRU也做一个分析。GRU的运算过程为:
zt=σ(Wzxt+Uzht1+bz)rt=σ(Wrxt+Urht1+br)ˆht=tanh(Whxt+Uh(rtht1)+bc)ht=(1zt)ht1+ztˆht
还有个更极端的版本是将rt,zt合成一个:
rt=σ(Wrxt+Urht1+br)ˆht=tanh(Whxt+Uh(rtht1)+bc)ht=(1rt)ht1+rtˆht
不管是哪一个,我们发现它在算ˆht的时候,ht1都是先乘个rt变成rtht1,不知道读者是否困惑过这一点?直接用ht1不是更简洁更符合直觉吗?

首先我们观察到,而h0一般全零初始化,ˆht则因为tanh激活,因此结果必然在-1~1之间,所以作为ht1ˆht的加权平均的ht也一直保持在-1~1之间,因此ht本身就有类似门的作用。这跟LSTM的ct不一样,理论上ct是有可能发散的。了解到这一点后,我们再去求导:
htht1=1ztzt(1zt)ht1Uz+zt(1zt)ˆhtUz+(1ˆh2t)rt(1+(1rt)ht1Ur)ztUh
其实结果跟LSTM的类似,主导项应该是1zt,但剩下的项比LSTM对应的项少了1个门,因此它们的量级可能更大,相对于LSTM的梯度其实更不稳定,特别是rtht1这步操作,虽然给最后一项引入了多一个门rt,但也同时引入了多一项1+(1rt)ht1Ur,是好是歹很难说。总体相对而言,感觉GRU应该会更不稳定,比LSTM更依赖于好的初始化方式

针对上述分析结果,个人认为如果沿用GRU的思想,又需要简化LSTM并且保持LSTM对梯度的友好性,更好的做法是把rtht1放到最后:
zt=σ(Wzxt+Uzht1+bz)rt=σ(Wrxt+Urht1+br)ˆct=tanh(Whxt+Uhht1+bc)ct=(1zt)ct1+ztˆctht=rtct
当然,这样需要多缓存一个变量,带来额外的显存消耗了。

文章总结概述 #

本文讨论了RNN的梯度消失/爆炸问题,主要是从梯度函数的有界性、门控数目的多少来较为明确地讨论RNN、LSTM、GRU等模型的梯度流情况,以确定其中梯度消失/爆炸风险的大小。本文属于闭门造车之作,如有错漏,请读者海涵并斧正。

转载到请包括本文地址:https://spaces.ac.cn/archives/7888

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Nov. 13, 2020). 《也来谈谈RNN的梯度消失/爆炸问题 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/7888

@online{kexuefm-7888,
        title={也来谈谈RNN的梯度消失/爆炸问题},
        author={苏剑林},
        year={2020},
        month={Nov},
        url={\url{https://spaces.ac.cn/archives/7888}},
}