通过前几篇文章的推导和计算,我们可以发现,第一篇《MuP之上:1. 好模型的三个特征》所提的三个稳定性指标通常可以分为“参数稳定性”和“增量稳定性”两部分,而在《MuP之上:2. 线性层与最速下降》《MuP之上:3. 特殊情况特殊处理》中,我们演示了将增量稳定性与最速下降结合来获得新的更新规则(优化器)的过程.

然而,对于参数稳定性,我们之前只是停留在初始化上。这篇文章的任务,正是探讨如何在整个训练过程中维持参数的稳定性,将理论的实践补充完整。

问题背景 #

《MuP之上:2. 线性层与最速下降》为例,三个稳定性指标分别是:
\begin{align}
&\text{前向稳定性:}\quad\max_{\Vert\boldsymbol{x}\Vert_{RMS}=1} \Vert \boldsymbol{x}\boldsymbol{W}\Vert_{RMS} = \sqrt{\frac{d_{in}}{d_{out}}}\Vert\boldsymbol{W}\Vert_2 \\[5pt]
&\text{依赖稳定性:}\quad\max_{\Vert\boldsymbol{x}_1\Vert_{RMS}=\Vert\boldsymbol{x}_2\Vert_{RMS}=1} \Vert \boldsymbol{x}_1\boldsymbol{W} - \boldsymbol{x}_2\boldsymbol{W}\Vert_{RMS} = 2\sqrt{\frac{d_{in}}{d_{out}}}\Vert\boldsymbol{W}\Vert_2 \\[5pt]
&\text{更新稳定性:}\quad\max_{\Vert\boldsymbol{x}\Vert_{RMS}=1} \Vert \boldsymbol{x}(\boldsymbol{W} + \Delta\boldsymbol{W}) - \boldsymbol{x}\boldsymbol{W}\Vert_{RMS} = \sqrt{\frac{d_{in}}{d_{out}}}\Vert\Delta\boldsymbol{W}\Vert_2
\end{align}
其中$\boldsymbol{W}\in\mathbb{R}^{d_{in}\times d_{out}}$是线性层的参数。我们希望这三个指标都是$\Theta(1)$,那么就是希望参数及其增量分别满足$\Vert\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}})$和$\Vert\Delta\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}})$。在《MuP之上:3. 特殊情况特殊处理》中我们对Embedding、LM Head等层做了计算,结论也是类似的,只不过对应的范数有所不同。

增量条件我们作为稳定性指标,基于“稳中求快”的最速下降原则,来推导理论最优的更新规则,比如线性层对应的是Muon优化器:
\begin{equation}\newcommand{argmin}{\mathop{\text{argmin}}}\newcommand{tr}{\mathop{\text{tr}}}\newcommand{msign}{\mathop{\text{msign}}}\argmin_{\Vert\Delta\boldsymbol{W}\Vert_2\leq\eta\sqrt{\frac{d_{out}}{d_{in}}}} \tr(\boldsymbol{G}^{\top}\Delta\boldsymbol{W}) \qquad \Rightarrow \qquad \Delta\boldsymbol{W} = -\eta\sqrt{\frac{d_{out}}{d_{in}}}\msign(\boldsymbol{G})\end{equation}
而对于参数稳定性部分,我们之前只是要求参数的初始化满足$\Vert\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}})$,如何保证模型在整个训练过程中都维持同样的参数稳定性,尚不得而知。

初步思考 #

如何保证$\boldsymbol{W}$能保持$\Vert\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}})$呢?更一般地,给定一个参数$\boldsymbol{\omega}$,它可能是向量、矩阵甚至高阶张量,然后再给定一个范数$\Vert\cdot\Vert$,它通常是由前向稳定性或依赖稳定性诱导出来的指标,最后指定一个标度$\tau$,问:如何让$\boldsymbol{\omega}$能在训练过程中维持$\Vert\boldsymbol{\omega}\Vert=\Theta(\tau)$呢?

