...

/

Track Model Training in JAX Using TensorBoard

Track Model Training in JAX Using TensorBoard

Learn about tracking JAX model training through TensorBoard.

We'll cover the following...

Logging evaluation metrics

We can log the evaluation metrics when training machine learning models with JAX. They were obtained at the training stage. At this point, we can log the metrics to TensorBoard. In the example below, we log the training and evaluation metrics.

Press + to interact
for epoch in range(1, num_epochs + 1):
state, train_metrics = train_one_epoch(state, train_loader)
training_loss.append(train_metrics['loss'])
training_accuracy.append(train_metrics['accuracy'])
print(f"Train epoch: {epoch}, loss: {train_metrics['loss']}, accuracy: {train_metrics['accuracy'] * 100}")
test_metrics = evaluate_model(state, test_images, test_labels)
testing_loss.append(test_metrics['loss'])
testing_accuracy.append(test_metrics['accuracy'])
writer.add_scalar('Loss/train', train_metrics['loss'], epoch)
writer.add_scalar('Loss/test', test_metrics['loss'], epoch)
writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)
writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)
print(f"Test epoch: {epoch}, loss: {test_metrics['loss']}, accuracy: {test_metrics['accuracy'] * 100}")
writer.flush()
...