Just-in-Time (JIT) Compilation
JIT compilation is the core component of JAX. This lesson provides an overview of the feature.
We'll cover the following
What is JIT?
Anyone who has some experience with Java or .NET frameworks will be familiar with Just-In-Time compilation. In Just-In-Time (JIT) compilation, we compile the code at the run-time also known as dynamic translation which results in high-speed code execution.
Python and JIT
One of the major criticisms of Python is its slow execution. However, inspired by .NET framework and Objective-C/Swift, there have been some JIT attempts in Python. Some of the notable examples include CPython and PyPy.
XLA compiler
JAX uses JIT compilation, which enables simple Python functions to target the XLA compiler. Before we move on, it would be worthwhile to give a brief introduction to XLA.
XLA (Accelerated Linear Algebra) is a domain-specific compiler. Originally developed for Tensorflow, XLA allows faster code and more efficient memory usage.
JIT compilation enables XLA to compile the given code into computation kernels that are specific to the given model, which helps to improve performance.
JAX JIT
JAX operations can be either:
- Static
- Dynamic/Traced
Static operations are evaluated on the compile-time, and cannot target the XLA compiler as dynamic/traced operations do.
Luckily, most JAX operations are expressable in XLA terms. All we need to do to achieve this is to import the required library and call the respective function using the jit
.
Let’s illustrate some of the features of jit
using an example.
We’ll begin by making a simple function to demonstrate the difference between standard Python and JIT compilation by a simple square function:
Get hands-on with 1400+ tech skills courses.