...
/Dealing with Imbalanced Datasets in Python Programming
Dealing with Imbalanced Datasets in Python Programming
Learn about the fundamentals of imbalanced datasets and explore how to use SMOTE to effectively handle imbalanced datasets.
In this lesson, we will rectify imbalanced datasets using the MNIST dataset, focusing on the digits 0 and 1. We will investigate the impact of some classes having more examples than others and learn how this affects the model’s performance. Then, we will apply SMOTE to balance the dataset. Moreover, we will train balanced and imbalanced datasets using a CNN. Finally, we will compare the performance of these models using metrics such as accuracy, F1 score, precision, and recall.
This lesson is divided into the following three steps:
Step 1: Using a bar chart, we will visualize how many images of the numbers 0 and 1 are in the MNIST dataset.
Step 2: We will apply SMOTE to balance the imbalanced dataset and create a bar chart to show the updated distribution.
Step 3: We will use a CNN model to train the imbalanced and balanced datasets and measure their performance using metrics such as accuracy, F1 score, precision, and recall.
Step 1: Visualizing the MNIST dataset (digits 0 and 1)
The code provided below generates a bar chart that displays the imbalanced distribution of the digits 0 and 1 in the dataset.
Click the “Run” button to observe the imbalanced dataset’s output of the digits 0 and 1.
# Import necessary librariesimport numpy as npimport matplotlib.pyplot as pltfrom tensorflow.keras.datasets import mnist# Load the MNIST dataset(x_train, y_train), (x_test, y_test) = mnist.load_data()# Filter out digit 0 and 1 onlyindex = np.where((y_train == 0) | (y_train == 1))x_filtered = x_train[index]y_filtered = y_train[index]# Delete 3000 samples from digit 0 of the imbalanced datasetsamples_to_delete = 3000index_to_delete = np.where(y_filtered == 0)[0][:samples_to_delete]x_filtered = np.delete(x_filtered, index_to_delete, axis=0)y_filtered = np.delete(y_filtered, index_to_delete, axis=0)# Show the distribution before balancingplt.bar([0, 1], [np.sum(y_filtered == 0), np.sum(y_filtered == 1)])plt.xlabel('Digits')plt.ylabel('Count')plt.title('Digit Distribution (0 and 1) Before Balancing')plt.show()
Lines 1–5: We import the necessary libraries, including
numpy
,matplotlib
, andmnist
.Line 7: We load the MNIST dataset into the training and testing datasets along with their respective labels.
Lines 11–13: We filter the dataset to retain only the samples with the digits
0
and1
. Thex_filtered
variables hold the corresponding images, and they_filtered
variables contain their respective labels.Lines 15–19: We remove
3000
samples of the digit0
from the filtered dataset and store the resulting data in thex_filtered
array, which consists of images, and they_filtered
array, which consists of labels.Lines 22–26: We create a bar chart illustrating the number of samples in the digits 0 and 1 before balancing.
Step 2: Using SMOTE
SMOTE works by selecting a point from the ...