Improved Generalization for Image Classification
Through Memorization

Roger Waleffe and Jason Mohoney

Motivation


Deep learning models for computer vision are deployed everywhere, but they are increasingly being deployed in safety-critical applications. Examples include self-driving cars and medical applications like AI-assisted robotic surgery or medical image analysis.

Problem

In safety-critical applications models need to generalize reliably from their train to test environments, otherwise consequences could be severe.

Models often don't generalize well!

Consider the experiment from [Recht et al. (2019)]. The authors plotted the ImageNet test set accuracy of a bunch of deep learning models (blue dots) on the x-axis. From left to right the accuracy of individual models increases. On the y-axis, the authors plotted the accuracy of each model on a new test set that they created, but one which was created by following the same instructions and distribution as the original (as best they could). If the models truly generalized well to unseen data, you would expect the accuracy on the two test sets to be quite similar (black line). Instead, even in this simple setting, all models perform about 10 percent worse on the new test set. This is rather alarming, especially considering the possible deployment applications where these differences could matter substantially.
Does ImageNet Generalize to ImageNet? Legend
Figure 1

Problem

Generalizing well to unseen examples is even harder on instances which appear only rarely in the training data.

Real-world datasets are long-tailed!

In long-tailed distributions a few objects (or visibility patterns) occur frequently, while many objects (or visibility patterns) occur infrequently. For example, in the SUN large scale scene recognition dataset below (plot from [Zhu et al. (2014)]), people and windows occur frequently, but many other objects also appear in the dataset. Likewise, even for a single object like a person, some visibility patters (i.e. head on) occur frequently, but many others (people on horses or facing sideways etc.) appear in the dataset a few times. The total weight of the infrequent examples, however, ends up being a substantial part of the dataset. Think of all the many corner cases self-driving cars need to handle correctly.
Long-tailed distribution
Figure 2

Problem Statement


To improve the generalization of computer vision models on real-world datasets, we need to consider both common and atypical instances. In this project, we focus on the latter.

How do we improve generalization on the long-tail?

So how do we go about trying to make progress on this problem? First off, existing techniques for training with rare examples include:
  • upsampling or downsampling the training data
  • re-weighting the loss function to pay more attention to underrepresented classes
  • or utilizing training objectives which seek to maximize the accuracy on the worst performing class rather than the average accuracy (e.g. Distributionally Robust Optimization (DRO) ).
One main problem with these techniques, however, is that you need to know, and have labeled, all of the atypical examples/groups. Rare instances like uncommon visibility patterns may be missed because they could be rare subpopulations within a very common labeled class.
In addition to the techniques above, a very recent line of theoretical work suggests that memorization of training data and their labels is key to achieving close-to-optimal generalization, especially in the context of long-tailed distributions ( [Feldman (2021)], [Talwar et al. (2020)] ). We found this very interesting. Intuitively it makes some sense, as what else can you really do with only a few examples? Another obvious advantage is that you may not need to actually know and label what’s atypical (as above). If you just memorize a bunch of training data, you are likely to memorize a few labels from each of the rare classes (or visibility patterns), and then you may be able to use that information to help classify similar rare instances in the future. Note also that if memorization can help improve generalization, it may be possible to couple it with the existing techniques above for additional robustness. In contrast to these potential benefits, however, conventional thinking somewhat contradicts this theory, as memorization is generally thought to lead to overfitting and therefore poor generalization.

Project Goal

While recent theoretical results suggest memorization may improve generalization, especially over long-tailed distributions, these results are very abstract, make a number of assumptions, and do not yet provide any practical algorithms for computer vision. The goal of our project is to:

Test whether memorization can improve generalization by incorporating it into real object recognition settings.

Approach


