前言

经典 MoE 的计算过程

在深入 MegaMoE 之前,先梳理一下经典 MoE(Mixture of Experts)层的完整计算流程。以一个具体配置为例:

T = 8          # 当前 batch 中的 token 数
H = 7168       # hidden size
I = 2048       # intermediate size(每个 expert 的 FFN 中间维度)
E = 256        # 总 expert 数量
K = 6          # 每个 token 激活的 expert 数(top-K)

第一步:路由(Routing)

输入 hidden_states 形状为 [T, H](即 [8, 7168])。

通过 Gate 线性层:

gate = hidden_states @ W_gate^T        # W_gate: [E, H]
→ gate: [T, E] = [8, 256]              # 每个 token 对每个 expert 的得分

对每个 token 施加 scoring 函数(如 softmaxsqrt(softplus)),然后取 top-K:

scores = scoring_func(gate)             # [8, 256]
topk_weights, topk_ids = topk(scores, K)  # 各 [8, 6], int [8, 6]

这里 topk_ids[t][k] 是 token t 选中的第 k 个 expert 的全局编号,topk_weights[t][k] 是对应的权重。

第二步:Dispatch(分发)

将每个 token 的 hidden state 发送到其选中 expert 所在的设备。

在专家并行(EP)模式下,每个 rank 持有 E / ep_size 个完整的 expert,通过 all-to-all 通信将 token 分发到持有对应 expert 的 rank。分发后每个 rank 拿到自己负责的 expert 的输入。

注意区分:TP(Tensor Parallelism)是对 expert 权重做切分(每个 rank 参与全部专家但只算部分 GEMM,然后 all-reduce),EP(Expert Parallelism)是对 expert 本身做切分(每个 rank 持有不同的专家集合)。DeepSeek V4 的 MoE 使用 EP,而非 TP。

# 假设 8 个 token × 6 topk = 48 个 (token, expert) 对
# 经分发后,某 rank 上的 local expert 接收到的 token 集合
# 例如 expert 0 收到 token {1, 3, 5},expert 1 收到 token {0, 2, 4, 7},...

在 MegaMoE 中这个过程是 fused 在 kernel 内部的,但在传统实现中这是独立的 all-to-all kernel。

第三步:专家计算(Expert Computation)

每个 expert 是一个 SwiGLU FFN:

h_gate = x @ W1^T      # gate projection
h_up   = x @ W3^T      # up projection
h_act  = silu(h_gate) * h_up   # SwiGLU 激活
y      = h_act @ W2^T  # down projection

其中 W1, W3 形状为 [I, H],W2 形状为 [H, I]。很多实现会将 W1 和 W3 合并为 W13:

W13: [2*I, H] = [4096, 7168]   # gate 和 up 拼接
W2:  [H, I]   = [7168, 2048]

以 token 2 分发给 expert 7 为例:

x = hidden_states[2]              # [7168]

gate_up = x @ W13[expert_7]^T     # [4096]  ← [7168] @ [4096, 7168]^T
→ gate = gate_up[:2048]            # split: 前一半是 gate
→ up   = gate_up[2048:]           # 后一半是 up

h_act = silu(gate) * up           # [2048]
y     = h_act @ W2[expert_7]^T    # [7168]

→ final: y * topk_weight          # 加权

第四步:Combine(归约)

每个 rank 上的每个 expert 产生部分输出后,需要对每个 token 的所有 top-K expert 输出做加权求和:

# token 2 选中了 expert {3, 7, 12, 45, 128, 200}
# 每个产生 [7168] 的向量
y_2 = Σ w_k * y_{expert_k}        # 加权求和,shape [7168]

这也是一个跨 rank 的 all-to-all 通信,将分散的输出收集回原始 token 所在的 rank。

完整形状链

输入:              hidden_states [8, 7168]
路由:              gate [8, 256] → topk_weights [8, 6], topk_ids [8, 6]
分发 (per expert): 例如 expert 7 收到 3 个 token → x_expert7 [3, 7168]
L1 GEMM:           gate_up_expert7 [3, 4096] = x_expert7 @ W13[7]^T
SwiGLU:            h_act_expert7 [3, 2048] = silu(gate) * up
L2 GEMM:           y_expert7 [3, 7168] = h_act_expert7 @ W2[7]^T
Combine (per token): y_token [7168] = Σ topk_weight * y_expert
输出:              y [8, 7168]

