论文传送门:Denoising Diffusion Probabilistic Models
代码实现:DDPM模型——pytorch实现
推荐视频:54、Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读
需要的数学基础:
联合概率(Joint probability):
P
(
A
,
B
,
C
)
=
P
(
C
∣
B
,
A
)
P
(
B
,
A
)
=
P
(
C
∣
B
,
A
)
P
(
B
∣
A
)
P
(
A
)
P(A, B, C)=P(C \mid B, A) P(B, A)=P(C \mid B, A) P(B \mid A) P(A)
P(A,B,C)=P(C∣B,A)P(B,A)=P(C∣B,A)P(B∣A)P(A)
条件概率(Conditional probability):
P
(
B
,
C
∣
A
)
=
P
(
B
∣
A
)
P
(
C
∣
A
,
B
)
P(B, C \mid A)=P(B \mid A) P(C \mid A, B)
P(B,C∣A)=P(B∣A)P(C∣A,B)
马尔可夫链(Markov Chain):
p
(
X
t
+
1
∣
X
t
,
…
,
X
1
)
=
p
(
X
t
+
1
∣
X
t
)
p\left(X_{t+1} \mid X_{t}, \ldots, X_{1}\right)=p\left(X_{t+1} \mid X_{t}\right)
p(Xt+1∣Xt,…,X1)=p(Xt+1∣Xt)
贝叶斯公式(Bayes Rule):
P
(
A
i
∣
B
)
=
P
(
B
∣
A
i
)
P
(
A
i
)
∑
j
P
(
B
∣
A
j
)
P
(
A
j
)
P\left(A_{i} \mid B\right)=\frac{P\left(B \mid A_{i}\right) P\left(A_{i}\right)}{\sum_{j} P\left(B \mid A_{j}\right) P\left(A_{j}\right)}
P(Ai∣B)=∑jP(B∣Aj)P(Aj)P(B∣Ai)P(Ai)
正态分布(Normal distribution)
X
∼
N
(
μ
,
σ
2
)
X \sim N\left(\mu, \sigma^{2}\right)
X∼N(μ,σ2)的概率密度函数:
f
(
x
)
=
1
2
π
σ
e
−
(
x
−
μ
)
2
2
σ
2
f(x)=\frac{1}{\sqrt{2 \pi} \sigma} e^{-\frac{(x-\mu)^{2}}{2 \sigma^{2}}}
f(x)=2πσ1e−2σ2(x−μ)2
两个正态分布
X
∼
N
(
μ
X
,
σ
X
2
)
X \sim N\left(\mu_{X}, \sigma_{X}^{2}\right)
X∼N(μX,σX2)和
Y
∼
N
(
μ
Y
,
σ
Y
2
)
Y \sim N\left(\mu_{Y}, \sigma_{Y}^{2}\right)
Y∼N(μY,σY2)的叠加:
U
=
X
+
Y
∼
N
(
μ
X
+
μ
Y
,
σ
X
2
+
σ
Y
2
)
U=X+Y \sim N\left(\mu_{X}+\mu_{Y}, \sigma_{X}^{2}+\sigma_{Y}^{2}\right)
U=X+Y∼N(μX+μY,σX2+σY2)
两个正态分布
p
,
q
p,q
p,q的KL散度(Kullback-Leibler divergence):
K
L
(
p
,
q
)
=
log
σ
q
σ
p
+
σ
p
2
+
(
μ
p
−
μ
q
)
2
2
σ
q
2
−
1
2
K L(p, q)=\log \frac{\sigma_{q}}{\sigma_{p}}+\frac{\sigma_{p}^{2}+\left(\mu_{p}-\mu_{q}\right)^{2}}{2 \sigma_{q}^{2}}-\frac{1}{2}
KL(p,q)=logσpσq+2σq2σp2+(μp−μq)2−21
重参数技巧(Reparameterrization):
若
X
∼
N
(
μ
,
σ
2
)
,
Y
=
X
−
μ
σ
∼
N
(
0
,
1
)
若X \sim N\left(\mu, \sigma^{2}\right), Y=\frac{X-\mu}{\sigma} \sim N(0,1)
若X∼N(μ,σ2),Y=σX−μ∼N(0,1)
从正态分布
X
X
X中采样
z
z
z,等价于从标准正态分布
Y
Y
Y中采样
z
′
z'
z′,
z
=
μ
+
σ
×
z
′
z = \mu + \sigma \times z'
z=μ+σ×z′
一元二次式的配方:
a
x
2
+
b
x
=
a
(
x
+
b
2
a
)
2
+
C
a x^{2}+b x=a\left(x+\frac{b}{2 a}\right)^{2}+C
ax2+bx=a(x+2ab)2+C
概念:
t
t
t:时刻(加噪次数)
T
T
T:总时长(总加噪次数)
x
\mathbf{x}
x:图像
x
0
\mathbf{x}_{0}
x0:初始时刻图像
x
t
\mathbf{x}_{t}
xt:
t
t
t时刻图像
x
T
\mathbf{x}_{T}
xT:终止时刻图像
x
0
x_0
x0 ~
q
(
x
0
)
q(x_0)
q(x0),
q
(
x
0
)
q(x_0)
q(x0):真实图像分布
p
θ
(
x
0
)
:
=
∫
p
θ
(
x
0
:
T
)
d
x
1
:
T
p_\theta (x_0) := \int p_\theta (x_{0:T}) d x_{1:T}
pθ(x0):=∫pθ(x0:T)dx1:T,
p
θ
(
x
0
)
p_\theta (x_0)
pθ(x0):生成图像分布
θ
\theta
θ:(网络)参数
β
t
\beta_{t}
βt:扩散过程t时刻加入噪声的方差
β
\beta
β:噪声方差序列,长度为T,在
(
0
,
1
)
(0,1)
(0,1)区间内单调递增
Reverse process:
逆扩散过程的数学表达:
p
θ
(
x
0
:
T
)
:
=
p
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}\left(\mathbf{x}_{0: T}\right):=p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)
pθ(x0:T):=p(xT)t=1∏Tpθ(xt−1∣xt)
全部时刻图像的联合概率分布
p
θ
(
x
0
:
T
)
p_{\theta}\left(\mathbf{x}_{0: T}\right)
pθ(x0:T),整个过程是马尔科夫链。其中,
p
(
x
T
)
=
N
(
x
T
;
0
,
I
)
p\left(\mathbf{x}_{T}\right)=\mathcal{N}\left(\mathbf{x}_{T} ; \mathbf{0}, \mathbf{I}\right)
p(xT)=N(xT;0,I)
p
(
x
T
)
p\left(\mathbf{x}_{T}\right)
p(xT)是标准正态分布,
x
T
\mathbf{x}_{T}
xT为采样值,与网络参数无关。
t
t
t时刻去噪的数学表达:
p
θ
(
x
t
−
1
∣
x
t
)
:
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
Σ
θ
(
x
t
,
t
)
)
p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right):=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right), \boldsymbol{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)
pθ(xt−1∣xt):=N(xt−1;μθ(xt,t),Σθ(xt,t))
x
t
−
1
\mathbf{x}_{t-1}
xt−1服从均值为
μ
θ
(
x
t
,
t
)
\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)
μθ(xt,t),方差为
Σ
θ
(
x
t
,
t
)
\boldsymbol{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)
Σθ(xt,t)的正态分布,作者在原文中将方差
Σ
θ
(
x
t
,
t
)
\boldsymbol{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)
Σθ(xt,t)设为
σ
t
2
=
β
~
t
=
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
β
t
\sigma_{t}^{2}=\tilde{\beta}_{t}=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}} \beta_{t}
σt2=β~t=1−αˉt1−αˉt−1βt(经实验,
σ
t
2
=
β
t
\sigma_{t}^{2}={\beta}_{t}
σt2=βt和
σ
t
2
=
β
~
t
\sigma_{t}^{2}=\tilde{\beta}_{t}
σt2=β~t的结果相似),与模型参数无关(
β
~
t
\tilde{\beta}_{t}
β~t在后续计算中会提到)。
Forward process:
扩散过程的数学表达:
q
(
x
1
:
T
∣
x
0
)
:
=
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right):=\prod_{t=1}^{T} q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right)
q(x1:T∣x0):=t=1∏Tq(xt∣xt−1)
给定初始图像
x
0
\mathbf{x}_{0}
x0,全部时刻(
t
>
0
t>0
t>0)的联合概率分布,整个过程是马尔科夫链。
t
t
t时刻加噪的数学表达:
q
(
x
t
∣
x
t
−
1
)
:
=
N
(
x
t
;
1
−
β
t
x
t
−
1
,
β
t
I
)
q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}\right):=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{1-\beta_{t}} \mathbf{x}_{t-1}, \beta_{t} \mathbf{I}\right)
q(xt∣xt−1):=N(xt;1−βtxt−1,βtI)
x
t
\mathbf{x}_{t}
xt服从均值为
1
−
β
t
x
t
−
1
\sqrt{1-\beta_{t}} \mathbf{x}_{t-1}
1−βtxt−1,方差为
β
t
\beta_{t}
βt的正态分布。
使用重参数技巧,任意时刻的图像
x
t
\mathbf{x}_{t}
xt可以由初始时刻图像
x
0
\mathbf{x}_{0}
x0和噪声方差序列
β
\beta
β来确定,为简化表达,定义
α
t
:
=
1
−
β
t
\alpha_{t}:=1-\beta_{t}
αt:=1−βt,
α
ˉ
t
:
=
∏
s
=
1
t
α
s
\bar{\alpha}_{t}:=\prod_{s=1}^{t} \alpha_{s}
αˉt:=∏s=1tαs,则:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
t
−
1
=
α
t
(
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
ϵ
t
−
2
)
+
1
−
α
t
ϵ
t
−
1
=
α
t
α
t
−
1
x
t
−
2
+
α
t
1
−
α
t
−
1
ϵ
t
−
2
+
1
−
α
t
ϵ
t
−
1
=
α
t
α
t
−
1
x
t
−
2
+
α
t
−
α
t
α
t
−
1
+
1
−
α
t
ϵ
‾
t
−
2
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
ϵ
‾
t
−
2
=
…
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
xt=√αtxt−1+√1−αtϵt−1=√αt(√αt−1xt−2+√1−αt−1ϵt−2)+√1−αtϵt−1=√αtαt−1xt−2+√αt√1−αt−1ϵt−2+√1−αtϵt−1=√αtαt−1xt−2+√αt−αtαt−1+1−αt¯ϵt−2=√αtαt−1xt−2+√1−αtαt−1¯ϵt−2=…=√ˉαtx0+√1−ˉαtϵ
可以将上式改写,使用
x
t
\mathbf{x}_{t}
xt和
ϵ
{\epsilon}
ϵ来表达
x
0
\mathbf{x}_{0}
x0:
x
0
=
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
ϵ
)
\mathbf{x}_{0}=\frac{1}{\sqrt{\bar{\alpha}_{t}}}\left(\mathbf{x}_{t}-\sqrt{1-\bar{\alpha}_{t}} \mathbf{\epsilon}\right)
x0=αˉt1(xt−1−αˉtϵ)
(推导过程用到两个正态分布的叠加公式)
其中,
ϵ
i
\mathbf{\epsilon}_{i}
ϵi ~
N
(
0
,
I
)
\mathcal{N}\left(\mathbf{0}, \mathbf{I}\right)
N(0,I)
Loss:
负对数似然的上界:
E
q
[
−
log
p
θ
(
x
0
)
]
≤
E
q
[
−
log
p
θ
(
x
0
)
]
+
D
K
L
(
q
(
x
1
:
T
∣
x
0
)
∥
p
(
x
1
:
T
∣
x
0
)
)
=
E
q
[
−
log
p
θ
(
x
0
)
]
+
E
q
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
/
p
θ
(
x
0
)
]
=
E
q
[
−
log
p
θ
(
x
0
)
]
+
E
q
[
−
log
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
+
log
p
θ
(
x
0
)
]
=
E
q
[
−
log
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
]
Eq[−logpθ(x0)]≤Eq[−logpθ(x0)]+DKL(q(x1:T∣x0)‖p(x1:T∣x0))=Eq[−logpθ(x0)]+Eq[logq(x1:T∣x0)pθ(x0:T)/pθ(x0)]=Eq[−logpθ(x0)]+Eq[−logpθ(x0:T)q(x1:T∣x0)+logpθ(x0)]=Eq[−logpθ(x0:T)q(x1:T∣x0)]
定义损失函数L:
E
q
[
−
log
p
θ
(
x
0
)
]
≤
E
q
[
−
log
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
]
:
=
L
\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{0}\right)\right] \leq\mathbb{E}_{q}\left[-\log \frac{p_{\theta}\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_{0}\right)}\right]:=L
Eq[−logpθ(x0)]≤Eq[−logq(x1:T∣x0)pθ(x0:T)]:=L
L的进一步推导:
L
=
E
q
[
−
log
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
]
=
E
q
[
−
log
p
(
x
T
)
−
∑
t
≥
1
log
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
[
−
log
p
(
x
T
)
−
∑
t
>
1
log
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
)
−
log
p
θ
(
x
0
∣
x
1
)
q
(
x
1
∣
x
0
)
]
=
E
q
[
−
log
p
(
x
T
)
−
∑
t
>
1
log
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
−
log
p
θ
(
x
0
∣
x
1
)
q
(
x
1
∣
x
0
)
]
=
E
q
[
−
log
p
(
x
T
)
q
(
x
T
∣
x
0
)
−
∑
t
>
1
log
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
−
log
p
θ
(
x
0
∣
x
1
)
]
=
E
q
[
D
K
L
(
q
(
x
T
∣
x
0
)
∥
p
(
x
T
)
)
+
∑
t
>
1
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
−
log
p
θ
(
x
0
∣
x
1
)
]
L=Eq[−logpθ(x0:T)q(x1:T∣x0)]=Eq[−logp(xT)−∑t≥1logpθ(xt−1∣xt)q(xt∣xt−1)]=Eq[−logp(xT)−∑t>1logpθ(xt−1∣xt)q(xt∣xt−1)−logpθ(x0∣x1)q(x1∣x0)]=Eq[−logp(xT)−∑t>1logpθ(xt−1∣xt)q(xt−1∣xt,x0)⋅q(xt−1∣x0)q(xt∣x0)−logpθ(x0∣x1)q(x1∣x0)]=Eq[−logp(xT)q(xT∣x0)−∑t>1logpθ(xt−1∣xt)q(xt−1∣xt,x0)−logpθ(x0∣x1)]=Eq[DKL(q(xT∣x0)‖p(xT))+∑t>1DKL(q(xt−1∣xt,x0)‖pθ(xt−1∣xt))−logpθ(x0∣x1)]
用
L
T
{L}_{T}
LT,
L
t
−
1
{L}_{t-1}
Lt−1,
L
0
{L}_{0}
L0三项来表示:
L
=
E
q
[
D
K
L
(
q
(
x
T
∣
x
0
)
∥
p
(
x
T
)
)
⏟
L
T
+
∑
t
>
1
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
⏟
L
t
−
1
−
log
p
θ
(
x
0
∣
x
1
)
⏟
L
0
]
L = \mathbb{E}_{q}[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right) \| p\left(\mathbf{x}_{T}\right)\right)}_{L_{T}}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right)}_{L_{t-1}} \underbrace{-\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}_{L_{0}}]
L=Eq[LT
DKL(q(xT∣x0)∥p(xT))+t>1∑Lt−1
DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))L0
−logpθ(x0∣x1)]
第一项
L
T
{L}_{T}
LT与网络参数
θ
\theta
θ无关,可以忽略。
对第三项
L
0
{L}_{0}
L0进行分析:
p
θ
(
x
0
∣
x
1
)
=
∏
i
=
1
D
∫
δ
−
(
x
0
i
)
δ
+
(
x
0
i
)
N
(
x
;
μ
θ
i
(
x
1
,
1
)
,
σ
1
2
)
d
x
δ
+
(
x
)
=
{
∞
if
x
=
1
x
+
1
255
if
x
<
1
δ
−
(
x
)
=
{
−
∞
if
x
=
−
1
x
−
1
255
if
x
>
−
1
pθ(x0∣x1)=D∏i=1∫δ+(xi0)δ−(xi0)N(x;μiθ(x1,1),σ21)dxδ+(x)={∞ if x=1x+1255 if x<1δ−(x)={−∞ if x=−1x−1255 if x>−1
相当于从连续空间向离散空间的变化,即将连续高斯分布转化为离散高斯分布,以对应输入的图片数据。
对第二项
L
t
−
1
{L}_{t-1}
Lt−1进行分析:
KL散度的第一项
q
(
x
t
−
1
∣
x
t
,
x
0
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)
q(xt−1∣xt,x0),设
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
μ
~
t
(
x
t
,
x
0
)
,
β
~
t
I
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right), \tilde{\beta}_{t} \mathbf{I}\right)
q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
使用贝叶斯公式和配方,计算
q
(
x
t
−
1
∣
x
t
,
x
0
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right)
q(xt−1∣xt,x0):
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
∝
exp
(
−
1
2
(
(
x
t
−
α
t
x
t
−
1
)
2
β
t
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
)
=
exp
(
−
1
2
(
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
(
2
α
t
β
t
x
t
+
2
α
t
−
1
1
−
α
ˉ
t
−
1
x
0
)
x
t
−
1
+
C
(
x
t
,
x
0
)
)
)
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt−1∣x0)q(xt∣x0)∝exp(−12((xt−√αtxt−1)2βt+(xt−1−√ˉαt−1x0)21−ˉαt−1−(xt−√ˉαtx0)21−ˉαt))=exp(−12((αtβt+11−ˉαt−1)x2t−1−(2√αtβtxt+2√αt−11−ˉαt−1x0)xt−1+C(xt,x0)))
得到方差
β
~
t
\tilde{\beta}_{t}
β~t:
β
~
t
=
1
/
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
=
1
−
α
ˉ
t
−
1
α
t
+
β
t
−
α
ˉ
t
⋅
β
t
=
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
⋅
β
t
˜βt=1/(αtβt+11−ˉαt−1)=1−ˉαt−1αt+βt−ˉαt⋅βt=1−ˉαt−11−ˉαt⋅βt
得到均值
μ
~
t
(
x
t
,
x
0
)
\tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)
μ~t(xt,x0):
μ
~
t
(
x
t
,
x
0
)
=
(
α
t
β
t
x
t
+
α
ˉ
t
−
1
1
−
α
ˉ
t
−
1
x
0
)
/
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
α
t
+
β
t
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
α
t
+
β
t
−
α
ˉ
t
x
0
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
˜μt(xt,x0)=(√αtβtxt+√ˉαt−11−ˉαt−1x0)/(αtβt+11−ˉαt−1)=√αt(1−ˉαt−1)αt+βt−ˉαtxt+√ˉαt−1βtαt+βt−ˉαtx0=√αt(1−ˉαt−1)1−ˉαtxt+√ˉαt−1βt1−ˉαtx0
使用
x
t
\mathbf{x}_{t}
xt和
ϵ
{\epsilon}
ϵ来表达
x
0
\mathbf{x}_{0}
x0:
μ
~
t
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
ϵ
)
=
1
α
t
(
x
t
−
β
t
1
−
α
ˉ
t
ϵ
)
˜μt=√αt(1−ˉαt−1)1−ˉαtxt+√ˉαt−1βt1−ˉαt1√ˉαt(xt−√1−ˉαtϵ)=1√αt(xt−βt√1−ˉαtϵ)
KL散度的第二项
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)
pθ(xt−1∣xt),在逆扩散过程中已经定义,作者将方差
Σ
θ
(
x
t
,
t
)
\boldsymbol{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)
Σθ(xt,t)设为
σ
t
2
=
β
~
t
\sigma_{t}^{2}=\tilde{\beta}_{t}
σt2=β~t。
p
θ
(
x
t
−
1
∣
x
t
)
:
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
Σ
θ
(
x
t
,
t
)
)
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
σ
t
2
)
pθ(xt−1∣xt):=N(xt−1;μθ(xt,t),Σθ(xt,t))=N(xt−1;μθ(xt,t),σ2t)
使用两个正态分布(方差相同)的KL散度计算公式,可以计算得到:
L
t
−
1
=
E
q
[
1
2
σ
t
2
∥
μ
~
t
(
x
t
,
x
0
)
−
μ
θ
(
x
t
,
t
)
∥
2
]
+
C
L_{t-1}=\mathbb{E}_{q}\left[\frac{1}{2 \sigma_{t}^{2}}\left\|\tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)-\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|^{2}\right]+C
Lt−1=Eq[2σt21∥μ~t(xt,x0)−μθ(xt,t)∥2]+C
使用
x
0
\mathbf{x}_{0}
x0和
ϵ
\epsilon
ϵ来表达
x
t
\mathbf{x}_{t}
xt,
ϵ
\epsilon
ϵ ~
N
(
0
,
I
)
\mathcal{N}\left(\mathbf{0}, \mathbf{I}\right)
N(0,I):
x
t
(
x
0
,
ϵ
)
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
\mathbf{x}_{t}\left(\mathbf{x}_{0}, \boldsymbol{\epsilon}\right)=\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}
xt(x0,ϵ)=αˉtx0+1−αˉtϵ
于是,
L
t
−
1
−
C
L_{t-1}-C
Lt−1−C可以表示为:
L
t
−
1
−
C
=
E
x
0
,
ϵ
[
1
2
σ
t
2
∥
μ
~
t
(
x
t
(
x
0
,
ϵ
)
,
1
α
ˉ
t
(
x
t
(
x
0
,
ϵ
)
−
1
−
α
ˉ
t
ϵ
)
)
−
μ
θ
(
x
t
(
x
0
,
ϵ
)
,
t
)
∥
2
]
=
E
x
0
,
ϵ
[
1
2
σ
t
2
∥
1
α
t
(
x
t
(
x
0
,
ϵ
)
−
β
t
1
−
α
ˉ
t
ϵ
)
−
μ
θ
(
x
t
(
x
0
,
ϵ
)
,
t
)
∥
2
]
Lt−1−C=Ex0,ϵ[12σ2t‖˜μt(xt(x0,ϵ),1√ˉαt(xt(x0,ϵ)−√1−ˉαtϵ))−μθ(xt(x0,ϵ),t)‖2]=Ex0,ϵ[12σ2t‖1√αt(xt(x0,ϵ)−βt√1−ˉαtϵ)−μθ(xt(x0,ϵ),t)‖2]
上式表明,在给定
x
t
\mathbf{x}_{t}
xt的情况下,需要网络输出
μ
θ
(
x
t
(
x
0
,
ϵ
)
,
t
)
\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \boldsymbol{\epsilon}\right), t\right)
μθ(xt(x0,ϵ),t)去预测
1
α
t
(
x
t
(
x
0
,
ϵ
)
−
β
t
1
−
α
ˉ
t
ϵ
)
\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}\left(\mathbf{x}_{0}, \boldsymbol{\epsilon}\right)-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}\right)
αt1(xt(x0,ϵ)−1−αˉtβtϵ)。
但作者并没有这样搭建网络,而是选择对
μ
θ
(
x
t
,
t
)
\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)
μθ(xt,t)进行参数化处理:
μ
θ
(
x
t
,
t
)
=
μ
~
t
(
x
t
,
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
ϵ
θ
(
x
t
)
)
)
=
1
α
t
(
x
t
−
β
t
1
−
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
)
\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)=\tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}, \frac{1}{\sqrt{\bar{\alpha}_{t}}}\left(\mathbf{x}_{t}-\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}_{\theta}\left(\mathbf{x}_{t}\right)\right)\right)=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)
μθ(xt,t)=μ~t(xt,αˉt1(xt−1−αˉtϵθ(xt)))=αt1(xt−1−αˉtβtϵθ(xt,t))
可见,网络通过输入
x
t
\mathbf{x}_{t}
xt和
t
{t}
t,实际输出为
ϵ
θ
{\epsilon}_{\theta}
ϵθ(即预测噪声),而
x
t
\mathbf{x}_{t}
xt又可以由
x
0
\mathbf{x}_{0}
x0来表示,最终
L
t
−
1
−
C
L_{t-1}-C
Lt−1−C可以表示为:
E
x
0
,
ϵ
[
β
t
2
2
σ
t
2
α
t
(
1
−
α
ˉ
t
)
∥
ϵ
−
ϵ
θ
(
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
,
t
)
∥
2
]
\mathbb{E}_{\mathbf{x}_{0}, \boldsymbol{\epsilon}}\left[\frac{\beta_{t}^{2}}{2 \sigma_{t}^{2} \alpha_{t}\left(1-\bar{\alpha}_{t}\right)}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}, t\right)\right\|^{2}\right]
Ex0,ϵ[2σt2αt(1−αˉt)βt2
ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)
2]
最终,作者忽略上式的系数,得到简化的Loss形式:
L
simple
(
θ
)
:
=
E
t
,
x
0
,
ϵ
[
∥
ϵ
−
ϵ
θ
(
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
,
t
)
∥
2
]
L_{\text {simple }}(\theta):=\mathbb{E}_{t, \mathbf{x}_{0}, \boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}, t\right)\right\|^{2}\right]
Lsimple (θ):=Et,x0,ϵ[
ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)
2]