...

/

Training model in Flax vs. TensorFlow

Training model in Flax vs. TensorFlow

Learn about model training in Flax and TensorFlow.

Training models in TensorFlow is done by compiling the network and calling the fit method. However, in Flax, we create a training state to hold the training information and then pass data to the network.

Press + to interact
from flax.training import train_state
def create_train_state(rng):
"""Creates initial `TrainState`."""
model = LSTMModel()
params = model.init(rng, jnp.array(X_train_padded[0]))['params']
tx = optax.adam(0.001,0.9,0.999,1e-07)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In the code above:

  • Line 1: We import the train_state module from flax.training.

  • Lines 3–8: We define the create_train_state() function that creates the initial state for model training. This function takes the random number generator key, rng, as the argument. Inside this function:

    • Lines 5–6: We create an instance model of the LSTMModel class and get the initial model parameters params by calling the init() method of the model. This method takes the random number generator key and a sample input, X_train_padded[0].

    • Line 7: We define the Adam optimizer with the provided parameters, including learning rate, beta1, beta2, and epsilon.

    • Line 8: We create and return the train state by calling the create() method of the train_state.TrainState module. This method takes three arguments: apply_fn is the function to apply the model, params are the model parameters, and tx is the optimizer.

After that, we define a training step that will compute the loss and gradients. It then uses these gradients to update the model parameters and returns the model metrics and new state.

Press + to interact
@jax.jit
def train_step(state, text, labels):
def loss_fn(params):
logits = LSTMModel().apply({'params': params}, text)
loss = jnp.mean(optax.softmax_cross_entropy(
logits=logits,
labels=jax.nn.one_hot(labels, num_classes=2)))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, labels)
return state, metrics

We define the train_step() function that performs the model ...