Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
491 Discussions

Training Causal Language Models on SDSC’s Gaudi-based Voyager Supercomputing Cluster

Chen_Levkovich
Employee
3 2 2,741

The SDSC Voyager supercomputer is an innovative AI system designed specifically for science and engineering research at scale. Funded by the National Science Foundation, Voyager represents a collaboration with the San Diego Supercomputer Center at UC San Diego, Supermicro, and Intel’s Habana Labs, facilitating deep engagement with AI research community, and enabling the application of deep learning techniques to interdisciplinary problems requiring Natural Language Processing (NLP) and image analysis.  A key trend in NLP is the development of large language models and efficiently training them is critical to any AI research project. Voyager provides the scale and compute required to train such LLMs.

DeepSpeed [1] is a popular deep learning software library which includes ZeRO (Zero Redundancy Optimizer), a memory-efficient approach for distributed training [2].  ZeRO has multiple stages of memory efficient optimizations, and the Habana ® SynapseAI® software currently supports ZeRO-1 and ZeRO-2. In an earlier blog [3], we discussed theoretical underpinnings behind DeepSpeed ZeRO. In the current blog, we discuss how we trained large causal language models, namely Hugging Face GPT2-XL model with 1.5 Billion parameters and GPT3-XL with 1.3 Billion parameters, on the Voyager cluster and show how the training scales across multiple nodes.  Habana’s SynapseAI software and DeepSpeed support is integrated in the Hugging Face Optimum Habana library and makes it easy to configure and train these causal language models at scale.

 A brief peek into Voyager AI Supercomputer

Chen_Levkovich_0-1673301090348.jpeg

 

Science is one of the critical domains where petabytes and exabytes of data need to be analyzed to extract insights and discoveries that directly impact humanity in health, precision medicine, physics, climate change and many other sciences. Scientists are increasingly focusing on AI and deep learning as the means to help them in their data-driven scientific exploration and discovery. The Voyager supercomputer is an innovative AI system designed specifically for science and engineering research at scale [4]. Funded by the National Science Foundation, Voyager represents a collaboration with the San Diego Supercomputer Center at UC San Diego, Supermicro, and Habana Labs, an Intel company, focused on supporting research in science and engineering that is increasingly dependent upon artificial intelligence and deep learning as a critical element in the experimental and/or computational work.

Voyager includes 42 training nodes of Supermicro X12 GAUDI Servers with Intel’s third-generation Xeon Scalable “Ice Lake” processors. Each node contains eight first-generation Gaudi cards. AI training on today’s massive datasets requires huge systems with very large numbers of processor cores and nodes in order to train algorithms in reasonable time. At this scale, communication across the system is often constrained by the network. Each Gaudi processor integrates ten 100 Gbps RDMA over Converged Ethernet (RoCE) interfaces. The interconnect can be configured several ways, making it flexible for different applications.  In Voyager, each of the eight Gaudi cards contained within the servers dedicates seven 100 Gbps ports to connect in an all-to-all, non-blocking configuration to the other cards. The other three 100 Gigabit ports in each Gaudi are dedicated to scale out, giving each Voyager node 24 100 Gigabit ports. The scalability with integrated RoCE makes the overall system more efficient.  The Voyager system moved into its operational testbed phase in May 2022, with several applications already running on the system with minimal code changes needed [5].  With the Voyager system in operational test bed phase, there is considerable interest in exploring the key AI advancements in inter-disciplinary scientific topics as diverse as climate research and high-energy physics.  

Natural Language Processing is key to data driven scientific exploration in myriad ways, including automatic extraction of structured knowledge from scientific publications, extracting insights from electronic health records, data driven drug discovery techniques [6], and even predicting future scientific discoveries by simply extracting meaningful data from research publications [7]. The field of natural language processing has of late been driven by large language models (LLMs), which have been shown to excel in various tasks such as natural language understanding, question answering, summarization, translation, and natural language generation.  As the size of LLMs keeps growing, it becomes critical to efficiently train these models at scale. Let us discuss next how we trained a large causal language model on the Voyager cluster and demonstrate how it scales efficiently across multiple nodes.

Training causal language models on Voyager

