预训练刚兴起时,在语言模型的输出端重用Embedding权重是很常见的操作,比如BERT、第一版的T5、早期的GPT,都使用了这个操作,这是因为当模型主干部分不大且词表很大时,Embedding层的参数量很可观,如果输出端再新增一个独立的同样大小的权重矩阵的话,会导致显存消耗的激增。不过随着模型参数规模的增大,Embedding层的占比相对变小了,加之《Rethinking embedding coupling in pre-trained language models》等研究表明共享Embedding可能会有些负面影响,所以现在共享Embedding的做法已经越来越少了。

本文旨在分析在共享Embedding权重时可能遇到的问题,并探索如何更有效地进行初始化和参数化。尽管共享Embedding看起来已经“过时”,但这依然不失为一道有趣的研究题目。

共享权重 #

在语言模型的输出端重用Embedding权重的做法,英文称之为“Tied Embeddings”或者“Coupled Embeddings”,其思想主要是Embedding矩阵跟输出端转换到logits的投影矩阵大小是相同的(只差个转置),并且由于这个参数矩阵比较大,所以为了避免不必要的浪费,干脆共用同一个权重,如下图所示:

共享Embedding权重的Transformer示意图

共享Embedding权重的Transformer示意图

共享Embedding最直接的后果可能是——它会导致预训练的初始损失非常大。这是因为我们通常会使用类似DeepNorm的技术来降低训练难度,它们都是将模型的残差分支初始化得接近于零。换言之,模型在初始阶段近似于一个恒等函数,这使得初始模型相当于共享Embedding的2-gram模型。接下来我们将推导这样的2-gram模型损失大的原因,以及分析一些解决方案。

准备工作 #

在正式开始推导之前,我们需要准备一些基础结论。

首先,要明确的是,我们主要对初始阶段的结果进行分析,此时的权重都是从某个“均值为0、方差为σ2”的分布中独立同分布地采样出来的,这允许我们通过期望来估计某些求和结果。比如对于w=(w1,w2,,wd),我们有
E[w2]=E[iw2i]=iE[w2i]=dσ2


因此可以取wdσ。那么误差有多大呢?我们可以通过它的方差来感知。为此,我们先求它的二阶矩:
E[w4]=E[(iw2i)2]=E[iw4i+i,j|ijw2iw2j]=iE[w4i]+i,j|ijE[w2i]E[w2j]=dE[w4]+d(d1)σ4

如果采样分布是正态分布,那么可以直接算出E[w4]=3σ4,所以
Var[w2]=E[w4]E[w2]2=2dσ4

这个方差大小也代表着wdσ的近似程度,也就是说原本的采样方差σ2越小,那么近似程度越高。特别地,常见的采样方差是1/d(对应w1,即单位向量),那么代入上式得到2/d,意味着维度越高近似程度越高。此外,如果采样分布不是正态分布,可以另外重新计算E[w4],或者直接将正态分布的结果作为参考结果,反正都只是一个估算罢了。

如果v=(v1,v2,,vd)是另一个独立同分布向量,那么我们可以用同样的方法估计内积,结果是
E[wv]=E[iwivi]=iE[wi]E[vi]=0


以及
E[(wv)2]=E[(iwivi)2]=E[iw2iv2i+i,j|ijwiviwjvj]=iE[w2i]E[w2j]+i,j|ijE[wi]E[vi]E[wj]E[vj]=dσ4

同样地,取σ2=1/d的话,那么方差是1/d3,维度越高近似程度越高。以上两个结果可以说是《n维空间下两个随机向量的夹角分布》《让人惊叹的Johnson-Lindenstrauss引理:理论篇》中的结论的统计版本。

损失分析 #

对语言模型来说,最终要输出一个逐token的n元分布,这里n是词表大小。假设我们直接输出均匀分布,也就是每个token的概率都是1/n,那么不难计算交叉熵损失将会是logn。这也就意味着,合理的初始化不应该使得初始损失明显超过logn,因为logn代表了最朴素的均匀分布,明显超过logn等价于说远远不如均匀分布,就好比是故意犯错,并不合理。

