从梯度最大化看Attention的Scale操作
By 苏剑林 | 2023-10-22 | 64241位读者 |我们知道,Scaled Dot-Product Attention的Scale因子是$\frac{1}{\sqrt{d}}$,其中$d$是$\boldsymbol{q},\boldsymbol{k}$的维度。这个Scale因子的一般解释是:如果不除以$\sqrt{d}$,那么初始的Attention就会很接近one hot分布,这会造成梯度消失,导致模型训练不起来。然而,可以证明的是,当Scale等于0时同样也会有梯度消失问题,这也就是说Scale太大太小都不行。
那么多大的Scale才适合呢?$\frac{1}{\sqrt{d}}$是最佳的Scale了吗?本文试图从梯度角度来回答这个问题。
已有结果 #
在《浅谈Transformer的初始化、参数化与标准化》中,我们已经推导过标准的Scale因子$\frac{1}{\sqrt{d}}$,推导的思路很简单,假设初始阶段$\boldsymbol{q},\boldsymbol{k}\in\mathbb{R}^d$都采样自“均值为0、方差为1”的分布,那么可以算得
\begin{equation}\mathbb{V}ar[\boldsymbol{q}\cdot\boldsymbol{k}] = d\end{equation}
于是我们将$\boldsymbol{q}\cdot\boldsymbol{k}$除以$\sqrt{d}$,将Attention Score的方差变为1。也就是说,之前的推导纯粹是基于“均值为0、方差为1”就会更好的信仰来得到的结果,但没有解释让Attention Score的方差为1,也没有评估$\frac{1}{\sqrt{d}}$是否真的就解决了梯度消失问题。
当然,从已有的实验来看,$\frac{1}{\sqrt{d}}$至少一定程度上是缓解了这个问题,但这毕竟是实验结果,我们还是希望能从理论上知道“一定程度”究竟是多少。
计算梯度 #
既然涉及到了梯度,那么最好的办法就是把梯度算出来,然后定一个优化目标。设$p_i = e^{\alpha s_i}/Z$,$i \in \{1,2,...,n\}$,$Z=\sum_i e^{\alpha s_i}$是归一化因子,那么可以直接算得:
\begin{equation}\frac{\partial p_i}{\partial s_j} = \left\{\begin{aligned}
\alpha(p_i - p_i^2),&\quad i=j\\
-\alpha p_i p_j,&\quad i\neq j
\end{aligned}\right.\end{equation}
或者可以简写成$\partial p_i/\partial s_j = \alpha(p_i\delta_{i,j} - p_i p_j)$。很明显,当$\alpha\to 0$时梯度为0;当$\alpha\to\infty$时,$p_i$之中只有一个1、其余都是0(假设$s_i$中只有唯一的最大值),梯度也是0。
为了更有利于优化,我们应该选取$\alpha$使得梯度尽可能最大化。为此,我们以L1范数作为梯度大小的度量:
\begin{equation}\frac{1}{2}\left\Vert\frac{\partial p}{\partial s}\right\Vert_1=\frac{1}{2}\sum_{i,j}\left|\frac{\partial p_i}{\partial s_j}\right|=\frac{1}{2}\sum_i \alpha(p_i - p_i^2) + \frac{1}{2}\sum_{i\neq j} \alpha p_i p_j = \alpha\left(1 - \sum_i p_i^2\right)\label{eq:target}\end{equation}
从最后的结果不难猜到,之所以选择L1而不是其他的根本原因是因为L1范数的计算结果足够简单。值得指出的是,这里出现了$\sum_i p_i^2$,它本质上就是我们在《如何度量数据的稀疏程度?》介绍过的“Rényi熵”,跟信息熵类似,它也是不确定性的一种度量。
有了优化目标后,我们就可以着手进行最大化了。注意$p_i$的定义里边也包含$\alpha$,所以这是一个关于$\alpha$复杂的非线性目标,看上去求解析解是不可能的,但我们可以针对一些特殊例子求近似解。
正态分布 #
首先,我们可以接着前面的结果来做,当我们通过除以$\sqrt{d}$使得Attention Score的均值为0、方差为1后,我们就可以近似假设$s_i\sim\mathcal{N}(0,1)$,然后再求$\alpha$的最优解,如果$\alpha=1$,那么就意味着原来的$\frac{1}{\sqrt{d}}$就是最优的Scale比例了,否则$\frac{\alpha}{\sqrt{d}}$才是最佳的Scale比例。
我们用期望去估计求和
\begin{equation}\sum_i p_i^2 = \frac{\sum_i e^{2\alpha s_i}}{\left(\sum_i e^{\alpha s_i}\right)^2} = \frac{\frac{1}{n}\sum_i e^{2\alpha s_i}}{n\left(\frac{1}{n}\sum_i e^{\alpha s_i}\right)^2} \approx \frac{\mathbb{E}_s[e^{2\alpha s}]}{n\left(\mathbb{E}_s[e^{\alpha s}]\right)^2}\label{eq:approx}\end{equation}
对于服从标准正态分布的$s$,我们有
\begin{equation}\mathbb{E}_s[e^{\alpha s}] = \int \frac{1}{\sqrt{2\pi}}e^{-s^2/2}e^{\alpha s} ds = e^{\alpha^2 / 2}\label{eq:normal}\end{equation}
代入上式,然后代入式$\eqref{eq:target}$,得到
\begin{equation}\alpha\left(1 - \sum_i p_i^2\right)\approx\alpha\left(1 - \frac{e^{\alpha^2}}{n}\right)\end{equation}
最后的近似,虽然已经足够简化了,但其实也不容易求出最大值来。不过无妨,我们可以遍历一些$n$,然后数值求解出取最大值时的$\alpha^*$,这样我们就大致能看到$\alpha^*$与$n$的关系了,Mathematica的参考代码如下:
(*定义函数*)
f[a_, n_] := a*(1 - Exp[a^2]/n)
(*找到函数的最大点对应的a*)
FindArg[n_] :=
Module[{a}, a = a /. Last@NMaximize[{f[a, n], a > 0}, a][[2]]; a]
(*给定n的范围*)
nRange = 40*Range[1, 500];
(*求出每个n对应的a*)
args = FindArg /@ nRange;
(*画出a与n的函数图像*)
ListLinePlot[{args, 0.84*Log[nRange]^0.5},
DataRange -> {40, 20000}, AxesLabel -> {"n", "a"},
PlotLegends -> {Row[{"a", Superscript["", "*"]}],
TraditionalForm[HoldForm[0.84*Sqrt[Log[n]]]]}]
经过拟合,笔者发现一定范围内最优点$\alpha^*$与$n$大致满足$\alpha\approx 0.84\sqrt{\log n}$的关系,所以也已经将对应的近似函数一并画在一起:
可以看到,在相当大的一个范围内,$\alpha^*$的最优值都在$2\sim 3$之间,所以折中一下的话,盲取$\frac{2.5}{\sqrt{d}}$作为Attention的Scale因子理论上更有利于优化。
余弦分布 #
现在我们考虑另一个不那么常见的例子:当我们对$\boldsymbol{q},\boldsymbol{k}$都做$l_2$归一化变成单位向量后,它们的内积就变成了夹角余弦,即$s_i$近似服从$d$维空间中的两个随机向量的夹角余弦分布。这个分布可能有些读者并不熟悉,但之前我们在《n维空间下两个随机向量的夹角分布》已经探讨过,它的概率密度具有形式
\begin{equation}p(s)\propto (1-s^2)^{(d-3)/2}\end{equation}
看上去并不复杂,但事实上这个形式比正态分布难处理得多,主要是$\mathbb{E}_s[e^{\alpha s}]$已经不像式$\eqref{eq:normal}$那样可以用初等函数表达出来了,不过对于Mathematica数值求解来说问题不大。跟上一节同样的思路,近似式$\eqref{eq:approx}$也同样适用,先数值求解最大值,然后再拟合,结果如下(图中$d=128$,$\alpha^*$跟$d$相关):
可以看到,$\alpha^*$与$3.5\log n$拟合得也不错(换一个$d$的话,$3.5$这个系数会变化)。可以看到,在一个相当大的范围内,$\alpha^*$都是$25\sim 35$之间,所以如果用$\cos$值作为Attention Score的话,就需要乘以一个$25\sim 35$之间的Scale,才能使得模型比较容易训下去。这同时也解释了为什么我们在用$\cos$值构建Softmax分布(比如AM-Softmax、SimCSE等)时,需要在$\cos$之后乘上一个30左右的Scale了,因为不乘是很难训得动模型的。
对于不同的$d$和$n$,读者可以自行修改下面的代码计算最优$\alpha$:
(*定义函数*)
h[a_] :=
Integrate[Exp[a*s]*(1 - s^2)^((d - 3)/2), {s, -1, 1},
Assumptions -> {d > 10}]
g[a_] = h[a]/h[0] // FullSimplify;
f[a_, n_] := a (1 - g[2*a]/g[a]^2/n) /. {d -> 128}
(*找到函数的最大点对应的a*)
FindArg[n_] :=
Module[{a}, a = a /. Last@NMaximize[{f[a, n], a > 0}, a][[2]]; a]
(*给定n的范围*)
nRange = 40*Range[1, 500];
(*求出每个n对应的a*)
args = FindArg /@ nRange;
(*画出a与n的函数图像*)
ListLinePlot[{args, 3.5*Log[nRange]},
DataRange -> {40, 20000}, AxesLabel -> {"n", "a"},
PlotLegends -> {Row[{"a", Superscript["", "*"]}],
TraditionalForm[HoldForm[3.5*Log[n]]]}]
相关思考 #
本文的标题和结果,尤其是余弦分布中$\alpha$近似正比于$\log n$的结果,很容易让我们联想到另一篇讨论Attention Scale的文章《从熵不变性看Attention的Scale操作》。事实上,两篇文章的联系确实存在,本文的优化目标$\eqref{eq:target}$出现了“Rényi熵”,而“熵不变性”的熵指的是香侬信息熵,两者的性质很大程度上是一致的。最大化式$\eqref{eq:target}$使得它进入了一个“缓变”的区域,这意味着“Rényi熵”关于$n$的变化是很慢的,也意味着信息熵关于$n$的变化是很慢的,这就约等于熵不变性。
此外,对于双向Attention(Encoder)来说,假设训练样本长度相同,那么$n$就是一个常数,我们可以根据$n$算得相应的最优$\alpha$,然后固定在模型中即可;但是对于单向Attention(Decoder)来说,每个token的$n$实际上都不一样(位置id加1),所以理论上无法做到对所有token都最大化式$\eqref{eq:target}$,不过由于$\alpha^*$关于$n$的变化较慢,所以取一个差不多的值就行了,比如可以取$n=L_{\max} / 2$,这样对大部分token的梯度都比较友好了。
文章小结 #
本文从梯度的角度探讨了Attention Scale因子的选择问题。众所周知,关于这个Scale因子的“标准答案”是$\frac{1}{\sqrt{d}}$,但其推导过程中并没有讨论到它的最优性问题,所以笔者定义了一个Softmax梯度的优化目标,从最大化该目标的角度探讨了Scale因子的最优值。相关结果既可以用来改进Attention的Scale因子,也可以用来解释$\cos$相似度的对比学习的温度参数。
转载到请包括本文地址:https://spaces.ac.cn/archives/9812
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Oct. 22, 2023). 《从梯度最大化看Attention的Scale操作 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9812
@online{kexuefm-9812,
title={从梯度最大化看Attention的Scale操作},
author={苏剑林},
year={2023},
month={Oct},
url={\url{https://spaces.ac.cn/archives/9812}},
}
October 22nd, 2023
[...]Read More [...]
November 7th, 2023
虽然推导看懂了,但没有看懂行文描述。请问scalar =1/根号d 一定是一个小于1的分数。那么缩小怎么会是避免梯度消失的手段呢?这里是不是应该写成梯度爆炸呢?
这里说初始attention会是one-hot,有点糊涂。既然作者已经写明qk初始分布在N(0,d)怎么会是one-hot呢?
相信我,你根本没看懂推导。
1、为什么“缩小就不能是避免梯度消失的手段”呢?它犯了哪条法律了吗?还是有其他显然不成立的理由?
2、$\boldsymbol{q}\cdot\boldsymbol{k}$之后还有softmax。
November 24th, 2023
太屌了~
January 31st, 2024
激活函数各种改进,这个softmax的梯度对输入方差这么敏感,有没有啥改进的进展呢?比如在attention中换一个概率归一化方案,或者把scale因子作为参数来学习?
方差敏感不也说明单纯改变方差就可以控制稀疏程度了,这不也是好事?
但这梯度也太敏感了,每个都要重新调整,比如带前向mask的那长度都从1到n连续变化的
我试试了,softmax的雅可比矩阵最大元素的值是0.25,如果我们使用4xsoftmax(x)就可以把最大的梯度拉到1,或者我试了另一个平替sotmax到函数,a=power(0.25x,6),y=a/sum(a),用这个函数平替,softmax,雅可比矩阵的梯度的最大值比较接近1,使用0.25这个scale就可以
个人认为稀疏性是本质,参考:https://kexue.fm/archives/9889 。所以单纯把attention的梯度整得很好看可能是不够的...
这个函数的特点是最优s因子不随n的变化而变化,比较稳定,适用于attention单项mask这种长度连续变化的情况
s是scale
保证梯度的最优性是一回事,但保证稀疏程度的上限更重要。当然,你用的$x^6$大概率也是能保证稀疏性的。
但我的疑问是为啥偏偏是1?也就是说最大值是1不是小于1的常数有什么特别的好处吗?不用担心梯度消失?但这里问题可能有些复杂,attention之间的梯度不是完全串联(链式法则)的。
从熵或者https://kexue.fm/archives/9889 中给出的度量稀疏性的指标都是可以满足的
a=power(x,6),y=a/sum(a),这个函数的梯度大多数时候都是比softmax要大,所以要loss下降要快,我单纯在mlp+softmax上做的实验
之前有个错误,power(x,6)这个函数前向时对输入的缩放无感,反向梯度对输入的缩放也无感(这个结论应该是适用于所有f(ax)=f(x)的函数),跟rmsnorm一样,所以0.25这个缩放因子没意义!所以长度的变化也无法影响其梯度!
根据您https://spaces.ac.cn/archives/6992,我不知道pow(x,6)和exp这两个函数的Lipschitz常数大小是不是明显不同?
【attention之间的梯度不是完全串联(链式法则)的】这个没明白
1、一般情况下都要加个epsilon防止除零错误,加了就不是缩放不变的了~
2、“不是链式法则”指的是attention的梯度不像是rnn那样满足连乘的链式法则,只要每个小于1就不会梯度爆炸,所以我不大理解将梯度弄到接近于1的意图是什么。
January 31st, 2024
公式3中为啥把非对角元p_i x p_j都给丢掉了?
在$\sum_i p_i = 1$的约束之下,最后一个等号就是恒等式,不是“丢掉”,是直接推出来的。如果只是直接丢掉的话,右端还要多乘一个$1/2$。
这相当于是非对角元素之和 = 对角元素之和 这个是不是推错了 非对角元素是对称的,所以所有非对角元素之和=2倍下三角元素(除了对角元素)的和,
这个结果怎么=对角元素之和的呢?我推了下,没理解怎么会想等呢?
另外,我使用数值方法,画了下最优的scale因子和长度n之间的曲线关系,发现没有明显的log(n)规律!我用的不是f1范数而是看雅可比矩阵中最大的元素作为雅可比矩阵的度量
最优的缩放因子,大概都在1~5之间
$$\sum_{i\neq j} p_i p_j = \sum_{i,j} p_i p_j - \sum_i p_i^2 = \left(\sum_i p_i\right)\left(\sum_j p_j\right) - \sum_i p_i^2 = 1 - \sum_i p_i^2$$
很基础的推导...
至于数值结果,论文的代码都是可复现的,毕竟这都是理论性质,不需要炼丹,没什么随机性或者误差。但我不确定其他度量是否能得到相同的结果。
您这个基础推导中sum(p_i)sum(p_j)=sum(p_ixP_j)包含着整体=部分之和的思想,同样思想的还有一维的牛顿莱布尼茨公式二维stokes公式三维的高斯定理,以及其微分流形下的形式,他们就是场论中maxswell方程组中的描述,概率上也有一样的思想E[xy]=E[x]E[y](x,y独立时)不知道还有没有包含同样思想的结论,能一下子都串起来,太有趣了
这也没那么神秘,就是满足分配律情况下都有$\left(\sum\limits_i x_i\right)\left(\sum\limits_j y_j\right) = \sum\limits_{i,j} x_i y_j$,说白了就是括号可以省掉。
January 31st, 2024
@苏剑林|comment-23648
原来如此,厉害,原来softmax的雅可比矩阵有这个性质
February 14th, 2024
你好, 对于余弦分布, 是文章中的拟合结果是 scale = 30 / sqrt(d) or scale = 30 / d?
To be clear:
q = RMSNorm(q)
k = RMSNorm(k)
d = k.shape[-1]
scale1 = 30 / math.sqrt(d)
scale2 = 30 / d
logit = matmul(q, k^T)
# logit = scale1 * logit or logit = scale2 * logit?
由于$p(s)$跟$d$的关系比较复杂,它没法简单整合到$s$中,所以$\alpha^*$与$d$的关系没有这么简单。
不过当$d$足够大时,利用近似$(1-s^2)^{(d-3)/2}\approx \exp(-s^2(d-3)/2)=\exp(-(s\sqrt{d-3})^2/2)$,可以得出在一定范围内,$\alpha^*$是反比$\sqrt{d-3}$的。
April 26th, 2024
苏神有个问题请教一下,在计算梯度的推到中,说了这么一句话“为了更利于优化,我们应该选取\alpha使得梯度尽可能最大化”,这里不是很理解,反向传播过程当中不应该是尽量使得梯度保持稳定吗,最大化的话不会造成梯度爆炸吗
事后分析来看,这个梯度是有上界的,就算取到最大值,也没有梯度爆炸的风险。
June 12th, 2024
(4)式中n在transformer中的物理含义是什么?还是说n只是单纯地在推导中引入了,表示的是$s_i$的取值个数?
哦,$n$是$s_i$的总个数。前面有个笔误,现在更正了($i\in\{1,2,\cdots,n\}$)