Matrices
We'll learn how to perform various matrix operations using JAX.
We'll cover the following...
Matrices are at the core of almost any data application. Even vectors can be treated as a matrix with either a row or column dimension of 1.
Note: We will use the common notation of capital letters of linear algebra for matrices in our codes as well.
a = jnp.arange(5)b = 2*jnp.arange(5)c = -1*jnp.arange(5)A = jnp.array((a,b,c)) #Concatenation of vectors to make a matrixprint(A)print(A.shape)
Slicing
We can make submatrices by using the slicing (:
) notation.
Note: Python uses 0-based indexing, as does JAX.
For example, as shown in line 10 below, a submatrix containing the first 2 rows and first 3 columns of the above matrix will be:
B = A[0:2,0:3]
Reshaping functions
Reshaping functions are commonly used in several applications, especially computer vision.
The rule behind any reshaping function is simple: If the input and output matrices have mxn
and jxk
dimensions respectively, then:
...