Imbalanced Dataset
Learn how you can counter imbalanced data in binary cross-entropy loss of PyTorch.
We'll cover the following...
Introduction to the imbalanced dataset
In our dummy example with two data points, we had one of each class: positive and negative. The dataset was perfectly balanced. Let us create another dummy example with an imbalance, adding two extra data points belonging to the negative class. For the sake of simplicity and to illustrate a quirk in the behavior of BCEWithLogitsLoss
, we will give those two extra points the same logits as the other data points in the negative class. It looks like this:
logit1 = log_odds_ratio(.9)logit2 = log_odds_ratio(.2)dummy_imb_labels = torch.tensor([1.0, 0.0, 0.0, 0.0])dummy_imb_logits = torch.tensor([logit1, logit2, logit2, logit2])print(dummy_imb_labels, dummy_imb_logits)
Clearly, this is an imbalanced dataset. There are three times more data points in the negative class than in the positive one.
The pos_weight
argument
Now, let us turn to the pos_weight
argument of BCEWithLogitsLoss
. To compensate for the imbalance, one can set the weight equals the ratio of negative to positive examples:
In our imbalanced dummy example, the result would be 3.0. This way, every point in the positive class would have its corresponding loss multiplied by three. Since there is a single label for each data point (c = 1), ...