Batch normalization implementation in PyTorch

Batch normalization

Batch normalization is a technique that normalizes the inputs to each layer in a network by adjusting and scaling them to have zero mean and unit variance. This means that the data is transformed so that its values have a similar scale, making it easier for the network to learn and converge faster. It is used in deep learning to improve the performance and speed of neural networks.

In this Answer, we are using the PyTorch library to implement batch normalization. PyTorch is an open-source machine learning library that is primarily used for building and training deep neural networks.

PyTorch
PyTorch

Batch normalization is applied to mini-batches of data during training, and the parameters learned during training are then used to normalize the data during inference. The technique can also help to reduce overfitting by adding noise to the inputs and acting as a regularizer.

Syntax

The syntax for batch normalization is given below:

nn.BatchNorm1d()

Implementation with batch normalization

The following code trains a Multilayer Perceptron (MLP) neural network on the CIFAR-10 dataset with batch normalization. The MLP has three hidden layers and one output layer and uses the ReLU activation function.

import os
import torch
from torch import nn
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms
import time
import datetime

class MLP(nn.Module):
 
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Flatten(),
      nn.Linear(32 * 32 * 3, 64),
      nn.BatchNorm1d(64),
      nn.ReLU(),
      nn.Linear(64, 32),
      nn.BatchNorm1d(32),
      nn.ReLU(),
      nn.Linear(32, 10)
    )

  def forward(self, y):
    '''Forward pass'''
    return self.layers(y)
  
  
if __name__ == '__main__':
  
  torch.manual_seed(47)
  
  # Load the CIFAR-10 dataset
  data = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor())
  train_loader = torch.utils.data.DataLoader(data, batch_size=10, shuffle=True, num_workers=1)
 
  mlp = MLP()
  
  # Define the loss function and optimizer
  loss_function = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
  
  # record start time
  start_time = time.time()  
  
  # Run the training epochs
  for i in range(0, 3): 
    print(f'Starting epoch {i+1}')
    current_loss = 0
    
    # Iterate over the DataLoader for training data
    for j, data in enumerate(train_loader, 0):
      
      # Get inputs
      inputs, targets = data
      optimizer.zero_grad()
      # Perform forward pass
      outputs = mlp(inputs)
      
      # Compute loss
      loss = loss_function(outputs, targets)
      
      # Execution of backward pass
      loss.backward()
      
      # Performing optimization
      optimizer.step()
      
      # Print the statistics
      current_loss += loss.item()
      if j % 500 == 499:
          print('Loss after mini-batch %5d: %.3f' %
                (j + 1, current_loss / 500))
          current_loss = 0

  # record end time
  end_time = time.time()  
  print('Training process has been completed. ')
  training_time = end_time - start_time

  # calculate and print the training time in minutes and seconds format
  print('Training time:', str(datetime.timedelta(seconds=training_time)))
MLP with batch normalization

Note: The loss after mini-batch 5000 of epoch three with batch normalization is 1.642.

Code explanation

  • Lines 1–8: We import the necessary packages for implementing the MLP and loading the CIFAR-10 dataset.

  • Line 10: We define MLP as a subclass of the nn.Module class. It contains three linear layers, each followed by batch normalization and ReLU activation.

  • Lines 12–23: First, we use the nn.Flatten layer to flatten the input data (which has shape (batch_size, 3, 32, 32)) into a vector of length 32 x 32 x 3 = 3072. Then, there are three linear layers with 64, 32, and 10 output units, respectively. We use a batch normalization layer and a ReLU activation function between each linear layer.

  • Lines 25–27: In the forward method, the input tensor y is passed through the sequence of layers defined in self.layers. The output of the last linear layer is returned as the output of the MLP.

  • Lines 35–42: We load the CIFAR-10 dataset and create a data loader to iterate over the training data in batches of size 10. We also define the loss function and set up the optimizer (Adam with learning rate 1e-4).

  • Lines 48–75: We train our model on the dataset for three epochs. For each epoch, the loop iterates over the training data in batches, performs forward and backward passes through the MLP, updates the model parameters using the optimizer, and prints the current loss after every 500 mini-batches. The loss is accumulated over the mini-batches and divided by 500 before it is printed, giving us the average loss per mini-batch.

  • Lines 78–83: We record the time of the training process using the time.time() function. To make it more readable, we use the datetime.timedelta() function to format the time into hours, minutes, and seconds.

Implementation without batch normalization

The code below is the same as the one above except that there isn’t batch normalization present in it.

import os
import torch
from torch import nn
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision import transforms
import time
import datetime

class MLP(nn.Module):
 
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Flatten(),
      nn.Linear(32 * 32 * 3, 64),
      nn.ReLU(),
      nn.Linear(64, 32),
      nn.ReLU(),
      nn.Linear(32, 10)
    )

  def forward(self, y):
    '''Forward pass'''
    return self.layers(y)
  
  
if __name__ == '__main__':
  
  torch.manual_seed(47)
  
  # Load the CIFAR-10 dataset
  data = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor())
  train_loader = torch.utils.data.DataLoader(data, batch_size=10, shuffle=True, num_workers=1)
 
  mlp = MLP()
  
  # Define the loss function and optimizer
  loss_function = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
  
  # record start time
  start_time = time.time()  
  
  # Run the training epochs
  for i in range(0, 3): 
    print(f'Starting epoch {i+1}')
    current_loss = 0
    
    # Iterate over the DataLoader for training data
    for j, data in enumerate(train_loader, 0):
      
      # Get inputs
      inputs, targets = data
      optimizer.zero_grad()
      # Perform forward pass
      outputs = mlp(inputs)
      
      # Compute loss
      loss = loss_function(outputs, targets)
      
      # Execution of backward pass
      loss.backward()
      
      # Performing optimization
      optimizer.step()
      
      # Print the statistics
      current_loss += loss.item()
      if j % 500 == 499:
          print('Loss after mini-batch %5d: %.3f' %
                (j + 1, current_loss / 500))
          current_loss = 0

  # record end time
  end_time = time.time()  
  print('Training process has been completed. ')
  training_time = end_time - start_time

  # calculate and print the training time in minutes and seconds format
  print('Training time:', str(datetime.timedelta(seconds=training_time)))
MLP without batch normalization

Note: The loss after mini-batch 5000 of epoch three without using batch normalization is 1.702.

By normalizing the activations of each layer, batch normalization can help to prevent the vanishing and exploding gradient problems that can occur in deep neural networks, leading to faster convergence and better performance.

Conclusion

We can see that adding batch normalization has a significant positive impact on the model’s performance, reducing both training loss and time. In essence, incorporating batch normalization improves the overall efficiency and effectiveness of our model.

Free Resources

Copyright ©2024 Educative, Inc. All rights reserved