随着ChatGPT及其平替的火热,各种参数高效(Parameter-Efficient)的微调方法也“水涨船高”,其中最流行的方案之一就是本文的主角LoRA了,它出自论文《LoRA: Low-Rank Adaptation of Large Language Models》。LoRA方法上比较简单直接,而且也有不少现成实现,不管是理解还是使用都很容易上手,所以本身也没太多值得细写的地方了。

然而,直接实现LoRA需要修改网络结构,这略微麻烦了些,同时LoRA给笔者的感觉是很像之前的优化器AdaFactor,所以笔者的问题是:能否从优化器角度来分析和实现LoRA呢?本文就围绕此主题展开讨论。

方法简介 #

以往的一些结果(比如《Exploring Aniversal Intrinsic Task Subspace via Prompt Tuning》)显示,尽管预训练模型的参数量很大,但每个下游任务对应的本征维度(Intrinsic Dimension)并不大,换句话说,理论上我们可以微调非常小的参数量,就能在下游任务取得不错的效果。

LoRA借鉴了上述结果,提出对于预训练的参数矩阵W0Rn×m,我们不去直接微调W_0,而是对增量做低秩分解假设:
\begin{equation}W = W_0 + A B,\qquad A\in\mathbb{R}^{n\times r},B\in\mathbb{R}^{r\times m}\end{equation}
其中A,B之一用全零初始化,W_0固定不变,优化器只优化A,B。由于本征维度很小的结论,所以r我们可以取得很小,常见的是r=8,极端情况下我们甚至可以取1。所以说,LoRA是一种参数高效的微调方法,至少被优化的参数量大大降低了。

用MathJax直接画了个示意图:
\style{display: inline-block; width: 24ex; padding: 10ex 0; border: 1px solid #6C8EBF; background-color: #DAE8FC}{W_0\in\mathbb{R}^{n\times m}} \quad + \quad \style{display: inline-block; width: 8ex; padding: 10ex 0; border: 1px solid #D79B00; background-color: #FFE6CC}{A\in\mathbb{R}^{n\times r}}\quad\times\quad \style{display: inline-block; width: 24ex; padding: 3ex 0; border: 1px solid #D79B00; background-color: #FFE6CC}{B\in\mathbb{R}^{r\times m}}

梯度分析 #

正如《Ladder Side-Tuning:预训练模型的“过墙梯”》所提到的,很多参数高效的微调实际上只是降低了显存需求,并没有降低计算量。那么LoRA是否例外呢?它在显存和计算量方面的效率如何呢?下面我们来分析一下。

首先,我们知道训练模型所消耗的显存来源包括模型参数模型梯度模型激活值优化器状态四部份,LoRA通过低秩分解降低了模型参数量,那么梯度和优化器状态也会随之降低,因此节省的显存是很明显的。那它能否节省计算量呢?

这取决于LoRA的实现方式,不同的实现方式计算梯度的复杂度不一样。LoRA的两种等效实现如下:
\begin{align}Y =&\, XW = X(W_0 + AB) \label{eq:lora-1}\\[5pt] Y =&\, XW_0 + XAB = XW_0 + ZB \label{eq:lora-2}\end{align}
其中X\in\mathbb{R}^{b\times n}是模型输入,Z=XA\in\mathbb{R}^{b\times r}是中间输出。针对实现\eqref{eq:lora-1},我们有
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial W} B^{\top} = \left(X^{\top}\frac{\partial \mathcal{L}}{\partial Y}\right) B^{\top},\quad \frac{\partial \mathcal{L}}{\partial B} = A^{\top}\frac{\partial \mathcal{L}}{\partial W} = A^{\top}\left(X^{\top}\frac{\partial \mathcal{L}}{\partial Y}\right)\label{eq:grad-1}\end{equation}
\mathcal{L}是损失函数。很明显,这种实现导致的后果是需要算完整梯度\frac{\partial \mathcal{L}}{\partial W}\in\mathbb{R}^{n\times m},然后才能算A,B的梯度,这意味着它比不LoRA还慢,也费显存。对于实现\eqref{eq:lora-2},我们则有
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = X^{\top}\frac{\partial \mathcal{L}}{\partial Z} = X^{\top}\left(\frac{\partial \mathcal{L}}{\partial Y} B^{\top}\right),\quad \frac{\partial \mathcal{L}}{\partial B} = Z^{\top}\frac{\partial \mathcal{L}}{\partial Y} = (XA)^{\top}\frac{\partial \mathcal{L}}{\partial Y}\label{eq:grad-2}\end{equation}
此时的Z,\frac{\partial \mathcal{L}}{\partial Z}\in\mathbb{R}^{b\times r},相比完整的梯度显然省了不少,计算复杂度也明显降低。所以,LoRA想要节省显存和计算最大化,关键是按照\eqref{eq:lora-2}而不是\eqref{eq:lora-1}来实现。

