Track Model Training in JAX Using TensorBoard
Learn about tracking JAX model training through TensorBoard.
We'll cover the following...
Logging evaluation metrics
We can log the evaluation metrics when training machine learning models with JAX. They were obtained at the training stage. At this point, we can log the metrics to TensorBoard. In the example below, we log the training and evaluation metrics.
Press + to interact
for epoch in range(1, num_epochs + 1):state, train_metrics = train_one_epoch(state, train_loader)training_loss.append(train_metrics['loss'])training_accuracy.append(train_metrics['accuracy'])print(f"Train epoch: {epoch}, loss: {train_metrics['loss']}, accuracy: {train_metrics['accuracy'] * 100}")test_metrics = evaluate_model(state, test_images, test_labels)testing_loss.append(test_metrics['loss'])testing_accuracy.append(test_metrics['accuracy'])writer.add_scalar('Loss/train', train_metrics['loss'], epoch)writer.add_scalar('Loss/test', test_metrics['loss'], epoch)writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)print(f"Test epoch: {epoch}, loss: {test_metrics['loss']}, accuracy: {test_metrics['accuracy'] * 100}")writer.flush()
...