一个朴素的想法是直接让$\Vert\boldsymbol{\omega}\Vert=\tau$($\tau$也可以换成它的常数倍,但这不影响接下来的讨论),最简单的实现是每一步优化后通过归一化将范数重新缩放成$\tau$(例如HyperballNemotron-Flash)。还有一种思路是直接用归一化去重参数原本的模型,即$\boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega})$改为$\boldsymbol{f}(\boldsymbol{x};\tau\boldsymbol{\omega}/\Vert\boldsymbol{\omega}\Vert)$,这样理论上也能起到类似效果。

更进一步的做法是结合最速下降思想来调整更新规则,正如文章《流形上的最速下降:1. SGD + 超球面》《流形上的最速下降:4. Muon + 谱球面》和论文《Controlled LLM Training on Spectral Sphere》所讨论的。这样做从方法上看更为优雅,但实践上更加复杂,通常需要求解一个非线性方程才能获得准确的更新量。

然而,我们真的应该将一个参数的某种范数严格控制成某个值吗?直观来想,参数的范数应该由训练过程自行决定,我们顶多为其设置一个先验范围。尽管有一些工作表明,设置得当的情况下,固定参数范数为预设值并不会影响效果,但这仍然会破坏原本的训练动力学,可能需要花上更多的努力去理解和适应它。

因此,本文要提出的观点是,我们只需保证$\Vert\boldsymbol{\omega}\Vert = \mathcal{O}(\tau)$,具体来说,就是设法保证每一步都满足$\Vert\boldsymbol{\omega}\Vert \leq \tau$,至于具体是什么值,是否能保证达到$\Theta(\tau)$,则交由训练算法自己决定,不做进一步干预。

事后裁剪 #

接下来的问题自然是:如何实现$\Vert\boldsymbol{\omega}\Vert \leq \tau$呢?更具体地说,假设$\boldsymbol{\omega}$原本的更新规则是
\begin{equation}\boldsymbol{\omega}_t = \boldsymbol{\omega}_{t-1} - \eta \boldsymbol{\phi}_t\label{eq:base-update}\end{equation}
那么要如何修改,才能保证$\boldsymbol{\omega}_t$始终满足$\Vert\boldsymbol{\omega}_t\Vert\leq\tau$呢?方法当然有很多,比如上一节提到的归一化其实也算是一种方案。既然如此,我们希望从中选出对优化过程影响最小的方案,即给定参数$\boldsymbol{\omega}$和范数$\Vert\cdot\Vert$,我们希望以最小的改动,将它的范数变得不超过$\tau$,形式定义为
\begin{equation}\color{skyblue}{\lfloor}\boldsymbol{\omega}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert\leq\tau}} = \argmin_{\Vert\tilde{\boldsymbol{\omega}}\Vert\leq\tau} \Vert \boldsymbol{\omega} - \tilde{\boldsymbol{\omega}}\Vert_{RMS}\label{eq:nclip}\end{equation}
熟悉凸优化的读者应该容易看出,这其实就是将$\boldsymbol{\omega}$投影到某种范数半径不超过$\tau$的超球内的投影运算。这里最关键的地方在于,我们希望达到范数不超过$\tau$的目标,但又希望对原始参数$\boldsymbol{\omega}$的影响降低到最小,所以最小化差异指标$\Vert \boldsymbol{\omega} - \tilde{\boldsymbol{\omega}}\Vert_{RMS}$,由此诱导出特定的投影或者说裁剪操作。

至于$\color{skyblue}{\lfloor}\boldsymbol{\omega}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert\leq\tau}}$如何计算,需要具体范数具体分析,我们等会再展开。有了这个运算,我们可以考虑的一种方案是每一步更新之后,通过这个运算将参数范数进行截断,即将式$\eqref{eq:base-update}$改为
\begin{equation}\boldsymbol{\omega}_t = \color{skyblue}{\lfloor}\boldsymbol{\omega}_{t-1} - \eta \boldsymbol{\phi}_t\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert\leq\tau}}\end{equation}
这种方案我们暂且称为“事后裁剪(Post Clip)”,特点是简单直观,但可能会给人一种“不光滑”的感觉。这个不难理解,假设我们初始化的半径小于$\tau$,然后开始训练的时候参数半径在缓慢增加,达到$\tau$之后,裁剪“突然”就触发了,这个过程虽然连续,但不光滑,类似$\max(x,0)$函数。

