Just-in-Time (JIT) Compilation

JIT compilation is the core component of JAX. This lesson provides an overview of the feature.

What is JIT?

Anyone who has some experience with Java or .NET frameworks will be familiar with Just-In-Time compilation. In Just-In-Time (JIT) compilation, we compile the code at the run-time - also known as dynamic translation - which results in high-speed code execution.

Python and JIT

One of the major criticisms of Python is its slow execution. However, inspired by .NET framework and Objective-C/Swift, there have been some JIT attempts in Python. Some of the notable examples include CPython and PyPy.

XLA compiler

JAX uses JIT compilation, which enables simple Python functions to target the XLA compiler. Before we move on, it would be worthwhile to give a brief introduction to XLA.

XLA (Accelerated Linear Algebra) is a domain-specific compiler. Originally developed for Tensorflow, XLA allows faster code and more efficient memory usage.

JIT compilation enables XLA to compile the given code into computation kernels that are specific to the given model, which helps to improve performance.

JAX JIT

JAX operations can be either:

  1. Static
  2. Dynamic/Traced

Static operations are evaluated on the compile-time, and cannot target the XLA compiler as dynamic/traced operations do.

Luckily, most JAX operations are expressable in XLA terms. All we need to do to achieve this is to import the required library and call the respective function using the jit.

Let’s illustrate some of the features of jit using an example.

We’ll begin by making a simple function to demonstrate the difference between standard Python and JIT compilation by a simple square function:

Get hands-on with 1400+ tech skills courses.