Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
754 Discussions

Running Llama3.3-70B on Intel® Gaudi® 2 with vLLM: A Step-by-Step Inference Guide

Eugenie_Wirz
Employee
0 0 986

Authors: Jaideep Kamisetti (Senior AI Engineer, Intel® Liftoff), Rahul Unnikrishnan Nair (Engineering Lead, Intel® Liftoff)

This guide demonstrates how to efficiently run inference with Llama3.3 70B on Intel® Gaudi® 2 AI accelerators using vLLM. We’ll cover the setup process, configuration options, and performance optimization techniques for deploying this large language model in production environments.

 

Gaudi2 Hardware Overview

Intel®’s Gaudi® 2 accelerators offer specific advantages for LLM inference workloads with their hardware architecture:

  • 24 100 Gbps RDMA NICs for efficient networking
  • 96 GB of HBM2E memory delivering 2.45 TB/s bandwidth
  • 48 MB of on-chip SRAM
  • Specialized Matrix Multiplication Engines (MMEs) and Tensor Processing Cores (TPCs)
  • Support for FP32, TF32, BF16, FP16, FP8, INT8, and INT4 data types

These specifications make Gaudi2 particularly well-suited for large model inference where memory bandwidth and efficient matrix operations are critical performance factors.

 

vLLM Inference Engine

Running inference on a 70B parameter model requires specialized software architecture to manage memory efficiently and parallelize computation effectively. vLLM (Very Large Language Model) inference engine provides several key optimizations:

 

PagedAttention

vLLM’s PagedAttention mechanism uses a block-based KV cache to efficiently manage attention keys and values without requiring contiguous memory allocation. This approach significantly improves memory utilization and enables serving more concurrent requests.

 

Step 1: Obtaining the Source Code

First, clone the vLLM repository, which contains the necessary code for running inference on Gaudi2 hardware:

git clone https://github.com/vllm-project/vllm.git
cd vllm

This repository contains the specialized code that enables vLLM to interface with Habana’s Gaudi2 accelerators through the SynapseAI SDK.

 

Step 2: Building the Optimized Container

The Dockerfile.hpu contains specialized configurations that enable vLLM to interface with Habana’s Synapse AI software stack:

docker build . -t vllm_hpu -f ./docker/Dockerfile.hpu

This build process creates a container with:

  1. PyTorch with Habana device support
  2. Habana’s Synapse AI libraries
  3. vLLM with HPU-specific optimizations
  4. Required dependencies for efficient inference

The resulting container is specifically engineered to translate vLLM’s tensor operations into efficient execution paths on the Gaudi2 architecture.

 

Step 3: Executing the Inference Server

Note: Before using Llama models, ensure you have read and accepted Meta’s license agreement on the Hugging Face model page. Access to these models requires agreement to the terms of use.

Before running the container, we need to set up environment variables for the model and authentication:

export MODEL=meta-llama/Llama-3.3-70B-Instruct
export HF_TOKEN=<hf_token>

Now we can launch the vLLM server with Gaudi2-specific configurations:

docker run -it \
  --runtime=habana \
  -e VLLM_SKIP_WARMUP=true \
  -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN \
  -e HABANA_VISIBLE_DEVICES=all \
  --cap-add=sys_nice \
  --net=host \
  --rm \
  vllm_hpu \
  --model=$MODEL \
  --tensor-parallel-size=<number_of_devices>

 

Parameter Breakdown

Let’s examine the critical parameters that affect performance:

Parameter

Purpose

Optimization Notes

--runtime=habana

Activates the Habana device driver

Required for HPU access

VLLM_SKIP_WARMUP=true

Bypasses initial warmup phase

Reduces startup time but may affect initial inference latency

HABANA_VISIBLE_DEVICES=all

Makes all HPUs available

For multi-card setups

--tensor-parallel-size

Controls model parallelism

Should match available HPU count

 

HPU Graph Compilation Process

A key optimization in vLLM for Gaudi2 is the HPU graph capture process, which involves ahead-of-time compilation of execution graphs to minimize runtime overhead:

  1. During initialization, vLLM captures computational graphs for both the “prefill” phase (processing all input tokens) and the “decode” phase (generating each output token).
  2. These graphs are compiled into optimized execution plans that can be reused across multiple inference requests.
  3. The environment variable VLLM_CAPTURE_GRAPH=true enables this optimization.

