Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
785 Discussions

Cost Effective Deployment of DeepSeek R1 with Intel® Xeon® 6 CPU on SGLang

IntelAI
Employee
1 0 4,959

The impressive performance of DeepSeek R1 marked the rise of giant Mixture of Experts (MoE) models in large language models (LLM). However, new deployment challenges have arisen due to its massive model size and unique architecture. Significant memory requirements normally require 8x or even 16x high-end AI accelerators.

Intel PyTorch and the SGLang Team have contributed to the CPU backend for SGLang for the past few months. We proposed a high-performance CPU-only solution using the Intel® Xeon® 6 Processor at a fraction of the cost. We’ll explain the technical details of achieving high efficiency for deploying DeepSeek on a single node with Intel Xeon 6 processors.

Highlights:

  • SGLang now supports native CPU backend on Intel® Xeon® processors with Intel® Advanced Matrix Extensions (Intel® AMX).
  • Supports BF16, INT8, and FP8 for Dense FFNs and Sparse FFNs (MoE).
  • Achieves a 6-14x speedup for TTFT and a 2-4x for TPOT vs. llama.cpp.
  • Realizes 85% memory bandwidth efficiency with highly optimized MoE kernels.
  • Multi-Numa Parallelism via Tensor Parallelism (TP).

A CPU Optimization Strategy

In this article, we discuss the technical details of kernel-level optimization, including task partition strategy, memory access efficiency, and effective utilization of Intel® AMX for highly optimized GEMM implementations.

To begin, we will focus on 4 performance hotspots:

  • Extend Attention and Decode Attention, which are backends for RadixAttention of SGLang.
  • MoE, which contributes to the majority of weights in DeepSeek R1.
  • FP8 GEMM in which we utilized an emulated approach on an existing x86 platform without native FP8 support.

1. Extend Attention

We implemented a native C++ backend with Intel® AMX based on the interface of RadixAttention, which consists of two major components: a) Extend Attention that handles the prefill phase for Multi-Head Attention (MHA); b) Decode Attention for the decoding phase. Taking GPU kernels as a reference, we mapped the flash attention algorithm to CPU intrinsics, as illustrated in Fig. 1 below:

Picture1.png

 Fig. 1: Flash Attention in Prefilling Phase

To remove redundant computations, SGLang divides the query sequence into two parts: a) prefix – a historical sequence in which attention is a rectangle; b) extend – a newly added prompt in which attention is a lower triangle. The CPU kernel exactly maps to the Flash Attention V2 algorithm, and we carefully chose the block size for the Query sequence and KV sequence to make sure that the immediate values of attention Si and momentums Mi, S* fit in L1/L2 cache. Intel AMX computes the GEMM parts, and Intel® AVX-512 computes the Block Pointwise OPs. Because Intel AMX does the accumulation in FP32 (for example, A: BF16; B: BF16; C: FP32), we fused data type conversion with the momentum updates, keeping Si in FP32, which is the result of 1st GEMM and S in BF16, which is the input for 2nd GEMM, reducing the rounding error to a minimal level while achieving high computation efficiency.

2. Decode Attention

Decoding faces more pressure on parallelization than prefilling because the query sequence length is reduced to one. Specifically, in Multi-Head Attention, we can parallel the kernel on dimensions of [Batches, Heads, qBlocks], which will be simplified to [1, Heads, 1] for single request decoding, leading to insufficient parallelism. We implemented a Flash Decoding algorithm that chunks the KV sequence into multiple splits to increase the degree of parallelism, as shown in Fig. 2. The implementation takes two phases to complete: first, compute attention for each of the KV splits; then reduce the immediate results from all splits to the final output.

Picture2.png

 Fig. 2: Flash Decoding Implementation

Multi-Head Latent Attention (MLA) Optimization

MLA is one of the core features of the DeepSeek model series. We provide several critical optimizations on MLA CPU implementation in addition to Flash Decoding. We referenced FlashMLA, which exploits the fact that key and value share the same tensor storage and pipeline memory load, and computation.

Picture3.png

 Fig. 3: MLA Decoding Implementation

  • Load Once Pack Twice: Intel AMX requires tile data in VNNI format; key and value must be packed differently since the 1st GEMM is NT and the 2nd GEMM is NN. We implemented a fully vectorized packing logic as indicated in Fig. 3, KV caches are fetched through 2 LUTs with prefetch; with all 32 lanes loaded (BLOCK_N equals 32), simultaneously packed into two-thread local immediate buffers, one for key in format of [E/2, BLOCK_N, 2], the other for value in format of [BLOCK_N/2, Ev, 2].
  • Head Folding: MLA employs weight absorption in the decode phase, which reduces the number of heads to 1 for both key and value. Therefore, we can fold the Head dimension into GEMM to increase computation intensity, as shown below. We balanced parallelism when blocking the Head dimension: with a Head dimension of 22 in DeepSeek R1, we used a BLOCK_SIZE of 6 for a single request and gradually increased to 22 for more requests.

