概述
本文围绕 vLLM 推理引擎中的两个关键注意力机制展开深入分析:Attention Sink(注意力沉没)和 Sliding Window Attention(滑动窗口注意力,SWA)。内容涵盖它们的设计原理、vLLM 中的代码实现、KV cache 管理流程,以及 DeepSeek V4 模型如何运用这些技术。
所有分析基于 vLLM 代码库中 DeepSeek V4 相关的实现。
一、Attention Sink(注意力沉没)
1.1 背景
在稀疏注意力(sparse attention)中,每个 query 只 attend 到 top-k 个 selected KV token,softmax 的归一化也只在选中的 token 上做。问题是:模型无法表达"选中的这些 token 我其实都不太关心",因为概率被强制在可选集上归一化到 1。
Attention Sink 解决这个问题:引入一个逐头(per-head)的可学习偏置(bias),在 softmax 归一化时增加一个虚拟的"沉没"项。
1.2 数学原理
标准 softmax 对一个 query 的 top-k scores 做归一化:
$$LSE = \log \sum_{i=1}^{k} e^{score_i}$$$$\text{prob}_i = e^{score_i - LSE}$$加入 attn sink 后,分母多了一项:
$$LSE_{\text{merged}} = \log\left(\sum_{i=1}^{k} e^{score_i} + e^{attn\_sink}\right)$$$$\text{prob}'_i = e^{score_i - LSE_{\text{merged}}}$$关键点:
- 分子不变——不需要对 score 或 v 做任何修改
- 分母膨胀——当
attn_sink接近 0 时,exp(attn_sink) ≈ 1,每个prob'_i被稀释 - 输出幅度衰减:
Σprob'_i = Σprob_i / (1 + exp(attn_sink - LSE)),差值就是被沉没掉的概率质量 - 只有分母变了,没有对应的 v 向量——sink 不对应任何实际 token 的 value,概率质量相当于被"丢弃"了
当 attn_sink = -inf(初始值):退化为普通 softmax。
当 attn_sink 很大(接近 0):模型保留了大量概率质量作为"注意力预算",表示"当前 query 不关心任何可用的 token"。
1.3 vLLM 实现
参数定义
vllm/model_executor/models/deepseek_v4.py (L965-971):
# Padded to min 64 heads for FlashMLA, initialized to -inf
# (no sink effect). Weight loading fills the first n_local_heads slots.
padded_heads = max(self.n_local_heads, 64)
self.attn_sink = nn.Parameter(
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
- 1D tensor,shape
[padded_heads](pad 到至少 64 以兼容 FlashMLA) - 初始化
-inf,无 sink 效果 - 从 checkpoint 加载,不是配置参数——每个 head 独立学出
权重加载
deepseek_v4.py (L1512-1518):按 tensor parallelism 切分,只加载当前 rank 负责的 head 的 sink 值。
前向传播
作为参数传入 FlashMLA kernel:
- Prefill:
flash_mla_sparse_fwd(..., attn_sink=self.attn_sink, ...) - Decode:
flash_mla_with_kvcache(..., attn_sink=self.attn_sink, ...)
ROCm 参考实现
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py 中有纯 PyTorch 的参考实现,逻辑最清晰:
Prefill(L908-951):合并 LSE:
orig_lse = torch.logsumexp(scores, dim=-1)
lse_for_o = torch.logsumexp(
torch.stack([orig_lse, attn_sink[:h_q].expand_as(orig_lse)], dim=0), dim=0
)
probs = torch.exp(scores - lse_for_o.unsqueeze(-1))
out = probs @ gathered_kv[..., :head_dim]
Decode(L1016-1083):用衰减因子缩放输出:
output *= (1.0 / (1.0 + torch.exp(attn_sink - lse))).unsqueeze(-1)
二、Sliding Window Attention(SWA)
2.1 基本原理
每个 query i 只 attend 到前面最多 W 个 token([i-W+1, i]):
token: t1 t2 t3 t4 t5 t6 t7 t8
window W=3
query t6 能看到: [t3, t4, t5]
query t7 能看到: [t4, t5, t6]
- KV cache 只保留窗口内:超过 W 的 token 可以丢弃,显存 O(W) 而非 O(seq_len)
- 感受野逐层叠加:虽然单层只看 W,堆叠 L 层后每个位置的隐状态可以通过层层传递,获得最大
L × (W-1)的感受野
2.2 vLLM 全流程实现
配置层
HF config 中的 sliding_window: N → ModelConfig.get_sliding_window() → CacheConfig.sliding_window
模型层
模型文件逐层决定是否启用 SWA(如 llama.py 根据 layer_types),把 per_layer_sliding_window 传给 Attention()。
Attention 层
attention.py 中确定 self.sliding_window,并:
get_kv_cache_spec()→ 返回SlidingWindowSpec(告知 KV cache 管理器)AttentionImpl.__init__()→ 把sliding_window转成(window_left, 0)元组(如 4096→(4095, 0)),传给 kernel
KV Cache 管理器
SlidingWindowManager(single_type_kv_cache_manager.py)负责 block 回收:
def get_num_skipped_tokens(self, num_computed_tokens):
return max(0, num_computed_tokens - self.sliding_window + 1)
def remove_skipped_blocks(self, request_id, total_computed_tokens):
num_skipped = self.get_num_skipped_tokens(total_computed_tokens)
if num_skipped <= 0:
return
num_skipped_blocks = num_skipped // self.block_size
for i in range(num_skipped_blocks - 1, -1, -1):
if blocks[i] == self._null_block:
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
Kernel 层
SWA 的正确性靠 kernel 的 window_size 参数保证,block 回收只是内存优化。不同 backend:
| Backend | 实现方式 |
|---|---|
| Triton | SLIDING_WINDOW constexpr,kernel 内只加载 [q_start-W+1, q_start] |
| FlashInfer | 传入 (window_left, 0) 元组 |
| FlexAttention | mask_mod = causal & (abs(q_idx-kv_idx) < window) |
三、Block 回收与生命周期
3.1 回收触发点
SWA 的 block 回收只发生在 allocate_slots() 函数中,每次 decode/prefill step 调用一次。另外在 request 结束时(KV transfer 前)也有一次清理。
allocate_slots 分三阶段:
- Stage 1:
remove_skipped_blocks()释放窗口外 block - Stage 2:处理前缀命中 block(包含 NULL padding)
- Stage 3:
allocate_new_blocks()分配新 block +cache_blocks()写入前缀缓存
3.2 Block 生命周期
remove_skipped_blocks 释放 block 后:
def free_blocks(self, ordered_blocks):
for block in ordered_blocks:
block.ref_cnt -= 1 # 仅减引用
self.free_block_queue.append_n(
[block for block in ordered_blocks
if block.ref_cnt == 0 and not block.is_null]
) # ref_cnt=0 的入 free queue
# block_hash 不清除,hash table 不变
释放后的 block 状态:
ref_cnt = 0,在 free queue 队尾block_hash不清除,仍在cached_block_hash_to_block哈希表中- KV 数据不清除
3.3 两种后继路径
路径 A——前缀命中抢救:
def touch(self, blocks):
for block in blocks:
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block) # 从 free queue 捞回
block.ref_cnt += 1
另一请求 P 的 find_longest_cache_hit 在 hash table 中找到该 block → touch() 将其从 free queue 移除并 ref_cnt++,block 被重新"钉住"。
路径 B——被其他请求分配走:
def get_new_blocks(self, num_blocks):
ret = self.free_block_queue.popleft_n(num_blocks)
for block in ret:
self._maybe_evict_cached_block(block) # 从 hash table 摘除,清 block_hash
block.ref_cnt += 1
block 从队首被取走 → _maybe_evict_cached_block 从 hash table 摘除并清除 block_hash → 后续请求无法再命中。
Free queue 是 FIFO 结构,先释放(队尾追加)后分配(队首取出),所以释放的 block 通常会留在队尾,路径 B 更常见。路径 A 只在空闲队列短、前缀恰好匹配时才发生。
四、SWA vs Full Attention:KV Cache Manager 差异
| 方面 | FullAttention | Sliding Window |
|---|---|---|
| block 回收 | 从不回收 | remove_skipped_blocks 释放窗口外 block |
| block table | 全部有效 block 连续排列 | 头部 NULL 填充,尾部窗口内有效 block |
| admission cap | 无(按 max_model_len 算) | 有(max_admission_blocks_per_request) |
| 前缀缓存方向 | 从左到右找最长前缀 | 从右到左找窗口内最长连续后缀 |
num_common_prefix_blocks | 正常返回 ref_cnt 计数 | 返回 0(不兼容 cascade attention) |
前缀缓存方向
FullAttention——左对齐,从左到右扫:
for block_hash in block_hashes:
if hit = get_cached_block(block_hash):
computed.append(hit) # 追加
else:
break # 首个 miss 停
结果:[B0, B1, B2, ...]——最长共享前缀。
SWA——右对齐,从右到左扫:
for i in range(max_blocks-1, -1, -1):
if hit = get_cached_block(block_hashes[i]):
computed[i] = hit
if contiguously_hit >= window_contiguous_blocks:
trim, break
else:
num_contiguous = 0
结果:[NULL, NULL, ..., B3, B4, B5]——窗口内最长连续后缀。
原因:SWA 中最早的 token 随时被窗口滑出回收,前缀不稳定;只有窗口内的后缀才是稳定的。
五、SlidingWindowSpec vs FullAttentionSpec
SlidingWindowSpec 与 FullAttentionSpec 的核心差异:
# FullAttentionSpec.max_memory_usage_bytes:
return cdiv(max_model_len, block_size) * page_size
# → 按 max_model_len 算
# SlidingWindowSpec.max_memory_usage_bytes:
num_tokens = min(sliding_window - 1 + max_num_batched_tokens, max_model_len)
return cdiv(num_tokens, block_size) + 1 # +1 因窗口可能不对齐 block 边界
# → 按滑动窗口算,远小于 max_model_len
最大区别:Full 的 KV cache 随序列增长(O(seq_len)),SWA 有上界(O(window))。同时 SlidingWindowSpec 有 admission cap,保证单请求的峰值 block 数不超过预计算的值,避免死锁。
六、ChunkedLocalAttention
ChunkedLocalAttention 是固定窗口 size 的 local attention,通过把 KV cache block table 做"虚拟分块"来复用已有 attention backend,不需要改 kernel。
目前只有 Llama 4 使用:
attn_cls = ChunkedLocalAttention if use_chunked_local_attn else Attention
与 SWA 的区别:SWA 的窗口连续滑动,CLA 把序列切成不重叠的 chunk。
七、DeepSeek V4 中的 SWA 实现
7.1 三层注意力架构
每个 decoder layer 由 compress_ratio 决定类型:
| 类型 | compress_ratio | 含义 |
|---|---|---|
| SWA-only | ≤ 1 | 纯滑动窗口,无压缩 KV cache |
| C4A | = 4 | 4× 压缩稀疏 attention + SWA(带 Indexer) |
| C128A | = 128 | 128× 压缩稀疏 attention + SWA(无 Indexer) |
7.2 多个子模块
每个 DeepseekV4Attention layer 内创建多个子模块:
| 子模块 | spec | backend | KV cache manager |
|---|---|---|---|
DeepseekV4SWACache | SlidingWindowMLASpec (block=64) | DeepseekSparseSWABackend | SlidingWindowManager |
DeepseekV4MLAAttention (SWA-only) | None | DeepseekV4FlashMLASparseBackend | 无独立 cache(复用 SWA cache) |
DeepseekV4MLAAttention (C4A/C128A) | MLAAttentionSpec (block=256) | 同上 | FullAttentionManager |
DeepseekV4IndexerCache (仅 C4A) | MLAAttentionSpec (block=256) | DeepseekV4IndexerBackend | FullAttentionManager |
CompressorStateCache | SlidingWindowMLASpec (block=4/8) | CompressorBackend | SlidingWindowManager |
7.3 DeepseekV4SWACache:只有存储,没有计算
DeepseekV4SWACache 不参与 attention 计算。它只负责:
- 管理 SWA 的 KV cache buffer(通过
SlidingWindowManager分配 block) - 生成 SWA backend metadata(slot mapping、indices 等)
- 返回
SlidingWindowMLASpec给 KV cache 管理器
真正的计算在 DeepseekV4MLAAttention._forward_decode / _forward_prefill 中进行。
7.4 SWA indices 计算
sparse_swa.py 中的 Triton kernel _compute_swa_indices_and_lens_kernel 每 decode step 运行一次。它把逻辑位置转成 SWA cache 的物理 slot ID(通过 block table 查 slot mapping):
start_pos = tl.maximum(pos - window_size + 1, 0)
end_pos = pos + 1
swa_len = end_pos - start_pos
# ... 查 slot mapping → 得到 swa_indices
三种 layer type 共享同一份 tile scheduling plan(tile_sched_swaonly / tile_sched_c4a / tile_sched_c128a),每 decode step 每个 type 只算一次。
7.5 Decode 路径
flash_mla_with_kvcache 同时传入 SWA cache 和压缩 cache:
out, _ = flash_mla_with_kvcache(
q=q,
k_cache=swa_cache, # SWA KV cache (滑动窗口内)
indices=swa_indices, # SWA 的 slot 索引
topk_length=swa_lens,
extra_k_cache=kv_cache, # 压缩稀疏 KV cache (C4A/C128A)
extra_indices_in_kvcache=topk_indices,
extra_topk_length=topk_lens,
)
7.6 Compressor 的滑动窗口
Compressor 也用 SlidingWindowMLASpec 管理其 state 的滑动窗口。窗口大小 coff * compress_ratio(如 C4A 时为 32,C128A 时为 256)。
7.7 多个 KV cache 组
group_and_unify_kv_cache_specs() 为 DeepSeek V4 创建多个 KV cache 组,按 (block_size, sliding_window) 分组:
- MLA 主 cache:
MLAAttentionSpec,block_size=256,FullAttentionManager - SWA cache:
SlidingWindowMLASpec,block_size=64,SlidingWindowManager - C4A compressor state:
SlidingWindowMLASpec,block_size=4,SlidingWindowManager - C128A compressor state:
SlidingWindowMLASpec,block_size=8,SlidingWindowManager
八、总结
本文深入分析了 vLLM 中的 Attention Sink 和 Sliding Window Attention 两种机制:
Attention Sink 是一个逐头的可学习 bias,在 softmax 归一化时增加虚拟项来解决稀疏注意力中"模型无法表达不关心"的问题。实现上只需在 log-sum-exp 上做 logaddexp,不影响 score 或 value。
Sliding Window Attention 限制每个 query 只 attend 到窗口内 token,配套的 KV cache block 回收机制显著降低长序列推理的显存占用。vLLM 的实现同时作用于两个层面:
- 内存层:
SlidingWindowManager主动回收窗口外 block - 正确性层:kernel 的
window_size约束保证计算准确
两个层面互相独立——不回收也能正确计算,回收只是优化。
DeepSeek V4 将 SWA 与稀疏压缩注意力(C4A/C128A)组合使用,通过多个 KV cache 组分别管理不同类型 cache,DeepseekV4SWACache 仅负责存储,不参与计算。其 prefix cache 右对齐设计是与标准 full attention 最显著的区别之一。