vLLM 的两种 Activation Format

vLLM 的 MoE 模块定义了两种激活格式,决定 expert kernel 如何组织输入数据:

Format形状含义适用后端
Standard[T_local, H]标准的 2D 张量,连续存储所有 tokenDeepEP HT、FlashInfer、MORI、Triton 等
BatchedExperts[E_local, max_tokens_per_expert, H]3D 张量,按 expert 分组,每个 expert 一个 padded 的 token 块DeepEP LL、NIXL、Humming 等

Standard 格式中,dispatch 后每个 rank 拿到所有本地 expert 的 token 拼接在一起([T_local, H]),expert kernel 内部通过 sorted_token_ids 等辅助数组来确定每个 expert 处理哪段 token。这是最通用的格式。

BatchedExperts 格式中,dispatch 后 token 已经按 expert 排列成 [E_local, max_per_expert, H],不足的 padding。expert kernel 直接按 expert 维度并行处理,但需要处理 padding 引入的冗余计算。

MegaMoE 不直接使用这两种格式——它通过 symmetric buffer 使用自己的 dispatch 机制,但思想上更接近 BatchedExperts:kernel 内部维护 per-expert 的 token pool,按 BLOCK_M 对齐组织。


什么时候需要 permute/unpermute?

“Permute” 指将 token 从原始顺序按 expert 分组重排,“unpermute” 指计算完成后将结果还原回原始 token 顺序。vLLM 中有三条不同的路径:

路径 A:Standard 格式(索引间接访问)— 无显式 permute

hidden_states [T, H]  (按输入顺序)
    │
    ├── All2All dispatch → [T_local, H] (仍是连续存储)
    │
    ├── moe_align_block_size  (仅生成索引,不动数据)
    │   ├── sorted_token_ids: 按 expert 排序的虚拟 token 索引
    │   └── expert_ids: 每个 block 对应的 expert 编号
    │
    ├── fused_moe_kernel  (通过索引间接访问 hidden_states)
    │   tl.load(sorted_token_ids + pid) // top_k → 原始 token 行号
    │   → 输出 [T_local * topk, H]  (在 kernel 内部按 expert 顺序排列)
    │
    └── moe_sum / buffer.combine  (scatter-reduce 回原始顺序)
        → 输出 [T, H]

结论:无显式 permute,token 数据从未被复制移动,全靠 sorted_token_ids 做索引重映射。

路径 B:BatchedExperts 格式(显式 permute)

hidden_states [T, H]
    │
    ├── BatchedPrepare: 显式 permute
    │   for each expert: boolean mask → copy → b_a1[expert, :rows, :]
    │   → 输出 [E_local, max_per_expert, H]  (显存中按 expert 分组)
    │
    ├── batched_triton_kernel  (直接按 expert 维度取 slice)
    │   输出 [E_local, max_per_expert, H]
    │
    └── TopKWeightAndReduceNaiveBatched: 显式 unpermute
        for each expert: slice → scatter-add → output[token]
        → 输出 [T, H]

结论:显式 permute + unpermute,数据被实际搬移。

路径 C:MegaMoE(permute/unpermute fused 在核内)

hidden_states [T, H]
    │
    ├── prepare_megamoe_inputs (Triton: bf16→fp8, 不重排)
    │
    └── deep_gemm.fp8_fp4_mega_moe (单核)
        ├── Dispatch warps: TMA pull → L1 pool  ← 这是 permute,通信即重排
        ├── L1 + SwiGLU + L2
        └── Epilogue combine: TMA load → f32 acc → y[token]  ← 这是 unpermute,通信即归约

结论:permute/unpermute 不存在独立步骤,完全 fuse 在 all-to-all 通信中。

何时选择什么

条件路径
后端支持 Standard + 有 sorted_token_ids 实现无 permute(DeepEP HT、FlashInfer、Triton 等)
后端天然产生 expert 分组输出(如 DeepEP LL)显式 permute(BatchedExperts)
EP 通信和计算 fused 在单核内核内隐式(MegaMoE)
极少量 token:T * topk * 4 <= E连 sorted_token_ids 都跳过,直接用 pid_m 做 naive 分配

