JAX 001 - A closer look at its background
I recently wanted to get into parallel programming so that I can optimize one of the optics simulation projects I was working on (it was built without much focus on distributed training). So with all the buzz going around with JAX and some cool applications I found (listed below), I thought to give it a try.
- Protein Structure Prediction - AlphaFold 2 [1]
- Differentiable, Hardware Accelerated, Molecular Dynamics - JAX-MD [2]
- Massively parallel rigid-body physics simulation - Brax
- Chemical Modelling - JAXChem
- Computational Fluid Dynamics - JAX-CFD
- Differentiable Cosmology - JAX-Cosmo
Why JAX?
With all the established libraries such as PyTorch and Tensorflow, why is there a requirement for this new library in the first place? Let's find out.
We all know that increasing FLOPS (floating point operations per second) is a huge deal in machine learning to train models efficiently. JAX aims to help this goal by enabling researchers to write python programs which are automatically compiled and scaled to utilize accelerators (GPUs/TPUs). Often it is hard to write optimized code in python to leverage the potential of hardware accelerators. JAX aims to keep a balance between research-friendly programming experience vs hardware acceleration.
To do so, JAX aims to accelerate pure-and-statically-composed (PSC) Subroutines. It is done through a just-in-time (JIT) compiler which traces PSC routines. JAX Paper)
For this compilation process, the execution of the code needs to be monitored once. Therefore, JAX stands for "Just After eXecution". (Source :Key features of JAX
JAX addresses several limitations in numpy. Therefore, it presents,
- Lightweight NumPy-like API for array-based computing
- Composable function transformations
(autodiff, JIT compilation, vectorization, parallelization) - Execute on CPU, GPU, or TPU without changing your code.
JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.
Getting familiar with related concepts and history
There were a few loosly defined terms in my mind when I first read through the documentation. So I thought to look into those and have an idea about them.
Autograd - Autograd is an example of an automatic differentiation library released in 2014. It is a lightweight tool to automatically differentiate native python and numpy code. But it doesn't focus on hardware accelerators such GPU/TPUs. Therefore, since 2018, the main developers of Autograd are focusing on JAX which has more functionalities. Dougal, one of the main authors of autograd, calls JAX as the second generation of the project which started with Autograd.
"XLA - (Accelerated Linear Algebra) is a compiler-based linear algebra execution engine. It is the backend that powers machine learning frameworks such as TensorFlow and JAX at Google, on a variety of devices including CPUs, GPUs, and TPUs." - (to read more)
JIT - (Just in time) compilation: This enables compilation of a JAX Python function so it can be executed efficiently in XLA.
Digging into these concepts made me realize that while the libraries we use makes it really easy to build end-to-end differentiable models, under the hood, these tools have been years-long effort by amazing teams with cool design choices.
In the next writeup I'm going to share my experience with device parallelization (pmap
) and automatic vectorization (vmap
) aspects.
Resources :
- https://jax.readthedocs.io/en/latest/index.html
- Compiling machine learning programs via high-level tracing [pdf]
- Matthew Johnson - JAX: accelerated ML research via composable function transformations (video) [slides]
- JAX at DeepMind (video) [slides]
- Lecture 6: Automatic Differentiation [pdf]
- 🌟 Accelerate your cutting-edge machine learning research with free Cloud TPUs. (apply here).