设$\boldsymbol{P}\in\mathbb{R}^{n\times n}$是一个特征值都是非负实数的$n$阶方阵,本文来讨论它的平方根$\boldsymbol{P}^{1/2}$和逆平方根$\boldsymbol{P}^{-1/2}$的计算。

基本概念 #

矩阵$\boldsymbol{P}$的平方根,指的是满足$\boldsymbol{X}^2=\boldsymbol{P}$的矩阵$\boldsymbol{X}$。我们知道正数都有两个平方根,因此不难想象矩阵平方根一般也不唯一。不过,“算术平方根”是唯一的,一个正数的算术平方根是正的那个平方根,类似地,我们将$\boldsymbol{P}$的特征值全是非负数的那个平方根称为算术平方根。本文要求的矩阵平方根,默认都是指算术平方根。

本文的计算依赖于我们在《矩阵符号函数mcsgn能计算什么?》讨论过的矩阵符号函数:
\begin{equation}\newcommand{mcsgn}{\mathop{\text{mcsgn}}}\mcsgn(\boldsymbol{M}) = (\boldsymbol{M}^2)^{-1/2}\boldsymbol{M}= \boldsymbol{M}(\boldsymbol{M}^2)^{-1/2}
\end{equation}
简单来说,它就是把任意矩阵$\boldsymbol{M}\in\mathbb{R}^{n\times n}$的特征值变成对应的符号函数的新矩阵。假设$\boldsymbol{M}$的特征值都是实数,那么$\mcsgn$可以通过Newton-Schulz迭代高效计算:
\begin{equation}\newcommand{tr}{\mathop{\text{tr}}}\boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\sqrt{\tr(\boldsymbol{M}^2)}},\qquad \boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t^3 + c_{t+1}\boldsymbol{X}_t^5\end{equation}
其中$\frac{\boldsymbol{M}}{\sqrt{\tr(\boldsymbol{M}^2)}}$是为了将$\boldsymbol{X}_0$的特征值都缩放到$[-1,1]$内,而$a_t,b_t,c_t$是《msign算子的Newton-Schulz迭代(下)》所推导的系数:
\begin{array}{c|ccc}
\hline
t & a\times 1.01 & b\times 1.01^3 & c\times 1.01^5 \\
\hline
\quad 1\quad & 8.28721 & -23.5959 & 17.3004 \\
2 & 4.10706 & -2.94785 & 0.544843 \\
3 & 3.94869 & -2.9089 & 0.551819 \\
4 & 3.31842 & -2.48849 & 0.510049 \\
5 & 2.30065 & -1.6689 & 0.418807 \\
6 & 1.8913 & 1.268 & 0.376804 \\
7 & 1.875 & -1.25 & 0.375 \\
8 & 1.875 & -1.25 & 0.375 \\
\hline
\end{array}
实际上,当$\boldsymbol{M}$的特征值全都是实数时,$\mcsgn$的计算原理跟另一种矩阵符号函数$\newcommand{msign}{\mathop{\text{msign}}}\msign$是相通的。

计算原理 #

接下来计算的出发点是恒等式
\begin{equation}\mcsgn\left(\begin{bmatrix}\boldsymbol{0} & \boldsymbol{A} \\ \boldsymbol{B} & \boldsymbol{0}\end{bmatrix}\right)=\begin{bmatrix}\boldsymbol{0} & \boldsymbol{A}(\boldsymbol{B}\boldsymbol{A})^{-1/2} \\ \boldsymbol{B}(\boldsymbol{A}\boldsymbol{B})^{-1/2} & \boldsymbol{0}\end{bmatrix}\label{eq:core}\end{equation}
直接代入$\mcsgn$的定义就可以验证上式成立(注:$\boldsymbol{A},\boldsymbol{B}$未必是方阵)。接下来我们要确定什么情况下左端求$\mcsgn$的矩阵的特征值全都是实数。设$\lambda$是它的一个非零特征值,那么
\begin{equation}0=\det\left(\lambda\boldsymbol{I} - \begin{bmatrix}\boldsymbol{0} & \boldsymbol{A} \\ \boldsymbol{B} & \boldsymbol{0} \end{bmatrix}\right) = \det\left(\begin{bmatrix}\lambda\boldsymbol{I} & -\boldsymbol{A} \\ -\boldsymbol{B} & \lambda\boldsymbol{I} \end{bmatrix}\right) = \det(\lambda^2 \boldsymbol{I} - \boldsymbol{A}\boldsymbol{B})\end{equation}
即$\lambda^2$是矩阵$\boldsymbol{A}\boldsymbol{B}$的特征值。这意味着上述分块矩阵的全体特征值都是实数,当且仅当$\boldsymbol{A}\boldsymbol{B}$的全体特征值非负。

