Vectors

Learn how you can perform various vector operations using JAX.

Since we are assuming you are fairly familiar with NumPy, this chapter is by no means a thorough coverage of the topic and will just serve as a quick refresher. We will practice using JAX while discovering the equivalence between normal NumPy and the JAX version of NumPy’s syntax.

Note: As a reference, we’ll use np and jnp respectively for default and JAX NumPy versions in our codes.

Inner product

The inner product of two vectors can be calculated by any of the three syntaxes (dot(), inner() and @), as shown below:

Press + to interact
a = jnp.arange(1,10)
b = jnp.arange(11,20)
print(jnp.dot(a,b))
print(jnp.inner(a,b))
print(a@b)
print(b@a)

Linear functions

A linear function can be represented as:

f(x)=αx1+βx2+γx3+.... f(x) = \alpha x_1+\beta x_2+\gamma x_3+.... ...