Imbalanced Dataset

Learn how you can counter imbalanced data in binary cross-entropy loss of PyTorch.

Introduction to the imbalanced dataset

In our dummy example with two data points, we had one of each class: positive and negative. The dataset was perfectly balanced. Let us create another dummy example with an imbalance, adding two extra data points belonging to the negative class. For the sake of simplicity and to illustrate a quirk in the behavior of BCEWithLogitsLoss, we will give those two extra points the same logits as the other data points in the negative class. It looks like this:

Press + to interact
logit1 = log_odds_ratio(.9)
logit2 = log_odds_ratio(.2)
dummy_imb_labels = torch.tensor([1.0, 0.0, 0.0, 0.0])
dummy_imb_logits = torch.tensor([logit1, logit2, logit2, logit2])
print(dummy_imb_labels, dummy_imb_logits)

Clearly, this is an imbalanced dataset. There are three times more data points in the negative class than in the positive one.

The pos_weight argument

Now, let us turn to the pos_weight argument of BCEWithLogitsLoss. To compensate for the imbalance, one can set the weight equals the ratio of negative to positive examples:

pos_weight=  #points in negative class#points in positive classpos\_weight = \dfrac{\space\space \#points \space in \space negative \space class}{\#points \space in \space positive \space class}

In our imbalanced dummy example, the result would be 3.0. This way, every point in the positive class would have its corresponding loss multiplied by three. Since there is a single label for each data point (c = 1), ...

Access this course and 1400+ top-rated courses and projects.