直接对原矩阵进行迭代自然是可以的,但会比较浪费计算,我们可以利用它的反对角线结构来降低计算量。因为
\begin{equation}
\begin{bmatrix}\boldsymbol{0} & \boldsymbol{Y} \\
\boldsymbol{Z} & \boldsymbol{0}\end{bmatrix}^3 = \begin{bmatrix}\boldsymbol{0} & (\boldsymbol{Y}\boldsymbol{Z})\boldsymbol{Y} \\
\boldsymbol{Z}(\boldsymbol{Y}\boldsymbol{Z}) & \boldsymbol{0}\end{bmatrix},\quad
\begin{bmatrix}\boldsymbol{0} & \boldsymbol{Y} \\
\boldsymbol{Z} & \boldsymbol{0}\end{bmatrix}^5 = \begin{bmatrix}\boldsymbol{0} & (\boldsymbol{Y}\boldsymbol{Z})^2\boldsymbol{Y} \\
\boldsymbol{Z}(\boldsymbol{Y}\boldsymbol{Z})^2 & \boldsymbol{0}\end{bmatrix} \\
\end{equation}
我们可以得到迭代
\begin{gather}
\boldsymbol{Y}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)\boldsymbol{Y}_t \label{eq:r1} \\[6pt]
\boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2) \label{eq:r2}
\end{gather}
那么$\boldsymbol{Y}_t\to \boldsymbol{A}(\boldsymbol{B}\boldsymbol{A})^{-1/2},\boldsymbol{Z}_t\to \boldsymbol{B}(\boldsymbol{A}\boldsymbol{B})^{-1/2}$。特别地,将上面两式相乘可以得到$\boldsymbol{Y}_t\boldsymbol{Z}_t$的递归
\begin{equation}\boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t\label{eq:r3}\end{equation}

求平方根 #

现在正式进入平方根的计算。由于假设了$\boldsymbol{P}$的特征值非负,我们总可以通过除以$\tr(\boldsymbol{P})$进一步将它的特征值压缩到$0\sim 1$之间,因此不失一般性,我们假设$\boldsymbol{P}$的特征值都在$[0,1]$内,这样就可以直接用Newton-Schulz迭代计算$\mcsgn$了。

将$\boldsymbol{A}=\boldsymbol{P},\boldsymbol{B}=\boldsymbol{I}$代入到式$\eqref{eq:core}$,可以得到
\begin{equation}\mcsgn\left(\begin{bmatrix}\boldsymbol{0} & \boldsymbol{P} \\ \boldsymbol{I} & \boldsymbol{0}\end{bmatrix}\right)=\begin{bmatrix}\boldsymbol{0} & \boldsymbol{P}^{1/2} \\ \boldsymbol{P}^{-1/2} & \boldsymbol{0}\end{bmatrix}\end{equation}
非常神奇,理论上通过只需要$\mcsgn$一次,就可以把平方根和逆平方根都求出来,也就是按照式$\eqref{eq:r1}$和$\eqref{eq:r2}$迭代,我们就可以同时完成两个任务!

然而,实际上没有这么理想。如果$\boldsymbol{P}$有非常接近于0的奇异值,那么$\boldsymbol{P}^{-1/2}$是会数值爆炸的(相当于出现了$1/\sqrt{0}$),但$\boldsymbol{P}^{1/2}$并不会,所以假设我们只关心$\boldsymbol{P}^{1/2}$的值,同时计算$\boldsymbol{P}^{1/2},\boldsymbol{P}^{-1/2}$反而会增加数值不稳定性。这时候更好的办法是通过式$\eqref{eq:r1}$和$\eqref{eq:r3}$来迭代,只计算$\boldsymbol{P}^{1/2}$:
\begin{gather}
\boldsymbol{Y}_0 = \boldsymbol{P}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt]
\boldsymbol{Y}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)\boldsymbol{Y}_t \\[6pt]
\boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t \\[6pt]
\lim_{t\to\infty} \boldsymbol{Y}_t = \boldsymbol{P}^{1/2}\notag
\end{gather}
由于$\boldsymbol{Z}_t$的极限是$\boldsymbol{P}^{-1/2}$,所以$\boldsymbol{Y}_t\boldsymbol{Z}_t$的极限是$\boldsymbol{I}$,因此迭代$\boldsymbol{Y}_t\boldsymbol{Z}_t$更不容易出现数值风险。参考代码如下:

