概述
Cascade Attention(级联注意力)是 vLLM 推理引擎中针对多请求共享长前缀场景的一种注意力优化技术。它将标准 attention 拆解为 prefix(前缀)和 suffix(后缀)两个阶段,显著降低 KV cache 的全局内存读取量,在共享 system prompt 的批量推理场景中可实现最高数十倍的 attention 加速。
本文将深入剖析 cascade attention 的设计思想、使用方式、适用场景、实现原理,并与 FlashAttention、FlashDecoding 等其他优化技术进行对比。
一、Cascade Attention 的核心思想
标准 Attention 的冗余
在大语言模型的批量推理中,多个请求往往共享一个较长的 system prompt。例如:
- Chatbot 场景:所有请求共享 “You are a helpful assistant…” 等 system prompt
- Document QA:多用户对同一篇文档提问,文档内容为公共前缀
- Self-Consistency:对同一 prompt 采样多条推理路径
标准 attention 的处理方式是每个 request 独立计算其完整的注意力——包括共享的 system prompt。这意味着同一份前缀 KV cache 被重复加载多次。
Cascade 的拆解思路
Cascade attention 将一次 attention 计算拆成三步:
- Prefix 阶段:将所有请求的 query 拼成一个"大序列",对共享前缀做一次非因果(bidirectional)attention
- Suffix 阶段:每个请求各自对其独有的后缀做因果(causal)attention
- Merge:通过 LSE(log-sum-exp)rescaling 将两阶段结果加权合并
数学上等价于标准 attention,但计算量和显存带宽需求大幅降低。
二、使用方式
Cascade attention 默认关闭,需显式启用:
CLI 启动
vllm serve <model> --disable-cascade-attn False
Python API
from vllm import LLM
llm = LLM(model="...", disable_cascade_attn=False)
自动禁用条件
即使已启用,以下情况也会被强制关闭:
- CPU 平台
- 异步 speculative decoding
- Full CUDA graphs(退化为 piecewise 或 eager 模式)
- Microbatching / DBO(disaggregated batch overlap)
VLLM_BATCH_INVARIANT环境变量启用
三、适用场景
触发条件
内置启发式规则通过 use_cascade_attention() 判断是否启用:
| 条件 | 阈值 |
|---|---|
| 公共前缀长度 | ≥ 256 tokens |
| Batch size | ≥ 8 个请求 |
| 不支持的特性 | ALiBi、sliding window、local attention、context parallelism |
| FlashDecoding 分析 | 对比 cascade 和 FlashDecoding 的 tile 数/SM 资源消耗 |
典型场景
- 多轮对话 / Chatbot:所有请求共享相同的 system prompt
- Document QA:文档内容为公共前缀,不同问题为后缀
- Self-Consistency / Chain-of-Thought:同 prompt 采样多条推理路径
- Few-shot / In-context learning:示例作为共享前缀
不受益的场景
- 单请求推理(无共享前缀)
- 请求之间无公共前缀块
- batch 过小(<8 请求)
- 使用 sliding window attention 或 ALiBi 的模型
四、Prefix 与 Suffix 的直观理解
以两个请求共享 system prompt [A, B, C] 为例:
| 请求 | KV Cache 全部内容 | 已计算的 KV | 本次输入 query |
|---|---|---|---|
| R1 | [A, B, C, D, E, X] | [A, B, C] | [D, E, X] |
| R2 | [A, B, C, D, E, Y] | [A, B, C, D] | [E, Y] |
- Prefix = 所有请求共享的那部分 KV cache,即
[A, B, C](cap 到 minnum_computed_tokens以避免 attention mask 泄漏) - Suffix = 每个请求独有的那部分 KV cache:R1 的
[D, E, X]、R2 的[D, E, Y]
计算流程:
- Prefix kernel:所有 query 拼起来对 prefix 做 non-causal attention(一次 kernel 调用)
- Suffix kernel:每个请求各自对后缀做 causal attention(每个请求独立)
- Merge:用 LSE rescaling 合并两个结果
五、为什么 Cascade 会加速
Memory Bandwidth 之省
假设 N=64 个请求,公共前缀 P=4096 tokens,每个请求后缀 S=128 tokens:
标准 attention(无 cascade):
- 每个 request 独立加载全部 KV (P+S)
- KV cache 加载量 = N × (P + S) = 270,336 个 token
Cascade attention:
- Prefix kernel:KV 加载量 = P = 4,096(加载一次,所有 query 共享)
- Suffix kernel:KV 加载量 = N × S = 8,192
- Merge kernel:几乎免费,仅 element-wise 加权求和
- 总计 = 12,288 个 token
节省:22 倍的 KV cache memory bandwidth 减少。
核心洞察
节省来源于 Flash Attention 的并行化策略。在标准路径中:
gridDim = (num_m_block, batch_size, num_heads)
batch_size 等于 cu_seqlens_q 定义的序列数。标准路径中 cu_seqlens_q = [0, 1, 2, ..., N],batch_size = N,每个请求分配给独立的 thread block,各自加载 block_table[seq_id] 中的完整 KV pages(含 prefix + suffix),同一份 prefix pages 被 N 个 TB 独立加载。
Cascade prefix 将 cu_seqlens_q 设为 [0, total_q],batch_size = 1,所有 query 归属一个序列。Flash Attention 内部按 ceil(total_q / BLOCK_M) 个 M-block tiling,每个 block 内所有 Q token 共享 prefix KV cache。
节省倍数 ≈ BLOCK_M(通常 64 或 128)。
六、Attention Backend 支持矩阵
| Backend | 支持情况 |
|---|---|
| FLASH_ATTN(默认) | ✅ 完整支持,自定义 cascade 实现 |
| FLASH_ATTN_DIFFKV | ✅ 完整支持(继承自 FLASH_ATTN) |
| FLASHINFER | ❌ 有 MultiLevelCascadeAttentionWrapper 但因 bug 禁用 |
| ROCm | ❌ 硬阻断 |
| FlexAttention | ❌ “Not implemented yet” |
| Triton | ❌ |
| CPU | ❌ |
| MLA 系列(12+ 后端) | ❌ 均不支持,MLA 与标准 attention 路径完全分离 |
只有 FLASH_ATTN 真正实现
vLLM 的 FLASH_ATTN backend 不依赖 FlashInfer,自己实现了 cascade:两次 flash_attn_varlen_func(prefix non-causal + suffix causal)+ 自定义 merge_attn_states CUDA kernel(LSE rescaling)。
FlashInfer 库提供了 MultiLevelCascadeAttentionWrapper,vLLM 的 flashinfer backend 导入了它并构建了完整的 metadata 和 forward 路径,但 use_cascade_attention() 返回 False——包括 KV cache dtype 不匹配(不支持 FP8 KV cache)和功能性 bug(“Cascade attention doesn’t work, disable it for now”)两个原因。
七、与 FlashInfer Cascade 的对比
FlashInfer 的 cascade 是 kernel 级设计,vLLM FLASH_ATTN 的实现只是函数级拆分:
| 维度 | vLLM FLASH_ATTN | FlashInfer |
|---|---|---|
| KV 复用粒度 | Q tile 内(BLOCK_M 个 query 共享) | SMEM 片上显式复用 |
| 级数 | 2 级(prefix + suffix) | 多级(num_levels 任意) |
| CUDA Graph | ❌ 不兼容 | ✅ plan/run 分离,支持 capture |
| 加载方式 | KV 仍走 Global Memory | KV 加载到 SMEM 跨 query 复用 |
| 实现方式 | 两次 flash_attn_varlen_func + merge | 自定义 kernel MultiLevelCascadeAttentionWrapper |
FlashInfer 的博客报告在 H100 上最高可达 31x 加速(batch ≥ 128,shared prefix 32768,suffix ≤ 256)。
八、实现流程详解
完整的 cascade attention 流程分为 6 步:
Scheduler → GPU Model Runner → Metadata Builder → Forward Dispatch → cascade_attention() → Merge
Step 1: Scheduler 计算公共前缀 blocks
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
)
遍历 KV cache 的 block table,找到所有 running request 共享的物理 block 数(从 block 0 开始连续匹配)。
Step 2: GPU Model Runner 计算 cascade prefix 长度
common_prefix_len = num_common_prefix_blocks * block_size
common_prefix_len = min(common_prefix_len, num_computed_tokens.min())
common_prefix_len = (common_prefix_len // block_size) * block_size
use_cascade = attn_metadata_builder.use_cascade_attention(...)
启发式检查 batch ≥ 8、prefix ≥ 256 tokens、以及 FlashDecoding 成本对比。
Step 3: 构建 cascade metadata
cu_prefix_query_lens = [0, num_actual_tokens] # 所有 Q 视为一个序列
prefix_kv_lens = [common_prefix_len] # 只有公共前缀
suffix_kv_lens = seq_lens - common_prefix_len # 每个 request 各自的 suffix
Step 4: Forward 分发
if not attn_metadata.use_cascade:
flash_attn_varlen_func(...) # 标准路径
else:
cascade_attention(...) # cascade 路径
Step 5: cascade_attention() 执行
# Phase 1: Prefix — batch=1, non-causal, 只加载 prefix blocks
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query, cu_seqlens_q=[0, total_q],
max_seqlen_k=common_prefix_len, causal=False,
block_table=block_table[:1],
)
# Phase 2: Suffix — batch=N, causal, 从 prefix 之后开始
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query, cu_seqlens_q=per_request_lens,
causal=True,
block_table=block_table[:, num_common_kv_blocks:],
)
# Phase 3: Merge — LSE rescaling
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
九、与其他 Attention 优化的比较
FlashDecoding (Split-KV)
FlashDecoding 将 KV 序列沿 KV 维度切分成 num_splits 份,每份由一个独立的 CTA 做 partial attention,再用 LSE rescaling 合并。本质上是对KV 维度的并行化。
Cascade 的启发式会建模 FlashDecoding 的 CTA 数量来对比性能:
flash_decoding_ctas = num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)
flash_decoding_ctas *= num_prefix_tiles
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
只有当 cascade 比 FlashDecoding 快时才启用 cascade,否则直接用标准的 FlashDecoding。
两者的核心区别:
| 维度 | Cascade Attention | FlashDecoding |
|---|---|---|
| 目标 | 减少全局 KV 读取 | 并行化 KV 计算 |
| 优化对象 | Memory bandwidth | Compute latency |
| 适用条件 | 多请求共享长前缀 | 长 KV 序列的 decode |
| 并行维度 | 合并 batch 维度 | 拆分 KV 维度 |
| 与 FA3 关系 | 外部策略 | 内置于 FA3 kernel |
Decode Context Parallelism (DCP)
将每个 request 的 KV cache 切到多个 rank 上分布式计算,与 cascade 互斥。
Attention Sinks
保留序列开头 token 始终参与 attention 的机制。vLLM 框架层面提供了通用支持——任何 backend 都接收可选的 sinks 参数,透传到 CUDA kernel。但目前仅有 Laguna 和 MimoV2 两个模型实际使用。
FA2 vs FA3
| FA2 | FA3 | |
|---|---|---|
| Grid 策略 | 静态 (M_tiles, batch, heads) | Persistent (num_sm) + scheduler metadata |
| 循环顺序 | M-outer(Q 外层) | N-outer(KV 外层) |
| 负载均衡 | 短序列 TB 空转 | 动态 work queue 均衡 |
| Cascade 兼容 | ✅ | ✅ |
FA3 的 persistent kernel 通过预计算 work 分配表 (get_scheduler_metadata),只启动 num_sm 个 persistent thread block,按大任务优先排序 + 动态 N-split 填满所有 SM。
十、FA2 vs FA3 Tiling 差异
Flash Attention 的 cu_seqlens_q 参数决定了 grid 中 batch 维度的大小。这直接影响了 cascade 的优化效果:
// flash_fwd_launch_template.h
const int num_m_block = (params.seqlen_q + kBlockM - 1) / kBlockM;
dim3 grid(num_m_block, params.b, params.h);
标准路径:cu_seqlens_q = [0, 1, 2, ..., N] → b = N
256 个 request 各 1 个 Q token,每个被独立分配 thread block,独立加载全部 KV pages。Prefix 的物理 pages 被 256 个 TB 各加载一次。
Cascade prefix:cu_seqlens_q = [0, 256] → b = 1
所有 Q 作为一个序列,Flash Attention 在 Q 维度上 tiling 为 ceil(256/64) = 4 个 M-block。Prefix KV 只加载 4 次(每个 M-block 加载一次),每次供 64 个 Q token 共享。
这就是为什么 cascade 的节省倍数 ≈ BLOCK_M(64 或 128)。
十一、局限性与注意事项
CUDA Graph 不兼容:由于动态控制流(
use_cascadeper-batch 决策)、不固定的 kernel 调用次数(1 vs 3)、动态 tensor shapes(block table 切片),cascade 无法支持 full CUDA graphs,退化为 piecewise 或 eager 模式非 k 级融合:vLLM 的 cascade 是函数级拆分(两次
flash_attn_varlen_func),prefix 阶段仍然走 Global Memory,没有达到 FlashInfer 的 SMEM 级 KV 复用MLA 模型不支持:DeepSeek V2/V3/V4 等使用 MLA 的模型走独立的 attention 路径,cascade attention 不适用
正确性保障:前缀长度被 cap 到
min(num_computed_tokens)避免 attention mask 泄漏;Kernel 级测试 (test_cascade_flash_attn.py) 和 E2E 测试 (test_cascade_attention.py) 保证输出与标准路径一致
总结
Cascade Attention 是 vLLM 面对多请求共享长前缀这一高频场景时,在 attention 计算层面的一个精妙优化。它的核心思路并不复杂——把共享的前缀拆出来只算一次——但实现层面涉及对 Flash Attention 底层 kernel tiling 策略的深刻理解。
关键洞察:Cascade 通过把 cu_seqlens_q 从 [0, 1, ..., N] 改为 [0, total_q],将 Flash Attention grid 的 batch 维度从 N 压缩到 1,就省掉了 (N-1) 次 prefix KV 的全局内存读取。对于 decode 阶段这个 memory-bound 的计算,这是一个优雅且高效的优化。
对于需要高频处理共享长前缀的 LLM 服务(多轮对话、文档问答、few-shot 推理),开启 cascade attention 是一个低成本、高收益的推理优化手段。