...

/

Major Functions in Flax and TensorFlow

Major Functions in Flax and TensorFlow

Learn about the important functions in Flax and Tensorflow.

Flax and TensorFlow

Flax is the neural network library for JAX. TensorFlow is a deep learning library with a large ecosystem of tools and resources. Flax and TensorFlow are similar (for instance, they both can run on XLA) but different in some ways.

Let’s look at the differences between Flax and TensorFlow.

Random number generation

In TensorFlow, we can set global or function level seeds. Generating random numbers in TensorFlow is quite straightforward.

Press + to interact
import tensorflow as tf
tf.random.set_seed(6853)
print(tf.random.uniform([1]))
print(tf.random.uniform([1]))
print(tf.random.uniform([1]))

In the code above:

  • Line 1: We import the TensorFlow library as tf.

  • Line 3: We call the tf.random.set_seed() method to set the global random seed.

  • Lines 4–6: We call the tf.random.uniform() method to print different random numbers.

However, this is not the case in Flax. Flax is built on top of JAX. JAX expects pure functions, meaning functions without any side effects. To achieve this, JAX introduces stateless pseudo-random number generators (PRNGs). For example, calling the random number generator from NumPy will result in a different number every time.

Press + to interact
import numpy as np
print(np.random.random())
print(np.random.random())
print(np.random.random())

In the code above:

  • Line 1: We import the NumPy library as np.

  • Lines 3–5: We call the np.random.random() method to print different random numbers.

In JAX and Flax, the result should be the same on every call. Therefore, we generate random numbers from a random state. The state should not be reused. It can be split to obtain several pseudo-random numbers.

Press + to interact
import jax
key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(key, num=3)
print(key1)
print(key2)
print(key3)

In the code above:

  • Line 1: We import the jax library.

  • Lines 3–4: We initialize the pseudo-random number generator key by calling the jax.random.PRNGKey() method with the seed value of 0. Next, we call the jax.random.split() method to split the key. We set the argument, num=3, to generate three random keys.

  • Lines 6–8: We print the generated PRNG keys to the console.

Model definition

Model definition in TensorFlow is made easy by the Keras API. We can use Keras to define Sequential or Functional networks. Keras has many layers for designing various types of networks, such as CNNs and LSTMs.

Press + to interact
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()

In the code above:

  • Line 1: We import the TensorFlow library as tf.

  • Lines 3–12: We define the MyModel class that inherits the tf.keras.Model to define the neural network architecture. Inside this class:

    • Lines 5–8: We define the __init__() constructor to add layers to the model. Inside this constructor function, we call the super().__init__() method to call the constructor of the parent class. Later, we ...