...

/

Solution: Basics of JAX

Solution: Basics of JAX

Here is the solution to the previous challenges.

Let’s go through each solution in detail.

Solution 1: JAX arrays

We use the arange function to create a JAX array with values ranging from 0 to 66.

Press + to interact
array=jnp.arange(67)
print('array=',array)

Let’s review the code:

  • Line 1: We pass 67 in the arange() function to generate a JAX array from 0 to 66.
  • Line 2: We print the array.

Solution 2: Random numbers

We know that JAX implements random number generation using a random state. This random state is referred to as a key. Using the same key will always generate the same output. We can split this key and generate different ...