Picture4.png

 

 Overall, the kernel-level optimizations on MLA provided an approximate 1.9x performance speedup against the vanilla implementation. Notably, we also fused the KV buffer setting with decoding kernels, which yielded a 12% improvement after removing several inefficiencies from torch: implicit data type conversion for indexing, creating TensorImpl for slicing a tensor, and mapping copy with TensorIterator, etc.

3. MoE

A naïve implementation of MoE with torch would involve looping through experts sequentially, and gathering (mask) activations for each expert before linear projection. To improve efficiency, a common strategy is to sort the index for activation and then chunk it into blocks. We followed the implementation from existing GPU kernels on SGLang, as shown in Fig. 4, run argsort on topk_ids, and kept indices of activations in sorted_ids according to expert ids. We also made several other optimizations for the CPU kernel:

  • SiLU Fusion: to fuse up_proj and SiLU, we implemented a GEMM kernel that operates in the pattern of A×[B1,B2]=[C1,C2]. With Bfrom the left half and B2 from the right half, we can fuse SiLU(C1 )*C2 together, illuminate additional load/store for the output of up_proj.
  • Dynamic Quant Fusion: In our INT8 dynamic quant kernels for MoE, we fused the quantization from BF16 to UINT8 with the fetching of activation. We implemented Intel AVX-512 and Intel AMX kernels and chose between them according to input configurations. Unlike Intel AMX, which supports U8S8 and S8S8, Intel AVX-512 VNNI only supports U8S8 (UNIT8 for A and INT8 for B). We must make a compromise to align the weights to the U8S8 pattern, which indicates the need for a compensation factor of -128×B to convert S8S8 to U8S8: A×B=(A+128)×B-128×B.

Picture5.png

 Fig. 4: MoE Implementation

With these optimizations combined, we achieved 85% memory bandwidth efficiency for INT8 MoE, or 1.45TB/s effective memory bandwidth on Multiplexed Rank Dual Inline Memory Modules (MRDIMMs).

4. FP8 Inference

DeepSeek V3 employs FP8 hybrid training, which is a challenge for CPU devices, because existing x86 devices don’t have native support for FP8. Since providing FP8 support is essential to represent the original user experience, we made a couple of optimizations for FP8 MoE and GEMM:

  • Weight Only FP8: We followed a weight-only pattern for FP8 MoE/GEMM, in which FP8 is converted to BF16 (same as activation), and made the computation.
  • Effective Vectorized Conversion: The data type conversion from FP8 to BF16 is a major performance bottleneck on the CPU, so we experimented with two approaches: a) a LUT that gathers BF16 data from a 2^8 table; b) intrinsics vectorized conversion. Notably, both approaches are equally slow, taking 60 to 70 cycles to accomplish, which is unacceptable for any performance-critical scenario. We made a trade-off for b) and skipped the NaN checks and DENORM handling, which reduced the conversion time by half.
  • WOQ-Aware Cache Blocking: To reduce the data type conversion overhead to a minimal level, we use weight unpacking from WOQ in cache blocking during GEMM. Specifically, for each weight block assigned to each thread, we visited the weight blocks in a zigzag pattern and cached the unpacked BF16 blocks in L2, ensuring that the slow data type conversion for each block only happened once.

We validated on GSM8K and MMLU. Our emulated FP8 implementation provided identical accuracy compared to GPU results. With these optimization tricks, the FP8 implementation achieved approximately 80% to 90% of the INT8 implementation.

Multi-NUMA Parallelism

Non-uniform memory access (NUMA) is a computer memory design used in multiprocessing, commonly seen on server CPUs, where the memory access time depends on the memory location relative to the processor. Under NUMA, a processor can access its local memory faster than remote memory (memory local to another processor or shared between processors). To reduce remote memory access to a minimal level, we mapped the Tensor Parallel (TP) for multi-GPU to multi-numa on a CPU server.

We also implemented communication primitives, for example, all reduce, all gather, based on a shared memory approach, skipping the use of torch.distributed with a tedious calling stack. Overall, the communication overhead contributed to merely 3% of end-to-end time.

Evaluation

Our test platform is a state-of-the-art dual-socket Intel® Xeon® 6980P CPU server with 128 cores per socket. We used another popular LLM tool, llama.cpp, as the performance baseline to compare against the SGLang CPU backend. We evaluated four models ranging from 3B to 671B: DeepSeek-R1-671B, Qwen3-235B, DeepSeek-R1-Distilled-70B, and Llama3.2-3B.

