論文原文:Auto-Encoding Variational Bayes [OpenReview (ICLR 2014) | arXiv]
本文記錄了我在學習 VAE 程序中的一些公式推導和思考,如果你希望從頭開始學習 VAE,建議先看一下蘇劍林的博客(本文末尾有鏈接),
VAE 的整體框架
VAE 認為,隨機變數 \(\boldsymbol{x} \sim p(\boldsymbol{x})\) 由兩個隨機程序得到:
- 根據先驗分布 \(p(\boldsymbol{z})\) 生成隱變數 \(\boldsymbol{z}\),
- 根據條件分布 \(p(\boldsymbol{x} | \boldsymbol{z})\) 由 \(\boldsymbol{z}\) 得到 \(\boldsymbol{x}\),
于是 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 就是我們所需要的生成模型,
一種樸素的想法是:先用亂數生成器生成隱變數 \(\boldsymbol{z}\),然后用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 從 \(\boldsymbol{z}\) 中生成出(或者說重構出) \(\boldsymbol{x}\),通過最小化重構損失來訓練模型,這個想法的問題在于:我們無法找到生成的樣本與原始樣本之間的對應關系,重構損失算不了,無法訓練,
VAE 的做法是引入后驗分布 \(p(\boldsymbol{z} | \boldsymbol{x})\),訓練程序變為:
- 采樣一批原始樣本 \(\boldsymbol{x}\),
- 用 \(p(\boldsymbol{z} | \boldsymbol{x})\) 獲得每個樣本 \(\boldsymbol{x}\) 對應的隱變數 \(\boldsymbol{z}\),
- 用 \(p(\boldsymbol{x} | \boldsymbol{z})\) 從隱變數 \(\boldsymbol{z}\) 中重構出 \(\boldsymbol{x}\),通過最小化重構損失來訓練模型,
從這個角度來看,\(p(\boldsymbol{z} | \boldsymbol{x})\) 相當于編碼器,\(p(\boldsymbol{x} | \boldsymbol{z})\) 相當于解碼器,訓練結束后只需要保留解碼器 \(p(\boldsymbol{x} | \boldsymbol{z})\) 即可,
除了重構損失以外,VAE 還有一項 KL 散度損失,希望近似的后驗分布 \(q(\boldsymbol{z} | \boldsymbol{x})\) 盡量接近先驗分布 \(p(\boldsymbol{z})\),即最小化二者的 KL 散度,
變分下界的推導
現有 \(N\) 個由分布 \(P(\boldsymbol{x}; \boldsymbol{\theta})\) 生成的樣本 \(\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(N)}\),我們可以使用極大似然估計從這些樣本中估計出分布的引數 \(\boldsymbol{\theta}\),即
\[\begin{aligned} \boldsymbol{\theta} & = \operatorname*{argmax}_{\boldsymbol{\theta}} p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta}) \\ & = \operatorname*{argmax}_{\boldsymbol{\theta}} \ln(p(\boldsymbol{x}^{(1)}; \boldsymbol{\theta}) \cdots p(\boldsymbol{x}^{(N)}; \boldsymbol{\theta})) \\ & = \operatorname*{argmax}_{\boldsymbol{\theta}} \sum_{i=1}^n \ln p(\boldsymbol{x}^{(i)}; \boldsymbol{\theta}). \end{aligned} \]后驗分布 \(p(\boldsymbol{z} | \boldsymbol{x}) = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{p(\boldsymbol{x})} = \frac{p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})}{\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}}\) 是 intractable 的,因為分母處的邊緣分布 \(p(\boldsymbol{x})\) 積不出來,具體來說,聯合分布 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{z})p(\boldsymbol{x} | \boldsymbol{z})\) 的運算式非常復雜,\(\int_{\boldsymbol{z}} p(\boldsymbol{x}, \boldsymbol{z}) \mathrm{d}\boldsymbol{z}\) 這個積分找不到決議解,
需要使用變分推斷解決后驗分布無法計算的問題,我們使用一個形式已知的分布 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\) 來近似后驗分布 \(p(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\theta})\),于是有
\[\begin{aligned} \log p(\boldsymbol{x}^{(i)}) & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot 1 \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log\frac{q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}{p(\boldsymbol{z}|\boldsymbol{x}^{(i)})} \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \log p(\boldsymbol{x}^{(i)}) \cdot \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}) \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)})] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z}|\boldsymbol{x}^{(i)})p(\boldsymbol{x}^{(i)}))] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \\ & = \mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z}|\boldsymbol{x}^{(i)})] + L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) \\ & \geq L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}). \end{aligned} \]利用 KL 散度大于等于 0 這一特性,我們得到了對數似然 \(\log p(\boldsymbol{x}^{(i)})\) 的一個下界 \(L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)})\),于是可以將最大化對數似然改為最大化這個下界,
這個下界可以進一步寫成
\[\begin{aligned} L(\boldsymbol{\theta}, \boldsymbol{\phi}; \boldsymbol{x}^{(i)}) & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{x}^{(i)}, \boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log (p(\boldsymbol{z})p(\boldsymbol{x}^{(i)}|\boldsymbol{z}))] \mathrm{d}\boldsymbol{z} \\ & = \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[-\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) + \log p(\boldsymbol{z}) + \log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = -\int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})[\log q(\boldsymbol{z}|\boldsymbol{x}^{(i)}) - \log p(\boldsymbol{z})] \mathrm{d}\boldsymbol{z} + \int_{\boldsymbol{z}} q(\boldsymbol{z}|\boldsymbol{x}^{(i)})\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})] \mathrm{d}\boldsymbol{z} \\ & = -\mathrm{KL}[q(\boldsymbol{z}|\boldsymbol{x}^{(i)}), p(\boldsymbol{z})] + \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z}|\boldsymbol{x}^{(i)})}[\log p(\boldsymbol{x}^{(i)}|\boldsymbol{z})]. \\ \end{aligned} \]其中的第一項是 KL 散度損失,第二項是重構損失,
KL 散度損失
使用標準正態分布作為先驗分布,即 \(p(\boldsymbol{z}) = N(\boldsymbol{z}; \boldsymbol{0}, \boldsymbol{I})\),
使用一個由 MLP 的輸出來引數化的正態分布作為近似后驗分布,即 \(q(\boldsymbol{z}|\boldsymbol{x}^{(i)}; \boldsymbol{\phi}) = N(\boldsymbol{z}; \boldsymbol{\mu}(\boldsymbol{x}^{(i)}; \boldsymbol{\phi}), \boldsymbol{\sigma}^2(\boldsymbol{x}^{(i)}; \boldsymbol{\phi})\boldsymbol{I})\),
選擇正態分布的好處在于 KL 散度的這個積分可以寫出決議解,訓練時直接按照公式計算即可,無需通過采樣的方式來算積分,
由于我們選擇的是各分量獨立的多元正態分布,因此只需要推導一元正態分布的情形即可:
\[\begin{aligned} \mathrm{KL}[N(z; \mu, \sigma^2), N(z; 0, 1)] & = \int_z N(z; \mu, \sigma^2)\log\frac{N(z; \mu, \sigma^2)}{N(z; 0, 1)} \mathrm{d}z \\ & = \int_z N(z; \mu, \sigma^2) \log\frac{\frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{(z - \mu)^2}{2\sigma^2}\right)}{\frac{1}{\sqrt{2\pi}}\exp\left(-\frac{z^2}{2}\right)} \mathrm{d}z \\ & = \int_z N(z; \mu, \sigma^2) \log\left(\frac{1}{\sqrt{\sigma^2}}\exp\left(\frac{1}{2}\left(-\frac{(z - \mu^2)^2}{\sigma^2} + z^2\right)\right)\right) \mathrm{d}z \\ & = \frac{1}{2}\int_z N(z; \mu, \sigma^2) \left(-\log\sigma^2 - \frac{(z - \mu)^2}{\sigma^2} + z^2\right)\mathrm{d}z \\ & = \frac{1}{2}\left(-\log\sigma^2\int_z N(z; \mu, \sigma^2) \mathrm{d}z - \frac{1}{\sigma^2}\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z + \int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\right) \\ & = \frac{1}{2}\left(-\log\sigma^2 \cdot 1 - \frac{1}{\sigma^2} \cdot \sigma^2 + \mu^2 + \sigma^2\right) \\ & = \frac{1}{2}(-\log\sigma^2 - 1 + \mu^2 + \sigma^2). \end{aligned} \]解釋一下倒數第三行的三個積分:
- \(\int_z N(z; \mu, \sigma^2) \mathrm{d}z\) 是概率密度函式的積分,也就是 1,
- \(\int_z N(z; \mu, \sigma^2)(z - \mu)^2\mathrm{d}z\) 是方差的定義,也就是 \(\sigma^2\),
- \(\int_z N(z; \mu, \sigma^2)z^2\mathrm{d}z\) 是正態分布的二階矩,結果為 \(\mu^2 + \sigma^2\),
重構損失
伯努利分布模型
當 \(\boldsymbol{x}\) 是二值向量時,可以用伯努利分布(兩點分布)來建模 \(p(\boldsymbol{x}|\boldsymbol{z})\),即認為向量 \(\boldsymbol{x}\) 的每個維度都服從對應的相互獨立的伯努利分布,使用一個 MLP 來計算各維度所對應的伯努利分布的引數,第 \(i\) 維伯努利分布的引數為 \(y_i = \boldsymbol{y}(\boldsymbol{z})_i\),于是有
\[p(\boldsymbol{x}|\boldsymbol{z}) = \prod_{i=1}^D y_i^{x_i}(1 - y_i)^{1 - x_i}, \]\[\log p(\boldsymbol{x}|\boldsymbol{z}) = \sum_{i=1}^D x_i\log y_i + (1 - x_i)\log(1 - y_i). \]其中 \(D\) 表示向量 \(\boldsymbol{x}\) 的維度,可見此時最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等價于最小化交叉熵損失,
正態分布模型
當 \(\boldsymbol{x}\) 是實值向量時,可以用正態分布來建模 \(p(\boldsymbol{x}|\boldsymbol{z})\),使用一個 MLP 來計算正態分布的引數,于是有
\[\begin{aligned} p(\boldsymbol{x}|\boldsymbol{z}) & = N(\boldsymbol{x}; \boldsymbol{\mu}, \boldsymbol{\sigma}^2\boldsymbol{I}) \\ & = \prod_{i=1}^D N(x_i; \mu_i, \sigma_i^2) \\ & = \left(\prod_{i=1}^D\frac{1}{\sqrt{2\pi}\sigma_i}\right)\exp\left(\sum_{i=1}^D-\frac{(x_i - \mu_i)^2}{2\sigma_i^2}\right), \end{aligned} \]\[\log p(\boldsymbol{x}|\boldsymbol{z}) = -\frac{D}{2}\log 2\pi - \frac{1}{2}\sum_{i=1}^D\log\sigma_i^2 - \frac{1}{2}\sum_{i=1}^D\frac{(x_i - \mu_i)^2}{\sigma_i^2}. \]很多時候我們會假設 \(\sigma_i^2\) 是一個常數,于是 MLP 只需要輸出均值引數 \(\boldsymbol{\mu}\) 即可,此時有
\[\log p(\boldsymbol{x}|\boldsymbol{z}) \sim -\frac{1}{2}\sum_{i=1}^D(x_i - \mu_i)^2 = -\frac{1}{2}\|\boldsymbol{x} - \boldsymbol{\mu}(\boldsymbol{z})\|^2. \]可見此時最大化 \(\log p(\boldsymbol{x}|\boldsymbol{z})\) 等價于最小化 MSE 損失,
重引數化技巧
需要使用重引數化技巧解決采樣 \(z\) 時不可導的問題,解決的思路是先從無引數分布中采樣一個 \(\varepsilon\),再通過變換得到 \(z\),
從 \(N(\mu, \sigma^2)\) 中采樣一個 \(z\),相當于先從 \(N(0, 1)\) 中采樣一個 \(\varepsilon\),然后令 \(z = \mu + \varepsilon\cdot\sigma\),
相關知識
技巧,通過取對數把乘除變成加減:
\[\ln ab = \ln a + \ln b,\ \ln\frac{a}{b} = \ln a - \ln b. \]隨機變數的函式的期望:
\[\mathbb{E}_{x \sim P(x)} g(x) = \int_x p(x)g(x) \mathrm{d}x, \]利用此公式可以將積分改寫成期望的形式,這樣就可以用采樣的方式計算積分了(蒙特卡羅積分法),
條件概率密度的定義:
\[p_{Y|X}(y|x) = \frac{p(x, y)}{p_X(x)}, \]此處的 \(p\) 并不是概率而是概率密度函式,但是這個公式在形式上跟條件概率公式是一樣的,
參考資料
蘇劍林的 VAE 系列博客:
- 變分自編碼器(一):原來是這么一回事 - 科學空間
- 變分自編碼器(二):從貝葉斯觀點出發 - 科學空間
- 變分自編碼器(三):這樣做為什么能成? - 科學空間
15 分鐘了解變分推理:
- 【15分鐘】了解變分推理 - 嗶哩嗶哩
- 【15分鐘】了解變分自編碼器 - 嗶哩嗶哩
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/556470.html
標籤:其他
下一篇:返回列表
