Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
785 Discussions

A Fresh Take on Neural Network Pruning from the Angle of Graph Theory

Souvik_Kundu
Employee
1 0 7,167

This article summarizes our recent findings on sparse neural networks with the VITA group @ UT Austin jointly written with Duc N.M Hoang (UT Austin). The work is accepted at NeurIPS’23. 

Souvik_Kundu_0-1701269360568.png

Souvik_Kundu_1-1701269360570.jpeg

Figure 1: Pruning is a well-explored research area. However, we believe there are still quite some unknowns that need exploration, as the model and training dynamics evolves. In this work, we take a dig at the Sparse graph pattern to identify their ability to yield good subnetworks. 

 

Pruning: The Old Faithful 

Model pruning is arguably one of the oldest methods of deep neural networks (DNN) model size reduction that dates to the 90s, and quite stunningly, is still a very active area of research in the AI community. Pruning in a nutshell, creates sparsely connected DNNs that intend to retain model performance as the original dense model. There is a plethora of research to perform pruning, including that before training (pruning at initialization or PaI), during training (sparse learning) or post training (pruning after training or PaT). However, have you ever wondered whether the sparse graph patterns that is generated through pruning carry any specific information that we should at all pay attention to (see Fig. 1)? If so, how can we make such sparse graph information useful. 

In exploring ways to assess the performance of pruned neural networks at initialization, we take inspiration from Ramanujan graphs and present a new methodology: using Ramanujan graph theory to sample sparse masks. This approach shifts focus from traditional magnitude pruning to exploring how a network’s structure, such as the expander graph, influences its performance. The key take-away being a novel usage of structure and weight expander property that marries network structure with parameter importance. The outcome is a better way to sample structured masked at initialization, bypassing the need for full-finetuning. 

 

Basics of Ramanujan Graphs 

Ramanujan graphs, representing an ideal amalgamation of sparsity and connectivity, have been a key focus of our previous research. In this previous work[1] in ICLR 2023, the authors demonstrated the application of Ramanujan criteria as a benchmark for evaluating highly sparse Deep networks prior to training, achieved by introducing a novel recursive structure that softens the criteria. These graphs hold significant value in computer science, especially in the development of efficient communication networks, optimization of data structures, and applications in cryptography. There are several ways to measure the degree of the Ramanujan graph in a bipartite graph (i.e., compute layers in neural networks). Let A and W be a computational layer’s binary and weighted adjacency matrixes and µ/µ^ be corresponding eigenvalues (trivial and non-trivial). 

In our work, we formalize their definitions as follows: 

 

Souvik_Kundu_2-1701269360573.png

Table 1: Important graph notations, their formulations, and description 

 

Observation of correlation between Ramanujan expansion and performance in pruned networks: 

Upon establishment of the criteria for both weighted and non-weighted topology in compute graphs, we apply them on a random sparse mask generator (random in time over-parametrization or in short, ITOP-Rand) with ResNet34/18 to see a surprising near-linear correlation to performance. 

In specific, we empirically find a potential relation between the Ramanujan gap and the weighted spectral gap, both from canonical and iterative perspective (see Fig. 2). We see a combined influence of sparse topology and weights in optimizing sparse DNN performance, prompting us to investigate the central question: Is possible to fuse all the metrics to co-exist in the same representation? Towards that goal, the “Full spectrum” combines the separated notations in Table 1 into a single metric. 

Souvik_Kundu_3-1701269360576.png

Figure 2: The evolution of ∆rimdb and ∆r over time as a performance function using ITOP. 

 

We summarize two key observations with this new metric: 

  1. Linear Correlation with Performance: Under ITOP’s random growth regime, the study reveals that the ℓ2-moving distance from a sparse subnetwork’s original position within the “Full-spectrum” almost linearly aligns with its performance (See Fig. 3 below). 
  1. Advantage Over Current Graph Metrics: Compared to traditional metrics, the “Full spectrum” demonstrates a more direct linear correlation with performance (see Fig. 4). This not only validates the co-existence of these metrics within a unified framework but also underscores the potency of this integrated approach. 

Souvik_Kundu_4-1701269360579.png

Figure 3: Full-spectrum ℓ2-distance against classification performance on CIFAR-10 across different models. We denote the associated Pearson correlation (ρ) to performance in the parentheses. 

 

The “Full-spectrum” is a powerful metric that harmoniously marries critical aspects of DNN structure and performance. Its ability to provide a more accurate, linear correlation with performance paves the way for its application in practical pruning scenarios, promising a future where DNNs are not just efficient but also intuitively aligned with their underlying topological and weight-based properties. 

 