The basic setup of our project is as follows. Specifically, the goal of our approach is to concisely memorize a representation for each training image and then use this information to classify future test images. Instead of memorizing raw images themselves (comparing raw image intensities is not that semantically meaningful), we decided to memorize learned, high-level representations for each image and their corresponding labels. Thus, the first step in our approach is to learn these representations. As trained convolutional neural networks (CNNs) are known to extract useful image features/representations, we decided to start by conventionally training one of these models. We chose a normal ResNet-20 ([He et al. (2016)]) model for this step. The ResNet consists of a bunch of convolution layers, followed by an average pooling layer, and finally a single dense layer.
Step 1: Train ResNet-20
Step 1
After training our ResNet model, the second step is to extract and memorize the learned representations and labels for each training image. For each image, we feed it through the trained network and take its representation to be the output of the ResNet after the average pooling layer, but before the final dense layer and classification step. For ResNet-20, this representation is a 64 dimensional vector.
Step 2: Extract and memorize representations for each training image
Step 2
Given a new image which we wish to classify, we can then use the memorized information in a k-nearest neighbor (k-NN) fashion. First we compute the representation for the new image. Then we use this representation to compared against our memorized representations, and return the top k closest matches based on some similarity metric. Finally, given the k-nearest neighbors for our test image, we can predict its label using a majority vote over the neighbor’s labels.
Step 3: Classify future test images using memorized representations
Step 3
Implementation Details: We implemented our above approach from scratch in TensorFlow. For all experiments below, we evaluate our method on the CIFAR-10 ([Krizhevsky et al. (2009)]) dataset. During training, we use conventional data augmentation: pad four pixels on each side of the image and then randomly crop a 32x32 section from the padded image or its horizontal flip. We normalize images by subtracting off the mean and dividing by the variance of each color channel, with statistics computed on the train set. For training our ResNet implementation, we use the standard ResNet learning rate schedule for CIFAR-10: an initial learning rate of 0.1 dropped by a factor of 10 after 32k and 48k iterations (82 and 123 epochs) with 64k iterations total (164 epochs). We use a batch size of 128 and stochastic gradient descent with momentum 0.9 and categorical cross-entropy loss function. For k-NN similarity metrics, we consider three main choices:
  • the Euclidean distance (L2) between representation vectors
  • the dot product (Dot) between representation vectors
  • and the normalized dot product (NDot) between representation vectors.
We also consider first centering the representations by subtracting off the mean (computed on the memorized representations) before utilizing one of the three metrics above. Note that centering does not affect the Euclidean distance but does affect dot product based similarity metrics.

Results


We tested the basic memorization setup described above by looking at the accuracy of the k-nearest neighbor model on the standard CIFAR-10 image classification dataset. Our default baseline was to compare to the accuracy of the original ResNet-20 model from Step 1. We varied the similarity metric used to identify top neighbors and the hyperparameter k (how many nearest neighbors are used for classification). Below (Figure 3) you can see that for some values of k and some similarity metrics, the overall accuracy is very similar to that of a normal ResNet model. Zooming in on the per class accuracies (Figure 4) shows the memorization model makes very similar predictions to a standard ResNet. This isn't necessarily too surprising, as the memorization model uses representations from the trained ResNet.
basic_knn_1
Figure 3
basic_knn_per_class
Figure 4

What about when classes are imbalanced, e.g. the long-tail setting?

