前言

本文整理自一次围绕 vLLM 代码库中 DeepSeek V4 MoE 模块的技术讨论,内容涉及 MXFP4 与 NVFP4 的量化方案对比、Block Quantized GEMM 的设计原理、FP4 packed 存储格式、以及 DeepGEMM 库中 FP8×FP4 在 Blackwell 硬件上的具体实现。


一、DeepSeek V4 MoE 核心优化概览

DeepSeek V4 的 MoE 模块在 vLLM 中的实现包含了大量优化:

优化说明
DeepGEMM MegaMoE融合 EP dispatch + L1 GEMM + SwiGLU + L2 GEMM + EP combine 为单 mega-kernel,NVLink 通信与计算重叠
FP4 (MXFP4/NVFP4) 权重量化4-bit 浮点权重 + UE8M0 block scale
Expert Parallelism 多后端DeepEP、FlashInfer NVLink、MORI、NIXL 等多种 all-to-all 策略
Fused TopK Bias Routingsqrt(softplus) 得分函数、e_score_correction_bias、hash MoE
EPLB每层跟踪 expert 负载,动态重新分配
Fused MLA KernelQ-norm + RoPE + KV quant + cache insert 融合为单 CUDA 核
MTP (Multi-Token Prediction)共享 MoE 架构的 speculative decoding

二、MXFP4 与 NVFP4 的区别

DeepSeek V4 Flash 使用 FP4 权重,有两个可选方案:MXFP4 (OCP 开放标准) 和 NVFP4 (NVIDIA 私有格式)。切换由 HuggingFace config 中的 moe_quant_algo 字段控制。

2.1 核心差异

维度MXFP4NVFP4
标准来源OCP MX (Microscaling)NVIDIA ModelOpt
量化粒度每 32 个元素共享一个 scale每 16 个元素共享一个 scale
Scale 数据类型uint8 (UE8M0,纯指数)float8_e4m3fn + float32 全局 scale
额外全局 scale有 (weight_scale_2 + input_scale)
典型量化方案W4A8 (FP4 权重 + FP8 激活)W4A4 (FP4 权重 + FP4 激活)
Act 量化 block128 元素 (per-token-group)16 元素
硬件要求SM100 Blackwell (最优路径)SM100 Blackwell
DeepSeek V4 默认是 (DeepSeek-V4-Flash)仅 moe_quant_algo=“NVFP4” 时

2.2 MXFP4 的 scale 格式:UE8M0

UE8M0 是一种纯指数的 8-bit 编码:

scale = 2^(uint8_value - 127)

没有符号位,没有尾数,8 位全是指数
可表示范围: 2^(-127) ~ 2^128

对应 weight scale shape:

