Introduction to JAX Data Types and Array
Learn about the important datatypes and arrays in JAX.
We'll cover the following...
Overview
JAX is a Python library offering high performance in machine learning with
Automatic differentiation
Vectorization
JIT compilation
Data types in JAX
The data types in NumPy are similar to those in JAX arrays. For instance, here is how we can create float and int data in JAX.
import jax.numpy as jnpx = jnp.float32(1.25844)print("x :", x)y = jnp.int32(45.25844)print("y :", y)
In the code above, we import the JAX version of NumPy and name it jnp. We define two JAX variables, x and y, of types float32 and int32, respectively. Lastly, we print the values of both variables.
When we check the type of the data, we will see that it’s a DeviceArray. In the code below, we can see the same type for both float32 and int32 variables.
import jax.numpy as jnpx = jnp.float32(1.25844)print("type of x: ",type(x))y = jnp.int32(45.25844)print("type of y: ",type(y))
The DeviceArray in JAX is the equivalent of numpy.ndarry in NumPy, and jax.numpy provides an interface similar to NumPy’s. However, JAX also provides jax.lax, a low-level API that is more powerful and stricter. For example, with jax.numpy, we can add numbers that have mixed types, but jax.lax will not allow this.
Ways to create JAX arrays
We can create JAX arrays like we would in NumPy. For example, we can use:
- The
arange()function - The
linspace()function - Python lists
- The
zeros()function - The
ones()function - The
identity()oreye()function
Let’s look at the outputs of the functions above:
import jax.numpy as jnpa = jnp.arange(10)print("a : ", a)b = jnp.linspace(0, 10, 30)print("b :", b)scores = [50,60,70,30,25,70]scores_array = jnp.array(scores)print("scores_array :", scores_array)c = jnp.zeros(5)print("c :", c)d = jnp.ones(5)print("d :", d)e = jnp.eye(5)print("e :", e)f = jnp.identity(5)print("f :", f)
Let’s understand the code above:
Line 3: We call the
jnp.arange()method that generates the JAX array of10elements from 0 to 9.Line 6: We call the
jnp.linspace()method that creates a JAX array of30values that are linearly distributed between 0 to 10. By default, thelinspace()method generates50values. We can generate any number of values in a given range.Lines 9–10: We define a Python list,
scores, and use thejnp.array()method to convert thescoresinto a JAX array.Line 13: We call the
jnp.zeros()method to generate the JAX array of5zero values.Line 16: Similarly, we call the
jnp.ones()method to generate the JAX array of5one values.Line 19: We create an
ofidentity matrix An identity matrix is a square matrix where diagonal values are one, and all other elements are zero. by calling the jnp.eye()method.Line 22: Just like the
jnp.eye()method, we can also generate an identity matrix with thejnp.identity()method.