事前衰减 #

如果介意这种不光滑性,那么可以考虑模仿权重衰减,将惩罚分摊到每一步更新上。仍然从更新规则$\eqref{eq:base-update}$出发,假设$\boldsymbol{\phi}_t$满足$\Vert\boldsymbol{\phi}_t\Vert\leq\tau$,那么按照三角不等式有$\Vert\boldsymbol{\omega}_t\Vert = \Vert\boldsymbol{\omega}_{t-1} - \eta \boldsymbol{\phi}_t\Vert\leq \Vert\boldsymbol{\omega}_{t-1}\Vert + \eta \tau$,也就是说极端情况下范数每步增加$\eta\tau$,长期累积之下便会“失控”。

为了预防这种现象,我们可以在$- \eta \boldsymbol{\phi}_t$之前,对$\boldsymbol{\omega}_{t-1}$做一点预处理,让它范数变小,刚好能抵消更新带来的增长。按照权重衰减的经验,我们可以考虑的是
\begin{equation}\boldsymbol{\omega}_t = \color{skyblue}{\lfloor}\boldsymbol{\omega}_{t-1}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert\leq (1-\eta)\Vert\boldsymbol{\omega}_{t-1}\Vert}} - \eta \boldsymbol{\phi}_t\label{eq:pre-decay}\end{equation}
也就是说,先设法将$\boldsymbol{\omega}_{t-1}$的范数降低到原来的$1-\eta$倍,然后再进行更新,这样一来
\begin{equation}\Vert\boldsymbol{\omega}_t\Vert \leq (1-\eta)\Vert\boldsymbol{\omega}_{t-1}\Vert + \eta \tau \leq \max(\Vert\boldsymbol{\omega}_{t-1}\Vert,\tau)\end{equation}
那么一直传递下去有$\Vert\boldsymbol{\omega}_t\Vert \leq \max(\Vert\boldsymbol{\omega}_{t-1}\Vert,\tau) \leq \cdots \leq \max(\Vert\boldsymbol{\omega}_0\Vert,\tau)$,即只要初始化满足$\Vert\boldsymbol{\omega}_0\Vert\leq\tau$,那么整个更新链自动满足$\Vert\boldsymbol{\omega}_t\Vert\leq \tau$。这个结论跟具体哪种范数无关,它只依赖于范数的三角不等式。而降低范数的最小改动的操作,正是式$\eqref{eq:nclip}$定义的裁剪算子,所以用它来降低范数就顺理成章了。

这种方案我们称为“事前衰减(Pre Decay)”,跟“事后裁剪”的不同点在于,后者的阈值是静态的$\tau$,所以裁剪不一定会触发,但前者的阈值是动态的$(1-\eta)\Vert\boldsymbol{\omega}_{t-1}\Vert$,并且裁剪一定会触发,这个过程更加平滑,所以我们称之为衰减而不是裁剪,它是权重衰减的一般推广。

简单例子 #

到目前为止,我们建立了约束参数范数的一般框架,分“事后裁剪”和“事前衰减”两种方案,其中的核心运算是$\eqref{eq:nclip}$定义的裁剪算子$\color{skyblue}{\lfloor}\boldsymbol{\omega}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert\leq\tau}}$,目前只有形式化的定义,实际情况下需要具体范数具体计算。

