GenJAX is a probabilistic programming system built on JAX that supports programmable inference. GenJAX programs compile to efficient, vectorized JAX code for fast inference on GPUs and TPUs.

See the POPL 2026 paper: “Probabilistic Programming with Vectorized Programmable Inference” (Becker, Huot, Matheos, Wang, Chung, Smith, Ritchie, Saurous, Lew, Rinard, Mansinghka). Paper (PDF)