背景:什么是 UBatchWrapper

在 vLLM 的 GPU 推理引擎中,UBatchWrapper(位于 vllm/v1/worker/gpu_ubatch_wrapper.py)是一个模型包装器。它拦截对原始模型的调用,在内部将一次大 batch 的 forward 拆分成多个 micro-batch(ubatch),并利用多线程 + CUDA stream 实现计算与通信的重叠。

核心目标:解决 MoE 模型中 all2all 通信开销大的问题——让一个 ubatch 在做计算时,另一个 ubatch 的通信已经在背后并行执行。


线程模型与同步原语

threading.Barrier:集结号

UBatchWrapper 在初始化时创建一个 threading.Barrier

self.ready_barrier = threading.Barrier(num_ubatches + 1)

threading.Barrier 不关心具体有哪些线程,只计数。内部维护一个计数器,每当有线程调用 barrier.wait(),计数器就 +1。当总数达到构造时指定的 parties 值(这里是 num_ubatches + 1),所有正在 wait() 的线程同时被释放,计数器归零。

这里的 +1 包含了 N 个 ubatch 线程 + 1 个主线程

ready_barrier.wait() 在三处被调用:

文件行号调用方
gpu_ubatch_wrapper.py261主线程_capture_ubatches
gpu_ubatch_wrapper.py325主线程_run_ubatches
ubatching.py56每个 ubatch 线程UBatchContext.__enter__

少一个线程到达,所有人全部卡住。

threading.Event:红绿灯信号牌

threading.Event 是一个简单的线程间布尔标志

操作语义
event.set()举起信号牌(设为 True)。其他线程的 wait() 看到已举牌,直接通过不阻塞。
event.clear()放下信号牌(设为 False)。之后 wait() 的线程会阻塞等待。
event.wait()阻塞直到信号牌被举起set()),才继续往下走。

在 ubatch 的上下文管理中,每个 ubatch 线程有两个 event 引用:

  • cpu_wait_event[i]:ubatch[i] 用于等待自己被唤醒的信号牌
  • cpu_signal_event[i] = cpu_events[(i+1) % N]:ubatch[i] 完成后唤醒下一个 ubatch 的信号牌

由此形成一条 链式唤醒 路径:0 → 1 → 2 → ... → N-1


完整调用链路

execute_model 到 ubatch 线程执行的全路径:

execute_model()
│
├─ 1. _prepare_inputs() → 准备 logits_indices 等
│
├─ 2. _determine_batch_execution_and_padding()
│     └─ coordinate_batch_across_dp()
│         ├─ check_ubatch_thresholds()  → 判断 token 数是否达到 ubatch 阈值
│         └─ _synchronize_dp_ranks()    → 所有 DP rank 确认相同决策
│
├─ 3. maybe_create_ubatch_slices()      → 按 token 平分创建 UBatchSlice 列表
│
├─ 4. 准备 per-ubatch 数据
│     ├─ _get_slot_mappings(ubatch_slices)      → list[dict] (每个 ubatch 一份)
│     └─ _build_attention_metadata(ubatch_slices) → list[dict] (每个 ubatch 一份)
│
├─ 5. set_forward_context()             → 设置全局 ForwardContext
│     └─ 包含 attn_metadata、slot_mapping、ubatch_slices、
│         cudagraph_runtime_mode、batch_descriptor、dp_metadata
│
└─ 6. self.model(...)                   → UBatchWrapper.__call__()
      │
      └─ UBatchWrapper.__call__()
           │
           ├─ 从 ForwardContext 读取 ubatch_slices、attn_metadata、slot_mapping
           │
           ├─ ubatch_slices is None → 直接调 self.runnable(...), 结束
           │
           ├─ [首次 FULL] _make_ubatch_metadata() + _capture_ubatches()
           ├─ [已有 cudagraph] replay()
           └─ [NONE/PIECEWISE] _make_ubatch_metadata() + _run_ubatches()

_make_ubatch_metadata:创建 per-ubatch 上下文

for i in range(num_ubatches):
    fc = create_forward_context(
        attn_metadata[i],                     # ubatch i 的注意力元数据
        slot_mapping=slot_mapping[i],         # ubatch i 的 slot mapping
        dp_metadata=ubatch_dp_metadata[i],    # ubatch i 的数据并行元数据
        cudagraph_runtime_mode=NONE,          # ubatch 内禁用 cudagraph
    )
    contexts = make_ubatch_contexts(           # 见 ubatching.py
        num_micro_batches=len(ubatch_slices),
        forward_contexts=[fc1, fc2, ...],
        ready_barrier=self.ready_barrier,
    )

make_ubatch_contexts 创建 N 个 UBatchContext,每个包含:

  • 独立的 ForwardContext(attn_metadata、slot_mapping 等)
  • cpu_wait_event[i] = cpu_events[i]
  • cpu_signal_event[i] = cpu_events[(i+1) % N]

_run_ubatches:实际线程执行

def _run_ubatches(ubatch_metadata, model):
    ubatch_threads = []
    for metadata in ubatch_metadata:
        t = Thread(target=_ubatch_thread, args=(results, model, metadata))
        t.start()
        ubatch_threads.append(t)

    self.ready_barrier.wait()                          # ① 主线程等待所有 ubatch 线程就绪
    ubatch_metadata[0].context.cpu_wait_event.set()    # ② 唤醒 ubatch[0]
    for t in ubatch_threads:
        t.join()                                       # ③ 等待全部完成

    sort_and_cat_results()                             # ④ 排序并拼接输出