这一节我们先算一个简单例子,它选择的范数是$\Vert\cdot\Vert_{RMS}$,对向量来说它等价于L2范数,对矩阵来说它等价于F范数。不难得到
\begin{equation}\color{skyblue}{\lfloor}\boldsymbol{\omega}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert_{RMS}\leq\tau}} = \argmin_{\Vert\tilde{\boldsymbol{\omega}}\Vert_{RMS}\leq\tau} \Vert \boldsymbol{\omega} - \tilde{\boldsymbol{\omega}}\Vert_{RMS} = \min\left(1,\,\frac{\tau}{\Vert\boldsymbol{\omega}\Vert_{RMS}}\right)\boldsymbol{\omega}\end{equation}
证明就留给读者了(如果实在想不出来可以问Kimi)。特别地,代入$\tau = (1 - \eta) \Vert\omega\Vert_{RMS}$得
\begin{equation}\color{skyblue}{\lfloor}\boldsymbol{\omega}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert_{RMS}\leq (1 - \eta) \Vert\omega\Vert_{RMS}}} = \min\left(1,\,\frac{(1 - \eta) \Vert\omega\Vert_{RMS}}{\Vert\boldsymbol{\omega}\Vert_{RMS}}\right)\boldsymbol{\omega} = (1-\eta)\boldsymbol{\omega}\end{equation}
然后代入到式$\eqref{eq:pre-decay}$得
\begin{equation}\boldsymbol{\omega}_t = (1-\eta)\boldsymbol{\omega}_{t-1} - \eta \boldsymbol{\phi}_t\end{equation}
容易看出,这就是常规的权重衰减(Weight Decay)。换言之,RMS范数下的Pre Decay就是我们常用的权重衰减,它是保持RMS范数(等价地,向量的L2范数或矩阵的F范数)约束下,对原始参数改动最Minimal的Pre Decay方案。

奇异裁剪 #

现在开始进入到本文的“主节目”——矩阵参数与Muon。这里我们用回记号$\boldsymbol{W}$,将Muon原本的更新规则写成
\begin{equation}\boldsymbol{W}_t = \boldsymbol{W}_{t-1} - \eta\lambda\boldsymbol{\Phi}_t,\quad \boldsymbol{\Phi}_t=\frac{1}{\lambda}\sqrt{\frac{d_{out}}{d_{in}}}\msign(\boldsymbol{G}_t)\end{equation}
记$\tau = \frac{1}{\lambda}\sqrt{\frac{d_{out}}{d_{in}}}$,那么有$\Vert\boldsymbol{\Phi}_t\Vert_2=\tau$,想让$\boldsymbol{W}_t$满足$\Vert\boldsymbol{W}_t\Vert_2\leq\tau$的两种方案是:
\begin{align}
\text{Post Clip:}\quad\boldsymbol{W}_t =&\, \color{skyblue}{\lfloor}\boldsymbol{W}_{t-1} - \eta\lambda\boldsymbol{\Phi}_t\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert_2\leq\tau}} \\[5pt]
\text{Pre Decay:}\quad\boldsymbol{W}_t =&\, \color{skyblue}{\lfloor}\boldsymbol{W}_{t-1}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert_2\leq(1-\eta\lambda)\Vert\boldsymbol{W}_{t-1}\Vert_2}} - \eta\lambda\boldsymbol{\Phi}_t \\
\end{align}
接下来的任务是要计算$\color{skyblue}{\lfloor}\boldsymbol{W}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert_2\leq\tau}}$,根据RMS和F范数的等价性,它又等于
\begin{equation}\color{skyblue}{\lfloor}\boldsymbol{W}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert_2\leq\tau}} = \argmin_{\Vert\tilde{\boldsymbol{W}}\Vert_2\leq\tau} \Vert\boldsymbol{W} - \tilde{\boldsymbol{W}}\Vert_F\label{eq:mclip-loss}\end{equation}
这个问题的最优解对部分读者来说应该不陌生,它是我们在《高阶MuP:更简明但更高明的谱条件缩放》提到过的“奇异值裁剪(Singular Value Clipping,SVC)”,在《通过msign来计算奇异值裁剪mclip(上)》《通过msign来计算奇异值裁剪mclip(下)》中则称之为$\newcommand{mclip}{\mathop{\text{mclip}}}\mclip$:
\begin{equation}\color{skyblue}{\lfloor}\boldsymbol{W}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert_2\leq\tau}} = \mclip(\boldsymbol{W};\tau) = \boldsymbol{U}\min(\boldsymbol{\Sigma},\tau)\boldsymbol{V}^{\top}\label{eq:2-to-mclip}\end{equation}
其中$\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$是$\boldsymbol{W}$的SVD,$\min(\boldsymbol{\Sigma},\tau)$是将奇异值阶段到不超过$\tau$,证明过程我们下节再演示。有了这个记号,两种方案可以分别写成
\begin{align}
\text{Post Clip:}\quad\boldsymbol{W}_t =&\, \mclip(\boldsymbol{W}_{t-1} - \eta\lambda\boldsymbol{\Phi}_t;\tau) \\[5pt]
\text{Pre Decay:}\quad\boldsymbol{W}_t =&\, \mclip(\boldsymbol{W}_{t-1};(1-\eta\lambda)\Vert\boldsymbol{W}_{t-1}\Vert_2) - \eta\lambda\boldsymbol{\Phi}_t \\
\end{align}