This approach significantly improves inference performance by reducing computational overhead during model execution.

 

Step 4: Interacting with the Inference Engine

Once the server is running, you’ll see a startup sequence as the model is loaded and the HPU graphs are compiled.

Eugenie_Wirz_0-1750787276271.png

Llama3 on Gaudi2 vLLM Start

Wait for the server initialization to complete, indicated by this message:

INFO:     Application startup complete.

At this point, the vLLM server is listening for API requests that conform to the OpenAI-compatible protocol, allowing you to interact with the model through a standardized REST API.

 

Testing with a Simple Prompt

Let’s send a basic completion request to verify the system is working properly:

curl http://localhost:8000/v1/completions \
  -H "Content-Type: application/json" \
  -d '{
      "prompt": "San Francisco is a",
      "max_tokens": 7,
      "temperature": 0
  }'

Eugenie_Wirz_1-1750787276273.png

Llama3 on Gaudi2 Inference

 

API Request Parameters

Key Parameters: - prompt: Input text to complete - max_tokens: Output length limit - temperature: Randomness control

Processing Pipeline: - Tokenization - Prefill phase (process all prompt tokens) - Decode phase (generate output tokens) - Response formatting

 

Inference Pipeline Breakdown

When processing a request, the system goes through several distinct phases, each with different resource requirements:

Phase

Description

Primary Bottleneck

Tokenization

Converting text to token IDs

CPU

Prefill

Processing all prompt tokens

HPU compute

Decode

Generating each output token

HPU compute

Response

Formatting and returning JSON

CPU

 

The prefill phase involves a full forward pass through the model for all input tokens, which is typically the most compute-intensive part of the inference process for short prompts. For longer generation tasks, the decode phase becomes the dominant factor in overall latency.

 

Experimenting with Different Parameters

For more complex interactions, you can adjust various parameters to control the model’s behavior:

curl http://localhost:8000/v1/completions \
  -H "Content-Type: application/json" \
  -d '{
      "prompt": "Write a recursive algorithm to calculate Fibonacci numbers",
      "max_tokens": 250,
      "temperature": 0.7,
      "top_p": 0.95,
      "frequency_penalty": 0.5
  }'

These parameters provide fine-grained control over the generation process, allowing you to balance between deterministic outputs and creative responses.

 

Advanced Considerations and Optimizations

Model Compatibility and Hardware Utilization

Not all models are created equal when it comes to specialized hardware acceleration. The architectural decisions made during model development can significantly impact performance on different accelerators, so it’s important to verify compatibility before deployment.

Compatible Model Families

  • Llama
  • Mistral
  • Falcon
  • MPT
  • BLOOM
  • Gemma

Before deploying a model, verify its compatibility with Gaudi2 by checking the Habana Model Zoo. This repository maintains an up-to-date list of validated models and any required patches or adaptations.

 

Hardware Configuration Optimization

The --tensor-parallel-size parameter is critical for performance optimization and must be adjusted based on your specific hardware configuration. This parameter determines how the model’s weights are distributed across multiple HPUs:

--tensor-parallel-size=<number_of_devices>

For example, with a full 8-HPU Gaudi2 server, you would set:

--tensor-parallel-size=8

This distributes the 70B parameters across all available HPUs, with each device handling approximately 8.75B parameters. The communication fabric between HPUs becomes a critical performance factor at this scale, as efficient inter-device communication is essential for maintaining low latency.

 

Memory Management Strategies

Large language models require careful memory management, particularly for the attention key-value cache which grows linearly with sequence length. vLLM implements several strategies to optimize memory usage on Gaudi2:

  1. Block-based KV Cache: Instead of contiguous memory allocation, vLLM uses a block structure for efficient memory utilization
  2. Attention Bucketing: Similar sequence lengths are processed together to minimize padding waste
  3. Graph Memory Reservation: HPU graph compilation requires dedicated memory, controlled by environment variables:
# Reserve 15% of usable memory for graph compilation (default: 0.1 or 10%)
export VLLM_GRAPH_RESERVED_MEM=0.15

# Allocate 40% of graph memory to prefill operations (default: 0.3 or 30%)
export VLLM_GRAPH_PROMPT_RATIO=0.4