(注:关于矩阵计算梯度,我们可以根据链式法则和输出形状来“凑”,比如\frac{\partial \mathcal{L}}{\partial A},根据链式法则我们知道它必然是\frac{\partial \mathcal{L}}{\partial W}B以某种方式相乘,我们约定\frac{\partial \mathcal{L}}{\partial A}的形状跟A一致,即n\times r,想要用\frac{\partial \mathcal{L}}{\partial W}B凑出一个n\times r的结果来,那就只有\frac{\partial \mathcal{L}}{\partial W} B^{\top}了。)

其他原因 #

除了低秩分解带来的好处外,如下几点也是LoRA能节省显存和提速的原因:

1、只更新了部分参数:比如LoRA原论文就选择只更新Self Attention的参数,实际使用时我们还可以选择只更新部分层的参数;

2、减少了通信时间:由于更新的参数量变少了,所以(尤其是多卡训练时)要传输的数据量也变少了,从而减少了传输时间;

3、采用了各种低精度加速技术,如FP16、FP8或者INT8量化等。

当然,这三部分原因确实能加快训练速度,但它们并不是LoRA所独有的,事实上几乎都有参数高效方法都具有这些特点。LoRA的突出优点是它的低秩分解很直观,在不少场景下跟全量微调的效果一致,以及在预测阶段可以直接把W_0,A,B合并成单个矩阵从而不增加推理成本。

优化视角 #

梯度\eqref{eq:grad-1}还告诉了我们如何从优化器角度来实现LoRA。优化器可以直接获取到全量梯度\frac{\partial \mathcal{L}}{\partial W},然后我们只需要按照公式\eqref{eq:grad-1}对梯度进行投影,就得到A,B的梯度,接着就可以按照常规的优化器实现A,B的更新了。

假如优化器是SGD,那么就是
\begin{equation}\begin{aligned} A_{t+1} =&\, A_t - \eta\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top},\quad B_{t+1} = B_t - \eta A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\\[5pt] W_{t+1} =&\, W_0 + A_{t+1} B_{t+1} = W_t + (A_{t+1} B_{t+1} - A_t B_t) \end{aligned}\end{equation}
如果是Adam之类的带滑动变量的优化器,则只需要滑动投影后的梯度,因此是降低了优化器的参数量,节省了一定的显存。模型越大,这部分参数所占的显存比例也就越大。

LoRA约定AB之一使用全零初始化,这是为了保证初始状态模型跟预训练一致,但同时也带来了不对称问题(一个全零,一个非全零)。事实上,A,B都使用非全零初始化也是可以的,只需要事先将预训练权重减去A_0 B_0就行了,或者等价地说,将W参数化为
\begin{equation}W = W_0 - A_0 B_0 + A B\end{equation}
这样同时保持了初始状态一致,同时允许A,B都用非全零初始化,增强了对称性。

随机投影 #

如果我们将SGD场景下的更新量A_{t+1} B_{t+1} - A_t B_t展开,结果将是
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} B_t + A_t A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right) + \eta^2 \frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\end{equation}
假设\eta^2项是可以忽略的高阶项,那么就剩下
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} B_t + A_t A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right)\end{equation}
从这个角度来看,相比全量微调的SGD,LoRA就是用括号中的结果替代了全量的梯度\frac{\partial \mathcal{L}}{\partial W_t}