一、架构概览

整体思路

传统 MoE 前向的步骤:

输入 hidden_states [T, H]
  │
  ├──路由: GateLinear → TopK → scoring → topk_ids/topk_weights
  │
  ├──Dispatch (All2All): 按 topk 将 token 分发到对应 expert 所在 GPU
  │
  ├──L1 GEMM: hidden @ W13  → gate + up 两个投影
  │
  ├──SwiGLU: silu(gate) * up
  │
  ├──L2 GEMM: intermediate @ W2 → down 投影
  │
  └──Combine (All2All Reduce): 收集各 expert 输出,加权求和

传统实现至少有 6~8 个独立的 kernel launch(dispatch kernel、两个 GEMM kernel、activation kernel、combine kernel、若干通信 barrier 同步 kernel),每一步都要读写显存,通信和计算串行执行。

MegaMoE 的做法:单个 persistent CUDA kernel,一个 launch 完成全部。

线程分工

每个 SM 运行一个 persistent CTA(grid size = SM 数量),CTA 内有 4 组不同角色的 warp:

Warp 角色线程数职责
Dispatch Warps128+管理 all-to-all 通信:统计 expert 接收 token 数、通过 NVLink 写 source indices、Grid Sync + NVLink Barrier 跨所有 rank 同步、从远端 rank 的对称缓冲区 pull token 数据到本地 L1 pool
MMA Load Warp (A)32TMA 加载激活矩阵 (FP8 token + UE8M0 scales) → shared memory
MMA Load Warp (B)32TMA 加载权重矩阵 (FP4 weights + UE8M0 scales) → shared memory
MMA Issue Warp32通过 UTCCP 将 scale factor 从 shared memory 拷贝到 TMEM,然后发出 UMMA block-scaled MMA 指令 (FP8×FP4)
Epilogue Warps128+L1 完成时:从 TMEM 读取结果,执行 SwiGLU + FP8 requant,写入 L2 pool;L2 完成时:将结果转为 BF16 通过 NVLink 写到远端的 combine buffer,再从各远端 load 贡献值,float32 累加,BF16 写回输出

TMEM(Tensor Memory)是 Blackwell SM100 新增的片上内存,专为 Tensor Core UMMA 指令服务,线程不可直接读写,仅通过 UMMA/UTCCP 指令访问。


二、Fused 了什么——远远不止 3 个 kernel

MegaMoE 把以下 6 个阶段 全部融合进一个 kernel:

阶段 1:Dispatch(路由分发)

  • 每个 dispatch warp 遍历自己负责的 token-topk 对,atomicAdd_block 统计每个 expert 的 token 数
  • 将 per-expert 计数写到远端 rank 的 workspace(NVLink 原子操作)
  • 写入 source metadata(远端 rank 需要用这个来知道每个 token 来自哪里,以便 combine 阶段写回)
  • Grid Sync + NVLink Barrier:所有 SM 在各自 rank 内同步,然后所有 rank 跨 NVLink 同步
  • TMA pull:用 round-robin 策略从各远端 rank 的 symmetric buffer 拉取 FP8 token + scale 到本地的 L1 token pool(TMA 异步拷贝 + mbarrier 同步)

阶段 2:L1 GEMM

pool_tokens [pool_M, H] FP8 @ W13 [2*I, H] FP4^T → TMEM [pool_M, 2*I]
  • 使用 Blackwell 的 tcgen05 UMMA block-scaled 指令
  • A 矩阵(activation)= FP8 E4M3,B 矩阵(weight)= FP4 E2M1(MXFP4 packed)
  • 流水线:TMA load A + SFA → shared memory → UTCCP → TMEM → MMA → TMEM accum
  • 双 buffer pipeline(kNumStages 个 stage),通过 full/empty barrier 同步

阶段 3:SwiGLU(在 TMEM 中原地完成)

