基于Amos优化器思想推导出来的一些“炼丹策略”
By 苏剑林 | 2022-11-22 | 38059位读者 |如果将训练模型比喻为“炼丹”,那么“炼丹炉”显然就是优化器了。据传AdamW优化器是当前训练神经网络最快的方案,这一点笔者也没有一一对比过,具体情况如何不得而知,不过目前做预训练时多数都用AdamW或其变种LAMB倒是真的。然而,正如有了炼丹炉也未必能炼出好丹,即便我们确定了选择AdamW优化器,依然有很多问题还没有确定的答案,比如:
1、学习率如何适应不同初始化和参数化?
2、权重衰减率该怎么调?
3、学习率应该用什么变化策略?
4、能不能降低优化器的显存占用?
尽管在实际应用时,我们大多数情况下都可以直接套用前人已经调好的参数和策略,但缺乏比较系统的调参指引,始终会让我们在“炼丹”之时感觉没有底气。在这篇文章中,我们基于Google最近提出的Amos优化器的思路,给出一些参考结果。
基础回顾 #
Amos优化器出自Google最近的论文《Amos: An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale》,它对上述几个问题都推导了比较完整的推导,并通过实验证实了它的有效性。然而,原论文的推导实在是不好读,各种记号和估计都过于随意,给人很“凌乱”感觉。不过好在Amos的思想还不算复杂,我们可以借用一下。
在开始推导之前,我们不妨先回顾一下对于上述几个问题,现有的解决方案是怎样的。
首先,第一个问题,大家可能不大理解“初始化”和“参数化”分别是什么含义,其实这就是模型权重的两种设置方式,常见的就是一个n×n的矩阵,一般用“均值为0、方差为1/n”的方式初始化,详细介绍可以参考笔者之前《从几何视角来理解模型参数的初始化策略》、《浅谈Transformer的初始化、参数化与标准化》。从“方差为1/n”我们就可以看到,不同参数有着不同的尺度(或者说数量级),如果我们用同一个学习率更新所有参数,那么就会导致每个参数的更新幅度不一样。这个问题笔者觉得比较优雅的解决方案就是LAMB优化器,它每次更新的模长直接取决于参数本身的模长,学习率只是用来描述相对更新量的大小。
至于权重衰减率问题,至少在预训练领域,笔者观察到的是都是沿用最早的选择0.01,没有发现去调整该参数的工作。而对于学习率变化策略,大家都知道应该要将学习率慢慢降到零,但具体应该选用什么什么下降策略,暂时也没有太多的理论指导,多数结果也只是实验总结出来的。最后,关于节省显存问题,比较经典的工作就是AdaFactor优化器,笔者之前在《AdaFactor优化器浅析(附开源实现)》也有过介绍。降低优化器显存占用的主要就两个思路,一是去掉动量,二是对二阶矩做低秩分解,Amos本质上也是沿用了这两个思路。
问题设置 #
本文主要关心开头的前三个问题,希望能够推导出一些“即插即用”的结果。首先,我们将优化器的更新规则简写成:
θt+1=θt−αtut
其实θt,θt+1分别代表t,t+1时刻的参数值,ut代表t时刻的更新向量(依赖于任务和数据),而标量αt>0(向量的每个元素都大于0)代表t时刻的学习率。
自AdamW起,主流优化器都倾向于把权重衰减(Weight Decay)项从ut中独立出来,即
θt+1=θt−(αtut+ρtθt)
其中ρt>0是权重衰减率。本文的主要任务,就是希望能解决αt和ρt该怎么设置的问题。
权重衰减 #
我们知道,权重衰减也好,L2正则也好,它本身是跟训练目标无关的,它只是一个辅助项,目的是提高模型的泛化能力。既然是辅助,那么一个基本的要求就是它不应该“喧宾夺主”,为此,我们不妨加入一个限制:
O(α2t)=O(ρt)
也就是说,在整个更新过程中,权重衰减带来的更新量始终要比目标相关的更新量高一阶,由于αt,ρt基本上都是小于1的,所以更高阶意味着更小。
设优化的参数终点是θ∗,我们记εt=θt−θ∗,根据更新规则可以得到
‖εt+1‖2=‖θt+1−θ∗‖2=‖θt−(αtut+ρtθt)−θ∗‖2≈‖εt‖2−2αtut⋅εt+(α2t‖ut‖2−2ρtθt⋅εt)
最后的近似只保留了不超过O(α2t)的项。
很明显,‖εt‖是当前结果与终点的距离,它自然是越小越好,因此我们自然也希望每一步的更新都能缩小这个距离,即‖εt+1‖<‖εt‖。而我们看式(4),−2αtut⋅εt可正可负,如果它为负就有助于实现‖εt+1‖<‖εt‖,但是α2t‖ut‖2必然是正的,它是不利于实现‖εt+1‖<‖εt‖,不过在引入权重衰减后,多出了一项−2ρtθt⋅εt,如果这一项能抵消掉α2t‖ut‖2的负面作用,那么权重衰减的引入就不仅能增强泛化能力,还有利于模型收敛了。
可行分析 #
所以,接下来的事情,我们就是要考察
α2t‖ut‖2=2ρtθt⋅εt
的可行性。所谓可行性,就是θt⋅εt能否大于0,只有它大于0,左右两端才有可能相等。利用εt的定义我们得到θt=εt+θ∗,于是
θt⋅εt=(εt+θ∗)⋅εt=‖εt‖2+θ∗⋅εt
注意θ∗是我们的目标,是一个固定的点,而εt是当前时刻与目标的差异向量,两者一般来说没什么必然的相关性,于是我们可以近似认为它们是高维空间中两个随机向量。根据《n维空间下两个随机向量的夹角分布》,我们知道高维空间中两个随机向量几乎都是垂直的,于是θ∗⋅εt≈0,即θt⋅εt≈‖εt‖2。当然,如果不放心,还可以引入一个参数q:
θt⋅εt≈q‖εt‖2
此时式(5)就变成了
α2t‖ut‖2≈2ρtq‖εt‖2
两端都大于0,因此式(5)是有可能成立的。
渐近估计 #
如果式(5)成立,那么式(4)就简化为了‖εt+1‖2≈‖εt‖2−2αtut⋅εt=‖εt‖2−2αt‖ut‖‖εt‖cos(ut,εt)
我们说了ut代表的是任务相关的更新量,平均来说它必然是有利于任务的(否则原来的优化器就是有缺陷的了),所以平均来说应该有cos(ut,εt)>0。这里我们进一步假设,存在一个p>0,使得cos(ut,εt)∼p,于是我们有
‖εt+1‖2≈‖εt‖2−2αtp‖ut‖‖εt‖
根据近似(8)我们有αt‖ut‖‖εt‖≈√2ρtq‖εt‖2,代入上式得到
‖εt+1‖2≈‖εt‖2(1−2p√2ρtq)≈‖εt‖2exp(−2p√2ρtq)
一步一步往前递推,可以得到
‖εt‖2≈‖ε0‖2exp(−2t−1∑i=1p√2ρiq)
可以看出右端的指数必然是单调递减的,它是一个衰减函数。现在我们再看近似(8),它有两个参数αt和ρt要调,但只有一个(近似)等式。为了使αt和ρt能够同等程度地衰减,我们设2ρtq≈λ2‖εt‖2,于是解得
αt≈λ‖εt‖2‖ut‖≈λ‖ε0‖2‖ut‖exp(−2t−1∑i=1p√2ρiq)ρt≈λ2‖εt‖22q≈λ2‖ε0‖22qexp(−2t−1∑i=1p√2ρiq)
这就是本文推出的αt,ρt的变化规律。当然,变化规律是有了,可是还有四个参数λ,‖ε0‖,p,q要确定,其中q相对来说比较简单,直接设q=1问题也不大,但即便这样还有三个参数要确定。
尺度预判 #
根据定义,‖ε0‖=‖θ0−θ∗‖,也就是初始化参数与目标参数的距离,可以理解为参数的变化尺度,它有几种不同的情况。
第一种,参数是矩阵乘法核,比如全连接层、卷积层的kernel矩阵,它们的初始化一般是“均值为0、方差为σ2”(σ取决于shape)的随机初始化,这样如果θ∈Rk,那么我们就可以估算出‖θ0‖2≈kσ2。另外,这类参数有一个特点,就是在合理的初始化下,训练完成后参数的均值方差也不会有太大变化,至少量级是一致的,因此也可以认为‖θ∗‖2≈kσ2,而因为初始化是随机的,所以θ0⋅θ∗≈0,因此
‖ε0‖2=‖θ0−θ∗‖2=‖θ0‖2+‖θ∗‖2−2θ0⋅θ∗≈2kσ2
第二种,参数是加性偏置项,比如全连接层、卷积层的bias向量,以及Normalization层的β向量,这些参数一般是“全零初始化”,所以‖ε0‖2=‖θ∗‖2,如果我们根据经验预测训练好的模型偏置项都在±σ附近,那么也可以估计出‖θ∗‖2≈kσ2,Amos原论文取了σ=0.5。最后还有Normalization层的γ向量,它一般是“全1初始化”,训练完成后也是在1附近,不妨假设误差为±σ,那么也可以估算出‖θ∗‖2≈kσ2。这里的k都是指向量维度。
可以看出,‖ε0‖2的结果都有一个共性,那就是都可以写成kσ2,其中σ是我们对参数变化尺度的一个预判。乘性矩阵的σ可以直接取初始化的标准差,加性偏置或者γ向量可以直接简单地取σ=0.5,或者有其他特殊参数的再做特殊处理。
分离尺度 #
现在我们来看完整的更新量,根据式(13),有
αtut≈λ‖ε0‖2×ut‖ut‖×exp(−2t−1∑i=1p√2ρiq)
其中ut‖ut‖是一个单位向量,控制更新方向,exp部分是一个衰减项,我们可以先不管它,所以更新量的模长由λ‖ε0‖2控制。
回到文章开头的第一个问题“学习率如何适应不同初始化和参数化?”,很明显,直观想法应该就是变化尺度大的参数每一步的更新量应该更大,或者直接简单地正比于变化尺度,而变化尺度我们刚才估计了,可以用‖ε0‖来描述,所以我们认为应该有λ‖ε0‖2=α0‖ε0‖,其中α0是全局的初始学习率。反过来解得λ=α0/‖ε0‖,代入式(13)得到
αt≈α0‖ε0‖‖ut‖exp(−2t−1∑i=1p√2ρiq),ρt≈α202qexp(−2t−1∑i=1p√2ρiq)
其中α0代表了每一步的相对更新幅度(全局学习率),这一步没啥推导空间了,一般取10−3左右就行,如果任务简单也可以取到10−2;‖ε0‖在上一节已经做了估计,大概是√kσ,σ代表参数平均变化尺度,不同参数不一样,我们正是通过它把参数尺度显式地分离了出来,从而达到了自适应参数尺度的效果(更新量正比σ)。特别地,如果将上式的‖ε0‖换成‖θt‖,那么就是LAMB优化器。从这里也可以看出,如果θ的初始化均值不是0(像γ向量),用‖θt‖替代‖ε0‖是会有问题的,所以LAMB的做法是直接不对这些参数的更新量进行变换(即保留原来的更新规则)。
解析近似 #
其实目前的结果已经适合编程实现了,只是参数p不好调罢了。为了进一步看出参数p是怎么影响衰减函数的,我们可以进一步求出ρt的解析近似!
在式(16)的ρt两边乘以2q,然后两边开平方,得到
将指数的求和t−1∑i=1p√2ρiq记为St,那么上式就对应差分方程
St−St−1p≈α0exp(−St−1)⇒St+1−St≈α0pexp(−St)
此时衰减函数就是exp(−2St)。为了求渐近近似,我们用导数代替差分(参考《差分方程的摄动法》),得到
dStdt≈α0pexp(−St)
这是个简单的微分方程,可以解得(结合S0=0)
exp(−2St)≈1(α0pt+1)2
这就是衰减函数的显式解,表明超参数应该按照步数的平方反比衰减,代入式(16)后的完整结果是
αt≈α0‖ε0‖‖ut‖1(α0pt+1)2,ρt≈α202q1(α0pt+1)2
这个显式解不但能让编程实现更方便,还使得p的含义更为清晰。比如我们希望学习率在T步后就降低为原来的一半,那么就有(α0pT+1)2=2,从中解得
α0p=√2−1T
至于T应该是多少,这依赖于任务难度和数据量,也没有太大推导空间了。
动态收敛 #
上述讨论的假设是存在常数p>0,使得cos(ut,εt)∼p,这可以理解为模型按照固定的速度收敛,这在实际中很难成立,更常见的是越接近训练的后期,收敛速度相对来说越慢。为此,我们可以进一步假设p是步数t的函数pt,这样一来,前面的推导大体上还是成立,只不过相应的常数p要换成带下标的pi:
√2ρtq≈α0exp(−t−1∑i=1pi√2ρiq)
重复上一节的推导,我们得到
St−St−1pt≈α0exp(−St−1)⇒St+1−St≈α0ptexp(−St)
近似的微分方程就是
dStdt≈α0ptexp(−St)
积分的结果是
exp(−St)≈1α0∫t0pτdτ+1
但现在多了一个pt需要确定。为了降低调参成本,我们不妨假设收敛的下降速度跟‖εt‖的下降速度一致,而根据式(12),‖εt‖的衰减函数就是exp(−St),所以我们设pt=p0exp(−St),代入上式得到
exp(−St)≈1α0p0∫t0exp(−Sτ)dτ+1
这本质就是一个简单的微分方程,容易解得
exp(−2St)≈12α0p0t+1
代入式(16)之后,得到
αt≈α0‖ε0‖‖ut‖12α0p0t+1,ρt≈α202q12α0p0t+1
单看衰减策略,这正好是“逆时间衰减(Inverse Time Decay)”,也是学习率的常见衰减策略之一。理论上来说,这个结果在假设上比前面的式(20)更为合理。
文章小结 #
本文借鉴了Amos优化器的思路,推导了一些关于学习率和权重衰减率的结果(20)、(28),这些结果可以即插即用地应用到现有优化器中,能一定程度上简化调参难度。
转载到请包括本文地址:https://spaces.ac.cn/archives/9344
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Nov. 22, 2022). 《基于Amos优化器思想推导出来的一些“炼丹策略” 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9344
@online{kexuefm-9344,
title={基于Amos优化器思想推导出来的一些“炼丹策略”},
author={苏剑林},
year={2022},
month={Nov},
url={\url{https://spaces.ac.cn/archives/9344}},
}
November 23rd, 2022
苏神您好,今天面试遇到了个问题,什么样的函数能被用作激活函数,您有什么高见嘛
没什么高见。理论上非线性的都可以,连浮点误差都可以,参考 https://kexue.fm/archives/4647
February 20th, 2023
2ρtq≈λ‖ϵt‖2 这个为什么会这么设呀,感觉是根据后面的内容才强行设的这个值
文中有解释,是“为了使αt和ρt能够同等程度地衰减”,也就是衰减速度是一致的。至于你如果要问为什么要“使αt和ρt能够同等程度地衰减”,那就只能说这是一种直觉了。
意思是学习率和衰减率都和误差呈二次关系
August 19th, 2023
非常感谢您的文章!对我帮助很大,我有以下几个问题。
请问AdamW中的weight_decay应该怎么设置,pytorch默认是1e-2, 一些论文中也有说1e-4比较好。
1. 根据这篇论文似乎要设置为初始学习率的平方?我目前在训练的是ViT-H级别的图像生成模型,学习率固定为1e-4,所以weight decay应该设置为1e-8吗
2. 我在训练ViT时,AdamW lr=1e-4,weight decay=0,出现了梯度爆炸的情况,用grad norm=1.0能在梯度爆炸时裁剪避免参数更新,的确是有作用。我想问如果我使用weight decay=1e-4,是否能让梯度爆炸不出现
1、我看LLM的训练,不是0.01就是0.1吧,比如llama2就是0.1,似乎没看到更小的;
2、这个应该没法保证吧。
December 4th, 2023
很棒,但是不太理解。
优化过程是非凸的,一开始假设
‖εt+1‖<‖εt‖
更好?这样做合理吗?
但我们需要一个约等于的关系而不单单是不等关系。