我们知道,Scaled Dot-Product Attention的Scale因子是1d,其中dq,k的维度。这个Scale因子的一般解释是:如果不除以d,那么初始的Attention就会很接近one hot分布,这会造成梯度消失,导致模型训练不起来。然而,可以证明的是,当Scale等于0时同样也会有梯度消失问题,这也就是说Scale太大太小都不行。

那么多大的Scale才适合呢?1d是最佳的Scale了吗?本文试图从梯度角度来回答这个问题。

已有结果 #

《浅谈Transformer的初始化、参数化与标准化》中,我们已经推导过标准的Scale因子1d,推导的思路很简单,假设初始阶段q,kRd都采样自“均值为0、方差为1”的分布,那么可以算得
Var[qk]=d
于是我们将qk除以d,将Attention Score的方差变为1。也就是说,之前的推导纯粹是基于“均值为0、方差为1”就会更好信仰来得到的结果,但没有解释让Attention Score的方差为1,也没有评估1d是否真的就解决了梯度消失问题。

当然,从已有的实验来看,1d至少一定程度上是缓解了这个问题,但这毕竟是实验结果,我们还是希望能从理论上知道“一定程度”究竟是多少。

计算梯度 #

既然涉及到了梯度,那么最好的办法就是把梯度算出来,然后定一个优化目标。设pi=eαsi/Zi{1,2,...,n}Z=ieαsi是归一化因子,那么可以直接算得:
pisj={α(pip2i),i=jαpipj,ij
或者可以简写成pi/sj=α(piδi,jpipj)。很明显,当α0时梯度为0;当α时,pi之中只有一个1、其余都是0(假设si中只有唯一的最大值),梯度也是0。

为了更有利于优化,我们应该选取α使得梯度尽可能最大化。为此,我们以L1范数作为梯度大小的度量:
12ps1=12i,j|pisj|=12iα(pip2i)+12ijαpipj=α(1ip2i)
从最后的结果不难猜到,之所以选择L1而不是其他的根本原因是因为L1范数的计算结果足够简单。值得指出的是,这里出现了ip2i,它本质上就是我们在《如何度量数据的稀疏程度?》介绍过的“Rényi熵”,跟信息熵类似,它也是不确定性的一种度量。

有了优化目标后,我们就可以着手进行最大化了。注意pi的定义里边也包含α,所以这是一个关于α复杂的非线性目标,看上去求解析解是不可能的,但我们可以针对一些特殊例子求近似解。

正态分布 #

首先,我们可以接着前面的结果来做,当我们通过除以d使得Attention Score的均值为0、方差为1后,我们就可以近似假设siN(0,1),然后再求α的最优解,如果α=1,那么就意味着原来的1d就是最优的Scale比例了,否则αd才是最佳的Scale比例。

我们用期望去估计求和
ip2i=ie2αsi(ieαsi)2=1nie2αsin(1nieαsi)2Es[e2αs]n(Es[eαs])2
对于服从标准正态分布的s,我们有
Es[eαs]=12πes2/2eαsds=eα2/2
代入上式,然后代入式(3),得到
α(1ip2i)α(1eα2n)
最后的近似,虽然已经足够简化了,但其实也不容易求出最大值来。不过无妨,我们可以遍历一些n,然后数值求解出取最大值时的α,这样我们就大致能看到αn的关系了,Mathematica的参考代码如下:

(*定义函数*)
f[a_, n_] := a*(1 - Exp[a^2]/n)
(*找到函数的最大点对应的a*)
FindArg[n_] := 
 Module[{a}, a = a /. Last@NMaximize[{f[a, n], a > 0}, a][[2]]; a]
(*给定n的范围*)
nRange = 40*Range[1, 500];
(*求出每个n对应的a*)
args = FindArg /@ nRange;
(*画出a与n的函数图像*)
ListLinePlot[{args, 0.84*Log[nRange]^0.5}, 
 DataRange -> {40, 20000}, AxesLabel -> {"n", "a"}, 
 PlotLegends -> {Row[{"a", Superscript["", "*"]}], 
   TraditionalForm[HoldForm[0.84*Sqrt[Log[n]]]]}]

经过拟合,笔者发现一定范围内最优点αn大致满足α0.84logn的关系,所以也已经将对应的近似函数一并画在一起:

标准正态分布的最优alpha与n关系

标准正态分布的最优alpha与n关系

可以看到,在相当大的一个范围内,α的最优值都在23之间,所以折中一下的话,盲取2.5d作为Attention的Scale因子理论上更有利于优化。

余弦分布 #

现在我们考虑另一个不那么常见的例子:当我们对q,k都做l2归一化变成单位向量后,它们的内积就变成了夹角余弦,即si近似服从d维空间中的两个随机向量的夹角余弦分布。这个分布可能有些读者并不熟悉,但之前我们在《n维空间下两个随机向量的夹角分布》已经探讨过,它的概率密度具有形式
p(s)(1s2)(d3)/2

看上去并不复杂,但事实上这个形式比正态分布难处理得多,主要是Es[eαs]已经不像式(5)那样可以用初等函数表达出来了,不过对于Mathematica数值求解来说问题不大。跟上一节同样的思路,近似式(4)也同样适用,先数值求解最大值,然后再拟合,结果如下(图中d=128αd相关):

余弦分布的最优alpha与n关系

余弦分布的最优alpha与n关系

可以看到,α3.5logn拟合得也不错(换一个d的话,3.5这个系数会变化)。可以看到,在一个相当大的范围内,α都是2535之间,所以如果用cos值作为Attention Score的话,就需要乘以一个2535之间的Scale,才能使得模型比较容易训下去。这同时也解释了为什么我们在用cos值构建Softmax分布(比如AM-SoftmaxSimCSE等)时,需要在cos之后乘上一个30左右的Scale了,因为不乘是很难训得动模型的。

对于不同的dn,读者可以自行修改下面的代码计算最优α

(*定义函数*)
h[a_] := 
 Integrate[Exp[a*s]*(1 - s^2)^((d - 3)/2), {s, -1, 1}, 
  Assumptions -> {d > 10}]
g[a_] = h[a]/h[0] // FullSimplify;
f[a_, n_] := a (1 - g[2*a]/g[a]^2/n) /. {d -> 128}
(*找到函数的最大点对应的a*)
FindArg[n_] := 
 Module[{a}, a = a /. Last@NMaximize[{f[a, n], a > 0}, a][[2]]; a]
(*给定n的范围*)
nRange = 40*Range[1, 500];
(*求出每个n对应的a*)
args = FindArg /@ nRange;
(*画出a与n的函数图像*)
ListLinePlot[{args, 3.5*Log[nRange]}, 
 DataRange -> {40, 20000}, AxesLabel -> {"n", "a"}, 
 PlotLegends -> {Row[{"a", Superscript["", "*"]}], 
   TraditionalForm[HoldForm[3.5*Log[n]]]}]

相关思考 #

本文的标题和结果,尤其是余弦分布中α近似正比于logn的结果,很容易让我们联想到另一篇讨论Attention Scale的文章《从熵不变性看Attention的Scale操作》。事实上,两篇文章的联系确实存在,本文的优化目标(3)出现了“Rényi熵”,而“熵不变性”的熵指的是香侬信息熵,两者的性质很大程度上是一致的。最大化式(3)使得它进入了一个“缓变”的区域,这意味着“Rényi熵”关于n的变化是很慢的,也意味着信息熵关于n的变化是很慢的,这就约等于熵不变性。

此外,对于双向Attention(Encoder)来说,假设训练样本长度相同,那么n就是一个常数,我们可以根据n算得相应的最优α,然后固定在模型中即可;但是对于单向Attention(Decoder)来说,每个token的n实际上都不一样(位置id加1),所以理论上无法做到对所有token都最大化式(3),不过由于α关于n的变化较慢,所以取一个差不多的值就行了,比如可以取n=Lmax,这样对大部分token的梯度都比较友好了。

文章小结 #

本文从梯度的角度探讨了Attention Scale因子的选择问题。众所周知,关于这个Scale因子的“标准答案”是\frac{1}{\sqrt{d}},但其推导过程中并没有讨论到它的最优性问题,所以笔者定义了一个Softmax梯度的优化目标,从最大化该目标的角度探讨了Scale因子的最优值。相关结果既可以用来改进Attention的Scale因子,也可以用来解释\cos相似度的对比学习的温度参数。

转载到请包括本文地址:https://spaces.ac.cn/archives/9812

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Oct. 22, 2023). 《从梯度最大化看Attention的Scale操作 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9812

@online{kexuefm-9812,
        title={从梯度最大化看Attention的Scale操作},
        author={苏剑林},
        year={2023},
        month={Oct},
        url={\url{https://spaces.ac.cn/archives/9812}},
}