While we didn't expect much difference between the memorization approach and a standard ResNet in the case where the classes are balanced, recall that our primary motivation was to improve generalization on the long-tail. To this end, we tested the k-nearest neighbor model on versions of the CIFAR-10 train set which we manually imbalanced. We first randomly removed all but 50 images from class zero (one percent of the normal 5000 per class) to create the class distribution below (Figure 5). In the per class accuracies for this case (for a specific k-NN model) there seems to be a slight improvement on the rare class when comparing the memorization approach to the standard ResNet (Figure 6).
class_0_at_50_dist
Figure 5
class_0_at_50_per_class
Figure 6
When computing the overall accuracy of models on imbalanced class distributions there is an important caveat to consider. In theory, the class distribution should also be imbalanced at test time, i.e. rare instances in the training data remain rare during inference. However when a class appears infrequently in the test set, it becomes hard to calculate an accurate accuracy on those instances. For example, there are normally 1000 instances of each class in the CIFAR-10 test set. If we remove all but one percent of them for class zero to match the training distribution above (Figure 5), then we are left with only 10 examples from that class. Instead of removing examples from the test set, we instead evaluate on every example, allowing for more accurate per class accuracy calculations, but in addition to the overall test set accuracy we also report a down weighted test set accuracy. In this metric, the contribution of rare classes to the overall accuracy is down weighted so that the end result is the accuracy you would expect when these instances appear only rarely at test time.
We report test set accuracy (Figure 7) and down weighted test set accuracy (Figure 8) for the class distribution above (Figure 5) in the plots below. When evaluating on the full test set, lower values of k achieve slightly higher accuracy on the imbalanced class zero (as we saw above in Figure 6) when compared to a ResNet model, but also occasionally drop some accuracy on other classes (e.g. class five in Figure 6). This averages out and results in the memorization approach and standard ResNet model achieving very similar overall test set accuracies. Higher values of k achieve poor accuracy on class zero, and thus drop significant accuracy on the test set. This is because the memorized representations contain only 50 instances which have class zero labels. Unless a test point is very close to all of them, when computing a majority vote over many neighbors, other frequent classes are likely to appear more often. In contrast, when evaluating the down weighted accuracy, higher values of k perform better on the frequent classes (as seen in the balanced class setting, Figure 3) and thus achieve higher accuracy. The improvement on class zero for small values of k is negligible for the down weighted accuracy because of the rarity of those examples at test time.
class_0_at_50_37
Figure 7
dw_class_0_at_50_37
Figure 8
In addition to removing images from one class, we also created another imbalanced CIFAR-10 training set which more closely matches the long-tailed setting. The class distribution is shown below (Figure 9).
log_dist
Figure 9
In this case, if you look at the per class accuracies for the memorization approach and compare to a standard ResNet model, you can see the k-NN is performing significantly better on the rare classes. For the similarity metric on the left (Figure 10), the difference is quite noticeable. The problem is that the k-NN has dropped some accuracy on the most common class, class zero in this case. With a different similarity metric (Figure 11) the k-NN regains some of the lost accuracy on class zero, but at the cost of smaller improvements on the rare classes.
log_per_class_85
Figure 10
log_per_class_88
Figure 11
Looking at the overall accuracies of different models on the simulated long-tailed distribution (Figure 9), we can see that on the raw test set (Figure 12) a number of memorization models achieve significantly higher accuracy than the baseline ResNet. This is because they have significantly better accuracy on the rarer classes (as we saw above, e.g. Figure 10). As for the case where only one class was imbalanced, higher values of k do not achieve the same improvements on the infrequent classes and thus see less of an improvement, if any, on the test set. Note however, that the down weighted test set accuracy, the accuracy you would see in the real world, of the ResNet and memorization models is nearly identical (Figure 13). Even though the k-NN approach significantly improves accuracy on the rarer classes, it looses some accuracy on the most common classes, and this drop has a larger impact on the true accuracy (as the common classes appear more frequently at inference time and thus contribute more to the true accuracy).
log_73
Figure 12
dw_log_73
Figure 13
Other Baselines: For another baseline in addition to the standard ResNet model, we tried training a ResNet with loss re-weighting. The idea is to re-weight the contribution of each example to the loss function during training so that infrequent classes have higher weight and the model pays more attention to them. Interestingly, this baseline approach failed to converge for the simulated long-tailed distribution above (Figure 9). We hypothesize that the class distribution is too heavily imbalanced, requiring to large of re-weighting values, making the training objective very unstable. Certainly a strength of the k-NN memorization approach is that it can be applied to any distribution/model/training objective without suffering from similar problems.

Results Takeaway

We noticed promising signs that memorization can help improve generalization, especially in the long-tailed setting (i.e. Figure 10). There appears to be a tradeoff however, where as accuracy improves on rare classes, it drops slightly on more common classes. This results in overall accuracies for our k-NN models that are very similar to our baseline, standard ResNet model.

Extensions