简单起见,接下来我们只关心r=1的情形,留意到在上式中,t时刻的投影向量A_t,B_t是依赖于t的,如果我们将它们换成不依赖于t的随机向量(每步训练都重新随机生成),那么会发生什么呢?我们考虑u,v\sim\mathcal{N}(0,1),其中u\in\mathbb{R}^{m\times 1}, v\in\mathbb{R}^{1\times n},那么更新量就变为
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} v^{\top} v + u u^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right)\end{equation}
可以证明的是
\begin{equation}\mathbb{E}_{u\sim \mathcal{N}(0,1)}[u u^{\top}] = I_{n\times n},\quad \mathbb{E}_{v\sim \mathcal{N}(0,1)}[v^{\top} v] = I_{m\times m}\end{equation}
这里的I_{n\times n},I_{m\times m}分别指n\times n,m\times m的单位矩阵。因此,跟“零阶梯度”类似,在平均意义下,这种每步都重新初始化的LoRA事实上等价于满秩的SGD。然而,真要按照这个方式实现的话,其速度甚至可能比满秩的SGD都要慢,所以它的目的不是提速,而是希望能缓解灾难遗忘问题——通过对单个(batch)样本使用低秩矩阵(而不是满秩)更新量的方式,减少对整个模型权重的影响。当然,这只是猜测,实际效果如何,笔者还没有实验过。

一个变体 #

同样还是先只考虑r=1的情形,LoRA相当于假设了\Delta w_{i,j} = u_i v_j,我们能不能做其他低秩分解假设呢?比如\Delta w_{i,j} = u_i + v_j?写成矩阵形式就是
\begin{equation}W = W_0 + A \mathbb{1}_{1\times m} + \mathbb{1}_{n\times 1} B,\qquad A\in\mathbb{R}^{n\times 1},B\in\mathbb{R}^{1\times m}\end{equation}
其中\mathbb{1}_{1\times m},\mathbb{1}_{n\times 1}分别指1\times m,n\times 1的全1矩阵。容易求出它的梯度是:
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial W} \mathbb{1}_{m\times 1},\quad \frac{\partial \mathcal{L}}{\partial B} = \mathbb{1}_{1\times n}\frac{\partial \mathcal{L}}{\partial W}\end{equation}
其实就是原本梯度的行求和与列求和。相比原版LoRA,这个加性分解有两个优点:1、加比乘计算量更低,梯度形式也更简单;2、AB的秩一定是1,但是A \mathbb{1}_{1\times m} + \mathbb{1}_{n\times 1} B的秩可能是2,如果秩代表了模型能力的话,那也就是说同样的参数量,加性的表达能力可能还更强。至于具体效果如何,后面笔者用到LoRA的时候,再做对比实验吧。

那么,加性分解能不能推广到r > 1的情形呢?自然是可以的,但稍微有些技巧。这里约定m,n都能被r整除,那么我们只需要将参数化方式改为
\begin{equation}W = W_0 + A I_{r(1\times m/r)} + I_{r(n/r\times 1)} B,\qquad A\in\mathbb{R}^{n\times r},B\in\mathbb{R}^{r\times m}\end{equation}
这里的I_{r(1\times m/r)}I_{r(n/r\times 1)}分别指1\times m/rn/r\times 1的分块矩阵,每一块则是r\times r的单位阵。这个形式说白了,就是分别将AB看成是n/r\times 11\times m/r的分块矩阵,然后套用r=1的思路来操作。

文章小结 #

本文介绍了从梯度角度来理解LoRA,除了基本的介绍外,还包含了笔者的一些猜测和推广,供读者参考。

转载到请包括本文地址:https://spaces.ac.cn/archives/9590

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Apr. 17, 2023). 《梯度视角下的LoRA:简介、分析、猜测及推广 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9590

@online{kexuefm-9590,
        title={梯度视角下的LoRA:简介、分析、猜测及推广},
        author={苏剑林},
        year={2023},
        month={Apr},
        url={\url{https://spaces.ac.cn/archives/9590}},
}