Revisit of Local Reparameterization Trick
Brief introduction
Paper can be found here [1]. This paper proposed local reparameterization technique to reduce the variance of stochastic gradients for variational Bayesian inference (SGVB) of a posterior over model parameters, while retaining parallelizability.
Notation: data $\mathcal{D} = \{\nbr{\x_i, y_i}\}_{i=1}^N$ , a single data $\mathcal{D^i} = \{(\x_i, y_i)\}$
Todo:
- Introduction
- Local Reparameterization Trick
- Variance of Gradient
Variance of the SGVB estimator
Recall Bayesian framework $P(\w|\mathcal{D}) = \frac{P(\mathcal{D}\mid\w) P(\w)}{P(\mathcal{D})}$ and variational posterior $q(\w\mid\mathcal{D})$. Recall the loss function negative ELBO:
\[\begin{align} L(\theta;\mathcal{D}) &= \underbrace{KL[q(\w\mid\theta)\mid P(\w)]}_{\text{complexity cost}} - \underbrace{\mathbb{E}_{q(\w\mid\mathcal{D})}[\log P(\mathcal{D}\mid\w)]}_{\text{likelihood cost}}\\ \end{align}\]We assume that the complexity cost can be calculated analytically and the likelihood requires approximations. Hence, the variance is from the likelihood term which can be written as
\[\begin{align} \mathbb{E}_{q(\w\mid\mathcal{D})}[\log P(\mathcal{D}\mid\w)] &= \mathbb{E}_{q(\w\mid\mathcal{D})}\sbr{\sum_{i=1}^N\log P(\mathcal{D^i\mid\w})}\\ &\approx\sum_{i=1}^{N}\sbr{\frac{1}{n}\sum_{j=1}^n\log P(\mathcal{D^i\mid\w^j})}\\ &\approx\sum_{i=1}^{N}\log P(\mathcal{D^i\mid\w}) \quad \text{for $n=1$} \end{align}\]where we use $n$ samples for each data $\mathcal{D}^i$ and there will be $N\times n$ samples of $\w$. If $n = 1$, Eq. (4) is the summation of log density of likelihood.
To construct the expected likelihood of the full data based on minibatches with size $M$, Eq. (3) is further written as
\[L_{M} = \mathbb{E}_{q(\w\mid\mathcal{D})}[\log P(\mathcal{D}\mid\w)]\approx \frac{N}{M}\sum_{i=1}^{M}\log P(\mathcal{D^i\mid\w}),\]where we set $n = 1$. In VAE paper [2], it says that if the batch size $M$ is large enough, $n$ could be $1$. The variance of $L_M$ can be decomposed as
\[\begin{align} Var[L_M] &= \frac{N^2}{M^2}\nbr{\sum_{i=1}^{M}Var[L_i] + 2\sum_{i=1}^{M}\sum_{j=i+1}^{M}Cov[L_i,L_j]}\nonumber \\ &= N^2\nbr{\frac{1}{M}Var[L_i] + \frac{M-1}{M}Cov[L_i, L_j]}, \end{align}\]where $L_M = \frac{N}{M}\sum_{i=1}^{M}L_i$ and $L_i = \log P(\mathcal{D^i\mid\w})$. The problem here is the coefficient $\frac{M-1}{M}$ ahead of covariance. It is not inversely proportional to $M$ as the variance term. Our target now is to remove the covariance, i.e., $Cov[L_i, L_j] = 0.$
Local reparameterization trick
Some notations:
- $\bm{A}: M \times 1000$ feature matrix with $M$ samples
- $\bm{W}: 1000 \times 1000$ weight matrix with the element $w_{i,j} \sim \N(\mu_{i,j}, \sigma^{2}_{i,j})$
- $\bm{B} = \bm{AW}$: $M \times 1000$ output before activation
With the reparameterization trick, $w_{i,j} = \mu_{i,j} + \sigma_{i,j}\epsilon_{i,j}$ where $\epsilon_{i,j} \sim \N(0,1)$. To make $Cov[L_i, L_j] = 0$, we can sample $\w^1,…,\w^M$ for each observation in the minbatch because $\w^1,…,\w^M$ are i.i.d. and the randomness of the likelihood is only from the weight $\w$. However, it requires $M \times 1000 \times 1000$ samples for the minibatch only in a single layer.
To address the issue, we can sample $\bm{B}$ directly. $\bm{B=AW}$ where $\w \sim \N(\bm{\mu}, \bm{\Sigma})$ can be regarded as a sequence of linear combination of Gaussian random variables which is still Gaussian:
\[\bm{B} \sim \N(\bs{A\mu},\bs{A\Sigma A^T})\]For each element $b_{m,j}$:
\[\begin{align} \gamma_{m,j} &= \sum_{i=1}^{1000}a_{m,i}\mu_{i,j}\\ \delta_{m,j} &= \sum_{i=1}^{1000}a^2_{m,i}\sigma^2_{i,j} \end{align}\]To sample from $b_{m,j} = \gamma_{m,j} + \sqrt{\delta_{m,j}}\zeta_{m,j}$ where $\zeta_{m,j} \sim \N(0,1)$. In this case, the number of samples is $M \times 1000$.
Variance of gradient
The local reparameterization trick also leads to lower variance of gradient compared to the method that sample $\w$ for each data. The gradient of posterior variance $\sigma^2_{i,j}$ with and without local reparameterization trick in Eq. (8) and (9):
\[\begin{align} \frac{\partial L_M}{\partial \sigma^2_{i,j}} &= \frac{\partial L_M}{\partial b_{m,j}}\frac{\partial b_{m,j}}{\partial \delta_{m,j}}\frac{\partial \delta_{m,j}}{\partial \sigma^2_{i,j}}\nonumber \\ &= \frac{\partial L_M}{\partial b_{m,j}}\frac{\zeta_{m,j}}{2\sqrt{\delta_{m,j}}}a^2_{m,i}\\ \end{align}\] \[\begin{equation} \frac{\partial L_M}{\partial \sigma^2_{i,j}} = \frac{\partial L_M}{\partial b_{m,j}}\frac{\epsilon_{i,j}}{2\sigma_{i,j}}a_{m,i} \end{equation}\]The two sources of randomness are $b_{m,j}$ and $\zeta_{m,j}$ for Eq. (8), $b_{m,j}$ and $\epsilon_{i,j}$ for Eq. (9). We may notice the difference of the subscript between random variable $\epsilon$ and $\zeta$. One intuitive explanation that Eq. (8) has lower variance than Eq. (9) is that there is more than one random variables $\epsilon$ contributes to $b_{m,j}$ in Eq. (9), i.e., $\epsilon_{i,j}$ for $i = 1,…,1000$ in this case. In contrast, only one variable $\zeta_{m,j}$ contributes to $b_{m,j}$ in Eq. (9). The mathematical analysis can be found in Appendix D from [1]. It decomposes the variance by total variance law and conditioning on $b_{m,j}$.
References
[1] Kingma, Durk P., Tim Salimans, and Max Welling. “Variational dropout and the local reparameterization trick.” Advances in neural information processing systems 28 (2015).
[2] Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).