推导过程 #

这一节我们来证明结论$\eqref{eq:2-to-mclip}$。设$\boldsymbol{W}$的SVD是$\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$,其中$\boldsymbol{U}\in\mathbb{R}^{d_{in}\times d_{in}}$、$\boldsymbol{\Sigma}\in\mathbb{R}^{d_{in}\times d_{out}}$、$\boldsymbol{V}\in\mathbb{R}^{d_{out}\times d_{out}}$,那么
\begin{equation}\Vert\boldsymbol{W} - \tilde{\boldsymbol{W}}\Vert_F = \Vert\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} - \tilde{\boldsymbol{W}}\Vert_F = \Vert\boldsymbol{U}(\boldsymbol{\Sigma} - \boldsymbol{U}^{\top}\tilde{\boldsymbol{W}}\boldsymbol{V})\boldsymbol{V}^{\top}\Vert_F = \Vert\boldsymbol{\Sigma} - \boldsymbol{U}^{\top}\tilde{\boldsymbol{W}}\boldsymbol{V}\Vert_F\end{equation}
最后一个等号是因为正交矩阵不改变F范数。同时,正交矩阵也不改变谱范数,所以设$\tilde{\boldsymbol{\Sigma}}=\boldsymbol{U}^{\top}\tilde{\boldsymbol{W}}\boldsymbol{V}$,目标$\eqref{eq:mclip-loss}$可以等价地简化成
\begin{equation}\argmin_{\Vert\tilde{\boldsymbol{\Sigma}}\Vert_2\leq\tau} \Vert\boldsymbol{\Sigma} - \tilde{\boldsymbol{\Sigma}}\Vert_F\end{equation}
注意,这里的$\boldsymbol{\Sigma}$是对角阵,设其对角线元素为$\sigma_1,\sigma_2,\cdots \geq 0$,但$\tilde{\boldsymbol{\Sigma}}$暂时是不确定的,证明时要设它为一般矩阵。用分量写法得
\begin{equation}\Vert\boldsymbol{\Sigma} - \tilde{\boldsymbol{\Sigma}}\Vert_F^2 = \sum_i \sigma_i^2 + \sum_{i,j} \tilde{\Sigma}_{i,j}^2 - 2\sum_i \sigma_i \tilde{\Sigma}_{i,i} \geq \sum_i \sigma_i^2 + \sum_i (\tilde{\Sigma}_{i,i}^2 - 2 \sigma_i \tilde{\Sigma}_{i,i})\end{equation}
逐项来看,$\tilde{\Sigma}_{i,i}^2 - 2 \sigma_i \tilde{\Sigma}_{i,i}$只不过是关于$\tilde{\Sigma}_{i,i}$的一元二次函数,最小值在$\tilde{\Sigma}_{i,i}=\sigma_i$取到。但我们还有约束$\Vert\tilde{\boldsymbol{\Sigma}}\Vert_2\leq\tau$,由于谱范数大于等于矩阵任意元素的绝对值,所以至少有约束$\tilde{\Sigma}_{i,i}\leq\tau$,在该约束下,$\tilde{\Sigma}_{i,i}^2 - 2 \sigma_i \tilde{\Sigma}_{i,i}$的最小值在$\tilde{\Sigma}_{i,i}^* = \min(\sigma_i,\tau)$取到。