w13_weight_scale: (num_experts, 2*intermediate_size, hidden_size // 32)
每个 uint8 代表 32 个 FP4 权重的共享 scale

2.3 NVFP4 的双层 scale

NVFP4 使用两层 scale:

block_scale = float8_e4m3fn    每 16 个元素一个
global_scale = float32         per-tensor 或 per-output

最终 scale = block_scale x global_scale

global scale 和 input scale 在 weight 加载时融合:

# process_weights_after_loading
w13_weight_scale_2 *= w13_input_scale

三、Block Quantized GEMM 的设计原理

3.1 为什么只在 K (reduction) 维分组?

核心在于 GEMM 的语义:

C[m, n] = sum_k A[m, k] * W[n, k]

K 是 reduction 维。当权重量化为 FP4 + block scale 后:

W[n, k] 约等于 qW[n, k] * sW[n, floor(k/B)]

C[m, n] = sum_b sW[n,b] * sum_{k=bB} A[m,k] * qW[n,k]

B=32 时,dequant 操作从 O(K) 降到 O(K/32)。scale 是块的常数乘因子,可以提取到内层循环外部。

3.2 为什么不在 N 或 M 维分组?

数学上可行,但精度不可接受:

M-only 分组 (激活整个 hidden 向量共享 1 个 scale):
  一个 token 的 hidden state 有 outlier 12.456 和大量正常值 -0.002
  outlier 决定全局 scale,正常元素被压缩到几乎没有有效精度

N-only 分组 (多个输出神经元共享 1 个 scale):
  gate 门控权重在 [-0.5, 0.5],up 投影权重在 [-5.0, 5.0]
  强行共享 scale 导致细值域行精度全部浪费

结论:K 分组不是设计权衡,是 FP4 量化精度可用的必要条件。

3.3 实际上是 2D blocking

很多人误以为"只在 K 维分组"。实际 weight scale shape 是:

(N, K // B)
  ^     └── 每 B 元素共享一个 scale
  └── 每个输出 N 行有独立的 scale 向量

这是 N 维 dense、K 维 blocked 的 2D 分组。每行独立 scale 的存储成本(N * K/B 个 uint8,约 3% 额外参数)远小于精度收益。


四、FP4 打包存储:为什么 shape 是 hidden // 2

4.1 PyTorch 的 dtype 限制

# PyTorch 没有 4-bit dtype,最小可寻址单位是 1 byte
# 两个 FP4 nibble (4-bit each) pack 进一个 uint8

想要: tensor of shape (..., 7168), dtype=虚拟fp4, itemsize=0.5
实际: tensor of shape (..., 3584), dtype=uint8, itemsize=1

每个 byte 的布局:

byte[n]:  [a3 a2 a1 a0 | b3 b2 b1 b0]
           ^--- 值 B ---^ ^--- 值 A ---^

4.2 Kernel 层面的不同视角

层级存储/传递理解方式
PyTorch Pythontensor.view(torch.uint8), shape K//2当 int8 传指针
调 DeepGEMM/TRTLLMint8* 传 C API当 int8 传,API 内部处理
自写 CUDA/Triton__ldg((uint4*)ptr) 手动拆 nibble必须按 E2M1 查表解码

绝对不能做的事:

uint8_t nibble = packed_byte & 0x0F;
int8_t val = nibble;   // 错误!FP4 0x2=1.0,当 int 是 2
// 送入 mma 会算出 2*3=6,但正确结果应是 1.0*1.5=2.25

五、DeepGEMM 如何做 FP8 x FP4 矩阵乘

5.1 数据类型

// sm100_fp8_fp4_mega_moe.cuh
using a_dtype_t = cutlass::float_e4m3_t;                    // FP8 激活
using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; // FP4 权重

b_dtype_t 使用 _unpacksmem_t 变体,表示在 SMEM 中按 1 byte/value 展开存放。

5.2 TMA 加载:自动拆 nibble

Global memory 中是 packed (2 values/byte),TMA 加载时自动展开:

tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(
    tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx],
    k_idx, n_idx, 2);

TMA descriptor 使用 CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B

GMEM:  [packed nibble | packed nibble | ...]
         |
         | cp.async.bulk.tensor .b4x16_p64
         | 每 16 个 packed FP4 展开 + padding 到 16B-aligned
         v
SMEM:  [FP4(1B) | pad | FP4(1B) | pad | ...]

对齐约束:基地址 32B 对齐、leading dim 128 元素倍数、仅 128B swizzle。

5.3 UMMA 指令:硬件内部的 FP4→FP8 cast

// 创建 block scaled 变体的指令描述符
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<
    b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
    UMMA_M, UMMA_N,
    cute::UMMA::Major::K, cute::UMMA::Major::K
>();

// 发射 UMMA
ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma(
    b_desc, a_desc,                              // B(FP4) x A(FP8)
    accum_stage_idx * UMMA_N,
    k_block_idx > 0 or k > 0,                    // 是否累加
    runtime_instr_desc,                          // 含 scale factor ID
    kTmemStartColOfSFB,                          // B 的 UE8M0 scale
    kTmemStartColOfSFA                           // A 的 UE8M0 scale
);

编译到 SASS 为 OMMA (FP4) 或 QMMA (FP8)。硬件内部数据路径:

UMMA 数据路径 (kind::f8f6f4, block_scale):
  1. 从 TMEM 加载 B (FP4 byte)
  2. 查表 E2M1->E4M3 (FP4 只有 16 个值,组合逻辑)
  3. 从 TMEM 加载 A (FP8, 直通)
  4. FP8 x FP8 tensor core 矩阵乘
  5. 从 TMEM 加载 UE8M0 scale,乘到对应 block 输出
  6. FP32 累加 -> 写回 TMEM

FP4->FP8 cast 是 UMMA 指令内部的一个 pipeline stage,不是独立 kernel。

5.4 Scale 的 TMEM 路径

UE8M0 scale 通过 UTCCP (4x32dp128bit_2cta) 从 SMEM 搬入 TMEM:

for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++i) {
    cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}

