Haiku
This lesson will introduce Haiku, a high-level neural network library that provides an object-oriented interface.
Haiku is a tool
For building neural networksThink: “Sonnet for JAX”Introduction
If we check the official documentation of the Haiku, we will find the above literal haiku.
Like Keras or Sonnet for Tensorflow, Haiku is a high-level library that provides an object-oriented interface for building neural networks (in JAX).
As a starter, let’s use a two-layer MLP.
import haiku as hkdef forward(x):mlp = hk.nets.MLP([4, 3, 3])return mlp(x)x = jnp.ones((5,5))#Line below will throw an errory = forward(x)
If we run the code above, it will not execute and throw the following error:
ValueError: All
hk. Module's
must be initialized inside anhk.transform
.
Transforms
If you recall, throughout the course, JAX transformations operate only on pure functions. In contrast, Haiku is object-oriented.
Luckily, Haiku’s simple hk.transform
helps us by converting an impure Haiku function into a pair of pure functions: init()
and apply()
.
- Initialize: Takes a
PRNGKey
and the input matrix (usually of 1’s or 0’s as the purpose here is just to have the size) and returns a randomly initialized matrix. This is yet another application of JAX PRNG. - Apply: It takes the initialized matrix,
PRNGKey
, and the input matrix. SincePRNGKey
has little use in neural network transforms (dropout being a notable exception), we can simply passNone
. Passing the input (size) matrix is also a redundant feature, but we cannot bypass it.
def feedForward(x):mlp = hk.nets.MLP([4, 3, 3])return mlp(x)transformedForward = hk.transform(feedForward)key = jax.random.PRNGKey(0)X = jnp.ones((5,5))initialized_X = transformedForward.init(key, X)Y = transformedForward.apply(initialized_X, None, X)#you can check the outputs by uncommenting them#print("----")#print(initialized_X)#print("----X and Y-----")#print(X)#print(Y)
We can analyze the neural network structure using Pytrees.
def forward(x):mlp = hk.nets.MLP([40, 30, 30, 12])return mlp(x)forward = hk.transform(forward)key = jax.random.PRNGKey(0)X = jnp.ones((5,5))initialized_X = forward.init(key, X)Y = forward.apply(initialized_X, None, X)print(jax.tree_map(jnp.shape,initialized_X))
Having seen the basic mechanisms of Haiku, we’ll explore its vast function library a bit more.
Haiku has support for almost every neural network module:
Linear
For linear modules, we have these three functions:
Linear()
Bias()
nets.MLP()
Linear layer
The Linear()
adds a linear layer with the following required parameters:
- Output size, which is a sequence of integers.
- Whether to use bias or not, which is set to true by default.
By default,
Linear()
uses the same initializer for weights as the one used in batch normalization’s original paper.
Bias
If we wanted to add a bias separately, we would use the Bias()
function. All its parameters are optional and are mentioned below as a reference:
output_size
bias_dims
is a parameter for bias vector dimensions.b_init
specifies the algorithm for bias initialization.name
allows us to actually specify a name for any Haiku module.
Multi-layer perceptron (MLP)
As we saw in the last lesson, we can quickly jump to the point by making an MLP using the function, haiku.nets.MLP()
. Its parameters are:
output_sizes
specifies the