Deep learning models for computer vision are deployed everywhere, but they are
increasingly being deployed in safety-critical applications.
Examples include self-driving cars and medical applications like AI-assisted robotic surgery
or medical image analysis.
Problem
In safety-critical applications models need to generalize reliably from their train to test
environments, otherwise consequences could be severe.
Models often don't generalize well!
Consider the experiment from [Recht et al. (2019)].
The authors plotted the ImageNet test set
accuracy of a bunch of deep learning models (blue dots) on the x-axis. From left to right the
accuracy of individual models increases. On the y-axis, the authors plotted the accuracy of each
model on a new test set that they created, but one which was created by following the same
instructions and distribution as the original (as best they could). If the models truly generalized
well to unseen data, you would expect the accuracy on the two test sets to be quite similar
(black line). Instead, even in this simple setting, all models perform about 10 percent worse on
the new test set. This is rather alarming, especially considering the possible deployment
applications where these differences could matter substantially.
Problem
Generalizing well to unseen examples is even harder on instances which appear only rarely
in the training data.
Real-world datasets are long-tailed!
In long-tailed distributions a few objects (or visibility patterns) occur frequently,
while many objects (or visibility patterns) occur infrequently. For example, in the SUN
large scale scene recognition dataset below (plot from [Zhu et al. (2014)]),
people and windows occur frequently, but many other objects also appear in the dataset. Likewise,
even for a single object like a person, some visibility patters (i.e. head on) occur frequently, but
many others (people on horses or facing sideways etc.) appear in the dataset a few times. The
total weight of the infrequent examples, however, ends up being a substantial part of the
dataset. Think of all the many corner cases self-driving cars need to handle correctly.
Problem Statement
To improve the generalization of computer vision models on real-world datasets, we need to consider
both common and atypical instances. In this project, we focus on the latter.
How do we improve generalization on the long-tail?
So how do we go about trying to make progress on this problem? First off,
existing techniques for training with rare examples include:
upsampling or downsampling the training data
re-weighting the loss function to pay more attention to underrepresented classes
or utilizing training objectives which seek to maximize the accuracy on the worst
performing class rather than the average accuracy (e.g.
Distributionally Robust Optimization (DRO)
).
One main problem with these techniques, however, is that you need to know, and have labeled, all
of the atypical examples/groups. Rare instances like uncommon visibility patterns may be missed
because they could be rare subpopulations within a very common labeled class.
In addition to the techniques above, a very recent line of theoretical work suggests that
memorization of training data and their labels is key to achieving close-to-optimal
generalization, especially in the context of long-tailed distributions (
[Feldman (2021)],
[Talwar et al. (2020)]
).
We found this very interesting. Intuitively it makes some sense,
as what else can you really do with only a few examples? Another obvious advantage is that you
may not need to actually know and label what’s atypical (as above). If you just memorize a bunch of
training data, you are likely to memorize a few labels from each of the rare classes (or
visibility patterns), and then you may be able to use that information to help classify similar
rare instances in the future. Note also that if memorization can help improve
generalization, it may be possible to couple it with the existing techniques above for additional
robustness. In contrast to these potential benefits, however,
conventional thinking somewhat contradicts this theory,
as memorization is generally thought to lead to overfitting and therefore poor generalization.
Project Goal
While recent theoretical results suggest memorization may improve generalization, especially
over long-tailed distributions, these results are very abstract, make a number of assumptions,
and do not yet provide any practical algorithms for computer vision. The goal of our project is
to:
Test whether memorization can improve generalization by incorporating it into real object
recognition settings.
Approach
The basic setup of our project is as follows. Specifically, the goal of our approach is to
concisely memorize a representation for each training image and then use this information to
classify future test images. Instead of memorizing raw images themselves (comparing raw image
intensities is not that semantically meaningful), we decided to memorize learned, high-level
representations for each image and their corresponding labels. Thus, the first step in our approach
is to learn these representations. As trained convolutional neural networks (CNNs) are known to
extract useful image features/representations, we decided to start by conventionally training one
of these models. We chose a normal ResNet-20
([He et al. (2016)])
model for this step. The ResNet consists of a
bunch of convolution layers, followed by an average pooling layer, and finally a single dense layer.
Step 1: Train ResNet-20
After training our ResNet model, the second step is to extract and memorize the learned
representations and labels for each training image. For each image, we feed it through the trained
network and take its representation to be the output of the ResNet after the average pooling layer,
but before the final dense layer and classification step. For ResNet-20, this representation
is a 64 dimensional vector.
Step 2: Extract and memorize representations for each training image
Given a new image which we wish to classify, we can then use the memorized information in a
k-nearest neighbor (k-NN) fashion. First we compute the representation for the new image.
Then we use this representation to compared against our memorized representations,
and return the top k closest matches based on some similarity metric. Finally, given the
k-nearest neighbors for our test image, we can predict its label using a majority vote
over the neighbor’s labels.
Step 3: Classify future test images using memorized representations
Implementation Details:
We implemented our above approach from scratch in TensorFlow.
For all experiments below, we evaluate our method on the CIFAR-10
([Krizhevsky et al. (2009)])
dataset. During training, we use conventional data augmentation: pad four pixels on each
side of the image and then
randomly crop a 32x32 section from the padded image or its horizontal flip.
We normalize images by subtracting off the mean and dividing by the variance of each color
channel, with statistics computed on the train set. For training our ResNet implementation,
we use the standard ResNet learning rate schedule for CIFAR-10: an initial learning rate of 0.1
dropped by a factor of 10 after 32k and 48k iterations (82 and 123 epochs) with 64k iterations total
(164 epochs). We use a batch size of 128 and stochastic gradient descent with momentum 0.9 and
categorical cross-entropy loss function. For k-NN similarity metrics, we consider three main
choices:
the Euclidean distance (L2) between representation vectors
the dot product (Dot) between representation vectors
and the normalized dot product (NDot) between representation vectors.
We also consider first centering the representations by subtracting off the mean (computed on
the memorized representations) before utilizing one of the three metrics above.
Note that centering does not affect the Euclidean distance but
does affect dot product based similarity metrics.
Results
We tested the basic memorization setup described above by looking at the
accuracy of the k-nearest
neighbor model on the standard CIFAR-10 image classification dataset. Our default baseline was
to compare to the accuracy of the original ResNet-20 model from Step 1. We varied the similarity
metric used to identify top neighbors and the hyperparameter k (how many nearest neighbors are
used for classification). Below
(Figure 3)
you can see that for some values of k and some
similarity
metrics, the overall accuracy is very similar to that of a normal ResNet model.
Zooming in on the per class
accuracies
(Figure 4)
shows the memorization model makes very similar predictions to a standard
ResNet. This isn't necessarily too surprising, as the memorization model uses representations
from the trained ResNet.
What about when classes are imbalanced, e.g. the long-tail setting?
While we didn't expect much difference between the memorization approach and a standard ResNet in
the case where the classes are balanced, recall that our primary motivation was to improve
generalization
on the long-tail. To this end, we tested the k-nearest neighbor model on versions of the CIFAR-10
train set which we manually imbalanced.
We first randomly removed all but 50 images from class zero
(one percent of the normal 5000 per class) to create the class distribution
below
(Figure 5).
In the per class accuracies
for this case (for a specific k-NN model) there seems to be a slight improvement on the rare
class when comparing the memorization approach to the standard ResNet
(Figure 6).
When computing the overall accuracy of models on imbalanced class distributions there is an
important caveat to consider. In theory, the class distribution should also be imbalanced at
test time, i.e. rare instances in the training data remain rare during inference. However when a
class appears infrequently in the test set, it becomes hard to calculate an accurate accuracy on
those instances. For example, there are normally 1000 instances of each class in the CIFAR-10 test
set. If we remove all but one percent of them for class zero to match the training distribution
above
(Figure 5),
then we are left with only 10 examples from that class.
Instead of removing examples from the
test set, we instead evaluate on every example, allowing for more accurate per class accuracy
calculations, but in addition to the overall test set accuracy we also report a down weighted
test set accuracy. In this metric, the contribution of rare classes to the overall accuracy is
down weighted so that the end result is the accuracy you would expect when these instances appear
only rarely at test time.
We report test set accuracy
(Figure 7)
and down weighted test set accuracy
(Figure 8)
for the class distribution above
(Figure 5)
in the plots below.
When evaluating on the full test set,
lower values of k achieve slightly higher accuracy on the imbalanced class zero (as we saw above
in
Figure 6)
when compared to a ResNet model,
but also occasionally drop some accuracy on other classes (e.g. class five in
Figure 6).
This
averages out and results in the memorization approach and standard ResNet model achieving
very similar overall test set accuracies. Higher values of k achieve poor accuracy
on class zero, and thus drop significant accuracy on the test set. This is because
the memorized representations contain only 50 instances which have class zero labels.
Unless a test
point is very close to all of them, when computing a majority vote over many neighbors, other
frequent classes are likely to appear more often. In contrast, when evaluating the down weighted
accuracy, higher values of k perform better on the frequent classes (as seen in the balanced
class setting,
Figure 3)
and
thus achieve higher accuracy. The improvement on class zero for small values of k is negligible
for the down weighted accuracy because of the rarity of those examples at test time.
In addition to removing images from one class, we also created another imbalanced CIFAR-10
training set which more closely matches the long-tailed setting. The class distribution is shown
below (Figure 9).
In this case, if you look at the per class accuracies for the memorization approach and compare
to a standard ResNet model, you can see the k-NN is performing significantly better on the rare
classes. For the similarity metric on the left
(Figure 10),
the difference is quite noticeable. The problem
is that the k-NN has dropped some accuracy on the most common class, class zero in this case.
With a different similarity metric
(Figure 11)
the k-NN regains some of the
lost accuracy on class zero,
but at the cost of smaller improvements on the rare classes.
Looking at the overall accuracies of different models on the simulated long-tailed
distribution
(Figure 9),
we can see that on the raw test set
(Figure 12)
a number of memorization models
achieve significantly
higher accuracy than the baseline ResNet. This is because they have significantly better
accuracy on the rarer classes (as we saw above, e.g.
Figure 10).
As for the case where only one class was imbalanced,
higher values of k do not achieve the same improvements on the infrequent classes and thus see less
of an improvement, if any, on the test set. Note however, that the down weighted test set
accuracy,
the accuracy you would see in the real world,
of the ResNet and memorization models is nearly identical
(Figure 13).
Even though the k-NN approach
significantly improves
accuracy on the rarer classes, it looses some accuracy on the most common classes, and this drop
has a larger impact on the true accuracy (as the common classes appear more frequently at inference
time and thus contribute more to the true accuracy).
Other Baselines: For another baseline in addition to the standard ResNet model, we tried
training a ResNet with loss re-weighting. The idea is to re-weight the contribution of each
example to the loss function during training so that infrequent classes have higher weight and
the model pays more attention to them. Interestingly, this baseline
approach failed to converge for the
simulated long-tailed distribution above
(Figure 9).
We hypothesize that the class distribution is too heavily
imbalanced, requiring to large of re-weighting values, making the training objective very unstable.
Certainly a strength of the k-NN memorization approach
is that it can be applied to any distribution/model/training objective without suffering from
similar problems.
Results Takeaway
We noticed promising signs that memorization can help improve generalization, especially
in the long-tailed setting
(i.e. Figure 10).
There appears to be a tradeoff however, where
as accuracy improves on rare classes, it drops slightly on more common classes. This results
in overall accuracies for our k-NN models that are very similar to our baseline, standard
ResNet model.
Extensions
Beyond our main approach and its corresponding results presented above, we had a number of
hypotheses on how to improve our method. We term these experiments extensions.
Extension 1: Compute similarity metrics for k-NN model in PCA subspaces
One hypothesis we had was that our similarity metrics might be noisy because they are operating
over high dimensional (dimension 64) vectors. It's a well known phenomena that points in many
dimensions are all sort of close or far away from each other. To mitigate this issue, we tried
first reducing the dimension of our representations and then computing our similarity metrics to
find nearest neighbors. Specifically, after memorizing the training data representations in Step 2,
we then computed PCA over these vectors to reduce their dimension (down to 32 or 16). Given
a test image in Step 3, after computing its representation we also reduce its dimension using the
same transformation as computed on the training data. The nearest neighbor model can then proceed
as normal given the smaller vectors.
We found that this additional processing of image representations lead to minimal changes in our
results. For example, in the case of balanced classes below
(Figure 14), the results are quite
similar to the original results for this setting above
(Figure 3).
The only real difference
manifests in the behavior of the dot product similarity metric.
As another example, when looking at the down weighted test
set accuracy on the long-tailed distribution for PCA similarity metrics
(Figure 15), there is
little change in the results when compared to using the full dimensional
vectors (Figure 13).
Extension 2: Learn representations using different training objectives
A second extensions we tried aimed to improve the results of the memorization model by explicitly
training the representations with a k-NN similarity metric in mind. In other words, instead of
taking the memorized representations of the training data to be the output of the average pooling
layer after standard ResNet training, we modified the training procedure in Step 1 by designing
different loss functions. The goal of these loss functions was to cluster similar instances together
with the hope that this might improve the quality of the returned nearest neighbors at inference
time. The notion of clustering is tied to the similarity metric we hoped to improve: i.e. if the
plan was to improve the k-NN model when using Euclidean distance to compare representations, then
it might benefit us to train with a loss function which penalizes large distances between instances
of the same class or small distances between instances of differing classes (same idea for
converting dot product or normalized dot product similarity metrics into loss functions).
We tried three different loss functions (to cluster based on the three main
similarity metrics) each with possibly a few hyperparameters. We also tried varying the optimizer
between SGD and Adagrad.
The exact details of our loss functions can be found in
the code (train_helpers.py). All loss functions we tried resulted in k-NN models which performed
worse than when learning representations using the standard ResNet loss function
(cross-entropy). For example, when using an objective designed to cluster classes based on
the Euclidean distance between points
(Figure 16),
the best performing k-NN models achieve around
89 percent accuracy with balanced classes.
Compare this with the 91.3 percent accuracy of the best performing k-NN models
when using standard ResNet training
(Figure 3).
Similar results hold for another example objective
which clustered points based on the dot product between representations
(Figure 17). Interestingly,
the normalized dot product similarity metric often appeared to perform better than the dot product
similarity metric for balanced classes,
but the reverse was true for the loss function (our normalized dot product
clustering objectives performed worse than the dot product clustering objectives). In hindsight,
it's not necessarily surprising our objectives performed worse than cross-entropy. After all,
even with the classification loss function the network must learn some form of clustering, and there
is a reason cross-entropy has become the go to objective: it works quite well in practice.
Extension 3: Replace basic k-NN majority vote with a more complicated model
For our third extension, we wanted to see if the memorization model could be improved by using
a more complicated function to combine the nearest neighbors into a prediction (instead of
just using the simple majority vote over the neighbor's labels as above). When using the majority
vote to combine the neighbors, only the label information is used, and the actual representation
vector of each of the neighbors is discarded. We wondered if there was any useful information
in these representations themselves. To test this hypothesis, we replaced the majority vote over
the neighbors with a
Transformer
model which takes as input the neighbor's representations and labels, along with the representation
of the test image. The Transformer can then learn to combine this information and produce a
final representation, called the contextualized representation, of the test image which can
then be fed to a final function for classification (we used a single fully connected
layer for the final
function). The entire architecture is depicted below.
In our preliminary experiments with this architecture, we found it provided no accuracy
improvements over the simple majority vote method. For example, with balanced classes a number
of k-NN models achieve 91.3 percent accuracy
(Figure 3)
which is nearly the same accuracy achieved by the Transformer (91.38) with 10 nearest neighbors.
It appears there
is little additional information contained in the neighbor's representations.
We remark, however, that there are a number
of hyperparameters and configuration details for the Transformer that we did not have time to
study, so it's possible that with some tuning the results could improve.
Extension 4: Combine k-NN memorization model and baseline ResNet into one model
For our final extension, and the one we spent the most time on, the goal was to try and combine
the memorization model with the original ResNet to create a "best of both worlds" sudo-model.
This extension was motivated by the tradeoff observed above: the memorization models seemed to
improve accuracy on the rarer classes
but at the expense of some accuracy on the more common classes
(Figure 10).
We wondered whether we could design a gate which picked when to use each model such that the
sudo-model achieved the highest possible accuracy on all classes. Another significant benefit
of such a sudo-model is that it requires zero extra training and very small additional computation
at inference time.
Note, however, that the gate is not as simple
as deciding which model to use based on the class it belongs to because the true class is unknown
at test time. We are only able to separate out the per class accuracies because the true labels
on the test set are known and can be used for evaluation after the fact, but using this information
to decide the prediction itself would be cheating.
Even though using the true labels to design the gate is invalid, we define the optimal gate which
does utilize this information, but only to get an upper bound on how much the sudo-model could
possibly increase accuracy. Specifically, the optimal gate picks the k-NN model if the k-NN model
is correct and the ResNet if the ResNet is correct. If both are correct or both are wrong, it
doesn't matter which model is chosen in the gate. We primarily considered two gates: a gate using
the one nearest neighbor model with normalized dot product similarity metric for the balanced class
setting and a gate using the one nearest neighbor model with centered
dot product similarity metric for
the long-tailed class distribution (the same similarity metric which produced
the large improvements on the rarer classes in
Figure 10).
In the balanced class setting, a single run of our k-NN model achieves 90.92 percent accuracy
while the corresponding ResNet reaches 91.35 percent. The optimal gate produces a sudo-model
with 92.67 percent accuracy, a 15 percent relative drop in error. On the long-tailed
distribution
(Figure 9),
a single k-NN model reaches 59 percent accuracy on the test set and 74.35 percent down weighted
accuracy. The ResNet model results in 52.60 and 90.44 percent test set and down weighted accuracy
respectively. In this case, the optimal gate produces a model with 67.03 percent test set and
92.49 percent down weighted accuracy, a 21 percent relative drop in down weighted error. These
results
were quite exciting and lead us to spend significant time trying to develop a true gate which
approximated the optimal choice. Note, however, that the combination of any two models which make
different predictions in this manner will increase accuracy. It does not mean that such a
combination can be achieved in practice, for example if the two models randomly make different
predictions.
To decide when to use each model in practice, we questioned whether some simple criteria existed
which would allow us to pick the k-NN model or vice versa. Our hypothesis was that it would be
better to use the memorization approach for outlier examples or when the k-NN was very confident
and the ResNet model's prediction in the opposite scenarios.
For example:
Should we pick the k-NN
model when the similarity metric with the nearest neighbor is really high,
say above some threshold?
Or should we pick the ResNet model when the test image representation is very similar to the
centroid
representation for that predicted class (i.e. the test image seems to be at the center of a
cluster of training images of a specific class and therefore is likely not an outlier)?
We plotted possible gate criteria for test images where the k-NN and ResNet predictions
differed (the only points where
the gate matters) and color coded them by which model ended up being correct. If a clear separation
emerged between colors (along the gate criterion axis),
then we would be able to define a gate based on a separation threshold
for that criterion.
In the balanced class setting, the color coded plots for the two example gate criteria
described in the preceding paragraph are plotted below
(Figure 18 and
Figure 19).
On the x-axis, the example number shows how many instances the ResNet model and k-NN model each
got correct out of the instances where their predictions differed. I.e. in
Figure 18, of the 305 test images where the two models
differ, the ResNet gets 174
correct and the k-NN gets 131 correct. Separation in colors along the x-axis merely shows which
model was correct more often on instances from the training set where the two predictions differed.
We are instead interested in separation along the y-axis which would indicate that points which
the ResNet and k-NN each predicted correctly had different characteristics according to the specific
gate criterion.
In these plots no clear separation
between when the ResNet model is correct and when the k-NN is correct exists. For example, in
Figure 18 it is not the case that the k-NN is
correct more often
when the normalized dot product with the nearest neighbor is really high. These results contradict
our above hypothesis.
Similar behavior exists for the two example gate criteria when looking at the long-tailed
distribution
(Figure 20 and
Figure 21).
In this case, the ResNet and k-NN differ on many more test images. While it appears
the k-NN is correct much more often than the ResNet, recall that we evaluate on the
full test set and perform down weighting to calculate the true accuracy for imbalanced class
distributions. So while the k-NN is correct on many more instances from the test set and therefore
has higher test set accuracy than the model
(Figure 12), many of these instances are from classes
that would be much more atypical in a true test environment and the down weighted accuracy of
the two models is actually quite similar
(Figure 12).
In terms of the gate criteria, again no y-axis threshold exists
which separates instances where the ResNet ended up being correct from those where the k-NN ended
up being correct.
We tried many more gate criteria than the two described and shown here,
and even tried combinations
of them, but none produced a decision criterion which allowed us to choose when to use
the k-NN model or
ResNet model in such a way as to meaningfully improve the overall accuracy. While we were
quite excited about combining the two models, it appears based on our initial attempts, that the
models differ randomly and not in a systematic way which can be utilized in an ensemble.
Discussion
In summary, the goal of our project was to test whether memorization could help the generalization
of image classification models, especially in the long-tailed setup. To evaluate this theory,
we designed a first-step approach which stored learned representations of training images and
utilized this information to classify test images in a k nearest neighbor fashion. In our
experiments, we saw some promising signs for benefits of our memorization model, namely in
improving the accuracy on rarer classes. Yet at the same time, the majority of our experiments
didn't turn out as hypothesized, evidence against potential benefits of incorporating
memorization into computer vision models.
In particular, all of our extension ideas failed to improve
overall accuracy. We were particularly surprised that we were unable to combine the base ResNet
model with our k-NN approach into a sudo-model which unified the benefits of both approaches.
We believe
this may still be possible and are interested in continuing to study this question in future work.
Finally, we note that we’ve only just begun studying how to add memorization into the mix and that
there is still lots of work to be done to evaluate the potential benefits and to
translate theory into practical algorithms.
Similarly to our effort to combine the k-NN approach with the ResNet model, we think another
line of interesting future work is to
focus on how to incorporate memorization into existing models right from the start so memorization
can affect the training process.
Contributions
RW: developed initial codebase for our approach and extensions, ran experiments,
wrote reports, created presentation and website
JM: improved codebase and its usability, added different similarity metrics, added the long-tailed
distribution, ran experiments, discussed ideas, results, reports, and final presentation/website