Awesome JAX
JAX brings automatic
differentiation and the
XLA compiler together through
a NumPy-like API for high performance
machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other
resources. Contributions are welcome!
Contents
Libraries
-
Neural Network Libraries
-
Flax - Centered on
flexibility and clarity.
-
Haiku - Focused
on simplicity, created by the authors of Sonnet at DeepMind.
-
Objax - Has an object
oriented design similar to PyTorch.
-
Elegy - A
framework-agnostic Trainer interface for the Jax ecosystem. Supports
Flax, Haiku, and Optax.
-
Trax - “Batteries
included” deep learning library focused on providing solutions for
common workloads.
-
Jraph - Lightweight
graph neural network library.
-
Neural Tangents
- High-level API for specifying neural networks of both finite and
infinite width.
-
HuggingFace
- Ecosystem of pretrained Transformers for a wide range of natural
language tasks (Flax).
-
Equinox -
Callable PyTrees and filtered JIT/grad transformations => neural
networks in JAX.
-
NumPyro -
Probabilistic programming based on the Pyro library.
-
Chex - Utilities to write
and test reliable JAX code.
-
Optax - Gradient
processing and optimization library.
-
RLax - Library for
implementing reinforcement learning agents.
-
JAX, M.D. - Accelerated,
differential molecular dynamics.
-
Coax - Turn RL papers
into code, the easy way.
-
SymJAX - Symbolic
CPU/GPU/TPU programming.
-
mcx - Express & compile
probabilistic programs for performant inference.
-
Distrax -
Reimplementation of TensorFlow Probability, containing probability
distributions and bijectors.
-
cvxpylayers -
Construct differentiable convex optimization layers.
-
TensorLy - Tensor
learning made simple.
-
NetKet - Machine Learning
toolbox for Quantum Physics.
New Libraries
This section contains libraries that are well-made and useful, but have
not necessarily been battle-tested by a large userbase yet.
-
Neural Network Libraries
-
FedJAX - Federated
learning in JAX, built on Optax and Haiku.
-
Equivariant MLP
- Construct equivariant neural network layers.
-
jax-resnet -
Implementations and checkpoints for ResNet variants in Flax.
-
jax-unirep - Library
implementing the
UniRep model
for protein machine learning applications.
-
jax-flows -
Normalizing flows in JAX.
-
sklearn-jax-kernels
-
scikit-learn
kernel matrices using JAX.
-
jax-cosmo
- Differentiable cosmology library.
-
efax - Exponential
Families in JAX.
-
mpi4jax - Combine
MPI operations with your Jax code on CPUs and GPUs.
-
imax - Image augmentations
and transformations.
-
FlaxVision - Flax
version of TorchVision.
-
Oryx
- Probabilistic programming language based on program transformations.
-
Optimal Transport Tools
- Toolbox that bundles utilities to solve optimal transport problems.
-
delta PV - A
photovoltaic simulator with automatic differentation.
-
jaxlie - Lie theory
library for rigid body transformations and optimization.
-
BRAX - Differentiable
physics engine to simulate environments along with learning algorithms
to train agents for these environments.
-
flaxmodels -
Pretrained models for Jax/Flax.
-
CR.Sparse -
XLA accelerated algorithms for sparse representations and compressive
sensing.
-
exojax -
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs
compatible to JAX.
-
JAXopt - Hardware
accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
-
PIX - PIX is an image
processing library in JAX, for JAX.
Models and Projects
JAX
Flax
Haiku
Trax
-
Reformer
- Implementation of the Reformer (efficient transformer) architecture.
Videos
-
NeurIPS 2020: JAX Ecosystem Meetup
- JAX, its use at DeepMind, and discussion between engineers,
scientists, and JAX core team.
-
Introduction to JAX - Simple
neural network from scratch in JAX.
-
JAX: Accelerated Machine Learning Research | SciPy 2020 |
VanderPlas
- JAX’s core design, how it’s powering new research, and how you can
start using it.
-
Bayesian Programming with JAX + NumPyro — Andy Kitchen
- Introduction to Bayesian modelling using NumPyro.
-
JAX: Accelerated machine-learning research via composable function
transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne
- JAX intro presentation in
Program Transformations for Machine Learning
workshop.
-
JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James
Bradbury
- Presentation of TPU host access with demo.
-
Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and
Beyond | NeurIPS 2020
- Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with
Colab notebooks avaliable in
Deep Implicit Layers.
-
Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey
- A four part YouTube tutorial series with Colab notebooks that starts
with Jax fundamentals and moves up to training with a data parallel
approach on a v3-32 TPU Pod slice.
-
JAX, Flax & Transformers 🤗
- 3 days of talks around JAX / Flax, Transformers, large-scale language
modeling and other great topics.
Papers
This section contains papers focused on JAX (e.g. JAX-based library
whitepapers, research on JAX, etc). Papers implemented in JAX are listed
in the Models/Projects section.
-
Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary.
MLSys 2018.
- White paper describing an early version of JAX, detailing how
computation is traced and compiled.
-
JAX, M.D.: A Framework for Differentiable Physics.
Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020.
- Introduces JAX, M.D., a differentiable physics library which includes
simulation environments, interaction potentials, neural networks, and
more.
-
Enabling Fast Differentially Private SGD via Just-in-Time
Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath.
arXiv 2020.
- Uses JAX’s JIT and VMAP to achieve faster differentially private than
existing libraries.
Tutorials and Blog Posts
Contributing
Contributions welcome! Read the
contribution guidelines first.