考虑到让所有等号同时成立,我们得到$\tilde{\Sigma}_{i,j}^*=0(i\neq j)$,由此可见$\tilde{\boldsymbol{\Sigma}}^*$也是一个对角阵,它正好可以简写成$\tilde{\boldsymbol{\Sigma}}^*=\min(\boldsymbol{\Sigma},\tau)$,这又对应于$\tilde{\boldsymbol{W}}^*=\boldsymbol{U}\min(\boldsymbol{\Sigma},\tau)\boldsymbol{V}^{\top}$,至此结论$\eqref{eq:2-to-mclip}$得证。

裁剪主项 #

那么,$\mclip$该如何高效计算呢?每步训练都执行一次SVD显然过于昂贵了。在文章《通过msign来计算奇异值裁剪mclip(上)》《通过msign来计算奇异值裁剪mclip(下)》中,我们其实系统探讨过这个问题,当时的提法是借助$\msign$来实现,但需要2~3次$\msign$,代价不菲。例如下篇发现的一个恒等式是
\begin{equation}\mclip(\boldsymbol{W};\tau)
=\frac{1}{2}\Bigl\{\boldsymbol{W}+\tau\msign(\boldsymbol{W})-(\tau\boldsymbol{I}-\boldsymbol{W}\msign(\boldsymbol{W})^{\top})\msign(\tau\msign(\boldsymbol{W})-\boldsymbol{W})\Bigr\}\end{equation}
它需要两次$\msign$,由于参数的计算往往是在FP32下进行的,所以执行两次$\msign$还是比较昂贵的,因此还不算特别可行。

这里我们主要考虑在《基于流式幂迭代的Muon实现:5. 延伸》讨论过的逐项裁剪思路。具体来说,$\mclip$是将所有大于$\tau$的奇异值都变成$\tau$,那么必要的操作是将主奇异值变成$\tau$(如果它大于$\tau$),裁剪掉主奇异值后,如果还存在大于$\tau$的奇异值,那么其中的最大者就会变成新的主奇异值。所以,只要反复“裁剪主奇异值到$\tau$”,就可以实现$\mclip$。

由于主奇异值和主奇异向量可以通过幂迭代高效求出(记为$\mathop{\text{SVD1}}$),所以主奇异值裁剪可以认为是高效的。进一步,我们假设训练足够平缓,每一步可以只执行一次主奇异值裁剪,这样也能近似地实现同样的效果。基于这个策略,两种限制奇异值的方案可以进一步写成
\begin{align}
\text{Post Clip:}\quad\boldsymbol{W}_t =&\, \tilde{\boldsymbol{W}}_t - \max(\sigma_1 - \tau, 0) \boldsymbol{u}_1 \boldsymbol{v}_1^{\top},\quad\sigma_1, \boldsymbol{u}_1, \boldsymbol{v}_1 = \mathop{\text{SVD1}}(\tilde{\boldsymbol{W}}_t),\quad\tilde{\boldsymbol{W}}_t = \boldsymbol{W}_{t-1} - \eta \boldsymbol{\Phi}_t \\[5pt]
\text{Pre Decay:}\quad\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \lambda\eta\sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} - \eta \boldsymbol{\Phi}_t,\quad\sigma_1, \boldsymbol{u}_1, \boldsymbol{v}_1 = \mathop{\text{SVD1}}(\boldsymbol{W}_{t-1})
\end{align}
其中“Pre Decay”版,正是《从谱范数梯度到新式权重衰减的思考》引入的谱权重衰减,时隔一年多,我们从另一条途径得到了相同的结果;至于“Post Clip”版,@@_arohan_ 曾在X上提过,当时称为“Wion”。实践中,由于每步只裁剪一个奇异值,所以可能会存在一些比较“上进”的矩阵,其谱范数明显偏离设定的阈值,这是正常的,在LR Decay阶段会慢慢降下来。

其他细节 #

如果想要裁剪得更精准,我们也可以用幂迭代同时求Top-$k$个奇异值和奇异向量,每步至多裁剪$k$个奇异值,代价是幂迭代的L2 Normalize要换成QR分解,而QR分解也有一些加速手段,相关原理可以参考流式幂迭代系列文章,如《基于流式幂迭代的Muon实现:1. 初识》