These parameters allow fine-tuning the memory allocation between model weights, KV cache, and graph compilation to optimize for your specific workload requirements.

 

Offline Serving Options

While this guide focuses on online serving via a REST API, vLLM also supports offline batch inference for scenarios where latency is less critical than throughput. This approach is ideal for processing large volumes of text without the overhead of API calls.

The offline mode can be accessed through the Python API:

from vllm import LLM, SamplingParams

# Initialize the model on Gaudi2
llm = LLM(model="meta-llama/Llama-3.3-70B-Instruct",
          tensor_parallel_size=8,
          device="hpu")

# Define generation parameters
sampling_params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=512)

# Process a batch of prompts
prompts = [
    "Explain quantum computing in simple terms.",
    "Write a recursive algorithm for binary search.",
    "Describe the architecture of modern CPUs."
]

outputs = llm.generate(prompts, sampling_params)

# Process results
for output in outputs:
    print(f"Prompt: {output.prompt}")
    print(f"Generated text: {output.outputs[0].text}")
    print("---")

For more details on offline serving options, refer to the official documentation.

 

Performance Tuning and Benchmarking

Key Performance Metrics

When optimizing LLM inference, three primary metrics determine performance:

  • Throughput: Measured in tokens per second, this represents the total output generation capacity of the system. Higher is better for batch processing and data generation tasks.
  • Time to First Token (TTFT): The latency between submitting a request and receiving the first generated token. Lower is better for interactive applications where initial responsiveness matters.
  • Inter-Token Latency: The time between consecutive token generations. Consistent, low inter-token latency creates a smooth, natural experience in interactive applications.

Different use cases prioritize these metrics differently:

  • Interactive chat applications: Prioritize TTFT and consistent inter-token latency for responsive user experience
  • Batch document processing: Prioritize overall throughput for maximum efficiency
  • Real-time assistants: Require balanced optimization across all metrics

 

Tuning Parameters

To optimize for your specific use case, experiment with these parameters:

  1. Batch Size: Larger batches improve throughput at the cost of latency
  2. Sequence Length Bucketing: Configure bucket parameters to optimize for your workload patterns. Gaudi2 accelerators work best with fixed tensor shapes, and bucketing minimizes graph recompilations by grouping similar-sized requests together. Each bucket corresponds to a separate optimized binary for specific tensor shapes.
  3. Quantization: Experiment with FP8 precision for improved throughput

 

Testing Matrix

Create a comprehensive benchmark using these variables:

  • Batch sizes: 1, 4, 16, 32
  • Prompt lengths: 32, 128, 512, 1k, 2k, 4k, 8k, 16k, 32k tokens
  • Output lengths: 32, 128, 512, 1k, 2k, 4k, 8k tokens
  • Different Bucket configurations (Refer below on how to tune bucketing configs)

Plot the results to identify performance patterns and optimal configurations for your deployment scenario.

 

Understanding Bucketing on Gaudi2

Bucketing is a critical optimization technique for Gaudi accelerators. Here’s how it works and how to configure it:

 

How Bucketing Works

Intel® Gaudi® accelerators perform best with fixed tensor shapes. The Intel® Gaudi® Graph Compiler translates neural network operations into an optimized computation graph for particular tensor dimensions. Without bucketing, each new tensor shape would trigger a graph recompilation, causing significant runtime overhead.

Bucketing solves this by:

  1. Pre-defining a set of tensor shapes (“buckets”) for both batch size and sequence length dimensions
  2. Padding incoming requests to the nearest bucket size
  3. Using pre-compiled graphs for each bucket

This minimizes runtime compilation and improves overall performance, especially for production serving.

 

Optimizing for Long-Context Workloads

For optimal performance with long-context models, bucket configuration should be tailored to your specific workload requirements. The default conservative settings are designed for general use cases, but long-context applications require careful tuning to unlock maximum performance potential.

 

Configuring Buckets

Buckets are defined by three parameters:

  • min: The smallest bucket size
  • step: The interval between buckets after the ramp-up phase
  • max: The largest bucket size

These can be configured separately for prompt (prefill) and decode phases using environment variables:

