Vectors
Explore fundamental vector operations with JAX to strengthen your understanding of linear algebra in deep learning. Learn how to compute inner products, norms, cosine similarity, and apply Taylor approximations using JAX syntax and functions. This lesson helps you refresh key concepts while practicing efficient numerical computations essential for machine learning.
We'll cover the following...
Since we are assuming you are fairly familiar with NumPy, this chapter is by no means a thorough coverage of the topic and will just serve as a quick refresher. We will practice using JAX while discovering the equivalence between normal NumPy and the JAX version of NumPy’s syntax.
Note: As a reference, we’ll use
npandjnprespectively for default and JAX NumPy versions in our codes.
Inner product
The inner product of two vectors can be calculated by any of the three syntaxes (dot(), inner() and @), as shown below:
Linear functions
A linear function can be represented as:
We can compose it as a lambda function too:
Taylor approximation
Taylor approximation lies at the core of major optimization algorithms in machine learning. We can first-order approximate the value of a function near a point z as:
...