Shallow-Deep Networks

Understanding and Mitigating Network Overthinking

Yigitcan Kaya, Sanghyun Hong, and Tudor Dumitraș
University of Maryland, College Park

In ICML 2019
Talk: 4:30-35 P.M (Hall A), Poster: 6:30-9:00 P.M (Pacific Ballroom) [Jun. 11, Tue]

Abstract

We characterize a prevalent weakness of deep neural networks (DNNs), overthinking, which occurs when a DNN can reach correct predictions before its final layer. Overthinking is computationally wasteful, and it can also be destructive when, by the final layer, a correct prediction changes into a misclassification. Understanding overthinking requires studying how each prediction evolves during a DNN's forward pass, which conventionally is opaque. For prediction transparency, we propose the Shallow-Deep Network (SDN), a generic modification to off-the-shelf DNNs that introduces internal classifiers. We apply SDN to four modern architectures, trained on three image classification tasks, to characterize the overthinking problem. We show that SDNs can mitigate the wasteful effect of overthinking with confidence-based early exits, which reduce the average inference cost by more than 50% and preserve the accuracy. We also find that the destructive effect occurs for 50% of misclassifications on natural inputs and that it can be induced, adversarially, with a recent backdooring attack. To mitigate this effect, we propose a new confusion metric to quantify the internal disagreements that will likely to lead to misclassifications.


Shallow-Deep Networks: A Generic Modification to Deep Neural Networks

Shallow-Deep Network (SDN): SDN is a generic modification to to off-the-shelf DNNs for introducing internal classifiers (ICs). Our modification attaches ICs to various stages of the forward pass, as the above figure illustrates, and combines shallow networks (earlier ICs) and deep networks. In the figure, we introduce the two ICs—dotted components—to a CNN, and the resulting network produces three predictions: two internal and one final.

Internal Classifiers (ICs): an internal classifier consists of two parts: a single fully connected layer that follows a feature reduction layer. The feature reduction layer takes the large output of a network's internal layers and reduces its size. The fully connected layer, using the reduced output, produces the internal prediction. We pick a subset of internal layers to attach the internal classifiers after: the internal layers closest to the 15%, 30%, 45%, 60%, 75% and 90% of the full network's computational cost-the number of FLOPs. Given an input, SDNs produce 6 internal predictions and a final prediction.

Training the ICs: To complete our modification, we train the ICs. The training strategy is based on whether the original CNN is pre-trained—the IC only training—, or untrained—the SDN training. The IC-only training converts a pre-trained network to an SDN, we freeze original weights and train only the weights in the attached ICs. However, the IC-only training has a major drawback: the standard network training aims to improve only the final prediction accuracy, resulting weak ICs. Therefore, we provide another training strategy (the SDN Training) that trains the original weights, from scratch, jointly with the ICs by using a weighted loss function.


Network Overthinking

Network Overthinking: we consider that a network overthinks on an input sample when its simpler representations, relative to the final layer, at an earlier layer are adequate to make a correct classification. SDN modification allows us to monitor how a decision-making process of a DNN evolves. Using SDNs, we demonstrate that the overthinking can be wasteful and destructive.

Overthinking leads to wasted computation as a network applies unnecessarily complex representations to classify the input.

The above figure shows randomly sampled TinyImageNet images that can exit at each IC in VGG16-SDN. The first column presents the samples that are correctly classified at the first IC (IC1), and so on. We see how the input complexity progressively increases, which suggests that the earlier ICs learn simple representations that allow them to recognize typical samples from a class, but that fail on atypical samples.

Overthinking results in destructive outcomes as a network confuses: with complex inputs, correct internal predictions evolve into misclassifications.

The above figure presents a selection of images that only the first IC can correctly classify in VGG16-SDN. All other ICs in VGG16-SDN misclassify these samples. We hypothesize that these images consist of confusing elements, belonging to more than one class; while the first IC recognizes the objects displayed prominently, subsequent layers discover finer details that are irrelevant.


Early Exits: Mitigate the Wasteful Effect

We compare the inference costs of the CNNs (V: VGG16, R: ResNet56, W: Wide-ResNet32-4, M: MobileNet) and SDNs with early exits. N. lists the original CNNs and their accuracies. ≤25%, ≤50%, and ≤75% report the early exit accuracy when we limit the average inference cost to at most 25%, 50% and 75% that of the original CNN’s. Max reports the highest accuracy early exits can achieve. We highlight the cases where an SDN outperforms the original CNN. In each cell, the left and right accuracies are from the IC-only and the SDN training strategies.


The SDN training improves the early exit accuracy significantly; exceeding the original accuracy while reducing the inference costs more than 50% on CIFAR-10 and CIFAR-100 tasks. Even with the IC-only training strategy, early exits are still effective; allowing more than 25% reduced inference cost. In Tiny ImageNet, early exits can reduce the cost by more than 25%, usually without any accuracy loss.

Confusion: Mitigate the Destructive Effect

Confusion metric: we propose the confusion metric to capture the inconsistency that quantifies how much the final prediction diverged from the internal predictions. The divergence of the final prediction from an internal prediction is given by the L1 distance between them.


The plot illustrates the distribution of the confusion scores VGG16-SDN produces on Tiny ImageNet test samples. The dotted and straight lines indicate the average confusion scores in the wrong and correct predictions, respectively. The confusion metric inherently captures the destructive effect. While correct predictions tend to have low confusion scores (−0.29 on average), the misclassifications are concentrated among instances with high confusion (0.71 on average). This is better than the confidence scores, which shows the difference: 0.93 vs. 0.71 on the correct and wrong predictions. Further, when used as an indicator for likely misclassifications, we found that confusion also produces fewer false negatives than confidence.


Acknowledgements

We thank Dr. Tom Goldstein, Dr. Nicolas Papernot, Virinchi Srinivas and Sarah Joseph for their valuable feedback. This research was partially supported by the Department of Defense.