Training model in Flax vs. TensorFlow
Learn about model training in Flax and TensorFlow.
Training models in TensorFlow is done by compiling the network and calling the fit method. However, in Flax, we create a training state to hold the training information and then pass data to the network.
from flax.training import train_statedef create_train_state(rng):"""Creates initial `TrainState`."""model = LSTMModel()params = model.init(rng, jnp.array(X_train_padded[0]))['params']tx = optax.adam(0.001,0.9,0.999,1e-07)return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
In the code above:
Line 1: We import the
train_state
module fromflax.training
.Lines 3–8: We define the
create_train_state()
function that creates the initial state for model training. This function takes the random number generator key,rng
, as the argument. Inside this function:Lines 5–6: We create an instance
model
of theLSTMModel
class and get the initial model parametersparams
by calling theinit()
method of themodel
. This method takes the random number generator key and a sample input,X_train_padded[0]
.Line 7: We define the Adam optimizer with the provided parameters, including learning rate, beta1, beta2, and epsilon.
Line 8: We create and return the train state by calling the
create()
method of thetrain_state.TrainState
module. This method takes three arguments:apply_fn
is the function to apply the model, ...