前言
经典 MoE 的计算过程
在深入 MegaMoE 之前,先梳理一下经典 MoE(Mixture of Experts)层的完整计算流程。以一个具体配置为例:
T = 8 # 当前 batch 中的 token 数
H = 7168 # hidden size
I = 2048 # intermediate size(每个 expert 的 FFN 中间维度)
E = 256 # 总 expert 数量
K = 6 # 每个 token 激活的 expert 数(top-K)
第一步:路由(Routing)
输入 hidden_states 形状为 [T, H](即 [8, 7168])。
通过 Gate 线性层:
gate = hidden_states @ W_gate^T # W_gate: [E, H]
→ gate: [T, E] = [8, 256] # 每个 token 对每个 expert 的得分
对每个 token 施加 scoring 函数(如 softmax 或 sqrt(softplus)),然后取 top-K:
scores = scoring_func(gate) # [8, 256]
topk_weights, topk_ids = topk(scores, K) # 各 [8, 6], int [8, 6]
这里 topk_ids[t][k] 是 token t 选中的第 k 个 expert 的全局编号,topk_weights[t][k] 是对应的权重。
第二步:Dispatch(分发)
将每个 token 的 hidden state 发送到其选中 expert 所在的设备。
在专家并行(EP)模式下,每个 rank 持有 E / ep_size 个完整的 expert,通过 all-to-all 通信将 token 分发到持有对应 expert 的 rank。分发后每个 rank 拿到自己负责的 expert 的输入。
注意区分:TP(Tensor Parallelism)是对 expert 权重做切分(每个 rank 参与全部专家但只算部分 GEMM,然后 all-reduce),EP(Expert Parallelism)是对 expert 本身做切分(每个 rank 持有不同的专家集合)。DeepSeek V4 的 MoE 使用 EP,而非 TP。
# 假设 8 个 token × 6 topk = 48 个 (token, expert) 对
# 经分发后,某 rank 上的 local expert 接收到的 token 集合
# 例如 expert 0 收到 token {1, 3, 5},expert 1 收到 token {0, 2, 4, 7},...
在 MegaMoE 中这个过程是 fused 在 kernel 内部的,但在传统实现中这是独立的 all-to-all kernel。
第三步:专家计算(Expert Computation)
每个 expert 是一个 SwiGLU FFN:
h_gate = x @ W1^T # gate projection
h_up = x @ W3^T # up projection
h_act = silu(h_gate) * h_up # SwiGLU 激活
y = h_act @ W2^T # down projection
其中 W1, W3 形状为 [I, H],W2 形状为 [H, I]。很多实现会将 W1 和 W3 合并为 W13:
W13: [2*I, H] = [4096, 7168] # gate 和 up 拼接
W2: [H, I] = [7168, 2048]
以 token 2 分发给 expert 7 为例:
x = hidden_states[2] # [7168]
gate_up = x @ W13[expert_7]^T # [4096] ← [7168] @ [4096, 7168]^T
→ gate = gate_up[:2048] # split: 前一半是 gate
→ up = gate_up[2048:] # 后一半是 up
h_act = silu(gate) * up # [2048]
y = h_act @ W2[expert_7]^T # [7168]
→ final: y * topk_weight # 加权
第四步:Combine(归约)
每个 rank 上的每个 expert 产生部分输出后,需要对每个 token 的所有 top-K expert 输出做加权求和:
# token 2 选中了 expert {3, 7, 12, 45, 128, 200}
# 每个产生 [7168] 的向量
y_2 = Σ w_k * y_{expert_k} # 加权求和,shape [7168]
这也是一个跨 rank 的 all-to-all 通信,将分散的输出收集回原始 token 所在的 rank。
完整形状链
输入: hidden_states [8, 7168]
路由: gate [8, 256] → topk_weights [8, 6], topk_ids [8, 6]
分发 (per expert): 例如 expert 7 收到 3 个 token → x_expert7 [3, 7168]
L1 GEMM: gate_up_expert7 [3, 4096] = x_expert7 @ W13[7]^T
SwiGLU: h_act_expert7 [3, 2048] = silu(gate) * up
L2 GEMM: y_expert7 [3, 7168] = h_act_expert7 @ W2[7]^T
Combine (per token): y_token [7168] = Σ topk_weight * y_expert
输出: y [8, 7168]
vLLM 的两种 Activation Format
vLLM 的 MoE 模块定义了两种激活格式,决定 expert kernel 如何组织输入数据:
| Format | 形状 | 含义 | 适用后端 |
|---|---|---|---|
| Standard | [T_local, H] | 标准的 2D 张量,连续存储所有 token | DeepEP HT、FlashInfer、MORI、Triton 等 |
| BatchedExperts | [E_local, max_tokens_per_expert, H] | 3D 张量,按 expert 分组,每个 expert 一个 padded 的 token 块 | DeepEP LL、NIXL、Humming 等 |
Standard 格式中,dispatch 后每个 rank 拿到所有本地 expert 的 token 拼接在一起([T_local, H]),expert kernel 内部通过 sorted_token_ids 等辅助数组来确定每个 expert 处理哪段 token。这是最通用的格式。
BatchedExperts 格式中,dispatch 后 token 已经按 expert 排列成 [E_local, max_per_expert, H],不足的 padding。expert kernel 直接按 expert 维度并行处理,但需要处理 padding 引入的冗余计算。
MegaMoE 不直接使用这两种格式——它通过 symmetric buffer 使用自己的 dispatch 机制,但思想上更接近 BatchedExperts:kernel 内部维护 per-expert 的 token pool,按 BLOCK_M 对齐组织。
什么时候需要 permute/unpermute?
“Permute” 指将 token 从原始顺序按 expert 分组重排,“unpermute” 指计算完成后将结果还原回原始 token 顺序。vLLM 中有三条不同的路径:
路径 A:Standard 格式(索引间接访问)— 无显式 permute
hidden_states [T, H] (按输入顺序)
│
├── All2All dispatch → [T_local, H] (仍是连续存储)
│
├── moe_align_block_size (仅生成索引,不动数据)
│ ├── sorted_token_ids: 按 expert 排序的虚拟 token 索引
│ └── expert_ids: 每个 block 对应的 expert 编号
│
├── fused_moe_kernel (通过索引间接访问 hidden_states)
│ tl.load(sorted_token_ids + pid) // top_k → 原始 token 行号
│ → 输出 [T_local * topk, H] (在 kernel 内部按 expert 顺序排列)
│
└── moe_sum / buffer.combine (scatter-reduce 回原始顺序)
→ 输出 [T, H]
结论:无显式 permute,token 数据从未被复制移动,全靠 sorted_token_ids 做索引重映射。
路径 B:BatchedExperts 格式(显式 permute)
hidden_states [T, H]
│
├── BatchedPrepare: 显式 permute
│ for each expert: boolean mask → copy → b_a1[expert, :rows, :]
│ → 输出 [E_local, max_per_expert, H] (显存中按 expert 分组)
│
├── batched_triton_kernel (直接按 expert 维度取 slice)
│ 输出 [E_local, max_per_expert, H]
│
└── TopKWeightAndReduceNaiveBatched: 显式 unpermute
for each expert: slice → scatter-add → output[token]
→ 输出 [T, H]
结论:显式 permute + unpermute,数据被实际搬移。
路径 C:MegaMoE(permute/unpermute fused 在核内)
hidden_states [T, H]
│
├── prepare_megamoe_inputs (Triton: bf16→fp8, 不重排)
│
└── deep_gemm.fp8_fp4_mega_moe (单核)
├── Dispatch warps: TMA pull → L1 pool ← 这是 permute,通信即重排
├── L1 + SwiGLU + L2
└── Epilogue combine: TMA load → f32 acc → y[token] ← 这是 unpermute,通信即归约
结论:permute/unpermute 不存在独立步骤,完全 fuse 在 all-to-all 通信中。
何时选择什么
| 条件 | 路径 |
|---|---|
后端支持 Standard + 有 sorted_token_ids 实现 | 无 permute(DeepEP HT、FlashInfer、Triton 等) |
| 后端天然产生 expert 分组输出(如 DeepEP LL) | 显式 permute(BatchedExperts) |
| EP 通信和计算 fused 在单核内 | 核内隐式(MegaMoE) |
极少量 token:T * topk * 4 <= E | 连 sorted_token_ids 都跳过,直接用 pid_m 做 naive 分配 |
一、架构概览
整体思路
传统 MoE 前向的步骤:
输入 hidden_states [T, H]
│
├──路由: GateLinear → TopK → scoring → topk_ids/topk_weights
│
├──Dispatch (All2All): 按 topk 将 token 分发到对应 expert 所在 GPU
│
├──L1 GEMM: hidden @ W13 → gate + up 两个投影
│
├──SwiGLU: silu(gate) * up
│
├──L2 GEMM: intermediate @ W2 → down 投影
│
└──Combine (All2All Reduce): 收集各 expert 输出,加权求和
传统实现至少有 6~8 个独立的 kernel launch(dispatch kernel、两个 GEMM kernel、activation kernel、combine kernel、若干通信 barrier 同步 kernel),每一步都要读写显存,通信和计算串行执行。
MegaMoE 的做法:单个 persistent CUDA kernel,一个 launch 完成全部。
线程分工
每个 SM 运行一个 persistent CTA(grid size = SM 数量),CTA 内有 4 组不同角色的 warp:
| Warp 角色 | 线程数 | 职责 |
|---|---|---|
| Dispatch Warps | 128+ | 管理 all-to-all 通信:统计 expert 接收 token 数、通过 NVLink 写 source indices、Grid Sync + NVLink Barrier 跨所有 rank 同步、从远端 rank 的对称缓冲区 pull token 数据到本地 L1 pool |
| MMA Load Warp (A) | 32 | TMA 加载激活矩阵 (FP8 token + UE8M0 scales) → shared memory |
| MMA Load Warp (B) | 32 | TMA 加载权重矩阵 (FP4 weights + UE8M0 scales) → shared memory |
| MMA Issue Warp | 32 | 通过 UTCCP 将 scale factor 从 shared memory 拷贝到 TMEM,然后发出 UMMA block-scaled MMA 指令 (FP8×FP4) |
| Epilogue Warps | 128+ | L1 完成时:从 TMEM 读取结果,执行 SwiGLU + FP8 requant,写入 L2 pool;L2 完成时:将结果转为 BF16 通过 NVLink 写到远端的 combine buffer,再从各远端 load 贡献值,float32 累加,BF16 写回输出 |
TMEM(Tensor Memory)是 Blackwell SM100 新增的片上内存,专为 Tensor Core UMMA 指令服务,线程不可直接读写,仅通过 UMMA/UTCCP 指令访问。
二、Fused 了什么——远远不止 3 个 kernel
MegaMoE 把以下 6 个阶段 全部融合进一个 kernel:
阶段 1:Dispatch(路由分发)
- 每个 dispatch warp 遍历自己负责的 token-topk 对,
atomicAdd_block统计每个 expert 的 token 数 - 将 per-expert 计数写到远端 rank 的 workspace(NVLink 原子操作)
- 写入 source metadata(远端 rank 需要用这个来知道每个 token 来自哪里,以便 combine 阶段写回)
- Grid Sync + NVLink Barrier:所有 SM 在各自 rank 内同步,然后所有 rank 跨 NVLink 同步
- TMA pull:用 round-robin 策略从各远端 rank 的 symmetric buffer 拉取 FP8 token + scale 到本地的 L1 token pool(TMA 异步拷贝 + mbarrier 同步)
阶段 2:L1 GEMM
pool_tokens [pool_M, H] FP8 @ W13 [2*I, H] FP4^T → TMEM [pool_M, 2*I]
- 使用 Blackwell 的
tcgen05UMMA block-scaled 指令 - A 矩阵(activation)= FP8 E4M3,B 矩阵(weight)= FP4 E2M1(MXFP4 packed)
- 流水线:TMA load A + SFA → shared memory → UTCCP → TMEM → MMA → TMEM accum
- 双 buffer pipeline(kNumStages 个 stage),通过 full/empty barrier 同步
阶段 3:SwiGLU(在 TMEM 中原地完成)
output = silu(gate) * up * topk_weight
- 权重是 interleaved 存储的(gate[0..7], up[0..7], gate[8..15], up[8..15], …)
- 使用
SM100_TMEM_LOAD_16dp256b1x指令从 TMEM 加载 gate/up 对 - 计算 silu(gate) × up × topk_weight
- 计算 amax → UE8M0 scale → 量化到 FP8 E4M3
- TMA store 到 L2 token pool(复用 L1 pool 的空间)
- 这个过程不需要离开 TMEM,也不需要写回显存再读取
阶段 4:L2 GEMM
pool_tokens [pool_M, I] FP8 @ W2 [H, I] FP4^T → TMEM [pool_M, H]
- 与 L1 相同模式,但依赖 L1 SwiGLU 输出的 arrival mask
- 由于 L1 的输出 N 维度是
BLOCK_K/2(因为 SwiGLU 把 2*I 减半为 I),L2 需要等待 2 倍 L1 block 的 arrival
阶段 5:Combine(结果归约)
- L2 完成后,epilogue warp 从 TMEM 读取结果
- 转为 BF16,通过 NVLink 写到各远端 rank 的 combine buffer
- 然后对于本地的每个 token,读取 topk indices,从各远端 rank 的 combine buffer TMA load 对应贡献值
- float32 accumulate 所有 top-k 贡献
- 转为 BF16,TMA store 到输出
y[token_idx]
阶段 6:Workspace 清理
- Dispatch warps 在 combine 结束后,清理 per-expert counts、arrival masks 等
- 再次 NVLink Barrier 确保所有 rank 的 cleanup 完成
三、数据流图
Hidden states [T, H] bf16, topk_weights [T, K], topk_ids [T, K] int64
│
│ [外部 Triton kernel] prepare_megamoe_inputs
│ └─ bf16 → FP8 E4M3 + UE8M0 scales (group=32)
│ └─ 4 scales 打包为 uint32
│ └─ topk_ids/weights → int64/float32
↓
SymmBuffer: x [T, H] fp8, x_sf [T, H/128] i32,
topk_idx [T, K] i64, topk_weights [T, K] f32
│
│ [单个 persistent kernel] deep_gemm.fp8_fp4_mega_moe
│
├── DISPATCH WARPS
│ ├── atomic 统计 per-expert token 数 → workspace
│ ├── NVLink 写 source indices 到远端
│ ├── Grid Sync + NVLink Barrier (跨所有 rank)
│ ├── TMA round-robin pull 远端 token → L1 pool
│ └── 写入 topk_weight 和 source metadata
│
├── MMA WARPS (pipelined)
│ ├── L1: TMA load A(fp8) + B(fp4) → UMMA → TMEM
│ ├── [Epilogue] TMEM → silu(gate)*up*weight → fp8 → L2 pool
│ └── L2: TMA load A(fp8) + B(fp4) → UMMA → TMEM
│
└── EPILOGUE WARPS (combine)
├── TMEM → bf16 → NVLink 写远端 combine buffer
├── TMA load 各远端贡献 → f32 accumulate
└── bf16 → TMA store → y [T, H]
四、关键数据结构
对称缓冲区(Symmetric Buffer)
所有参与 EP 的 rank 分配一块相同大小的缓冲区,包含:
x/x_sf:输入 token 的 FP8 数据和 UE8M0 scalestopk_idx/topk_weights:路由结果- Expert send/receive counts(NVLink 可原子访问)
- Source token-topk indices(dispatch 时写入,combine 时读取)
- L1/L2 到达计数和 mask
- Combine buffer(每个 token 的 top-k 贡献)
对称缓冲区使用 get_symm_buffer_for_mega_moe 一次性预分配,按 (max_num_tokens, num_experts, topk, hidden, intermediate) 参数缓存。运行时按实际 num_tokens 切片使用。
TMEM Allocation
每 2 个 CTA(一个 cluster)共享 TMEM,分为:
kNumAccumTmemCols:UMMA 累加结果kNumSFATmemCols:SFA 的 scale factor(UTCCP 每 4 col 一组)kNumSFBTmemCols:SFB 的 scale factor
通过 Allocator2Sm 分配,每个 warpgroup 使用自己的切片。
Workspace Layout
在 symmetric buffer 的基础上组织各 rank 的 per-expert 管理数据:
expert_send_count[global_expert_id]:每个 global expert 已分配的 token 数(NVLink atomic)expert_recv_count[rank][local_expert_id]:某 rank 上的某 local expert 预期接收 token 数src_token_topk_idx[local_expert_id][rank][slot]:source token-topk 索引(dispatch 写,combine 读)token_src_metadata[pool_token_idx]:pool 中 token 来自哪个 rank、原始 token 索引、topk 索引l1_arrival_count[pool_block]:L1 流水线等待 token 到达计数器l2_arrival_mask[pool_block]:L2 等待 L1 SwiGLU 完成的 bitmask
五、为什么能这么高效
1. 融合通信与计算
传统 MoE 中,dispatch kernel 完成 all-to-all 后显存写回,然后 GEMM kernel 再读入。MegaMoE 中 dispatch warps 直接 TMA pull token 到 L1 pool,其后的 MMA warps 立即消费。Combine 同理,epilogue warps 计算完直接 NVLink 写到远端,再 TMA load 回来归约。通信和计算完全 overlap。
2. TMEM 消除中间显存读写
L1 GEMM 结果在 TMEM 中,SwiGLU 直接在 TMEM 中计算,结果 FP8 requant 后 TMA store 到 L2 pool。L2 GEMM 再从 L2 pool 读入。SwiGLU 的中间结果不需要写回显存再读取,省去了约 2× 的显存带宽。
3. 单核调度开销极低
传统多 kernel 方案中,每个 kernel launch 有微秒级的启动延迟,且中间 tensors 需要在显存和寄存器之间反复搬运。MegaMoE 使用 persistent kernel 模式(所有 SM 持续运行一个 kernel,使用 wave 调度器分批处理 token),launch 开销和 barrier 同步开销降到了最低。
4. Wave 调度器
MegaMoEScheduler 根据 num_tokens 和 num_experts 自动确定 wave 大小(每 wave 处理 num_experts_per_wave 个 expert 的 BLOCK_M 个 token),多 SM 协作处理各 wave。scheduler 寄存器中缓存了各 expert 的 recv_count,避免重复全局加载。
5. FP4 权重存储减半
FP4 权重(每 byte 存 2 个 4-bit 浮点值)配合 block=32 的 UE8M0 scale,相比 FP8 的权重存储再减半,且 FP4 E2M1 格式在 4-bit 精度下保持了较好的数值范围。
6. TMA 与流水线
所有数据搬运使用 TMA(Tensor Memory Accelerator),不占用线程资源。双 buffer pipeline 通过 full/empty barrier 同步,有效地隐藏了访存延迟。
六、限制与代价
| 限制 | 原因 |
|---|---|
| 仅 Blackwell SM100 | 依赖 UMMA + TMEM + UTCCP 等 Blackwell 新指令 |
| 必须开启 Expert Parallelism | 核内硬编码了跨 rank NVLink 通信逻辑 |
| hidden/intermediate 必须 128 对齐 | TMA 描述子和 block 划分要求 |
| 仅支持 sqrtsoftplus 路由 | 与其他 scoring_func 的数值精度和 topk 语义未对齐 |
| 仅支持 FP4 expert dtype | 核内硬编码 FP8×FP4 UMMA 描述子 |
| 代码复杂度极高 | 单核 ~1400 行 CUDA 模板,调度器、布局、通信逻辑高度耦合 |
七、总结
| 维度 | 传统 MoE 实现 | MegaMoE |
|---|---|---|
| Kernel 数量 | 6~8+ 个 launch | 1 个 persistent kernel |
| SwiGLU 中间显存 | 需要写回再读取 | TMEM 原地计算 |
| Dispatch/Combine 与计算 | 串行 | 完全重叠 |
| 数据搬运 | 多轮显存读写 | TMA 异步 + zero-copy |
| 权重精度 | FP8 | FP4(存储减半) |
| 适用 GPU | H100+ | B200/Blackwell 独占 |
MegaMoE 通过极致的算子融合和硬件特性利用,将 MoE 推理的延迟推到了接近单个 GEMM kernel 的理论下限,代价是完全绑定 Blackwell 硬件和 Expert Parallelism 拓扑。对于运行 DeepSeek V4 的 B200 集群而言,这是当前最高效的推理方案。