Currently, the standard approach for building new models in NLP for any task follows the well-known paradigm of first Pre-train then Fine-tune. This consists of pre-training a large language model on a huge dataset followed by fine-tuning of the pre-trained model. Pre-training can typically be driven by either (a) masked language modelling objective or (b) causal language modelling objective. In masked language modelling, a certain percentage of words in a text is masked and the language model is trained with the objective of predicting the masked words correctly. The model learns the representation of the masked word in a bidirectional context looking at words from both left and right. In causal language model, only the words left of the word being predicted are used as the context by the model. Models belonging to GPT family such as GPT-1, GPT-2 and GPT-3 are causal language models [8,9,10].  We trained a causal language model namely GPT2-XL on the Voyager platform scaling the training across multiple Gaudi devices.

A brief refresher on GPT-2

GPT-2 is the second-generation language model in the Generative Pre-Trained Transformer (GPT) family from OpenAI [9]. It is a transformer-based pre-trained model on English language using a causal language modeling (CLM) objective. GPT-2 was a pioneer in the LLM revolution by showing that language models can achieve excellent performance in various NLP tasks without being explicitly trained using supervised learning paradigm for that task, when it is trained in an unsupervised fashion on a dataset of millions of webpages called WebText.  GPT-2 family consists of four models of different parameter sizes namely 117M (GPT-2), 345M (GPT-2 medium), 762M (GPT-2 large) and 1,542M (GPT-2 XL).  The capacity of the language model is essential to the success of zero-shot task transfer. 

Training the Hugging Face GPT-2 XL model with DeepSpeed ZeRO2

We used Hugging Face Transformers powered by Optimum Habana Library [11] for our experiments. The Optimum Habana makes it easy to train and run Hugging Face transformer models on Habana Gaudi with maximum efficiency. You can find more details on Optimum Habana library here

The OpenAI release of GPT-2 XL model was trained on a huge English corpus of 40GB of data, by scraping all the web pages from outbound links on Reddit that received at least 3 upvotes.  Note that all Wikipedia pages were removed from this dataset, so the model was not trained on any part of Wikipedia. The resulting dataset (called WebText) weighs 40GB of texts but has not been publicly released. Hence, we used the open-source equivalent to the WebText dataset, which is the OpenWebText dataset for training the GPT-2 XL model [12]. The OpenWebText dataset is available from Hugging Face datasets repository and we used the same.  

The model configuration is the same as available in Hugging Face repository as shown below.

_path": "gpt2-xl",
  "activation_function": "gelu_new",
  "architectures": 
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1600,
  "n_head": 25,
  "n_inner": null,
  "n_layer": 48,
  "n_positions": 1024,
  "output_past": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "torch_dtype": "bfloat16",
  "transformers_version": "4.23.1",
  "use_cache": false,
  "vocab_size": 50257
}

We trained the model on Voyager platform with a different number of Gaudi devices using openWebText dataset https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling The experiment settings are listed below:

  • SynapseAI Version: 1.7.0
  • Dataset: OpenWebText (~12 GB)
  • Global Batch Size (BS): 512
  • Micro BS per Device: 8
    • For 128 Cards, Micro BS is 4
  • Max Steps: 400
  • Warm Up Steps: 300

The command is shown below:

N_CARDS=$((NUM_NODES*8));
TRAIN_BATCH_SIZE=512;
TRAIN_MICRO_BATCH_SIZE_PER_GPU="${TRAIN_MICRO_BATCH_SIZE_PER_GPU:-8}";
GRADIENT_ACCUMULATION_STEPS="${GRADIENT_ACCUMULATION_STEPS:-$((TRAIN_BATCH_SIZE/N_CARDS/TRAIN_MICRO_BATCH_SIZE_PER_GPU))}";

CMD="python optimum-habana/examples/language-modeling/run_clm.py \
        --config_name gpt2-xl \
        --tokenizer_name gpt2-xl \
        --dataset_name openwebtext \
        --dataset_config_name plain_text \
        --do_train \
        --save_steps 1000 \
        --per_device_train_batch_size ${TRAIN_MICRO_BATCH_SIZE_PER_GPU}  \
        --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
        --max_steps 400 \
        --gaudi_config_name Habana/gpt2 \
        --use_habana \
        --logging_steps 1 \
        --use_lazy_mode \
        --gradient_checkpointing \
        --throughput_warmup_steps 2 \
        --use_cache=False \
        --log_on_each_node=False \
        --deepspeed configs/deepspeed_zero_2.json"

