vLLM 的编译系统在标准 PyTorch torch.compile 之上做了大量定制:分段编译(Piecewise Compilation)、字节码 Hook、AOT 缓存、动态形状管 理等。本文从多个实际调试问题出发,系统梳理 vLLM 编译系统的核心机制。

一、FixFunctionalizationPass:消除冗余 Tensor 拷贝

背景

PyTorch 在 torch.compile 时为了保证函数式语义(图中没有 in-place 操作),会自动将 in-place 算子包装成 auto_functionalized(op, ...) 形式。

例如原始的 in-place 调用:

rotary_embedding(query, key, ...)  # in-place 修改 query 和 key

函数化(functionalization)后变成:

result = auto_functionalized(rotary_embedding, query=query, key=key, ...)
query_new = result[1]   # getitem[1]
key_new   = result[2]   # getitem[2]

这会带来额外的 tensor 拷贝,在 vLLM 中是不必要的性能开销。

Pass 做的事

遍历图中所有 auto_functionalized 节点,对已知的特定算子执行反函数化(De-functionalization):

  1. replace_users_with_mutated_args:把后续用 getitem[1]getitem[2] 取出"修改后 tensor"的节点,直接替换回原始输入 tensor(因为 in-place 修改了它,结果就在原地)。

  2. insert_defunctionalized:在图中插入原始的 in-place 调用。

  3. 移除原来的 auto_functionalized 节点和多余的 getitem 节点。

处理的算子列表

算子作用
rotary_embeddingRoPE 位置编码,in-place 修改 query/key
fused_add_rms_normRMSNorm + 残差,in-place 修改 input/residual
fused_add_rms_norm_static_fp8_quantFP8 量化版 fused RMSNorm
rms_norm_dynamic_per_token_quant动态量化 RMSNorm
rms_norm / rms_norm_static_fp8_quantRMSNorm
silu_and_mul / silu_and_mul_quantSiLU 激活 + 乘法(FFN 中间层)
fused_qk_norm_ropeQ/K norm + RoPE 融合算子
flashinfer_trtllm_fused_allreduce_normAllReduce + Norm 融合(分布式)

注意事项

此 Pass 之后不能再运行 DCE(死代码消除),因为反函数化后的 in-place 节点没有返回值,在图中看起来像"死代码",DCE 会误删它们。


二、DynamicShapesConfig:控制动态形状行为

三个配置项