每个 ubatch 线程内部:

def _ubatch_thread(results, model, metadata):
    with metadata.context:                # 进入 UBatchContext
        # __enter__:
        #   ready_barrier.wait()          → 等待主线程 + 所有 ubatch 线程
        #   cpu_wait_event.wait()         → 等待被唤醒
        #   cpu_wait_event.clear()        → 放下自己的信号牌
        #   _restore_context()            → 恢复自己的 ForwardContext
        #   update_stream(compute_stream) → 切到 compute stream

        output = model(**sliced_inputs)   # 实际推理

        # __exit__:
        #   maybe_run_recv_hook()
        #   cpu_signal_event.set()        → 唤醒下一个 ubatch
        #   cpu_wait_event.clear()        → 防御性重置

    results.append((metadata.context.id, output))

链式唤醒时序

主线程:  启动 N 个线程 → ready_barrier.wait() → ubatch[0].set() → join...
            ↓               ↓                      ↓
ubatch[0]:  start → ready_barrier.wait() → wait(event[0]) → 被唤醒 → 跑 model → exit: set(event[1])
            ↓               ↓                          ↓                    ↑
ubatch[1]:  start → ready_barrier.wait() → wait(event[1]) → 被唤醒 → 跑 model → exit: set(event[2])
            ↓               ↓                          ↓                    ↑
ubatch[2]:  start → ready_barrier.wait() → wait(event[2]) → 被唤醒 → 跑 model → ...

ubatch 线程之间是顺序执行,不是并行的。每个 ubatch 跑完 model 后通过 cpu_signal_event.set() 唤醒下一个。


CPU 信号链 vs CUDA Stream 并行

线程级:互斥(CPU 串行)

一次只有一个 ubatch 线程在 CPU 上活跃。cpu_wait_event / cpu_signal_event 保证线程间互斥——ubatch[i] 没跑完之前,ubatch[i+1] 一定阻塞在 wait() 上。

GPU 级:并行(Stream 异步)

虽然 CPU 线程串行,但每个线程持有两个 CUDA Stream

Stream用途
compute_streamML 计算(matmul、attention 等)
comm_streamall2all 通信(dispatch 发 / combine 收)

当 ubatch[0] 在 compute stream 上做矩阵乘法时,ubatch[1] 的 dispatch 操作可能已经在 comm stream 上异步运行了——因为 CUDA stream 之间是并行的。

DBO(Double Buffered Overlap)的 Layer 级别流切换

在 MoE 的 _prepare(dispatch)和 _finalize(combine)阶段,每个 ubatch 在计算和通信之间反复切换:

_prepare() 阶段:
  compute: [matmul][matmul]...
    → dbo_yield_and_switch_from_compute_to_comm()
        ├─ compute stream record gpu_compute_done_event
        ├─ cpu_signal_event.set()         → 唤醒下一个 ubatch 线程
        ├─ 当前线程阻塞等待被唤醒
        ├─ 切到 comm_stream
        └─ comm_stream.wait_event(gpu_compute_done_event)  → 等 compute 做完
    → comm: [dispatch all2all]
    → dbo_switch_to_compute_sync()
        ├─ comm stream record gpu_comm_done_event
        ├─ 切到 compute_stream
        └─ compute_stream.wait_event(gpu_comm_done_event)  → 等 comm 做完
    → 注册 recv_hook 给下一个 ubatch
    → dbo_yield() → 唤醒下一个 ubatch

_finalize() 阶段:
  compute: [计算完成]
    → dbo_yield_and_switch_from_compute_to_comm()
    → comm: [combine all2all]
    → dbo_switch_to_compute()
    → receiver: comm copy output → yield back

关键模式是 dbo_yield_and_switch_from_*:它同时做两件事:

  1. 线程级 yieldcpu_signal_event.set() 唤醒下一个 ubatch 线程,然后 cpu_wait_event.wait() 阻塞自己
  2. Stream 级同步:在当前 stream 上 record event,切换到目标 stream 后 wait_event() 确保前面的操作已完成

双 ubatch 重叠效果

时间线 →
─────────────────────────────────────────────────────────
ubatch[0]: | compute | dispatch(yield) | comm: dispatch | compute ...
ubatch[1]: | idle    | comm: dispatch  | compute         | ...
                      ↑──── overlap ────↑

ubatch[0] 在 compute 上算完专家后 yield,ubatch[1] 被唤醒开始做 dispatch 通信,而 ubatch[0] 的 dispatch 已经在 comm stream 上异步完成了——计算和通信在 GPU 层面完全重叠


哪些 Layer 使用了 DBO

只有 MoE all2all 相关代码:

文件角色
modular_kernel.py编排 _prepare / _finalize,注册 hook、yield
deepep_ht.pyDeepEP 高吞吐:完整的 yield + stream 切换
deepep_ll.pyDeepEP 低延迟:只在 hook 层面参与
nixl_ep.pyNIXL EP:同低延迟模式

Attention layer 等非 MoE 部分不涉及 DBO 流切换。


总结

概念实现
线程间同步threading.Barrier(集结号)+ threading.Event(信号牌链)
ubatch 执行顺序串行(CPU 层面),通过 event 链 0→1→2→... 依次唤醒
GPU 并行每个 ubatch 有 compute + comm 双 stream,异步重叠
Layer 级 yielddbo_yield_and_switch_from_* 同时做线程 yield + stream 切换
适用范围MoE dispatch/combine 中的 all2all 通信