MoE环游记:3、换个思路来分配
By 苏剑林 | 2025-03-05 | 35242位读者 |这篇文章我们继续探讨MoE的负载均衡问题。在上一篇文章《MoE环游记:2、不患寡而患不均》中,我们主要讨论了通过Aux Loss来促进负载均衡的思路。Aux Loss固然简单直观,但它也有一个明显的缺点——权重不好调——调低了无法促进均衡,调高了容易损害LM Loss,所以业界一直有寻找替代方案的尝试。
本文要分享的是名为“Loss-Free”的方案,由DeepSeek在《Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts》提出。和DeepSeek众多耀眼的开源作品相比,这篇论文也许不算起眼,但在笔者看来,它潜在的学术影响力可能远超其他工作,因为所提方法不仅简单有效,而且极具普适性,堪称经典。
方法大意 #
面对负载不均衡,Aux Loss的应对思路是通过额外的损失引导Router给出均衡的打分,而Loss-Free的想法则是换个新的分配思路,即不改变Router现有打分结果,而是改变argtopkρ这个分配方式。
其实这个方向此前也有过一些努力。比如2021年Facebook提出了BASE Layer,将Expert的分配视为线性指派问题,即以负载均衡为约束条件,求在该约束之下Router总打分尽可能高的分配结果,这可以用匈牙利算法等来解决。但该方案需要知道全体Token的打分,所以对于自回归式LLM来说,它只适用于训练,推理还是只能用argtopkρ,训练推理存在不一致性,并且由于目前求解算法的限制,它只适用于k=1的场景。
相比之下,Loss-Free的做法非常简单且有效,它留意到一个事实,即我们总可以引入一个偏置项b,使得argtopkρ+b的分配是均衡的,所以它将MoE的形式改为
y=∑i∈argtopkρρiei→y=∑i∈argtopkρ+bρiei
这里的b是输入无关的向量,由训练过程确定下来,训练完后它就保持不变,因此推理阶段也可以用,换言之训练和推理具有一致的形式。注意乘以ei的还是ρi而不是ρi+bi,也就是说b仅仅参与分配过程而不参与MoE的前向计算,所以我们对b或ρ+b的正负性都没有特殊要求。
手搓梯度 #
怎么训练b呢?我们知道,b的优化方向自然是促进负载均衡,为此按照上一篇的记号,我们先定义f=[f1,f2,⋯,fn]:
fi={1/k,i∈argtopkρ+b0,i∉argtopkρ+b
以及F=E[f],这里的F自然就是在b偏置下Expert当前的负载分布了。借着我们定义均匀分布为Q=(1/n,1/n,⋯,1/n),那么负载均衡就相当于最小化
Laux=12‖F−Q‖2=12n∑i=1(Fi−1/n)2
这个目标是不可导的,但有了上一篇的经验,我们知道STE(Straight-Through Estimator)可以解决这个问题。STE的关键是找一个可导且跟F具有同增减趋势的量作为F的光滑近似,这里我们的优化参数只有b,而它正好具有我们期望的性质(增大bi,i被选中的概率就更高,那么Fi就更大),所以答案就呼之欲出了:
Laux=12‖b+sg[F−b]−Q‖2=12n∑i=1(bi+sg[Fi−bi]−1/n)2
它的梯度是
∇bLaux=12∇b‖b+sg[F−b]−Q‖2=F−Q
所以用梯度下降(SGD)来更新b就是
b←b−α(F−Q)
这里α是b的学习率。不过Loss-Free最终选择的更新规则略有不同,它选择的是符号梯度下降(SignSGD):
b←b−αsign(F−Q)
这个结果其实也很好理解,就是如果Fi比1/n大,那么就调小一点bi,否则就增大一点bi。
改良版本 #
除了加sign的符号梯度下降外,笔者发现直接对F−Q做RMS Norm(即Normalized SGD),在相同的α下往往能达到更好的均衡效果:
b←b−αF−QRMS(F−Q)
这里的RMS是“Root Mean Square”,定义为
RMS(F−Q)=√1nn∑i=1(Fi−Qi)2
不难看出,加sign后的sign(F−Q)和加RMS Norm后的F−QRMS(F−Q),它们的RMS都是1,因此它们俩尺度上是大致相同的,所以我们可以使用相同的α。
简单来说,sign的问题在于不论Fi与目标Qi的远近都使用同样的更新幅度,这导致原本就已经跟Qi比较接近的Fi反而容易偏离原本已经达到的均衡,从而产生震荡;而RMS Norm则保留了Fi−Qi之间的相对大小,更新幅度更加自适应一些,理论上更有助于促进均衡,实测效果也多是它更好。
一脉相承 #
原论文在介绍Loss-Free时,并没有上述Aux Loss的推导过程,而是直接给出式(7)的更新规则,给人的感觉是给b“手搓”了梯度sign(F−Q),这也是它Loss-Free这个名字的来源。
然而,从本文给出的推导可以看出,更新规则(7)也完全可以从Aux Loss视角得到,两者是一脉相承的。看起来Loss-Free最直接的好处是不用调Aux Loss权重了,但它实际上也有个学习率参数α要调,尽管原论文已经帮我们搜好α=0.001这个默认值,但不可否认这个超参数是存在的。
在笔者看来,Loss-Free的本质创新并不是没有Aux Loss,而是隔离了Aux Loss和LM Loss的优化参数,从而达到了负载均衡和模型能力两不误的效果。其中最关键一步,是留意到“一个偏置项足以达到负载均衡”这一事实,然后就让Aux Loss只优化新引入的偏置b,而LM Loss则优化剩余参数,让Aux Loss对LM Loss的负面作用降到最低。
相比之下,常规的Aux Loss方案需要全体参数来促进负载均衡,而LM Loss优化的也是全体参数,两者的优化方向可能并不完全兼容,因此想找到一个最优的平衡点相对来说就更为困难。所以,Loss-Free基于“一个偏置项足以达到负载均衡”将两个Loss的优化参数隔离开来,是负载均衡问题的一个绝妙的解决办法。
相关细节 #
尽管Loss-Free已经足够简单明了,但是在使用的时候还要稍微注意一些细节。
首先,对于每个Batch的数据,我们应当先根据LM Loss来更新模型参数,然后再根据式(7)来更新b。这是因为b的更新依赖于全体Token的统计信息F,先更新b再更新模型其余参数的话,原则上会有泄漏未来信息的风险。虽然直观看来就一个向量b泄漏不了多少信息,但这个风险终归是存在的,因此要尽量去规避它。
其次,刚才我们说原论文已经调好α=0.001,但这个结果可能跟原论文用Sigmoid作为Router ρ激活函数的选择是绑定的。原因也不难想,经过Sigmoid后,每个ρi相对比较独立,并且都在(0,1)内,α=0.001相当于说每一步的更新幅度约为千分之一,如果换Softmax、ReLU或者其他激活函数,那么就可能需要重调α了。
针对这个问题,笔者建议的做法是解耦Gate和Bias所用的激活函数,即
y=∑i∈argtopkρ+bρiei→y=∑i∈argtopkρ(σ)+bρ(h)iei
其中ρ(σ)=σ(xW(R)),ρ(h)=h(xW(R)),σ(⋅)是Sigmoid函数,h(⋅)是任意单调且值域非负的函数,说白了就是加上b的是Sigmoid激活的打分,这样我们就可以复用α=0.001,至于乘上Expert的Gate,我们可以用其他激活函数,只要它的单调性跟Sigmoid一致就行。
此外,由于更新规则(7)加了sign函数,因此有可能训出绝对值大于1的bi,整体绝对值还可能越来越大,这些都是正常的,对模型效果不会有影响。实际上b有一个冗余的自由度,因为全体bi都加上同一个常数后,argtopkρ+b的结果不变。这个额外的自由度我们可以用来做其他好玩的事情(且听下回分解)。
延伸思考 #
除了MoE的负载均衡之外,Loss-Free的思想还可以应用到很多类似问题,比如VQ-VQE的编码表坍缩(Codebook Collapse),就可以用同样思路解决,而且相比之前介绍的“旋转技巧”、“线性变换技巧”显得更自然和普适。事实上,本文开篇的评价“Loss-Free潜在的学术影响力可能远超其他工作”,正是基于Loss-Free的普适性考虑的。
抛开具体的应用背景,从数学上来看,Loss-Free的贡献可以理解为给出了用梯度下降来求解指派问题的方法。一个经典的线性指派问题可以表示为:
min
其中c_{i,j}是给定的成本函数,f是\{1,2,\cdots,n\}到自身的双射。放到本文的背景下,c_{i,j}不就相当于n个Token、n个Expert的打分,所求f不就是一个负载均衡的分配方案?求解此类问题的一般想法是在满足约束条件的空间里搜索尽可能优的解,而Loss-Free则反过来,先构建一个最优但不一定满足约束条件的解:
\begin{equation}f(i) = \mathop{\text{argmin}}_j c_{i,j}\end{equation}
这个解在分数上肯定是最优的,但不一定满足双射的条件,这里不满足双射就等价于负载不均衡。于是我们引入偏置
\begin{equation}f(i) = \mathop{\text{argmin}}_j c_{i,j} + b_j\end{equation}
b_j初始化为零,然后根据式\eqref{eq:aux-loss-free}来更新,更新规则说白了就是哪个j出现出现次数多,那减少相应的b_j,反之增加,直到出现双射为止。
文章小结 #
本文介绍了MoE负载均衡问题的Loss-Free方法,它由DeepSeek提出,其核心在于通过引入一个简单的偏置项来实现负载均衡。本文进一步思考了它与Aux Loss的联系,以及它在类似数学问题上的应用潜力。
转载到请包括本文地址:https://spaces.ac.cn/archives/10757
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 05, 2025). 《MoE环游记:3、换个思路来分配 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/10757
@online{kexuefm-10757,
title={MoE环游记:3、换个思路来分配},
author={苏剑林},
year={2025},
month={Mar},
url={\url{https://spaces.ac.cn/archives/10757}},
}
March 23rd, 2025
苏老师,我想不明白为什么要用 sign 函数来完成 b 的更新?文中的公式(6)不好吗?
原论文提到公式(6) "slightly improves load balance but does not show improvement in model performance."
\text{sign}容易控制更新幅度呀,不加\text{sign}你不大好感知每步改变了多少。
问下苏老师能这么理解吗:公式(6)(7)(8)的更新梯度对应不同的损失函数,(6)式对应的是L2范数平方也就是公式(3)的\frac{1}{2} \Vert\boldsymbol{F} - \boldsymbol{Q}\Vert_2^2,(7)式对应的L1范数\Vert\boldsymbol{F} - \boldsymbol{Q}\Vert_1,然后(8)式对应的是\sqrt{n} \Vert\boldsymbol{F} - \boldsymbol{Q}\Vert_2。 (p.s.有个typo 想纠正下, 公式(10)前面的"结构”应该是"解构"?)
谢谢,typo是“解耦”,感谢指出。关于(6),(7),(8)的等价理解,确实如此。
March 25th, 2025
看完三篇,个人感觉MOE的功能跟multi head attention很像。所以如果attention层做的足够好, 那还需要Router去处理attention的output吗?或者how does MoE routing interact with attention? 搜索到很多MoE-Attention Hybrid的文章,貌似现在的大模型都没有采纳这样的做法,还是选择替换掉FFN。
你是说将MoE的思想用在Attention heads还是啥来着?现在MoE是用在FFN上,这是一个主流方向;用在Attention heads上也有研究,目前没成为主流,但MoE兴起也没多久,我觉得静待百花齐放呗。
March 25th, 2025
苏神,top_k操作造成的router梯度计算不准的问题,是不是某种程度上用负载均衡loss缓解了?所以我们优化负载均衡,也是在优化router分类本身的准确度?
只能说负载均衡加大batch size可以很好缓解这个问题,但实际上最优分布是不是均匀分布还不好说。
March 31st, 2025
苏神您好,请问我在官方的推理代码里看到了bias这个参数,但是由于官方没有训练代码,我现在对这个Aux Loss free 的 bias如何更新产生了疑问。它通过直接训练就可以达到更新效果嘛。如何通过梯度下降让b更新呢
这个更新算法在loss free论文里面有写的,bias的更新是在每个batch在bp更新后单独进行的,是一个独立于loss bp的步骤,所以叫loss free.
你仔细看一下本文,它就不是用梯度下降更新的,它是额外“手搓”的梯度,手写的更新规则来更新。
April 5th, 2025
在公式(5)中,\begin{equation}\nabla_{\boldsymbol{b}}\mathcal{L}_{\text{aux}} = \frac{1}{2}\nabla_{\boldsymbol{b}}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \boldsymbol{F} - \boldsymbol{Q}\end{equation}
这个公式是怎么化简得到的呢,有点没看懂?
问题来源:\nabla_b 是对 b 求导,sg 是stop gradient,从结果看似乎把 b 拆进去,变成了 \nabla_b \Vert b + F - b + Q \Vert^2 = \nabla_b \Vert F + (b-b) - Q\Vert^2 ,并且认为 F 就是 b,这样的公式 gradient 才是 F - Q,那为什么还要有 sg[F-b],而不是直接写成 \Vert b + F - b + Q \Vert^2
\begin{aligned} \nabla_b \mathcal{L}_{aux} &= \frac{1}{2} \nabla_b \sum_{i=1}^n (b + sg[F_i - b] - Q_i)^2 \\ &= \sum_{i=1}^n (b + sg[F_i - b] - Q_i) \nabla_b (b + sg[F_i - b] - Q_i) \\ &= \sum_{i=1}^n (F_i - Q_i) \nabla_b (b - 1/n) \\ &= \sum_{i=1}^n (F_i - Q_i) *1\\ &= F - Q. \end{aligned}
理解了,感谢大佬~
@lyc|comment-27340已经给出了推导,感谢。这里补充一句:关键就是\text{sg}[]的梯度为零,其他方面的计算规则是完全不变(比如前向计算、链式法则)。
April 8th, 2025
如果让aux loss对router的输入stop gradient,也能做到不影响整体的PPL吧
我们还实验过这样的思路,发现balance没有问题,但是loss并不如正版的。直觉上就是把balance的压力都加给router weights,反而让router压力过大,没法对效果有很好的贡献(别忘了router同时作为expert的gate)
April 17th, 2025
苏神,这里的b训练过程通过手搓梯度更新,是一个同专家数同维度的1*N向量。那么推理时,每层FFN专家对应一个定值的b是吗?
是的
April 19th, 2025
这样做,router选择的就不是模长最大的方向了吗?
确实不是。可以理解为全体token一起统筹安排,在满足均匀的前提下才去尽量选择模长最大的。