Skip to content

Denoising Diffusion Implicit Models

SONG J, MENG C, ERMON S. Denoising Diffusion Implicit Models[J]. arXiv preprint arXiv:2010.02502, 2020.

Denoising Diffusion Implicit Models


去噪扩散隐式模型

Abstract

去噪扩散概率模型(DDPMs)无需对抗训练即可实现高质量图像生成,但其采样过程需要模拟一个马尔可夫链的多个步骤,耗时较长。为了加速采样,我们提出了去噪扩散隐式模型(DDIMs),这是一类更高效的迭代隐式概率模型,其训练过程与 DDPMs 相同。在 DDPMs 中,生成过程被定义为特定马尔可夫扩散过程的逆过程。我们通过一类非马尔可夫扩散过程对 DDPMs 进行了推广,这些过程能够导出相同的训练目标。这些非马尔可夫过程可以对应确定性的生成过程,从而形成能够更快生成高质量样本的隐式模型。实验表明,与 DDPMs 相比,DDIMs 能够以实际时间快 10 至 50 倍的速度生成高质量样本,允许我们在计算量和样本质量之间进行权衡,直接在隐空间中进行语义有意义的图像插值,并以极低误差重建观测数据。

目标

我们想学习一个模型分布 pθ(x0),让它能近似真实数据分布 q(x0),并且容易采样生成新数据。

DDPM

DDPM 引入了一系列潜在变量 x1,x2,,xT,它们与原始数据 x0 相同。。模型通过以下方式定义:

pθ(x0)=pθ(x0:T)dx1:T(1)pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)

参数 θ 通过最大化变分下界来学习以拟合数据分布 q(x0)

(2)maxθEq(x0)[logpθ(x0)]maxθEq(x0,x1,,xT)[logpθ(x0:T)logq(x1:Tx0)]

其中 q(x1:Tx0) 是隐变量上的某个推断分布。与典型的隐变量模型(如变分自编码器)不同,DDPM 使用一个固定的(而非可训练的)推断过程 q(x1:Tx0) 进行学习,且隐变量维度相对较高。

前向过程表示:

q(x1:Tx0)=t=1Tq(xtxt1)(3)q(xtxt1)=N(αtαt1xt1,(1αtαt1)I)

前向过程的一个特殊性质是:

q(xtx0)=q(x1:tx0)dx1:(t1)=N(xt;αtx0,(1αt)I)

因此,我们可以将 xt 表示为 x0 与一个噪声变量 ϵ 的线性组合:

(4)xt=αtx0+1αtϵ

其中 ϵN(0,I)

当我们将 αt 设置得足够接近于 0 时,对于所有的 x0q(xTx0) 都会收敛为标准高斯分布。因此,很自然地设定 pθ(xT)=N(0,I) ,如果所有的条件分布都被建模为具有可训练均值函数和固定方差的高斯分布,那么目标可以简化为:

(5)Lγ(ϵθ)=t=1TγtEx0q(x0),ϵtN(0,I)[ϵθ(t)(αtx0+1αtϵt)ϵt22]

前向过程的长度 T 是 DDPM 中的一个重要超参数。从变分的角度来看,较大的 T 能使反向过程更接近高斯分布,从而使使用高斯条件分布建模的生成过程成为一个良好的近似;这促使人们选择较大的 T 值,例如 DDPM 中 T=1000。然而,由于所有 T 次迭代必须顺序执行(而非并行)才能获得一个样本 x0,从 DDPM 中采样比从其他深度生成模型中采样要慢得多,这使得它们在计算资源有限且延迟至关重要的任务中不切实际。

非马尔可夫前向过程的变分推断

由于生成模型近似于推断过程的逆过程,我们需要重新思考推断过程,以减少生成模型所需的迭代次数。我们的关键观察是:DDPM的目标函数 Lγ 仅依赖于边缘分布 q(xtx0),而不直接依赖于联合分布 q(x1:Tx0)。既然存在许多具有相同边缘分布的推断分布(联合分布),我们探索了非马尔可夫的替代推断过程,从而引出了新的生成过程。

非马尔可夫前向过程

考虑一个由实向量 σR0T 索引的推断分布族 Q

(6)qσ(x1:Tx0):=qσ(xTx0)t=2Tqσ(xt1xt,x0)

与 DDPM 的区别

  • DDPM(马尔可夫):q(x1:Tx0)=t=1Tq(xtxt1),每一步只依赖前一步。
  • DDIM(非马尔可夫):qσ(xt1xt,x0) 同时依赖 xtx0,允许跳过中间状态。

其中 qσ(xTx0)=N(αTx0,(1αT)I),并且对于所有 t>1

(7)qσ(xt1xt,x0)=N(αt1x0+1αt1σt2xtαtx01αtμσ(xt,x0),σt2I)

设计原理

  • 我们希望给定 x0 时,xt 的边缘分布仍为 N(αtx0,(1αt)I)(和 DDPM 相同)
  • 通过特定的均值函数设计,可以确保这一点
  • 参数 σt2 控制条件分布的方差:
    • σt2=β~t=1αt11αtβt 时,退化为 DDPM
    • σt2=0 时,方差为0,成为确定性转移(即 DDIM)

选择该均值函数是为了确保对于所有 t,都有 qσ(xtx0)=N(αtx0,(1αt)I),从而使其定义的联合推断分布能够如期望地匹配“边缘分布”。前向过程可以通过贝叶斯规则导出:

