Solution: Basics of JAX
Here is the solution to the previous challenges.
Let’s go through each solution in detail.
Solution 1: JAX arrays
We use the arange
function to create a JAX array with values ranging from 0 to 66.
Press + to interact
array=jnp.arange(67)print('array=',array)
Let’s review the code:
- Line 1: We pass
67
in thearange()
function to generate a JAX array from 0 to 66. - Line 2: We print the
array
.
Solution 2: Random numbers
We know that JAX implements random number generation using a random state. This random state is referred to as a key. Using the same key will always generate the same output. We can split this key and generate different ...