除了线性层矩阵的谱范数外,在《MuP之上:3. 特殊情况特殊处理》中我们遇到了另外一些层的不同范数,比如Embedding、LM Head分别对应最大行、列RMS,而RMS Norm层的gamma参数,对应的是最大绝对值,也称为向量的无穷范数。

幸运的是,这些范数下的裁剪算子$\color{skyblue}{\lfloor}\boldsymbol{\omega}\color{skyblue}{\rfloor}_{\color{skyblue}{\Vert\cdot\Vert\leq\tau}}$都比较好算。比如Embedding层的范数是行RMS的最大者,裁剪算子就是将每行向量的RMS都裁剪到不超过$\tau$;LM Head同理,只不过行换成列;至于gamma参数就更简单了,直接就是主项裁剪$\mathop{\text{clip}}(\boldsymbol{\gamma};-\tau,\tau) = \max(\min(\boldsymbol{\gamma},\tau),-\tau)$。

这些结论都很直观,而且证明也比较简单,我们就不展开了,权当是留给读者的练习题。

必要保证 #

可能有读者疑问,非要做得这么复杂吗?直接像《Training Deep Learning Models with Norm-Constrained LMOs》那样用普通的权重衰减不行吗?比如
\begin{equation}\boldsymbol{W}_t = (1-\eta\lambda)\boldsymbol{W}_{t-1} - \eta\sqrt{\frac{d_{out}}{d_{in}}}\msign(\boldsymbol{G}_t)\label{eq:muon-wd}\end{equation}
同样也可以将谱范数限制成$\tau = \frac{1}{\lambda}\sqrt{\frac{d_{out}}{d_{in}}}$内,为什么不用这种简单的形式呢?

答案是:避免过度干预。从定义$\eqref{eq:nclip}$可以看成,我们定义的裁剪算子是在实现同样效果的前提下对原始参数改动最小的操作,对于谱范数来说,直接乘以$1-\eta\lambda$虽然也能将$\boldsymbol{W}_{t-1}$的谱范数变得不超过$(1-\eta\lambda)\Vert\boldsymbol{W}_{t-1}\Vert_2$,但它既然不同于最小改动的$\mclip$,那么它必然存在某种程度上的“过度干预”。

过度干预的后果有两种:要不为了保证效果,选取较小的$\lambda$,此时$\tau$过大,即无法保证谱范数载我们期望的范围内;要不为了保证控制谱范数,选取较大的$\lambda$,但这样就会明显降低效果。比如$d_{in}=d_{out}$的情况下希望谱范数不超过$5$,那么根据公式$\lambda=0.2$,对于式$\eqref{eq:muon-wd}$的Muon,0.2的权重衰减系数是极大的(一般数值是0.01左右)。

注意我们多次强调“保证”,这是很关键的。假设我们用权重衰减,系数设为0.01,理论上谱范数最多能达到100,但在小模型上做实验发现可能5都不到,这是很常见的。然而,小模型安全不代表大模型安全,之前我们就说过,大模型很强,强到可以放大任何细微的Bug,如果理论上界是100,小模型没机会达到,但大模型真的是可能达到的。

所以,让参数的关键范数在理论上保持有一个合理的界是非常必要的,这也是“稳中求快”原则中“稳”的体现。而式$\eqref{eq:nclip}$定义的裁剪算子,则是保证有界的最“轻量”运算,换言之它可能是在保证同样界限的前提下对效果损失最小的操作。

文章小结 #

本文基于最小改动思想,提出了在训练过程中维持参数稳定性的一般框架,包含Post Clip与Pre Decay两种方案。在谱范数下,它们进一步可以演化成奇异值裁剪与谱权重衰减。这些操作旨在保证参数关键范数有界的同时,尽可能降低对训练动力学的干预。

转载到请包括本文地址:https://spaces.ac.cn/archives/11729

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Apr. 24, 2026). 《MuP之上:4. 坚守参数的稳定性 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/11729

@online{kexuefm-11729,
        title={MuP之上:4. 坚守参数的稳定性},
        author={苏剑林},
        year={2026},
        month={Apr},
        url={\url{https://spaces.ac.cn/archives/11729}},
}