We characterize a prevalent weakness of deep neural networks (DNNs), overthinking, which occurs when a DNN can reach correct predictions before its final layer. Overthinking is computationally wasteful, and it can also be destructive when, by the final layer, a correct prediction changes into a misclassification. Understanding overthinking requires studying how each prediction evolves during a DNN's forward pass, which conventionally is opaque. For prediction transparency, we propose the Shallow-Deep Network (SDN), a generic modification to off-the-shelf DNNs that introduces internal classifiers. We apply SDN to four modern architectures, trained on three image classification tasks, to characterize the overthinking problem. We show that SDNs can mitigate the wasteful effect of overthinking with confidence-based early exits, which reduce the average inference cost by more than 50% and preserve the accuracy. We also find that the destructive effect occurs for 50% of misclassifications on natural inputs and that it can be induced, adversarially, with a recent backdooring attack. To mitigate this effect, we propose a new confusion metric to quantify the internal disagreements that will likely to lead to misclassifications.
Shallow-Deep Network (SDN): SDN is a generic
modification to to off-the-shelf DNNs for introducing internal
classifiers (ICs). Our modification attaches ICs to various stages
of the forward pass, as the above figure illustrates, and combines
shallow networks (earlier ICs) and deep networks. In
the figure, we introduce the two ICs—dotted components—to a CNN, and
the resulting network produces three predictions: two internal and
Internal Classifiers (ICs): an internal classifier consists of two parts: a single fully connected layer that follows a feature reduction layer. The feature reduction layer takes the large output of a network's internal layers and reduces its size. The fully connected layer, using the reduced output, produces the internal prediction. We pick a subset of internal layers to attach the internal classifiers after: the internal layers closest to the 15%, 30%, 45%, 60%, 75% and 90% of the full network's computational cost-the number of FLOPs. Given an input, SDNs produce 6 internal predictions and a final prediction.
Training the ICs: To complete our modification, we train the ICs. The training strategy is based on whether the original CNN is pre-trained—the IC only training—, or untrained—the SDN training. The IC-only training converts a pre-trained network to an SDN, we freeze original weights and train only the weights in the attached ICs. However, the IC-only training has a major drawback: the standard network training aims to improve only the final prediction accuracy, resulting weak ICs. Therefore, we provide another training strategy (the SDN Training) that trains the original weights, from scratch, jointly with the ICs by using a weighted loss function.
Network Overthinking: we consider that a network overthinks on an input sample when its simpler representations, relative to the final layer, at an earlier layer are adequate to make a correct classification. SDN modification allows us to monitor how a decision-making process of a DNN evolves. Using SDNs, we demonstrate that the overthinking can be wasteful and destructive.
Overthinking leads to wasted computation as a
network applies unnecessarily complex representations to
classify the input.
The above figure shows randomly sampled TinyImageNet images that can exit at each IC in VGG16-SDN. The first column presents the samples that are correctly classified at the first IC (IC1), and so on. We see how the input complexity progressively increases, which suggests that the earlier ICs learn simple representations that allow them to recognize typical samples from a class, but that fail on atypical samples.
Overthinking results in destructive outcomes as a
network confuses: with complex inputs, correct
internal predictions evolve into misclassifications.
The above figure presents a selection of images that only the first IC can correctly classify in VGG16-SDN. All other ICs in VGG16-SDN misclassify these samples. We hypothesize that these images consist of confusing elements, belonging to more than one class; while the first IC recognizes the objects displayed prominently, subsequent layers discover finer details that are irrelevant.
We compare the inference costs of the CNNs (V: VGG16, R: ResNet56, W: Wide-ResNet32-4, M: MobileNet) and SDNs with early exits. N. lists the original CNNs and their accuracies. ≤25%, ≤50%, and ≤75% report the early exit accuracy when we limit the average inference cost to at most 25%, 50% and 75% that of the original CNN’s. Max reports the highest accuracy early exits can achieve. We highlight the cases where an SDN outperforms the original CNN. In each cell, the left and right accuracies are from the IC-only and the SDN training strategies.
The SDN training improves the early exit accuracy significantly; exceeding the original accuracy while reducing the inference costs more than 50% on CIFAR-10 and CIFAR-100 tasks. Even with the IC-only training strategy, early exits are still effective; allowing more than 25% reduced inference cost. In Tiny ImageNet, early exits can reduce the cost by more than 25%, usually without any accuracy loss.
Confusion metric: we propose the confusion metric to capture the inconsistency that quantifies how much the final prediction diverged from the internal predictions. The divergence of the final prediction from an internal prediction is given by the L1 distance between them.
The plot illustrates the distribution of the confusion scores VGG16-SDN produces on Tiny ImageNet test samples. The dotted and straight lines indicate the average confusion scores in the wrong and correct predictions, respectively. The confusion metric inherently captures the destructive effect. While correct predictions tend to have low confusion scores (−0.29 on average), the misclassifications are concentrated among instances with high confusion (0.71 on average). This is better than the confidence scores, which shows the difference: 0.93 vs. 0.71 on the correct and wrong predictions. Further, when used as an indicator for likely misclassifications, we found that confusion also produces fewer false negatives than confidence.