...

/

Dealing with Imbalanced Datasets in Python Programming

Dealing with Imbalanced Datasets in Python Programming

Learn about the fundamentals of imbalanced datasets and explore how to use SMOTE to effectively handle imbalanced datasets.

In this lesson, we will rectify imbalanced datasets using the MNIST dataset, focusing on the digits 0 and 1. We will investigate the impact of some classes having more examples than others and learn how this affects the model’s performance. Then, we will apply SMOTE to balance the dataset. Moreover, we will train balanced and imbalanced datasets using a CNN. Finally, we will compare the performance of these models using metrics such as accuracy, F1 score, precision, and recall.

This lesson is divided into the following three steps:

  • Step 1: Using a bar chart, we will visualize how many images of the numbers 0 and 1 are in the MNIST dataset.

  • Step 2: We will apply SMOTE to balance the imbalanced dataset and create a bar chart to show the updated distribution.

  • Step 3: We will use a CNN model to train the imbalanced and balanced datasets and measure their performance using metrics such as accuracy, F1 score, precision, and recall.

Step 1: Visualizing the MNIST dataset (digits 0 and 1)

The code provided below generates a bar chart that displays the imbalanced distribution of the digits 0 and 1 in the dataset.

Click the “Run” button to observe the imbalanced dataset’s output of the digits 0 and 1.

Press + to interact
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Filter out digit 0 and 1 only
index = np.where((y_train == 0) | (y_train == 1))
x_filtered = x_train[index]
y_filtered = y_train[index]
# Delete 3000 samples from digit 0 of the imbalanced dataset
samples_to_delete = 3000
index_to_delete = np.where(y_filtered == 0)[0][:samples_to_delete]
x_filtered = np.delete(x_filtered, index_to_delete, axis=0)
y_filtered = np.delete(y_filtered, index_to_delete, axis=0)
# Show the distribution before balancing
plt.bar([0, 1], [np.sum(y_filtered == 0), np.sum(y_filtered == 1)])
plt.xlabel('Digits')
plt.ylabel('Count')
plt.title('Digit Distribution (0 and 1) Before Balancing')
plt.show()
  • Lines 1–5: We import the necessary libraries, including numpy, matplotlib, and mnist.

  • Line 7: We load the MNIST dataset into the training and testing datasets along with their respective labels.

  • Lines 11–13: We filter the dataset to retain only the samples with the digits 0 and 1. The x_filtered variables hold the corresponding images, and the y_filtered variables contain their respective labels.

  • Lines 15–19: We remove 3000 samples of the digit 0 from the filtered dataset and store the resulting data in the x_filtered array, which consists of images, and the y_filtered array, which consists of labels.

  • Lines 22–26: We create a bar chart illustrating the number of samples in the digits 0 and 1 before balancing.

Step 2: Using SMOTE

SMOTE works by selecting a point from the ...