生成扩散模型漫谈(二十六):基于恒等式的蒸馏(下)
By 苏剑林 | 2024-11-22 | 27084位读者 |继续回到我们的扩散系列。在《生成扩散模型漫谈(二十五):基于恒等式的蒸馏(上)》中,我们介绍了SiD(Score identity Distillation),这是一种不需要真实数据、也不需要从教师模型采样的扩散模型蒸馏方案,其形式类似GAN,但有着比GAN更好的训练稳定性。
SiD的核心是通过恒等变换来为学生模型构建更好的损失函数,这一点是开创性的,同时也遗留了一些问题。比如,SiD对损失函数的恒等变换是不完全的,如果完全变换会如何?如何从理论上解释SiD引入的λ的必要性?上个月放出的《Flow Generator Matching》(简称FGM)成功从更本质的梯度角度解释了λ=0.5的选择,而受到FGM启发,笔者则进一步发现了λ=1的一种解释。
接下来我们将详细介绍SiD的上述理论进展。
思想回顾 #
根据上一篇文章的介绍,我们知道SiD实现蒸馏的思想是“相近的分布,它们训练出来的去噪模型也是相近的”,用公式表示就是
教师扩散模型:φ∗=argminφEx0∼˜p(x0),ε∼N(0,I)[‖ϵφ(xt,t)−ε‖2]学生扩散模型:ψ∗=argminψEz,ε∼N(0,I)[‖ϵψ(x(g)t,t)−ε‖2]学生生成模型:θ∗=argminθEz,ε∼N(0,I)[‖ϵφ∗(x(g)t,t)−ϵψ∗(x(g)t,t)‖2]⏟L1
这里记号比较多,我们逐一解释。第一个损失函数就是我们要蒸馏的扩散模型的训练目标,其中xt=ˉαtx0+ˉβtε代表加噪样本,ˉαt,ˉβt是noise schedule,x0是训练样本;第二个损失函数则是用学生模型生成的数据来训练的扩散模型,其中x(g)t=ˉαtgθ(z)+ˉβtε,这里的gθ(z)代表学生模型的生成样本,也记为x(g)0;第三个损失函数,则是试图通过拉近真实数据和学生数据所训练的扩散模型的差距,来训练学生生成模型(生成器)。
这里的教师模型是可以提前训练好的,而两个学生模型的训练只需要教师模型本身,并不需要用到训练教师模型的数据,所以作为一种蒸馏方式来看SiD是data-free的;两个学生模型则是类似GAN那样的交替训练,逐步提高生成器的生成质量。就笔者所阅读过的文献来看,这种训练思想最早出自论文《Learning Generative Models using Denoising Density Estimators》,我们在《从去噪自编码器到生成模型》也有过相关介绍。
然而,尽管看上去没什么毛病,但实际情况是式(2)和式(3)的交替训练非常容易崩溃,以至于几乎不能出效果。这是因为理论和实践上的两个gap:
1、理论上要求先求出式(2)的最优解,然后才去优化式(3),但实际上从训练成本考虑,我们并没有将它训练到最优就去优化式(3)了;
2、理论上ψ∗随θ而变,即应该写成ψ∗(θ),从而在优化式(3)时应该多出一项ψ∗(θ)对θ的梯度,但实际上在优化式(3)时我们都只当ψ∗是常数。
第1个问题其实还好,因为随着训练的推进ψ总能慢慢逼近理论最优的ψ∗,但第2个问题非常困难且本质,可以说GAN的训练不稳定性同样也有这个问题的“功劳”。而SiD和FGM的核心贡献,正是试图解决第2个问题。
恒等变换 #
SiD的想法是通过恒等变换来减少生成器损失函数(3)对ψ∗的依赖,从而弱化第2个问题。这一想法确实是开创性的,后面已经有不少工作围绕着SiD展开,包括下面要介绍的FGM也算是其中之一。
恒等变换的核心,是如下恒等式:
Ex0∼˜p(x0),ε∼N(0,I)[⟨f(xt,t),ϵφ∗(xt,t)⟩]=Ex0∼˜p(x0),ε∼N(0,I)[⟨f(xt,t),ε⟩]
简单来说就是ϵφ∗(xt,t)可以替换成ε。这里的ϵφ∗(xt,t)是式(1)的理论最优解,而f(xt,t)是任意只依赖于xt和t的向量函数。注意“只依赖于xt和t”是恒等式成立的必要条件,一旦f掺杂了独立的x0或ε,那么恒等式就未必成立了,所以应用该恒等式之前需要仔细检查这一点。
上一篇文章我们已经给出了该恒等式的证明,不过现在看来那个证明显得有点迂回,这里给出一个更直接点的证明:
证明:将目标(1)等价地改写成
φ∗=argminφExt∼p(xt)[Eε∼p(ε|xt)[‖ϵφ(xt,t)−ε‖2]]
根据E[x]=argminμEx[‖μ−x‖2](不熟悉可以求导证一下),我们可以得出上式的理论最优解是
ϵφ∗(xt,t)=Eε∼p(ε|xt)[ε]
所以
Ex0∼˜p(x0),ε∼N(0,I)[⟨f(xt,t),ϵφ∗(xt,t)⟩]=Ext∼p(xt)[⟨f(xt,t),ϵφ∗(xt,t)⟩]=Ext∼p(xt)[⟨f(xt,t),Eε∼p(ε|xt)[ε]⟩]=Ext∼p(xt),ε∼p(ε|xt)[⟨f(xt,t),ε⟩]=Ex0∼˜p(x0),ε∼N(0,I)[⟨f(xt,t),ε⟩]
证毕。证明过程的“必经之路”是第一个等号,这需要用到“f(xt,t)只依赖于xt和t”这个条件。
恒等式(4)的关键是ϵφ∗(xt,t)的最优性,而目标(1)和(2)形式是一样的,所以同样的结论也适用于ϵψ∗(xt,t),利用它我们就可以将(3)变换成
Ez,ε∼N(0,I)[‖ϵφ∗(x(g)t,t)−ϵψ∗(x(g)t,t)‖2]=Ez,ε∼N(0,I)[⟨ϵφ∗(x(g)t,t)−ϵψ∗(x(g)t,t),ϵφ∗(x(g)t,t)−ϵψ∗(x(g)t,t)⏟可以替换为ε⟩]=Ez,ε∼N(0,I)[⟨ϵφ∗(x(g)t,t)−ϵψ∗(x(g)t,t),ϵφ∗(x(g)t,t)−ε⟩]≜
最后的形式就是SiD所提的生成器损失函数\mathcal{L}_2,它是SiD成功训练的关键,我们可以理解为它通过恒等变换提前预估了\boldsymbol{\psi}^*的值,同时弱化了对\boldsymbol{\psi}^*的依赖,从而以它为损失函数训练生成器比\mathcal{L}_1有着更好的效果。
SiD的遗留问题是:
1、\mathcal{L}_2的恒等变换并不彻底,将\mathcal{L}_2展开会发现里边还有一项\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle\boldsymbol{\epsilon}_{\boldsymbol{\varphi}^*}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\rangle],这一项的\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)同样可以替换为\boldsymbol{\varepsilon},那么问题就是完整的变换即下式会是一个比\mathcal{L}_2更好的选择吗?
\begin{equation}\mathcal{L}_3 = \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}\left[\Vert\boldsymbol{\epsilon}_{\boldsymbol{\varphi}^*}\Vert^2 - 2\langle\boldsymbol{\epsilon}_{\boldsymbol{\varphi}^*}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle + \langle \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\boldsymbol{\varepsilon}\rangle\right]\label{eq:gloss-3}\end{equation}2、实际上SiD最终用的损失不是\mathcal{L}_2也不是\mathcal{L}_1,而是\mathcal{L}_2 - \lambda\mathcal{L}_1,其中\lambda > 0,并且实验发现\lambda的最优值在1附近,某些任务甚至在\lambda=1.2表现最好,这是非常让人困惑的,因为\mathcal{L}_1,\mathcal{L}_2是理论上相等的,所以\lambda > 1似乎在反向优化\mathcal{L}_1?这不就跟出发点相反了?显然这迫切需要一个理论解释。
直面梯度 #
再来回顾一下,我们面临的根本困难是:理论上\boldsymbol{\psi}^*是\boldsymbol{\theta}的函数,所以我们在求\nabla_{\boldsymbol{\theta}} \mathcal{L}_1或\nabla_{\boldsymbol{\theta}} \mathcal{L}_2时,需要想办法求\nabla_{\boldsymbol{\theta}}\boldsymbol{\psi}^*,但实践中我们至多可以得到\mathcal{L}_i^{\color{skyblue}{(\text{sg})}} \triangleq \mathcal{L}_i|_{\boldsymbol{\psi}^* \to \color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}},其中\color{skyblue}{\text{sg}}是stop gradient的意思,即无法获取\boldsymbol{\psi}^*关于\boldsymbol{\theta}的梯度,所以不论\mathcal{L}_1,\mathcal{L}_2,\mathcal{L}_3,它们在实践中的梯度都是有偏的。
这时候就轮到FGM登场了,它的想法更贴近本质:损失\mathcal{L}_1,\mathcal{L}_2,\mathcal{L}_3都只关注到了损失层面的相等性,但对于优化器来说我们需要的是梯度层面的相等,所以我们需要想办法找一个新的损失函数\mathcal{L}_4,使得它满足
\begin{equation}\nabla_{\boldsymbol{\theta}}\mathcal{L}_4(\boldsymbol{\theta}, \color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]})= \nabla_{\boldsymbol{\theta}}\mathcal{L}_{1/2/3}(\boldsymbol{\theta}, \boldsymbol{\psi}^*)\end{equation}
即\nabla_{\boldsymbol{\theta}}\mathcal{L}_4^{\color{skyblue}{(\text{sg})}} = \nabla_{\boldsymbol{\theta}}\mathcal{L}_{1/2/3},那么以\mathcal{L}_4为损失函数时,就可以实现无偏的优化效果了。
FGM的推导同样基于恒等式\eqref{eq:id},不过它的原始推导有点繁琐,对于本文来说可以直接从\mathcal{L}_3即式\eqref{eq:gloss-3}出发,它跟\boldsymbol{\psi}^*相关的项就只剩下\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle],我们直接把它的梯度算出来,方法将“先恒等变换后求梯度”和“先求梯度后恒等变换”分别应用于\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\Vert\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\Vert^2]操作一遍,对比它们的结果。
先恒等变换后求梯度:
\begin{equation}\begin{aligned}
&\,\nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\Vert\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\Vert^2] \\[5pt]
=&\, \nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle] = \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle] \\[5pt]
=&\, \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle] + \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\color{skyblue}{\text{sg}[}\boldsymbol{x}_t^{(g)}\color{skyblue}{]},t),\boldsymbol{\varepsilon}\rangle]
\end{aligned}\label{eq:g-grad-1}\end{equation}
先求梯度后恒等变换:
\begin{equation}\begin{aligned}
&\,\nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\Vert\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\Vert^2] \\[8pt]
=&\, \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\nabla_{\boldsymbol{\theta}}\Vert\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\Vert^2] = 2\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle\nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t), \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\rangle] \\[8pt]
=&\, 2\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle\nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t), \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\rangle] + \underbrace{2\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle\nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\color{skyblue}{\text{sg}[}\boldsymbol{x}_t^{(g)}\color{skyblue}{]},t), \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\rangle]}_{\text{可以应用式}\eqref{eq:id}} \\[5pt]
=&\, 2\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle\nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t), \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\rangle] + 2\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle\nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\color{skyblue}{\text{sg}[}\boldsymbol{x}_t^{(g)}\color{skyblue}{]},t), \boldsymbol{\varepsilon}\rangle]
\end{aligned}\label{eq:g-grad-2}\end{equation}
这里要注意第三个等号,只有\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\color{skyblue}{\text{sg}[}\boldsymbol{x}_t^{(g)}\color{skyblue}{]},t)这一项才可以应用恒等式\eqref{eq:id},因为\nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t)的\boldsymbol{x}_t^{(g)}要对\boldsymbol{\theta}求梯度,求完梯度后就不一定是\boldsymbol{x}_t^{(g)}的函数了,所以不满足应用式\eqref{eq:id}的条件。
现在对于\nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\Vert\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\Vert^2]我们有两个结果,将式\eqref{eq:g-grad-1}乘以2然后减去式\eqref{eq:g-grad-2}得到
\begin{equation}\begin{aligned}
&\,\nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle] = \nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\Vert\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\Vert^2] = \eqref{eq:g-grad-1}\times 2 - \eqref{eq:g-grad-2} \\[5pt]
=&\,2 \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle] - 2\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle\nabla_{\boldsymbol{\theta}}\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t), \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)\rangle] \\[5pt]
=&\,2 \nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle] - \nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\Vert\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t)\Vert^2] \\[5pt]
=&\,\nabla_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[2\langle \boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle - \Vert\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t)\Vert^2]
\end{aligned}\end{equation}
留意最后被求梯度的式子,它所有的\boldsymbol{\psi}^*都被加上了\color{skyblue}{\text{sg}},说明我们不用设法求它关于\boldsymbol{\theta}的梯度了,但它的梯度等于\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}[\langle \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle]的准确梯度,所以用它来替换掉\mathcal{L}_3的对应项,我们就得到了\mathcal{L}_4:
\begin{equation}\mathcal{L}_4^{\color{skyblue}{(\text{sg})}} = \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})}\left[\Vert\boldsymbol{\epsilon}_{\boldsymbol{\varphi}^*}\Vert^2 - 2\langle\boldsymbol{\epsilon}_{\boldsymbol{\varphi}^*}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle + 2\langle \boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t),\boldsymbol{\varepsilon}\rangle - \Vert\boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t)\Vert^2\right]\end{equation}
这就是FGM的最终结果,它只依赖于\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]},但成立\nabla_{\boldsymbol{\theta}}\mathcal{L}_4^{\color{skyblue}{(\text{sg})}}=\nabla_{\boldsymbol{\theta}}\mathcal{L}_{1/2/3}。再仔细观察一下,就会发现成立\mathcal{L}_4^{\color{skyblue}{(\text{sg})}}=2\mathcal{L}_2^{\color{skyblue}{(\text{sg})}}-\mathcal{L}_1^{\color{skyblue}{(\text{sg})}}=2(\mathcal{L}_2^{\color{skyblue}{(\text{sg})}}-0.5\times \mathcal{L}_1^{\color{skyblue}{(\text{sg})}}),所以FGM相当于从梯度角度肯定了SiD的\lambda=0.5的选择。
顺便说一下,FGM原论文的描述是在ODE式扩散框架(flow matching)内进行的,但正如笔者在上一篇文章所说,不管是SiD还是FGM,它实际并没有用到扩散模型的迭代生成过程,而是只用到了扩散模型所训练的去噪模型,所以不管是ODE、SDE还是DDPM框架都只是表象,它的去噪模型才是本质,所以本文可以接着上一篇SiD的记号来介绍FGM。
广义散度 #
FGM已经成功地求出了最本质的梯度,但这只能解释SiD的\lambda=0.5,这意味着如果我们需要解释其他\lambda值的可行性,就必须修改出发点了。为此,我们回到原点,反思一下生成器的目标\eqref{eq:gloss-1}。
熟悉扩散模型的读者应该都知道,式\eqref{eq:tloss}的理论最优解还可以写成\boldsymbol{\epsilon}_{\boldsymbol{\varphi}^*}(\boldsymbol{x}_t,t)=-\bar{\beta}_t\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t),同理式\eqref{eq:dloss}的最优解则是\boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\boldsymbol{x}_t^{(g)},t)=-\bar{\beta}_t\nabla_{\boldsymbol{x}_t^{(g)}}\log p_{\boldsymbol{\theta}}(\boldsymbol{x}_t^{(g)}),这里的p(\boldsymbol{x}_t)、p_{\boldsymbol{\theta}}(\boldsymbol{x}_t^{(g)})分别是真实数据、生成器数据加噪的分布,如果不了解这个结果,可以参考《生成扩散模型漫谈(五):一般框架之SDE篇》、《生成扩散模型漫谈(十八):得分匹配 = 条件得分匹配》等介绍。
将这两个理论最优解代回式\eqref{eq:gloss-1},我们会发现生成器实际上在试图最小化Fisher散度:
\begin{equation}\begin{aligned}
\mathcal{F}(p, p_{\boldsymbol{\theta}}) =&\, \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})} \left[\Vert \nabla_{\boldsymbol{x}_t^{(g)}}\log p_{\boldsymbol{\theta}}(\boldsymbol{x}_t^{(g)}) - \nabla_{\boldsymbol{x}_t^{(g)}}\log p(\boldsymbol{x}_t^{(g)})\Vert^2\right] \\
=&\, \int p_{\boldsymbol{\theta}}(\boldsymbol{x}_t^{(g)}) \left\Vert \nabla_{\boldsymbol{x}_t^{(g)}}\log p_{\boldsymbol{\theta}}(\boldsymbol{x}_t^{(g)}) - \nabla_{\boldsymbol{x}_t^{(g)}}\log p(\boldsymbol{x}_t^{(g)})\right\Vert^2 d\boldsymbol{x}_t^{(g)}
\end{aligned}\end{equation}
我们要反思的事情,就是Fisher散度的合理性和改进点。可以看到,Fisher散度里边p_{\boldsymbol{\theta}}出现了两次,现在我们来请读者思考一个问题:这两处p_{\boldsymbol{\theta}}中哪一处更重要呢?
答案是第二处。为了理解这个事实,我们不妨考虑两种情况:1、固定第一处p_{\boldsymbol{\theta}},只优化第二处p_{\boldsymbol{\theta}};2、固定第二处p_{\boldsymbol{\theta}},只优化第一处p_{\boldsymbol{\theta}}。它们的结果有什么区别呢?第一种情况大概率不会有什么变化,即依然能学到p_{\boldsymbol{\theta}}=p,事实上由于Fisher散度带有\Vert\Vert^2,所以下面更一般的结论几乎是显然成立的:
只要r(\boldsymbol{x})是一个处处不为零的分布,那么p(\boldsymbol{x})=q(\boldsymbol{x})依然是如下广义Fisher散度的理论最优解: \begin{equation}\mathcal{F}(p,q|r) = \int r(\boldsymbol{x}) \Vert \nabla_{\boldsymbol{x}} p(\boldsymbol{x}) - \nabla_{\boldsymbol{x}} q(\boldsymbol{x})\Vert^2 d\boldsymbol{x}\end{equation}
说简单点,就是第一处p_{\boldsymbol{\theta}}根本不重要,换成其他分布都行,单靠\Vert\Vert^2就能保证两个分布相等。但第二种情况就不一样了,固定第二处p_{\boldsymbol{\theta}}只优化第一处p_{\boldsymbol{\theta}}的理论最优解是
\begin{equation}p_{\boldsymbol{\theta}}(\boldsymbol{x}_t^{(g)}) = \delta(\boldsymbol{x}_t^{(g)} - \boldsymbol{x}_t^*),\quad \boldsymbol{x}_t^* = \mathop{\text{argmin}}_{\boldsymbol{x}_t^{(g)}} \,\left\Vert \nabla_{\boldsymbol{x}_t^{(g)}}\log p_{\boldsymbol{\theta}}(\boldsymbol{x}_t^{(g)}) - \nabla_{\boldsymbol{x}_t^{(g)}}\log p(\boldsymbol{x}_t^{(g)})\right\Vert^2\end{equation}
其中\delta是狄拉克delta分布,即模型只需要生成让\Vert\Vert^2最小的那个样本,就可以让损失最小,这说白了就是模式坍缩(Mode Collapse)!所以,Fisher散度中的第一处p_{\boldsymbol{\theta}}的作用不单单是次要的,甚至还可能是负面的。
这启发我们,当我们使用基于梯度的优化器来训练模型时,第一处p_{\boldsymbol{\theta}}的梯度干脆不要还会更好,即下述形式的Fisher散度是一个更好的选择
\begin{equation}\begin{aligned}
\mathcal{F}^+(p, p_{\boldsymbol{\theta}}) =&\, \int p_{\color{skyblue}{\text{sg}[}\boldsymbol{\theta}\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)}) \left\Vert \nabla_{\boldsymbol{x}_t^{(g)}}\log p_{\boldsymbol{\theta}}(\boldsymbol{x}_t^{(g)}) - \nabla_{\boldsymbol{x}_t^{(g)}}\log p(\boldsymbol{x}_t^{(g)})\right\Vert^2 d\boldsymbol{x}_t^{(g)} \\[5pt]
=&\, \mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})} \left[\Vert \nabla_{\boldsymbol{x}_t^{(g)}}\log p_{\boldsymbol{\theta}}(\color{skyblue}{\text{sg}[}\boldsymbol{x}_t^{(g)}\color{skyblue}{]}) - \nabla_{\boldsymbol{x}_t^{(g)}}\log p(\color{skyblue}{\text{sg}[}\boldsymbol{x}_t^{(g)}\color{skyblue}{]})\Vert^2\right] \\[5pt]
\propto&\, \underbrace{\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})} \left[\Vert \boldsymbol{\epsilon}_{\boldsymbol{\varphi}^*}(\color{skyblue}{\text{sg}[}\boldsymbol{x}_t^{(g)}\color{skyblue}{]},t) - \boldsymbol{\epsilon}_{\boldsymbol{\psi}^*}(\color{skyblue}{\text{sg}[}\boldsymbol{x}_t^{(g)}\color{skyblue}{]},t)\Vert^2\right]}_{\mathcal{L}_5}
\end{aligned}\end{equation}
也就是说,这里的\mathcal{L}_5极有可能会是一个比\mathcal{L}_1更好的出发点,它数值上跟\mathcal{L}_1是相等的,但少了一部分梯度:
\begin{equation}\nabla_{\boldsymbol{\theta}}\mathcal{L}_5 = \nabla_{\boldsymbol{\theta}}\mathcal{L}_1 - \nabla_{\boldsymbol{\theta}}\underbrace{\mathbb{E}_{\boldsymbol{z},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})} \left[\Vert \boldsymbol{\epsilon}_{\boldsymbol{\varphi}^*}(\boldsymbol{x}_t^{(g)},t) - \boldsymbol{\epsilon}_{\color{skyblue}{\text{sg}[}\boldsymbol{\psi}^*\color{skyblue}{]}}(\boldsymbol{x}_t^{(g)},t)\Vert^2\right]}_{\text{刚好是}\mathcal{L}_1^{\color{skyblue}{(\text{sg})}}}\end{equation}
其中\nabla_{\boldsymbol{\theta}}\mathcal{L}_1已经由FGM算出来了,它等于\nabla_{\boldsymbol{\theta}}(2\mathcal{L}_2^{\color{skyblue}{(\text{sg})}}-\mathcal{L}_1^{\color{skyblue}{(\text{sg})}}),因此以\mathcal{L}_5为出发点,我们实践中的损失函数是2\mathcal{L}_2^{\color{skyblue}{(\text{sg})}}-\mathcal{L}_1^{\color{skyblue}{(\text{sg})}}-\mathcal{L}_1^{\color{skyblue}{(\text{sg})}}=2(\mathcal{L}_2^{\color{skyblue}{(\text{sg})}}-\mathcal{L}_1^{\color{skyblue}{(\text{sg})}}),这就解释了\lambda=1的选择。至于\lambda稍大于1的选择,则更为极端一些,它相当于在\mathcal{L}_5的基础上将-\mathcal{L}_1^{\color{skyblue}{(\text{sg})}}作为额外的惩罚项,进一步降低模式坍缩的风险,当然这里真就是单纯的惩罚项,所以权重就不能太大了,根据SiD的实验结果,\lambda=1.5的时候已经开始训崩了。
顺便说一下,FGM之前作者还有个作品《One-Step Diffusion Distillation through Score Implicit Matching》,里边也提出了类似的对第一处p_{\boldsymbol{\theta}}改为p_{\color{skyblue}{\text{sg}[}\boldsymbol{\theta}\color{skyblue}{]}}的做法,但没有明确地从Fisher散度的原始形式出发讨论该操作的合理性,稍欠完整。
文章小结 #
本文介绍了SiD(Score identity Distillation)的后续理论进展,主要内容是从梯度视角解释了SiD中的\lambda参数设置,核心部分是由FGM(Flow Generator Matching)发现的准确估计SiD梯度的巧妙思路,这肯定了\lambda=0.5的选择,在此基础上,笔者拓展了Fisher散度的概念,从而解释了\lambda=1的取值。
转载到请包括本文地址:https://spaces.ac.cn/archives/10567
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Nov. 22, 2024). 《生成扩散模型漫谈(二十六):基于恒等式的蒸馏(下) 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/10567
@online{kexuefm-10567,
title={生成扩散模型漫谈(二十六):基于恒等式的蒸馏(下)},
author={苏剑林},
year={2024},
month={Nov},
url={\url{https://spaces.ac.cn/archives/10567}},
}
November 23rd, 2024
感谢作者的分享,在广义散度的部分,我理解这和FGM团队更早的这篇SIM[1]的思路是类似的?
[1] Luo, Weijian, et al. "One-Step Diffusion Distillation through Score Implicit Matching." (NeurIPS 2024) (https://arxiv.org/pdf/2410.16794)
嗯嗯,最后的\mathcal{F}^+(p, p_{\boldsymbol{\theta}})形式上差不多,不过两者的出发点和侧重点都不一样。我把它补充到正文吧~
November 23rd, 2024
哇,下终于出了
November 23rd, 2024
苏博士,你也太高产啦,仰慕!
November 30th, 2024
不管是FGM还是Score Implicit Matching(SIM),都受到了SiD的启发。某种程度上看,SIM和FGM给SiD的偏经验性的损失函数给了一个理论解释,并且将L2 distance直接推广到了general score-based divergence下进行讨论。感谢苏老师分享,博客写的非常清楚,点赞!
欢迎罗老师莅临指导
December 1st, 2024
能不能列个清单推荐下,要读哪些书,学什么课程才能跟着推你的公式?
近几年已经没有系统读过什么数学书了,都是边学边补的,基础的话就把本科的数学分析、线性代数和统计概率学透一点就差不多了。