背景

vLLM 利用 PyTorch 2.x 的 torch.compile 路径,通过 TorchDynamo 捕获模型计算图,再经过图分割(graph splitting)和分段编译(piecewise compilation)来优化 GPU kernel 执行效率。本文将深入剖析 FX Graph 的分割原理及其背后的设计思想。

一、TorchDynamo 编译概览

TorchDynamo 是一个 Python 级别的 JIT 编译器,它通过 PEP 523 的 frame evaluation callback 在 Python 字节码执行之前捕获计算图。vLLM 利用这一机制,将模型 forward 函数中的计算捕获为 fx.GraphModule,然后送入自定义后端 VllmBackend 进行编译。

编译流程大致如下:

model.forward()
  └── TorchDynamo 捕获计算图
        └── fx.GraphModule (原始完整图)
              └── VllmBackend.__call__()
                    ├── split_graph()       → 图分割
                    ├── PiecewiseCompileInterpreter → 分段编译
                    └── codegen              → 生成胶水函数
                          └── 返回可调用对象给 Dynamo

二、FX Graph 分割:split_graph 的工作原理

2.1 为什么需要分割

vLLM 的模型计算中包含两种性质不同的算子:

  • 静态计算:MLP、LayerNorm、RMSNorm 等连续矩阵运算,shape 稳定,适合 Inductor 编译和 CUDA graph capture
  • 动态算子:Attention(包含动态 shape、string 参数等),无法直接 CUDA graph capture,且 Inductor 难以优化

如果对整个图做单一编译,Inductor 需要对 attention 等动态算子做特殊处理,编译效率和运行效率都不理想。因此 vLLM 采用了图分割策略。

2.2 split_graph 函数的实现

split_graph 内部调用 torch.fx.passes.split_module.split_module,根据指定的 splitting_ops(如 attention 相关算子)对原始图进行切分。关键代码位于 vllm/compilation/backends.py

split_gm, outputs = split_module(
    graph_module, None, tagging_fn, split_by=SplitBy.TAGS
)

tagging_fn 决定每个节点属于哪个子图:

  • 遇到 splitting_ops 中的算子时,开启一个新子图
  • 连续的非 splitting 算子合并到同一个子图中

2.3 返回值的真相:为什么 split_gm 还是一个 fx.GraphModule

这是最常见的疑问:调用 split_graph 后,返回的 split_gm 仍然是一个 fx.GraphModule,而不是多个独立的 graph。

原因是 split_module 的设计就是返回一个根 GraphModule,内部把每个子图作为子 module 挂载其上split_gm 的结构如下:

split_gm (fx.GraphModule)
├── graph: stitching graph(胶水图)
│   └── nodes:
│       ├── placeholder (原始输入)
│       ├── call_module submod_0  → 调用子 module
│       ├── call_module submod_1  → 调用子 module
│       ├── ...
│       └── output
├── submod_0 (fx.GraphModule)  ← 子图 0
├── submod_1 (fx.GraphModule)  ← 子图 1
├── ...
└── submod_N (fx.GraphModule)  ← 子图 N

根 module 自己的 graph 是一个"stitching graph"(胶水图),只负责按顺序调用子 module 并传递 tensor。stitching graph 的 forward 大致如下:

def forward(self, ...):
    submod_0 = self.submod_0(...)
    getitem = submod_0[0]
    submod_1 = self.submod_1(getitem, ...)
    ...
    return output

额外返回的 outputs: list[SplitItem] 是这些子 module 的列表,调用方可以通过 split_gm.submod_0 或遍历 outputs 拿到每个独立的子图。

这是一种组合模式——既保留了完整的调用关系,又允许单独编译每个子 module。

三、分段编译架构

3.1 PiecewiseCompileInterpreter

得到 split_gm 后,vLLM 使用自定义的 PiecewiseCompileInterpreter(继承自 torch.fx.Interpreter)来驱动编译。它用 fake tensor 执行 split_gm 的 stitching graph,逐节点模拟 forward:

