Matrices

We'll learn how to perform various matrix operations using JAX.

Matrices are at the core of almost any data application. Even vectors can be treated as a matrix with either a row or column dimension of 1.

Note: We will use the common notation of capital letters of linear algebra for matrices in our codes as well.

Press + to interact
a = jnp.arange(5)
b = 2*jnp.arange(5)
c = -1*jnp.arange(5)
A = jnp.array((a,b,c)) #Concatenation of vectors to make a matrix
print(A)
print(A.shape)

Slicing

We can make submatrices by using the slicing (:) notation.

Note: Python uses 0-based indexing, as does JAX.

For example, as shown in line 10 below, a submatrix containing the first 2 rows and first 3 columns of the above matrix will be:

B = A[0:2,0:3]

Reshaping functions

Reshaping functions are commonly used in several applications, especially computer vision.

The rule behind any reshaping function is simple: If the input and output matrices have mxn and jxk dimensions respectively, then:

m×n=j×km\times n = j\times k ...