模型优化漫谈:BERT的初始标准差为什么是0.02?
By 苏剑林 | 2021-11-08 | 90437位读者 |前几天在群里大家讨论到了“Transformer如何解决梯度消失”这个问题,答案有提到残差的,也有提到LN(Layer Norm)的。这些是否都是正确答案呢?事实上这是一个非常有趣而综合的问题,它其实关联到挺多模型细节,比如“BERT为什么要warmup?”、“BERT的初始化标准差为什么是0.02?”、“BERT做MLM预测之前为什么还要多加一层Dense?”,等等。本文就来集中讨论一下这些问题。
梯度消失说的是什么意思? #
在文章《也来谈谈RNN的梯度消失/爆炸问题》中,我们曾讨论过RNN的梯度消失问题。事实上,一般模型的梯度消失现象也是类似,它指的是(主要是在模型的初始阶段)越靠近输入的层梯度越小,趋于零甚至等于零,而我们主要用的是基于梯度的优化器,所以梯度消失意味着我们没有很好的信号去调整优化前面的层。
换句话说,前面的层也许几乎没有得到更新,一直保持随机初始化的状态;只有比较靠近输出的层才更新得比较好,但这些层的输入是前面没有更新好的层的输出,所以输入质量可能会很糟糕(因为经过了一个近乎随机的变换),因此哪怕后面的层更新好了,总体效果也不好。最终,我们会观察到很反直觉的现象:模型越深,效果越差,哪怕训练集都如此。
解决梯度消失的一个标准方法就是残差链接,正式提出于ResNet中。残差的思想非常简单直接:你不是担心输入的梯度会消失吗?那我直接给它补上一个梯度为常数的项不就行了?最简单地,将模型变成
\begin{equation}y = x + F(x)\end{equation}
这样一来,由于多了一条“直通”路$x$,就算$F(x)$中的$x$梯度消失了,$x$的梯度基本上也能得以保留,从而使得深层模型得到有效的训练。
LN真的能缓解梯度消失? #
然而,在BERT和最初的Transformer里边,使用的是Post Norm设计,它把Norm操作加在了残差之后:
\begin{equation}x_{t+1} = \text{Norm}(x_t + F_t(x_t))\end{equation}
其实具体的Norm方法不大重要,不管是Batch Norm还是Layer Norm,结论都类似。在文章《浅谈Transformer的初始化、参数化与标准化》中,我们已经分析过这种Norm结构,这里再来重复一下。
在初始化阶段,由于所有参数都是随机初始化的,所以我们可以认为$x$与$F(x)$是两个相互独立的随机向量,如果假设它们各自的方差是1,那么$x+F(x)$的方差就是2,而$\text{Norm}$操作负责将方差重新变为1,那么在初始化阶段,$\text{Norm}$操作就相当于“除以$\sqrt{2}$”:
\begin{equation}x_{t+1} = \frac{x_t + F_t(x_t)}{\sqrt{2}}\end{equation}
递归下去就是
\begin{equation}\begin{aligned}
x_l =&\, \frac{x_{l-1}}{\sqrt{2}} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\
=&\, \frac{x_{l-2}}{2} + \frac{F_{l-2}(x_{l-2})}{2} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\
=&\, \cdots \\
=&\,\frac{x_0}{2^{l/2}} + \frac{F_0(x_0)}{2^{l/2}} + \frac{F_1(x_1)}{2^{(l-1)/2}} + \frac{F_2(x_2)}{2^{(l-2)/2}} + \cdots + \frac{F_{l-1}(x_{l-1})}{2^{1/2}}
\end{aligned}\end{equation}
我们知道,残差有利于解决梯度消失,但是在Post Norm中,残差这条通道被严重削弱了,越靠近输入,削弱得越严重,残差“名存实亡”。所以说,在Post Norm的BERT模型中,LN不仅不能缓解梯度消失,它还是梯度消失的“元凶”之一。
那我们为什么还要加LN? #
那么,问题自然就来了:既然LN还加剧了梯度消失,那直接去掉它不好吗?
是可以去掉,但是前面说了,$x+F(x)$的方差就是2了,残差越多方差就越大了,所以还是要加一个Norm操作,我们可以把它加到每个模块的输入,即变为$x+F(\text{Norm}(x))$,最后的总输出再加个$\text{Norm}$就行,这就是Pre Norm结构,这时候每个残差分支是平权的,而不是像Post Norm那样有指数衰减趋势。当然,也有完全不加Norm的,但需要对$F(x)$进行特殊的初始化,让它初始输出更接近于0,比如ReZero、Skip Init、Fixup等,这些在《浅谈Transformer的初始化、参数化与标准化》也都已经介绍过了。
但是,抛开这些改进不说,Post Norm就没有可取之处吗?难道Transformer和BERT开始就带了一个完全失败的设计?
显然不大可能。虽然Post Norm会带来一定的梯度消失问题,但其实它也有其他方面的好处。最明显的是,它稳定了前向传播的数值,并且保持了每个模块的一致性。比如BERT base,我们可以在最后一层接一个Dense来分类,也可以取第6层接一个Dense来分类;但如果你是Pre Norm的话,取出中间层之后,你需要自己接一个LN然后再接Dense,否则越靠后的层方差越大,不利于优化。
其次,梯度消失也不全是“坏处”,其实对于Finetune阶段来说,它反而是好处。在Finetune的时候,我们通常希望优先调整靠近输出层的参数,不要过度调整靠近输入层的参数,以免严重破坏预训练效果。而梯度消失意味着越靠近输入层,其结果对最终输出的影响越弱,这正好是Finetune时所希望的。所以,预训练好的Post Norm模型,往往比Pre Norm模型有更好的Finetune效果,这我们在《RealFormer:把残差转移到Attention矩阵上面去》也提到过。
我们真的担心梯度消失吗? #
其实,最关键的原因是,在当前的各种自适应优化技术下,我们已经不大担心梯度消失问题了。
这是因为,当前NLP中主流的优化器是Adam及其变种。对于Adam来说,由于包含了动量和二阶矩校正,所以近似来看,它的更新量大致上为
\begin{equation}\Delta \theta = -\eta\frac{\mathbb{E}_t[g_t]}{\sqrt{\mathbb{E}_t[g_t^2]}}\end{equation}
可以看到,分子分母是都是同量纲的,因此分式结果其实就是$\mathcal{O}(1)$的量级,而更新量就是$\mathcal{O}(\eta)$量级。也就是说,理论上只要梯度的绝对值大于随机误差,那么对应的参数都会有常数量级的更新量;这跟SGD不一样,SGD的更新量是正比于梯度的,只要梯度小,更新量也会很小,如果梯度过小,那么参数几乎会没被更新。
所以,Post Norm的残差虽然被严重削弱,但是在base、large级别的模型中,它还不至于削弱到小于随机误差的地步,因此配合Adam等优化器,它还是可以得到有效更新的,也就有可能成功训练了。当然,只是有可能,事实上越深的Post Norm模型确实越难训练,比如要仔细调节学习率和Warmup等。
Warmup是怎样起作用的? #
大家可能已经听说过,Warmup是Transformer训练的关键步骤,没有它可能不收敛,或者收敛到比较糟糕的位置。为什么会这样呢?不是说有了Adam就不怕梯度消失了吗?
要注意的是,Adam解决的是梯度消失带来的参数更新量过小问题,也就是说,不管梯度消失与否,更新量都不会过小。但对于Post Norm结构的模型来说,梯度消失依然存在,只不过它的意义变了。根据泰勒展开式:
\begin{equation}f(x+\Delta x) \approx f(x) + \langle\nabla_x f(x), \Delta x\rangle\end{equation}
也就是说增量$f(x+\Delta x) - f(x)$是正比于梯度的,换句话说,梯度衡量了输出对输入的依赖程度。如果梯度消失,那么意味着模型的输出对输入的依赖变弱了。
Warmup是在训练开始阶段,将学习率从0缓增到指定大小,而不是一开始从指定大小训练。如果不进行Wamrup,那么模型一开始就快速地学习,由于梯度消失,模型对越靠后的层越敏感,也就是越靠后的层学习得越快,然后后面的层是以前面的层的输出为输入的,前面的层根本就没学好,所以后面的层虽然学得快,但却是建立在糟糕的输入基础上的。
很快地,后面的层以糟糕的输入为基础到达了一个糟糕的局部最优点,此时它的学习开始放缓(因为已经到达了它认为的最优点附近),同时反向传播给前面层的梯度信号进一步变弱,这就导致了前面的层的梯度变得不准。但我们说过,Adam的更新量是常数量级的,梯度不准,但更新量依然是数量级,意味着可能就是一个常数量级的随机噪声了,于是学习方向开始不合理,前面的输出开始崩盘,导致后面的层也一并崩盘。
所以,如果Post Norm结构的模型不进行Wamrup,我们能观察到的现象往往是:loss快速收敛到一个常数附近,然后再训练一段时间,loss开始发散,直至NAN。如果进行Wamrup,那么留给模型足够多的时间进行“预热”,在这个过程中,主要是抑制了后面的层的学习速度,并且给了前面的层更多的优化时间,以促进每个层的同步优化。
这里的讨论前提是梯度消失,如果是Pre Norm之类的结果,没有明显的梯度消失现象,那么不加Warmup往往也可以成功训练。
初始标准差为什么是0.02? #
喜欢扣细节的同学会留意到,BERT默认的初始化方法是标准差为0.02的截断正态分布,在《浅谈Transformer的初始化、参数化与标准化》我们也提过,由于是截断正态分布,所以实际标准差会更小,大约是$0.02/1.1368472\approx 0.0176$。这个标准差是大还是小呢?对于Xavier初始化来说,一个$n\times n$的矩阵应该用$1/n$的方差初始化,而BERT base的$n$为768,算出来的标准差是$1/\sqrt{768}\approx 0.0361$。这就意味着,这个初始化标准差是明显偏小的,大约只有常见初始化标准差的一半。
为什么BERT要用偏小的标准差初始化呢?事实上,这还是跟Post Norm设计有关,偏小的标准差会导致函数的输出整体偏小,从而使得Post Norm设计在初始化阶段更接近于恒等函数,从而更利于优化。具体来说,按照前面的假设,如果$x$的方差是1,$F(x)$的方差是$\sigma^2$,那么初始化阶段,$\text{Norm}$操作就相当于除以$\sqrt{1+\sigma^2}$。如果$\sigma$比较小,那么残差中的“直路”权重就越接近于1,那么模型初始阶段就越接近一个恒等函数,就越不容易梯度消失。
正所谓“我们不怕梯度消失,但我们也不希望梯度消失”,简单地将初始化标注差设小一点,就可以使得$\sigma$变小一点,从而在保持Post Norm的同时缓解一下梯度消失,何乐而不为?那能不能设置得更小甚至全零?一般来说初始化过小会丧失多样性,缩小了模型的试错空间,也会带来负面效果。综合来看,缩小到标准的1/2,是一个比较靠谱的选择了。
当然,也确实有人喜欢挑战极限的,最近笔者也看到了一篇文章,试图让整个模型用几乎全零的初始化,还训练出了不错的效果,大家有兴趣可以读读,文章为《ZerO Initialization: Initializing Residual Networks with only Zeros and Ones》。
为什么MLM要多加Dense? #
最后,是关于BERT的MLM模型的一个细节,就是BERT在做MLM的概率预测之前,还要多接一个Dense层和LN层,这是为什么呢?不接不行吗?
之前看到过的答案大致上是觉得,越靠近输出层的,越是依赖任务的(Task-Specified),我们多接一个Dense层,希望这个Dense层是MLM-Specified的,然后下游任务微调的时候就不是MLM-Specified的,所以把它去掉。这个解释看上去有点合理,但总感觉有点玄学,毕竟Task-Specified这种东西不大好定量分析。
这里笔者给出另外一个更具体的解释,事实上它还是跟BERT用了0.02的标准差初始化直接相关。刚才我们说了,这个初始化是偏小的,如果我们不额外加Dense就乘上Embedding预测概率分布,那么得到的分布就过于均匀了(Softmax之前,每个logit都接近于0),于是模型就想着要把数值放大。现在模型有两个选择:第一,放大Embedding层的数值,但是Embedding层的更新是稀疏的,一个个放大太麻烦;第二,就是放大输入,我们知道BERT编码器最后一层是LN,LN最后有个初始化为1的gamma参数,直接将那个参数放大就好。
模型优化使用的是梯度下降,我们知道它会选择最快的路径,显然是第二个选择更快,所以模型会优先走第二条路。这就导致了一个现象:最后一个LN层的gamma值会偏大。如果预测MLM概率分布之前不加一个Dense+LN,那么BERT编码器的最后一层的LN的gamma值会偏大,导致最后一层的方差会比其他层的明显大,显然不够优雅;而多加了一个Dense+LN后,偏大的gamma就转移到了新增的LN上去了,而编码器的每一层则保持了一致性。
事实上,读者可以自己去观察一下BERT每个LN层的gamma值,就会发现确实是最后一个LN层的gamma值是会明显偏大的,这就验证了我们的猜测~
希望大家多多海涵批评斧正 #
本文试图回答了Transformer、BERT的模型优化相关的几个问题,有一些是笔者在自己的预训练工作中发现的结果,有一些则是结合自己的经验所做的直观想象。不管怎样,算是分享一个参考答案吧,如果有不当的地方,请大家海涵,也请各位批评斧正~
转载到请包括本文地址:https://spaces.ac.cn/archives/8747
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Nov. 08, 2021). 《模型优化漫谈:BERT的初始标准差为什么是0.02? 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/8747
@online{kexuefm-8747,
title={模型优化漫谈:BERT的初始标准差为什么是0.02?},
author={苏剑林},
year={2021},
month={Nov},
url={\url{https://spaces.ac.cn/archives/8747}},
}
November 8th, 2021
请问从经验来讲,Warmup一般怎么设置呢
通常是总训练步数的10%,但也不一定,反正能让模型正常收敛就差不多了。
November 8th, 2021
看完就知道深度学习实打实的玄学了
November 8th, 2021
你说的那个群可以让我加一下吗?
入群方式已经在明显的地方摆了几年时间。
November 8th, 2021
最后一层的gamma不是偏小吗?
m = BertForMaskedLM.from_pretrained("bert-base-cased")
for i in range(12): print(i, m.bert.encoder.layer[i].output.LayerNorm.weight.mean().detach())
0 tensor(0.8939)
1 tensor(0.9423)
2 tensor(0.9327)
3 tensor(0.9260)
4 tensor(0.9723)
5 tensor(0.9820)
6 tensor(0.9508)
7 tensor(0.9138)
8 tensor(0.9270)
9 tensor(0.9286)
10 tensor(0.9289)
11 tensor(0.7497)
敢情我说了那么多都是白说了...
对于中文BERT base,结果是:
Embedding-Norm/gamma:0 mean: 0.88696283
Transformer-0-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.82883006
Transformer-0-FeedForward-Norm/gamma:0 mean: 0.9062193
Transformer-1-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.8794548
Transformer-1-FeedForward-Norm/gamma:0 mean: 0.95188266
Transformer-2-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.88048005
Transformer-2-FeedForward-Norm/gamma:0 mean: 0.91880983
Transformer-3-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.8744504
Transformer-3-FeedForward-Norm/gamma:0 mean: 0.94945145
Transformer-4-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.8591768
Transformer-4-FeedForward-Norm/gamma:0 mean: 0.9897261
Transformer-5-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.8541221
Transformer-5-FeedForward-Norm/gamma:0 mean: 0.9948392
Transformer-6-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.86367863
Transformer-6-FeedForward-Norm/gamma:0 mean: 0.96400976
Transformer-7-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.83335274
Transformer-7-FeedForward-Norm/gamma:0 mean: 0.9587536
Transformer-8-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.80936116
Transformer-8-FeedForward-Norm/gamma:0 mean: 0.9643423
Transformer-9-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.81411123
Transformer-9-FeedForward-Norm/gamma:0 mean: 0.96500415
Transformer-10-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.8703678
Transformer-10-FeedForward-Norm/gamma:0 mean: 0.9663982
Transformer-11-MultiHeadSelfAttention-Norm/gamma:0 mean: 0.868598
Transformer-11-FeedForward-Norm/gamma:0 mean: 0.81089574
MLM-Norm/gamma:0 mean: 2.5846736
现在知道什么叫做“MLM的最后一个LN层”了吗?
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')
for name, layer in model.named_modules():
if 'LayerNorm' in name:
print(f"Name:{name}\tgamma:{layer.weight.mean()}")
=======Result=========
Name:bert.embeddings.LayerNorm gamma:0.8493084907531738
Name:bert.encoder.layer.0.attention.output.LayerNorm gamma:0.9584805369377136
Name:bert.encoder.layer.0.output.LayerNorm gamma:0.7558717727661133
Name:bert.encoder.layer.1.attention.output.LayerNorm gamma:0.875697672367096
Name:bert.encoder.layer.1.output.LayerNorm gamma:0.8695356249809265
Name:bert.encoder.layer.2.attention.output.LayerNorm gamma:0.867520809173584
Name:bert.encoder.layer.2.output.LayerNorm gamma:0.8513254523277283
Name:bert.encoder.layer.3.attention.output.LayerNorm gamma:0.8626930117607117
Name:bert.encoder.layer.3.output.LayerNorm gamma:0.8110942840576172
Name:bert.encoder.layer.4.attention.output.LayerNorm gamma:0.8386680483818054
Name:bert.encoder.layer.4.output.LayerNorm gamma:0.8397259712219238
Name:bert.encoder.layer.5.attention.output.LayerNorm gamma:0.8466060757637024
Name:bert.encoder.layer.5.output.LayerNorm gamma:0.8322343230247498
Name:bert.encoder.layer.6.attention.output.LayerNorm gamma:0.8275265693664551
Name:bert.encoder.layer.6.output.LayerNorm gamma:0.833625078201294
Name:bert.encoder.layer.7.attention.output.LayerNorm gamma:0.8220710158348083
Name:bert.encoder.layer.7.output.LayerNorm gamma:0.8104994893074036
Name:bert.encoder.layer.8.attention.output.LayerNorm gamma:0.8367538452148438
Name:bert.encoder.layer.8.output.LayerNorm gamma:0.8310534954071045
Name:bert.encoder.layer.9.attention.output.LayerNorm gamma:0.8171906471252441
Name:bert.encoder.layer.9.output.LayerNorm gamma:0.800506591796875
Name:bert.encoder.layer.10.attention.output.LayerNorm gamma:0.8407526016235352
Name:bert.encoder.layer.10.output.LayerNorm gamma:0.8172881007194519
Name:bert.encoder.layer.11.attention.output.LayerNorm gamma:0.8528356552124023
Name:bert.encoder.layer.11.output.LayerNorm gamma:0.633033812046051
Name:cls.predictions.transform.LayerNorm gamma:2.290243625640869
^^^^
苏神的意思是,如果没有最后一个额外的LN,那么最后一层gamma会偏大,而正常的BERT训练的时候是加上了额外的LN的,所以第12层LN是0.633并不大,而prediction层的gamma是2.29.
你要是想验证这个观点,应该是去掉额外的层,重新训练BERT,这样才会看到第12层的 gamma比较大
赞
November 9th, 2021
梯度消失,我理解先是越靠近输出的层,像relu这种激活函数,求导为0-1内,重复相乘导致梯度越来越小。然后反向传播误差就很小,然后前面的层的参数更新很小。
看了这篇文章的描述,“指的是(主要是在模型的初始阶段)越靠近输入的层梯度越小,趋于零甚至等于零”。
所以哪个是原因?靠近输入还是靠近输出的层梯度越小?
你的意思不也是越靠近输入层的梯度越小吗?梯度计算是又后往前使用链式法则算的,越靠近输入端这条链越长,可以简单理解成乘了越多的0-1之间的数,从而导致梯度越来越小
梯度是反向传播的(链式法则)。从$L+1$层到$L$层的梯度小于1,然后重复相乘,导致从$L$层到$1$层的梯度非常小。所以很明显,是靠近输入的层梯度消失。事实上,直观想一下就知道,输出层最靠近loss,与loss相关性最高,肯定是最不可能梯度消失的。
November 9th, 2021
说起几乎全零的权重,想起之前看到的一个蛮有趣的文章。
就是利用浮点数的不连续,那么在0附近的浮点数可以视作是个激活函数,
就做了一个没加激活函数的小浮点数层,最后居然还在一个小模型上训练成功了
这个我也印象颇深。
能提供下这篇论文的链接么。还有请教一个问题,在关于 warmUp 如何起作用的一节中,有下面的描述:
所以,如果Post Norm结构的模型不进行Wamrup,我们能观察到的现象往往是:loss快速收敛到一个常数附近,然后再训练一段时间,loss开始发散,直至NAN。
为啥loss在发散的过程中,Bp过程不能指导网络重新收敛,而会走入到 NAN 的情况,感觉直观上很难理解这个现象。
1、https://mp.weixin.qq.com/s/PBRzS4Ol_Zst35XKrEpxdw;
2、本文已经描述了这个现象的直观理解;
3、梯度下降为什么反而导致loss发散?这也不是什么太稀奇的现象,主要模型来到了过于崎岖的区域,轻微的扰动都导致严重的变化吧。
November 16th, 2021
关于Warmup这个问题,我还是没太理解。使用Adam作为优化器时,不同层参数的更新量已经被归一化到了$\eta$尺度,应该不会有后面的层更新快,前面的层更新慢的问题了吧?如果真的是前后更新步幅不一致导致模型不稳定,那是否有工作尝试了对前后层使用不同学习率呢?
大家都更新同样的步幅,但是由于梯度消失,前面的层对于更新量更不敏感,所以相对而言还是后面的层先收敛。
你可以看$(6)$式,就算每个参数的$\Delta x$相同,但是$\nabla_x f(x)$不同,如果说$f(x)$快速陷入了局部最优,那么这个过程是由$\nabla_x f(x)$的绝对值比较大的部分所主导的。
这个问题明白了,谢谢苏神解答!
November 19th, 2021
行外人。。我只能观摩观摩
December 14th, 2021
苏神,x + f(x)假定了x, f(x)独立且方差都为1, 那么归一化相当于除以了sqrt(2)=1.414 > 1结果是残差结构加重了梯度消失;
但是实际初始化过程,x, f(x)初始化标准差为0.02, 那么x + f(x)相当于除以sqrt(0.028)< 1,那么是否意味着残差效应增强了?
不是。
如果是Post Norm结构,$x$是Norm之后的结果,所以$x$的方差大致上还是1,然后我们是将$F(x)$的参数的标准差初始化为$0.02$,这不等于$F(x)$的标准差是0.02。
此外,本文不是说了,小的初始化一定程度上确实有利于缓解梯度消失,所以你的“残差效应增强”是不是这个意思?
December 26th, 2021
想請問前輩,關於Warmup有沒有推薦的理論或是分析研究,謝謝。
暂时没留意到,抱歉。