那么,为什么共享Embedding会出现这种情况呢?假设初始Embedding是{w1,w2,,wn},前面已经说了,初始阶段残差分支接近于零,所以输入输入token i,模型输出就是经过Normalization之后的Embedding wi。常见的Normalization就是Layer Norm或者RMS Norm,由于初始化分布是零均值的,所以Layer Norm跟RMS Norm大致等价,因此输出是
wiwi/d=wiσ


接下来重用Embedding,内积然后Softmax,所建立的分布实质是
p(j|i)=ewiwj/σkewiwk/σ

对应的损失函数就是
logp(j|i)=logkewiwk/σwiwj/σ

语言模型任务是为了预测下一个token,而我们知道自然句子中叠词的比例很小,所以基本上可以认为ji,那么根据结果(4)就有wiwj0。所以,初始损失函数是
logp(j|i)logkewiwk/σ=log(ewiwi/σ+k|kiewiwk/σ)log(edσ+(n1))

后面的再次用到了式(1)和式(4)。常见的初始化方差σ2,或者是一个常数,或者是1/d(此时edσ=ed),不管是哪一种,当d较大时,都导致edσ占主导,于是损失将会是logedσ=dσ级别,这很容易就超过了均匀分布的logn

一些对策 #

根据上述推导结果,我们就可以针对性地设计一些对策了。比较直接的方案是调整初始化,根据式(9),我们只需要让edσ=n,那么初始损失就是变成logn级别的,也就是说初始化的标准差要改为σ=(logn)/d

一般来说,我们会希望参数的初始化方差尽量大一些,这样梯度相对来说没那么容易下溢,而σ=(logn)/d有时候会显得过小了。为此,我们可以换一种思路:很明显,式(9)之所以会偏大,是因为出现了ewiwi/σ,由于两个wi相同,它们内积变成了模长,从而变得很大,如果能让它们不同,那么就不会出现这一个占主导的项了。

为此,最简单的方法自然是干脆不共享Embedding,此时是ewivi/σ而不是ewiwi/σ,用(4)而不是(1)作为近似,于是式(9)渐近于logn。如果还想保留共享Embedding,我们可以在最后的Normalization之后,再接一个正交初始化的投影层,这样ewiwi/σ变成了e(wiP)wi/σ,根据Johnson-Lindenstrauss引理,经过随机投影的向量近似于独立向量了,所以也近似于不共享的情况,这其实就是BERT的解决办法。特别地,这个投影层还可以一般化地加上bias和激活函数。

如果一丁点额外参数都不想引入,那么可以考虑在Normalization之后“打乱”wi的各个维度,比如
S[w]=w[d/2:]w[:d/2]


这里的是拼接操作,那么S[wi]wi也接近正交了,内积自然也约等于0。这相当于(在初始阶段)将原来的n×d的Embedding矩阵劈开为两个n×(d/2)的矩阵然后构建不共享Embedding的2-gram模型。另外,我们还可以考虑其他打乱操作,比如ShuffleNet中的先reshape,然后transpose再reshape回来。

在笔者的实验中,直接改初始化标准差为σ=(logn)/d收敛速度是最慢的,其余方法收敛速度差不多,至于最终效果,所有方法似乎都差不多。

文章小结 #

本文重温了语言模型输出端共享Embedding权重的操作,推导了直接重用Embedding来投影输出可能会导致损失过大的可能性,并探讨了一些解决办法。

转载到请包括本文地址:https://spaces.ac.cn/archives/9698

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Jul. 20, 2023). 《语言模型输出端共享Embedding的重新探索 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9698

@online{kexuefm-9698,
        title={语言模型输出端共享Embedding的重新探索},
        author={苏剑林},
        year={2023},
        month={Jul},
        url={\url{https://spaces.ac.cn/archives/9698}},
}