Softmax后传:寻找Top-K的光滑近似
By 苏剑林 | 2024-09-19 | 37651位读者 |Softmax,顾名思义是“soft的max”,是max算子(准确来说是\text{argmax})的光滑近似,它通过指数归一化将任意向量\boldsymbol{x}\in\mathbb{R}^n转化为分量非负且和为1的新向量,并允许我们通过温度参数来调节它与\text{argmax}(的one hot形式)的近似程度。除了指数归一化外,我们此前在《通向概率分布之路:盘点Softmax及其替代品》也介绍过其他一些能实现相同效果的方案。
我们知道,最大值通常又称Top-1,它的光滑近似方案看起来已经相当成熟,那读者有没有思考过,一般的Top-k的光滑近似又是怎么样的呢?下面让我们一起来探讨一下这个问题。
问题描述 #
设向量\boldsymbol{x}=(x_1,x_2,\cdots,x_n)\in\mathbb{R}^n,简单起见我们假设它们两两不相等,即i\neq j \Leftrightarrow x_i\neq x_j。记\Omega_k(\boldsymbol{x})为\boldsymbol{x}最大的k个分量的下标集合,即|\Omega_k(\boldsymbol{x})|=k以及\forall i\in \Omega_k(\boldsymbol{x}), j \not\in \Omega_k(\boldsymbol{x})\Rightarrow x_i > x_j。我们定义Top-k算子\mathcal{T}_k为\mathbb{R}^n\mapsto\{0,1\}^n的映射:
\begin{equation}
[\mathcal{T}_k(\boldsymbol{x})]_i = \left\{\begin{aligned}1,\,\, i\in \Omega_k(\boldsymbol{x}) \\ 0,\,\, i \not\in \Omega_k(\boldsymbol{x})\end{aligned}\right.
\end{equation}
说白了,如果x_i属于最大的k个元素之一,那么对应的位置变成1,否则变成0,最终结果是一个Multi-Hot向量,比如\mathcal{T}_2([3,2,1,4]) = [1,0,0,1]。
从\boldsymbol{x}到\mathcal{T}_k(\boldsymbol{x})实际上是一种硬指派(Hard Assignment)运算,它本质上是不连续的,没有保留关于\boldsymbol{x}的有效梯度,因此无法集成到模型中进行端到端训练。为了解决这个问题,我们就需要构造一个能够提供有效梯度信息的\mathcal{T}_k(\boldsymbol{x})的光滑近似——在有些文献中也称为“可微Top-k算子(Differentiable Top-k Operator)”。
具体来说,我们定义集合
\begin{equation}\Delta_k^{n-1} = \left\{\boldsymbol{p}=(p_1,p_2,\cdots,p_n)\left|\, p_1,p_2,\cdots,p_n\in[0,1],\sum_{i=1}^n p_i = k\right.\right\}\end{equation}
那么我们要做的事情就是构建\mathbb{R}^n\mapsto \Delta_k^{n-1}的一个映射\mathcal{ST}_k(\boldsymbol{x}),并尽量满足如下性质
\begin{align}&{\color{red}{单调性}}:\quad [\mathcal{ST}_k(\boldsymbol{x})]_i \geq [\mathcal{ST}_k(\boldsymbol{x})]_j \,\,\Leftrightarrow\,\, x_i \geq x_j \\[8pt]
&{\color{red}{不变性}}:\quad \mathcal{ST}_k(\boldsymbol{x}) = \mathcal{ST}_k(\boldsymbol{x} + c),\,\,\forall c\in\mathbb{R} \\[8pt]
&{\color{red}{趋近性}}:\quad \lim_{\tau\to 0^+}\mathcal{ST}_k(\boldsymbol{x}/\tau) = \mathcal{T}_k(\boldsymbol{x}) \\
\end{align}
可以验证作为\mathcal{ST}_1(\boldsymbol{x})的Softmax是满足如上性质的,所以提出上述性质实际就是希望构建出来的\mathcal{ST}_k(\boldsymbol{x})能够成为Softmax的自然推广。当然,构建Top-k的光滑近似本身就比Top-1的要难,所以如果遇到困难的话,不必要严格遵循以上性质,只要能表明所构建的映射确实具备\mathcal{T}_k(\boldsymbol{x})的光滑近似的特性就行。
迭代构造 #
事实上,笔者很早之前就关注过该问题,首次讨论于2019年的文章《函数光滑化杂谈:不可导函数的可导逼近》中,当时称之为\text{soft-}k\text{-max},并给出过一个迭代构造方案:
输入为\boldsymbol{x},初始化\boldsymbol{p}^{(0)}为全0向量;
执行\boldsymbol{x} = \boldsymbol{x} - \min(\boldsymbol{x})(保证所有元素非负)对于i=1,2,\dots,k,执行:
\boldsymbol{y} = (1 - \boldsymbol{p}^{(i-1)})\otimes\boldsymbol{x};
\boldsymbol{p}^{(i)} = \boldsymbol{p}^{(i-1)} + \text{softmax}(\boldsymbol{y})返回\boldsymbol{p}^{(k)}。
其实这个迭代的构造思路很简单,我们可以先将\text{softmax}(\boldsymbol{y})替换为\mathcal{T}_1(\boldsymbol{y})来理解,此时算法流程就是先保证分量非负,然后识别出Top-1,再将Top-1置零(最大值变成了最小值),接着识别出剩下的Top-1,依此类推,最终的\boldsymbol{p}_k正好就是\mathcal{T}_k(\boldsymbol{x})。而\text{softmax}(\boldsymbol{y})作为\mathcal{T}_1(\boldsymbol{y})的光滑近似,在迭代过程中使用\text{softmax}(\boldsymbol{y})自然也就得到\mathcal{T}_k(\boldsymbol{x})的光滑近似了。
无独有偶,笔者发现在Stack Exchange上的提问《Is there something like softmax but for top k values?》中有回复者提出过一个相同思路的方案,它先定义了加权Softmax:
\begin{equation}[\text{softmax}(\boldsymbol{x};\boldsymbol{w})]_i = \frac{w_i e^{x_i}}{\sum\limits_{i=1}^n w_i e^{x_i}}\end{equation}
然后构建的迭代过程为
输入为\boldsymbol{x},初始化\boldsymbol{p}^{(0)}为全0向量;
对于i=1,2,\dots,k,执行:
\boldsymbol{p}^{(i)} = \boldsymbol{p}^{(i-1)} + \text{softmax}(\boldsymbol{x}; 1 - \boldsymbol{p}^{(i-1)})返回\boldsymbol{p}^{(k)}。
这跟笔者所提的迭代过程在思路上是完全一样的,只不过笔者把1 - \boldsymbol{p}_{i-1}乘到了\boldsymbol{x}上而它则乘到了e^{\boldsymbol{x}}上,借助e^{\boldsymbol{x}}本身的非负性简化了流程。然而,这个迭代实际上是错误的,它不符合“趋近性”,比如当k=2时,代入\boldsymbol{x}/\tau取\tau\to 0^+的极限并不是Multi-Hot向量,而是最大值变为1.5、次大值变为0.5、其余都变为0的向量,这是因为1-p_{\max}跟e^{-x_{\max}}大致上是同阶的,1-p_{\max}乘到e^{x_{\max}}上并不能完全消除最大值。
作为梯度 #
迭代构造全凭经验,可能会隐藏一些难以发现的问题,比如看起来更简单的加权Softmax迭代实际上不合理。由于没有更贴近本质的原理指导,这些方案往往也难以理论分析,比如笔者构造的迭代,虽然测起来没有问题,但我们很难证明\boldsymbol{p}_k分量都在[0,1]内,也很难判断它是否满足单调性。
所以,我们希望得到一个更高观点的原理去指导和设计这个光滑近似。就在前几天,笔者突然意识到一个关键的事实
\begin{equation}\mathcal{T}_k(\boldsymbol{x}) = \nabla_{\boldsymbol{x}} \sum_{i\in\Omega_k(\boldsymbol{x})} x_i\end{equation}
也就是说,最大的k个分量的和,它的梯度正好是\mathcal{T}_k(\boldsymbol{x})。所以我们似乎可以改为找\sum\limits_{i\in\Omega_k(\boldsymbol{x})} x_i的光滑近似,然后求梯度就得到\mathcal{T}_k(\boldsymbol{x})的光滑近似了,前者是一个标量,它的光滑近似找起来更容易一些,比如利用恒等式
\begin{equation}\sum_{i\in\Omega_k(\boldsymbol{x})} x_i = \max_{i_1 < \cdots < i_k} (x_{i_1} + \cdots + x_{i_k})\end{equation}
即遍历所有k个分量之和取最大值,这样一来,问题变成了找\max的光滑近似,这个我们早已解决(参考《寻求一个光滑的最大值函数》),答案是\text{logsumexp}:
\begin{equation}\max_{i_1 < \cdots < i_k} (x_{i_1} + \cdots + x_{i_k})\approx \log\sum_{i_1 < \cdots < i_k} e^{x_{i_1} + \cdots + x_{i_k}}\triangleq \log Z_k\end{equation}
对它求梯度,我们就得到\mathcal{ST}_k(\boldsymbol{x})的一个形式:
\begin{equation}[\mathcal{ST}_k(\boldsymbol{x})]_i = \frac{\sum\limits_{i_2 < \cdots < i_k} e^{x_i+x_{i_2} + \cdots + x_{i_k}}}{\sum\limits_{i_1 < \cdots < i_k} e^{x_{i_1} +x_{i_2}+ \cdots + x_{i_k}}}\triangleq \frac{Z_{k,i}}{Z_k}\label{eq:k-max-grad}\end{equation}
分母是所有k分量和的指数和,分子则是所有包含x_i的k分量和的指数和。根据这个形式,我们可以轻松证明
\begin{equation}0 < [\mathcal{ST}_k(\boldsymbol{x})]_i < 1,\quad \sum_{i=1}^n [\mathcal{ST}_k(\boldsymbol{x})]_i = k\end{equation}
所以这样定义的\mathcal{ST}_k(\boldsymbol{x})确实是属于\Delta_k^{n-1}的。事实上,我们还可以证明它同时满足单调性、不变性和趋近性,并且\mathcal{ST}_1(\boldsymbol{x})它就是Softmax,这些特性显示它确实是Softmax对于Top-k算子的自然推广,我们暂且称之为“GradTopK(Gradient-guided Soft Top-k operator)”。
不过还没到庆祝时刻,因为式\eqref{eq:k-max-grad}的数值计算问题还没解决。如果直接按照式\eqref{eq:k-max-grad}来计算的话,分母就涉及到C_n^k项指数求和,计算量非常可观,所以必须找到一个高效的计算方法。我们已经分别记了分子、分母为Z_{k,i},Z_k,可以观察到分子Z_{k,i}满足递归式
\begin{equation}Z_{k,i} = e^{x_i}(Z_{k-1} - Z_{k-1,i})\end{equation}
结合Z_{k,i}对i求和等于kZ_k的事实,我们可以构建一个递归计算过程:
\begin{equation}\begin{aligned}
\log Z_{k,i} =&\, x_i + \log(e^{\log Z_{k-1}} - e^{\log Z_{k-1,i}}) \\
\log Z_k =&\, \left(\log\sum_{i=1}^n e^{\log Z_{k,i}}\right) - \log k \\
\end{aligned}\end{equation}
其中\log Z_{1,i} = x_i,为了减少溢出风险,我们对两端都取了对数。现在只需要迭代k步就可以完成\mathcal{ST}_k(\boldsymbol{x})的计算,效率上是可以接受的。然而,即便做了对数化处理,上述递归也只能对小方差的\boldsymbol{x}或者比较小的k算一下,反之\log Z_{k-1}与最大的\log Z_{k-1,i}就会相当接近,当数值上无法区分时就会出现\log 0的Bug,个人认为这是这种递归转化的根本困难。
一个非常简陋的参考实现:
import numpy as np
def GradTopK(x, k):
for i in range(1, k + 1):
logZs = x if i == 1 else x + logZ + np.log(1 - np.exp(logZs - logZ))
logZ = np.logaddexp.reduce(logZs) - np.log(i)
return np.exp(logZs - logZ)
k, x = 10, np.random.randn(100)
GradTopK(x, k)
待定常数 #
上一节通过梯度来构建Top-k光滑近似的思路,确实能给人一种高屋建瓴的美感,但可能也会有读者觉得它过于抽象,缺乏一种由表及里的直观感,同时对于大方差的\boldsymbol{x}或者比较大的k的数值不稳定性,也让我们难以完全满意现有结果。所以,接下来我们将探究一种自下而上的构建思路。
方法大意 #
该思路来源于Stack Exchange上的另一个帖子《Differentiable top-k function》的回复。设f(x)是任意\mathbb{R}\mapsto [0,1]的、光滑的、单调递增的函数,并且满足\lim\limits_{x\to\infty}f(x) = 1,\lim\limits_{x\to-\infty}f(x) = 0。看上去条件很多,但实际上这种函数构造起来没有什么难度,比如经典的Sigmoid函数\sigma(x)=1/(1+e^{-x}),还有\text{clip}(x,0,1)、\min(1, e^x)等。接着我们考虑
\begin{equation}f(\boldsymbol{x}) = [f(x_1),f(x_2),\cdots,f(x_n)]\end{equation}
f(\boldsymbol{x})跟我们想要的\mathcal{ST}_k(\boldsymbol{x})差多远呢?每个分量都在[0,1]内肯定是满足了,但是分量之和等于k无法保证,所以我们引入一个跟\boldsymbol{x}相关的待定常数\lambda(x)来保证这一点:
\begin{equation}\mathcal{ST}_k(\boldsymbol{x}) \triangleq f(\boldsymbol{x} - \lambda(\boldsymbol{x})),\quad \sum_{i=1}^n f(x_i - \lambda(\boldsymbol{x})) = k\end{equation}
也就是通过分量之和为k来反解出\lambda(\boldsymbol{x}),我们可以称之为“ThreTopK(Threshold-adjusted Soft Top-k operator)”,如果读者已经阅读过《通向概率分布之路:盘点Softmax及其替代品》,就会发现这个做法跟Sparsemax、Entmax-\alpha是一样的。
ThreTopK会是我们理想的\mathcal{ST}_k(\boldsymbol{x})吗?还真是!首先我们假设了f的单调性,所以单调性是满足的,其次f(\boldsymbol{x} - \lambda(\boldsymbol{x}))=f(\boldsymbol{x}+c - (c+\lambda(\boldsymbol{x}))),也就是常数可以收纳到\lambda(\boldsymbol{x})里边,所以不变性也是满足的。最后,当\tau\to 0^+时,我们可以找到一个适当的阈值\lambda(\boldsymbol{x}/\tau),使得\boldsymbol{x}/\tau-\lambda(\boldsymbol{x}/\tau)最大的k个分量趋于\infty,剩下的分量趋于-\infty,从而f(\boldsymbol{x}/\tau-\lambda(\boldsymbol{x}/\tau))就等于\mathcal{T}_k(\boldsymbol{x}),也就是说满足趋近性。
解析求解 #
既然已经证明了ThreTopK的理论优越性,那么接下来要解决的就是\lambda(\boldsymbol{x})的计算了,这个大部份情况下都只能诉诸数值计算方法,不过对于f(x)=\min(1, e^x),我们可以求得一个解析解。
求解的思路跟之前的Sparsemax是一样的。不失一般性,假设\boldsymbol{x}的分量已经从大到小排好顺序,即x_1 > x_2 > \cdots > x_n,接着假设我们已经知道x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1},那么此时
\begin{equation}k = \sum_{i=1}^n \min(1, e^{x_i - \lambda(\boldsymbol{x})}) = m + \sum_{i=m+1}^n e^{x_i - \lambda(\boldsymbol{x})}\end{equation}
由此解得
\begin{equation}\lambda(\boldsymbol{x})=\log\left(\sum_{i=m+1}^n e^{x_i}\right) - \log(k-m)\end{equation}
由此我们可以看出,当k=1时,m只能取0,此时可以发现ThreTopK正好就是Softmax;当k > 1时,我们无法事先确定m的值,所以只能遍历m=0,1,\cdots,k-1,根据上式计算\lambda(\boldsymbol{x}),寻找满足x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}的\lambda(\boldsymbol{x})。下面同样是一个非常简陋的参考实现:
import numpy as np
def ThreTopK(x, k):
x_sort = np.sort(x)
x_lamb = np.logaddexp.accumulate(x_sort)[-k:] - np.log(np.arange(k) + 1)
x_sort_shift = np.pad(x_sort[-k:][1:], (0, 1), constant_values=np.inf)
lamb = x_lamb[(x_lamb <= x_sort_shift) & (x_lamb >= x_sort[-k:])]
return np.clip(np.exp(x - lamb), 0, 1)
k, x = 10, np.random.randn(100)
ThreTopK(x, k)
通用结果 #
从原理和代码都可以看出,f(x)=\min(1, e^x)时的ThreTopK几乎不会出现数值稳定性问题,并且k=1时能退化为Softmax,这些都是它的优势。然而,\min(1, e^x)实际上也算不上完全光滑(除了当k=1时\min不起作用),它在x=0处是不可导的。如果介意这一点,那么我们就需要选择处处可导的f(x),比如\sigma(x)。
下面我们以f(x)=\sigma(x)为例,此时我们无法求出\lambda(\boldsymbol{x})的解析解,不过由于\sigma(x)的单调递增性,所以函数
\begin{equation}F(\lambda)\triangleq \sum_{i=1}^n \sigma(x_i - \lambda)\end{equation}
关于\lambda是单调递减的,因此F(\lambda(\boldsymbol{x}))=k的数值求解并不困难,二分法、牛顿法均可。以二分法为例,不难看出\lambda(\boldsymbol{x})\in[x_{\min} - \sigma^{-1}(k/n), x_{\max} - \sigma^{-1}(k/n)],这里\sigma^{-1}是\sigma的逆函数,从这个初始区间出发,逐步二分到指定精度即可:
import numpy as np
def sigmoid(x):
y = np.exp(-np.abs(x))
return np.where(x >= 0, 1, y) / (1 + y)
def sigmoid_inv(x):
return np.log(x / (1 - x))
def ThreTopK(x, k, epsilon=1e-4):
low = x.min() - sigmoid_inv(k / len(x))
high = x.max() - sigmoid_inv(k / len(x))
while high - low > epsilon:
lamb = (low + high) / 2
Z = sigmoid(x - lamb).sum()
low, high = (low, lamb) if Z < k else (lamb, high)
return sigmoid(x - lamb)
k, x = 10, np.random.randn(100)
ThreTopK(x, k)
因此,\lambda(\boldsymbol{x})的数值计算并没有太大困难,真正的困难是当我们用数值方法去计算\lambda(\boldsymbol{x})时,往往会丢失\lambda(\boldsymbol{x})关于\boldsymbol{x}的梯度,从而影响端到端的训练。针对这个问题,我们可以手动把\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})算出来,然后自定义反向传播过程。具体来说,我们对
\begin{equation}\sum_{i=1}^n \sigma(x_i - \lambda(\boldsymbol{x})) = k\end{equation}
两边求某个x_j的偏导数,得到
\begin{equation}\sigma'(x_j - \lambda(\boldsymbol{x}))-\sum_{i=1}^n \sigma'(x_i - \lambda(\boldsymbol{x}))\frac{\partial\lambda(\boldsymbol{x})}{\partial x_j} = 0\end{equation}
那么
\begin{equation}\frac{\partial\lambda(\boldsymbol{x})}{\partial x_j} = \frac{\sigma'(x_j - \lambda(\boldsymbol{x}))}{\sum\limits_{i=1}^n \sigma'(x_i - \lambda(\boldsymbol{x}))}\end{equation}
其中\sigma'是\sigma的导函数。现在我们有了\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})的表达式,它的每一项都是可计算的(\lambda(\boldsymbol{x})也已经由数值方法求出),我们可以直接指定它作为反向传播的结果。一个比较简单且通用的实现方法是利用\text{stop_gradient}(下面简称\text{sg})技巧,即在实现模型时将\lambda(\boldsymbol{x})替换为
\begin{equation}\boldsymbol{x}\cdot\text{sg}[\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})] + \text{sg}[\lambda(\boldsymbol{x}) - \boldsymbol{x}\cdot\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})]\end{equation}
其中\cdot是向量的内积。这样一来,前向传播时等价于\text{sg}不存在,所以结果是\lambda(\boldsymbol{x}),反向传播时被\text{sg}的部份梯度都是零,所以梯度就是给定的\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x}),这样我们就自定义了\lambda(\boldsymbol{x})的梯度,跟\lambda(\boldsymbol{x})如何计算得到的无关。
二者兼之 #
现在我们看到,f(x)=\min(1,e^x)有解析解,但它并非全局光滑;f(x)=\sigma(x)倒是足够光滑了,但它求解起来比较复杂。有没有兼顾两者优点的选择呢?还真有,笔者发现下述的f(x)是全局光滑并且\lambda(\boldsymbol{x})可以解析求解的:
\begin{equation}f(x) = \left\{\begin{aligned}1 - e^{-x}/2,\quad x\geq 0 \\ e^x / 2,\quad x < 0\end{aligned}\right.\end{equation}
也可以写成f(x) = (1 - e^{-|x|})\text{sign}(x)/2+1/2。可以验证f(x)其实也是一个S型函数,虽然它是分段函数,但它本身及其导函数在x=0处都是连续的,因此足够光滑了。
求解思路跟之前一样,不失一般性设x_1 > x_2 > \cdots > x_n,假设已经知道x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1},那么
\begin{equation}\begin{aligned}
k =&\, \sum_{i=1}^m (1 - e^{-(x_i - \lambda(\boldsymbol{x}))}/2) + \sum_{i=m+1}^n e^{x_i - \lambda(\boldsymbol{x})}/2 \\
=&\, m - \frac{1}{2}e^{\lambda(\boldsymbol{x})}\sum_{i=1}^m e^{-x_i} + \frac{1}{2}e^{-\lambda(\boldsymbol{x})}\sum_{i=m+1}^n e^{x_i}
\end{aligned}\end{equation}
由此解得
\begin{equation}\lambda(\boldsymbol{x})=\log\sum_{i=m+1}^n e^{x_i} - \log\left(\sqrt{(k-m)^2 + \left(\sum_{i=1}^m e^{-x_i}\right)\left(\sum_{i=m+1}^n e^{x_i}\right)}+(k-m)\right)\end{equation}
然后遍历m=0,1,\cdots,n-1,寻找满足x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}的\lambda(\boldsymbol{x})即可。读者还可以尝试证明一下,当k=1时此f(x)下的ThreTopK也正好退化为Softmax。
参考实现:
import numpy as np
def ThreTopK(x, k):
x_sort = np.sort(x)
lse1 = np.logaddexp.accumulate(x_sort)
lse2 = np.pad(np.logaddexp.accumulate(-x_sort[::-1])[::-1], (0, 1), constant_values=-np.inf)[1:]
m = np.arange(len(x) - 1, -1, -1)
x_lamb = lse1 - np.log(np.sqrt((k - m)**2 + np.exp(lse1 + lse2)) + (k - m))
x_sort_shift = np.pad(x_sort[1:], (0, 1), constant_values=np.inf)
lamb = x_lamb[(x_lamb <= x_sort_shift) & (x_lamb >= x_sort)]
return (1 - np.exp(-np.abs(x - lamb))) * np.sign(x - lamb) * 0.5 + 0.5
k, x = 10, np.random.randn(100)
ThreTopK(x, k)
文章小结 #
本文探讨了Top-k算子的光滑近似问题,它是Softmax等Top1的光滑近似的一般推广,提出了迭代构造、梯度指引、待定常数三种构造思路,并分析了它们的优缺点。
转载到请包括本文地址:https://spaces.ac.cn/archives/10373
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Sep. 19, 2024). 《Softmax后传:寻找Top-K的光滑近似 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/10373
@online{kexuefm-10373,
title={Softmax后传:寻找Top-K的光滑近似},
author={苏剑林},
year={2024},
month={Sep},
url={\url{https://spaces.ac.cn/archives/10373}},
}
September 19th, 2024
对于ThreTopK中f(x)=\min(1, e^x)的情形,为什么不直接固定取\lambda使得m=k-1,这样f(x)就是k-1个1和一个n-k+1维的Softmax的直积。
怎么取m=k-1?取m=k-1就无法保证求和为k了啊
也就是取\lambda 使得 x_{k-1}\geq \lambda >x_k,则为了保证求和为k,有\lambda(x)=\log\sum_{i=k}^n e^{x_i}。此时f(x)的前k-1项为1,而第k\sim n项构成了Softmax,和为1,总的和为k。这似乎也是个简单粗暴的TopK光滑近似呀?
你这个就是Top-(k-1)个直接变为1,剩下的做Softmax,这是属于你自己构造的一个Top-k的“伪”光滑近似,不属于f(x)=\min(1,e^x)的ThreTopK解,因为你这样算出来的\lambda(\boldsymbol{x}),无法保证对于前k-1个x_i,都有x_i > \lambda(\boldsymbol{x})。
至于为什么是“伪”光滑近似,都直接将Top-(k-1)硬截断为1了,明显的不连续操作。
我明白我的出发点错在那里了:x_{k-1}\geq \lambda(x)和\lambda(x)=\log\sum_{i=k}^n e^{x_i}并不一定同时成立。但是核心问题似乎仍没被解决:为什么不直接使用前k-1项为1,k\sim n项为Softmax的f(x)作为TopK的光滑近似?它也满足“单调性、不变性和趋近性”。
因为不够光滑。如果直接将Top-k/Top-(k-1)选出来做硬截断的操作都允许,那么就没有光滑近似这个问题了。
感谢解答~
September 21st, 2024
感谢分享,确实ThreTopK是个比较巧妙的方法。我们这边之前也在研究相关问题,在分析了很多不同的Differentiable top-k 后最后引用了Stack Exchange这个方法。目前这个算子在做些改动后在剪枝领域效果十分不错,感兴趣的话可以看下我们的paper Smart(https://arxiv.org/abs/2403.19969),当然目前也发现当前算子在深度学习应用过程确实还有一些限制,目前在整理相关问题,之后也会整理成flow work 的paper 分享给大家,并且我们也发现类似的算子在混合量化等领域效果依旧出色。
学习了,期待后续工作。Stack Exchange这个思路确实精彩,果然是高手在民间。
September 24th, 2024
三种方法里,第一种缺点是有数值稳定性问题,第二种缺点是不完全光滑的,那第三种方法有缺点吗?
大概清楚了,第三种方法要自己写反向传播函数
没有解析解,要自己写反向传播,当k=1时无法退化为Softmax,这些都可以说是缺点。
September 27th, 2024
這個東西看起來可以用於 MoE。
是有这个猜测,所以最近在尝试补充一些sparse training的内容。
之前有人從最佳傳輸問題的角度提出了一個 Differentiable Top-k [1],或許可以考慮用 Wasserstein distance 來做一個?
[1] Xie, Y., Dai, H., Chen, M., Dai, B., Zhao, T., Zha, H., Wei, W., & Pfister, T. (2020). Differentiable Top-k with Optimal Transport. In Advances in Neural Information Processing Systems (pp. 20520–20531). Curran Associates, Inc.
看到了,一眼看上去不是很简洁,还没认真看它的必要性。不过从@GD|comment-25270的反馈来看,最后一种方案貌似也够用了。
September 27th, 2024
Softmax 的输入和输出都是向量,它应该看作是 one-hot 的 soft 形式。至于 max 的光滑近似,还是回到 \log\sum\exp 吧。
括号已经注明了,准确来说是\text{argmax}的光滑近似(更准确来说是\text{onehot}(\text{argmax}),但一般不对两者做区分了)
November 8th, 2024
请问有没有top k sum的光滑近似的参考呀?
那不就是\sum\limits_i \big([\mathcal{ST}_k(\boldsymbol{x})]_i \times x_i\big)?
January 19th, 2025
STk: A Scalable Module for Solving Top-k Problems
学习了一下,似乎有点复杂。
这是top-k sum的近似算法,核心思想 \sum_{j=1}^k e_{[k]}= argmin_\lambda \sum_{i=1}^n [e_i - \lambda]_+ + k\lambda 承自Ogryczak 2003年的一篇文章。
Ogryczak, Wlodzimierz, and Arie Tamir. "Minimizing the sum of the k largest functions in linear time." Information Processing Letters 85.3 (2003): 117-122.
这个恒等变换确实挺有意思的,值得学习一番。
March 27th, 2025
苏神,最后一种方法使用了x_sort = np.sort(x),如果在torch框架下,torch.sort也没有梯度,是否需要给sort一个梯度?
sort是有梯度的,argsort才没梯度。最后一个我尝试过用在MoE上了,效果尚可(不过我是jax)。