Disclaimer: This is part of my notes on AI research papers. I do this to learn and communicate what I understand. Feel free to comment if you have any suggestion, that would be very much appreciated.
The following post is a comment on the paper Auto Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
What introduces their contributions is the following question:
How can we perform efficient approximate inference and learning with directed probabilistic models whose continuous latent variables and/or parameters have intractable posterior distributions?
They show how a reparameterization of the variational lower bound yields a simple differentiable unbiased estimator of the lower bound, which they call Stochastic Gradient Variational Bayes (SGVB) estimator. It can be used for efficient approximate posterior inference and learning in directed probabilistic models with continuous latent variables.
For the case of i.i.d. dataset and continuous latent variables, they propose the Auto-Encoding Variational Bayes (AEVB) algorithm. They use the SGVB estimator to optimize a recognition model that allows for efficient approximate posterior inference. This model derives to a Variational Autoencoder (VAE), which is a neural network-based recognition model.
Method
Problem Statement
Given a dataset \(\mathcal{D} = \{\textbf{x}^{(i)}\}_{i=1}^N\) of $N$ i.i.d. samples of some continuous or discrete variable \(\textbf{x}\), and assuming that the data is generated by a random process, involving the continuous random variable \(\textbf{z}\), the generative process consists of two steps:
A value \(\textbf{z}^{(i)}\) is generated from a prior distribution \(p_{\theta^*}(\textbf{z})\).
A value \(\textbf{x}^{(i)}\) is generated from a conditional distribution \(p_{\theta^*}(\textbf{x}|\textbf{z})\).
Parameters $\theta^*$ and the latent variables $\textbf{z}$ are unknown. The goal is to define an algorithm to find an approximation posterior inference of $\textbf{z}$ and learn $\theta$, given the observed dataset $\mathcal{D}$. The algorithm must work in the worst case scenario where the posterior distribution $p_\theta(\textbf{z}|\textbf{x})$ is intractable, the integral of the marginal likelihood $p_\theta(\textbf{x})$ is intractable, and where $\mathcal{D}$ is too large that sampling based solutions are not feasible.
Authors introduce the recognition model $q_\phi(\textbf{z}|\textbf{x})$ to approximate the true intractable posterior $p_\theta(\textbf{z}|\textbf{x})$. This can be seen as a probabilistic encoder that maps the data $\textbf{x}$ to a distribution over the latent space $\textbf{z}$. Similarly, the generative model $p_\theta(\textbf{x}|\textbf{z})$ is a probabilistic decoder that maps the latent variable $\textbf{z}$ to a distribution over the data space $\textbf{x}$.
The Variational Lower Bound
Since $\mathcal{D}$ is i.i.d., the marginal log-likelihood of the data can be written as: $$ \log p_\theta(\textbf{x}^{(1)}, \dots, \textbf{x}^{(N)}) = \sum_{i=1}^N \log p_\theta(\textbf{x}^{(i)}) $$ which given the recognition model can be rewritten using: $$ \log p_\theta(\textbf{x}^{(i)}) = D_{KL}(q_\phi(\textbf{z}|\textbf{x}^{(i)}) || p_\theta(\textbf{z}|\textbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \textbf{x}^{(i)}) $$ where $D_{KL}$ is the Kullback-Leibler divergence between the recognition model and the true posterior, and $\mathcal{L}$ is the variational lower bound, defined as:
\begin{align} \log p_\theta(\textbf{x}^{(i)}) \geq \mathcal{L}(\theta, \phi; \textbf{x}^{(i)}) =& \mathbb{E}_{q_\phi(\textbf{z}|\textbf{x}^{(i)})}[\log p_\theta(\textbf{x}^{(i)},\textbf{z}) - \log q_\phi(\textbf{z}|\textbf{x}^{(i)})] \\ =& \mathbb{E}_{q_\phi(\textbf{z}|\textbf{x}^{(i)})}[\log p_\theta(\textbf{x}^{(i)}|\textbf{z})] - D_{KL}(q_\phi(\textbf{z}|\textbf{x}^{(i)}) || p_\theta(\textbf{z})) \end{align}
In eq. (2), the first term can be seen as the reconstruction error and the second term as the regularization term that ensures that the approximate posterior \(q_\phi(\textbf{z}|\textbf{x})\) is close to the prior $p_\theta(\textbf{z})$. We want to maximize the variational lower bound, w.r.t. the variational parameters $\phi$ and the generative parameters $\theta$. However, the gradient w.r.t. $\phi$ is problematic.
SGVB Estimator and AEVB Algorithm
Kingma and Welling introduce SGVB as a practical estimator of the lower bound and its derivatives w.r.t. the parameters. They introduce the reparameterization trick, where the idea is to reparameterize the random variable $\textbf{z}\sim q_\phi(\textbf{z}|\textbf{x})$ as a deterministic differentiable transformation $g_\phi(\textbf{x}, \epsilon)$ of a random variable $\epsilon \sim p(\epsilon)$ that is independent of the parameters $\phi$ and $\theta$. This allows to form Monte Carlo estimates of the lower bound in eq. (1), which they call the SGVB estimator:
$$\begin{align} \mathcal{L}(\theta, \phi; \textbf{x}^{(i)}) =& \mathbb{E}_{q_\phi(\textbf{z}|\textbf{x}^{(i)})}[\log p_\theta(\textbf{x}^{(i)},\textbf{z}) - \log q_\phi(\textbf{z}|\textbf{x}^{(i)})] \\ =& \mathbb{E}_{p(\epsilon)}[\log p_\theta(\textbf{x}^{(i)}, g_\phi(\textbf{x}^{(i)}, \epsilon)) - \log q_\phi(g_\phi(\textbf{x}^{(i)}, \epsilon)|\textbf{x}^{(i)})] \\ \simeq& \frac{1}{L} \sum_{l=1}^L \log p_\theta(\textbf{x}^{(i)}, \textbf{z}^{(i,l)}) - \log q_\phi(\textbf{z}^{(i,l)}|\textbf{x}^{(i)}) \\ =:& \mathcal{L}^{\text{SGVB}}(\theta, \phi; \textbf{x}^{(i)}) \end{align}$$
where $\textbf{z}^{(i,l)} = g_\phi(\textbf{x}^{(i)}, \epsilon^{(l)})$ and $\epsilon^{(l)} \sim p(\epsilon)$. In the case where the KL-divergence term in eq. (2) can be solved analytically, this yields to a second version of the SGVB estimatior, which is more stable: $$ \mathcal{L}^{\text{SGVB’}}(\theta, \phi; \textbf{x}^{(i)}) = \frac{1}{L} \bigg(\sum_{l=1}^L \log p_\theta(\textbf{x}^{(i)}|\textbf{z}^{(i,l)})\bigg) - D_{KL}(q_\phi(\textbf{z}|\textbf{x}^{(i)}) || p_\theta(\textbf{z})) $$ In practice, for a big dataset with a large number of samples $N$, we can use mini-batches of size $M$ to compute the SGVB estimator. This leads to the AEVB algorithm, which is a stochastic optimization algorithm that computes the gradients of the SGVB estimator for each mini-batch and updates the parameters $\theta$ and $\phi$ using a gradient-based optimization algorithm.
VAE
As an example of their proposed method they introduce the VAE. In this setting, the latent variables are sampled from $p_\theta(\textbf{z}) = \mathcal{N}(\textbf{z}; 0, I)$. The generative model $p_\theta(\textbf{x}|\textbf{z})$ is either a Gaussian (real-valued data) or a Bernoulli (binary data) distribution, whose distribution parameters are computed from $\textbf{z}$ using a neural network. The recognition model $q_\phi(\textbf{z}|\textbf{x})$ is also a neural network that outputs the parameters of the Gaussian distribution of $\textbf{z}$ given $\textbf{x}$. The VAE is trained using the AEVB algorithm, where the gradients of the SGVB estimator are computed using backpropagation through the recognition model and the generative model. The neural networks they use are Multilayer Perceptrons (MLPs) with one hidden layer.
Personal Thoughts
To my understanding, the main contribution of this paper is the introduction of the SGVB estimator, which is a practical way to compute the gradients of the variational lower bound. This allows to use gradient-based optimization algorithms to optimize the parameters of the recognition model and the generative model. This is a very important contribution to the field of variational inference, as it allows to use deep learning models to approximate the posterior distribution of complex probabilistic models.
The VAE is a very interesting model that can be used for unsupervised learning, semi-supervised learning, and generative modeling. It is a very flexible model that can be used for a wide range of applications, such as image generation, text generation, and speech generation. The VAE is a very active area of research, and many extensions and improvements have been proposed since the publication of this paper.
Despite the complex theoretical background, the paper is very well written and easy to understand. The authors provide a lot of details and explanations that make it easy to follow the derivations and the algorithms. The paper is also very well organized, with a clear introduction, a detailed explanation of the method, and a thorough evaluation of the results.
References
[1] Kingma, D. P., & Welling, M. (2013). Auto-Encoding Variational Bayes. arXiv:1312.6114.