Generalization Error Estimation Methods for Fine-Tuned Models
machine learning
validation
Author
Matt Kosko
Introduction
One of the great advantages of the large increase in pre-trained transformer-based large language models (LLMs) like BERT and GPT is the ability of these models to be fruitfully applied to a variety of tasks they were not initially trained on (Wei, Xie, and Ma 2021). In particular, these models can be fine-tuned for other tasks, using datasets that are much smaller than ones used train the original model; this is attractive precisely because you can leverage the enormous amount of information the model learned during pre-training on a task where labeled data is expensive to obtain (Church, Chen, and Ma 2021).
As with any task, once we fit a model we want to measure how well it performs, that is, we want to validate it. Those using LLMs in downstream tasks should remember that typical model validation techniques often used in large sample-size cases do not work as well when datasets are small. In particular, when datasets are too small for train/test spliting, you can make use of resampling methods like K-fold cross validation. In this post, we explore how resampling methods can improve validation estimation compared to data-splitting in the small N case.
Validation
When we say we want to see how well a model performs, we mean that we want to estimate its generalization performance on an independent “test” set with data that was not seen during training (Hastie et al. 2009, 2:219).
Following the notation of Hastie et al. (2009), given an outcome \(Y\), a vector of inputs \(X\), and a loss function \(\mathcal{L}\), the generalization error is given by:
In other words, the mean performance of the model trained using the training data \(\mathcal{T}\) is averaged over all future draws of new datasets. The training data is fixed here.
This tells us about validation of the trained model itself, how well it will perform conditional on having seen the training data \(\mathcal{T}\). Often, it is easier to get estimates of a different quantity, the expectation of this error over training sets (Hastie et al. 2009):
Resampling methods like K-fold cross-validation estimate this quantity (more on this below).
Data-Splitting
The most straightforward method for validating a model is data-splitting, the classic train/test split method introduced in every machine learning textbook. The data is divided into a train set \(\mathcal{T}\) and a test set \(\mathcal{R}\)(Harrell et al. 2001; Ghojogh and Crowley 2019). When the test set is held out, the resulting error estimates Equation 1 rather than Equation 2.
K-Fold Cross-validation
In K-fold cross validation, the data is split into \(K\) folds and then, for each fold, the model is trained on \(K-1\) folds. A performance metric is estimated on each holdout fold and these metrics are then averaged over all folds. Repeat K-fold is the same procedure except it repeats the K-fold procedure multiple times. Because the training data is randomly partitioned into folds, it is not held constant and these methods estimate Equation 2.
Choosing a Method
When data is abundant, data-splitting is the most obvious validation solution, for computational convenience and because it works well in those cases. There, the held-out test set is representative of the data as a whole. However, when datasets are small, problems can emerge with this method. For one thing, you are removing data that would be useful in training; as (Harrell et al. 2001, 608:109) notes, “The surest method to have a model fit the data at hand is to discard much of the data.” Secondly, because you are only doing one split, you may get a test set that is “fortuitous” and calculate an estimate of performance that is too optimistic (Harrell et al. 2001, 608:112). Depending on sample size, a single split can lead to too noisy an estimate of model performance. In the large data setting, these concerns are attenuated somewhat. During pre-training, language models can be trained on corpuses consisting of hundreds of millions or even billions of words; BERT for example was pre-trained on BookCorpus and English Wikipedia, consisting of 800M and 2.5B words respectively (Devlin et al. 2018).
You might argue that you always should prefer data splitting, as the quantity of interest (how well the model performs) is given by Equation 1, which is only estimated with data-splitting. However, analysts often combine the training and test data and present a final model trained on both sets of data. In that case, the error obtained on the test set no longer estimates Equation 1, as the training error has changed. Moreover, in practice, Equation 2 is a good estimate of the generalization performance of the model.
Data Generating Process
To check different validation methods, we construct a classification task and apply an open-source LLM to it. We will test how well the distilBERT model can differentiate between text generated by two different GPT models, GPT-2 and distilGPT-2. Although a contrived example, it has analogies with real-world problems involving differentiating human from computer-generated text. We construct a highly imbalanced training set of 4000 observations where 80% of the observations are generated by GPT-2 while the other 20% are from the distilled version. The full code can be found on GitHub. An example of the training data is shown below:
train_dat.head()
text
label
0
surroundings.\n\n\n"I don't think he's a bad p...
0
1
complaint" filed in the U.S. district court fo...
1
2
high by the way, the "lazy kid" wasn't a boy.\...
1
3
say, that if I was on the verge of becoming a ...
1
4
normal". "We are pleased with this particular ...
0
We use three different validation methods, data-splitting, K-fold cross-validation, and repeat K-fold cross-validation.
Results
After constructing a single dataset according to our data-generating process, we perform 50 replications of each validation method. Because of the time it takes to run this many simulations, we deploy the simulation code to AWS and train every replication on a GPU. Figure 1 shows the F1-scores (a metric particularly good for imbalanced datasets). Figure 2 shows a similar plot for accuracy.
We can see clearly that all the methods are centered around a similar value, except that the resampling methods show much greater precision. We can see this more formally by looking at the summary statistics:
Illustrating the variance from a single split, there are replications where the data-splitting method returns accuracy scores as low as 0.92. We also see, as is obvious from the box plots, that the variance of the resampling methods are much smaller than those obtained by data-splitting.
At first glance, these results don’t indicate why we want to use anything other than data-splitting. Yes, there’s variation but it’s not that extreme. In this case, there appears to be enough signal in the data that data-splitting works (although there is more variance as we mentioned). However, in an earlier version of this experiment, with a dataset size of about 1500, the model was merely predicting the majority class for every case, and the observed high variance in F1 and accuracy scores given by data-splitting would lead an analyst to incorrectly conclude that a model was doing better than guessing. In addition, this variance in data-splitting can lead to problems when validation is used to choose between models. To show this, we look at a computationally simpler example.
Model Selection
For computational convenience, we consider how to choose neural network architecture and hyperparameters for use in a regression task. Consider the following data-generating process with 10 features:
The outcome is complex function of the predictors, with high-variance additive noise. It’s likely that noise will make it difficult to learn the underlying function in small sample sizes.
We consider neural network estimators of the following form:
class Net(nn.Module):def__init__(self, num_layers, input_size, hidden_size, output_size):super(Net, self).__init__()self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_size)])self.hidden_layers.extend([nn.Linear(hidden_size, hidden_size) for _ inrange(num_layers -2)])self.output_layer = nn.Linear(hidden_size, output_size)self.activation = nn.ReLU()def forward(self, x):for layer inself.hidden_layers: x =self.activation(layer(x)) x =self.output_layer(x)return x
We vary the number of hidden layers, hidden units in each layer, and weight decay hyperparameter:
We generate a training set of 10,000 observations and a test set of 1,000,000 observations. After training the neural network on the full training set, we choose the “correct” model on the test set. The correct model is the one that gives the least error on the test set. We will compare data-splitting to K-fold cross validation and repeat K-fold cross validation in selecting the “correct” model.
Let’s look at the top 10 validation losses on the test set (loss here is MSE):
Type
Data Split 0.06
KFold CV 0.12
Repeat CV 0.14
Name: Choice, dtype: float64
Although none of the methods do that well, given the relatively small sample size, the resampling methods choose the correct model far more often. Given that we saw many of the test sizes were similar, we next see how often each validation method chooses the among the 5 smallest-error models.
Type
Data Split 0.34
KFold CV 0.54
Repeat CV 0.46
Name: Choice, dtype: float64
We see that the resampling methods now choose the correct model about half the time, while simple data-splitting only works about 1/3 of the time.
Conclusion
For large datasets with a large signal-to-noise ratio, data-splitting does not pose a problem and is a computationally simpler form of validation. But when datasets are small, even when working with a sophisticated pre-trained model, analysts should consider resampling methods.
References
Church, Kenneth Ward, Zeyu Chen, and Yanjun Ma. 2021. “Emerging Trends: A Gentle Introduction to Fine-Tuning.”Natural Language Engineering 27 (6): 763–78.
Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2018. “Bert: Pre-Training of Deep Bidirectional Transformers for Language Understanding.”arXiv Preprint arXiv:1810.04805.
Ghojogh, Benyamin, and Mark Crowley. 2019. “The Theory Behind Overfitting, Cross Validation, Regularization, Bagging, and Boosting: Tutorial.”arXiv Preprint arXiv:1905.12787.
Harrell, Frank E et al. 2001. Regression Modeling Strategies: With Applications to Linear Models, Logistic Regression, and Survival Analysis. Vol. 608. Springer.
Hastie, Trevor, Robert Tibshirani, Jerome H Friedman, and Jerome H Friedman. 2009. The Elements of Statistical Learning: Data Mining, Inference, and Prediction. Vol. 2. Springer.
Wei, Colin, Sang Michael Xie, and Tengyu Ma. 2021. “Why Do Pretrained Language Models Help in Downstream Tasks? An Analysis of Head and Prompt Tuning.”Advances in Neural Information Processing Systems 34: 16158–70.