初探muP:超参数的跨模型尺度迁移规律
By 苏剑林 | 2025-03-13 | 15974位读者 |众所周知,完整训练一次大型LLM的成本是昂贵的,这就决定了我们不可能直接在大型LLM上反复测试超参数。一个很自然的想法是希望可以在同结构的小模型上仔细搜索超参数,找到最优组合后直接迁移到大模型上。尽管这个想法很朴素,但要实现它并不平凡,它需要我们了解常见的超参数与模型尺度之间的缩放规律,而muP正是这个想法的一个实践。
muP,有时也写μP,全名是Maximal Update Parametrization,出自论文《Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer》,随着LLM训练的普及,它逐渐已经成为了科学炼丹的事实标配之一。
方法大意 #
在接入主题之前,必须先吐槽一下muP原论文写得实在太过晦涩,并且结论的表达也不够清晰,平白增加了不少理解难度,所以接下来笔者尽量以一种(自认为)简明扼要的方式来复现muP的结论。
先说结论,muP主要研究超参数跨模型尺度的迁移规律。这里有几个关键词:
1、超参数,目前主要指学习率;
2、模型尺度,目前主要是模型宽度;
3、这里的核心是“迁移”。
请注意,muP不研究什么是最优的超参数,只研究最优超参数随着模型尺度的变化规律,所以我们需要在某个小模型上搜索最优的超参数组合,然后迁移到大模型上,这就是muP的使用场景和使用方法。
推导muP的原理是让模型的前向传播、反向传播、损失增量和特征变化都不随模型尺度的变化而发生明显变化:
1、具体做法是分析初始化的数量级,然后认为结论可以代表后续优化的规律;
2、说白了就是假设做好初始化,后面就会自动沿着正确的轨迹走(好的开始是成功的一大半?);
3、当然也可以给这个假设讲大数定律或中心极限定理的故事,但个人认为非必须。
前向传播 #
我们从前向传播开始讨论,因为这是相对简单且成熟的部分。首先,考虑线性层Y=XW,其中X∈Rb×din,W∈Rdin×dout。我们用RMS(Root Mean Square)来作为矩阵尺度的指标,例如
RMS(W)=√1dindoutdin∑i=1dout∑j=1W2i,j
我们知道,要让初始化阶段X的RMS跟Y的RMS大致相等(简称“稳定”),那么W要用:
LeCun初始化:“均值为0、方差为1/din”的随机初始化。
这已经算是深度学习的基础结论之一,所以不再展开推导,还不大了解的读者可以参考以往的《从几何视角来理解模型参数的初始化策略》、《浅谈Transformer的初始化、参数化与标准化》等博文。
接着,我们考虑非线性层Y=ϕ(XW),其中ϕ是Element-wise的激活函数。如果还是要维持X的RMS跟Y的RMS近似相等,那么结果会稍有不同,比如relu激活时我们得到
Kaiming初始化:“均值为0、方差为2/din”的随机初始化。
容易看出,Kaiming初始化跟LeCun初始化相比,只是方差相差一个(跟模型尺度无关的)常数2,可以证明其他激活函数的结果也类似。所以我们可以下一个结论:
fan_in初始化:要保证前向传播的稳定性,那么应该要用“均值为0、方差正比于1/din”的随机初始化。
这个结论也可以理解为“激活函数的影响是模型尺度无关的”,所以如果我们只想分析模型尺度的效应,那么可以忽略(Element-wise的)激活函数的存在,由LeCun初始化直接得到缩放规律∝1/din。
反向传播 #
现在我们继续分析反向传播(梯度),注意这里约定变量及其梯度具有相同的shape,那么可以算得
∂L∂W=X⊤(∂L∂Y⊗ϕ′(XW))∂L∂X=(∂L∂Y⊗ϕ′(XW))W⊤
第一个公式是当前层内参数的梯度,第二个公式则是该层往前传播的梯度,⊗是Hadamard积,ϕ′是ϕ的导函数。
注意到一个事实:我们常用的激活函数,其导数都可以被一个(尺度无关的)常数给Bound住,所以至少在数量级上我们可以写出
∂L∂W=X⊤(∂L∂Y⊗ϕ′(XW))∼X⊤∂L∂Y∂L∂X=(∂L∂Y⊗ϕ′(XW))W⊤∼∂L∂YW⊤
我们先来看第二个公式,跟Y=XW相比,它右端乘的矩阵变成了W⊤,那么按照上一节的结论,如果要保持反向传播的RMS稳定性,那么W的初始化就应该是:
fan_out初始化:“均值为0、方差为1/dout”的随机初始化。
当din≠dout时,前向传播和反向传播的要求就出现冲突,这时候有人提了一个折中策略:
Xavier初始化:“均值为0、方差为2/(din+dout)”的随机初始化。
这也叫“fan_avg初始化”,因为就是将din和dout简单代数平均了一下,其他平均方式也可以考虑,参考《初始化方法中非方阵的维度平均策略思考》。Xavier初始化看上去同时兼顾了前向和反向,但也可以说两者都没兼顾,更好的办法是设计模型让大部分参数都是方阵,如后面讨论的模型簇(???)。
损失增量 #
有了前向传播和反向传播的铺垫,我们就可以尝试分析损失函数的增量了。考虑W→W+ΔW时损失函数的变化量
ΔL=L(W+ΔW)−L(W)≈⟨∂L∂W,ΔW⟩F
这里的⟨⋅,⋅⟩F是Frobenius内积,即把矩阵展平成向量后算向量内积。考虑梯度下降ΔW=−η∂L∂W,这里η自然是学习率,结合式(4),我们有
ΔL≈−η‖∂L∂W‖2F∼−η‖X⊤∂L∂Y‖2F
事实上,这个式子已经告诉了我们同一个学习率η不能跨模型尺度使用的原因:
1、X⊤∂L∂Y是一个din×dout的矩阵;
2、‖X⊤∂L∂Y‖2F是din×dout个数的平方和;
3、X⊤∂L∂Y正好是前向和反向的乘积;
4、如果前向和反向都稳定,那么X⊤∂L∂Y每个元素都是O(1);
5、所以‖X⊤∂L∂Y‖2F就是O(dindout)。
第4点可能要多加评述一下。X⊤是一个din×b矩阵,∂L∂Y是一个b×dout矩阵,两者相乘就是dindout个b维向量对做内积,内积是b项求和,而损失L通常是对样本求平均(即包含了除以b操作),所以如果X⊤和∂L∂Y都是尺度无关的,那么它们乘起来基本也是尺度无关的【即RMS都是O(1)】。
最后的结论表明,如果我们直接将小模型的学习率用于大模型,那么对于足够大的模型,它的每一步损失增量就会随着参数尺度(即dindout)的变大而爆炸,这意味着没法复制小模型的收敛过程,甚至可能因为步子迈得太大导致无法收敛。
此时大家可能想到的一个做法是让η∝1/(dindout)来缩放ΔL,事实上这个想法已经跟上了muP的思路,但实际场景中由于前面说的前向和反向的不兼容性,导致第4点“如果前向和反向都稳定,那么X⊤∂L∂Y每个元素就是O(1)”不能总是成立,所以实际情况更为复杂一些,
模型假设 #
现在让我们考虑一个更接近实践的场景。我们的任务是训练一个Rdin↦Rdout的模型,其中din,dout是数据决定的,不可改变。开头我们就说了,muP旨在研究超参数随着模型尺度的缩放规律,所以一切固定不变的量,都相当于是常数或者说O(1),比如初始化方差为1/din,等价于说初始化方差为O(1)。
我们可以改变的是模型的架构、参数量等部分,但muP主要考虑宽度的规律,所以我们把模型的架构定一下。这里主要考虑的模型簇是:
Yin=XWinYout=NN(Yin,Θ)Z=YoutWout
其中:
1、X∈Rb×din(带上了batch size);
2、Win∈Rdin×d,Wout∈Rd×dout;
3、NN是任意Rd↦Rd的神经网络;
4、这里d其实就是我们常说的hidden size;
5、我们可以随意调大d,来提升模型的参数量和潜力;
6、muP就是想研究超参数关于d的变化规律。
更具体一点,这里我们考虑的NN是K层MLP:
Y0=YinYk+1=ϕ(YkWk+1)Yout=YK
这里Θ={W1,W2,⋯,WK},Wk∈Rd×d,即都是d×d的方阵,全都用fan_in初始化(等价地,也是fan_out初始化)。
补充一下,这里约定所有参数矩阵都是d×d方阵,纯粹是为了简化分析,并不是强制要求。因为这里真正的目的是假设NN的参数里没有尺度无关的形状,比如不允许d×64这样的形状,因为64是一个常数,但d×4d这样的形状是允许的,因为你不管fan_in、fan_out或fan_avg初始化,方差都是正比于1/d。
组装起来 #
确立后具体模型后,我们就可以把前面的结论都组装起来了。要更新的参数分为Win,Θ,Wout三部分,分别求梯度:
∂L∂Wout=Y⊤out∂L∂Z∂L∂Wk=∂Yout∂Wk⋅∂L∂Yout=∂Yout∂Wk⋅(∂L∂ZW⊤out)∂L∂Win=X⊤∂L∂Yin=X⊤(∂Yout∂Yin⋅∂L∂Yout)=X⊤(∂Yout∂Yin⋅(∂L∂ZW⊤out))
这里的⋅运算需要稍微解释一下:Yin,Yout都是一个矩阵,所以∂Yout∂Yin原则上是一个四阶张量,链式法则∂Yout∂Yin⋅∂L∂Yout实际是高阶张量的乘法,但这里不打算展开介绍了,所以简单用一个⋅代替,读者只需要知道它是矩阵乘法的一般推广就行。
现在来观察规律:
1、三个式子都有∂L∂Z;
2、后两式都有W⊤out;
3、Wk里都是方阵,∂Yout∂Yin和∂Yout∂Wk都是稳定的【RMS是O(1)】;
4、如果Win也用fan_in初始化,那么Yout也是稳定的;
5、要想∂L∂ZW⊤out稳定,那么初始化方差是1/dout,但dout是尺度无关的,相当于常数。
这样一来:
1、∂L∂Wout的RMS是O(1),‖∂L∂Wout‖2F是d×dout个数平方和,所以大小是O(d×dout),别忘了dout是常数,所以实际上就是O(d),于是为了得到O(1)的ΔL,它的学习率要满足ηout∝1/d;
2、‖∂L∂Wk‖2F是d2个数求和,∂Yout∂Wk和∂L∂Z的RMS都是O(1),我们直接将Wout的初始化方差设为∝1/d2,那么∂L∂Wk的RMS就是O(1/d),平方求和后就正好是O(1),因此学习率不用变化;
3、此时∂L∂Win的RMS也是O(1/d),但‖∂L∂Win‖2F只是din×d个数平方和,所以结果是O(1/d)的,为了得到O(1)的ΔL,学习率反而需要放大d倍来抵消这个影响,即ηin∝d。
特征变化 #
以上结果是没有问题的,但仔细思考我们会发现推导过程的一个问题:上面的第2、3点,都建立在“我们直接将Wout的初始化方差设为∝1/d2”这个设置上,然而这个设置目前来说并没有直接的依据。如果不对此进一步解释,那么推导过程还是不够完备的。
事实上,单看ΔL=O(1)这个要求的话,确实是无法排除其他选择的可能性的,比如Wout的初始化方差设为∝1/d,此时∂L∂Wk的RMS是O(1/√d),平方求和后是O(d),那么只要学习率η∝1/d同样可以实现ΔL=O(1)。因此,为了解释“Wout的初始化方差设为∝1/d2”的必要性,那么就需要引入新的条件。
损失函数L是模型的一个宏观指标,或者说外部指标,单看它的变化已经不足以解释全部结果了,那么就需要细化到模型内部了。具体来说,我们希望模型每一层的输出(通常也称为特征,有时也称激活值)变化量也具有尺度不变性。比如线性层Yk=Yk−1Wk,参数Wk→Wk+ΔWk带来的输出变化是
ΔYk=Yk−1(Wk+ΔWk)−Yk−1Wk=Yk−1ΔWk
注意Yk−1∈Rb×d,ΔWk∈Rd×d,所以Yk−1ΔWk就是b×d个d维向量对的内积。注意这里ΔWk是精心设计的更新量,它不大可能跟初始化那样跟Yk−1是独立的,所以“d维向量对的内积”更有可能是O(d)(d维内积共有d项求和),因此如果ΔYk−1的RMS是O(1),那么可以认为ΔYk的RMS将是O(d×RMS(ΔWk))。
于是,为了让ΔYk的RMS是O(1),我们得到了对ΔWk的一个额外要求:
RMS(ΔWk)=O(1/d)
结合ΔWk=−η∂L∂Wk和ΔL=O(1),我们就可以得到“Wout的初始化方差设为∝1/d2”的结果。
(注:这一节依赖于 @Chenyu Zheng 的指点,非常感谢!)
Adam版本 #
以上就是SGD的muP,对于Adam,我们通常用SignSGD近似做数量级分析:
1、ΔW=−ηsign(∂L∂W);
2、ΔL≈−η|∂L∂W|1;
3、这里的|⋅|1指每个元素取绝对值然后求和。
关于SignSGD近似本身,读者还可以参考《当Batch Size增大时,学习率该如何随之变化?》、《Adam的epsilon如何影响学习率的Scaling Law?》等文章,这里也不展开讨论了。总而言之,SignSGD是分析Adam相关缩放规律时一个常用的近似方式。
现在可以模仿SGD的过程进行分析:
1、∂L∂Wout的RMS是O(1),|∂L∂Wout|1是d×dout个数求和,大小是O(d×dout)=O(d),所以它的学习率要满足ηout∝1/d来抵消尺度影响;
2、|∂L∂Wk|1是d2个数求和,∂Yout∂Wk和∂L∂Z的RMS都是O(1),我们将Wout的初始方差设为∝1/d2,那么∂L∂Wk的RMS就是O(1/d),d2个数求和后是O(d),所以学习率按照ηk∝1/d变换来抵消尺度影响;
3、此时∂L∂Win的RMS也是O(1/d),但|∂L∂Win|1只是din×d个数求和,所以它已经是O(1),从而学习率不用随尺度改变。
(注:读者可以自行检查一下式(14)是满足的。)
Muon版本 #
接下来自然少不了Muon的分析。对于Muon本身,我们已经在《Muon优化器赏析:从向量到矩阵的本质跨越》、《Muon续集:为什么我们选择尝试Muon?》做了详细介绍,这里不再重复。跟Adam用SignSGD类似,我们用MSignSGD来近似Muon:
1、ΔW=−ηmsign(∂L∂W);
2、ΔL≈−η‖∂L∂W‖∗(证明见《Muon优化器赏析:从向量到矩阵的本质跨越》);
3、这里的‖⋅‖∗指Nuclear范数,是矩阵的所有奇异值之和;
4、Nuclear范数并不好算,但F范数好算,它等于矩阵的所有奇异值的平方和的平方根;
5、我们用F范数作为Nuclear范数近似,因此ΔL≈−η‖∂L∂W‖∗≈−η‖∂L∂W‖F;
6、F范数又等于矩阵的所有元素的平方和的平方根。
那么可以开始分析过程:
1、∂L∂Wout的RMS是O(1),所以‖∂L∂Wout‖∗大小是O(√d×dout)=O(√d),要消除尺度的影响,那么它的学习率要满足ηout∝1/√d;
2、‖∂L∂Wk‖F是d2个数的平方和的平方根,∂Yout∂Wk和∂L∂Z的RMS都是O(1),我们将Wout的初始方差设为∝1/d2,那么∂L∂Wk的RMS就是O(1/d),平方和后再平方根,结果是O(1),所以学习率不用变;
3、此时∂L∂Win的RMS也是O(1/d),但‖∂L∂Win‖F只是din×d个数的平方和平方根,所以它是O(1/√d)的,学习率反而需要放大√d倍来抵消这个影响,即ηin∝√d。
(注:这里Muon的结论是对的,但它不满足条件(14),因为式(14)要细说的话还依赖于一个更新量是Element-wise的假设,而Muon不符合这个假设,所以实际上不可用。这里没有仔细展开相关讨论,而是直接沿用了“Wout的初始化方差设为∝1/d2”的结论,回避了式(14)。)
结论汇总 #
将上述结论汇总在一起是:
Win方差Win学习率Wk方差Wk学习率Wout方差Wout学习率SGD1/dind1/d11/d21/dAdam1/din11/d1/d1/d21/dMuon1/din√d1/d11/d21/√d
这里的Wk指的是除Win,Wout外的所有参数,还有要强调的是,这里的关系都是“正比于”而不是“等于”。另外实践中可以根据具体需求稍作变化,比如实际我们用Muon时,Win和Wout的优化通常不用Muon而是用Adam,这将导致两个变化:
1、ηout∝1/d;
2、ηin不变。
如果结合我们在《Muon is Scalable for LLM Training》所提的Adujst LR的话,那么学习率要多乘一个max,n\times m是参数矩阵的形状,我们已经假设了\text{NN}部分的参数总等比例缩放,所以\sqrt{\max(n, m)}\propto \sqrt{d}。因此,如果要抵消Adujst LR带来的尺度影响,那么就需要
3、\eta_k\propto 1/\sqrt{d} 。
文章小结 #
本文以尽可能简明清晰的方式介绍了muP(Maximal Update Parametrization),这是旨在研究超参数跨模型尺度的迁移规律的工作。基于muP,我们可以在小模型上以相对较小的成本仔细搜索超参数(这里主要是学习率和初始化),然后迁移到大模型上,降低大模型的炼丹成本。
客观来讲,这里的介绍和分析还比较初步,比如没有考虑Bias项、没有评估结论在MLP以外架构的通用性、也没有仔细考虑Normalization和残差的作用等。没有考虑Bias项这个单纯是偷懒,权当留给读者的习题了;至于不同架构下的muP,一般分析起来比较麻烦,但由于神经网络的相似性,结论大致上是相同的,我们可以不加证明地用着。个人认为比较关键的改进点是Normalization和残差的影响,尤其是Normalization,它使得不依赖特殊的初始化就可以稳定前向传播,带来了更大的自由度和可能性。
当然,这些都留给后续分析了。
转载到请包括本文地址:https://spaces.ac.cn/archives/10770
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 13, 2025). 《初探muP:超参数的跨模型尺度迁移规律 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/10770
@online{kexuefm-10770,
title={初探muP:超参数的跨模型尺度迁移规律},
author={苏剑林},
year={2025},
month={Mar},
url={\url{https://spaces.ac.cn/archives/10770}},
}
March 14th, 2025
建议看Tensor Programs IVb (arxiv:2308.01814),里面有提到关于优化器的问题。
需要特别注意:优化器默认为element-wise,但Muon很明显不满足这个条件,不能按照此文章的理论来分析。
(这篇文章的核心其实在脚注里,如果认为Adam eps为0,可以极大简化这一理论)
Muon需要按照Spectral Condition(arxiv:2310.17813)来分析,原版Muon在此分析下学习率在扩宽模型时保持不变。
感谢推荐。确实我没有完整读过Tensor Programs系列文章,因为对我的学习背景来说读起来还是太困难了,所以想了另外一个稍微不同的理解路线(即从\Delta\mathcal{L}入手)。
在个人的推导方式下,能够成功复现SGD和Adam的muP结论,并没有遇到“优化器需要是Element-wise”的障碍,所以也一并推到了Muon上。如果您说原本的muP有这个假设(我不确定),那可能是因为它是从\Delta\boldsymbol{y}入手分析造成的困难?毕竟\mathcal{L}只是一个标量,\boldsymbol{y}至少是一个向量,分析后者可能需要更多的简化假设。
https://arxiv.org/pdf/2308.01814, Section 2.1,第12页。确实是这个假设。
我觉得greg yang在TP4b里的假设和苏神这里的推导没有冲突。TP4b中element-wise update的假设是至关重要的,因为如果超出这个范畴,网络的前传和反传过程就很可能不能被TP理论限定的三个算子表示,这会导致后续的所有的理论结果都不再可靠。
但苏神的推导并不依赖TP的语言(即基于TP的算子来表征分析优化过程,并且推导无穷宽极限),所以自然也不需要这个假设。苏神相当于用更简单的一套语言分析了mup的条件,至少目前和element-wise没有关系。
当然,这种分析方式是否对任意架构和优化器都成立,还是不清楚的。
感谢两位 @plc|comment-27121 @Chenyu Zheng|comment-27142
事实上,我几乎没读过Tensor Programs系列其他篇文章,所以不大了解Tensor Programs一贯下来的假设和推导。不过后面我还是打算从Grey Yang最新的 https://arxiv.org/abs/2310.17813 的谱条件缩放入手了,毕竟这样仔细算梯度还是太麻烦了(已经逐渐崩溃~
March 20th, 2025
苏老师,我还是没看明白SGD的\frac{\partial Y_{out}}{\partial Y_{in}}和\frac{\partial Y_{out}}{\partial W_{k}}为什么它们的RMS是O(1)??如何估算出来的?能否大概写一下证明的思路?
前向和反向传播的稳定性。非要逐步证明的话,那就用链式法则写出每一步的梯度,然后证明每一步往回传都是稳定的(尺度无关的)。
March 23rd, 2025
苏神,我这几天仔细拜读了一下,收获很大。但我也发现一个地方有问题。主要是类似“我们直接将W_{out}的初始化方差设为1/d^2”的语句,其实是恰好设对的,而不是从本文的三个条件推出来的。
我们以最简单的SGD中的W_k为例子。为了使\Delta L = O(1),我们在保证前传反传稳定的情况下,完全可以设置其它的W_{out}方差和学习率。比如,W_{out}的初始化方差设为1/d(前传和反传稳定),此时\frac{\partial L}{\partial W_k}的RMS为1/\sqrt{d},平方求和后是d,所以对应地我们把学习率设为1/d。
在这个时候,我们来简单观察一下第k层feature的RMS的变化。
\Delta W_k Y_{k-1} = -\eta \frac{\partial L}{\partial W_k} Y_k = O(d \times d^{-1} d^{-0.5}) = O(1/\sqrt{d})。
当宽度很大的时候,我们发现中间层的feature不再变化,这正是mup(所有层feature最大更新)所不期望的。作为对比,我们可以看看本文W_{out}的初始化方差设为1/d^2为什么合理:
\Delta W_k Y_{k-1} = -\eta \frac{\partial L}{\partial W_k} Y_k = O(d \times 1 d^{-1}) = O(1)。
我们可以看到,此时所有隐藏层可以进行最大更新,这正是mup的要求。
所以,本文的\Delta L = O(\Delta Z)条件只是确保了output能够最大更新且不爆炸,但是这个条件不能够保证中间层的feature也得到最大更新。只有将中间层的feature更新\Delta W_k Y_{k-1}加入作为第4个条件,才能正确地推导出W_{out}的初始化方差设为1/d^2是唯一正确的。
不过还是感谢苏神,之前看TP4原文虽然都把证明过了一遍,但一直没能尝试去建立起非常直观的、直觉的、简单的推导过程。最近几天对照苏神好好探究了一下,感觉自己终于发现了梯度传播视角下,mup的简单原子条件了。
感谢指点!事实上我在写博客的时候已经意识到你说的问题了,你补全了本文的不足。我参考你的意见,把它补充到正文里边了。再次感谢!
不过顺便说,我已经“叛逃”到 https://arxiv.org/abs/2310.17813 这套方法了hhh,因为即便是本文这套简化版思路,对于一些复杂case也太难算了。
March 24th, 2025
确实太复杂了,我之后也学一下谱的视角
March 27th, 2025
每次看到苏神的文章都要仔细揣摩一遍,然后一脸蒙蔽地离开。