语言模型输出端共享Embedding的重新探索
By 苏剑林 | 2023-07-20 | 35945位读者 |预训练刚兴起时,在语言模型的输出端重用Embedding权重是很常见的操作,比如BERT、第一版的T5、早期的GPT,都使用了这个操作,这是因为当模型主干部分不大且词表很大时,Embedding层的参数量很可观,如果输出端再新增一个独立的同样大小的权重矩阵的话,会导致显存消耗的激增。不过随着模型参数规模的增大,Embedding层的占比相对变小了,加之《Rethinking embedding coupling in pre-trained language models》等研究表明共享Embedding可能会有些负面影响,所以现在共享Embedding的做法已经越来越少了。
本文旨在分析在共享Embedding权重时可能遇到的问题,并探索如何更有效地进行初始化和参数化。尽管共享Embedding看起来已经“过时”,但这依然不失为一道有趣的研究题目。
共享权重 #
在语言模型的输出端重用Embedding权重的做法,英文称之为“Tied Embeddings”或者“Coupled Embeddings”,其思想主要是Embedding矩阵跟输出端转换到logits的投影矩阵大小是相同的(只差个转置),并且由于这个参数矩阵比较大,所以为了避免不必要的浪费,干脆共用同一个权重,如下图所示:
共享Embedding最直接的后果可能是——它会导致预训练的初始损失非常大。这是因为我们通常会使用类似DeepNorm的技术来降低训练难度,它们都是将模型的残差分支初始化得接近于零。换言之,模型在初始阶段近似于一个恒等函数,这使得初始模型相当于共享Embedding的2-gram模型。接下来我们将推导这样的2-gram模型损失大的原因,以及分析一些解决方案。
准备工作 #
在正式开始推导之前,我们需要准备一些基础结论。
首先,要明确的是,我们主要对初始阶段的结果进行分析,此时的权重都是从某个“均值为0、方差为σ2”的分布中独立同分布地采样出来的,这允许我们通过期望来估计某些求和结果。比如对于w=(w1,w2,⋯,wd),我们有
E[‖w‖2]=E[∑iw2i]=∑iE[w2i]=dσ2
因此可以取‖w‖≈√dσ。那么误差有多大呢?我们可以通过它的方差来感知。为此,我们先求它的二阶矩:
E[‖w‖4]=E[(∑iw2i)2]=E[∑iw4i+∑i,j|i≠jw2iw2j]=∑iE[w4i]+∑i,j|i≠jE[w2i]E[w2j]=dE[w4]+d(d−1)σ4
如果采样分布是正态分布,那么可以直接算出E[w4]=3σ4,所以
Var[‖w‖2]=E[‖w‖4]−E[‖w‖2]2=2dσ4
这个方差大小也代表着‖w‖≈√dσ的近似程度,也就是说原本的采样方差σ2越小,那么近似程度越高。特别地,常见的采样方差是1/d(对应‖w‖≈1,即单位向量),那么代入上式得到2/d,意味着维度越高近似程度越高。此外,如果采样分布不是正态分布,可以另外重新计算E[w4],或者直接将正态分布的结果作为参考结果,反正都只是一个估算罢了。
如果v=(v1,v2,⋯,vd)是另一个独立同分布向量,那么我们可以用同样的方法估计内积,结果是
E[w⋅v]=E[∑iwivi]=∑iE[wi]E[vi]=0
以及
E[(w⋅v)2]=E[(∑iwivi)2]=E[∑iw2iv2i+∑i,j|i≠jwiviwjvj]=∑iE[w2i]E[w2j]+∑i,j|i≠jE[wi]E[vi]E[wj]E[vj]=dσ4
同样地,取σ2=1/d的话,那么方差是1/d3,维度越高近似程度越高。以上两个结果可以说是《n维空间下两个随机向量的夹角分布》、《让人惊叹的Johnson-Lindenstrauss引理:理论篇》中的结论的统计版本。
损失分析 #
对语言模型来说,最终要输出一个逐token的n元分布,这里n是词表大小。假设我们直接输出均匀分布,也就是每个token的概率都是1/n,那么不难计算交叉熵损失将会是logn。这也就意味着,合理的初始化不应该使得初始损失明显超过logn,因为logn代表了最朴素的均匀分布,明显超过logn等价于说远远不如均匀分布,就好比是故意犯错,并不合理。
那么,为什么共享Embedding会出现这种情况呢?假设初始Embedding是{w1,w2,⋯,wn},前面已经说了,初始阶段残差分支接近于零,所以输入输入token i,模型输出就是经过Normalization之后的Embedding wi。常见的Normalization就是Layer Norm或者RMS Norm,由于初始化分布是零均值的,所以Layer Norm跟RMS Norm大致等价,因此输出是
wi‖wi‖/√d=wiσ
接下来重用Embedding,内积然后Softmax,所建立的分布实质是
p(j|i)=ewi⋅wj/σ∑kewi⋅wk/σ
对应的损失函数就是
−logp(j|i)=log∑kewi⋅wk/σ−wi⋅wj/σ
语言模型任务是为了预测下一个token,而我们知道自然句子中叠词的比例很小,所以基本上可以认为j≠i,那么根据结果(4)就有wi⋅wj≈0。所以,初始损失函数是
−logp(j|i)≈log∑kewi⋅wk/σ=log(ewi⋅wi/σ+∑k|k≠iewi⋅wk/σ)≈log(edσ+(n−1))
后面的≈再次用到了式(1)和式(4)。常见的初始化方差σ2,或者是一个常数,或者是1/d(此时edσ=e√d),不管是哪一种,当d较大时,都导致edσ占主导,于是损失将会是logedσ=dσ级别,这很容易就超过了均匀分布的logn。
一些对策 #
根据上述推导结果,我们就可以针对性地设计一些对策了。比较直接的方案是调整初始化,根据式(9),我们只需要让edσ=n,那么初始损失就是变成logn级别的,也就是说初始化的标准差要改为σ=(logn)/d。
一般来说,我们会希望参数的初始化方差尽量大一些,这样梯度相对来说没那么容易下溢,而σ=(logn)/d有时候会显得过小了。为此,我们可以换一种思路:很明显,式(9)之所以会偏大,是因为出现了ewi⋅wi/σ,由于两个wi相同,它们内积变成了模长,从而变得很大,如果能让它们不同,那么就不会出现这一个占主导的项了。
为此,最简单的方法自然是干脆不共享Embedding,此时是ewi⋅vi/σ而不是ewi⋅wi/σ,用(4)而不是(1)作为近似,于是式(9)渐近于logn。如果还想保留共享Embedding,我们可以在最后的Normalization之后,再接一个正交初始化的投影层,这样ewi⋅wi/σ变成了e(wiP)⋅wi/σ,根据Johnson-Lindenstrauss引理,经过随机投影的向量近似于独立向量了,所以也近似于不共享的情况,这其实就是BERT的解决办法。特别地,这个投影层还可以一般化地加上bias和激活函数。
如果一丁点额外参数都不想引入,那么可以考虑在Normalization之后“打乱”wi的各个维度,比如
S[w]=w[d/2:]∘w[:d/2]
这里的∘是拼接操作,那么S[wi]和wi也接近正交了,内积自然也约等于0。这相当于(在初始阶段)将原来的n×d的Embedding矩阵劈开为两个n×(d/2)的矩阵然后构建不共享Embedding的2-gram模型。另外,我们还可以考虑其他打乱操作,比如ShuffleNet中的先reshape,然后transpose再reshape回来。
在笔者的实验中,直接改初始化标准差为σ=(logn)/d收敛速度是最慢的,其余方法收敛速度差不多,至于最终效果,所有方法似乎都差不多。
文章小结 #
本文重温了语言模型输出端共享Embedding权重的操作,推导了直接重用Embedding来投影输出可能会导致损失过大的可能性,并探讨了一些解决办法。
转载到请包括本文地址:https://spaces.ac.cn/archives/9698
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jul. 20, 2023). 《语言模型输出端共享Embedding的重新探索 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9698
@online{kexuefm-9698,
title={语言模型输出端共享Embedding的重新探索},
author={苏剑林},
year={2023},
month={Jul},
url={\url{https://spaces.ac.cn/archives/9698}},
}
July 21st, 2023
[...]Read More [...]
请问苏神这里说的”至于最终效果,所有方法似乎都差不多。“
和原始共享embedding有提升么,有多大提升呢?
llama就没有共享,不共享效果应该会更好吧,毕竟增加了参数
我个人的实验没提升。
July 21st, 2023
输出端重用Embedding权重的话,形式上是原Embedding的转置; 但是如果将embedding理解为一种翻译器,我更想把它处理为原Embedding的伪逆,相当于一种回译。
初始阶段Embedding是零均值独立同分布初始化的,这样得到的矩阵接近正交矩阵,所以它的转置实质上就接近它的伪逆(的若干倍)。所以恰恰相反,共享Embedding出现loss很大的原因,正好是因为输出初始化成了伪逆,原因文章也分析了,我们的任务是语言模型,叠词很少,所以伪逆(或者说回译)初始化反而导致了初始化全错,比随机蒙还不如。
叠词,精辟!
July 31st, 2023
按照文中的分析,bert在transformer结构后面接了一个hidden_size x hidden_size的linear层,理论上可以规避这一问题?
是的,本文已经说了,bert的做法已经规避了这个问题。
January 24th, 2024
苏神,这篇文章的分类类别是不是少了信息时代,在信息时代往下看找不到这篇。
归到了数学研究类别了~我新增一下吧
February 28th, 2025
[...]ALBERT: A Lite BERT for Self-supervised Learning of Language Representations[C]// International Conference on Learning Representations Distributed Representations of Words and Phrases and their Compos[...]