import numpy as np

def abc(steps):
    coefs = [
        (8.287212018145622, -23.59588651909882, 17.300387312530923),
        (4.107059111542197, -2.9478499167379084, 0.54484310829266),
        (3.9486908534822938, -2.908902115962947, 0.5518191394370131),
        (3.3184196573706055, -2.488488024314878, 0.5100489401237208),
        (2.3006520199548186, -1.6689039845747518, 0.4188073119525678),
        (1.8913014077874002, -1.2679958271945908, 0.37680408948524996),
        (1.875, -1.25, 0.375)
    ]
    for a, b, c in coefs[:steps] + max(steps - 7, 0) * coefs[-1:]:
        yield a / 1.01, b / 1.01**3, c / 1.01**5

def msqrt(P, steps=6):
    Y = YZ = P / (t := np.trace(P))
    I = np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Y, YZ = W @ Y, W @ W @ YZ
    return Y * t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
np.abs(msqrt(P) @ msqrt(P) - P).mean()  # ~= 2e-4

逆平方根 #

如果我们必须要显式地求出逆平方根$\boldsymbol{P}^{-1/2}$,那么就没什么好办法了,该爆炸的始终都会爆炸,这时候不管是式$\eqref{eq:r2},\eqref{eq:r1}$组合还是式$\eqref{eq:r2},\eqref{eq:r3}$组合,效果都应该差不多,不过后者应该会相对稳定一点:
\begin{gather}
\boldsymbol{Z}_0 = \boldsymbol{I}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt]
\boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)\label{eq:r2-rsqrt} \\[6pt]
\boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t\label{eq:r3-rsqrt} \\[6pt]
\lim_{t\to\infty} \boldsymbol{Z}_t = \boldsymbol{P}^{-1/2}\notag
\end{gather}

参考代码如下:

def mrsqrt(P, steps=6):
    YZ = P / (t := np.trace(P))
    Z = I = np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Z, YZ = Z @ W, W @ W @ YZ
    return Z / t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
np.abs(mrsqrt(P) @ mrsqrt(P) @ P - np.eye(d)).mean()  # ~= 5e-4

矩阵相乘 #

不过在大多数时候,求$\boldsymbol{P}^{-1/2}$只是一个中间步骤,求完之后通常还要跟另一个矩阵做乘法。设矩阵$\boldsymbol{G}\in\mathbb{R}^{m\times n}$,我们需要计算$\boldsymbol{G}\boldsymbol{P}^{-1/2}$,如果我们能将$\boldsymbol{G}\boldsymbol{P}^{-1/2}$作为一个整体的迭代对象,那么相比单独求出$\boldsymbol{P}^{-1/2}$然后再执行矩阵乘法,往往有更好的数值稳定性。

让我们仔细观察式$\eqref{eq:r2-rsqrt}$和$\eqref{eq:r3-rsqrt}$,不难看出,当我们将$\boldsymbol{Y}_t\boldsymbol{Z}_t$视为一个整体时,它的迭代式$\eqref{eq:r3-rsqrt}$其实是独立于$\boldsymbol{Z}_t$的,所以$\boldsymbol{Z}_t$的式$\eqref{eq:r2-rsqrt}$本质上就只是一个线性递归!我们在它左边乘以一个矩阵,并不改变迭代形式,只需要修改一下初始值,于是得到
\begin{gather}
\boldsymbol{Z}_0 = \boldsymbol{G}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt]
\boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2) \label{eq:r2-final} \\[6pt]
\boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t \label{eq:r3-final}\\[6pt]
\lim_{t\to\infty} \boldsymbol{Z}_t = \boldsymbol{G}\boldsymbol{P}^{-1/2}\notag
\end{gather}

参考代码:

import scipy as sp

def matmul_mrsqrt(G, P, steps=6):
    YZ = P / (t := np.trace(P))
    Z, I = G, np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Z, YZ = Z @ W, W @ W @ YZ
    return Z / t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