Souvik_Kundu_5-1701269360582.pngFigure 4: Illustration of the correlation between topology (∆r), weights (λ), and the combined “full spectrum” with respect to the classification performance on CIFAR-100 with ResNet18 (top row) and ResNet34 (bottom row). ρ indicates the Pearson correlation. 

 

Efficient Sampling of Sparse Graph Structures: 

The ability to identify the most promising sparse structures right at initialization is the holy grail, as it sidesteps the tedious and costly process of full fine-tuning or training, thus saving lots of GPU hours and $$. Given our previous observations of a near-linear correlation between ‘Full-spectrum’ coordinates and performance, it seems logical to use such a metric as a criterion for sparse DNN sampling. We brief two of our findings on augmented pruning methods below. 

 

PAGS: A Lightweight, Effective Pruning Method 

Pruning at Initialization as Graph Sampling (PAGS) is a novel approach that doesn’t require pre-training, making it an appealing option for efficient model optimization when training resources are limited. The core idea behind PAGS is to maximize the layer-wise full-spectrum ℓ2-moving distance, a metric identified as a strong predictor of performance. This technique augments existing PaI methods by oversampling the PaI mask generator and selecting the top-performing mask based on the ℓ2-moving distance criterion. 

The beauty of PAGS lies in its simplicity and computational efficiency. Generating a mask only requires a forward pass over a minibatch of data, significantly reducing computational costs. In practice, PAGS operates as a meta-framework, capable of enhancing any existing PaI method, whether proxy-driven or not. 

Souvik_Kundu_6-1701269360585.png

 

 PEGS: Minimal Pre-Training for Maximal Efficiency 

Building on the zero-shot capability of PAGS, Pruning Early as Graph Sampling (PEGS) takes it a step further by incorporating a small amount of pre-training. This method involves lightly training the dense network for a few iterations (typically 500 in the experiments) before applying the PAGS technique. This approach draws inspiration from the concepts of Lottery Ticket Hypothesis (LTH) rewinding and early-bird ticket identification in neural networks. 

 

Results and Evaluations Summary: 

We leverage four different PaI generator methods (SNIP, GraSP, ERK and Random) and LTH and compare various benchmarks (we are showing only CIFAR-100 here; check the papers for more results). Unless otherwise stated, we use a high target sparsity of 99% in all our experiments. Means we keep only 1% of the weights active while ensuring all other weights are zero. Unsurprisingly, the sampled masks with “Full spectrum” outperform the generators. Table 2 and 3 shows results with ResNet18 and ResNet34 model. For further evaluations please refer to our paper [2]. EB refers to early bird ticket, and interestingly we identify that it fails to yield any mentionable accuracy with VGG16 model at a high sparsity. LTH refers to the model generated via lottery ticket hypothesis through iterative pruning, which is extremely compute heavy as compared to ours. 

Souvik_Kundu_7-1701269360590.png

Table 2: Results on CIFAR-100 on ResNet18 using PAGS/PEGS in comparison to vanilla PaI methods, LTH and EB. Baseline refers to the PaI-found sparse mask. ✗, ✗✗, and ✗✗✗, represent low, high, and very high pre-training costs, respectively. @100 and @500 refer to different “pre-training” iterations using PEGS. @0 means we start from random initialization using PAGS. 

 

 

Souvik_Kundu_8-1701269360596.png

Table 3: Results on CIFAR-100 on ResNet34 using PAGS/PEGS in comparison to vanilla PaI methods, LTH and EB. Baseline refers to the PaI-found sparse mask. ✗, ✗✗, and ✗✗✗, represent low, high, and very high pre-training costs, respectively. @100 and @500 refer to different “pre-training” iterations using PEGS. @0 means we start from random initialization using PAGS. 

 

Conclusion: 

In summary, we bridged the critical gap in neural network optimization by applying graph theory to understand sparse subnetworks better. We focused on the weighted spectral gap in Ramanujan structures, finding a significant correlation with network performance. This led to the valuable “full spectrum” coordinate, a metric used to enhance pruning both at initialization and after minimal pre-training. 

This approach not only improves sparse neural networks but also sets the stage for future research into the complex interplay of topology and weights in DNNs. 

References: 

[1] https://openreview.net/pdf?id=uVcDssQff_ , ICLR 2023 (spotlight) 

[2] https://openreview.net/pdf?id=DIBcdjWV7k, NeurIPS 2023 

 

Tags (2)