Refining Neural Networks with Class Weights

Discover how adjusting class weights in neural network models can address the challenge of predicting rare events.

We'll cover the following

Loss function

A rare event problem has very few positively labeled samples. Due to this, even if the classifier misclassifies the positive labels, their effect on the loss function is minuscule.

L(θ)=1ni=1n(yilog(pi)+(1yi)log(1pi))\mathcal{L(\theta)}=−\frac{1}{n}\sum_{i=1}^{n}\big( y_{i}log(p_{i})+(1−y_{i})log(1−p_{i})\big)

where

  • yi0,1y_{i} ∈ 0,1 are the true labels.
  • pi=Pr[yi=1]p_{i} = \text{Pr}[y_{i} = 1] is the predicted probability for yi=1y_{i} = 1 .

Remember the loss function in the equation above; it gives equal importance to the positive and negative samples. We can overweight the positives and underweight the negative samples to overcome this. A binary cross-entropy loss function will then be:

L(θ)=1ni=1n(w1yilog(pi)+w0(1yi)log(1pi))\mathcal{L(\theta)}=−\frac{1}{n} \sum_{i=1}^{n} \big(w_{1}y_{i}log(p_{i})+w_{0}(1−y_{i})\log(1−p_{i})\big)

where

  • w1>w0w_1 > w_0.
  • yi0,1y_{i} ∈ 0,1 are the true labels.
  • pi=Pr[yi=1]p_{i} = \text{Pr}[y_{i} = 1] is the predicted probability for yi=1y_{i} = 1.

Class-weighting approach

The class-weighting approach works as follows:

  • The model estimation objective is to minimize the loss. In a perfect case, if the model could predict all the labels perfectly, that is, pi=1yi=1p_i = 1|y_i = 1 and pi=0yi=0p_i = 0|y_i = 0, the loss will be zero. Therefore, the best model estimate is the one with the loss closest to zero.

  • With the class weights, w1>w0w_1 > w_0, if the model misclassifies the positive samples, that is, pi0yi=1p_i → 0|y_i = 1, the loss goes farther away from zero as compared to if the negative samples are misclassified. In other words, the model training penalizes the misclassification of positives more than negatives.

  • Therefore, the model estimation strives to classify the minority positive samples correctly.

In principle, any arbitrary weights such that w1>w0w_1 > w_0 can be used. But a rule-of-thumb is:

  • w1w_1:

    positive class weight=number of negative samplestotal samples\text{positive class weight} = \frac{\text{number of negative samples}} {\text{total samples}}

  • w0w_0:

    negative class weight=number of positive samplestotal samples\text{negative class weight} = \frac{\text{number of positive samples}} {\text{total samples}}

Get hands-on with 1200+ tech skills courses.