Hesham Mostafa is an AI research scientist at Intel Labs where he works on large-scale graph learning and the system and algorithmic challenges involved in leveraging web-scale graph data in common machine learning tasks
This week, at the Fifth Conference on Machine Learning and Systems (MLSys), we are excited to present our approach to distributed large-scale graph training on Intel CPUs. The Sequential Aggregation and Rematerialization (SAR) scheme is fully open sourced and available here. This method achieved impressive results on common graph learning benchmarks in terms of speed and memory efficiency.
Training GNNs on Large Graphs
Graph neural networks (GNNs) are widely applied in various types of graph-related problems. In a layer of a GNN, each node, or vertex, produces an output feature vector by using a learnable transformation to assemble the input feature vectors of its neighbors in the graph. This continues on for all subsequent layers. After K layers, a node’s receptive field would span the input features of all nodes that are less than K hops away in the graph. While this could be advantageous in order to allow a GNN to consider a wider neighborhood when learning the features of each node, it can lead to a large computational graph at the output where every node’s output features depend on a significant portion of the entire input graph, as well as intermediate node features, a phenomenon that is known as neighbor explosion.
When training deep GNNs on large graphs, storing the computational graph in memory quickly becomes a challenge. Full-batch GNN training is thus difficult to scale to large graphs. More scalable training alternatives such as sampling-based methods have recently become more common because they keep the memory requirements in check by sampling a small part of the graph during each training iteration. However, the sampling involved in these approaches leads to noisy and biased gradient estimates. Additionally, many sampling-based methods still run into memory issues as the depth of the GNN increases, and bigger neighborhoods need to be sampled for each node. The sampling operation itself introduces extra computational overhead as well. Many scalable GNN training methods avoid having to construct the large computational graph by using non-learnable message propagation followed by learnable node-wise networks. This simplifies the GNN considerably as the expensive propagation of messages between neighbors in the graph is only done once in a non-learnable manner during pre-processing. The use of non-learnable messages, however, limits the expressiveness of these models compared to traditional GNNs. Distributed training across several machines might seem to be one way to handle the large memory requirements of full-batch GNN training. However, distributed training methods such as model parallel training would still run into issues if a single device cannot accommodate the input graph or the activations of a single GNN layer.
A more promising approach is domain parallel training, in which the input is split into many parts and each machine handles the computation for a single part. The only issue with this method when applied to GNNs is that even though each machine stores only a small part of the input graph initially, each machine would eventually need to store a substantial portion of the entire graph as part of its output’s computational graph. Our work builds off of this framework, but avoids constructing the computational graph during the forward pass. Instead, SAR constructs the computational graph and frees it piece by piece during the backward pass. This results in excellent memory scaling behavior where the memory consumption per worker goes down linearly with the number of workers, even for densely connected graphs. This allows SAR to scale to arbitrarily large graphs by simply adding more workers. We show that the communication overhead incurred by SAR to re-materialize the computational graph during the backward pass can be avoided for many popular GNN variants, making the memory savings of SAR practically free.
Unfortunately, for some variants such as Graph Attention Networks (GATs), the communication overhead of SAR can not be avoided. In response, we identified a couple of mitigating factors for attention-based models which synergize particularly well with SAR. These optimizations avoid materializing the costly attention coefficients tensors and instead compute them on the fly using fused kernels during the forward and backward passes. We show this speeds up the computation in GAT-like networks by reducing redundant memory accesses, and in conjunction with SAR, further reduces the memory footprint for attention-based models. We show that after incorporating these optimizations, training GAT using SAR is as fast as vanilla domain parallel training while consuming a fraction of the memory.
Using SAR
We build SAR directly on top of DGL, one of the most popular GNN training libraries, which in turn is built on top of PyTorch. This allows us to directly use standard DGL layers and GNN networks. Using SAR requires minor modification to existing single-node DGL code. It requires:
Unlike prior full-batch GNN training frameworks that reimplement basic graph operations, we directly leveraged the graph operations of DGL, allowing us to capitalize on efficient, continuously updated kernels for many basic graph operations such as sparse-dense matrix multiplications (SpMM). We integrate the sequential rematerialization steps of SAR into PyTorch’s Autograd mechanics in a way that is transparent to the user. Users can thus use PyTorch’s standard model definition steps to describe arbitrary GNN topologies. During training, SAR runs under the hood to manage inter-machine communication and the dynamic construction and deletion of pieces of the computational graph.
Performance Results
SAR consumes up to 2x less memory when training a 3-layer GraphSage network on ogbn-papers100M (111M nodes, 3.2B edges), and up to 4x less memory when training a 3-layer Graph Attention Network (GAT). SAR achieves near linear scaling for the peak memory requirements per worker. We use a 3-layer GraphSage network with hidden layer size of 256, and a 3-layer GAT network with hidden layer size of 128 and 4 attention heads. We use batch normalization between all layers.
Figure 1. Peak memory consumption per machines as we vary the number of machines for 2 GNN variants: GraphSage and GAT. SAR cuts memory consumption by a factor of 2X and 4X, respectively, when training on 128 machines. Without SAR, we are unable to train the GAT variant using 32 machines due to OOM (out-of-memory) errors
As for run-time, the run-time of our open-source GNN training library improves as we add more workers. At 128 workers, the epoch time is 3.8s which is the fastest reported epoch time for the ogbn-papers100M dataset. Each worker is a 2-socket machine with 2 Icelake processors (36 cores each). The workers are connected using Infiniband HDR (200 Gbps) links. After 100 epochs, training has converged. We use a 3-layer GraphSage network with hidden layer size of 256 and batch normalization between all layers. The training curve is the same regardless of the number of workers/partition.
Figure 2. Epoch time when training on ogbn-papers100M using our open-source distributed GNN training library
Figure 3. Test accuracy at each training epoch when doing full-batch GNN training on ogbn-papers100M
To get started using SAR, check out the SAR GNN library at https://github.com/IntelLabs/SAR, and be sure to take a look at SAR's documentation.
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.