MoE环游记:3、换个思路来分配
By 苏剑林 | 2025-03-05 | 14633位读者 |这篇文章我们继续探讨MoE的负载均衡问题。在上一篇文章《MoE环游记:2、不患寡而患不均》中,我们主要讨论了通过Aux Loss来促进负载均衡的思路。Aux Loss固然简单直观,但它也有一个明显的缺点——权重不好调——调低了无法促进均衡,调高了容易损害LM Loss,所以业界一直有寻找替代方案的尝试。
本文要分享的是名为“Loss-Free”的方案,由DeepSeek在《Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts》提出。和DeepSeek众多耀眼的开源作品相比,这篇论文也许不算起眼,但在笔者看来,它潜在的学术影响力可能远超其他工作,因为所提方法不仅简单有效,而且极具普适性,堪称经典。
方法大意 #
面对负载不均衡,Aux Loss的应对思路是通过额外的损失引导Router给出均衡的打分,而Loss-Free的想法则是换个新的分配思路,即不改变Router现有打分结果,而是改变argtopkρ这个分配方式。
其实这个方向此前也有过一些努力。比如2021年Facebook提出了BASE Layer,将Expert的分配视为线性指派问题,即以负载均衡为约束条件,求在该约束之下Router总打分尽可能高的分配结果,这可以用匈牙利算法等来解决。但该方案需要知道全体Token的打分,所以对于自回归式LLM来说,它只适用于训练,推理还是只能用argtopkρ,训练推理存在不一致性,并且由于目前求解算法的限制,它只适用于k=1的场景。
相比之下,Loss-Free的做法非常简单且有效,它留意到一个事实,即我们总可以引入一个偏置项b,使得argtopkρ+b的分配是均衡的,所以它将MoE的形式改为
y=∑i∈argtopkρρiei→y=∑i∈argtopkρ+bρiei
这里的b是输入无关的向量,由训练过程确定下来,训练完后它就保持不变,因此推理阶段也可以用,换言之训练和推理具有一致的形式。注意乘以ei的还是ρi而不是ρi+bi,也就是说b仅仅参与分配过程而不参与MoE的前向计算,所以我们对b或ρ+b的正负性都没有特殊要求。
手搓梯度 #
怎么训练b呢?我们知道,b的优化方向自然是促进负载均衡,为此按照上一篇的记号,我们先定义f=[f1,f2,⋯,fn]:
fi={1/k,i∈argtopkρ+b0,i∉argtopkρ+b
以及F=E[f],这里的F自然就是在b偏置下Expert当前的负载分布了。借着我们定义均匀分布为Q=(1/n,1/n,⋯,1/n),那么负载均衡就相当于最小化
Laux=12‖
这个目标是不可导的,但有了上一篇的经验,我们知道STE(Straight-Through Estimator)可以解决这个问题。STE的关键是找一个可导且跟\boldsymbol{F}具有同增减趋势的量作为\boldsymbol{F}的光滑近似,这里我们的优化参数只有\boldsymbol{b},而它正好具有我们期望的性质(增大b_i,i被选中的概率就更高,那么F_i就更大),所以答案就呼之欲出了:
\begin{equation}\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (b_i + \text{sg}[F_i - b_i] - 1/n)^2\end{equation}
它的梯度是
\begin{equation}\nabla_{\boldsymbol{b}}\mathcal{L}_{\text{aux}} = \frac{1}{2}\nabla_{\boldsymbol{b}}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \boldsymbol{F} - \boldsymbol{Q}\end{equation}
所以用梯度下降(SGD)来更新\boldsymbol{b}就是
\begin{equation}\boldsymbol{b}\leftarrow \boldsymbol{b} - \alpha (\boldsymbol{F} - \boldsymbol{Q})\end{equation}
这里\alpha是\boldsymbol{b}的学习率。不过Loss-Free最终选择的更新规则略有不同,它选择的是符号梯度下降(SignSGD):
\begin{equation}\boldsymbol{b}\leftarrow \boldsymbol{b} - \alpha \mathop{\text{sign}}(\boldsymbol{F} - \boldsymbol{Q})\label{eq:aux-loss-free}\end{equation}
这个结果其实也很好理解,就是如果F_i比1/n大,那么就调小一点b_i,否则就增大一点b_i。
一脉相承 #
原论文在介绍Loss-Free时,并没有上述Aux Loss的推导过程,而是直接给出式\eqref{eq:aux-loss-free}的更新规则,给人的感觉是给\boldsymbol{b}“手搓”了梯度\mathop{\text{sign}}(\boldsymbol{F} - \boldsymbol{Q}),这也是它Loss-Free这个名字的来源。
然而,从本文给出的推导可以看出,更新规则\eqref{eq:aux-loss-free}也完全可以从Aux Loss视角得到,两者是一脉相承的。看起来Loss-Free最直接的好处是不用调Aux Loss权重了,但它实际上也有个学习率参数\alpha要调,尽管原论文已经帮我们搜好\alpha=0.001这个默认值,但不可否认这个超参数是存在的。
在笔者看来,Loss-Free的本质创新并不是没有Aux Loss,而是隔离了Aux Loss和LM Loss的优化参数,从而达到了负载均衡和模型能力两不误的效果。其中最关键一步,是留意到“一个偏置项足以达到负载均衡”这一事实,然后就让Aux Loss只优化新引入的偏置\boldsymbol{b},而LM Loss则优化剩余参数,让Aux Loss对LM Loss的负面作用降到最低。
相比之下,常规的Aux Loss方案需要全体参数来促进负载均衡,而LM Loss优化的也是全体参数,两者的优化方向可能并不完全兼容,因此想找到一个最优的平衡点相对来说就更为困难。所以,Loss-Free基于“一个偏置项足以达到负载均衡”将两个Loss的优化参数隔离开来,是负载均衡问题的一个绝妙的解决办法。
相关细节 #
尽管Loss-Free已经足够简单明了,但是在使用的时候还要稍微注意一些细节。
首先,对于每个Batch的数据,我们应当先根据LM Loss来更新模型参数,然后再根据式\eqref{eq:aux-loss-free}来更新\boldsymbol{b}。这是因为\boldsymbol{b}的更新依赖于全体Token的统计信息\boldsymbol{F},先更新\boldsymbol{b}再更新模型其余参数的话,原则上会有泄漏未来信息的风险。虽然直观看来就一个向量\boldsymbol{b}泄漏不了多少信息,但这个风险终归是存在的,因此要尽量去规避它。
其次,刚才我们说原论文已经调好\alpha=0.001,但这个结果可能跟原论文用Sigmoid作为Router \boldsymbol{\rho}激活函数的选择是绑定的。原因也不难想,经过Sigmoid后,每个\rho_i相对比较独立,并且都在(0,1)内,\alpha=0.001相当于说每一步的更新幅度约为千分之一,如果换Softmax、ReLU或者其他激活函数,那么就可能需要重调\alpha了。
针对这个问题,笔者建议的做法是结构Gate和Bias所用的激活函数,即
\begin{equation}\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}} \rho_i \boldsymbol{e}_i\qquad\to\qquad \boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}^{(\sigma)} + \boldsymbol{b}} \rho_i^{(h)} \boldsymbol{e}_i\end{equation}
其中\boldsymbol{\rho}^{(\sigma)} = \sigma(\boldsymbol{x}\boldsymbol{W}^{(R)}), \boldsymbol{\rho}^{(h)} = h(\boldsymbol{x}\boldsymbol{W}^{(R)}),\sigma(\cdot)是Sigmoid函数,h(\cdot)是任意单调且值域非负的函数,说白了就是加上\boldsymbol{b}的是Sigmoid激活的打分,这样我们就可以复用\alpha=0.001,至于乘上Expert的Gate,我们可以用其他激活函数,只要它的单调性跟Sigmoid一致就行。
此外,由于更新规则\eqref{eq:aux-loss-free}加了\text{sign}函数,因此有可能训出绝对值大于1的b_i,整体绝对值还可能越来越大,这些都是正常的,对模型效果不会有影响。实际上\boldsymbol{b}有一个冗余的自由度,因为全体b_i都加上同一个常数后,\mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}的结果不变。这个额外的自由度我们可以用来做其他好玩的事情(且听下回分解)。
延伸思考 #
除了MoE的负载均衡之外,Loss-Free的思想还可以应用到很多类似问题,比如VQ-VQE的编码表坍缩(Codebook Collapse),就可以用同样思路解决,而且相比之前介绍的“旋转技巧”、“线性变换技巧”显得更自然和普适。事实上,本文开篇的评价“Loss-Free潜在的学术影响力可能远超其他工作”,正是基于Loss-Free的普适性考虑的。
抛开具体的应用背景,从数学上来看,Loss-Free的贡献可以理解为给出了用梯度下降来求解指派问题的方法。一个经典的线性指派问题可以表示为:
\begin{equation}\min_f \sum_{i=1}^n c_{i, f(i)}\end{equation}
其中c_{i,j}是给定的成本函数,f是\{1,2,\cdots,n\}到自身的双射。放到本文的背景下,c_{i,j}不就相当于n个Token、n个Expert的打分,所求f不就是一个负载均衡的分配方案?求解此类问题的一般想法是在满足约束条件的空间里搜索尽可能优的解,而Loss-Free则反过来,先构建一个最优但不一定满足约束条件的解:
\begin{equation}f(i) = \mathop{\text{argmin}}_j c_{i,j}\end{equation}
这个解在分数上肯定是最优的,但不一定满足双射的条件,这里不满足双射就等价于负载不均衡。于是我们引入偏置
\begin{equation}f(i) = \mathop{\text{argmin}}_j c_{i,j} + b_j\end{equation}
b_j初始化为零,然后根据式\eqref{eq:aux-loss-free}来更新,更新规则说白了就是哪个j出现出现次数多,那减少相应的b_j,反之增加,直到出现双射为止。
文章小结 #
本文介绍了MoE负载均衡问题的Loss-Free方法,它由DeepSeek提出,其核心在于通过引入一个简单的偏置项来实现负载均衡。本文进一步思考了它与Aux Loss的联系,以及它在类似数学问题上的应用潜力。
转载到请包括本文地址:https://spaces.ac.cn/archives/10757
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 05, 2025). 《MoE环游记:3、换个思路来分配 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/10757
@online{kexuefm-10757,
title={MoE环游记:3、换个思路来分配},
author={苏剑林},
year={2025},
month={Mar},
url={\url{https://spaces.ac.cn/archives/10757}},
}
March 23rd, 2025
苏老师,我想不明白为什么要用 sign 函数来完成 b 的更新?文中的公式(6)不好吗?
原论文提到公式(6) "slightly improves load balance but does not show improvement in model performance."
\text{sign}容易控制更新幅度呀,不加\text{sign}你不大好感知每步改变了多少。
March 25th, 2025
看完三篇,个人感觉MOE的功能跟multi head attention很像。所以如果attention层做的足够好, 那还需要Router去处理attention的output吗?或者how does MoE routing interact with attention? 搜索到很多MoE-Attention Hybrid的文章,貌似现在的大模型都没有采纳这样的做法,还是选择替换掉FFN。
March 25th, 2025
苏神,top_k操作造成的router梯度计算不准的问题,是不是某种程度上用负载均衡loss缓解了?所以我们优化负载均衡,也是在优化router分类本身的准确度?