对于生成扩散模型来说,一个很关键的问题是生成过程的方差应该怎么选择,因为不同的方差会明显影响生成效果。

《生成扩散模型漫谈(二):DDPM = 自回归式VAE》我们提到,DDPM分别假设数据服从两种特殊分布推出了两个可用的结果;《生成扩散模型漫谈(四):DDIM = 高观点DDPM》中的DDIM则调整了生成过程,将方差变为超参数,甚至允许零方差生成,但方差为0的DDIM的生成效果普遍差于方差非0的DDPM;而《生成扩散模型漫谈(五):一般框架之SDE篇》显示前、反向SDE的方差应该是一致的,但这原则上在Δt0时才成立;《Improved Denoising Diffusion Probabilistic Models》则提出将它视为可训练参数来学习,但会增加训练难度。

所以,生成过程的方差究竟该怎么设置呢?今年的两篇论文《Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models》《Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models》算是给这个问题提供了比较完美的答案。接下来我们一起欣赏一下它们的结果。

不确定性 #

事实上,这两篇论文出自同一团队,作者也基本相同。第一篇论文(简称Analytic-DPM)下面简称在DDIM的基础上,推导了无条件方差的一个解析解;第二篇论文(简称Extended-Analytic-DPM)则弱化了第一篇论文的假设,并提出了有条件方差的优化方法。本文首先介绍第一篇论文的结果。

《生成扩散模型漫谈(四):DDIM = 高观点DDPM》中,我们推导了对于给定的p(xt|x0)=N(xt;ˉαtx0,ˉβ2tI),对应的p(xt1|xt,x0)的一般解为
p(xt1|xt,x0)=N(xt1;ˉβ2t1σ2tˉβtxt+γtx0,σ2tI)
其中γt=ˉαt1ˉαtˉβ2t1σ2tˉβtσt就是可调的标准差参数。在DDIM中,接下来的处理流程是:用ˉμ(xt)来估计x0,然后认为
p(xt1|xt)p(xt1|xt,x0=ˉμ(xt))
然而,从贝叶斯的角度来看,这个处理是非常不妥的,因为从xt预测x0不可能完全准确,它带有一定的不确定性,因此我们应该用概率分布而非确定性的函数来描述它。事实上,严格地有
p(xt1|xt)=p(xt1|xt,x0)p(x0|xt)dx0
精确的p(x0|xt)通常是没法获得的,但这里只要一个粗糙的近似,因此我们用正态分布N(x0;ˉμ(xt),ˉσ2tI)去逼近它(如何逼近我们稍后再讨论)。有了这个近似分布后,我们可以写出
xt1=ˉβ2t1σ2tˉβtxt+γtx0+σtε1ˉβ2t1σ2tˉβtxt+γt(ˉμ(xt)+ˉσtε2)+σtε1=(ˉβ2t1σ2tˉβtxt+γtˉμ(xt))+(σtε1+γtˉσtε2)σ2t+γ2tˉσ2tε
其中ε1,ε2,εN(0,I)。可以看到,p(xt1|xt)更加接近均值为ˉβ2t1σ2tˉβtxt+γtˉμ(xt)、协方差为(σ2t+γ2tˉσ2t)I的正态分布,其中均值跟以往的结果是一致的,不同的是方差多出了γ2tˉσ2t这一项,因此即便σt=0,对应的方差也不为0。多出来的这一项,就是第一篇论文所提的最优方差的修正项。

均值优化 #

现在我们来讨论如何用N(x0;ˉμ(xt),ˉσ2tI)去逼近真实的p(x0|xt),说白了就是求出p(x0|xt)的均值和协方差。

