...

/

Array Operations

Array Operations

Learn about the common array-related operations of NumPy and JAX.

Array operations

Operations on JAX arrays are similar to operations with NumPy arrays. For example, we can use the max(), argmax(), and sum() functions the same as we do in NumPy.

Press + to interact
matrix = jnp.arange(17,33)
matrix = matrix.reshape(4,4)
print("Matrix :",matrix)
print("Maximum :",jnp.max(matrix))
print("Argmax :",jnp.argmax(matrix))
print("Minimum :",jnp.min(matrix))
print("Argmin :",jnp.argmin(matrix))
print("Sum :",jnp.sum(matrix))
print("Square root :",jnp.sqrt(matrix))
print("Transpose :",matrix.transpose())

Let’s review the code:

  • Lines 1–3: We create a JAX array matrix of values 17 to 32 and reshape it as the dimension of 4×44\times4 ...

Access this course and 1400+ top-rated courses and projects.