Researchers Shaojie Bai from Carnegie Mellon University (CMU), Vladlen Koltun from Intel Labs, and J. Zico Kolter from CMU and Bosch Center for AI presented the paper Multiscale Deep Equilibrium Models as an oral presentation at NeurIPS 2020.
Researchers took on the challenge of determining if implicit deep learning is relevant for general pattern recognition tasks. Implicit networks do away with flexible layers and stages, which are used by explicit models to appropriately model multiscale structure. Explicit architectures explicitly express hierarchical structure, with upsampling and downsampling layers that transition between consecutive blocks operating at different scales. Researchers began the project with this question: Can implicit models that forego deep sequences of layers and stages attain competitive accuracy in domains characterized by rich multiscale structure, such as computer vision?
“We developed the first one-layer implicit deep model that is able to scale to realistic visual tasks, such as megapixel-level images, and achieve competitive results,” said Vladlen Koltun, Director of the Intelligent Systems Lab at Intel Labs, and Chief Scientist for Intelligent Systems at Intel. “MDEQ exemplifies a different approach to differentiable modeling.”
Constant Memory-Efficient Model
The original deep equilibrium models (DEQs) solve for the fixed point of a sequence model with black-box root-finding methods, equivalent to finding the limit state of an infinite-layer network. MDEQ extends this setting, in which the different resolutions coexist side by side in a layer. The input is injected at the highest resolution and then propagated implicitly to the other scales, which are optimized simultaneously by a black-box solver that drives them to satisfy a joint equilibrium condition.
The main advantage of MDEQ is that this one-layer network uses a simplified layer design that consumes less memory, according to Shaojie Bai, a machine learning doctoral student at Carnegie Mellon University.
For example, the design of a 100-layer explicit network is labor intensive, according to Bai. Additionally, during training of this explicit network, the forward propagation computation graphs from all layers are memorized and stored as intermediate feature representations. This allows the system to backpropagate through the exact same number of layers, making a backward pass in the reverse direction through the same graph. However, this can cause a memory bottleneck.
“MDEQ is designed to resolve or alleviate these issues with a constant memory-efficient model for hierarchical pattern recognition,” said Bai. “This implicit model can be trained for multiple tasks, such as image classification and semantic segmentation, without paying the same memory costs as an explicit model.”
In contrast to explicit networks, implicit models do not have prescribed computation graphs. They instead posit a specific criterion that the model must satisfy, such as the root of an equation. Importantly, the algorithm that drives the model to fulfill this criterion is not prescribed. Therefore, implicit models can leverage black-box solvers in their forward passes and use analytical backward passes that are independent of the forward pass trajectories. An MDEQ directly solves for and backpropagates through the equilibrium points of multiple feature resolutions simultaneously, using implicit differentiation to avoid storing intermediate states.
“There really isn't much that we need to store in the process,” said Bai. “We only need to store the eventual output, which we need to use to train the network, but not the previous 99 layers, for example.”
Experiments with MDEQs
Researchers tested the effectiveness of this approach on two large-scale vision tasks: ImageNet classification and semantic segmentation on high-resolution images from the Cityscapes dataset. In both settings, MDEQs are able to match or exceed the performance of recent competitive computer vision models.
MDEQ saved more than 60% of the GPU memory at training time compared to explicit models such as ResNets and DenseNets, while maintaining competitive accuracy. Training a large MDEQ on ImageNet consumes about 6 GB of memory. This low memory footprint is a direct result of the analytical backward pass.
However, MDEQs are generally slower than explicit networks. Researchers observed a 2.7X slowdown for MDEQ compared to ResNet-101, a tendency similar to that observed in the sequence domain. A major factor contributing to the slowdown is that MDEQs maintain features at all resolutions throughout, whereas explicit models such as ResNets gradually downsample their activations and thus reduce computation.
In the future, MDEQs would be useful in edge computing on mobile devices where memory is limited. Equilibrium networks also could be potentially used in large-scale generative modeling, such as generating new images. The team also proposed ways in which future research could stabilize implicit models, making the models scalable to large-scale applications.