G = np.random.randn(2 * d, d) / d**0.5
X = matmul_mrsqrt(G, P)
np.abs(X @ sp.linalg.sqrtm(P) - G).mean()  # ~= 1e-4

现在,让我们回过头来看求平方根的算法,不难看出它其实就是$\boldsymbol{G}=\boldsymbol{P}$时本节迭代的另一个等价写法,即$\boldsymbol{P}^{1/2}=\boldsymbol{P}\boldsymbol{P}^{-1/2}$。所以,虽然我们看上去式分开了三节来讨论了三个迭代,但它们本质上都是最后一个迭代的特例!

终极推广 #

最后,我们还可以将它推广到$\boldsymbol{Q}^{-1/2}\boldsymbol{G}\boldsymbol{P}^{-1/2}$的计算,其中$\boldsymbol{Q}\in\mathbb{R}^{m\times m}$是另一个特征值非负的矩阵,结果如下:
\begin{gather}
\boldsymbol{G}_0 = \boldsymbol{G}, \quad \boldsymbol{Q}_0 = \boldsymbol{Q},\quad \boldsymbol{P}_0 = \boldsymbol{P} \notag\\[6pt]
\boldsymbol{G}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Q}_t + c_{t+1}\boldsymbol{Q}_t^2)\boldsymbol{G}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) \\[6pt]
\boldsymbol{Q}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Q}_t + c_{t+1}\boldsymbol{Q}_t^2)^2\boldsymbol{Q}_t \\[6pt]
\boldsymbol{P}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^2\boldsymbol{P}_t \\[6pt]
\lim_{t\to\infty} \boldsymbol{G}_t = \boldsymbol{Q}^{-1/2}\boldsymbol{G}\boldsymbol{P}^{-1/2}\notag
\end{gather}

参考代码:

def mrsqrt_matmul_mrsqrt(Q, G, P, steps=6):
    Q = Q / (t1 := np.trace(Q))
    P = P / (t2 := np.trace(P))
    I1, I2 = np.eye(Q.shape[0]), np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W1 = a * I1 + b * Q + c * Q @ Q
        W2 = a * I2 + b * P + c * P @ P
        G, Q, P = W1 @ G @ W2, W1 @ W1 @ Q, W2 @ W2 @ P
    return G / (t1 * t2) **0.5

d = 100
Q = (x := np.random.randn(2 * d, 2 * d) / (2 * d)**0.5) @ x.T
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
G = np.random.randn(2 * d, d) / d**0.5
X = mrsqrt_matmul_mrsqrt(Q, G, P)
np.abs(sp.linalg.sqrtm(Q) @ X @ sp.linalg.sqrtm(P) - G).mean()  # ~= 2e-3

请读者根据前几节的结果自行完成证明。

对于Shampoo优化器,我们需要求$\boldsymbol{Q}^{-1/4}\boldsymbol{G}\boldsymbol{P}^{-1/4}$,目前看来比较可行的方案是分别先求出$\boldsymbol{Q}^{1/2}$和$\boldsymbol{P}^{1/2}$,然后代入上述迭代中求$(\boldsymbol{Q}^{1/2})^{-1/2}\boldsymbol{G}(\boldsymbol{P}^{1/2})^{-1/2}$。看上去计算量比较大,但实际上在Optimizer的Update阶段,算力往往不是瓶颈,只要算法可以充分并行,那么时间并不会明显增加,刚好$\boldsymbol{Q}^{1/2}$和$\boldsymbol{P}^{1/2}$的计算可以并行,迭代过程中两个W1、W2也可以并行,因此应该还能接受。

当然,比Muon慢是肯定的,毕竟Shampoo复杂度增加了这么多,总不能一点代价都不用付出(后续见《矩阵r次方根和逆r次方根的高效计算》)。

文章小结 #

本文提出了将矩阵的平方根和逆平方根转化为$\mcsgn$形式,利用它的Newton-Schulz迭代来实现高效计算的过程。

转载到请包括本文地址:https://spaces.ac.cn/archives/11158

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Jul. 19, 2025). 《矩阵平方根和逆平方根的高效计算 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/11158

@online{kexuefm-11158,
        title={矩阵平方根和逆平方根的高效计算},
        author={苏剑林},
        year={2025},
        month={Jul},
        url={\url{https://spaces.ac.cn/archives/11158}},
}