Transfer Learning using ResNet Model
Learn how to train, evaluate, and visualize the performance of a ResNet model.
We train the ResNet model by applying the train_one_epoch
function for the desired number of epochs. This is a few epochs since we are fine-tuning the network.
Set up TensorBoard in Flax
To monitor model training via TensorBoard, we can write the training and validation metrics to TensorBoard.
Press + to interact
from torch.utils.tensorboard import SummaryWriterlogdir = "flax_logs"writer = SummaryWriter(logdir)
In the code above:
Line 1: We import the
SummaryWriter
module fromtorch.utils.tensorboard
to log in to TensorBoard.Line 3: We define the
logdir
variable to store the path of the logging directory.Line 4: We create an instance of
SummaryWriter
withlogdir
to store the TensorBoard logs.
Train model
We define a function to train and evaluate the model while writing the metrics to TensorBoard.
Press + to interact
(test_images, test_labels) = next(iter(validation_loader))test_images = test_images / 255.0training_loss = []training_accuracy = []testing_loss = []testing_accuracy = []def train_model(epochs):for epoch in range(1, epochs + 1):state, train_metrics = train_one_epoch(state, train_loader)training_loss.append(train_metrics['loss'])training_accuracy.append(train_metrics['accuracy'])test_metrics = evaluate_model(state, test_images, test_labels)testing_loss.append(test_metrics['loss'])testing_accuracy.append(test_metrics['accuracy'])writer.add_scalar('Loss/train', train_metrics['loss'], epoch)writer.add_scalar('Loss/test', test_metrics['loss'], epoch)writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)print(f"Epoch: {epoch}, training loss: {train_metrics['loss']}, training accuracy: {train_metrics['accuracy'] * 100}, validation loss: {test_metrics['loss']}, validation accuracy: {test_metrics['accuracy'] * 100}")return state
In the code above: ...