生成扩散模型漫谈(十八):得分匹配 = 条件得分匹配
By 苏剑林 | 2023-02-28 | 27539位读者 |在前面的介绍中,我们多次提及“得分匹配”和“条件得分匹配”,它们是扩散模型、能量模型等经常出现的概念,特别是很多文章直接说扩散模型的训练目标是“得分匹配”,但事实上当前主流的扩散模型如DDPM的训练目标是“条件得分匹配”才对。
那么“得分匹配”与“条件得分匹配”具体是什么关系呢?它们两者是否等价呢?本文详细讨论这个问题。
得分匹配 #
首先,得分匹配(Score Matching)是指训练目标:
\begin{equation}\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t) - \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2\right]\label{eq:sm}\end{equation}
其中$\boldsymbol{\theta}$是训练参数。很明显,得分匹配是想学习一个模型$\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)$来逼近$\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)$,这里的$\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)$我们就称为“得分”。
在扩散模型场景,$p_t(\boldsymbol{x}_t)$由下式给出:
\begin{equation}p_t(\boldsymbol{x}_t) = \int p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)p_0(\boldsymbol{x}_0)d\boldsymbol{x}_0 = \mathbb{E}_{\boldsymbol{x}_0\sim p_0(\boldsymbol{x}_0)}\left[p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]\label{eq:pt}\end{equation}
其中$p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)$一般都是已知概率密度解析式的简单分布(如条件正态分布),$p_0(\boldsymbol{x}_0)$也是给定的分布,但一般代表训练数据,也就是说我们只能从$p_0(\boldsymbol{x}_0)$中采样,但不知道$p_0(\boldsymbol{x}_0)$的具体表达式。
根据式$\eqref{eq:pt}$,我们可以推导得
\begin{equation}\begin{aligned}
\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t) =&\, \frac{\nabla_{\boldsymbol{x}_t}p_t(\boldsymbol{x}_t)}{p_t(\boldsymbol{x}_t)} \\
=&\, \frac{\int p_0(\boldsymbol{x}_0)\nabla_{\boldsymbol{x}_t} p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)d\boldsymbol{x}_0}{p_t(\boldsymbol{x}_t)} \\
=&\, \frac{\int p_0(\boldsymbol{x}_0)\nabla_{\boldsymbol{x}_t} p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)d\boldsymbol{x}_0}{\int p_0(\boldsymbol{x}_0) p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)d\boldsymbol{x}_0} \\
=&\, \frac{\mathbb{E}_{\boldsymbol{x}_0\sim p_0(\boldsymbol{x}_0)}\left[\nabla_{\boldsymbol{x}_t}p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]}{\mathbb{E}_{\boldsymbol{x}_0\sim p_0(\boldsymbol{x}_0)}\left[p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]} \\
\end{aligned}\label{eq:score-1}\end{equation}
根据我们的假设,$\nabla_{\boldsymbol{x}_t}p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)$和$p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)$都是已知解析式的,因此理论上我们可以通过采样$\boldsymbol{x}_0$来估计$\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)$,但由于这里涉及到两个期望的除法,是一个有偏估计(参考《简述无偏估计和有偏估计》),因此需要采样足够多的点才能做出比较准确的估计,因此直接用式$\eqref{eq:sm}$作为训练目标的话,需要较大的batch_size才有比较好的效果。
条件得分 #
实际上,一般扩散模型所用的训练目标是“条件得分匹配(Conditional Score Matching)”:
\begin{equation}\mathbb{E}_{\boldsymbol{x}_0,\boldsymbol{x}_t\sim p_0(\boldsymbol{x}_0)p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0) - \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2\right]\end{equation}
根据假设,$\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)$是已知解析式的,因此上述目标是直接可用的,每次采样一对$(\boldsymbol{x}_0,\boldsymbol{x}_t)$进行估算。特别地,这是一个无偏估计,这意味着它不是特别依赖于大batch_size,因此是一个比较实用的训练目标。
为了分析“得分匹配”与“条件得分匹配”的关系,我们还需要$\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)$的另一个恒等式:
\begin{equation}\begin{aligned}
\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t) =&\, \frac{\nabla_{\boldsymbol{x}_t}p_t(\boldsymbol{x}_t)}{p_t(\boldsymbol{x}_t)} \\
=&\, \frac{\int p_0(\boldsymbol{x}_0)\nabla_{\boldsymbol{x}_t} p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)d\boldsymbol{x}_0}{p_t(\boldsymbol{x}_t)} \\
=&\, \frac{\int p_0(\boldsymbol{x}_0)p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\nabla_{\boldsymbol{x}_t} \log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0) d\boldsymbol{x}_0}{p_t(\boldsymbol{x}_t)} \\
=&\, \int p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)\nabla_{\boldsymbol{x}_t} \log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0) d\boldsymbol{x}_0 \\
=&\, \mathbb{E}_{\boldsymbol{x}_0\sim p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\nabla_{\boldsymbol{x}_t} \log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right] \\
\end{aligned}\label{eq:score-2}\end{equation}
不等关系 #
首先,我们可以很快速地证明两者之间的第一个结果:条件得分匹配是得分匹配的一个上界。这也就意味着最小化条件得分匹配,某种程度上也在最小化得分匹配。
证明并不困难,之前我们在《生成扩散模型漫谈(十六):W距离 ≤ 得分匹配》就已经证明过:
\begin{equation}\begin{aligned}
&\,\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t) - \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2\right] \\
=&\,\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\left[\left\Vert\mathbb{E}_{\boldsymbol{x}_0\sim p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right] - \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2\right] \\
\leq &\,\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_0\sim p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0) - \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2\right] \\
= &\,\mathbb{E}_{\boldsymbol{x}_0\sim p_0(\boldsymbol{x}_0),\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0) - \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2\right] \\
\end{aligned}\end{equation}
第一个等号是因为恒等式$\eqref{eq:score-2}$,第二个不等号则是因为平方平均不等式的推广或者詹森不等式,第三个等号则是贝叶斯公式了。
等价关系 #
前两天,在微信群里大家讨论到得分匹配的时候,有群友指点到:条件得分匹配与得分匹配之差是一个跟优化无关的常数,所以两者实际上是完全等价的!刚听到这个结论的时候笔者也相当惊讶,两者居然还是等价关系,而不单单是上下界关系。不仅如此,笔者尝试证明了一下后,发现证明过程居然也很简单!
首先,关于得分匹配,我们有
\begin{equation}\begin{aligned}
&\,\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t) - \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2\right] \\
=&\,\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\left[\color{orange}{\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)\right\Vert^2} + \color{red}{\left\Vert\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2} - \color{green}{2\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\cdot\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)}\right] \\
\end{aligned}\end{equation}
然后,关于条件得分匹配,我们有
\begin{equation}\begin{aligned}
&\,\mathbb{E}_{\boldsymbol{x}_0,\boldsymbol{x}_t\sim p_0(\boldsymbol{x}_0)p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0) - \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2\right] \\[5pt]
=&\,\mathbb{E}_{\boldsymbol{x}_0,\boldsymbol{x}_t\sim p_0(\boldsymbol{x}_0)p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2 + \left\Vert\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2 - 2\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\cdot\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right] \\[5pt]
=&\,\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t),\boldsymbol{x}_0\sim p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2 + \left\Vert\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2 - 2\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\cdot\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right] \\[5pt]
=&\,\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\left[{\begin{aligned}&\color{orange}{\mathbb{E}_{\boldsymbol{x}_0\sim p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2\right]} + \color{red}{\left\Vert\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2} \\
&\qquad\qquad- \color{green}{2\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\cdot\mathbb{E}_{\boldsymbol{x}_0\sim p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]}\end{aligned}}\right] \\[5pt]
=&\,\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\left[\color{orange}{\mathbb{E}_{\boldsymbol{x}_0\sim p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2\right]} + \color{red}{\left\Vert\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right\Vert^2} - \color{green}{2\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\cdot\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)}\right] \\
\end{aligned}\end{equation}
两者作差,可以发现结果是
\begin{equation}\mathbb{E}_{\boldsymbol{x}_t\sim p_t(\boldsymbol{x}_t)}\left[\color{orange}{\mathbb{E}_{\boldsymbol{x}_0\sim p_t(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2\right] - \left\Vert\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)\right\Vert^2}\right]\end{equation}
它跟参数$\boldsymbol{\theta}$无关。所以最小化得分匹配目标,跟最小化条件得分匹配,在理论上是等价的。根据群友介绍,相关结果首次出现在文章《A Connection Between Score Matching and Denoising Autoencoders》中。
既然两者理论上等价,这是不是意味着我们前面的“得分匹配”比“条件得分匹配”需要更大batch_size的说法不成立?并不是。如果还是直接从式$\eqref{eq:score-1}$来估计$\nabla_{\boldsymbol{x}_t}\log p_t(\boldsymbol{x}_t)$然后做得分匹配,那么结果确实还是有偏的,依赖于大batch_size;而当我们将目标$\eqref{eq:sm}$展开进一步化简后,已经逐步将有偏估计转化为无偏估计了,这时候不会太依赖于batch_size。也就是说,虽然两个目标从理论上是等价的,从统计量的角度,属于不同性质的统计量,它们的等价仅仅是在采样样本数趋于无穷时的精确等价。
文章小结 #
本文主要分析“得分匹配”和“条件得分匹配”两个训练目标之间的关联。
转载到请包括本文地址:https://spaces.ac.cn/archives/9509
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Feb. 28, 2023). 《生成扩散模型漫谈(十八):得分匹配 = 条件得分匹配 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9509
@online{kexuefm-9509,
title={生成扩散模型漫谈(十八):得分匹配 = 条件得分匹配},
author={苏剑林},
year={2023},
month={Feb},
url={\url{https://spaces.ac.cn/archives/9509}},
}
March 17th, 2023
确实,理论上来说,得分匹配和条件得分匹配是等价的,但是在实际应用中,两者的差异还是需要注意的。如你所说,对于得分匹配,如果依赖于大batch_size来估计梯度,会导致偏差,而条件得分匹配则不存在这个问题。在实际使用中,需要根据具体情况选择合适的训练目标。
March 19th, 2023
苏老师,可以问下怎么加微信群吗
PC版主页右上角有
March 28th, 2023
一口气读完苏老师的十八篇扩散模型漫谈,对扩散模型的认知站在了更高的视角,我被苏老师深厚的数学物理功底所折服,您写的文章详略得当,注重让读者参与到公式的推导之中,我在阅读的过程中曾多次推导出自己理解不了的等号和变换,作为电院的人工智能专业的本科生,我深深地被数学的魅力感染到了,同时我也越发感受到了自己数学体系的不完善,能看懂这些理论完全是因为苏老师把这些知识“嚼碎了喂给我”,但是并没有系统的学过文章中涉及的重要概念、定理,我想这些会是我所欠缺的,希望苏老师您能指点一下,此系列文章我受益良多!
指点不敢,你希望我提哪方面的建议?
感谢苏老师回复!我的评论中提到我认为我的数学体系不完善,目前我已学过数学分析、解几高代等基础课,我想这是和数学或物理专业有差距的,我想要请教您我该怎么弥补或者另外自学哪几门课才能够真正有人工智能的数学基础(我自认为我的基础和悟性不算特别差,自学入门ml、dl和rl一年从基础到比较前沿的模型基本上很少有特别不能理解的,但是对其中用到的概念和定理却只是浅层的理解,没有去溯源,比如w距离的Lipschitz约束)
September 18th, 2024
您好,非常感谢您的推导和观点分享。
关于最后‘等价关系’推导时,由于作差的二者是优化目标,sθ(xt,t)应该是在优化过程中或者优化完成后对应的网络参数,而根据前面的论述这二者优化目标中无条件得分对应的网络参数易受到batch size的影响,有条件得分不易受到batch size的影响,从而会使得这两个优化目标的sθ(xt,t)的网络参数不一致(不管是在优化中还是优化完成后),应该不能直接相减而消掉,继而推导出二者优化目标等价的结论和后续观点。所以这两个优化目标应该是不等价的。请问这样分析有道理吗?
等价关系说的是理论等价性,可以理解为batch_size无穷大时的等价性,实际情况下batch_size有限,自然是不等价的