# Prompt phase batch size buckets
export VLLM_PROMPT_BS_BUCKET_MIN=1                   # Default: 1
export VLLM_PROMPT_BS_BUCKET_STEP=32                 # Default: min(max_num_seqs, 32)
export VLLM_PROMPT_BS_BUCKET_MAX=256                  # Default: min(max_num_seqs, 256)

# Prompt phase sequence length buckets
export VLLM_PROMPT_SEQ_BUCKET_MIN=128                # Default: block_size
export VLLM_PROMPT_SEQ_BUCKET_STEP=2048               # Default: block_size
export VLLM_PROMPT_SEQ_BUCKET_MAX=32768               # Match: max_model_len

# Decode phase batch size buckets
export VLLM_DECODE_BS_BUCKET_MIN=1                   # Default: 1
export VLLM_DECODE_BS_BUCKET_STEP=32                 # Default: min(max_num_seqs, 32)
export VLLM_DECODE_BS_BUCKET_MAX=256                  # Default: max_num_seqs

# Decode phase block size buckets
export VLLM_DECODE_BLOCK_BUCKET_MIN=128              # Default: block_size
export VLLM_DECODE_BLOCK_BUCKET_STEP=1024             # Default: block_size
export VLLM_DECODE_BLOCK_BUCKET_MAX=33792

 You can control the bucketing strategy further with these environment variables:

# Enable exponential bucket spacing instead of linear (default: true)
export VLLM_EXPONENTIAL_BUCKETING=true

# Strategy for prompt graph capture: min_tokens or max_bs (default: min_tokens)
export VLLM_GRAPH_PROMPT_STRATEGY=min_tokens

# Strategy for decode graph capture: min_tokens or max_bs (default: max_bs)
export VLLM_GRAPH_DECODE_STRATEGY=max_bs

 

Key Differences: Prompt vs Decode Bucketing

●       Prompt Phase uses sequence length bucketing because prompts are processed in parallel, making the total token count the critical dimension.

●       Decode Phase uses block-based bucketing because it processes multiple sequences simultaneously, and memory is managed in fixed-size blocks (typically 128 tokens per block). The decode bucket represents the total number of KV cache blocks needed across all sequences in the batch.

 

Optimizing Bucket Configuration

When configuring buckets, consider these trade-offs:

  1. More buckets = better padding efficiency but longer warmup time and more memory for compiled graphs
  2. Fewer buckets = faster warmup but potentially more padding overhead
  3. Bucket distribution should match your expected workload distribution

 For long-context applications: Maximize performance by setting VLLM_PROMPT_SEQ_BUCKET_MAX to match the model's max_model_len and use VLLM_EXPONENTIAL_BUCKETING=true to efficiently manage memory and warmup time.

For interactive applications with varied input lengths, consider a wider range of sequence length buckets. For batch processing with consistent sizes, tighter bucket ranges can reduce padding overhead.

 

Warmup and Bucketing

Warmup is a critical process that pre-compiles the computation graphs for all defined buckets before the server starts accepting requests:

# Control warmup (default: false for VLLM_SKIP_WARMUP means warmup is enabled)
export VLLM_SKIP_WARMUP=false

If warmup is enabled (the default):

  • All bucket configurations are pre-compiled during startup
  • No compilation overhead occurs during inference (within bucket boundaries)
  • Server startup takes longer but runtime performance is more consistent

If warmup is disabled:

  • Server starts faster
  • The first request for each bucket size will trigger compilation
  • This causes significant latency spikes during production use

Important: While disabling warmup may be acceptable during development, it is strongly recommended to enable warmup in production deployments for consistent performance.

 

Intel® Liftoff for Startups

Startups building AI solutions globally can benefit from the Intel® Liftoff program through:

  • Compute Access: Project-based credits for ITAC and early access to Intel hardware and software
  • Engineering and GTM Support: Technical guidance from Intel engineers and optional co-marketing opportunities
  • Zero Equity Model: Intel Liftoff is a no-equity program focused on technical enablement

Apply or learn more at developer.intel.com/liftoff.

 

Related resources

Intel® Gaudi® 2 AI accelerator - High-performance AI training processor designed for deep learning workloads

Benchmark Intel® Gaudi® 2 AI Accelerator for Large Language Models

About the Author
I'm a proud team member of the Intel® Liftoff for Startups, an innovative, free virtual program dedicated to accelerating the growth of early-stage AI startups.