流形上的最速下降:3. Muon + Stiefel
By 苏剑林 | 2025-08-08 | 24301位读者 |上回说到,当我们把优化对象从向量参数转移到矩阵参数,并选用更适合矩阵的谱范数约束后,Muon优化器便自然而然地出现了。进一步地,我们考虑了给参数加上正交约束后的最速下降方向,这其中又分方阵和非方阵两部分讨论,其中方阵的求解我们在上一篇文章已经完成,但非方阵部分依然悬而未决。
本文的目标,则是把非方阵部分的求解补上,使得正交约束下的优化得以完全解决。
任务信息 #
先简单回顾一下上文《流形上的最速下降:2. Muon + 正交》的结果。我们要求解的目标是
\begin{equation}\newcommand{tr}{\mathop{\text{tr}}}\max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,(\boldsymbol{W} - \eta \boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta \boldsymbol{\Phi})=\boldsymbol{I}\end{equation}
其中$\boldsymbol{W},\boldsymbol{\Phi}\in\mathbb{R}^{n\times m}(n \geq m)$,$\Vert\cdot\Vert_2$是谱范数。基于“一阶近似够用”的原则,可以简化成
\begin{equation}\max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}\label{eq:ori-obj}\end{equation}
其中满足$\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}$的全体$\boldsymbol{\Phi}$也称为$\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$的“切空间”。在上一篇文章中我们已经求出了通解的形式
\begin{equation}\boldsymbol{\Phi} = \newcommand{msign}{\mathop{\text{msign}}}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})\end{equation}
其中$\boldsymbol{X}\in\mathbb{R}^{m\times m}$是一个待定的对称矩阵。
剩下的难题就是给出对称矩阵$\boldsymbol{X}$的计算方法,使得$\boldsymbol{W}^{\top}\boldsymbol{\Phi}$是一个反对称矩阵。一旦完成求解,那么对应的$\boldsymbol{\Phi}$自然是最优解。对于$n=m$,我们已经求得闭式解$\boldsymbol{X}=-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$;真正困难的是$n > m$的情形,此时亦称为“Stiefel流形”,它正是《Orthogonal manifold》所留下的Open problem。
方程变换 #
说白了,我们现在的任务是求解方程组:
\begin{equation}\boldsymbol{W}^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})+\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}\boldsymbol{W} = \boldsymbol{0}\label{eq:start}\end{equation}
当$n=m$时,$\boldsymbol{W}^{\top}$可以直接吸收到$\msign$里边。所以求解得以简化,然而$n > m$时并不能做这样的吸收,这也是求解的困难所在。笔者倾向于$n > m$时没有简单的显式解,所以我们来寻求数值算法。
根据定义$\msign(\boldsymbol{M})=\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}$,可以写出
\begin{equation}\boldsymbol{W}^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) = \boldsymbol{W}^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})\boldsymbol{Q}^{-1} = (\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X})\boldsymbol{Q}^{-1}\end{equation}
其中$\boldsymbol{Q} = ((\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}))^{1/2}$。在这个新记号下,方程组变为
\begin{equation}(\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X})\boldsymbol{Q}^{-1} + \boldsymbol{Q}^{-1}(\boldsymbol{G}^{\top}\boldsymbol{W} + \boldsymbol{X}) = \boldsymbol{0}\end{equation}
同时左乘和右乘$\boldsymbol{Q}$,得到
\begin{equation}\boldsymbol{Q}(\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X}) + (\boldsymbol{G}^{\top}\boldsymbol{W} + \boldsymbol{X})\boldsymbol{Q} = \boldsymbol{0}\label{eq:r-x}\end{equation}
其中$\boldsymbol{Q}$又成立
\begin{equation}\boldsymbol{Q} = (\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})\label{eq:r-q}\end{equation}
迭代求解 #
现在笔者的想法是,从某个初值的$\boldsymbol{X}$出发,代入式$\eqref{eq:r-q}$得到$\boldsymbol{Q}$,然后将$\boldsymbol{Q}$代入方程组$\eqref{eq:r-x}$求解出新的$\boldsymbol{X}$,反复迭代,直到收敛。在已知$\msign$的情况下,式$\eqref{eq:r-q}$是可以显式计算的,所以唯一的难度是解方程组$\eqref{eq:r-x}$。
我们可以整理一下式$\eqref{eq:r-x}$:
\begin{equation}\boldsymbol{Q}\boldsymbol{X} + \boldsymbol{X}\boldsymbol{Q} = -2[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\label{eq:r-xx}\end{equation}
在给定$\boldsymbol{Q}$的前提下,这其实是关于$\boldsymbol{X}$是线性方程组,名为“连续型Lyapunov方程”,也可以看成是“Sylvester方程”的一个特例。如果我们只用CPU进行计算,Scipy其实已经自带了该方程的求解函数scipy.linalg.solve_continuous_lyapunov,直接调用即可。
至于初值的选择,我们可以考虑方阵时的解$-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$,这样显然是方阵到非方阵的一个自然过渡。我们也可以从方程$\eqref{eq:r-xx}$的另一个等价形式,来观察初值$-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$的合理性:
\begin{equation}\boldsymbol{Q}(\boldsymbol{X} + [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}) + (\boldsymbol{X} + [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}})\boldsymbol{Q} =[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}\boldsymbol{Q} -\boldsymbol{Q}[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}\end{equation}
所以$-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$的精确程度,取决于$[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}$与$\boldsymbol{Q}$乘法的可交换程度,它们越接近交换矩阵,那么$-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$就越准确。不过后面的实测结果显示,我们的迭代算法对初值并不是特别敏感,即便以全零矩阵为初值问题也不大。
自己动手 #
刚才我们说到Scipy自带了求解Lyapunov方程函数,因此可以直接调用而无需关心求解过程。但这也仅限于CPU的Scipy,笔者查了一下,Torch和Jax都没有同类函数,所以要用GPU计算的话,只能“自力更生”。
自己编程求解方程$\eqref{eq:r-xx}$的做法有两个。一是按照《矩阵符号函数mcsgn能计算什么?》的思路,用$\newcommand{mcsgn}{\mathop{\text{mcsgn}}}\mcsgn$(不是$\msign$)来求解:
\begin{equation}\boldsymbol{X} = \mcsgn\left(\begin{bmatrix}-\boldsymbol{Q} & -[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} \\ \boldsymbol{0} & \boldsymbol{Q}\end{bmatrix}\right)_{[:m,m:]}\end{equation}
二是基于SVD求解,这个方法我们在《msign的导数》里计算$\msign$的梯度时已经用过,这里结合方程$\eqref{eq:r-xx}$再介绍一遍。根据$\boldsymbol{Q}$的定义知它是正定对称的,那么可以特征值分解为$\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$,其中$\boldsymbol{V}$是正交矩阵而$\boldsymbol{\Sigma}=\mathop{\text{diag}}(\sigma_1,\cdots,\sigma_m)$是对角矩阵,代入到式$\eqref{eq:r-xx}$,可整理得
\begin{equation}\boldsymbol{\Sigma}(\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V}) + (\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V})\boldsymbol{\Sigma} = -2\boldsymbol{V}^{\top}[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\boldsymbol{V}\end{equation}
左端可以表示成$(\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V})\otimes \boldsymbol{S}$,其中$\otimes$是Hadamard积,$\boldsymbol{S}_{i,j} = \sigma_i + \sigma_j$。由此,可以解得
\begin{equation}\boldsymbol{X} = -2\boldsymbol{V}((\boldsymbol{V}^{\top}[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\boldsymbol{V})\oslash \boldsymbol{S})\boldsymbol{V}^{\top}\end{equation}
其中$\oslash$是Hadamard商。这里比较有趣的地方是,对$\boldsymbol{Q}$做特征值分解,基本等价于对$\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$做SVD,而对$\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$做SVD也可以用来求$\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})$,所以只需一遍SVD就可以把$\msign$和方程$\eqref{eq:r-xx}$的解都算出来。
两个思路各有特点。思路一需要先对$m\times m$的矩阵算$\msign$,再对$2m\times 2m$的矩阵算$\mcsgn$,虽然它们都可以用Newton-Schulz迭代来高效计算,但也代价不菲。此外,这里我们还要选择能收敛且精度高的系数(推荐《msign算子的Newton-Schulz迭代(下)》的结果),要不然$\mcsgn$和$\msign$的计算都不收敛,更别说$\boldsymbol{X}$了。
思路二需要用到SVD。虽然SVD的复杂度较高,而且往往要强制使用FP32精度,但在这里的问题上,每一步迭代只需要一次SVD就可以同时求$\msign$和$\boldsymbol{X}$,总体效率也不会太差。如果我们需要正交约束的矩阵参数并不多,那么SVD可能是最简便的选择。
相关结果 #
本文之前,@leloy 在他的博客文章《Heuristic Solutions for Steepest Descent on the Stiefel Manifold》也提出了原始目标$\eqref{eq:ori-obj}$的两种启发式求解方法。这里的“启发式”,指的是在大多数情况下,它能得到一个还不错的解,但无法保证是最优解,这里我们也一起学习下。
第一种方法可以说是纯几何法。首先我们定义投影运算:
\begin{equation}\newcommand{proj}{\mathop{\mathcal{P}}}\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M}) = \boldsymbol{M} - \boldsymbol{W}[\boldsymbol{W}^{\top}\boldsymbol{M}]_{\text{sym}}\end{equation}
可以验证$\boldsymbol{W}^{\top}\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M})$一定是反对称矩阵,也就是说$\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M})$一定在切空间中,所以我们将它视为任意矩阵$\boldsymbol{M}$投影到$\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$的切空间中的投影运算。
我们从梯度$\boldsymbol{G}$出发,$\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M})$肯定是在切空间了,但我们知道Muon的更新量一定是个正交矩阵(满秩时),而$\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M})$不一定正交,所以我们可以通过$\msign$来寻找与之最邻近的正交矩阵,即$\msign(\proj\nolimits_{\boldsymbol{W}}(\boldsymbol{M}))$。然而$\msign$之后又不一定在切空间了,我们又可以将它投影到切空间去,然后又寻找最近正交矩阵,反复迭代:
\begin{equation}\boldsymbol{\Phi} = (\msign\circ\proj\nolimits_{\boldsymbol{W}}\circ\cdots\circ\msign\circ\proj\nolimits_{\boldsymbol{W}})(\boldsymbol{M})\end{equation}
这便是 @leloy 的第一种思路,交替投影到切空间和正交空间直到收敛,可以说相当直观。而且在比较随机的情况下,它跟最优解也非常接近,甚至能精确到小数点后4位,以至于笔者一开始认为它就是精确解。不过后来经过搜索,发现了它跟最优解偏差足够大的case,才确认了这只是巧合,并非最优解。
第二种方法可以称为线搜索。具体来说,当$n > m$时,我们可以考虑将$\boldsymbol{W}$补成标准的$n\times n$的正交矩阵$[\boldsymbol{W},\overline{\boldsymbol{W}}]$,然后将所求$\boldsymbol{\Phi}$分解为$\boldsymbol{W}^{\top}\boldsymbol{\Phi}$和$\overline{\boldsymbol{W}}{}^{\top}\boldsymbol{\Phi}$两部份。接着 @leloy 做了一个贪心近似,先求$\boldsymbol{W}^{\top}\boldsymbol{\Phi}$的最优解,然后再求$\overline{\boldsymbol{W}}{}^{\top}\boldsymbol{\Phi}$的最优解,两者之间再引入一个线搜索来提高准确度。
这样一套操作下来,确实能得到一个近似程度还不错的解,并且它一定在切空间内且满足正交性。求解过程需要计算谱范数、$\msign$和Cholesky分解,细节大家自行看作者的文章了。此外,当$m=2$时,理论上它是能搜出最优解的,这是因为$2\times 2$的反对称矩阵只有一个自由参数,而线搜索刚好是一个自由度。
测试一下 #
下面在Numpy中实测上面几个方法,其中主要目的是验证方法本身的正确性,所以我们直接用奇异值分解和特征值分解来实现$\msign$和$\mcsgn$。
import numpy as np
import scipy as sp
def mcsgn(x):
"""特征值分解精确计算mcsgn
"""
s, v = np.linalg.eig(x)
return v @ np.diag(np.sign(s)) @ np.linalg.inv(v)
def msign(g):
"""奇异值分解精确计算msign
"""
u, s, vh = np.linalg.svd(g, full_matrices=False)
return u @ np.diag(np.sign(s)) @ vh
def sym(x):
"""对称化
"""
return (x + x.T) * 0.5
def skew(x):
"""反对称化
"""
return (x - x.T) * 0.5
def proj(g, w):
"""投影到正交的切空间
"""
return g - w @ sym(w.T @ g)
def jianlin_by_mcsgn(g, w, steps=20):
"""通过mcsgn来构建本文的迭代
"""
n, m = g.shape
x = -sym(w.T @ g)
for i in range(1, steps + 1):
phi = msign(z := g + w @ x)
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
if i == steps:
return phi
q = z.T @ phi
x = mcsgn(np.block([[-q, -sym(q @ w.T @ g)], [np.zeros_like(q), q]]))[:m, m:]
# x = -2 * sp.linalg.solve_continuous_lyapunov(q, sym(q @ w.T @ g))
def jianlin_by_svd(g, w, steps=20):
"""通过svd来构建本文的迭代
"""
x = -sym(w.T @ g)
for i in range(1, steps + 1):
u, s, vh = np.linalg.svd(z := g + w @ x, full_matrices=False)
phi = (u * np.sign(s)) @ vh
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
if i == steps:
return phi
x = -2 * vh.T @ (vh @ sym(z.T @ phi @ w.T @ g) @ vh.T / (s + s[:, None])) @ vh
def leloy_v1(g, w, steps=20):
"""交替投影到切空间和正交空间
"""
phi = g
for i in range(1, steps + 1):
phi = msign(proj(phi, w))
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
return phi
def leloy_v2(g, w, steps=20):
"""分部贪心求解 + 线搜索(形式经过笔者的简化)
"""
n, m = g.shape
taus = np.linspace(0, 1, steps + 2)[1:-1]
p_max, tau_opt, phi_opt = 0, 0, None
for tau in taus:
b = (b := skew(w.T @ g)) * tau / max(np.linalg.norm(b, ord=2), 1e-8)
r = np.linalg.cholesky(np.eye(m) - b.T @ b)
c = msign((np.eye(n) - w @ w.T) @ g @ r) @ r
phi = w @ b + c
print('tau:', tau, ', inner product:', p := (phi * g).sum())
if p > p_max:
p_max, tau_opt, phi_opt = p, tau, phi
print('best inner product:', p_max, ', tau:', tau_opt)
return phi_opt
w = np.array([[ 0.69453734, -0.26590866, -0.44721806, 0.2753041 ],
[-0.11738148, -0.5588003 , -0.17580748, 0.3218624 ],
[-0.4515288 , -0.23489913, -0.26683152, -0.25739142],
[ 0.02392521, 0.02664689, 0.48423648, 0.6193399 ],
[ 0.45194831, -0.25206333, 0.27654836, -0.60242337],
[ 0.21197332, -0.09174792, 0.24521762, -0.08484317],
[-0.15496767, -0.26446804, -0.34942415, -0.01877318],
[-0.16181251, -0.6474956 , 0.45243263, -0.01776086]])
g = np.array([[-17.85745 , -10.758921 , -2.9583392 , 6.245008 ],
[-28.883093 , 19.772121 , 8.086545 , -21.564013 ],
[ -1.6274693 , -14.96859 , 3.4465332 , 3.1070817 ],
[ -7.8890743 , 1.5304767 , -8.949573 , 9.579629 ],
[ 2.246596 , 14.46572 , 12.8451 , -2.7370298 ],
[ -0.9496974 , 6.9879804 , 2.849277 , 1.1148484 ],
[ -8.115278 , -18.054405 , -0.19287404, 7.0389237 ],
[-15.062008 , -15.02901 , 2.9083247 , 21.706533 ]])
phi1 = jianlin_by_mcsgn(g, w, steps=100)
phi2 = jianlin_by_svd(g, w, steps=100)
phi3 = leloy_v1(g, w, steps=100)
phi4 = leloy_v2(g, w, steps=100)
assert np.allclose(phi1, phi2)
w = np.linalg.qr(np.random.randn(100, 50))[0]
g = np.random.randn(100, 50)
phi1 = jianlin_by_mcsgn(g, w, steps=10)
phi2 = jianlin_by_svd(g, w, steps=10)
phi3 = leloy_v1(g, w, steps=10)
phi4 = leloy_v2(g, w, steps=10)
assert np.allclose(phi1, phi2)对于代码中给出的第一组$\boldsymbol{W},\boldsymbol{G}$,笔者方法求得的最优$\tr(\boldsymbol{G}^{\top} \boldsymbol{\Phi})$大致是$90$,并且$\mcsgn$和SVD的结果是完全一样的;而 @leloy 的第一种方法求得的结果是大致是$70$,第二种方法求得的结果大致是$80$,都跟最优解有一定差距。
不过,第一组$\boldsymbol{W},\boldsymbol{G}$只是为了显示出三个方法差距特意搜出来的极端例子,如果我们更换相对随机的数值,那么其实本文的解法和 @leloy 的第一种解法会很接近,并且迭代步数也可以少很多(5~10步),此时 @leloy 的第二种解法跟最优解差距更大。读者可以自行构建一些例子测试。
拓展思考 #
关于原始问题$\eqref{eq:ori-obj}$的求解,这里就暂告一段落了。接下来补充讨论几个可能有疑惑的细节问题。
首先,为了方便描述,笔者前面给出的迭代求解过程有一个隐含假设,那就是$\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$自始至终都是满秩的(秩为$m$),要不然矩阵$\boldsymbol{S}$就会有零分量,$\oslash\boldsymbol{S}$就不大好操作。但这个困难不是本质的,因为方程$\eqref{eq:start}$必然会有解,所以遇到分母为零时,分子必然也为零,于是我们只需要简单将$\boldsymbol{S}$的零分量换成一个小正数,就能得到正确的结果。
从数值计算的角度看,我们也很少有机会能遇到绝对等于零的奇异值,所以不需要太担心这个问题,默认$\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$满秩就好。在这个默认假设之下,回缩操作会变得很简单,因为
\begin{equation}(\boldsymbol{W} - \eta\boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta\boldsymbol{\Phi}) = \boldsymbol{W}^{\top} \boldsymbol{W} - \eta(\boldsymbol{W}^{\top} \boldsymbol{\Phi} + \boldsymbol{\Phi}^{\top}\boldsymbol{W}) + \eta^2 \boldsymbol{\Phi}^{\top}\boldsymbol{\Phi}\end{equation}
根据Stiefel流形的定义,右端第一项是$\boldsymbol{I}$,根据切空间的条件,第二项是$\boldsymbol{0}$,最后是满秩时$\msign$出来的也是一个Stiefel流形的矩阵,所以第三项是$\eta^2 \boldsymbol{I}$,总的结果是$(1+\eta^2)\boldsymbol{I}$,只需要除以$\sqrt{1+\eta^2}$就可以实现回缩:
\begin{equation}\boldsymbol{W}\quad\leftarrow\quad\frac{\boldsymbol{W} - \eta\boldsymbol{\Phi}}{\sqrt{1+\eta^2}}\end{equation}
看到这里,不知道笔者有没有发现,这里其实有一个更深刻的问题:不管是相对简单的正交流形,还是相对复杂的Stiefel流形,我们应该使用何种精度计算?要知道“正交”是一个精确的定量约束,$\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$包含$m(m+1)/2$个等式约束,可以预见在低精度下用上式进行迭代,久而久之肯定会严重偏离正交的,更不用说求解$\boldsymbol{\Phi}$过程中的误差了。
因此,笔者认为,除非我们定期给参数施加正交化操作(即$\boldsymbol{W}\leftarrow\msign(\boldsymbol{W})$)来将它拉回到正交流形上,否则求解过程的计算精度起码要FP32起步。考虑到通常要加正交约束的参数并不会很多,所以一般来说这也不算太大的代价。
文章小结 #
这篇文章将上一篇文章的“Muon + 正交流形”推广到了更一般的“Muon + Stiefel流形”,主要发现是一个求解对应更新量的迭代算法。
转载到请包括本文地址:https://spaces.ac.cn/archives/11221
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Aug. 08, 2025). 《流形上的最速下降:3. Muon + Stiefel 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/11221
@online{kexuefm-11221,
title={流形上的最速下降:3. Muon + Stiefel},
author={苏剑林},
year={2025},
month={Aug},
url={\url{https://spaces.ac.cn/archives/11221}},
}










August 19th, 2025
发现这个问题和我3年前讨论的一个问题有些类似(https://arxiv.org/abs/2110.03898),都是在stiefel流形上进行优化,只不过当时改造的优化器就是比较常见的SGD、Adam这种。有一个比较好奇的点,就是模型中对参数添加正交性约束能否带来额外的收益?比如能否通过参数正交化来增强对模型计算本质的理解,或者通过某些变换实现推理效率更优?当年讨论的问题是实际的量子多体问题,本身演化算子有这种类似正交性的约束,但在深度神经网络中这种约束似乎并不常见。
应该说是“需要”正交化后才来考虑看本文或相关论文。
至于哪里会需要正交约束,这是独立的课题了,感觉上也不是太朴素的事情,只能说某些case下确实可以考虑,尤其是一些需要防止退化/坍缩的场景,比如学习codebook,或者以前GAN的判别器,等等。
如果说正交约束、或者更一般的isometric变换可以保持hidden_state的归一性,那么对权重的isometric约束在一定程度上起到了norm的作用?比如说对wq, wk权重在head_dim轴上进行isometric约束,是不是效果和q/k norm差不多。
很好的想法,这个做法对于压缩maxlogit来说是没有问题的,但如果单纯为了压缩maxlogit,可能有点大材小用了。因为控制maxlogit理论上只需要控制weight的spectral norm,这可以通过简单地spectral normalize实现,spectral normalize的成本比多计算几次msign要低得多。
September 27th, 2025
[...]Non-Riemannian geometry in machine learning. Thomas Flynn’s paper from 2017 on duality structure gradient descent characterizes the neural network weight space as a Finsler manifold, meaning a manifol[...]