Training Neural Networks with Regularization
Regularization refers to a suite of techniques used to prevent overfitting, which is the tendency of highly expressive models such as deep neural networks to memorize details of the training data in a way that does not generalize to unseen test data.
- Using the ValidationSet option of NetTrain to return the net with the best performance determined by the validation loss (or error in a simple classification net).
- Using the TrainingStoppingCriterion option of NetTrain in conjunction with ValidationSet to stop training when overfitting starts to happen.
- Use of regularization layers such as DropoutLayer or features such as the "Dropout" parameter of LongShortTermMemoryLayer etc.
Before we can describe the solutions, we will demonstrate the problem with a simple example. We create a synthetic training dataset by taking noisy samples from a Gaussian curve. Next, we train a net on those samples. The net has much higher capacity than needed, meaning that it can model functions that are far more complex than necessary to fit the Gaussian curve. Because we know the form of the true model, it is visually obvious when overfitting occurs: the trained net produces a function that is quite different from the Gaussian, as it has "learned the noise" in the original data. This can be quantified by sampling a second set of points from the Gaussian and using the trained net to predict their values. The trained net fails to generalize: while it is a good fit to the training data, it does not approximate the new test data well.
The resulting net overfits the data, learning the noise in addition to the underlying function. To see this, we plot the function learned by the net alongside the original data.
Obtain the net from the NetTrainResultsObject:
A more quantitative way to demonstrate that overfitting has occurred is to test the net on data that comes from the same underlying distribution but that was not used to train the net.
The average loss on the test set is much higher than on the training set, showing that overfitting has occurred:
The first common approach to mitigate overfitting is to measure the performance of the net on the secondary test data (which is not otherwise used for training) and to choose the particular net that corresponds to the best performance on the test set across the entire history of training. This is possible with the ValidationSet option, which measures the net on the test data after each round. Those measurements have two consequences: they produce validation loss curves (and validation error curves when classification is being performed), and they change the selection process used by training to pick the intermediate net with the lowest validation loss as opposed to the lowest training loss.
Use the ValidationSet option to NetTrain to ensure that the net we actually obtain minimizes the validation loss. Note that we limit the training rounds to make it easier to see the portion of training before overfitting starts to occur:
Extract the trained net from the NetTrainResultsObject and plot it. The result is much smoother, as NetTrain effectively took a snapshot of the net before it started to memorize the idiosyncrasies of the noise in the training set:
The results object also stores the average loss on the test set for every round. We can examine the loss for the net that was picked. Notice that it is much lower than the loss we computed for the overfitted net:
A second common approach to mitigate overfitting is to perform early stopping. This is the procedure of stopping training when some measurement of the net's performance starts to become worse. This is accomplished using the TrainingStoppingCriterion option of NetTrain in conjunction with the ValidationSet option. This has two potential advantages over simply using ValidationSet. Firstly, notice that in the previous example we ended up training the net for an extra 1300 rounds. In this case, because the net and training datasets are both small, this was not an issue, but this could be very wasteful. Secondly, we might want to determine the best net using some measurement other than the loss.
Let's use the R-squared measurement (also known as the coefficient of determination) to determine when to stop training. We use a patience of 20 rounds to avoid stopping due to noise in the training process:
Extract the trained net from the NetTrainResultsObject and plot it as before:
A third common regularization technique is called weight decay. In this approach, the magnitude of the weights of the net is decreased slightly after each batch update, effectively moving the weights closer to zero. This is loosely equivalent to adding a loss term that corresponds to the L2 norm of the weights.
Performing weight decay encourages the net to find a parsimonious configuration of its weights that can still adequately model the data or, equivalently, penalizes the net for the complexity incurred by fitting noise rather than data.
In general, the optimal value for the strength of the weight decay is difficult to derive a priori, so a hyperparameter search should be performed using a validation set to find a good value.
The resulting net is a good fit to the Gaussian and demonstrates better generalization than the original overfitted net:
A fourth common regularization technique is dropout. Dropout introduces noise into the hidden activations of a net, but in such a fashion that the overall statistics of the activations at a given layer do not change. The noise takes the form of a random pattern of deactivation, in which a random set of components of the input array (often referred to as units or neurons) is zeroed, and the magnitudes of the remaining components are increased to compensate. The basic idea is that dropout prevents neurons from depending too heavily on any particular neuron in the layer below them and hence encourages the learning of more robust representations.
Dropout can be introduced into a net using a DropoutLayer or by specifying the "Dropout" parameter of certain layers such as LongShortTermMemoryLayer.
The resulting net is an acceptable fit to the Gaussian and demonstrates better generalization than the original overfitted net:
A fifth regularization technique is batch normalization. It normalizes its input data by learning the data mean and variance. Batch normalization has a number of useful properties in practice: it speeds up training and provides regularization. A BatchNormalizationLayer is typically inserted between a LinearLayer or ConvolutionLayer and its activation function.
Create a multilayer perceptron that includes a BatchNormalizationLayer: