高能预警:这篇文章难度很大,包含很多的数学推导,如果不想接触太多的数学内容,那么可以跳过不看。
看这篇文章之前,你需要了解:什么是马尔科夫链,什么是极大似然估计,什么是KL散度,两个正态分布的KL散度,什么是贝叶斯公式
以下内容参考了主要参考了博客What are Diffusion Models? 以及李宏毅老师的课程
- 1. 马尔科夫链与\(p_\theta(\mathbf{x})\)
- 2. 极大似然估计
- 3. \(L_{t-1}\)中的\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)
- 4. 最小化\(L_{t-1}\)
- 5. 总结
1. 马尔科夫链与\(p_\theta(\mathbf{x})\)
本节推导得出的结论:
- \(q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)\),\(p(\mathbf{x}_{0:T}) = p(\mathbf{x}_T)\prod_{t=1}^T p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)\)
- \(p_{\theta}(\mathbf{x}_{0:T}) = p(\mathbf{x}_T)\prod_{t=1}^T p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_t)\)
在扩散模型中,为了方便计算,我们假设前向过程中的图片\(\mathbf{x}_0,\mathbf{x}_1,\cdots\mathbf{x}_T\)构成一个马尔科夫链,并将前向过程中图片\(\mathbf{x}\)的概率分布记作\(q(\mathbf{x})\)
因此,我们有
\[q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
\]
同时,我们令\(p_\theta(\mathbf{x})\)表示:在反向过程中,模型生成图片\(\mathbf{x}\)的概率。
因此,在对扩散模型使用极大似然估计时,样本是没有噪音的图片\(\mathbf{x}_0\),似然函数\(p_\theta(\mathbf{x}_0)\)表示模型最终生成\(\mathbf{x}_0\)的概率。自然的,极大似然估计的目标是找到使得\(p_\theta(\mathbf{x}_0)\)最大的模型。
注意到在反向过程中,\(\mathbf{x}_T\)是噪音图片,直接采样自标准正态分布,并不需要通过模型生成,\(p_\theta(\mathbf{x}_T)\)和模型选取无关,因此可以记作\(p(\mathbf{x}_T)\)。
由于\(\mathbf{x}_0,\mathbf{x}_1,\cdots\mathbf{x}_T\)构成一个马尔科夫链,因此
\[p_\theta(\mathbf{x}_{0:T}) = p(\mathbf{x}_T)\prod_{t=1}^T p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)
\]
2. 极大似然估计
本节推导得出的结论:\(\min -\log{p_\theta(\mathbf{x}_0)}\)等价于\(\min L_T+L_{T-1}+\cdots+L_0\),其中
\[\begin{aligned}
L_T & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_T\right)\right) \\
L_{t-1} & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right) \quad \text { for } 2 \leq t \leq T \\
L_0 & =-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)
\end{aligned}
\]
上文中,我们说到,极大似然估计的目标是\(\max{p_\theta(\mathbf{x}_0)}\),为了方便起见,可以将目标转换为\(\min -\log{p_\theta(\mathbf{x}_0)}\)。
我们对\(-\log{p_\theta(\mathbf{x}_0)}\)进行一些变形,得到
\[\begin{aligned}
-\log p_\theta\left(\mathbf{x}_0\right) &\le -\log p_\theta\left(\mathbf{x}_0\right)+D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)\right)\\
& =-\log p_\theta\left(\mathbf{x}_0\right)+\mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}+\log p_\theta\left(\mathbf{x}_0\right)\right] \\
& =\mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right]\\
\end{aligned}
\]
即
\[-\log p_\theta\left(\mathbf{x}_0\right) \le\mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \tag{1}
\]
其中,\(D_{\mathrm{KL}}(q||p_\theta)\)表示分布\(q\)和分布\(p_\theta\)的KL散度;期望\(\mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}(f) = \int q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \times f \ \mathrm{d} \mathbf{x_{1:T}}\) 。
下面,我们对公式(1)左右两侧同时取期望
\[\begin{aligned}
\int -\log p\left(\mathbf{x_0}\right) \cdot q\left(\mathbf{x_0}\right) \mathrm{d} \mathbf{x_0} &\le \int \mathbb{E}_{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \cdot q(\mathbf{x}_0)\mathrm{d} \mathbf{x_0} \\
-\mathbb{E}_{q\left(\mathbf{x}_0\right)} \log p_\theta\left({\mathbf{x}}_0\right)&\le \iint \left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \cdot q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \cdot q(\mathbf{x}_0)\mathrm{d} \mathbf{x}_{1:T}\mathrm{d}\mathbf{x}_0\\
&=\int \left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] q(\mathbf{x}_{0: T}) \mathrm{d}\mathbf{x}_{0:T}\\
&=
\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right]
\end{aligned}
\]
即
\[-\mathbb{E}_{q\left(\mathbf{x}_0\right)} \log p_\theta\left({\mathbf{x}}_0\right) \le \mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \tag{2}
\]
为了方便表示,我们将\(-\mathbb{E}_{q\left(\mathbf{x}_0\right)} \log p_\theta\left({\mathbf{x}}_0\right)\)记作\(L_{\mathrm{CE}}\),将\(\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right]\)记作\(L_{\mathrm{VLB}}\)。
\(\min-\log p_\theta(\mathbf{x}_0)\)等价于\(\min L_{CE}\)。而只要\(\min L_{\mathrm{VLB}}\),就会\(\min L_{\mathrm{CE}}\)。
因此,\(\min -\log p(\mathbf{x}_0)\)的问题就转换为了\(\min L_{\mathrm{VLB}}\)的问题。
下面对\(L_{\mathrm{VLB}}\)进行变形
\[\begin{aligned}
& L_{\mathrm{VLB}}=\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right] \\
& =\mathbb{E}_q\left[\log \frac{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right] \\
& =\mathbb{E}_q\left[-\log p_\theta\left(\mathbf{x}_T\right)+\sum_{t=1}^T \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right] \\
& =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right] \\
& =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)} \cdot \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}\right)+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right] \\
& =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right] \\
& =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}+\log \frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right] \\
& =\mathbb{E}_q\left[\log \frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{p\left(\mathbf{x}_T\right)}+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\
& =\mathbb{E}_q\left[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t=2}^T \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}\right]
\end{aligned}
\]
对于一个期望来说,如果我们使它的每一项都最小化,那么期望的值也会最小化,因此有
\[\min L_{\mathrm{VLB}} \rightarrow \min L_T+L_{T-1}+\cdots+L_0 \tag{3}
\]
其中
\[\begin{aligned}
L_T & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right) \\
L_{t-1} & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right) \quad \text { for } 2 \leq t \leq T \\
L_0 & =-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)
\end{aligned}
\]
注意到,对于\(L_T\)而言,其中\(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)\)和\(p\left(\mathbf{x}_T\right)\)的取值均与参数\(\theta\)无关,因此\(L_T\)可以看成常数,我们只需要最小化\(L_t\)和\(L_0\)即可。
又因为\(\min D_{KL}(q(\mathbf{x}_0|\mathbf{x}_1,\mathbf{x}_0)||p_\theta(\mathbf{x}_0|\mathbf{x}_1)) \rightarrow \min D_{KL}(1||p_\theta(\mathbf{x}_0|\mathbf{x_1})) \rightarrow \min -\log p_\theta(\mathbf{x}_0\mid \mathbf{x}_1)\),因此,可以将\(L_0\)转换为\(L_{t-1}\)的形式,那么只需要最小化\(L_{t-1}\)即可。
3. \(L_{t-1}\)中的\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)
本节推导得出的结论:\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \tilde{\boldsymbol{\mu}}_t,\tilde{\beta}_t \mathbf{I}\right)\),其中\(\tilde{\boldsymbol{\mu}}_t=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}\right)\),\(\tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t\)
使用贝叶斯公式,我们可以将\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)转换为
\[q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)=q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}
\]
又因为\(\mathbf{x}_0,\mathbf{x}_1,\cdots\mathbf{x}_T\)构成一个马尔科夫链,因此\(q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) = q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)\)
我们在上篇文章的末尾提到过,在前向过程中,概率\(q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_t ; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}\right)\),\(q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)=\mathcal{N}\left(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0,\left(1-\bar{\alpha}_t\right) \mathbf{I}\right)\)。
于是有
\[\begin{aligned}
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) & =q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \\
& \propto \exp \left(-\frac{1}{2}\left(\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_{t-1}\right)^2}{\beta_t}+\frac{\left(\mathbf{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\
& =\exp \left(-\frac{1}{2}\left(\frac{\mathbf{x}_t^2-2 \sqrt{\alpha_t} \mathbf{x}_t \mathbf{x}_{t-1}+\alpha_t \mathbf{x}_{t-1}^2}{\beta_t}+\frac{\mathbf{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0 \mathbf{x}_{t-1}+\bar{\alpha}_{t-1} \mathbf{x}_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\
& =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}^2-\left(\frac{2 \sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) \mathbf{x}_{t-1}+C\left(\mathbf{x}_t, \mathbf{x}_0\right)\right)\right)
\end{aligned}
\]
其中 \(C\left(\mathbf{x}_t, \mathbf{x}_0\right)\) 是不含 \(\mathbf{x}_{t-1}\) 的常数,因此可以被忽略。
根据正态分布的概率公式,我们可以得到
注意 \(\alpha_t=1-\beta_t\) , \(\bar{\alpha}_t=\prod_{i=1}^t \alpha_i\)
\[\tilde{\beta}_t=1 /\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right)=1 /\left(\frac{\alpha_t-\bar{\alpha}_t+\beta_t}{\beta_t\left(1-\bar{\alpha}_{t-1}\right)}\right)=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t
\]
\[\begin{aligned}
\tilde{\boldsymbol{\mu}}_t& =\left(\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{\sqrt{\alpha_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) /\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \\
& =\left(\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t \\
& =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0
\end{aligned}
\]
在上一篇文章中,我们得到\(\mathbf{x}_t=\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}\),因此有
\[\mathbf{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol\epsilon\right)
\]
其中\(\boldsymbol\epsilon\)表示从\(\mathbf{x}_0\)到\(\mathbf{x}_t\)添加的噪音之和。
我们将\(\tilde{\boldsymbol{\mu}}_t\)表达式中的\(\mathbf{x}_0\)进行替换,可以得到
\[\begin{aligned}
\tilde{\boldsymbol{\mu}}_t & =\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}\right) \\
& =\frac{1}{\sqrt{\alpha}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}\right)
\end{aligned}
\]
因此,我们有
\[q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}\right),\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t \mathbf{I}\right)\tag{4}
\]
4. 最小化\(L_{t-1}\)
本节推导得出的结论:最小化\(L_{t-1}\)等价于最小化\(\| \boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t,t\right) \|^2\)。
其中,\(\boldsymbol\epsilon\)表示从\(\mathbf{x}_0\)到\(\mathbf{x}_t\)添加的噪音之和,\(\boldsymbol{\epsilon}_\theta\)表示预测噪音的模型,模型有两个输入:\(t\)时刻的图片\(\mathbf{x}_t\)以及时刻\(t\)
在上一小节,我们推出:\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)符合正态分布。又由于\(\mathbf{x}_0,\mathbf{x}_1,\cdots\mathbf{x}_T\)构成一个马尔科夫链,因此\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) = q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\),也就是说\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\)符合正态分布。
我们的目的是让反向过程尽可能和正向过程一致。因此我们可以合理假设,在反向过程中,\(p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\)也符合正态分布,并且和\(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)\)的分布近似。
设\(p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right)\),因为\(\tilde{\boldsymbol{\sigma}}_t^2\)为常数,因此我们直接令\(\mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)= \tilde{\boldsymbol{\sigma}}_t^2\)。同时,我们还要尽可能的令\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\)接近\(\tilde{\boldsymbol{\mu}}_t\)。
注意到,\(\tilde{\boldsymbol{\mu}}_t\)里面唯一一个,在反向过程中不知道的量就是从\(\mathbf{x}_0\)到\(\mathbf{x}_t\)添加的噪音之和 \(\boldsymbol\epsilon\),因此我们可以训练一个模型来预测\(\boldsymbol\epsilon\)。
这个模型就是我们在第一篇文章中提到的Noise Predicter
,我们将Noise Predicter
记作\(\boldsymbol{\epsilon}_\theta\),它有两个输入:\(t\)时刻的图片\(\mathbf{x}_t\)以及时刻\(t\)。模型的预测值记作\(\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\)
因此,
\[\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right) =\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)
\]
对于KL散度,我们有以下性质:
若有两个正态分布\(P\) ,\(Q\),均值分别为\(\mu_1\),\(\mu_2\);方差分别为\(\sigma_1^2\),\(\sigma_2^2\),且\(\sigma_1^2\),\(\sigma_2^2\)都为常数,那么
\[\min D_{KL}(P||Q) \rightarrow \min ||\mu_1-\mu_2||^2
\]
因此,
\[\begin{aligned}
\min L_{t-1} &\rightarrow \min||\tilde{\boldsymbol{\mu}}_t - \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)||^2\\
&\rightarrow \min\| \boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t,t\right) \|^2
\end{aligned} \tag{5}
\]
5. 总结
至此,我们完成了使用极大似然估计来推导损失函数的过程。
我们得到的结论是
\(\min -\log{p_\theta(\mathbf{x}_0)}\)等价于\(\min L_T+L_{T-1}+\cdots+L_0\),其中
\(\begin{aligned}
L_T & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right) \\
L_{t-1} & =D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right) \quad \text { for } 2 \leq t \leq T \\
L_0 & =-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)
\end{aligned}\)
其中\(L_T\)可以看作常数,\(L_0\)可以转换为\(L_{t-1}\)的形式,而最小化\(L_{t-1}\)又相当于最小化\(\| \boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t,t\right) \|^2\)。
也就是说,我们的目标是
\[\sum_{t = 1}^ T \left[\| \boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t,t\right) \|^2\right]
\]
因此我们知道:扩散模型的损失函数就是均方误差损失。