从梯度最大化看Attention的Scale操作
By 苏剑林 | 2023-10-22 | 82166位读者 |我们知道,Scaled Dot-Product Attention的Scale因子是1√d,其中d是q,k的维度。这个Scale因子的一般解释是:如果不除以√d,那么初始的Attention就会很接近one hot分布,这会造成梯度消失,导致模型训练不起来。然而,可以证明的是,当Scale等于0时同样也会有梯度消失问题,这也就是说Scale太大太小都不行。
那么多大的Scale才适合呢?1√d是最佳的Scale了吗?本文试图从梯度角度来回答这个问题。
已有结果 #
在《浅谈Transformer的初始化、参数化与标准化》中,我们已经推导过标准的Scale因子1√d,推导的思路很简单,假设初始阶段q,k∈Rd都采样自“均值为0、方差为1”的分布,那么可以算得
Var[q⋅k]=d
于是我们将q⋅k除以√d,将Attention Score的方差变为1。也就是说,之前的推导纯粹是基于“均值为0、方差为1”就会更好的信仰来得到的结果,但没有解释让Attention Score的方差为1,也没有评估1√d是否真的就解决了梯度消失问题。
当然,从已有的实验来看,1√d至少一定程度上是缓解了这个问题,但这毕竟是实验结果,我们还是希望能从理论上知道“一定程度”究竟是多少。
计算梯度 #
既然涉及到了梯度,那么最好的办法就是把梯度算出来,然后定一个优化目标。设pi=eαsi/Z,i∈{1,2,...,n},Z=∑ieαsi是归一化因子,那么可以直接算得:
∂pi∂sj={α(pi−p2i),i=j−αpipj,i≠j
或者可以简写成∂pi/∂sj=α(piδi,j−pipj)。很明显,当α→0时梯度为0;当α→∞时,pi之中只有一个1、其余都是0(假设si中只有唯一的最大值),梯度也是0。
为了更有利于优化,我们应该选取α使得梯度尽可能最大化。为此,我们以L1范数作为梯度大小的度量:
12‖∂p∂s‖1=12∑i,j|∂pi∂sj|=12∑iα(pi−p2i)+12∑i≠jαpipj=α(1−∑ip2i)
从最后的结果不难猜到,之所以选择L1而不是其他的根本原因是因为L1范数的计算结果足够简单。值得指出的是,这里出现了∑ip2i,它本质上就是我们在《如何度量数据的稀疏程度?》介绍过的“Rényi熵”,跟信息熵类似,它也是不确定性的一种度量。
有了优化目标后,我们就可以着手进行最大化了。注意pi的定义里边也包含α,所以这是一个关于α复杂的非线性目标,看上去求解析解是不可能的,但我们可以针对一些特殊例子求近似解。
正态分布 #
首先,我们可以接着前面的结果来做,当我们通过除以√d使得Attention Score的均值为0、方差为1后,我们就可以近似假设si∼N(0,1),然后再求α的最优解,如果α=1,那么就意味着原来的1√d就是最优的Scale比例了,否则α√d才是最佳的Scale比例。
我们用期望去估计求和
∑ip2i=∑ie2αsi(∑ieαsi)2=1n∑ie2αsin(1n∑ieαsi)2≈Es[e2αs]n(Es[eαs])2
对于服从标准正态分布的s,我们有
Es[eαs]=∫1√2πe−s2/2eαsds=eα2/2
代入上式,然后代入式(3),得到
α(1−∑ip2i)≈α(1−eα2n)
最后的近似,虽然已经足够简化了,但其实也不容易求出最大值来。不过无妨,我们可以遍历一些n,然后数值求解出取最大值时的α∗,这样我们就大致能看到α∗与n的关系了,Mathematica的参考代码如下:
(*定义函数*)
f[a_, n_] := a*(1 - Exp[a^2]/n)
(*找到函数的最大点对应的a*)
FindArg[n_] :=
Module[{a}, a = a /. Last@NMaximize[{f[a, n], a > 0}, a][[2]]; a]
(*给定n的范围*)
nRange = 40*Range[1, 500];
(*求出每个n对应的a*)
args = FindArg /@ nRange;
(*画出a与n的函数图像*)
ListLinePlot[{args, 0.84*Log[nRange]^0.5},
DataRange -> {40, 20000}, AxesLabel -> {"n", "a"},
PlotLegends -> {Row[{"a", Superscript["", "*"]}],
TraditionalForm[HoldForm[0.84*Sqrt[Log[n]]]]}]
经过拟合,笔者发现一定范围内最优点α∗与n大致满足α≈0.84√logn的关系,所以也已经将对应的近似函数一并画在一起:
可以看到,在相当大的一个范围内,α∗的最优值都在2∼3之间,所以折中一下的话,盲取2.5√d作为Attention的Scale因子理论上更有利于优化。
余弦分布 #
现在我们考虑另一个不那么常见的例子:当我们对q,k都做l2归一化变成单位向量后,它们的内积就变成了夹角余弦,即si近似服从d维空间中的两个随机向量的夹角余弦分布。这个分布可能有些读者并不熟悉,但之前我们在《n维空间下两个随机向量的夹角分布》已经探讨过,它的概率密度具有形式
p(s)∝(1−s2)(d−3)/2
看上去并不复杂,但事实上这个形式比正态分布难处理得多,主要是Es[eαs]已经不像式(5)那样可以用初等函数表达出来了,不过对于Mathematica数值求解来说问题不大。跟上一节同样的思路,近似式(4)也同样适用,先数值求解最大值,然后再拟合,结果如下(图中d=128,α∗跟d相关):
可以看到,α∗与3.5logn拟合得也不错(换一个d的话,3.5这个系数会变化)。可以看到,在一个相当大的范围内,α∗都是25∼35之间,所以如果用cos值作为Attention Score的话,就需要乘以一个25∼35之间的Scale,才能使得模型比较容易训下去。这同时也解释了为什么我们在用cos值构建Softmax分布(比如AM-Softmax、SimCSE等)时,需要在cos之后乘上一个30左右的Scale了,因为不乘是很难训得动模型的。
对于不同的d和n,读者可以自行修改下面的代码计算最优α:
(*定义函数*)
h[a_] :=
Integrate[Exp[a*s]*(1 - s^2)^((d - 3)/2), {s, -1, 1},
Assumptions -> {d > 10}]
g[a_] = h[a]/h[0] // FullSimplify;
f[a_, n_] := a (1 - g[2*a]/g[a]^2/n) /. {d -> 128}
(*找到函数的最大点对应的a*)
FindArg[n_] :=
Module[{a}, a = a /. Last@NMaximize[{f[a, n], a > 0}, a][[2]]; a]
(*给定n的范围*)
nRange = 40*Range[1, 500];
(*求出每个n对应的a*)
args = FindArg /@ nRange;
(*画出a与n的函数图像*)
ListLinePlot[{args, 3.5*Log[nRange]},
DataRange -> {40, 20000}, AxesLabel -> {"n", "a"},
PlotLegends -> {Row[{"a", Superscript["", "*"]}],
TraditionalForm[HoldForm[3.5*Log[n]]]}]
相关思考 #
本文的标题和结果,尤其是余弦分布中α近似正比于logn的结果,很容易让我们联想到另一篇讨论Attention Scale的文章《从熵不变性看Attention的Scale操作》。事实上,两篇文章的联系确实存在,本文的优化目标(3)出现了“Rényi熵”,而“熵不变性”的熵指的是香侬信息熵,两者的性质很大程度上是一致的。最大化式(3)使得它进入了一个“缓变”的区域,这意味着“Rényi熵”关于n的变化是很慢的,也意味着信息熵关于n的变化是很慢的,这就约等于熵不变性。
此外,对于双向Attention(Encoder)来说,假设训练样本长度相同,那么n就是一个常数,我们可以根据n算得相应的最优α,然后固定在模型中即可;但是对于单向Attention(Decoder)来说,每个token的n实际上都不一样(位置id加1),所以理论上无法做到对所有token都最大化式(3),不过由于α∗关于n的变化较慢,所以取一个差不多的值就行了,比如可以取n=Lmax,这样对大部分token的梯度都比较友好了。
文章小结 #
本文从梯度的角度探讨了Attention Scale因子的选择问题。众所周知,关于这个Scale因子的“标准答案”是\frac{1}{\sqrt{d}},但其推导过程中并没有讨论到它的最优性问题,所以笔者定义了一个Softmax梯度的优化目标,从最大化该目标的角度探讨了Scale因子的最优值。相关结果既可以用来改进Attention的Scale因子,也可以用来解释\cos相似度的对比学习的温度参数。
转载到请包括本文地址:https://spaces.ac.cn/archives/9812
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Oct. 22, 2023). 《从梯度最大化看Attention的Scale操作 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9812
@online{kexuefm-9812,
title={从梯度最大化看Attention的Scale操作},
author={苏剑林},
year={2023},
month={Oct},
url={\url{https://spaces.ac.cn/archives/9812}},
}
January 10th, 2025
公式3里是不是应该减号
哦,没错