Overfitting is the bane of machine learning algorithms and arguably the most common snare for rookies. It cannot be stressed enough: do not pitch your boss on a machine learning algorithm until you know what overfitting is and how to deal with it. It will likely be the difference between a soaring success and catastrophic failure.
With that said, overfitting is an interesting problem with fascinating solutions embedded in the very structure of the algorithms you’re using. Let’s break down what overfitting is and how we can provide an antidote to it in the real world.
Your Model is Too Wiggly
Overfitting is a very basic problem that seems counterintuitive on the surface. Simply put, overfitting arises when your model has fit the data too well.
That can seem weird at first glance. The whole point of machine learning is to fit the data. How can it be that your model is too good at that?
The problem is how we frame the objective of ‘fitting the data’. In machine learning, there are two really important measures you should be paying attention to at all times: the training error and the test error. These terms are fairly self-explanatory: training error is a measure of how well your model performed in training, and test error is how well it performed in the wild.
When we develop an algorithm, our goal is to create a model that performs out in the wild. We don’t necessarily care if our model completely bombs in training, as long as it works in the real world.
In fact, the only reason we really care about our training error is because it can give us a clue about how it will perform in test. If that link breaks down, then measuring our training error no longer provides us with any useful information.
The reason that link breaks down at all is because our data will inevitably have some amount of noise. The real world tends not to operate along perfect curves, and even when it does, our measurements of those curves are imperfect. Think about measuring rainfall: we can sample the data and get a good estimate of how much rain actually came down, but do we really think that that exact amount of rain fell everywhere in a multi-mile radius? That’d be crazy.
Which is why it is crazy to assume that fitting your data perfectly in training will lead to equally good results in test. Consider a dataset that looks something like this.
Now, we can draw a line that fits the data perfectly – and point in fact, many algorithms are very good at finding complex solutions to do so. We might send out a non-linear regression algorithm and find something like this.
Do you see the problem? We actually fit the data perfect in training – but why would we ever expect this to work in the real world?
Say we start evaluating the model on the test data, colored in green.
It may look like your model is cutting a reasonable swath through the data, but we need to measure the error on the test points. In linear regression, the error is typically measured by the Euclidean distance from the hyperplane at the same X value. What will that look like?
There’s no bones about it, those test error measures are huge, even though the training error is zero. That’s the crux of the problem: we overfit the training data to the detriment of real-world performance.
Straightening the Model Out
The reason that our model was able to overfit the data is because we allowed it to find complex solutions to simple problems. That explanation is an oversimplification, but the intuition is there. What if we just look for simpler solutions to simple data?
This concept is called regularization. Without delving deep into the math, the idea is that we adjust the error measure to reflect a preference for simple solutions. In addition to measuring the performance on the data, we measure the model’s complexity and penalize it for coming up with crazy solutions like what we saw above. This is expressed as an error measure that is the sum of the training error and a complexity penalty.
By penalizing our model for complexity, we nudge it towards simpler solutions. By doing so, we increase the model’s bias, but we also dramatically reduce the variance. Variance is oftentimes the larger source of error in the model, so the small price we pay in increased bias will actually get us a net gain in performance.
Consider the data we were originally interested in training on. What if we regularize our model so that it prefers a simple solution – say a straight line?
Our training error is obviously higher since we’re not cutting through each point perfectly, but the complexity penalty is much, much lower than it might be for the other model. How does it look with our test data?
Not too shabby. What about the error bars?
It’s plain to see that the average error bar is significantly smaller. Simply by virtue of choosing an unsophisticated solution, we created a model that performs better in the real world. It’s sort of paradoxical, but it works.
Every machine learning algorithm has a different approach to regularizing, but generally speaking, it works by restricting the range of values the model parameters can take on. This is done by using another parameter that the user is responsible for choosing, called a hyperparameter.
Think of your hyperparameter as a knob you twist to control your regularization. The goal is to twist the knob just so, finding the exact point where more complexity will increase the augmented error instead of decreasing it.
To reiterate, overfitting will kill your model. At the same time, so will underfitting (too much regularization). We’ll discuss how to choose the optimal hyperparameter for your model at a later date. For now, learn how regularization works with the algorithm you’ve chosen, and use it well.