...

/

Transfer Learning using ResNet Model

Transfer Learning using ResNet Model

Learn how to train, evaluate, and visualize the performance of a ResNet model.

We train the ResNet model by applying the train_one_epoch function for the desired number of epochs. This is a few epochs since we are fine-tuning the network.

Set up TensorBoard in Flax

To monitor model training via TensorBoard, we can write the training and validation metrics to TensorBoard.

Press + to interact
from torch.utils.tensorboard import SummaryWriter
logdir = "flax_logs"
writer = SummaryWriter(logdir)

In the code above:

  • Line 1: We import the SummaryWriter module from torch.utils.tensorboard to log in to TensorBoard.

  • Line 3: We define the logdir variable to store the path of the logging directory.

  • Line 4: We create an instance of SummaryWriter with logdir to store the TensorBoard logs.

Train model

We define a function to train and evaluate the model while writing the metrics to TensorBoard.

Press + to interact
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []
def train_model(epochs):
for epoch in range(1, epochs + 1):
state, train_metrics = train_one_epoch(state, train_loader)
training_loss.append(train_metrics['loss'])
training_accuracy.append(train_metrics['accuracy'])
test_metrics = evaluate_model(state, test_images, test_labels)
testing_loss.append(test_metrics['loss'])
testing_accuracy.append(test_metrics['accuracy'])
writer.add_scalar('Loss/train', train_metrics['loss'], epoch)
writer.add_scalar('Loss/test', test_metrics['loss'], epoch)
writer.add_scalar('Accuracy/train', train_metrics['accuracy'], epoch)
writer.add_scalar('Accuracy/test', test_metrics['accuracy'], epoch)
print(f"Epoch: {epoch}, training loss: {train_metrics['loss']}, training accuracy: {train_metrics['accuracy'] * 100}, validation loss: {test_metrics['loss']}, validation accuracy: {test_metrics['accuracy'] * 100}")
return state

In the code above: ...