Major Functions in Flax and TensorFlow
Learn about the important functions in Flax and Tensorflow.
Flax and TensorFlow
Flax is the neural network library for JAX. TensorFlow is a deep learning library with a large ecosystem of tools and resources. Flax and TensorFlow are similar (for instance, they both can run on XLA) but different in some ways.
Let’s look at the differences between Flax and TensorFlow.
Random number generation
In TensorFlow, we can set global or function level seeds. Generating random numbers in TensorFlow is quite straightforward.
import tensorflow as tftf.random.set_seed(6853)print(tf.random.uniform([1]))print(tf.random.uniform([1]))print(tf.random.uniform([1]))
In the code above:
Line 1: We import the TensorFlow library as
tf
.Line 3: We call the
tf.random.set_seed()
method to set the global random seed.Lines 4–6: We call the
tf.random.uniform()
method to print different random numbers.
However, this is not the case in Flax. Flax is built on top of JAX. JAX expects pure functions, meaning functions without any side effects. To achieve this, JAX introduces stateless pseudo-random number generators (PRNGs). For example, calling the random number generator from NumPy will result in a different number every time.
import numpy as npprint(np.random.random())print(np.random.random())print(np.random.random())
In the code above:
Line 1: We import the NumPy library as
np
.Lines 3–5: We call the
np.random.random()
method to print different random numbers.
In JAX and Flax, the result should be the same on every call. Therefore, we generate random numbers from a random state. The state should not be reused. It can be split to obtain several pseudo-random numbers.
import jaxkey = jax.random.PRNGKey(0)key1, key2, key3 = jax.random.split(key, num=3)print(key1)print(key2)print(key3)
In the code above:
Line 1: We import the
jax
library.Lines 3–4: We initialize the pseudo-random number generator
key
by calling thejax.random.PRNGKey()
method with the seed value of0
. Next, we call thejax.random.split()
method to split thekey
. We set the argument,num=3
, to generate three random keys.Lines 6–8: We print the generated PRNG keys to the console.
Model definition
Model definition in TensorFlow is made easy by the Keras API. We can use Keras to define Sequential or Functional networks. Keras has many layers for designing various types of networks, such as CNNs and LSTMs.
import tensorflow as tfclass MyModel(tf.keras.Model):def __init__(self):super().__init__()self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)def call(self, inputs):x = self.dense1(inputs)return self.dense2(x)model = MyModel()
In the code above:
Line 1: We import the TensorFlow library as
tf
.Lines 3–12: We define the
MyModel
class that inherits thetf.keras.Model
to define the neural network architecture. Inside this class:Lines 5–8: We define the
__init__()
constructor to add layers to the model. Inside this constructor function, we call thesuper().__init__()
method to call the constructor of the parent class. Later, we ...