...

/

Model Configuration, Training, and Predictions for Classification

Model Configuration, Training, and Predictions for Classification

Get to know the model configuration and training implementation changes for classification problems along with the steps of making predictions.

Model configuration

In the chapter, Going Classy, we ended up with a lean model configuration part. We only need to define a model, an appropriate loss function, and an optimizer. Let us define a model that produces logits and uses BCEWithLogitsLoss as the loss function. Since we have two features, and we are producing logits instead of probabilities, our model has one layer and one layer alone: Linear(2, 1). We will keep using the SGD optimizer with a learning rate of 0.1 for now.

This is what the model configuration looks like for our classification problem:

Press + to interact
# Sets learning rate - this is "eta" ~ the "n" like Greek letter
lr = 0.1
torch.manual_seed(42)
model = nn.Sequential()
model.add_module('linear', nn.Linear(2, 1))
# Defines a SGD optimizer to update the parameters
optimizer = optim.SGD(model.parameters(), lr=lr)
# Defines a BCE with logits loss function
loss_fn = nn.BCEWithLogitsLoss()

Model training

Time to train our model! We can leverage the StepByStep class we built in the chapter, Going Classy and use pretty much the same code as before:

Press + to interact
n_epochs = 100
sbs = StepByStep(model, loss_fn, optimizer)
sbs.set_loaders(train_loader, val_loader)
sbs.train(n_epochs)

After training our model, if we try plotting our results, we get the following training and validation losses:

“Wait, there is something weird with this plot…”

Having cleared that, it is time to inspect the model’s trained parameters:

Press + to interact
# Checking model parameters
print(model.state_dict())

GPU users will get an output similar to the following:

Our model produced logits, right? So we can plug the weights above in the corresponding logit equation (equation 6.3), and we end up with:

z=b+w1x1+w2x2z = b + w_1x_1 + w_2x_2 ...

Access this course and 1400+ top-rated courses and projects.