Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
812 Discussions

In-production AI Optimization Guide for Xeon: Search and Recommendation Use Case

bibekbhattarai
Employee
0 0 111

Introduction

Search and recommendation models power everything from e-commerce suggestions to content discovery, shaping user experiences across the digital world. These models must deliver fast, accurate results while handling vast amounts of real-time data. While GPUs often dominate AI discussions, CPUs remain a powerful and cost-effective choice for running these workloads. Modern Intel Xeon CPUs, with their high memory bandwidth, advanced vector/matrix processing, and scalable core counts, can efficiently handle search indexing, sparse computations, and recommendation inference—often with lower power consumption and greater deployment flexibility. In this blog, we’ll explore the optimization journey of a real-world search and recommendation model, unlocking the full potential of CPUs for AI-driven personalization.

  1. Model architecture overview 
  2. Deployment setup 
  3. HW/SW optimization
  4. CPU Resource Optimization
  5. Results and Summary

Model Architecture Overview

The component of the Recommendation pipeline we are focusing on is L1-ranker. It has a simple Two-tower model architecture, where the model learns separate embeddings for users and items so that recommendations can be generated by a simple vector similarity search, enabling millisecond-level retrieval from billions of items. Inside each tower, the model first encodes numerous features relevant to users and products into vectors using different encoding techniques and concatenates all the encodings together into a feature vector. The concatenated feature vector goes through a small Multi-Layer Perceptron (MLP). The encoding techniques for different combinations of input features range from simple one-hot encoding to complex transformer architecture-based encoders. Some of the key observations that will be very useful in our optimization process: 

  1. The amount of compute/memory/time required for different encoding techniques varies widely. 
  2. Most of these encoding routines are independent of one another and can be dispatched in parallel.
  3. The modules that are dominating the runtime include Linear operator, which is handled as a General Matrix Multiplication (GEMMs) in the compute backends like OneDNN/OneMKL.



bibekbhattarai_0-1763151985782.png

Figure 1: A typical inference profile of the model. Multiple parallel threads are executing the encoding routines at the beginning of the timeline. Once all the encoding threads finish and their embeddings are concatenated, the top-MLP head of the recommender is executed to learn the high-order feature interactions.

Model Deployment Overview

Triton Inference Server[1] is an open-source platform designed to streamline the deployment and serving of machine learning models in production. It supports multiple frameworks, including PyTorch, TensorFlow, and ONNX, and allows models to be served efficiently on both CPUs* and GPUs. Triton optimizes inference through dynamic batching, concurrent model execution, and automatic model loading, ensuring high throughput and low latency on CPU platforms.

Figure 2 shows the deployment setup of this model. In the model repository, the recommendation model rec_model is stored, which includes the exported torchscript file of the model, i.e., `model.pt`, and model config `config.pbtxt`. When the Triton server is launched, you can decide on how many model copies, i.e., N, you want to deploy in each server node. Accounting for a total of  `n` servers, a total of `N x n` model copies are running simultaneously.

Here are the customer's Key Performance Indicators (KPI) objectives

  1. Keep client-side query latency within a desired range, e.g., client-side average query latency < 100ms
  2. Minimize the number of servers required to serve the given volume of traffic.  
  3. Model accuracy must remain intact through any model/resource optimization. 

 

bibekbhattarai_1-1763151985786.png

Figure 2: The deployment setup for the recommendation model on the Triton inference server.

Model Optimization

