JAX Overview
Let's get started with an introduction to JAX and its uses.
We'll cover the following
Python and deep learning
In the current era, most programmers are more or less familiar with the role of deep learning in shaping today’s world and some of its applications.
Some of the advancements in the field of deep learning can be attributed to the rapid rise in Python’s usage, and libraries like NumPy, SciPy, Keras, as well as more specialized ones like PyTorch or Tensorflow.
What is JAX?
JAX (Just After eXecution) is a recent machine and deep learning library. Why should we invest our time in learning a new library, though?
Before delving deeper into JAX and its architecture, it’s useful to have a quick overview of its features. We’ll find the answer to this “Why?” below.
JAX features
JAX is basically a
- Just-in-Time (JIT) compilation.
- Enabling NumPy code on not only CPU but GPU and TPU as well.
of both NumPy and native Python code.Automatic differentiation A superior technique for computing the derivatives without any manual calculations .Automatic vectorization A technique for automatically batching the data using vectorized map - Expressing and composing transformations of numerical programs.
- Advanced (pseudo) random number generation.
- More options for control flow.
JAX ecosystem
JAX does not stop there, though. It provides us with a whole ecosystem of exciting libraries like:
- Haiku is a neural network library providing object-oriented programming models.
- RLax is a library for deep reinforcement learning.
- Jraph, pronounced “giraffe”, is a library used for Graph Neural Networks (GNNs).
- Optax provides an easy one-liner interface to utilize gradient-based optimization methods efficiently.
- Chex is used for testing purposes.