output = silu(gate) * up * topk_weight
  • 权重是 interleaved 存储的(gate[0..7], up[0..7], gate[8..15], up[8..15], …)
  • 使用 SM100_TMEM_LOAD_16dp256b1x 指令从 TMEM 加载 gate/up 对
  • 计算 silu(gate) × up × topk_weight
  • 计算 amax → UE8M0 scale → 量化到 FP8 E4M3
  • TMA store 到 L2 token pool(复用 L1 pool 的空间)
  • 这个过程不需要离开 TMEM,也不需要写回显存再读取

阶段 4:L2 GEMM

pool_tokens [pool_M, I] FP8 @ W2 [H, I] FP4^T → TMEM [pool_M, H]
  • 与 L1 相同模式,但依赖 L1 SwiGLU 输出的 arrival mask
  • 由于 L1 的输出 N 维度是 BLOCK_K/2(因为 SwiGLU 把 2*I 减半为 I),L2 需要等待 2 倍 L1 block 的 arrival

阶段 5:Combine(结果归约)

  • L2 完成后,epilogue warp 从 TMEM 读取结果
  • 转为 BF16,通过 NVLink 写到各远端 rank 的 combine buffer
  • 然后对于本地的每个 token,读取 topk indices,从各远端 rank 的 combine buffer TMA load 对应贡献值
  • float32 accumulate 所有 top-k 贡献
  • 转为 BF16,TMA store 到输出 y[token_idx]

阶段 6:Workspace 清理

  • Dispatch warps 在 combine 结束后,清理 per-expert counts、arrival masks 等
  • 再次 NVLink Barrier 确保所有 rank 的 cleanup 完成

三、数据流图

Hidden states [T, H] bf16, topk_weights [T, K], topk_ids [T, K] int64
    │
    │  [外部 Triton kernel] prepare_megamoe_inputs
    │    └─ bf16 → FP8 E4M3 + UE8M0 scales (group=32)
    │    └─ 4 scales 打包为 uint32
    │    └─ topk_ids/weights → int64/float32
    ↓
SymmBuffer: x [T, H] fp8, x_sf [T, H/128] i32,
            topk_idx [T, K] i64, topk_weights [T, K] f32
    │
    │  [单个 persistent kernel] deep_gemm.fp8_fp4_mega_moe
    │
    ├── DISPATCH WARPS
    │   ├── atomic 统计 per-expert token 数 → workspace
    │   ├── NVLink 写 source indices 到远端
    │   ├── Grid Sync + NVLink Barrier (跨所有 rank)
    │   ├── TMA round-robin pull 远端 token → L1 pool
    │   └── 写入 topk_weight 和 source metadata
    │
    ├── MMA WARPS (pipelined)
    │   ├── L1: TMA load A(fp8) + B(fp4) → UMMA → TMEM
    │   ├── [Epilogue] TMEM → silu(gate)*up*weight → fp8 → L2 pool
    │   └── L2: TMA load A(fp8) + B(fp4) → UMMA → TMEM
    │
    └── EPILOGUE WARPS (combine)
        ├── TMEM → bf16 → NVLink 写远端 combine buffer
        ├── TMA load 各远端贡献 → f32 accumulate
        └── bf16 → TMA store → y [T, H]

四、关键数据结构

对称缓冲区(Symmetric Buffer)

所有参与 EP 的 rank 分配一块相同大小的缓冲区,包含:

  • x / x_sf:输入 token 的 FP8 数据和 UE8M0 scales
  • topk_idx / topk_weights:路由结果
  • Expert send/receive counts(NVLink 可原子访问)
  • Source token-topk indices(dispatch 时写入,combine 时读取)
  • L1/L2 到达计数和 mask
  • Combine buffer(每个 token 的 top-k 贡献)

对称缓冲区使用 get_symm_buffer_for_mega_moe 一次性预分配,按 (max_num_tokens, num_experts, topk, hidden, intermediate) 参数缓存。运行时按实际 num_tokens 切片使用。

TMEM Allocation

每 2 个 CTA(一个 cluster)共享 TMEM,分为:

  • kNumAccumTmemCols:UMMA 累加结果
  • kNumSFATmemCols:SFA 的 scale factor(UTCCP 每 4 col 一组)
  • kNumSFBTmemCols:SFB 的 scale factor