Modern Intel Xeon® processors are equipped with built-in accelerators to speed up the deep learning workloads.  It includes several ISA (Instruction Set Architecture) extensions that enhance inference performance, particularly by accelerating the matrix and vector operations. Key among these extensions are:

  • AVX-512 (Advanced Vector Extensions 512): Enables high-throughput vectorized computations, improving the efficiency of matrix multiplications and tensor operations in deep learning workloads.
  • VNNI (Vector Neural Network Instructions): A subset of AVX-512 introduced to accelerate INT8 computations, reducing the number of instructions needed for convolution and GEMM (General Matrix Multiplication) operations.
  • AMX (Advanced Matrix Extensions): Introduced in 4th 4th-generation Xeon, AMX provides dedicated tile-based matrix processing units that significantly boost the efficiency of deep learning workloads, particularly INT8 and BF16 operations. Starting at the 6th generation of Xeon(codenamed Granite Rapids), it has added built-in support for FP16 operations.

Table 1: Breakdown of Intel Xeon CPU Generations and relevant ISA extensions for deep-learning workloads

Xeon Generation

AVX-512

VNNI

AMX_BF16

AMX_INT8

AMX_FP16

Skylake (1st Gen Xeon Scalable)

Yes

No

No

No

No

Cascade Lake (2nd Gen Xeon Scalable)

Yes

Yes

No

No

No

Cooper Lake (3rd Gen Xeon Scalable - subset)

Yes

Yes

No

No

No

Ice Lake (3rd Gen Xeon Scalable)

Yes

Yes

No

No

No

Sapphire Rapids (4th Gen Xeon Scalable)

Yes

Yes

Yes

Yes

No

Emerald Rapids (5th Gen Xeon Scalable)

Yes

Yes

Yes

Yes

No

Granite Rapids (6th Gen Xeon Scalable)

Yes

Yes

Yes

Yes

Yes

 

One way to know the supported ISAs on your CPU server is to run the `lscpu` and check the flags for each of these ISAs. For example, running ISA on the 4th-generation Xeon server( codenamed Sapphire Rapids) shows these details. Some of the relevant tags you can see on the flags section include:  avx512_vnni, amx_tile, amx_int8, amx_bf16, etc.

Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          46 bits physical, 48 bits virtual
  Byte Order:             Little Endian
CPU(s):                   32
  On-line CPU(s) list:    0-31
Vendor ID:                GenuineIntel
  Model name:             Intel(R) Xeon(R) Platinum 8488C
    CPU family:           6
    Model:                143
    Thread(s) per core:   2
    Core(s) per socket:   16
    Socket(s):            1
    Stepping:             8
    BogoMIPS:             4800.00
    Flags:                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdsee dadx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd ida arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes v pclmulqdq **avx512_vnni** avx512_bitalg tme avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear serialize **amx_bf16** avx512_fp16 **amx_tile** **amx_int8** flush_l1d arch_capabilities
Virtualization features:
  Hypervisor vendor:      KVM
  Virtualization type:    full
Caches (sum of all):
  L1d:                    768 KiB (16 instances)
  L1i:                    512 KiB (16 instances)
  L2:                     32 MiB (16 instances)
  L3:                     105 MiB (1 instance)
NUMA:
  NUMA node(s):           1
  NUMA node0 CPU(s):      0-31

 

These enhancements make CPUs increasingly competitive for high-performance inference, particularly for recommendation models that rely on efficient embedding lookups and matrix computations. Advanced Matrix Extensions (AMX) in particular represent the biggest leap forward for deep learning acceleration on CPUs, as they introduce 2D tiled registers of size 1KB each and built-in operators to compute in INT8, BF16, and FP16 (starting 6th gen Xeon Scalable), offering significant improvements in lower precision matrix computations. Intel provides optimized kernel libraries, e.g., OneDNN, and OneMKL for efficient matrix and vector operations leveraging these ISA extensions. Deep learning frameworks like PyTorch and Tensorflow use OneDNN backend by default when running on the CPUs.

 

In addition, there are framework extension libraries from Intel that act as staging areas for other optimizations like dynamic ISA dispatching, auto-mixed precision, graph optimization, operator fusions, and INT8 quantization. For example, given a Pytorch model, we can optimize it for a given CPU node using Intel® Extension for Pytorch as follows.



import torch
import intel_extension_for_pytorch as ipex

