前言
本文整理自一次围绕 vLLM 代码库中 DeepSeek V4 MoE 模块的技术讨论,内容涉及 MXFP4 与 NVFP4 的量化方案对比、Block Quantized GEMM 的设计原理、FP4 packed 存储格式、以及 DeepGEMM 库中 FP8×FP4 在 Blackwell 硬件上的具体实现。
一、DeepSeek V4 MoE 核心优化概览
DeepSeek V4 的 MoE 模块在 vLLM 中的实现包含了大量优化:
| 优化 | 说明 |
|---|---|
| DeepGEMM MegaMoE | 融合 EP dispatch + L1 GEMM + SwiGLU + L2 GEMM + EP combine 为单 mega-kernel,NVLink 通信与计算重叠 |
| FP4 (MXFP4/NVFP4) 权重量化 | 4-bit 浮点权重 + UE8M0 block scale |
| Expert Parallelism 多后端 | DeepEP、FlashInfer NVLink、MORI、NIXL 等多种 all-to-all 策略 |
| Fused TopK Bias Routing | sqrt(softplus) 得分函数、e_score_correction_bias、hash MoE |
| EPLB | 每层跟踪 expert 负载,动态重新分配 |
| Fused MLA Kernel | Q-norm + RoPE + KV quant + cache insert 融合为单 CUDA 核 |
| MTP (Multi-Token Prediction) | 共享 MoE 架构的 speculative decoding |
二、MXFP4 与 NVFP4 的区别
DeepSeek V4 Flash 使用 FP4 权重,有两个可选方案:MXFP4 (OCP 开放标准) 和 NVFP4 (NVIDIA 私有格式)。切换由 HuggingFace config 中的 moe_quant_algo 字段控制。
2.1 核心差异
| 维度 | MXFP4 | NVFP4 |
|---|---|---|
| 标准来源 | OCP MX (Microscaling) | NVIDIA ModelOpt |
| 量化粒度 | 每 32 个元素共享一个 scale | 每 16 个元素共享一个 scale |
| Scale 数据类型 | uint8 (UE8M0,纯指数) | float8_e4m3fn + float32 全局 scale |
| 额外全局 scale | 无 | 有 (weight_scale_2 + input_scale) |
| 典型量化方案 | W4A8 (FP4 权重 + FP8 激活) | W4A4 (FP4 权重 + FP4 激活) |
| Act 量化 block | 128 元素 (per-token-group) | 16 元素 |
| 硬件要求 | SM100 Blackwell (最优路径) | SM100 Blackwell |
| DeepSeek V4 默认 | 是 (DeepSeek-V4-Flash) | 仅 moe_quant_algo=“NVFP4” 时 |
2.2 MXFP4 的 scale 格式:UE8M0
UE8M0 是一种纯指数的 8-bit 编码:
scale = 2^(uint8_value - 127)
没有符号位,没有尾数,8 位全是指数
可表示范围: 2^(-127) ~ 2^128
对应 weight scale shape:
w13_weight_scale: (num_experts, 2*intermediate_size, hidden_size // 32)
每个 uint8 代表 32 个 FP4 权重的共享 scale
2.3 NVFP4 的双层 scale
NVFP4 使用两层 scale:
block_scale = float8_e4m3fn 每 16 个元素一个
global_scale = float32 per-tensor 或 per-output
最终 scale = block_scale x global_scale
global scale 和 input scale 在 weight 加载时融合:
# process_weights_after_loading
w13_weight_scale_2 *= w13_input_scale
三、Block Quantized GEMM 的设计原理
3.1 为什么只在 K (reduction) 维分组?
核心在于 GEMM 的语义:
C[m, n] = sum_k A[m, k] * W[n, k]
K 是 reduction 维。当权重量化为 FP4 + block scale 后:
W[n, k] 约等于 qW[n, k] * sW[n, floor(k/B)]
C[m, n] = sum_b sW[n,b] * sum_{k=bB} A[m,k] * qW[n,k]
B=32 时,dequant 操作从 O(K) 降到 O(K/32)。scale 是块的常数乘因子,可以提取到内层循环外部。
3.2 为什么不在 N 或 M 维分组?
数学上可行,但精度不可接受:
M-only 分组 (激活整个 hidden 向量共享 1 个 scale):
一个 token 的 hidden state 有 outlier 12.456 和大量正常值 -0.002
outlier 决定全局 scale,正常元素被压缩到几乎没有有效精度
N-only 分组 (多个输出神经元共享 1 个 scale):
gate 门控权重在 [-0.5, 0.5],up 投影权重在 [-5.0, 5.0]
强行共享 scale 导致细值域行精度全部浪费
结论:K 分组不是设计权衡,是 FP4 量化精度可用的必要条件。
3.3 实际上是 2D blocking
很多人误以为"只在 K 维分组"。实际 weight scale shape 是:
(N, K // B)
^ └── 每 B 元素共享一个 scale
└── 每个输出 N 行有独立的 scale 向量
这是 N 维 dense、K 维 blocked 的 2D 分组。每行独立 scale 的存储成本(N * K/B 个 uint8,约 3% 额外参数)远小于精度收益。
四、FP4 打包存储:为什么 shape 是 hidden // 2
4.1 PyTorch 的 dtype 限制
# PyTorch 没有 4-bit dtype,最小可寻址单位是 1 byte
# 两个 FP4 nibble (4-bit each) pack 进一个 uint8
想要: tensor of shape (..., 7168), dtype=虚拟fp4, itemsize=0.5
实际: tensor of shape (..., 3584), dtype=uint8, itemsize=1
每个 byte 的布局:
byte[n]: [a3 a2 a1 a0 | b3 b2 b1 b0]
^--- 值 B ---^ ^--- 值 A ---^
4.2 Kernel 层面的不同视角
| 层级 | 存储/传递 | 理解方式 |
|---|---|---|
| PyTorch Python | tensor.view(torch.uint8), shape K//2 | 当 int8 传指针 |
| 调 DeepGEMM/TRTLLM | int8* 传 C API | 当 int8 传,API 内部处理 |
| 自写 CUDA/Triton | __ldg((uint4*)ptr) 手动拆 nibble | 必须按 E2M1 查表解码 |
绝对不能做的事:
uint8_t nibble = packed_byte & 0x0F;
int8_t val = nibble; // 错误!FP4 0x2=1.0,当 int 是 2
// 送入 mma 会算出 2*3=6,但正确结果应是 1.0*1.5=2.25
五、DeepGEMM 如何做 FP8 x FP4 矩阵乘
5.1 数据类型
// sm100_fp8_fp4_mega_moe.cuh
using a_dtype_t = cutlass::float_e4m3_t; // FP8 激活
using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; // FP4 权重
b_dtype_t 使用 _unpacksmem_t 变体,表示在 SMEM 中按 1 byte/value 展开存放。
5.2 TMA 加载:自动拆 nibble
Global memory 中是 packed (2 values/byte),TMA 加载时自动展开:
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(
tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx],
k_idx, n_idx, 2);
TMA descriptor 使用 CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B:
GMEM: [packed nibble | packed nibble | ...]
|
| cp.async.bulk.tensor .b4x16_p64
| 每 16 个 packed FP4 展开 + padding 到 16B-aligned
v
SMEM: [FP4(1B) | pad | FP4(1B) | pad | ...]
对齐约束:基地址 32B 对齐、leading dim 128 元素倍数、仅 128B swizzle。
5.3 UMMA 指令:硬件内部的 FP4→FP8 cast
// 创建 block scaled 变体的指令描述符
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<
b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N,
cute::UMMA::Major::K, cute::UMMA::Major::K
>();
// 发射 UMMA
ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma(
b_desc, a_desc, // B(FP4) x A(FP8)
accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, // 是否累加
runtime_instr_desc, // 含 scale factor ID
kTmemStartColOfSFB, // B 的 UE8M0 scale
kTmemStartColOfSFA // A 的 UE8M0 scale
);
编译到 SASS 为 OMMA (FP4) 或 QMMA (FP8)。硬件内部数据路径:
UMMA 数据路径 (kind::f8f6f4, block_scale):
1. 从 TMEM 加载 B (FP4 byte)
2. 查表 E2M1->E4M3 (FP4 只有 16 个值,组合逻辑)
3. 从 TMEM 加载 A (FP8, 直通)
4. FP8 x FP8 tensor core 矩阵乘
5. 从 TMEM 加载 UE8M0 scale,乘到对应 block 输出
6. FP32 累加 -> 写回 TMEM
FP4->FP8 cast 是 UMMA 指令内部的一个 pipeline stage,不是独立 kernel。
5.4 Scale 的 TMEM 路径
UE8M0 scale 通过 UTCCP (4x32dp128bit_2cta) 从 SMEM 搬入 TMEM:
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++i) {
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
5.5 SwiGLU 后 Requant
L1 GEMM 输出到 L2 GEMM 之间,epilogue warp 手动做:
// 1. Amax reduction (warp 间归约找 absmax)
math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv);
// 2. 乘 sf_inv 缩放到 FP8 范围
const auto fp8x4_values = __nv_fp8x4_e4m3(...);
// 3. STSM 写回 SMEM (L2 act buffer)
ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr);
// 4. UE8M0 scale 存到 l2_sf_buffer
sf_base_ptr[...] = (*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
六、容易忽略的细节
6.1 权重是 checkpoint 预量化的
DeepSeek V4 的 checkpoint 里权重已经是 FP4 + UE8M0 scale 格式。quant_config.py 直接从 HF config 读取 expert_dtype="fp4"。scale 也是 checkpoint 自带的,不需要推理时计算。
6.2 Gate/Up 权重 interleaving
MegaMoE 的 L1 权重在 transform_weights_for_mega_moe() 中做了 granularity-8 的 interleave:
原始: gate[0..7] gate[8..15] ... up[0..7] up[8..15] ...
变换: gate[0] up[0] gate[1] up[1] ... | gate[8] up[8] ...
目的:L1 GEMM 输出后 gate/up 值在 TMEM 中相邻,一次 TMEM_LOAD 即可加载到寄存器直接算 SwiGLU。
6.3 UE8M0 4-in-1 打包
prepare_megamoe.py 中 4 个 UE8M0 scale 被 packed 进一个 int32:
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
BLOCK_K=128, GROUP_K=32 -> 128 元素分 4 组 -> 4 字节 -> (num_tokens, hidden//128) 的 uint32 tensor。
6.4 L2 act requant 的 block size 可能不同
L1 用 block=128,L2 的 activation scale block size 可以独立配置。SM90 移植版用了 per-64-K。MegaMoE 中 L2 act scale 有独立的 TMA descriptor。
七、总结
整体数据流:
Weight (checkpoint 预量化):
FP4 packed -> TMA unpack -> SMEM(1B/val) -> TMEM -> UMMA 内部 E2M1->E4M3 cast
Activation (推理时动态量化):
per_token_group_quant_fp8 (block=128) -> FP8 TMEM -> UMMA 直通
L1->L2 之间:
TMEM(FP32 acc) -> SwiGLU(FP32) -> amax -> sf_inv -> fp8x4_e4m3 -> SMEM -> L2 TMEM
Scale 路径:
4xUE8M0 packed in uint32 -> SMEM -> UTCCP 4x32 -> TMEM -> UMMA block_scale 硬件消费
乘法:
不是 "反量化到 f32 再乘标量"
而是 UMMA 硬件内部: FP4->FP8 cast + FP8xFP8 MMA + UE8M0 scale 融合
整个过程在 tensor core 数据路径内完成
核心要点:FP8 x FP4 在 Blackwell 上不是用 f32 标量乘做的,而是在 UMMA 指令内部通过硬件 pipeline 实现 E2M1->E4M3 查表转换后做原生 FP8 tensor core MMA,scale 也由 block_scale 变体在硬件层面处理。