We used the recommended DeepSpeed Configuration settings for our experiment.  The batch sizes and gradient accumulation steps are configured through the main training scripts and can be manually set in the Python command. We report experimental results with Habana’s DeepSpeed ZeRO2 using the following settings.

ZeRO2 Configuration

{
    "steps_per_print": 1,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto",
    "bf16": {
        "enabled": true
    },
    "gradient_clipping": 1.0,
    "zero_optimization": {
        "stage": 2,
        "overlap_comm": false,
        "reduce_scatter": false,
        "contiguous_gradients": false
    }
}

We used the MPIRun command used to run distributed training across multiple nodes.

mpirun -n ${N_CARDS} \
        --allow-run-as-root \
        --bind-to core \
        --map-by ppr:4:socket:PE=6 \
        --rank-by core --report-bindings \
        --tag-output \
        --merge-stderr-to-stdout \
        --prefix /opt/amazon/openmpi/ \
        -x PYTHONPATH="/usr/lib/habanalabs/:$PYTHONPATH" \
        ${CMD};

Samples Per Second Performance is calculated by the following equation [12]

 

Chen_Levkovich_1-1673301089992.png

 

Tokens Per Second is calculated by multiplying “Samples Per Second” with “Max Sequence Length”.  In the table below, ideal throughput is calculated by assuming 1x for single node and Nx that number for N nodes as the ideal linear scaling of 100%. Scaling efficiency is calculated by the ratio of actual throughout (samples per second) to ideal throughput.

# Devices Samples Per Second Tokens Per Second Ideal Throughput (calculated assuming ideal linear scaling of 100%) Scaling efficiency Grad Accumulation Steps
8 19.17 19630 19.17 100% 8
16 37.50 38404 38.34 98% 4
32 72.63 74370 76.68 95% 2
64 119.00 121856 153.36 78% 1
128 233.42 239022 306.72 76% 1
Table 1.  GPT2-XL pretraining throughput

We find that the training scales satisfactorily as the number of devices are increased.  With DeepSpeed ZeRO2, we find that samples per second scales from 19.17 with 8 devices to 37.5 with 16 devices, 72.63 with 32 devices, 119 with 64 devices and reaches 233.4 with 128 devices, showing reasonable scaling as the number of Gaudi devices are increased.

Chen_Levkovich_2-1673301089989.png

 

Figure 1.  GPT2-XL pretraining throughput

Training GPT3-XL with DeepSpeed ZeRO2

We also ran experiments to train another causal language model namely GPT3-XL [10]. GPT3-XL is based on GPT2-XL and has 1.3B parameters.

We used the Hugging Face GPT2 model and adjusted the model parameters to match that of GPT3-XL (we used the number of attention heads as 16 as configured in [11]). The table below lists the model configuration used in our experiment.

{
  "_name_or_path": "gpt3-xl",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 2048,
  "n_embd": 2048,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 24,
  "n_positions": 2048,
  "output_past": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "torch_dtype": "bfloat16",
  "transformers_version": "4.23.1",
  "use_cache": false,
  "vocab_size": 50257
}

The example run command to train the GPT3-XL model is shown below:

N_CARDS=$((NUM_NODES*8));
TRAIN_BATCH_SIZE=512;
TRAIN_MICRO_BATCH_SIZE_PER_GPU="${TRAIN_MICRO_BATCH_SIZE_PER_GPU:-8}";
GRADIENT_ACCUMULATION_STEPS="${GRADIENT_ACCUMULATION_STEPS:-$((TRAIN_BATCH_SIZE/N_CARDS/TRAIN_MICRO_BATCH_SIZE_PER_GPU))}";

