Haiku

This lesson will introduce Haiku, a high-level neural network library that provides an object-oriented interface.

Haiku is a tool

For building neural networks
Think: “Sonnet for JAX”

Introduction

If we check the official documentation of the Haiku, we will find the above literal haiku.

Like Keras or Sonnet for Tensorflow, Haiku is a high-level library that provides an object-oriented interface for building neural networks (in JAX).

As a starter, let’s use a two-layer MLP.

Press + to interact
import haiku as hk
def forward(x):
mlp = hk.nets.MLP([4, 3, 3])
return mlp(x)
x = jnp.ones((5,5))
#Line below will throw an error
y = forward(x)

If we run the code above, it will not execute and throw the following error:

ValueError: All hk. Module's must be initialized inside an hk.transform.

Transforms

If you recall, throughout the course, JAX transformations operate only on pure functions. In contrast, Haiku is object-oriented.

Luckily, Haiku’s simple hk.transform helps us by converting an impure Haiku function into a pair of pure functions: init() and apply().

  • Initialize: Takes a PRNGKey and the input matrix (usually of 1’s or 0’s as the purpose here is just to have the size) and returns a randomly initialized matrix. This is yet another application of JAX PRNG.
  • Apply: It takes the initialized matrix, PRNGKey, and the input matrix. Since PRNGKey has little use in neural network transforms (dropout being a notable exception), we can simply pass None. Passing the input (size) matrix is also a redundant feature, but we cannot bypass it.
Press + to interact
def feedForward(x):
mlp = hk.nets.MLP([4, 3, 3])
return mlp(x)
transformedForward = hk.transform(feedForward)
key = jax.random.PRNGKey(0)
X = jnp.ones((5,5))
initialized_X = transformedForward.init(key, X)
Y = transformedForward.apply(initialized_X, None, X)
#you can check the outputs by uncommenting them
#print("----")
#print(initialized_X)
#print("----X and Y-----")
#print(X)
#print(Y)

We can analyze the neural network structure using Pytrees.

Press + to interact
def forward(x):
mlp = hk.nets.MLP([40, 30, 30, 12])
return mlp(x)
forward = hk.transform(forward)
key = jax.random.PRNGKey(0)
X = jnp.ones((5,5))
initialized_X = forward.init(key, X)
Y = forward.apply(initialized_X, None, X)
print(jax.tree_map(jnp.shape,initialized_X))

Having seen the basic mechanisms of Haiku, we’ll explore its vast function library a bit more.


Haiku has support for almost every neural network module:

Linear

For linear modules, we have these three functions:

  • Linear()
  • Bias()
  • nets.MLP()

Linear layer

The Linear() adds a linear layer with the following required parameters:

  • Output size, which is a sequence of integers.
  • Whether to use bias or not, which is set to true by default.

By default, Linear() uses the same initializer for weights as the one used in batch normalization’s original paper.

Bias

If we wanted to add a bias separately, we would use the Bias() function. All its parameters are optional and are mentioned below as a reference:

  • output_size
  • bias_dims is a parameter for bias vector dimensions.
  • b_init specifies the algorithm for bias initialization.
  • name allows us to actually specify a name for any Haiku module.

Multi-layer perceptron (MLP)

As we saw in the last lesson, we can quickly jump to the point by making an MLP using the function, haiku.nets.MLP(). Its parameters are:

  • output_sizes specifies the
...