背景
vLLM 中实现了多种并行策略与编译优化。本文聚焦三条密切关联的技术线:Sequence Parallelism (SP)、Async Tensor Parallelism (AsyncTP) 与 量化 (Quantization) 的协同工作方式,涵盖概念辨析、GEMM-通信融合原理、量化感知改写以及配置细节。
一、SP vs PCP vs DCP:三种并行对照
| 维度 | SP (Sequence Parallel) | PCP (Prefill Context Parallel) | DCP (Decode Context Parallel) |
|---|---|---|---|
| 作用阶段 | 编译期图变换,prefill / 大 batch 阶段生效 | Prefill 阶段 | Decode 阶段(每一步推理) |
| 解决的问题 | 减少 TP 中 RMSNorm/量化等"逐 token"计算的冗余,把通信换成更便宜的 RS+AG | 长上下文 prefill 时 TTFT 太高,把 Q/K/V 计算切到多卡 | KV head 太少(如 MLA 只有 1 个),TP 把 KV cache 复制了多份,浪费显存 |
| 切分维度 | 沿序列维 (dim=0) 切隐藏状态 | 沿序列维切 prefill 的 query | 沿序列维 (T) 切 KV cache |
| 是否新增 GPU | 否,复用 TP 的卡 | 是,独立通信组,与 TP 正交 | 否,复用 TP 的 GPU;要求 tp_size % dcp_size == 0 |
| 通信组 | TP 组 | 独立 _PCP 组 | 独立 _DCP 组(在 TP 组内细分) |
| 通信原语 | reduce_scatter + all_gather(替代 all_reduce) | DualChunkSwap / ring send-recv | ag_rs(AllGather+ReduceScatter)或 a2a(All-to-All+LSE 合成) |
| 配置入口 | pass_config.enable_sp / sp_min_token_num | --prefill-context-parallel-size | --decode-context-parallel-size |
| 与 TP 关系 | TP 内部的通信优化,不改变并行结构 | 与 TP 正交的独立维度 | TP 的子分组,在 TP 内部再切 KV |
关键差别一句话:
- SP 是"图编译优化",把
AllReduce → RMSNorm改写成ReduceScatter → RMSNorm → AllGather,减少冗余计算。 - PCP 是"prefill 阶段的算力并行",把长 prompt 的 Q/K/V 计算切到多卡。
- DCP 是"decode 阶段的 KV 存储并行",沿序列维分片 KV cache,去掉 TP 复制。
二、SP ≠ AsyncTP
两者高度耦合,但定位完全不同:SP 是基础图改写,AsyncTP 是建立在 SP 之上的"通信-计算融合"。
| SequenceParallelismPass (SP) | AsyncTPPass (AsyncTP) | |
|---|---|---|
| 配置开关 | pass_config.enable_sp / sp_min_token_num | pass_config.fuse_gemm_comms |
| 作用 | 把 AllReduce → RMSNorm 改写成 ReduceScatter → RMSNorm → AllGather | 把 GEMM + ReduceScatter / AllGather + GEMM 融成一个异步 kernel |
| 改变了什么 | 图结构(通信形态) | 执行方式(通信和 GEMM 重叠/融合) |
| 是否独立可用 | 可以单独开(虽然不开 AsyncTP 的话只是把 AR 拆成 RS+AG) | 依赖 SP 的产物,没有 RS/AG 模式可融 |
SP Pass 的 docstring 写得很清楚:
While this pass itself does not directly yield performance improvements, it lays the groundwork for subsequent fusion passes, such as GEMM + ReduceScatter and AllGather + GEMM fusions.
流水线演变:
原图: [GEMM] → AllReduce → RMSNorm → ...
SP 之后: [GEMM] → ReduceScatter → RMSNorm → AllGather → ...
SP + AsyncTP 之后: [fused GEMM+ReduceScatter] → RMSNorm → [fused AllGather+GEMM] → ...
三、GEMM 与通信融合
什么是融合
在 TP 推理里,每个 transformer block 末尾的典型计算流程是:
local GEMM → 集合通信 (ReduceScatter / AllGather)
朴素执行:整块 GEMM 跑完 → 结果落到 buffer → 送进 NCCL 做通信 → 通信结束才能用结果。
融合后:让 GEMM 一边算、一边把已经算出来的小块立刻发出去;接收端一边收、一边喂给下一个 GEMM。本质是流水线重叠 (overlap) 计算与通信。
vLLM 把这件事交给 PyTorch 内置的 torch.ops.symm_mem.* 算子(基于 NVSHMEM / CUDA P2P 的对称内存),由 AsyncTPPass 在编译期把模式替换进去。
两个核心融合模式
1. GEMM + ReduceScatter(block 出口)
改写前:
mm = torch.ops.aten.mm(mul, weight) # 本地 GEMM
out = torch.ops.vllm.reduce_scatter(mm, dim=0, ...) # NCCL ReduceScatter
改写后:
out = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul, weight, "sum",
scatter_dim=0,
group_name=...,
)
直接调一个单算子完成 GEMM + RS。
2. AllGather + GEMM(block 入口)
改写前:
ag = torch.ops.vllm.all_gather(x, dim=0, ...) # NCCL AllGather
mm = torch.ops.aten.mm(ag, weight) # 本地 GEMM
改写后:
ag_out, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
x, [weight],
gather_dim=0,
group_name=...,
)
FP8 / NVFP4 还有量化版本:ScaledMMReduceScatterPattern、AllGatherCutlassScaledMMPattern 等。
四、SP + AsyncTP 与量化的协同
量化(FP8 / NVFP4)在 vLLM 里和 SP + AsyncTP 是互相设计来配合的。量化把算子和通信的负载都缩小,SP 决定通信形态,AsyncTP 把缩小后的 GEMM 和通信融成异步流水。
链路 1:SP 改写——量化算子也被一起圈进 RS/AG
SP Pass 不只是改 AllReduce → RMSNorm,而是把后面紧跟的量化也一起包起来。源码里专门注册了量化感知的 pattern:
FirstAllReduceRMSNormStaticFP8PatternMiddleAllReduceRMSNormStaticFP8PatternFirstAllReduceRMSNormStaticNVFP4PatternMiddleAllReduceRMSNormStaticNVFP4Pattern
替换的形态(FP8 版本):
原图: AllReduce → RMSNorm → FP8_quant
SP后: ReduceScatter → RMSNorm(局部) → FP8_quant(局部) → AllGather
关键意义:
- AllGather 直接传 FP8/NVFP4 数据,比传 bf16 节省 2~4 倍带宽。
- 量化也只在
S/tp长度上做,省掉冗余计算。 - NVFP4 还会 AllGather 量化后的 scale,保证下游 GEMM 拿到正确的 per-block scale。
链路 2:AsyncTP 改写——量化 GEMM 与集合通信融成单算子
SP 之后图变成 FP8_GEMM → ReduceScatter / AllGather → FP8_GEMM,AsyncTPPass 的量化 fusion pattern 接手:
SP后: AllGather(FP8) → FP8_GEMM
→ 匹配 AllGatherScaledMMPattern
→ 替换为 fused_all_gather_matmul(输入为 FP8, 权重为 FP8, 附带 scale)
SP后: FP8_GEMM → ReduceScatter
→ 匹配 ScaledMMReduceScatterPattern
→ 替换为 fused_matmul_reduce_scatter(带 FP8 scale)
这些 pattern 注册在 collective_fusion.py 中。
Per-tensor 量化的正确性边界
| 类型 | scale 来源 | 能否只在 S/tp 上做 |
|---|---|---|
| per-tensor static | 离线校准好的常量 | ✅ 可以 |
| per-tensor dynamic | 运行时 max|x|/qmax 在整个张量上算 | ❌ 不能 |
| per-token (row-wise) | 每行一个 scale | ✅ 可以(行只属于一个 rank) |
| per-block (NVFP4) | 每 N 个元素一组 scale | ✅ 可以(块不跨 rank) |
SP Pass 只匹配 static per-tensor / per-token / per-block 量化。看 pattern 类名全部带 Static:FirstAllReduceRMSNormStaticFP8Pattern。kFp8StaticTensorSym 表示 Static + per-Tensor + Symmetric。
Dynamic per-tensor 不被 SP 匹配,保留原图 AllReduce → RMSNorm → dynamic_quant。
五、配置细节与护栏
启停规则
# AsyncTP 强制依赖 SP
if pass_config.fuse_gemm_comms:
pass_config.enable_sp = True
# TP=1 时 SP/AsyncTP 无意义,自动关闭
if pass_config.enable_sp and tensor_parallel_size == 1:
pass_config.enable_sp = False
pass_config.fuse_gemm_comms = False
sp_min_token_num 自动求阈值
get_sequence_parallelism_threshold 算阈值:如果 hidden_size < SP_MIN_HIDDEN_SIZE(H100/Blackwell = 8192),返回 None 时整条链路全关掉。小模型即便手动开也不会生效。
必须 fullgraph 编译
SP 和 AsyncTP 都断言 fullgraph。O2 默认 use_inductor_graph_partition=False,所以靠 splitting_ops==[] 满足要求。vLLM 的处理方式是 SP 优先、自动清空 splitting_ops——见 compilation_config.validate():
if (not self.use_inductor_graph_partition
and (self.pass_config.enable_sp or self.pass_config.fuse_gemm_comms)
and self.splitting_ops):
logger.warning_once("Sequence parallelism requires full-graph compilation...")
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
self.cudagraph_mode = CUDAGraphMode.FULL
注意这里还会把 cuDAGraph mode 从 FULL_AND_PIECEWISE 降级为 FULL。
SP 阈值取 min(min_token_num, max_num_batched_tokens)
所以小 batch(decode)走原 TP 路径,大 batch(prefill / 大 chunked prefill)才进 SP 路径——这是为什么这条链路对 throughput 友好、对 latency 几乎不伤害。
默认 O2,SP 默认开
O0: enable_sp=False, fuse_gemm_comms=False, cudagraph=NONE
O1: enable_sp=False, fuse_gemm_comms=False, cudagraph=PIECEWISE
O2(默认): enable_sp=IS_DENSE, fuse_gemm_comms=IS_DENSE, cudagraph=FULL_AND_PIECEWISE
O3: enable_sp=IS_DENSE, fuse_gemm_comms=IS_DENSE, cudagraph=FULL_AND_PIECEWISE
稠密模型 + 默认 O2 + TP>1 + 满足硬件阈值,SP 和 AsyncTP 是默认启用的。
总结
SP、PCP、DCP 是三种不同层面的并行——SP 是编译期图变换(TP 内部的通信优化),PCP 是 prefill 算力并行(额外 GPU),DCP 是 decode KV 存储并行(复用 TP GPU)。
SP ≠ AsyncTP——SP 把
AllReduce拆成ReduceScatter + AllGather,改变图的通信形态;AsyncTP 在此基础上把 GEMM 和通信融合成异步流水(对称内存 + 单算子)。量化协同链路:SP 感知
static_fp8_quant/nvfp4_quant,把量化框进RS → Norm → Quant → AG的模式中;AsyncTP 进一步融合AllGather + FP8_GEMM和FP8_GEMM + ReduceScatter。Dynamic per-tensor 量化不在支持范围内。配置护栏:SP 自动清空
splitting_ops保证 fullgraph;hidden_size < 8192时整条链路自动关闭;小 batch decode 不进 SP 路径。