interpreter = PiecewiseCompileInterpreter(
    self.split_gm, submod_names_to_compile, ...
)
interpreter.run(*fake_args)

每遇到一个 call_module 节点(即一个子图 submod_N),call_module 方法会:

  1. 为子图创建 PiecewiseBackend 实例(内部触发 Inductor 编译)
  2. wrap_with_cudagraph_if_needed 包装(CUDA graph 支持)
  3. split_gm 上的原子 module 替换为编译后的可调用对象

3.2 dict 替换技巧

替换子 module 的方式非常巧妙——直接操作 __dict__ 而非 _modules

self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
    piecewise_backend, ...
)

这利用了 Python 属性查找顺序:__dict____getattr__(查 _modules。编译后的 PiecewiseBackend 直接写入实例的 __dict__,之后 getattr(self.split_gm, "submod_0") 拿到的是编译后的版本,原始 GraphModule 被 shadow 了。

好处是:不干扰 PyTorch 对 _modules 的管理named_children()state_dict()to(device) 等仍然正常工作),同时运行时通过属性访问拿到的是编译后的可调用对象。

四、PiecewiseBackend:多 shape 编译与运行时调度

4.1 整体职责

PiecewiseBackend 是每个子图的编译管理与运行时调度器。一个子图对应一个 PiecewiseBackend 实例。

PiecewiseBackend
├── range_entries: dict[Range, RangeEntry]
│   ├── Range(256,256)    → RangeEntry(runnable=Inductor 编译产物_256)
│   ├── Range(512,512)    → RangeEntry(runnable=Inductor 编译产物_512)
│   └── Range(1024,4096)  → RangeEntry(runnable=Inductor 编译产物_dynamic)
└── __call__(*args)
      → 根据 runtime_shape 找到对应 range_entry
      → 调用 range_entry.runnable(*args)  ← 真正的编译产物

4.2 两种初始化模式

  • 编译模式graph 不为空):首次编译,调用 compile_all_ranges() 为每个 shape range 生成 Inductor 编译产物
  • 预编译加载模式compiled_runnables 不为 None):从 AOT artifact 加载已编译产物
def compile_all_ranges(self):
    for range_entry in self.range_entries:
        range_entry.runnable = self.vllm_backend.compiler_manager.compile(
            self.graph, args_list, inductor_config, ...
        )

4.3 Shape Range 调度

