上周笔者写了《生成扩散模型漫谈(十四):构建ODE的一般步骤(上)》(当时还没有“上”这个后缀),本以为已经窥见了构建ODE扩散模型的一般规律,结果不久后评论区大神 @gaohuazuo 就给出了一个构建格林函数更高效、更直观的方案,让笔者自愧不如。再联想起之前大神之前在《生成扩散模型漫谈(十二):“硬刚”扩散ODE》同样也给出了一个关于扩散ODE的精彩描述(间接启发了上一篇博客的结果),大神的洞察力不得不让人叹服。

经过讨论和思考,笔者发现大神的思路本质上就是一阶偏微分方程的特征线法,通过构造特定的向量场保证初值条件,然后通过求解微分方程保证终值条件,同时保证了初值和终值条件,真的非常巧妙!最后,笔者将自己的收获总结成此文,作为上一篇的后续。

前情回顾 #

简单回顾一下上一篇文章的结果。假设随机变量x0Rd连续地变换成xT,其变化规律服从ODE
dxtdt=ft(xt)


那么对应的t时刻的分布pt(xt)服从“连续性方程”:
tpt(xt)=xt(ft(xt)pt(xt))

u(t,xt)=(pt(xt),ft(xt)pt(xt))Rd+1,那么连续性方程可以简写成
{(t,xt)u(t,xt)=0u1(0,x0)=p0(x0),u1(t,xt)dxt=1

为了求解这个方程,可以用格林函数的思想,即先求解
{(t,xt)G(t,0;xt,x0)=0G1(0,0;xt,x0)=δ(xtx0),G1(t,0;xt,x0)dxt=1

那么
u(t,xt)=G(t,0;xt,x0)p0(x0)dx0=Ex0p0(x0)[G(t,0;xt,x0)]

就是满足约束条件的解之一。

几何直观 #

所谓格林函数,其实思想很简单,它就是说我们先不要着急解决复杂数据生成,我们先假设要生成的数据只有一个点x0,先解决单个数据点的生成。有的读者想这不是很简单吗?直接xT×0+x0就完事了?当然不是这么简单,我们需要的是连续的、渐变的生成,如下图所示,就是t=T上的任意一点xT,都沿着一条光滑轨迹运行到t=0x0上:

格林函数示意图。图中T=1,在t=1处的每个点,都沿着特定的轨迹运行到t=0处的一个点,除了公共点外,轨迹之间无重叠,这些轨迹就是格林函数的场线

格林函数示意图。图中T=1,在t=1处的每个点,都沿着特定的轨迹运行到t=0处的一个点,除了公共点外,轨迹之间无重叠,这些轨迹就是格林函数的场线

而我们的目的,只是构造一个生成模型出来,所以我们原则上并不在乎轨迹的形状如何,只要它们都穿过x0,那么,我们可以人为地选择我们喜欢的、经过x0的一个轨迹簇,记为
φt(xt|x0)=xT


再次强调,这代表着以x0为起点、以xT为终点的一个轨迹簇,轨迹自变量、因变量分别为t,xt,起点x0是固定不变的,终点xT是可以任意变化的,轨迹的形状是无所谓的,我们可以选择直线、抛物线等等。

现在我们对式(6)两边求导,由于xT是可以随意变化的,它相当于微分方程的积分常数,对它求导就等于0,于是我们有
φt(xt|x0)xtdxtdt+φt(xt|x0)t=0dxtdt=(φt(xt|x0)xt)1φt(xt|x0)t


对比式(1),我们就得到
ft(xt|x0)=(φt(xt|x0)xt)1φt(xt|x0)t

这里将原本的记号ft(xt)替换为了ft(xt|x0),以标记轨线具有公共点x0。也就是说,这样构造出来的力场ft(xt|x0)所对应的ODE轨迹,必然是经过x0的,这就保证了格林函数的初值条件。

特征线法 #

既然初值条件有保证了,那么我们不妨要求更多一点:再保证一下终值条件。终值条件也就是希望t=TxT的分布是跟x0无关的简单分布。上一篇文章的求解框架的主要缺点,就是无法直接保证终值分布的简单性,只能通过事后分析来研究。这篇文章的思路则是直接通过设计特定的ft(xt|x0)来保证初值条件,然后就有剩余空间来保证终值条件了。而且,同时保证了初、终值后,在满足连续性方程(2)的前提下,积分条件是自然满足的。

用数学的方式说,我们就是要在给定ft(xt|x0)pT(xT)的前提下,去求解方程(2),这是一个一阶偏微分方程,可以通过“特征线法”求解,其理论介绍可以参考笔者之前写的《一阶偏微分方程的特征线法》。首先,我们将方程(2)等价地改写成
tpt(xt|x0)+xtpt(xt|x0)ft(xt|x0)=pt(xt|x0)xtft(xt|x0)


同前面类似,由于接下来是在给定起点x0进行求解,所以上式将pt(xt)替换为pt(xt|x0),以标记这是起点为x0的解。

特征线法的思路,是先在某条特定的轨迹上考虑偏微分方程的解,这可以将偏微分转化为常微分,降低求解难度。具体来说,我们假设xtt的函数,在方程(1)的轨线上求解。此时由于成立方程(1),将上式左端的ft(xt|x0)替换为dxtdt后,左端正好是pt(xt|x0)的全微分,所以此时有
ddtpt(xt|x0)=pt(xt|x0)xtft(xt|x0)


注意,此时所有的xt应当被替换为对应的t的函数,这理论上可以从轨迹方程(6)解出。替换后,上式的pf都是纯粹t的函数,所以上式只是关于p的一个线性常微分方程,可以解得
pt(xt|x0)=Cexp(Ttxsfs(xs|x0)ds)

代入终值条件pT(xT),得到C=pT(xT),即
pt(xt|x0)=pT(xT)exp(Ttxsfs(xs|x0)ds)

把轨迹方程(6)xT代入,就得到一个只含有t,xt,x0的函数,便是最终要求解的格林函数G1(t,0;xt,x0)了,相应地有G>1(t,0;xt,x0)=pt(xt|x0)ft(xt|x0)

训练目标 #

有了格林函数,我们就可以得到
u1(t,xt)=pt(xt|x0)p0(x0)dx0=pt(xt)u>1(t,xt)=ft(xt|x0)pt(xt|x0)p0(x0)dx0


于是
ft(xt)=u>1(t,xt)u1(t,xt)=ft(xt|x0)pt(xt|x0)p0(x0)pt(xt)dx0=ft(xt|x0)pt(x0|xt)dx0=Ex0pt(x0|xt)[ft(xt|x0)]

根据《生成扩散模型漫谈(五):一般框架之SDE篇》中构建得分匹配目标的方法,可以构建训练目标
Extpt(xt)[Ex0pt(x0|xt)[vθ(xt,t)ft(xt|x0)2]]dxt=Ex0,xtpt(xt|x0)p0(x0)[vθ(xt,t)ft(xt|x0)2]

它跟《Flow Matching for Generative Modeling》所给出的“Conditional Flow Matching”形式上是一致的,后面我们还会看到,该论文的结果都可以从本文的方法推出。训练完成后,就可以通过求解方程dxtdt=vθ(xt,t)来生成样本了。从这个训练目标也可以看出,我们对pt(xt|x0)的要求是易于采样就行了。

一些例子 #

可能前面的抽象结果对大家来说还是不大好理解,接下来我们来给出一些具体例子,以便加深大家对这个框架的直观理解。至于特征线法本身,笔者在《一阶偏微分方程的特征线法》也说过,一开始笔者也觉得特征线法像是“变魔术”一样难以捉摸,按照步骤操作似乎不困难,但总把握不住关键之处,理解它需要一个反复斟酌的思考过程,无法进一步代劳了。

直线轨迹 #

作为最简单的例子,我们假设xT是沿着直线轨迹变为x0,简单起见我们还可以将T设为1,这不会损失一般性,那么xt的方程可以写为
xt=(x1x0)t+x0xtx0t+x0=x1


根据式(8),有
ft(xt|x0)=xtx0t

此时xtft(xt|x0)=dt,根据式(12)就有
pt(xt|x0)=p1(x1)td

代入式(16)中的x1,得到
pt(xt|x0)=p1(xtx0t+x0)td

特别地,如果p1(x1)是标准正态分布,那么上式实则意味着pt(xt|x0)=N(xt;(1t)x0,t2I),这正好是常见的高斯扩散模型之一。这个框架的新结果,是允许我们选择更一般的先验分布p1(x1),比如均匀分布。另外在介绍得分匹配(15)时也已经说了,对pt(xt|x0)我们只需要知道它的采样方式就行了,而上式告诉我们只需要先验分布易于采样就行,因为:
xtpt(xt|x0)xt=(1t)x0+tε,εp1(ε)

效果演示 #

注意,我们假设从x0x1的轨迹是一条直线,这仅仅是对于单点生成的,也就是格林函数解。当通过格林函数叠加出一般分布对应的的力场ft(xt)时,其生成轨迹就不再是直线了。

下图演示了先验分布为均匀分布时多点生成的轨线图:

单点生成

单点生成

两点生成

两点生成

三点生成

三点生成

参考作图代码:

import numpy as np
from scipy.integrate import odeint
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

prior = lambda x: 0.5 if 2 >= x >= 0 else 0
p = lambda xt, x0, t: prior((xt - x0) / t + x0) / t
f = lambda xt, x0, t: (xt - x0) / t

def f_full(xt, t):
    x0s = [0.5, 0.5, 1.2, 1.7]  # 0.5出现两次,代表其频率是其余的两倍
    fs = np.array([f(xt, x0, t) for x0 in x0s]).reshape(-1)
    ps = np.array([p(xt, x0, t) for x0 in x0s]).reshape(-1)
    return (fs * ps).sum() / (ps.sum() + 1e-8)

for x1 in np.arange(0.01, 1.99, 0.10999/2):
    ts = np.arange(1, 0, -0.001)
    xs = odeint(f_full, x1, ts).reshape(-1)[::-1]
    ts = ts[::-1]
    if abs(xs[0] - 0.5) < 0.1:
        _ = plt.plot(ts, xs, color='skyblue')
    elif abs(xs[0] - 1.2) < 0.1:
        _ = plt.plot(ts, xs, color='orange')
    else:
        _ = plt.plot(ts, xs, color='limegreen')

plt.xlabel('$t$')
plt.ylabel(r'$\boldsymbol{x}$')
plt.show()

一般推广 #

其实上面的结果还可以一般地推广到
xt=μt(x0)+σtx1xtμt(x0)σt=x1


这里的μt(x0)是任意满足μ0(x0)=x0,μ1(x0)=0RdRd函数,σt是任意满足σ0=0,σ1=1的单调递增函数。根据式(8),有
ft(xt|x0)=˙μt(x0)+˙σtσt(xtμt(x0))

这也等价于《Flow Matching for Generative Modeling》中的式(15),此时xtft(xt|x0)=d˙σtσt,根据式(12)就有
pt(xt|x0)=p1(x1)σdt

代入x1,最终结果是
pt(xt|x0)=p1(xtμt(x0)σt)σdt

这是关于线性ODE扩散的一般结果,包含高斯扩散,也允许使用非高斯的先验分布。

再复杂些? #

前面的例子,都是通过x0(的某个变换)与x1的简单线性插值(插值权重纯粹是t的函数)来构建xt的变化轨迹。那么一个很自然的问题就是:可不可以考虑更复杂的轨迹呢?理论上可以,但是更高的复杂度意味着隐含了更多的假设,而我们通常很难检验目标数据是否支持这些假设,因此通常都不考虑更复杂的轨迹了。此外,对于更复杂的轨迹,解析求解的难度通常也更高,不管是理论还是实验,都难以操作下去。

更重要的一点的,我们目前所假设的轨迹,仅仅是单点生成的轨迹而已,前面已经演示了,即便假设为直线,多点生成依然会导致复杂的曲线。所以,如果单点生成的轨迹都假设得不必要的复杂,那么可以想像多点生成的轨迹复杂度将会奇高,模型可能会极度不稳定。

文章小结 #

接着上一篇文章的内容,本文再次讨论了ODE式扩散模型的构建思路。这一次我们从几何直观出发,通过构造特定的向量场保证结果满足初值分布条件,然后通过求解微分方程保证终值分布条件,得到一个同时满足初值和终值条件的格林函数。特别地,该方法允许我们使用任意简单分布作为先验分布,摆脱以往对高斯分布的依赖来构建扩散模型。

转载到请包括本文地址:https://spaces.ac.cn/archives/9379

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Dec. 22, 2022). 《生成扩散模型漫谈(十五):构建ODE的一般步骤(中) 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9379

@online{kexuefm-9379,
        title={生成扩散模型漫谈(十五):构建ODE的一般步骤(中)},
        author={苏剑林},
        year={2022},
        month={Dec},
        url={\url{https://spaces.ac.cn/archives/9379}},
}