通过 Allocator2Sm 分配,每个 warpgroup 使用自己的切片。

Workspace Layout

在 symmetric buffer 的基础上组织各 rank 的 per-expert 管理数据:

  • expert_send_count[global_expert_id]:每个 global expert 已分配的 token 数(NVLink atomic)
  • expert_recv_count[rank][local_expert_id]:某 rank 上的某 local expert 预期接收 token 数
  • src_token_topk_idx[local_expert_id][rank][slot]:source token-topk 索引(dispatch 写,combine 读)
  • token_src_metadata[pool_token_idx]:pool 中 token 来自哪个 rank、原始 token 索引、topk 索引
  • l1_arrival_count[pool_block]:L1 流水线等待 token 到达计数器
  • l2_arrival_mask[pool_block]:L2 等待 L1 SwiGLU 完成的 bitmask

五、为什么能这么高效

1. 融合通信与计算

传统 MoE 中,dispatch kernel 完成 all-to-all 后显存写回,然后 GEMM kernel 再读入。MegaMoE 中 dispatch warps 直接 TMA pull token 到 L1 pool,其后的 MMA warps 立即消费。Combine 同理,epilogue warps 计算完直接 NVLink 写到远端,再 TMA load 回来归约。通信和计算完全 overlap。

2. TMEM 消除中间显存读写

L1 GEMM 结果在 TMEM 中,SwiGLU 直接在 TMEM 中计算,结果 FP8 requant 后 TMA store 到 L2 pool。L2 GEMM 再从 L2 pool 读入。SwiGLU 的中间结果不需要写回显存再读取,省去了约 2× 的显存带宽。

3. 单核调度开销极低

传统多 kernel 方案中,每个 kernel launch 有微秒级的启动延迟,且中间 tensors 需要在显存和寄存器之间反复搬运。MegaMoE 使用 persistent kernel 模式(所有 SM 持续运行一个 kernel,使用 wave 调度器分批处理 token),launch 开销和 barrier 同步开销降到了最低。

4. Wave 调度器

MegaMoEScheduler 根据 num_tokensnum_experts 自动确定 wave 大小(每 wave 处理 num_experts_per_wave 个 expert 的 BLOCK_M 个 token),多 SM 协作处理各 wave。scheduler 寄存器中缓存了各 expert 的 recv_count,避免重复全局加载。

5. FP4 权重存储减半

FP4 权重(每 byte 存 2 个 4-bit 浮点值)配合 block=32 的 UE8M0 scale,相比 FP8 的权重存储再减半,且 FP4 E2M1 格式在 4-bit 精度下保持了较好的数值范围。

6. TMA 与流水线

所有数据搬运使用 TMA(Tensor Memory Accelerator),不占用线程资源。双 buffer pipeline 通过 full/empty barrier 同步,有效地隐藏了访存延迟。


六、限制与代价

限制原因
仅 Blackwell SM100依赖 UMMA + TMEM + UTCCP 等 Blackwell 新指令
必须开启 Expert Parallelism核内硬编码了跨 rank NVLink 通信逻辑
hidden/intermediate 必须 128 对齐TMA 描述子和 block 划分要求
仅支持 sqrtsoftplus 路由与其他 scoring_func 的数值精度和 topk 语义未对齐
仅支持 FP4 expert dtype核内硬编码 FP8×FP4 UMMA 描述子
代码复杂度极高单核 ~1400 行 CUDA 模板,调度器、布局、通信逻辑高度耦合

七、总结

维度传统 MoE 实现MegaMoE
Kernel 数量6~8+ 个 launch1 个 persistent kernel
SwiGLU 中间显存需要写回再读取TMEM 原地计算
Dispatch/Combine 与计算串行完全重叠
数据搬运多轮显存读写TMA 异步 + zero-copy
权重精度FP8FP4(存储减半)
适用 GPUH100+B200/Blackwell 独占

MegaMoE 通过极致的算子融合和硬件特性利用,将 MoE 推理的延迟推到了接近单个 GEMM kernel 的理论下限,代价是完全绑定 Blackwell 硬件和 Expert Parallelism 拓扑。对于运行 DeepSeek V4 的 B200 集群而言,这是当前最高效的推理方案。