运行时 __call__ 根据实际输入 shape 做 O(1) 级别的 dispatch:

  1. 提取 runtime_shape(动态维度值)
  2. _find_range_for_shape() 匹配:
    • 优先匹配 compile_sizes 中的精确 size(如 256 → Range(256,256)
    • 再匹配 compile_ranges 中的区间(如 2048 → Range(1024,4096)
  3. 调用对应的 range_entry.runnable

4.4 编译时间线

完整的编译时间线如下:

VllmBackend.__call__(graph, example_inputs)     ← 被 Dynamo 调用
├── split_graph(graph)                           → self.split_gm
├── PiecewiseCompileInterpreter.run(*fake_args)
│   └── call_module(target, args)
│       ├── PiecewiseBackend(submod, ...)
│       │   └── __init__
│       │       └── compile_all_ranges()
│       │           └── compiler_manager.compile(graph, args, ...)
│       │               └── range_entry.runnable = 编译好的 callable
│       └── self.module.__dict__[target] = wrapped_piecewise_backend
├── generate_execution_code(split_gm)
└── compile_execution_fn(...)
      └── 返回最终 callable 给 Dynamo

所有 Inductor 编译在 VllmBackend.__call__ 返回前同步完成,运行时 __call__ 只做 shape dispatch,零编译开销。

五、与 CUDA Graph 的配合

5.1 分层架构

CUDA graph 与 Inductor 编译是分层叠加关系,不是二选一:

最终 callable
  └── make_copy_and_call (输入拷贝层)
       └── codegen 生成的 stitching 函数
            ├── CUDAGraphWrapper (submod_0)
            │   └── PiecewiseBackend
            │       └── range_entry.runnable (Triton kernel)
            ├── fx.GraphModule (submod_1, attention, eager 执行)
            ├── CUDAGraphWrapper (submod_2)
            │   └── PiecewiseBackend
            │       └── range_entry.runnable (Triton kernel)
            └── ...

5.2 Capture 时机

CUDA graph capture 发生在第一次推理调用时(也是 warmup run),而非编译阶段:

第一次推理调用:
  CUDAGraphWrapper.__call__(*args)
    → entry.cudagraph is None → 需要 capture
    → torch.cuda.graph(cudagraph):
        output = PiecewiseBackend.__call__(*args)
               → range_entry.runnable(*args)  # Inductor Triton kernel
    → entry.cudagraph = cudagraph (缓存)
    → return output

后续同 shape 调用:
  CUDAGraphWrapper.__call__(*args)
    → entry.cudagraph is not None → replay
    → entry.cudagraph.replay()
    → return entry.output

5.3 输入拷贝责任分离

CUDAGraphWrapper 不做输入拷贝——这是一种正交设计。输入拷贝由外层 make_copy_and_call 负责:

def copy_and_call(*args):
    for i, index in enumerate(sym_tensor_indices):
        static_tensor[:runtime_shape].copy_(runtime_tensor)
        list_args[index] = static_tensor  # 保证地址不变
    return callable_fn(*list_args)

这样 CUDAGraphWrapper 保持与编译逻辑正交,不关心 shape 信息,只依赖"输入地址不变"这个外部保证。

六、四层 Graph 的设计思想

整个流程涉及 4 层 graph,各有不同角色:

层次名称角色
1Dynamo 原始图TorchDynamo 捕获的完整计算图
2split_gm (Stitching Graph)胶水图,只编排子图调用顺序
3Submodule Graphs (submod_N)子图,非 splitting ops 走 Inductor 编译,splitting ops 保留为 fx.GraphModule
4Inductor 编译产物Triton kernel,存于 range_entry.runnable
[Dynamo Graph]           完整原始图
      │ split_graph(splitting_ops)
      ▼
[split_gm]               胶水图
  ├── submod_0 ──PiecewiseBackend──Triton Kernel (CUDA Graph)
  ├── submod_1 ──fx.GraphModule───attention (Python eager)
  ├── submod_2 ──PiecewiseBackend──Triton Kernel (CUDA Graph)
  └── ...
      │ generate_execution_code()
      ▼
[gen_fn]                 纯 Python 缝合函数
      │ compile_execution_fn()
      ▼
[runtime_callable]       最终可调用对象

七、总结

vLLM 的 FX Graph 分割与分段编译架构的精髓在于:

  1. 算子级切分:在原始图上按算子类型切分,attention 等动态算子独立出来,避免干扰 Inductor 对静态计算的优化
  2. 组合式结构split_module 返回单根 GraphModule + 子 module 的组合,而非多独立 graph,兼顾完整调用关系和独立编译能力
  3. 延迟 capture:Inductor 编译在 VllmBackend.__call__ 中同步完成;CUDA graph capture 推迟到第一次推理时,按需进行
  4. 正交分层:Inductor 编译、CUDA graph capture、输入拷贝三层各司其职,通过 __dict__ 替换、make_copy_and_call 等技巧解耦
  5. 多 shape 支持PiecewiseBackend 通过 range_entries 管理不同 shape 的编译产物,运行时 O(1) dispatch

这套架构让 vLLM 在保持动态算子灵活性的同时,对静态计算部分获得了接近手写 Triton kernel 的性能,并通过 CUDA graph 进一步消除 Python 开销。