Search⌘ K

Array Operations

Explore fundamental JAX array operations including max, argmax, sum, and transpose while understanding the immutability of JAX arrays. Learn methods for safe array updates and how JAX handles out-of-bounds indexing. Discover how to place and manage data on devices such as GPUs or TPUs to optimize your deep learning workflows.

Array operations

Operations on JAX arrays are similar to operations with NumPy arrays. For example, we can use the max(), argmax(), and sum() functions the same as we do in NumPy.

Python 3.8
matrix = jnp.arange(17,33)
matrix = matrix.reshape(4,4)
print("Matrix :",matrix)
print("Maximum :",jnp.max(matrix))
print("Argmax :",jnp.argmax(matrix))
print("Minimum :",jnp.min(matrix))
print("Argmin :",jnp.argmin(matrix))
print("Sum :",jnp.sum(matrix))
print("Square root :",jnp.sqrt(matrix))
print("Transpose :",matrix.transpose())

Let’s review the code:

  • Lines 1–3: We create a JAX array matrix of values 17 to 32 and reshape it as the dimension of 4×44\times4. Lastly, we use the ...