Colin Carroll - The state of Bayesian workflows in JAX | PyData Vermont 2024

Learn about JAX's powerful Bayesian workflow tools, the Bayou library for PPL interoperability, and how VMAP and NUTS enable efficient probabilistic programming and sampling.

Key takeaways
  • JAX provides powerful tools for Bayesian workflows through automatic differentiation, vectorization (VMAP), and optimization capabilities

  • Bayou is a library that integrates different probabilistic programming languages (PPLs) and samplers, allowing interoperability between TensorFlow Probability, NumPyro, PyMC and other frameworks

  • The No U-Turn Sampler (NUTS) is generally recommended as the default MCMC sampler for most problems, with NumPyro’s implementation being particularly robust

  • For simpler problems, optimization approaches are preferable to MCMC as they’re significantly faster - MCMC should be used when optimization isn’t sufficient

  • JAX’s VMAP transformation enables easy parallelization of operations without having to manually rewrite vectorized code

  • Transform functions in Bayou handle converting between constrained and unconstrained spaces automatically, simplifying work with different probability distributions

  • The Bayesian ecosystem has evolved to separate model specification from inference - libraries can now share model evaluation code

  • Integration with Optax provides access to various optimizers like Adam and LBFGS for Bayesian optimization

  • PPLs can define joint probability distributions and handle normalization constants automatically

  • Recent developments include high-performance samplers written in Rust that can leverage GPUs/TPUs through JAX