CMD="python optimum-habana/examples/language-modeling/run_clm.py \
        --config_name configs/gpt3-xl/config.json \
  --tokenizer_name gpt2-xl \
  --model_type gpt2 \
  --lr_scheduler_type cosine \
  --adam_beta1 0.9 \
  --adam_beta2 0.95 \
  --adam_epsilon 1e-8 \
  --max_grad_norm 1.0 \
  --weight_decay 0.1 \
  --learning_rate 2.0e-4 \
  --warmup_steps 300 \
        --dataset_name openwebtext \
        --dataset_config_name plain_text \
        --do_train \
        --save_steps 1000 \
        --per_device_train_batch_size ${TRAIN_MICRO_BATCH_SIZE_PER_GPU}  \
        --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
        --max_steps 400 \
        --gaudi_config_name Habana/gpt2 \
        --use_habana \
        --logging_steps 1 \
        --use_lazy_mode \
        --gradient_checkpointing \
        --throughput_warmup_steps 2 \
        --use_cache=False \
        --log_on_each_node=False \
        --deepspeed configs/deepspeed_zero_2.json"

MPIRun command was used to run distributed training across multiple nodes as shown below:

mpirun -n ${N_CARDS} \
        --allow-run-as-root \
        --bind-to core \
        --map-by ppr:4:socket:PE=6 \
        --rank-by core --report-bindings \
        --tag-output \
        --merge-stderr-to-stdout \
        --prefix /opt/amazon/openmpi/ \
        -x PYTHONPATH="/usr/lib/habanalabs/:$PYTHONPATH" \
        ${CMD};

The table below shows the experimental results for training GPT3-XL on multiple Gaudi devices (from 8 to 64). We see that Habana Gaudi exhibits scaling efficiency >95% up to 64 devices.

# Devices Samples Per Second Tokens Per Second Ideal Throughput (calculated assuming ideal linear scaling of 100%) Scaling efficiency Grad Accumulation Steps
8 29.14 59670 29.14 100% 8
16 57.00 116736 58.27 98% 4
32 114.57 234639 116.54 98% 2
64 222.06 454779 233.09 95% 1

Figure 2.  GPT3-XL pretraining throughput

Chen_Levkovich_3-1673301089994.png

 

Figure 2.  GPT3-XL pretraining throughput

Conclusion

Support for memory efficient optimizations with DeepSpeed enables effective training of large models on the Habana Gaudi platform. We demonstrated in this blog how causal language models can be trained effectively on multiple Gaudi devices on the Voyager platform using Synapse AI Software 1.7 software release [14] and Hugging Face Habana Optimum library with DeepSpeed support. Check out the optimum-habana DeepSpeed guide [15] and Habana’s  DeepSpeed Usage Guide [16] for more information.

On Habana GitHub, we have also published examples of training 1.5B and 5B parameter BERT models on Gaudi using DeepSpeed ZeRO-1 and ZeRO-2 which you can refer to for more details. Wishing you Happy & Speedy Language Model Training with Habana Gaudi, SynapseAI, DeepSpeed and Hugging Face Habana Optimum Library!

 

Originally published at https://developer.habana.ai on January 9, 2023.

References

  1. https://www.deepspeed.ai/
  2. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. Samyam Rajbhandari,  Jeff Rasley, Olatunji Ruwase and Yuxiong He.
  3. Memory-Efficient Training on Habana® Gaudi® with DeepSpeed  
  4. Voyager AI Supercomputer Gives Investigators New Deep Learning Experimental Platform
  5. Voyager Supercomputer Enters Testbed Phase
  6. AI-based language models powering drug discovery and development
  7. https://www.nature.com/nature-index/news-blog/how-an-ai-trained-to-read-scientific-papers-could-predict-future-discoveries
  8. Improving Language Understanding by Generative Pre-Training by Alec Radford et al.
  9. Language Models are Few-Shot Learners. By Tom Brown et al.
  10. Language Models are Unsupervised Multitask Learners by Alec Radford et al.
  11. https://github.com/EleutherAI/gpt-neo#5-choose-a-model-configuration
  12. https://huggingface.co/docs/optimum/main/habana/index
  13. https://skylion007.github.io/OpenWebTextCorpus/
  14. Habana Optimum Performance Calculation Code
  15. Habana Gaudi 1.7 Release Notes
  16. https://huggingface.co/docs/optimum/main/habana/usage_guides/deepspeed
  17. DeepSpeed Usage Guide

 

2 Comments
JonathanAdler
Novice

Amazing result. How can I get access to Voyager? 

Chen_Levkovich
Employee