背景
mHC(multi-Head Combinatorial)是 DeepSeek V4 模型中引入的一种多头残差混合机制。它将传统 Transformer 中单一的残差向量扩展为 $M$ 个并行的残差副本(multi-head residual),并在每个 block 前后通过可学习的门控和组合矩阵对多头残差进行变换。
本文以 mhc_pre_torch 为核心,从数学公式出发逐行对照 PyTorch 代码,并延伸至 mhc_post_torch 和整个 mHC 流水线,帮助读者完整理解这一算子的设计思路与实现细节。
符号约定
| 符号 | 代码变量 | 形状 | 含义 |
|---|---|---|---|
| $R$ | residual | (..., M, H) | 多头残差输入 |
| $W_{fn}$ | fn | (M_3, M \cdot H) | 共享投影矩阵 |
| $s_0, s_1, s_2$ | hc_scale[0/1/2] | 标量 | pre/post/comb 的 scale |
| $b_0, b_1, b_2$ | hc_base 分段 | 见描述 | pre/post/comb 的 bias |
| $\varepsilon_{\text{rms}}$ | rms_eps | 标量 | RMS 归一化 epsilon |
| $\varepsilon_{\text{pre}}$ | hc_pre_eps | 标量 | pre-mix epsilon |
| $\varepsilon_{\text{s}}$ | hc_sinkhorn_eps | 标量 | Sinkhorn epsilon |
| $v_{\text{post}}$ | hc_post_mult_value | 标量 | post-mix 乘数 |
| $T$ | num_tokens | — | token 数 |
| $M$ | hc_mult | — | 残差 head 数 |
| $H$ | hidden_size | — | 隐层维度 |
| $M_3$ | hc_mult3 = $2M + M^2$ | — | 投影输出总维度 |
mhc_pre_torch 整体流程
mhc_pre_torch 的作用是:在每个 Transformer block 的入口处,从多头残差 $R$ 中计算出三种混合系数(pre-mix、post-mix、comb-mix),并将多头残差加权求和为单向量 layer_input,送入 sublayer(attention 或 FFN)。
完整流程分 7 步,下面逐步骤展开。
第 1 步:展平与类型转换
x = residual_flat.view(num_tokens, hc_mult * hidden_size).to(torch.float32)
将形状为 (T, M, H) 的多头残差展平为 (T, M·H),并将 bf16 转换为 fp32。$M$ 个 head 的向量被首尾拼接成一维。
第 2 步:线性投影 — 计算 mixes
mixes = torch.matmul(x, fn_flat.t())
其中 $W_{fn} \in \mathbb{R}^{M_3 \times MH}$ 是一个共享投影矩阵。一次矩阵乘同时计算出三种 logits:
| 区间 | 数量 | 用途 |
|---|---|---|
[:M] | $M$ | pre-mix logits |
[M:2M] | $M$ | post-mix logits |
[2M:2M+M²] | $M^2$ | combinatorial mix logits(展平的 $M \times M$ 矩阵) |
第 3 步:RMS 归一化
sqrsum = x.square().sum(dim=-1, keepdim=True)
mixes = mixes * torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps)
逐 token 做 RMS 归一化,消除输入 $x$ 的幅度波动对投影结果的影响。
第 4 步:Pre-mix — 门控权重
pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult]
pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps
$p \in \mathbb{R}^{T \times M}$ 是每个 head 的门控权重,决定每个残差 head 对最终输出的贡献大小。加 $\varepsilon_{\text{pre}}$ 防止梯度消失。
第 5 步:Post-mix — 缩放因子
post_logits = mixes[:, hc_mult : 2 * hc_mult] * hc_scale[1] + hc_base[hc_mult : 2 * hc_mult]
post_mix = torch.sigmoid(post_logits) * hc_post_mult_value
$q \in \mathbb{R}^{T \times M}$ 是逐 head 的缩放因子,乘以固定值 $v_{\text{post}}$ 控制量级。将在 mhc_post_torch 中用于对 sublayer 输出进行逐头加权。
第 6 步:Combinatorial mix — Sinkhorn 双重随机化
这是 mHC 最核心的步骤,通过 Sinkhorn-Knopp 迭代将 $M \times M$ 的 logits 矩阵转化为一个近似双随机矩阵(行和与列和都接近 1)。
6a. 重塑为矩阵
comb_logits = mixes[:, 2 * hc_mult :].view(num_tokens, hc_mult, hc_mult) * hc_scale[2] + hc_base[2 * hc_mult :].view(1, hc_mult, hc_mult)
将 $M^2$ 维的向量重塑为 $M \times M$ 矩阵,加 scale 和 bias。
6b. Softmax + epsilon
comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps
对每行做 softmax,使每行初始和为 1。
6c. 首次列归一化
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
6d. Sinkhorn 迭代
for _ in range(sinkhorn_repeat - 1):
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps)
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
交替进行行归一化和列归一化,使 $A$ 逼近双随机矩阵。经过足够多次迭代后,$A$ 的所有行和与列和都趋近于 1,此时 $A$ 可以看作 head 之间的一个"软路由"矩阵。
第 7 步:输出 Layer Input
layer_input = torch.sum(pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1).to(torch.bfloat16)
Shape 变化过程:
| 表达式 | 形状 | 说明 |
|---|---|---|
pre_mix | (T, M) | 每个 head 的标量权重 |
pre_mix.unsqueeze(-1) | (T, M, 1) | 插入 size-1 维度,准备广播 |
residual_flat (bf16→fp32) | (T, M, H) | 多头残差 |
* 逐元素乘(广播) | (T, M, H) | 第 3 维 1 广播到 H |
.sum(dim=1) | (T, H) | 在 $M$ 维上求和,合并所有 head |
.to(torch.bfloat16) | (T, H) | 降回 bf16 |
数学表达:
$$ \text{layer\_input}[t, j] = \sum_{k=0}^{M-1} p_k[t] \cdot R[t, k, j], \quad j \in [0, H) $$用 pre-mix 权重 $p_k$ 对 $M$ 个 residual head 做加权求和,将维度从 (T, M, H) 坍缩为 (T, H)。
返回值总结
| 输出 | 形状 | 含义 |
|---|---|---|
post_mix | (..., M, 1) | 每个 head 的 post 缩放因子 $q_k$ |
comb_mix | (..., M, M) | Sinkhorn 归一化后的双随机组合矩阵 $A_{ij}$ |
layer_input | (..., H) | 门控聚合后的残差输出,送入 sublayer |
整体设计思路
为什么用共享投影矩阵?
$W_{fn}$ 一次性将 $MH$ 维的多头残差投影到 $2M + M^2$ 维空间,同时产出 pre、post、comb 三种 logits。这样设计参数高效:不需要为每种 logits 单独维护一套投影。
三类系数的分工
pre-mix (p) : 决定每个 head 对最终 layer_input 的贡献权重(软选择)
post-mix (q) : 决定每个 head 对 sublayer 输出的缩放幅度
comb-mix (A) : 建立 head 间的两两组合关系(经 Sinkhorn 归一化为双随机矩阵)
RMS 归一化的作用
消除 $x$ 的幅度波动对投影结果的影响,确保 $W_{fn}$ 的输出稳定。
mhc_post_torch — 更新多头残差
mhc_pre_torch 是 block 入口,mhc_post_torch 则是 block 出口。它将 sublayer 的输出重新合并回多头残差。
# Step 1: 用 comb_mix 矩阵线形组合旧残差 head
mixed_residual = torch.einsum("...ij,...ih->...jh", comb_res_mix, residual)
# Step 2: 把 sublayer 输出 x 按 head 加权注入
post_term = post_layer_mix * x.unsqueeze(-2)
# Step 3: 相加得到新残差
return (mixed_residual + post_term).to(residual.dtype)
einsum("...ij,...ih->...jh", comb_res_mix, residual) 等价于:
输入输出形状:
| 参数 | 形状 | 含义 |
|---|---|---|
x | (..., H) | sublayer 输出(attention 或 FFN) |
residual | (..., M, H) | 进入 sublayer 之前的多头残差 |
post_layer_mix | (..., M, 1) | post-mix 权重 |
comb_res_mix | (..., M, M) | 双随机组合矩阵 |
| 返回值 | (..., M, H) | 新的多头残差,流到下一个 block |
einsum 直观理解
einsum("...ij,...ih->...jh") 可以拆解为三步:
- 转置:
comb_res_mix从(i, j)转置为(j, i),让公共下标i对齐 - 矩阵乘:
comb_res_mix^T @ residual,(M, M) @ (M, H) → (M, H) - 输出映射:保留
j(来自第一个输入的第 2 维)和h(来自第二个输入的第 2 维)
完整 mHC 流水线
hc_pre hc_post
↓ ↑
R(M,H) ──→ layer_input(H) → sublayer → x(H)
└── pre_mix(M)──→ │
├── post_term = post_mix(M,1) * x
│ (sublayer 输出按头加权注入)
R(M,H) ──────────────────────────→┼── mixed_residual = comb_mix^T @ R
│ (双随机矩阵重排旧残差 head)
↓
新 R(M,H)
与传统 Transformer 残差的对比
传统残差流:
x₀ → sublayer → x₁ → sublayer → x₂ → ...
↕ ↕
x₀ + out₀ x₁ + out₁
残差只是一条直连边。
mHC 残差流:
残差从单个向量扩展为 $M$ 个并行副本。每个 block 中:
hc_pre:从(M, H)多头残差算出三种系数,加权求和为单向量送入 sublayerhc_post:用组合矩阵重排旧残差 head,将 sublayer 输出按头加权注入,得到新的多头残差
这等价于在隐层之上多了一个可学习的"头间路由/组合"维度,让信息可以有选择地流向不同 head,而不是简单地加总到同一向量中。
与其他 Op 的对比
| 特性 | MHCPreOp | MHCPostOp | HCHeadOp |
|---|---|---|---|
| 输入 | residual(M,H) + fn + scale/base + eps/iter | x(H) + residual(M,H) + post(M,1) + comb(M,M) | hidden_states(M,H) + fn + scale/base |
| 输出 | post(M,1), comb(M,M), layer_input(H) | 新 residual(M,H) | hidden_states(H) |
| Sinkhorn | ✅ 生成 | ❌ 不涉及 | ❌ 不涉及 |
| post_mix | ✅ 生成 | ✅ 消费 | ❌ 不涉及 |
| comb_mix | ✅ 生成(Sinkhorn 归一化) | ✅ 消费(einsum 混合) | ❌ 不涉及 |
HCHeadOp 是模型最后一层使用的简化版:去掉 Sinkhorn 和 post_mix,只保留 RMS norm → linear → sigmoid → weighted sum,将多头残差坍缩为最终的输出向量。
总结
mhc_pre_torch 是 mHC 机制的核心算子之一,通过一个共享投影矩阵同时计算出三种混合系数,再经 Sinkhorn 迭代将组合矩阵转化为双随机矩阵。配合 mhc_post_torch,它实现了 Transformer 中多头残差的生成、门控、组合与更新,为模型提供了更灵活的信息路由能力。
从数学公式到 CUDA kernel 的实现映射清晰直接:矩阵乘对应 torch.matmul,逐元素运算对应 broadcast 操作,Sinkhorn 迭代对应循环中的行/列归一化。整个算子可以高效地利用 GPU 的 tensor core 和并行计算能力。