# take the pretrained model
pretrained_model = ...

# Apply IPEX optimization
optimized_model = ipex.optimize(pretrained_model)

Models that are heavy on GEMM operations benefit significantly from conversion to bfloat16. To run the model with auto-mixed precision, on 4th-generation or newer Intel Xeon® processors, you can set the `dtype`  to `bfloat16` in ipex.optimize() function and it will automatically convert eligible parameters to bfloat16 and relevant operators to their bfloat16 counterparts. In addition, using the model in graph mode by utilizing torchscript can further improve the performance.

 

import torch
import intel_extension_for_pytorch as ipex

# pre-requisites: pretrained model and a sample input
pretrained_model = ...
example_inputs = ...

# Apply IPEX optimization
optimized_model = ipex.optimize(pretrained_model, dtype=torch.bfloat16)

# Export model as TorchScript file using JIT tracing
with torch.no_grad(), torch.cpu.amp.autocast(dtype=torch.bfloat16):
    traced_model = torch.jit.trace(optimized_model, example_inputs=example_inputs)
    traced_model = torch.jit.freeze(traced_model)
    traced_model.save("optimized_model.pt")

 

bibekbhattarai_2-1763151985789.png

Figure 3: Throughput improvements across different batch sizes with the usage of auto-mixed precision with the bfloat16 datatype

 

bibekbhattarai_3-1763151985791.png

Figure 4: Latency improvements across different batch sizes with the usage of auto-mixed precision with the bfloat16 datatype



Further Model Optimizations 

In addition to this, it is recommended to utilize Intel Neural Compressor to perform further model compression and optimization. Intel neural compressor is equipped to perform several model compression tasks while tuning for the models' accuracy.  This includes INT8/INT4/FP8 quantization of the models, sparsity pruning, knowledge distillation, and so on.

 

CPU Parallel Threads Optimization

CPU multi-threading in PyTorch

When deploying the model inference on a CPU, it is essential to study the hierarchy of threads to make optimal use of available CPU compute. Typically, model inference and any additional services are managed by an application thread pool. The inference of each model copy will be run on one or more application threads. Each application thread executes `Ops` of the model, one by one. The model can utilize the `fork`  operator to launch an asynchronous task. Forking several operations at once results in a set of tasks being executed in parallel. The fork operator returns a Future object, which can be used to synchronize. For example, the `rec_model` launches all its encoding tasks in parallel asynchronously, concatenates the results from all these modules, and passes them to the top MLP head for computation. All these asynchronous tasks forked from the main inference thread are handled by a single `inter-op thread pool`.

 

import torch

