The Limitations of Linear Regression
Explore the limitations of linear regression in this lesson.
We'll cover the following
In the previous chapter, we wrote a piece of code that learns. However, if that code is reviewed by computer scientists, they would find it lacking. In particular, they would raise an objection to the train()
function. According to the stern computer scientist this code might work okay for this simple example but it would not scale to real-world problems.
In this chapter, we’ll address those concerns in two ways.
- First, we won’t get our code reviewed by a computer scientist.
- Second, we’ll analyze the shortcomings of the current
train()
implementation and solve them with one of machine learning’s key ideas, an algorithm called gradient descent.
Like our current train()
code, gradient descent is a way to find the minimum of the loss function, but it’s faster, more precise, and more general than the code from the previous chapter.
Gradient descent is not just useful for our tiny program. In fact, we cannot go very far in ML without gradient descent. In different forms, this algorithm will accompany us to the end of this chapter.
Let’s start with the problem that gradient descent is meant to solve.
Our algorithm
Our program can successfully forecast pizza sales, but why stop there? We can forecast many other possibilities, for example, maybe we could use the same code to forecast other things, such as the stock market.
However, if we try to apply our linear regression program to a different problem, we would bump into an impediment. Our code is based on a simple line-shaped model with two parameters: the weight and the bias . Most real-life problems require complex models with more parameters. As an example, remember that our goal for the first part of this course is to build a system that recognizes images. An image is way more complicated than a single number, so it needs a model with many more parameters than the pizza forecaster.
Unfortunately, if we add more parameters to our model, we would kill its performance. To see why, let’s review the train()
function from the previous chapter:
Get hands-on with 1400+ tech skills courses.