前言
DeepSeek 提出的 MLA(Multi-head Latent Attention)通过将 KV 压缩到低维 latent 空间,大幅降低了推理时的 KV cache 开销。但在 DCP(Decode Context Parallel,即上下文并行)分布式环境下,MLA 的 prefill 阶段设计与 decode 阶段有显著差异。本文从实现角度展开分析。
什么是 DCP
DCP(Decode Context Parallel)是一种将 KV cache 按序列维度切分到多个 GPU 的分布式策略。每个 rank 只持有完整 KV cache 的 1/dcp_world_size,从而减少单卡显存占用,支持更长的上下文。
与更常见的 DP(Data Parallel,扛并发)和 EP(Expert Parallel,分摊 MoE 参数显存)不同,DCP 解决的是单请求长上下文场景下 KV cache 放不下的问题。
MLA Prefill vs Decode:两条不同的路径
MLA 在 prefill 和 decode 阶段走了截然不同的计算路径:
Prefill (forward_mha) | Decode (forward_mqa) | |
|---|---|---|
| KV 形态 | 完整 MHA(N 头) | Latent(1 头) |
| Head dim | P+R(~192) | Lkv+R(~576) |
| 计算特性 | Sq ≈ Skv,计算密集 | Sq ≪ Skv,避免显存搬运 |
Prefill 走 MHA 路径:kv_c 通过 W_UK/W_UV 解压成完整多头 K/V(N 个头),然后做标准的多头注意力。因为 prefill 时新 token 数和 context 长度在同一量级,展开 KV 做计算密集的 attention 是划算的。
Decode 走 MQA 路径:KV 保持 latent 形式(1 头),Q 通过 einsum 吸收进 latent 空间。因为 decode 每次只有 1 个新 token,避免展开 KV 可以大幅减少显存搬运。
Prefill 阶段的 DCP 设计
核心思路:AllGather KV
MLA prefill 在 DCP 下做的是 AllGather latent KV,而非标准 context parallelism 中常见的 AllGather Q。
为什么?关键在于 MLA latent 空间的极度紧凑性:
| 方案 | 每个 token 的通信量 |
|---|---|
| 标准 MHA(AllGather Q) | N × qk_dim = 128 × 192 = 24576 |
| MLA(AllGather latent KV) | kv_lora_rank + rope_dim = 512 + 64 = 576 |
Latent KV 每个 token 仅 576 维,而展开后的 Q 有 24576 维,差了 40 多倍。即使 context token 数远多于 prefill token 数,AllGather latent KV 的总通信量仍然更小。
与非 DCP 的流程对比
非 DCP (_compute_prefill_context):每个 chunk 从本地完整的 KV cache 里 gather_and_maybe_dequant_cache 取出需要的部分 → 解压 → attention → LSE merge。
DCP (_context_parallel_compute_prefill_context):
- 本地 gather——
cp_gather_cache从本地 block table 取出当前 rank 负责的那段 KV(按padded_local_cu_seq_lens寻址) - Workspace 布局——workspace 大小为
(1 + dcp_world_size) × chunk_size,前半段存本地 gather 结果,后半段留作 allgather 目标 - AllGather——
get_dcp_group().all_gather(local_gathered_kvcache, dim=0),各 rank 把自己那部分 KV 广播出去。由于各 rank 的 padding 长度可能不同,allgather 后数据是交错排列的 - Reorg——
reorg_kvcache()把交错排列的 KV 重排成按原始序列顺序连续排列的格式,依赖padded_local_chunk_seq_lens和local_context_lens_allranks两套元数据 - 解压 K/V——
kv_b_proj(kv_c_normed)展开成完整多头 K/V - Attention——
run_prefill_context_chunk,非 causal(context 全在 new tokens 之前) - Merge——chunk 间
merge_attn_states(LSE rescaling)
DCP 对 backend 透明
DCP 的所有复杂性都集中在 gather + reorg 这一步。prefill backend(FlashAttention 等)不感知 DCP 的存在——它拿到的是拼好的完整 K/V,直接做 attention。
已知限制
DCP prefill 目前不支持 KV cache 量化(k_scale is None 断言),这是当前实现的一个限制。
Chunked Prefill 与 LSE Merge
当有 cached context 时,forward_mha 把计算拆成三步:
# 第一步:new tokens causal self-attention
output_prefill = prefill_metadata.prefill_backend.run_prefill_new_tokens(
q=q, k=k, v=v, return_softmax_lse=has_context,
) # 内部 causal=True
# 第二步:new tokens → cached context(分 chunk 非 causal attention)
context_output, context_lse = self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
)
# 第三步:LSE merge
merge_attn_states(output=output, prefix_output=context_output, ...)
如果 context 被分成 N 个 chunk,则内部有 N-1 次 LSE merge(chunk 间合并),加上 final merge(context 与 new tokens 合并),共 N 次 merge。但每次 merge 只是 O(B×N×V) 的 element-wise rescaling,比 attention 本身便宜得多。
无 cached context 时,跳过整个 chunked context 流程,forward_mha 只做普通 causal MHA prefill。
DCP 在 Prefill 与 Decode 中的不同角色
一个值得注意的对比是 DCP 在 prefill 和 decode 中操作的对象不同:
- Prefill:DCP 在 KV cache 层面操作——各 rank allgather 自己的 KV 片段,拼成完整 context,attention kernel 看到的是完整数据
- Decode:DCP 在 输出层面操作——各 rank 算 partial attention,通过 LSE 加权合并
这背后的决定因素是通信量的不对称:
| 通信内容 | 每 token 维度 | 参与 token 数 | 总通信量 | |
|---|---|---|---|---|
| Prefill AllGather KV | latent KV | 576 | Skv(context 长度) | 576 × Skv |
| Prefill AllGather Q | 展开的 Q | 24576 | Sq(prefill 长度) | 24576 × Sq |
| Decode AllGather KV | latent KV | 576 | Skv(context 长度) | 576 × Skv |
| Decode AllGather Q | latent Q | 576 | 1(单个 token) | 576 × 1 |
选择逻辑很直接:谁更小/更少就传谁。prefill 时 latent KV 每 token 维度远小于 Q;decode 时 Q 只有 1 个 token,远小于全部 context。
关于 cp_gather_cache 的命名澄清
cp_gather_cache 并非 DCP 专用函数,它只是不做 dequant 的 gather:
- 非 DCP FP8 路径:用
cp_gather_cache保持 FP8 格式不反量化,FP8 prefill backend 内部自行处理 - DCP 路径:始终用
cp_gather_cache,因为 DCP 不支持 KV cache 量化,纯 gather 即可
这里的 cp_ 前缀更可能是 “cache page” 或历史遗留命名,与 “context parallel” 无关。
总结
MLA 的 DCP prefill 设计充分利用了 latent 空间的紧凑性,选择 AllGather KV 而非 AllGather Q,在通信效率和代码复用之间取得了良好平衡。与 decode 阶段共用 DCP 基础设施,但操作对象和通信模式截然不同,体现了 MLA 在不同计算阶段的灵活适配能力。
DCP 并非 DP+EP 的替代品,而是特定场景的补充优化——在极长上下文 + 低并发场景下,通过去重 KV cache 来提升单请求的吞吐能力。