对于均值ˉμ(xt)来说,它依赖于xt,所以需要一个模型来拟合它,而训练模型就需要损失函数。利用
Ex[x]=argminμEx[xμ2]
我们得到
ˉμ(xt)=Ex0p(x0|xt)[x0]=argminμEx0p(x0|xt)[x0μ2]=argminμ(xt)Extp(xt)Ex0p(x0|xt)[x0μ(xt)2]=argminμ(xt)Ex0˜p(x0)Extp(xt|x0)[x0μ(xt)2]
这就是训练ˉμ(xt)所用的损失函数。如果像之前一样引入参数化
ˉμ(xt)=1ˉαt(xtˉβtϵθ(xt,t))
就可以得到DDPM训练所用的损失函数形式εϵθ(ˉαtx0+ˉβtε,t)2了。关于均值优化的结果是跟以往一致的,没有什么改动。

方差估计1 #

类似地,根据定义,协方差矩阵应该是
Σ(xt)=Ex0p(x0|xt)[(x0ˉμ(xt))(x0ˉμ(xt))]=Ex0p(x0|xt)[((x0μ)(ˉμ(xt)μ))((x0μ)(ˉμ(xt)μ))]=Ex0p(x0|xt)[(x0μ0)(x0μ0)](ˉμ(xt)μ0)(ˉμ(xt)μ0)
其中μ0可以是任意常向量,这对应于协方差的平移不变性。

上式估计的是完整的协方差矩阵,但并不是我们想要的,因为目前我们是想要用N(x0;ˉμ(xt),ˉσ2tI)去逼近p(x0|xt),其中设计的协方差矩阵为ˉσ2tI,它有两个特点:

1、跟xt无关:为了消除对xt的依赖,我们对全体xt求平均,即Σt=Extp(xt)[Σ(xt)]

2、单位阵的倍数:这意味着我们只用考虑对角线部分,并且对对角线元素取平均,即ˉσ2t=Tr(Σt)/d,其中d=dim(x)

于是我们有
ˉσ2t=Extp(xt)Ex0p(x0|xt)[x0μ02d]Extp(xt)[ˉμ(xt)μ02d]=1dEx0˜p(x0)Extp(xt|x0)[x0μ02]1dExtp(xt)[ˉμ(xt)μ02]=1dEx0˜p(x0)[x0μ02]1dExtp(xt)[ˉμ(xt)μ02]
这是笔者给出的关于ˉσ2t的一个解析形式,在ˉμ(xt)完成训练的情况下,可以通过采样一批x0xt来近似计算上式。

特别地,如果取μ0=Ex0˜p(x0)[x0],那么刚好可以写成
ˉσ2t=Var[x0]1dExtp(xt)[ˉμ(xt)μ02]
这里的Var[x0]是全体训练数据x0的像素级方差。如果x0的每个像素值都在[a,b]区间内,那么它的方差显然不会超过(ba2)2,从而有不等式
ˉσ2tVar[x0](ba2)2

方差估计2 #

刚才的解是笔者给出的、认为比较直观的一个解,Analytic-DPM原论文则给出了一个略有不同的解,但笔者认为相对来说没那么直观。通过代入式(7),我们可以得到:
Σ(xt)=Ex0p(x0|xt)[(x0ˉμ(xt))(x0ˉμ(xt))]=Ex0p(x0|xt)[((x0xtˉαt)+ˉβtˉαtϵθ(xt,t))((x0xtˉαt)+ˉβtˉαtϵθ(xt,t))]=Ex0p(x0|xt)[(x0xtˉαt)(x0xtˉαt)]ˉβ2tˉα2tϵθ(xt,t)ϵθ(xt,t)=1ˉα2tEx0p(x0|xt)[(xtˉαtx0)(xtˉαtx0)]ˉβ2tˉα2tϵθ(xt,t)ϵθ(xt,t)
此时如果两端对xtp(xt)求平均,我们有
Extp(xt)Ex0p(x0|xt)[(xtˉαtx0)(xtˉαtx0)]=Ex0˜p(x0)Extp(xt|x0)[(xtˉαtx0)(xtˉαtx0)]
别忘了p(xt|x0)=N(xt;ˉαtx0,ˉβ2tI),所以ˉαtx0实际上就是p(xt|x0)的均值,那么Extp(xt|x0)[(xtˉαtx0)(xtˉαtx0)]实际上是在求p(xt|x0)的均值的协方差矩阵,结果显然就是ˉβ2tI,所以
Ex0˜p(x0)Extp(xt|x0)[(xtˉαtx0)(xtˉαtx0)]=Ex0˜p(x0)[ˉβ2tI]=ˉβ2tI
那么
Σt=Extp(xt)[Σ(xt)]=ˉβ2tˉα2t(IExtp(xt)[ϵθ(xt,t)ϵθ(xt,t)])
两边取迹然后除以d,得到
ˉσ2t=ˉβ2tˉα2t(11dExtp(xt)[ϵθ(xt,t)2])ˉβ2tˉα2t
这就得到了另一个估计和上界,这就是Analytic-DPM的原始结果。

