Introduction to Flax and Linen
This lesson will offer an overview of Flax, a dedicated neural network library of JAX.
While JAX has powerful features, coding our deep learning applications can still be tricky. This isn’t surprising, since JAX is intended to be a generic numeric computational library.
JAX does offer some pretty useful libraries for designing neural networks, though. We’ll review them in this chapter and consolidate our understanding to build the project at the conclusion.
Flax
As a high-performance neural network library, Flax aims to provide flexible designs while coding in JAX.
The main packages in Flax are:
- Neural networks
- Utilities
Neural networks
The package flax.linen
is used for all the required neural network classes. Because of the wide range of functionalities for neural networks, we’ll restrict ourselves to only the most relevant ones here.
Linen philosophy
Linen designers acknowledge that PyTorch and Tensorflow are already established libraries in the deep learning community, so there is little need to build a clone of either.
Note: This section is more about Linen’s philosophy and can be skipped without any loss of continuity.
Flax Linen instead focuses on the strength of JAX ...