Colin Carroll - The state of Bayesian workflows in JAX | PyData Vermont 2024

Colin Carroll

Learn about JAX's powerful Bayesian workflow tools, the Bayou library for PPL interoperability, and how VMAP and NUTS enable efficient probabilistic programming and sampling.

Key takeaways
  • JAX provides powerful tools for Bayesian workflows through automatic differentiation, vectorization (VMAP), and optimization capabilities

  • Bayou is a library that integrates different probabilistic programming languages (PPLs) and samplers, allowing interoperability between TensorFlow Probability, NumPyro, PyMC and other frameworks

  • The No U-Turn Sampler (NUTS) is generally recommended as the default MCMC sampler for most problems, with NumPyro’s implementation being particularly robust

  • For simpler problems, optimization approaches are preferable to MCMC as they’re significantly faster - MCMC should be used when optimization isn’t sufficient

  • JAX’s VMAP transformation enables easy parallelization of operations without having to manually rewrite vectorized code

  • Transform functions in Bayou handle converting between constrained and unconstrained spaces automatically, simplifying work with different probability distributions

  • The Bayesian ecosystem has evolved to separate model specification from inference - libraries can now share model evaluation code

  • Integration with Optax provides access to various optimizers like Adam and LBFGS for Bayesian optimization

  • PPLs can define joint probability distributions and handle normalization constants automatically

  • Recent developments include high-performance samplers written in Rust that can leverage GPUs/TPUs through JAX