Black-box Variational Inference with TFP

This is a quick post to briefly demonstrate black box variational inference with Tensor Flow Probability (using TF2).

First import all the modules we need,

The challenge we are trying to address is that we have a posterior, \(\text{p}(x)\), that we can't easily integrate, that might be very complicated, and we want to approximate it with a distribution that is far more tractable, \(\text{q}(x)\). Amazingly, for black box VI, the only thing that we need to be able to do with this tricky posterior is evaluate \(\log\; \text{p}(x)\). In this case we us the sum of two Gaussian distributions,

Defined by the following code (for the distribution and its log): We chose the discrepancy function (this is the choice of divergence we want to use, typically one uses the KL divergence, and can select whether you're more concerned with the approximation, \(\text{q}(x)\), avoiding putting weight on low-probability locations in \(\text{p}(x)\), or not putting weight on places where there is density in \(\text{p}(x)\), by chosing the order of p and q in the divergence. This is selected here, We need to configure the optimiser, pick some initial values for the approximating distribution's parameters. The VI algorithm is an optimisation algorithm that performs gradient descent on a proxy to the KL divergence. Black box variation inference approximates this gradient by sampling lots of times from the approximating distribution \(q\), specifically (from equation 3 in the BBVI paper): \[\nabla_\lambda L \approx \frac{1}{S} \sum_{s=1}^S \nabla_\lambda \log q(z_s|\lambda) \Big(p(x,z_s) - \log q(z_s|\lambda)\Big).\] Luckily we don't need to worry about the details of doing this. We just iterate over the following, Going through this step-by-step,
is telling tensorflow that we will want to evaluate the gradient of the following functions (note we can write everything relatively easily and rely on eager execution).
Here we're specifying \(q\) - it's a single multivariate Gaussian, with a fully flexible covariance matrix (one could instead specify that the covariance is diagonal).
This is the variational line, here we're using the TFP's VI library to specify,

  • logpdist - the function that gives us \(\log p(x)\).
  • surrogate_posterior - the function that gives us \(q(x)\).
  • sample_size=1000 - the number of samples we are using to estimate the gradient.
  • discrepancy_fn=discrepancy_fn - the discrepancy function we discussed earlier (e.g. forward or backward).
Here we use the tape.gradient(z,x) method, which gives the derivative of z with respect to tensor x, and pass these gradients to the optimiser.
The result is a Gaussian with appropriately chosen mean and covariance that roughly describes the true posterior:

Full code block

Here is the full block of code, you'll need to have TFP2 installed, etc.