Beyond our main approach and its corresponding results presented above, we had a number of hypotheses on how to improve our method. We term these experiments extensions.
Extension 1: Compute similarity metrics for k-NN model in PCA subspaces
One hypothesis we had was that our similarity metrics might be noisy because they are operating over high dimensional (dimension 64) vectors. It's a well known phenomena that points in many dimensions are all sort of close or far away from each other. To mitigate this issue, we tried first reducing the dimension of our representations and then computing our similarity metrics to find nearest neighbors. Specifically, after memorizing the training data representations in Step 2, we then computed PCA over these vectors to reduce their dimension (down to 32 or 16). Given a test image in Step 3, after computing its representation we also reduce its dimension using the same transformation as computed on the training data. The nearest neighbor model can then proceed as normal given the smaller vectors.
We found that this additional processing of image representations lead to minimal changes in our results. For example, in the case of balanced classes below (Figure 14), the results are quite similar to the original results for this setting above (Figure 3). The only real difference manifests in the behavior of the dot product similarity metric. As another example, when looking at the down weighted test set accuracy on the long-tailed distribution for PCA similarity metrics (Figure 15), there is little change in the results when compared to using the full dimensional vectors (Figure 13).
basic_knn_19
Figure 14
dw_log_91
Figure 15
Extension 2: Learn representations using different training objectives
A second extensions we tried aimed to improve the results of the memorization model by explicitly training the representations with a k-NN similarity metric in mind. In other words, instead of taking the memorized representations of the training data to be the output of the average pooling layer after standard ResNet training, we modified the training procedure in Step 1 by designing different loss functions. The goal of these loss functions was to cluster similar instances together with the hope that this might improve the quality of the returned nearest neighbors at inference time. The notion of clustering is tied to the similarity metric we hoped to improve: i.e. if the plan was to improve the k-NN model when using Euclidean distance to compare representations, then it might benefit us to train with a loss function which penalizes large distances between instances of the same class or small distances between instances of differing classes (same idea for converting dot product or normalized dot product similarity metrics into loss functions).
We tried three different loss functions (to cluster based on the three main similarity metrics) each with possibly a few hyperparameters. We also tried varying the optimizer between SGD and Adagrad. The exact details of our loss functions can be found in the code (train_helpers.py). All loss functions we tried resulted in k-NN models which performed worse than when learning representations using the standard ResNet loss function (cross-entropy). For example, when using an objective designed to cluster classes based on the Euclidean distance between points (Figure 16), the best performing k-NN models achieve around 89 percent accuracy with balanced classes. Compare this with the 91.3 percent accuracy of the best performing k-NN models when using standard ResNet training (Figure 3). Similar results hold for another example objective which clustered points based on the dot product between representations (Figure 17). Interestingly, the normalized dot product similarity metric often appeared to perform better than the dot product similarity metric for balanced classes, but the reverse was true for the loss function (our normalized dot product clustering objectives performed worse than the dot product clustering objectives). In hindsight, it's not necessarily surprising our objectives performed worse than cross-entropy. After all, even with the classification loss function the network must learn some form of clustering, and there is a reason cross-entropy has become the go to objective: it works quite well in practice.
l2_power_2_sgd_1
Figure 16
dot_pca_sgd_37
Figure 17
Extension 3: Replace basic k-NN majority vote with a more complicated model
For our third extension, we wanted to see if the memorization model could be improved by using a more complicated function to combine the nearest neighbors into a prediction (instead of just using the simple majority vote over the neighbor's labels as above). When using the majority vote to combine the neighbors, only the label information is used, and the actual representation vector of each of the neighbors is discarded. We wondered if there was any useful information in these representations themselves. To test this hypothesis, we replaced the majority vote over the neighbors with a Transformer model which takes as input the neighbor's representations and labels, along with the representation of the test image. The Transformer can then learn to combine this information and produce a final representation, called the contextualized representation, of the test image which can then be fed to a final function for classification (we used a single fully connected layer for the final function). The entire architecture is depicted below.
In our preliminary experiments with this architecture, we found it provided no accuracy improvements over the simple majority vote method. For example, with balanced classes a number of k-NN models achieve 91.3 percent accuracy (Figure 3) which is nearly the same accuracy achieved by the Transformer (91.38) with 10 nearest neighbors. It appears there is little additional information contained in the neighbor's representations. We remark, however, that there are a number of hyperparameters and configuration details for the Transformer that we did not have time to study, so it's possible that with some tuning the results could improve.
original_design
Extension 4: Combine k-NN memorization model and baseline ResNet into one model
For our final extension, and the one we spent the most time on, the goal was to try and combine the memorization model with the original ResNet to create a "best of both worlds" sudo-model. This extension was motivated by the tradeoff observed above: the memorization models seemed to improve accuracy on the rarer classes but at the expense of some accuracy on the more common classes (Figure 10). We wondered whether we could design a gate which picked when to use each model such that the sudo-model achieved the highest possible accuracy on all classes. Another significant benefit of such a sudo-model is that it requires zero extra training and very small additional computation at inference time. Note, however, that the gate is not as simple as deciding which model to use based on the class it belongs to because the true class is unknown at test time. We are only able to separate out the per class accuracies because the true labels on the test set are known and can be used for evaluation after the fact, but using this information to decide the prediction itself would be cheating.
Even though using the true labels to design the gate is invalid, we define the optimal gate which does utilize this information, but only to get an upper bound on how much the sudo-model could possibly increase accuracy. Specifically, the optimal gate picks the k-NN model if the k-NN model is correct and the ResNet if the ResNet is correct. If both are correct or both are wrong, it doesn't matter which model is chosen in the gate. We primarily considered two gates: a gate using the one nearest neighbor model with normalized dot product similarity metric for the balanced class setting and a gate using the one nearest neighbor model with centered dot product similarity metric for the long-tailed class distribution (the same similarity metric which produced the large improvements on the rarer classes in Figure 10).
In the balanced class setting, a single run of our k-NN model achieves 90.92 percent accuracy while the corresponding ResNet reaches 91.35 percent. The optimal gate produces a sudo-model with 92.67 percent accuracy, a 15 percent relative drop in error. On the long-tailed distribution (Figure 9), a single k-NN model reaches 59 percent accuracy on the test set and 74.35 percent down weighted accuracy. The ResNet model results in 52.60 and 90.44 percent test set and down weighted accuracy respectively. In this case, the optimal gate produces a model with 67.03 percent test set and 92.49 percent down weighted accuracy, a 21 percent relative drop in down weighted error. These results were quite exciting and lead us to spend significant time trying to develop a true gate which approximated the optimal choice. Note, however, that the combination of any two models which make different predictions in this manner will increase accuracy. It does not mean that such a combination can be achieved in practice, for example if the two models randomly make different predictions.
To decide when to use each model in practice, we questioned whether some simple criteria existed which would allow us to pick the k-NN model or vice versa. Our hypothesis was that it would be better to use the memorization approach for outlier examples or when the k-NN was very confident and the ResNet model's prediction in the opposite scenarios. For example:
  • Should we pick the k-NN model when the similarity metric with the nearest neighbor is really high, say above some threshold?
  • Or should we pick the ResNet model when the test image representation is very similar to the centroid representation for that predicted class (i.e. the test image seems to be at the center of a cluster of training images of a specific class and therefore is likely not an outlier)?
We plotted possible gate criteria for test images where the k-NN and ResNet predictions differed (the only points where the gate matters) and color coded them by which model ended up being correct. If a clear separation emerged between colors (along the gate criterion axis), then we would be able to define a gate based on a separation threshold for that criterion.
In the balanced class setting, the color coded plots for the two example gate criteria described in the preceding paragraph are plotted below (Figure 18 and Figure 19). On the x-axis, the example number shows how many instances the ResNet model and k-NN model each got correct out of the instances where their predictions differed. I.e. in Figure 18, of the 305 test images where the two models differ, the ResNet gets 174 correct and the k-NN gets 131 correct. Separation in colors along the x-axis merely shows which model was correct more often on instances from the training set where the two predictions differed. We are instead interested in separation along the y-axis which would indicate that points which the ResNet and k-NN each predicted correctly had different characteristics according to the specific gate criterion. In these plots no clear separation between when the ResNet model is correct and when the k-NN is correct exists. For example, in Figure 18 it is not the case that the k-NN is correct more often when the normalized dot product with the nearest neighbor is really high. These results contradict our above hypothesis.
ndot_k1_gate_on_ndot
Figure 18
ndot_k1_gate_on_ndot_w_knn_centroid
Figure 19
Similar behavior exists for the two example gate criteria when looking at the long-tailed distribution (Figure 20 and Figure 21). In this case, the ResNet and k-NN differ on many more test images. While it appears the k-NN is correct much more often than the ResNet, recall that we evaluate on the full test set and perform down weighting to calculate the true accuracy for imbalanced class distributions. So while the k-NN is correct on many more instances from the test set and therefore has higher test set accuracy than the model (Figure 12), many of these instances are from classes that would be much more atypical in a true test environment and the down weighted accuracy of the two models is actually quite similar (Figure 12). In terms of the gate criteria, again no y-axis threshold exists which separates instances where the ResNet ended up being correct from those where the k-NN ended up being correct.
dot_c_k1_gate_on_dot
Figure 20
dot_c_k1_gate_on_dot_w_knn_centroid
Figure 21
We tried many more gate criteria than the two described and shown here, and even tried combinations of them, but none produced a decision criterion which allowed us to choose when to use the k-NN model or ResNet model in such a way as to meaningfully improve the overall accuracy. While we were quite excited about combining the two models, it appears based on our initial attempts, that the models differ randomly and not in a systematic way which can be utilized in an ensemble.

Discussion


In summary, the goal of our project was to test whether memorization could help the generalization of image classification models, especially in the long-tailed setup. To evaluate this theory, we designed a first-step approach which stored learned representations of training images and utilized this information to classify test images in a k nearest neighbor fashion. In our experiments, we saw some promising signs for benefits of our memorization model, namely in improving the accuracy on rarer classes. Yet at the same time, the majority of our experiments didn't turn out as hypothesized, evidence against potential benefits of incorporating memorization into computer vision models. In particular, all of our extension ideas failed to improve overall accuracy. We were particularly surprised that we were unable to combine the base ResNet model with our k-NN approach into a sudo-model which unified the benefits of both approaches. We believe this may still be possible and are interested in continuing to study this question in future work. Finally, we note that we’ve only just begun studying how to add memorization into the mix and that there is still lots of work to be done to evaluate the potential benefits and to translate theory into practical algorithms. Similarly to our effort to combine the k-NN approach with the ResNet model, we think another line of interesting future work is to focus on how to incorporate memorization into existing models right from the start so memorization can affect the training process.
Contributions
RW: developed initial codebase for our approach and extensions, ran experiments, wrote reports, created presentation and website
JM: improved codebase and its usability, added different similarity metrics, added the long-tailed distribution, ran experiments, discussed ideas, results, reports, and final presentation/website