5.5 SwiGLU 后 Requant

L1 GEMM 输出到 L2 GEMM 之间,epilogue warp 手动做:

// 1. Amax reduction (warp 间归约找 absmax)
math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv);

// 2. 乘 sf_inv 缩放到 FP8 范围
const auto fp8x4_values = __nv_fp8x4_e4m3(...);

// 3. STSM 写回 SMEM (L2 act buffer)
ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr);

// 4. UE8M0 scale 存到 l2_sf_buffer
sf_base_ptr[...] = (*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);

六、容易忽略的细节

6.1 权重是 checkpoint 预量化的

DeepSeek V4 的 checkpoint 里权重已经是 FP4 + UE8M0 scale 格式。quant_config.py 直接从 HF config 读取 expert_dtype="fp4"。scale 也是 checkpoint 自带的,不需要推理时计算。

6.2 Gate/Up 权重 interleaving

MegaMoE 的 L1 权重在 transform_weights_for_mega_moe() 中做了 granularity-8 的 interleave:

原始: gate[0..7] gate[8..15] ... up[0..7] up[8..15] ...
变换: gate[0] up[0] gate[1] up[1] ... | gate[8] up[8] ...

目的:L1 GEMM 输出后 gate/up 值在 TMEM 中相邻,一次 TMEM_LOAD 即可加载到寄存器直接算 SwiGLU。

6.3 UE8M0 4-in-1 打包

prepare_megamoe.py 中 4 个 UE8M0 scale 被 packed 进一个 int32:

packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)

BLOCK_K=128, GROUP_K=32 -> 128 元素分 4 组 -> 4 字节 -> (num_tokens, hidden//128) 的 uint32 tensor。

6.4 L2 act requant 的 block size 可能不同

L1 用 block=128,L2 的 activation scale block size 可以独立配置。SM90 移植版用了 per-64-K。MegaMoE 中 L2 act scale 有独立的 TMA descriptor。


七、总结

整体数据流:

Weight (checkpoint 预量化):
  FP4 packed -> TMA unpack -> SMEM(1B/val) -> TMEM -> UMMA 内部 E2M1->E4M3 cast

Activation (推理时动态量化):
  per_token_group_quant_fp8 (block=128) -> FP8 TMEM -> UMMA 直通

  L1->L2 之间:
  TMEM(FP32 acc) -> SwiGLU(FP32) -> amax -> sf_inv -> fp8x4_e4m3 -> SMEM -> L2 TMEM

Scale 路径:
  4xUE8M0 packed in uint32 -> SMEM -> UTCCP 4x32 -> TMEM -> UMMA block_scale 硬件消费

乘法:
  不是 "反量化到 f32 再乘标量"
  而是 UMMA 硬件内部: FP4->FP8 cast + FP8xFP8 MMA + UE8M0 scale 融合
  整个过程在 tensor core 数据路径内完成

核心要点:FP8 x FP4 在 Blackwell 上不是用 f32 标量乘做的,而是在 UMMA 指令内部通过硬件 pipeline 实现 E2M1->E4M3 查表转换后做原生 FP8 tensor core MMA,scale 也由 block_scale 变体在硬件层面处理。