(8)qσ(xtxt1,x0)=qσ(xt1xt,x0)qσ(xtx0)qσ(xt1x0)

这里的前向过程不再是马尔可夫的,因为每个 xt 都可能依赖于 xt1x0σ 的大小控制了前向过程的随机程度,当 σ0 时,我们到达一个极端情况:只要观测到某个 t 对应的 x0xt,那么 xt1 就变成已知且固定的。

公式 (8)

实际并不需要这个前向过程

生成过程与统一的变分推断目标

接下来,我们定义一个可训练的生成过程 pθ(x0:T),其中每个 pθ(t)(xt1xt) 都利用了 qσ(xt1xt,x0)。直观上,给定一个带噪观测对应的 x0,然后通过我们已定义的反向条件分布 pθ(t)(xt1xt) 来获取样本 xt1

对于某个 x0q(x0)ϵtN(0,I),可以利用 (4) 得到 xt。模型 ϵθ(t)(xt) 在不知道 x0 的情况下,尝试从 xt 预测 ϵt。通过改写 (4),可以进而预测去噪后的观测值,即给定 xt 时对 x0 的预测:

(9)x^0=fθ(t)(xt)=xt1αtϵθ(t)(xt)αt

然后,我们可以用固定的先验 pθ(xT)=N(0,I) 和以下定义生成过程:

(10)pθ(t)(xt1xt)={N(fθ(1)(x1),σ12I) if t=1qσ(xt1xt,fθ(t)(xt)) otherwise 

公式 (10)

  • t>1 时:用预测的 x^0 代入 (7),计算 xt1 的分布
  • t=1 时:直接输出最终图像,加一点噪声保证分布性质

其中 qσ(xt1xt,fθ(t)(xt)) 按照 (7) 定义,只是将 x0 替换为 fθ(t)(xt)。对于 t=1 的情况,我们添加了一些高斯噪声(协方差为 σ12I)以确保生成过程处处有定义。

我们通过以下变分推断目标(这是关于 ϵθ 的一个泛函)来优化 θ

(11)Jσ(ϵθ)=Ex0:Tqσ(x0:T)[logqσ(x1:Tx0)logpθ(x0:T)]=Ex0:Tqσ(x0:T)[logqσ(xTx0)+t=2Tlogqσ(xt1xt,x0)t=1Tlogpθ(t)(xt1xt)logpθ(xT)]

这里我们根据 (6) 分解 qσ(x1:Tx0),并根据 (1) 分解 pθ(x0:T)

Jσ 的定义来看,似乎对于 σ 的每一种选择都必须训练一个不同的模型,因为它对应着不同的变分目标(和不同的生成过程)。然而,对于某些权重 γJσ 等价于 Lγ

定理1.对于所有 σ>0,存在 γR>0TCR,使得 Jσ=Lγ+C

变分目标 Lγ 的特殊之处在于:如果模型 ϵθ(t) 的参数 θ 在不同的 t 之间不共享,那么 ϵθ 的最优解将不依赖于权重 γ(因为全局最优可以通过分别最大化求和式中的每一项来实现)。Lγ 的这一性质有两层含义,一方面,这证明了在 DDPM 中使用 L1 作为变分下界的代理目标函数是合理的;另一方面,由于 Jσ 等价于定理1的某个 LγJσ 的最优解也与 L1 相同。因此,如果在模型 ϵθ 中参数不在不同的 t 之间共享,那么 DDPM 使用的 L1 目标也可以作为变分目标 Jσ 的代理目标。

定理1

用 DDPM 目标训练出的模型 ϵθ,就是 Jσ 的最优解,所以无需重新训练

从广义生成过程中采样

去噪扩散隐式模型

根据 (10) 中的 pθ(x1:T),可以从样本 xt 生成样本 xt1

(12)xt1=αt1x^0朝向原始信号的部分 +1αt1σt2ϵθ(t)(xt)朝向当前噪声 xt 的部分 +σtϵt随机噪声 

其中 ϵtN(0,I) 是与 xt 独立的标准高斯噪声,我们定义 α0=1。不同的 σ 值选择会导致不同的生成过程,但都是用相同的模型 ϵθ,因此无需重新训练模型。当对所有 tσt=(1αt1)/(1αt)1αt/αt1 时,前向过程变为马尔可夫的,生成过程则变为 DDPM。

DDPM

σt2=β~t=1αt11αtβtxt1=αt1x^0+1αt1β~tϵθ(xt,t)+β~tϵt

我们注意到另一个特殊情况,即当对所有 tσt=0 时:除了 t=1 的情况外,给定 xt1x0 时,前向过程变为确定性的;在生成过程中,随机噪声 ϵt 前的系数变为零。由此产生的模型成为一个隐式概率模型,其中样本通过固定过程(从 xTx0)从隐变量生成。我们将其命名为去噪扩散隐式模型(DDIM,发音为 /dɪm/),因为它是一个使用 DDPM 目标训练的隐式概率模型(尽管前向过程不再是扩散过程)。

DDIM

xt1=αt1x^0+1αt1ϵθ(xt,t)

TIP

尽管这种情况未被定理1所涵盖,但我们总可以通过使 σt 非常小来近似它。

加速生成过程

与神经 ODE 的相关性

此外,我们可以根据 (12) 重写 DDIM 的迭代式,它与求解常微分方程 (ODEs) 的欧拉积分的相似性变得更加明显:

(13)xtΔtαtΔt=xtαt+(1αtΔtαtΔt1αtαt)ϵθ(t)(xt)