Benchmarking notes

  • Socket Setting: We used a single socket for Llama3.2-3B and dual sockets for the other three models, because running a 3B small LLM on dual sockets leads to a performance downgrade.
  • Sub-NUMA Clustering (SNC) Setting: SGLang data are collected with SNC on and llama.cpp data with SNC off, as llama.cpp can’t guarantee local NUMA access with SNC on.
  • Multi-Instance: As we mentioned above, llama.cpp does not implement Multi-Numa Parallelism so running one instance on dual sockets is even slower than running it on a single socket. To be fair, we used two instances for llama.cpp on dual sockets, one for each, and collected metrics of TTFT and TPOT.
  • Data Type for Baseline: We compared INT8 with the GGUF Q8 format. Since llama.cpp does not have FP8 optimized, we also compared FP8 with GGUF Q8.

Table 1: Performance Evaluation of SGLang vs. llama.cpp

Picture6.png

 Detailed Breakdown

  • TTFT achieved a 6-14x performance speedup. MoE models provided larger improvements since experts are computed sequentially in llama.cpp, and we parallelized among experts by realigning expert indices.
  • TPOT achieved a 2-4x performance speedup. Since the decoding phase tends to be memory-bandwidth bound, the speedup ratio in TPOT is much smaller than TTFT.
  • In general, our emulated FP8 implementation has already achieved the best efficiency within the hardware capacities.

Limitations and Future Work

While our current work on SGLang CPU backend demonstrated significant throughput improvements, several limitations and opportunities for future enhancements remain:

  • Graph Mode Enabling: Python overhead contributes considerable time when the number of concurrent requests is low. We are experimenting with removing the Python overhead through graph mode with torch.compile. The preliminary results suggest an additional 10% improvement in TPOT; however, the work is still in progress.
  • Data Parallel MLA: The current Multi Numa Parallelism follows the Tensor Parallel pattern, which yields duplicate access for KV cache in different ranks. A more efficient solution already exists on GPUs that utilize DP Attention.
  • GPU/CPU Hybrid Execution: KTransformers innovatively uses a hybrid execution pattern for large MoE model inference, in which the MoE layers run on CPU and Attention layers run on GPU. We are experimenting with a similar approach with SGLang and further pipeline the computation stages from heterogeneous hardware.

Summary

We have demonstrated high performance through CPU-only deployment based on SGLang, providing the technical details for consideration. This work is fully open-source and upstreamed into the SGLang main branch. We are working to bring more performance optimizations for not only the CPU backend but also other Intel platforms.

Acknowledgements

The enabling and optimization of Intel® Xeon® processors in SGLang is a big milestone, providing a new alternative solution for LLM inference. Our work would not have been possible without the valued collaboration and contributions of the community.

We would like to extend our thanks to:

  • SGLang Core Team and Community Contributors: Yineng Zhang, Jiexin Liang, Thien Tran. Thank you for sharing your invaluable ideas, meticulous PR reviews, insightful feedback on RFCs, and solid code contributions.
  • KTransformers Team: Thank you, Mingxing Zhang, for sharing your insight and innovative ideas for GPU/CPU hybrid execution.

Appendix

Related RFCs and PRs

#2807, #5150, #6216, #6339, #6404, #6405, #6408, #6419, #6452, #6456, #6458, #6493, #6549, #6614, #6641, #6657, #6769, #6770, #6771, #6833, #7390, #7462, #7486, #7647, #7818, #7838, #7885.

Install SGLang with CPU Backend

Picture7.png

 

Run SGLang with CPU Backend

Picture8.png

 

Product and Performance Information

Measurement on Intel® Xeon® 6980P, HT On, Turbo On, NUMA 6, Integrated Accelerators Available [used]: DLB [8], DSA [8], IAA[8], QAT[on CPU, 8], Total Memory 1536GB (24x64GB DDR5 12800 MT/s [8800 MT/s]), BIOS BHSDCRB1.IPC.3544.D02.2410010029, microcode 0x11000314, CentOS Stream 9 Test by Intel on July 7th, 2025.

 

Notices and Disclaimers

Performance varies by use, configuration, and other factors. Learn more on the Performance Index site

Performance results are based on testing as of dates shown in configurations and may not reflect all publicly available ​updates.  See backup for configuration details.  No product or component can be absolutely secure.

Your costs and results may vary.

Intel technologies may require enabled hardware, software, or service activation.

© Intel Corporation.  Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries.  Other names and brands may be claimed as the property of others.