实验结果 #

原论文的实验结果显示,Analytic-DPM所做的方差修正,主要在生成扩散步数较少时会有比较明显的提升,所以它对扩散模型的加速比较有意义:

Analytic-DPM主要在扩散步数较少时会有比较明显的效果提升

Analytic-DPM主要在扩散步数较少时会有比较明显的效果提升

笔者也在之前自己实现的代码上尝试了Analytic-DPM的修正,参考代码为:

当扩散步数为10时,DDPM与Analytic-DDPM的效果对比如下图:

DDPM在扩散步数为10时的生成结果

DDPM在扩散步数为10时的生成结果

Analytic-DDPM在扩散步数为10时的生成结果

Analytic-DDPM在扩散步数为10时的生成结果

可以看到,在扩散步数较小时,DDPM的生成效果比较光滑,有点“重度磨皮”的感觉,相比之下Analytic-DDPM的结果显得更真实一些,但是也带来了额外的噪点。从评价指标来说,Analytic-DDPM要更好一些。

吹毛求疵 #

至此,我们已经完成了Analytic-DPM的介绍,推导过程略带有一些技巧性,但不算太复杂,至少思路上还是很明朗的。如果读者觉得还是很难懂,那不妨再去看看原论文在附录中用7页纸、13个引理完成的推导,想必看到之后就觉得本文的推导是多么友好了哈哈~

诚然,从首先得到这个方差的解析解来说,我为原作者们的洞察力而折服,但不得不说的是,从“事后诸葛亮”的角度来说,Analytic-DPM在推导和结果上都走了一些的“弯路”,显得“太绕”、”太巧“,从而感觉不到什么启发性。其中,一个最明显的特点是,原论文的结果都用了xtlogp(xt)来表达,这就带来了三个问题:一来使得推导过程特别不直观,难以理解“怎么想到的”;二来要求读者额外了解得分匹配的相关结果,增加了理解难度;最后落到实践时,xtlogp(xt)又要用回ˉμ(xt)ϵθ(xt,t)来表示,多绕一道。

本文推导的出发点是,我们是在估计正态分布的参数,对于正态分布来说,矩估计与最大似然估计相同,因此直接去估算相应的均值方差即可。结果上,没必要强行在形式上去跟xtlogp(xt)、得分匹配对齐,因为很明显Analytic-DPM的baseline模型是DDIM,DDIM本身就没有以得分匹配为出发点,增加与得分匹配的联系,于理论和实验都无益。直接跟ˉμ(xt)ϵθ(xt,t)对齐,形式上更加直观,而且更容易跟实验形式进行转换。

文章小结 #

本文分享了论文Analytic-DPM中的扩散模型最优方差估计结果,它给出了直接可用的最优方差估计的解析式,使得我们不需要重新训练就可以直接应用它来改进生成效果。笔者用自己的思路简化了原论文的推导,并进行了简单的实验验证。

转载到请包括本文地址:https://spaces.ac.cn/archives/9245

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Aug. 12, 2022). 《生成扩散模型漫谈(七):最优扩散方差估计(上) 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9245

@online{kexuefm-9245,
        title={生成扩散模型漫谈(七):最优扩散方差估计(上)},
        author={苏剑林},
        year={2022},
        month={Aug},
        url={\url{https://spaces.ac.cn/archives/9245}},
}