1. type: DynamicShapesType(默认 BACKED

控制 Symbol 的"有后备值"策略:

类型含义特点
BACKED有后备值的符号,PyTorch 默认行为会对 0/1/>=2 做特化(specialization),可能产生 guard
UNBACKED无后备值,不保证被 guard最"正确",不会被特化,但遇到数据相关分支可能报错
BACKED_SIZE_OBLIVIOUS实验性:对 backed symbol 也按 unbacked 处理介于两者之间

2. evaluate_guards: bool(默认 False

调试用途,检测是否有 Dynamo 对动态 shape 做了非预期的特化。开启后不丢弃 guard,如果 shape 变化触发重编译则直接报错。需要同时设置 VLLM_USE_BYTECODE_HOOK=0,BACKED 模式下还需关闭 AOT Compile(因 PyTorch bug)。

3. assume_32_bit_indexing: bool(默认 False

开启后 Inductor 可生成更高效的 32-bit 索引代码。需要 PyTorch 2.10+。

Backed vs Unbacked 的本质区别

第一次跑:seq_len = 128
Backed:   s0 = 128(有后备值),PyTorch 知道"它曾经是 128"
Unbacked: s0 = ???(无后备值),PyTorch 完全不知道它是多少

Backed 符号因为知道后备值,Dynamo 会倾向于做特化(如 s0 != 0s0 >= 2),这些 guard 在每次运行时都要验证,不满足则触发重编译。

vLLM 的 guard 策略

默认使用 Backed(借助后备值帮助编译推导),但主动丢弃 Dynamo 生成的 guard(通过 guard_filter_fn 设为 skip_all_guards_unsafe),从而在不同形状下复用同一份编译产物,避免反复重编译:

# wrapper.py
options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe

三、VLLM_USE_BYTECODE_HOOK:字节码 Hook 绕过 Dynamo

作用

控制 vLLM 是否使用一种更高效的方式来执行编译后的代码——直接注入 Python 字节码,绕过 Dynamo 的 dispatch 开销。

VLLM_USE_BYTECODE_HOOK=1(默认,推荐)

第一次调用:
  torch._dynamo.eval_frame.remove_from_cache(original_code)  # 强制触发编译
  → Dynamo 编译,触发 bytecode_hook 回调
  → hook 里捕获编译后的字节码,保存到 self._compiled_bytecode

后续调用:
  直接 dispatch_to_compiled_code()  ← 用编译后的字节码直接执行
  完全绕过 Dynamo 的 eval_frame 机制

VLLM_USE_BYTECODE_HOOK=0

每次调用:
  走标准 torch.compile 路径
  → Dynamo eval_frame → 检查 guard → 执行编译代码
  多了 Dynamo guard 检查的开销

bytecode hook 额外做的两件事

  1. 调试:保存反编译后的源码 — 用 depyf 把字节码反编译成可读 Python 代码写入 transformed_code.py
  2. 安全检查:检测 buffer mutation — 如果字节码里有 "update" 调用(说明 forward 里修改了 nn.Module 的 buffer),在 cudagraph 模式下会导致静默错误,直接抛异常

为什么某些特性需要关闭它(=0

特性原因
evaluate_guards=Truebytecode hook 模式下 remove_from_cache 会清掉所有缓存,导致 guard 检测失效
UNBACKED dynamic shapes同上,remove_from_cache 会强制每次重新编译,破坏 unbacked symbol 的追踪

四、guard_filter_fn 详解:控制哪些 Guard 保留

evaluate_guards=True 时,guard_filter_fn 设为只保留 SHAPE_ENV 类型的 guard:

options["guard_filter_fn"] = lambda x: [
    entry.guard_type == "SHAPE_ENV" for entry in x
]

SHAPE_ENV 类型的 guard 即形状相关的 guard(如 s0 >= 2s0 != 1)。只保留 shape guard 意味着 shape 变化会触发重编译(并报错),帮助定位"哪个 shape 被 Dynamo 意外特化了"。

分支保留的 guard效果
evaluate_guards=True只保留 SHAPE_ENVshape 变化 → recompile → 报错(调试用)
evaluate_guards=False(默认)全部丢弃任何 shape 都不触发 recompile

注意:通过 inductor_compile_config 传入自定义 guard_filter_fn 无效,因为代码后面会直接覆盖写入:

options = vllm_config.compilation_config.inductor_compile_config  # 先读
options["guard_filter_fn"] = lambda x: [False for _ in x]          # 后覆盖!

要让不同 shape 触发 recompile,只有 STOCK_TORCH_COMPILE 模式能真正工作——该模式下整段 guard 覆盖逻辑被跳过。


五、compile_sizes 与 PiecewiseBackend:真正的用途

compile_sizes 不是 recompile 触发器

compile_sizes 的语义是预热(warmup),不是 recompile 触发器:

compile_sizes = [128, 256, 512]
→ 启动时主动对这三个 size 各跑一次 dummy forward,提前触发 Inductor 编译
→ 目的是"提前编译好,避免运行时第一次遇到时的延迟"

PiecewiseBackend 的真实机制

vLLM 在 VLLM_COMPILE 模式下用的不是标准 Dynamo cache 机制,而是自己的 PiecewiseBackend

VllmBackend.__call__ 中,Dynamo trace 生成的完整 FX graph 被 split_graphsplitting_ops 分割成多个子图,PiecewiseCompileInterpreter 遍历 split_gm,对每个需编译的子图创建 PiecewiseBackend

PiecewiseBackend 在初始化时为每个 compile_size 和 compile_range 各预编译一份 Inductor kernel

compile_sizes = [512]
compile_ranges = [Range(1,8), Range(9,64)]

→ 生成 3 个编译产物:
  Range(512, 512) → 专门为 shape=512 编译的静态 kernel
  Range(1, 8)     → 为 shape 1~8 编译的通用动态 kernel
  Range(9, 64)    → 为 shape 9~64 编译的通用动态 kernel

运行时 dispatch 逻辑(_find_range_for_shape):

def _find_range_for_shape(self, runtime_shape):
    # 1. 精确匹配 compile_sizes
    if runtime_shape in self.compile_sizes:
        return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
    # 2. 范围匹配 compile_ranges
    for range in self.compile_ranges:
        if runtime_shape in range:
            return self.range_entries[range]
    return None  # → assert 报错,不是 recompile

如果 shape 不在任何 compile_sizes 和 compile_ranges 里,直接 assert 报错(硬崩溃),不是 recompile,不是 cache miss。

注意:_find_range_for_shape 的 bug

第 344-345 行在 compile_sizes is None跳过了 compile_ranges 的查找,直接返回 None

if self.compile_sizes is None:
    return None  # ← 即使 compile_ranges 有匹配也找不到,看似是 bug

不过实际运行中 compile_sizes 会被初始化为 [](空列表)而非 None,所以很少触发。


六、create_concrete_args 与 get_fake_args_from_graph

这两个函数是一对,分别用于 compile_sizes(精确 shape 编译)和 compile_ranges(范围编译)的输入准备。

create_concrete_args — 静态 kernel

def concretize(sym_val):
    # 把符号表达式里所有符号变量替换为具体 size
    expr = sym_val.node.expr
    return int(expr.subs({s: size for s in expr.free_symbols}))

例如 size=512s0 → 512s0 * 2 → 1024s0 + 3 → 515

遍历 graph 的所有 placeholder,把 SymInt 替换为具体整数,创建具体 shape 的 fake tensor,交给 Inductor 生成针对该 size 优化的静态 kernel

get_fake_args_from_graph — 动态 kernel

直接从 FX graph 的 placeholder 取出 example_value,这些值里的 shape 维度仍然是 SymInt 符号变量,不做任何替换。

Inductor 看到 SymInt 就知道"这个维度是动态的,需要生成为运行时参数",生成对该维度参数化的 Triton kernel

create_concrete_argsget_fake_args_from_graph
用途compile_sizes(精确 shape 编译)compile_ranges(范围编译)
输入 shape符号变量替换为具体数字保留符号变量原样
编译出的 kernel静态 kernel,只能处理那一个 size动态 kernel,能处理范围内任意 size

分支点(piecewise_backend.py L258-263):

if range_entry.compile_range.is_single_size():
    args_list = create_concrete_args(self.graph, range_entry.compile_range.start)
else:
    args_list = get_fake_args_from_graph(self.graph)

七、dynamic=False 但仍有 SymInt?mark_dynamic 覆盖

虽然 torch.compile(dynamic=False, ...) 硬编码在 wrapper.py 中,但 vLLM 在第一次编译前手动调用了 _mark_dynamic_inputs,对特定输入 tensor 调用 torch._dynamo.mark_dynamic(arg, dims)

覆盖了 dynamic=False 的全局设置。mark_dynamic 告诉 Dynamo:“即使全局是 dynamic=False,这几个 tensor 的这几个维度也要当符号变量处理。”

结果:被标记的维度 → SymInt(符号变量),其他维度 → 具体整数。

这样 PiecewiseBackend 才能:

  • compile_ranges:保留 s0,让 Inductor 生成对 s0 参数化的动态 kernel
  • compile_sizes:把 s0 替换为具体值,生成静态 kernel

设计哲学dynamic=False + 手动 mark_dynamic = 精确控制哪些维度是动态的。只在 num_tokens 等必要维度上引入符号变量,其余维度享受静态优化。


八、self.forward vs self._compiled_callable:字节码替换机制

# __init__ 时:
self._compiled_callable = torch.compile(self.forward, ...)
  • self.forward:原始的 Python 函数(未编译)
  • self._compiled_callable:经过 torch.compile 包装后的 callable

首次调用

torch._dynamo.eval_frame.remove_from_cache(...)  # 清缓存,强制触发编译
self._compiled_callable(*args)

→ 走 Dynamo eval_frame 机制 → trace self.forward → 编译 → bytecode_hook 捕获编译后的字节码存入 self._compiled_bytecode

后续调用

with self._dispatch_to_compiled_code():
    self.forward(*args)

关键在 _dispatch_to_compiled_code

self.__class__.forward.__code__ = self._compiled_bytecode  # 偷换字节码

直接把 forward 函数的字节码替换成编译后的版本,调用 self.forward 时 Python 执行的是编译后的字节码,完全绕过 Dynamo。

调的是谁经过 Dynamo?目的
首次self._compiled_callable触发 Dynamo trace + 编译,捕获字节码
后续self.forward(字节码被替换)直接执行编译产物,零 Dynamo 开销

九、编译后的字节码是否包含 torch dispatch?

分两部分:

编译子图(__compiled_subgraph_N

不走 torch dispatch。Inductor 生成的代码内部直接调用 Triton kernel 或 CUDA kernel,完全绕过了 Python 层的 torch dispatch 栈(__torch_dispatch__、autograd 等)。这是 torch.compile 加速的核心来源之一。

Splitting ops(如 attention)

仍走 torch dispatch。这些 op 被排除在编译图外,保留在字节码里作为普通 Python 调用,经过完整的 torch dispatch 栈。

部分是否走 torch dispatch原因
编译子图内的 op否,直接调 Triton/CUDA kernel已被 Inductor 编译替换
Splitting ops被排除在编译图外,保持原始调用

这也是 piecewise(分段编译) 的核心意义。

torch dispatch 的完整路径

torch.nn.functional.silu(x) 为例:

Python API → torch.ops.aten.silu → C++ dispatcher

Dispatcher 按优先级检查 dispatch key:
  FuncTorchBatched (vmap)
  Autograd (自动求导)
  AutocastCUDA (混合精度)
  FunctionalizeKey
  ProxyTorchDispatch (Dynamo tracing 时)
  CUDA / CPU / XPU (实际硬件后端)

→ 找到最高优先级的 active key → 调用其注册的 handler
→ 最终到达 CUDA backend → silu_kernel<<<blocks, threads>>>

Dynamo tracing 时 ProxyTorchDispatch 最高优先,不实际执行,只记录 FX node 并返回 fake tensor。

编译后:直接 triton_fused_silu_0.run(x_ptr, output_ptr, numel, ...),跳过整个 Python → C++ dispatcher → backend lookup 过程。


十、VLLM_USE_AOT_COMPILE:提前编译与缓存持久化

AOT vs JIT

JIT(默认)AOT(AOT_COMPILE=1
编译时机运行时第一次遇到时编译warmup 阶段编译,持久化到磁盘
第二次启动重新编译直接从磁盘加载,跳过编译
缓存路径仅内存缓存VLLM_CACHE_ROOT/torch_compile_cache/torch_aot_compile/{hash}

相关环境变量

变量作用
VLLM_USE_AOT_COMPILE=1开启 AOT 编译
VLLM_FORCE_AOT_LOAD=1强制从磁盘加载,找不到缓存则报错
VLLM_USE_MEGA_AOT_ARTIFACT=1更高效地加载编译产物,跳过重新分割图
VLLM_DISABLE_COMPILE_CACHE=1禁止保存/加载 AOT 缓存

hash 包含模型结构、编译配置、PyTorch 版本等因子,任何变化都会导致 cache 失效。


十一、首次编译路径详解(decorators.py L545-632)

1. 收集 trace 到的源文件(L545-565)

记录 Dynamo trace 过程中涉及的所有 Python 源文件到 self.compilation_config.traced_files。当这些文件发生变化时 AOT 缓存失效并重新编译。

通过 monkey-patch InliningInstructionTranslator.inline_call_,在 Dynamo 内联每个函数调用时顺便记录文件名。

2. 配置编译 patch(L567-600)

patch作用
enable_cpp_symbolic_shape_guards = False关闭 C++ 版 shape guard 编译(vLLM 反正会丢弃 guard,编译它浪费时间)
backed_size_oblivious = True仅在 BACKED_SIZE_OBLIVIOUS 模式下开启
assume_32bit_indexing让 Inductor 生成 32-bit 索引代码(PyTorch 2.10+)

3. 编译分支:AOT vs JIT(L601-632)

AOT 路径

self.aot_compiled_fn = self.aot_compile(*args, **kwargs)  # Dynamo trace + Inductor 编译
self.save_aot_compiled_function()                          # 保存到磁盘
output = self.aot_compiled_fn(self, *args, **kwargs)       # 预热 run

JIT 路径

output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
# 编译和首次执行合并在一起

4. 标记编译完成

self.compiled = True

后续调用走已编译路径,不再进入首次编译逻辑。

为什么首次编译要 mark_dynamic?

如果不 mark dynamic,Dynamo trace 时所有维度都是具体值(因 dynamic=False):

首次输入:hidden_states.shape = [128, 4096]
Dynamo 认为 128 是常量

FX graph:所有地方都写死 128
PiecewiseBackend:只能处理 128,其他 shape 全崩

mark dynamic 后:

hidden_states.shape = [s0, 4096]

FX graph:第 0 维是 s0(符号),第 1 维是 4096(常量)

PiecewiseBackend 才能为不同 shape range 分别编译 kernel。且只需要首次编译时 mark,因为 Dynamo 只 trace 一次(guard 被丢弃,永远 cache hit)。


十二、NVTX:NVIDIA 性能标注工具

NVTX(NVIDIA Tools Extension)用于在 GPU 性能分析工具(Nsight Systems、Nsight Compute)中给代码段打标签。

layerwise_nvtx_tracing_enabled 对应 --enable-layerwise-nvtx-tracing。开启后每层模型的 forward 前后插入 nvtx.range_push/pop,Nsight Systems 中能清晰看到每层的 kernel 分布。

没有 NVTX:
  ┌──kernel──┐┌──kernel──┐┌──kernel──┐
  │  ???     ││  ???     ││  ???     │

有 NVTX:
  ┌─── Layer 0: RMSNorm ───┐┌─── Layer 0: Attention ───┐
  │  kernel1  │  kernel2   ││  kernel3  │  kernel4     │

十三、torch.compile 是惰性的

torch.compile() 不会立刻触发编译,它只是把函数包装成一个"编译就绪"的 callable:

self._compiled_callable = torch.compile(compiled_ptr, ...)
# 此时没有任何编译发生

真正的编译在第一次调用 self._compiled_callable(*args, **kwargs) 时才触发。

torch.compiler.is_compiling() 返回 True 的时间窗口涵盖整个编译过程:Dynamo tracing → AOT Autograd → Inductor 优化/代码生成。

vLLM 使用 is_compiling() 来防止编译嵌套:

if self.do_not_compile or torch.compiler.is_compiling():
    return self.forward(*args, **kwargs)

如果当前已在编译中,直接走原始 forward,不要再尝试编译。


十四、每个子 FX 图分别编译

以一个 Transformer layer 为例,split 后有 3 个子图:

split_gm
  ├── submod_0  (RMSNorm)        ← 需编译
  ├── submod_1  (Attention)       ← splitting op,不编译
  └── submod_2  (RMSNorm + MLP)  ← 需编译

PiecewiseCompileInterpreter.run() 遍历 split_gm,对每个需编译的子图:

# backends.py L730
piecewise_backend = PiecewiseBackend(submod, ...)
# __init__ 内部调用 compile_all_ranges()
# 对这个子图的所有 shape range 逐一编译

总编译次数:

总编译数 = 可编译子图数 × (compile_sizes 个数 + compile_ranges 个数)

每个子图是独立的计算图(不同的 FX graph、不同的输入/输出签名),Inductor 需要分别做算子融合、内存规划、生成 Triton kernel,无法合并。这就是 piecewise 的由来。

Dynamo trace 的计时日志("Dynamo bytecode transform time: %.2f s")位于 VllmBackend.__call__ 开头,在 split_graphPiecewiseCompileInterpreter.run() 之前,所以这个 log 只打印一次,不包含后续子图编译的时间。

CompileContext 在 piecewise 编译期间有效

整个 piecewise 编译都发生在 VllmBackend.__call__ 内部,而 VllmBackend.__call__ 是 Dynamo 作为 backend 回调调用的。从 Dynamo 设置 CompileContextVllmBackend.__call__ 返回,整个调用栈都没有退出,所以 CompileContext.current_compile_id() 始终返回有效的 compile_id,不会是 None。

注意区分 ShapeEnvCompileContext

  • ShapeEnv:记录符号变量的约束关系,Dynamo trace 时创建。piecewise 后续为不同 shape 调用 Inductor 时原始 ShapeEnv 不再适用,需要 AlwaysHitShapeEnv 绕过。
  • CompileContext:记录编译 ID,是整个 backend 回调期间的线程局部状态。

十五、splitting_ops 的格式与配置

splitting_ops 决定了在 FX graph 的哪些算子处分割图。should_split 的匹配逻辑:

FX node 类型匹配字段示例值
OpOverloadPackettarget._qualified_op_name"torch_thrive::rotary_embedding"
OpOverloadtarget.name()"torch_thrive::rotary_embedding"
OpOverload(精确重载)f"{name}.{overloadname}""torch_thrive::rotary_embedding.default"

PyTorch 的 op 命名格式为 namespace::op_name。配置时:

splitting_ops = ["torch_thrive::rotary_embedding"]

可同时匹配 packet 和所有 overload。


十六、causal_conv1d:Mamba Conv State 更新逻辑

state_len 修正

IS_VARLEN 模式下,输入 x 把所有序列的 token 打包在一起,每个序列的实际 token 数不同。

传入的参数:

  • seqlen = max_query_len(全局最大序列长度)
  • state_len = width - 1 + (seqlen - 1)(spec decoding 下,保留历史 width-1 个 + 新 draft token seqlen-1 个)
  • 当前序列实际长度 = query_end_index - query_start_index

修正逻辑:

state_len = state_len - (seqlen - (query_end_index - query_start_index))
seqlen    = query_end_index - query_start_index

本质:把多出来的 max_seqlen - actual_seqlenstate_len 里减掉,使后续所有 state_len - seqlen 相关逻辑(读取旧 state、拼接新 token、写回新 state 的 mask 计算)都能对齐到该序列真实的 token 数量。

idx_tokens - VAL 的含义

VAL = state_len - seqlen 是新 conv_state 中从旧历史保留下来的 token 数量。

新 conv_state 布局:
索引:   0         ...   (VAL-1) | VAL       ...  (state_len-1)
内容: [ 旧conv_state的后半段历史  |  新输入x的所有token           ]
       ←——— VAL = state_len-seqlen 个 ———→←——— seqlen 个 ———→

idx_tokens - VAL 把新 conv_state 的槽位坐标映射到 x 的 token 坐标:

idx_tokensidx_tokens - VAL含义
0 ~ VAL-1负数对应旧历史,不从 x
VAL0对应 x[0],第一个新 token
state_len-1seqlen-1对应 x[seqlen-1],最后一个新 token

掩码逻辑:

new_conv_state = tl.where(mask, conv_state, loaded_x)
# mask(idx_tokens + seqlen < state_len,即 idx_tokens < VAL):用旧 conv_state
# 反之:用从 x 读来的新 token

总结

梳理 vLLM 编译系统的全貌:

编译流程torch.compile(dynamic=False) → Dynamo trace(mark_dynamic 标记必要维度)→ VllmBackend.__call__split_graph 分割 → PiecewiseCompileInterpreter 遍历子图 → 每个子图的 PiecewiseBackend 为每个 shape range 编译 Inductor kernel → 字节码 hook 捕获编译产物

Guard 策略:vLLM 通过 guard_filter_fn 丢弃所有 Dynamo guard,避免 shape 变化触发重编译。STOCK_TORCH_COMPILE 模式可恢复原生 guard 行为。evaluate_guards 用于调试。

Piecewise 编译:将 Transformer layer 按 splitting_ops 分割为多个子图,分别编译。compile_sizes 对应静态 kernel(精确 shape),compile_ranges 对应动态 kernel(符号参数化)。运行时 _find_range_for_shape 按 shape 分发。

字节码 Hook:捕获编译后字节码,后续调用直接替换 forward.__code__,完全绕过 Dynamo,零额外开销。

AOT 编译:编译产物序列化到磁盘,后续启动直接加载,跳过编译。

FixFunctionalizationPass:消除 auto_functionalized 带来的冗余拷贝,恢复 in-place 语义。