# Assuming encoding_1, encoding_2, ..., encoding_n and MLP are already implemented
# Create a list of tasks to be executed in parallel
tasks = []
for i in range(1, n + 1
    task = torch.jit.fork(encoding_i, sample_input)
    tasks.append(task)

# Wait for all tasks to complete
for task in tasks:
    torch.jit.wait(task)

# Concatenate the results from all encoders
encoded_features = torch.cat([task.fetch() for task in tasks], dim=1)

# Pass the concatenated features to the MLP
output = MLP(encoded_features)

 

In addition to inter-op parallelism, PyTorch also utilizes multiple threads within single ops to accelerate the operators' runtime, known as intra-op parallelism. This form of parallelism is useful for accelerating large operators, including element-wise ops on large tensors, convolutions, GEMMs, embedding lookups, etc. Depending on each model, operator types, sizes, and number of operators, a user needs to modify the degree of intra-op and inter-op parallelism for the best performance.

bibekbhattarai_4-1763151985792.png

Figure 5: The CPU multithreading thread hierarchy for a typical Pytorch application

 

Thread Oversubscription

Oversubscription in PyTorch refers to using more threads than available CPU cores, potentially degrading performance due to context switching and resource contention. To avoid oversubscription, it is recommended to limit the number of threads to the number of physical cores using torch.set_num_threads() and torch.set_num_interop_threads(). Setting these values to 1 effectively disables these thread pools. torch.set_num_threads() takes precedence over environment variables like OMP_NUM_THREADS and MKL_NUM_THREADS. 

 

For most inference workloads, it is recommended that the total number of threads across all model copies doesn’t exceed the available CPU cores. However, it is sometimes beneficial to oversubscribe CPU resources. Oversubscription happens when a system creates more software threads than there are hardware execution units (like CPU cores or logical processors) available. Despite the resource contention, oversubscribing can be useful when the workload is I/O bound, when the workload is heterogeneous, or when performing speculative computing to reduce the latency during parallelization. For example, in the recommendation mode we are optimizing, there are a host of parallel tasks during the encoding phase that require different amounts of work. On the other hand, the individual operators themselves aren’t very large. Thus, oversubscribing the CPU resources by launching a large number of inter-op parallel threads but a small number of intra-op parallel threads gave us the best performance. For example, even when we were using 2 cores for a model copy, launching 16 inter-op threads and 1 intra-op thread yielded the best result. 

 

CPU Multi-processing

While it is common to run a single model copy with optimal inter-op and intra-op parallel thread counts for a latency-sensitive application, often, throwing more cores at the model doesn’t necessarily scale the latency proportionally.  A common way of improving throughput while meeting the minimal Latency requirement is to run multiple model copies. Instead of using a single process with multiple threads, multiprocessing launches multiple independent processes, each with a separate instance of the model. This approach is particularly useful when one model instance cannot fully utilize all CPU compute and memory resources or when multiple client requests must be handled in parallel without contention or a huge queue time. Each such process will have control over how many inter-op and intra-op threads to launch within its scope separately.

 

In the Triton Inference Server, you can configure the degree of process-level and thread-level parallelism using the config file. Here’s an example of a config file on how to specify these parameters. Inside the instance group, you can set the number of model copies by setting the count parameter. This configuration launches 16 copies of a model, where each model copy uses 16 inter-op parallel threads and 1 intra-op parallel thread

name: "..."
platform: "pytorch_libtorch"
max_batch_size: ...
inputs = [...]
outputs = [...]

instance_group [
    {
        count: 16 # number of model copies 
        kind: KIND_CPU
    }
]
parameters: {
    key: "INTER_OP_THREAD_COUNT"
        value: {
        string_value: "16" # number of inter-op threads per process
        }
}
parameters: {
    key: "INTRA_OP_THREAD_COUNT"
        value: {
        string_value: "1" # number of intraop threads per process
        }
}

CPU Core Affinity

When we leave all model copies without setting any CPU affinity, the automatic affinity setup by the Triton inference server is not optimal, as it tends to place threads for multiple models on the same set of cores, thereby causing resource contention and context switching. We can solve this problem by assigning a discrete subset of CPU cores to each model copy. For example, let’s suppose we have a CPU with 32 cores and we need to launch 16 copies of the model, each using 2 cores for optimal throughput. Placing these 16 copies of models on distinct sets of 2 cores each, instead of leaving them as it is, produced much better performance. 

 

Table 2: The CPU cores specification for a server

Number of CPU cores 

32

Physical cores

0-15

Logical (hyper-thread) cores

16-31

Number of model copies 

16

Cores per model copy

2



Sibling threads are 2 threads that belong to the same physical CPU core. In our example above, cpu0 and cpu16 are sibling threads, as are cpu1 and cpu17. Assigning sibling threads to different processes (i.e., model copies) can lead to increased context-switching overhead, potentially reducing performance compared to assigning them to the same process, which allows for more efficient resource sharing and communication. So it is recommended that the sibling threads be assigned to the same model copy when possible. In the following table, we have a CPU affinity plan for when the available CPU cores are evenly divided across all the model copies. 

 

Table 3: The different copies of models mapped to their distinct sets of CPU cores 

model_0: 0, 16

model_1: 1, 17

model_2: 2, 18

model_3: 3, 19

model_4: 4, 20

model_5: 5, 21

model_6: 6, 22

model_7: 7, 23

model_8: 8, 24

model_9: 9, 25

model_10: 10, 26

model_11: 11, 27

model_12: 12, 28

model_13: 13, 29

model_14: 14, 30

model_15: 15, 31

 

In the Triton inference server, you can assign the CPU affinity for each model copy using the host-policy parameter. You can either set the host policy through the ```docker run `` command or put the affinity directly in the model config file. E.g., you can set the abovementioned affinity using the following command. 

docker run -d --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 \
    -v/home/ubuntu/Triton/model_repo:/models \
    --name tritonserver_container nvcr.io/nvidia/tritonserver:24.07-py3 \
    tritonserver --model-repository=/models \
    --host-policy=cpu_0,cpu-cores=0,16 --host-policy=cpu_1,cpu-cores=1,17 \ 
    --host-policy=cpu_2,cpu-cores=2,18 --host-policy=cpu_3,cpu-cores=3,19 \
    --host-policy=cpu_4,cpu-cores=4,20 --host-policy=cpu_5,cpu-cores=5,21 \
    --host-policy=cpu_6,cpu-cores=6,22 --host-policy=cpu_7,cpu-cores=7,23 \
    --host-policy=cpu_8,cpu-cores=8,24 --host-policy=cpu_9,cpu-cores=9,25 \
    --host-policy=cpu_10,cpu-cores=10,26 --host-policy=cpu_11,cpu-cores=11,27 \
    --host-policy=cpu_12,cpu-cores=12,28 --host-policy=cpu_13,cpu-cores=13,29 \
    --host-policy=cpu_14,cpu-cores=14,30 --host-policy=cpu_15,cpu-cores=15,31

Or you can modify the instance group in the config.pbtxt file as follows.

instance_group [
    {
        count: 1
        kind: KIND_CPU
        host_policy: "cpu-cores=0,16"
    },
    {
        count: 1
        kind: KIND_CPU
        host_policy: "cpu-cores=1,17"
    },
    {
        count: 1
        kind: KIND_CPU
        host_policy: "cpu-cores=2,18"
    },
    ...
]

In addition to the discrete core division, we tried running multiple model copies in the same subset of cores, a CPU affinity plan we named overlap_x2, overlap_x4, and so on. For overlap_x2, we ran 2 copies of the models in the same subset of CPUs. In essence, it is another form of CPU resource oversubscription.

 

Table 4: The CPU affinity for 16 model copies on a node with 32 cores with the overlap_x2 plan

Model Copies

Assigned CPU cores

model_0, model_8 

0, 1, 16, 17

model_1, model_9 

2, 3, 18, 19

model_2, model_10

4, 5, 20, 21

model_3, model_11

6, 7, 22, 23

model_4, model_12

8, 9, 24, 25

model_5, model_13 

10, 11, 26, 27

model_6, model_14

12, 13, 28, 29

model_7, model_15

14, 15, 30, 31

 

Tuning: putting it all together 

To obtain the optimal performance, we wrote a tuner that performs a grid search on all of these parameters to obtain the configuration. The tuner searches over the parameter space for the best combination of the following parameters.

  1. Inter-op parallel threads: e.g., [1 2 4 8 16 32]  
  2. Intra-op parallel threads: e.g., [1 2 4 8 16 32]Preview
  3. Number of model copies: e.g. [1 2 4 8 16 32]
  4. CPU core affinity plans: e.g. [discrete, overlap_x2, overlap_x4]

 

To evaluate the performance of each combination, we ran the Triton Performance Analyzer tool. We varied the query concurrency from 2 to 20 while keeping the batch size constant. 

perf_analyzer -m rec_model --concurrency-range=2:20:2 --async --input-data zero --shape input_0:batch_size,shape_0  --shape input_1:batch_size,shape_1 ... 

Here is the skeleton of the tuning script looks something like the one below.

 

# search space, change according to your server config
num_model_copies_list = [1, 2, 4, 8, 16, 32]
interop_threads_list = [1, 2, 4, 8, 16, 32]
intraop_threads_list = [1, 2, 4, 8, 16, 32]
cpu_affinity_plans = [1, 2, 4]
NUM_CPU_CORES = 32

# iterate through search space 
for count, interop, intraop, cpu_affinity_plan in itertools.product(
         num_model_copies_list, interop_threads_list, intraop_threads_list, cpu_affinity_plans):
    cpu_cores_per_model = (cpu_affinity_plan * NUM_CPU_CORES) // count
     
    # Update model config and launch commands 
    update_model_config(count, interop, intraop, cpu_cores_per_model)

    host_policy_cmd = generate_triton_host_policy(count, cpu_cores_per_model, MAX_CPU_CORES)
    TRITON_SERVER_CMD = TRITON_LAUNCH_PREFIX + host_policy_cmd

    # Start Triton
    subprocess.run(TRITON_SERVER_CMD, shell=True, check=True)

    # wait_for_triton() returns true once Triton server is up 
    if wait_for_triton():
        # Run Triton perf_analyzer 
        LOG_FILE = os.path.join(LOG_DIR,\ f"perf_interop{interop}_intraop{intraop}_count{count}_cores{cpu_cores_per_model}.log")
        print(f"Running perf_analyzer... Output will be saved in {LOG_FILE}")
        with open(LOG_FILE, "w") as log_file:
            result = subprocess.run(PERF_ANALYZER_CMD, shell=True, stdout=log_file, stderr=log_file)
        print(f"perf_analyzer output saved to {LOG_FILE}")
    else:
        print("Skipping perf_analyzer due to triton startup failure")
     
    # stop Triton
    subprocess.run(["docker", "container", "stop", "tritonserver_container"])


After searching through the whole parameter grid, we landed on the optimal configs. Figure 6 shows the best throughput configuration, best latency configuration, and best throughput configurations while keeping latency within given bands (<50ms or < 100ms) compared to the default config. For example, the best throughput configuration improves the overall server throughput by 1.87x while simultaneously reducing the average client latency to ~60%. Based on the application SLA, we can figure out the optimal configuration after the tuning completes.

 

bibekbhattarai_5-1763151985794.png

Figure 6: Impact of optimizing CPU threading configs on the throughput and client latency. With carefully managed configs, we can more than double the throughput while simultaneously reducing the client latency.

 

Summary

In this blog, we documented two key steps of inference optimization on Intel Xeon CPUs. The first part was optimizing the model itself, utilizing Intel Xeon's built-in accelerators like AVX-512, VNNI, and AMX, along with optimized kernel libraries such as OneDNN and OneMKL. In addition, framework extensions like Intel Extension for PyTorch (IPEX) provide additional Intel native optimizations that are not yet available on stock PyTorch. We saw up to 2.7x improvement in throughput when we converted the model to bfloat16 and serialized it using torchscript.

 

Second, CPU thread optimization, including inter-op and intra-op parallelism, was critical in improving inference throughput. Tuning thread counts to optimize for the given model and strategically distributing model copies across CPU cores with affinity settings maximizes resource utilization. We were able to improve the default configuration’s throughput by up to 1.87x. It is important to note that the model optimization and threading config optimizations are orthogonal to one another; thus, the overall performance improvement compounds when both optimizations are applied. When model-level optimization, like converting the model to bfloat16/int8, the optimal threading config will change as well. So it’s important to first apply model optimization and then perform the configuration tuning. 

 

About the Author
I work with customers to optimize their AI training, fine-tuning, and Inference workloads on Intel platforms