背景

mHC(multi-Head Combinatorial)是 DeepSeek V4 模型中引入的一种多头残差混合机制。它将传统 Transformer 中单一的残差向量扩展为 $M$ 个并行的残差副本(multi-head residual),并在每个 block 前后通过可学习的门控和组合矩阵对多头残差进行变换。

本文以 mhc_pre_torch 为核心,从数学公式出发逐行对照 PyTorch 代码,并延伸至 mhc_post_torch 和整个 mHC 流水线,帮助读者完整理解这一算子的设计思路与实现细节。

符号约定

符号代码变量形状含义
$R$residual(..., M, H)多头残差输入
$W_{fn}$fn(M_3, M \cdot H)共享投影矩阵
$s_0, s_1, s_2$hc_scale[0/1/2]标量pre/post/comb 的 scale
$b_0, b_1, b_2$hc_base 分段见描述pre/post/comb 的 bias
$\varepsilon_{\text{rms}}$rms_eps标量RMS 归一化 epsilon
$\varepsilon_{\text{pre}}$hc_pre_eps标量pre-mix epsilon
$\varepsilon_{\text{s}}$hc_sinkhorn_eps标量Sinkhorn epsilon
$v_{\text{post}}$hc_post_mult_value标量post-mix 乘数
$T$num_tokenstoken 数
$M$hc_mult残差 head 数
$H$hidden_size隐层维度
$M_3$hc_mult3 = $2M + M^2$投影输出总维度

mhc_pre_torch 整体流程

mhc_pre_torch 的作用是:在每个 Transformer block 的入口处,从多头残差 $R$ 中计算出三种混合系数(pre-mix、post-mix、comb-mix),并将多头残差加权求和为单向量 layer_input,送入 sublayer(attention 或 FFN)。

完整流程分 7 步,下面逐步骤展开。


第 1 步:展平与类型转换

x = residual_flat.view(num_tokens, hc_mult * hidden_size).to(torch.float32)
$$ x \in \mathbb{R}^{T \times MH} $$

将形状为 (T, M, H) 的多头残差展平为 (T, M·H),并将 bf16 转换为 fp32。$M$ 个 head 的向量被首尾拼接成一维。


第 2 步:线性投影 — 计算 mixes

mixes = torch.matmul(x, fn_flat.t())
$$ \text{mixes}[t, j] = \sum_{i=0}^{MH-1} x[t, i] \cdot W_{fn}[j, i], \quad j \in [0, M_3) $$

其中 $W_{fn} \in \mathbb{R}^{M_3 \times MH}$ 是一个共享投影矩阵。一次矩阵乘同时计算出三种 logits:

区间数量用途
[:M]$M$pre-mix logits
[M:2M]$M$post-mix logits
[2M:2M+M²]$M^2$combinatorial mix logits(展平的 $M \times M$ 矩阵)

第 3 步:RMS 归一化

sqrsum = x.square().sum(dim=-1, keepdim=True)
mixes = mixes * torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps)
$$ \text{mixes}[t, j] \gets \frac{\text{mixes}[t, j]}{\sqrt{\frac{1}{MH}\sum_i x[t,i]^2 + \varepsilon_{\text{rms}}}} $$

逐 token 做 RMS 归一化,消除输入 $x$ 的幅度波动对投影结果的影响。


第 4 步:Pre-mix — 门控权重

pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult]
pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps
$$ p_k[t] = \sigma\big( \text{mixes}[t, k] \cdot s_0 + b_0[k] \big) + \varepsilon_{\text{pre}}, \quad k \in [0, M) $$

$p \in \mathbb{R}^{T \times M}$ 是每个 head 的门控权重,决定每个残差 head 对最终输出的贡献大小。加 $\varepsilon_{\text{pre}}$ 防止梯度消失。


第 5 步:Post-mix — 缩放因子

post_logits = mixes[:, hc_mult : 2 * hc_mult] * hc_scale[1] + hc_base[hc_mult : 2 * hc_mult]
post_mix = torch.sigmoid(post_logits) * hc_post_mult_value
$$ q_k[t] = \sigma\big( \text{mixes}[t, M+k] \cdot s_1 + b_1[k] \big) \cdot v_{\text{post}}, \quad k \in [0, M) $$

$q \in \mathbb{R}^{T \times M}$ 是逐 head 的缩放因子,乘以固定值 $v_{\text{post}}$ 控制量级。将在 mhc_post_torch 中用于对 sublayer 输出进行逐头加权。


第 6 步:Combinatorial mix — Sinkhorn 双重随机化

这是 mHC 最核心的步骤,通过 Sinkhorn-Knopp 迭代将 $M \times M$ 的 logits 矩阵转化为一个近似双随机矩阵(行和与列和都接近 1)。

6a. 重塑为矩阵

comb_logits = mixes[:, 2 * hc_mult :].view(num_tokens, hc_mult, hc_mult) * hc_scale[2] + hc_base[2 * hc_mult :].view(1, hc_mult, hc_mult)
$$ C_0[t, i, j] = \text{mixes}[t, 2M + i\cdot M + j] \cdot s_2 + b_2[i, j], \quad i, j \in [0, M) $$

将 $M^2$ 维的向量重塑为 $M \times M$ 矩阵,加 scale 和 bias。

6b. Softmax + epsilon

comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps
$$ A[t, i, j] = \frac{\exp(C_0[t,i,j])}{\sum_{j'} \exp(C_0[t,i,j'])} + \varepsilon_s $$

对每行做 softmax,使每行初始和为 1。

6c. 首次列归一化

comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
$$ A[t, i, j] \gets \frac{A[t, i, j]}{\sum_{i'} A[t, i', j] + \varepsilon_s} $$

6d. Sinkhorn 迭代

for _ in range(sinkhorn_repeat - 1):
    comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps)
    comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
$$ \begin{aligned} A[t, i, j] &\gets \frac{A[t, i, j]}{\sum_{j'} A[t, i, j'] + \varepsilon_s} \quad \text{(行归一化)} \\ A[t, i, j] &\gets \frac{A[t, i, j]}{\sum_{i'} A[t, i', j] + \varepsilon_s} \quad \text{(列归一化)} \end{aligned} $$

交替进行行归一化和列归一化,使 $A$ 逼近双随机矩阵。经过足够多次迭代后,$A$ 的所有行和与列和都趋近于 1,此时 $A$ 可以看作 head 之间的一个"软路由"矩阵。


第 7 步:输出 Layer Input

layer_input = torch.sum(pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1).to(torch.bfloat16)

Shape 变化过程:

表达式形状说明
pre_mix(T, M)每个 head 的标量权重
pre_mix.unsqueeze(-1)(T, M, 1)插入 size-1 维度,准备广播
residual_flat (bf16→fp32)(T, M, H)多头残差
* 逐元素乘(广播)(T, M, H)第 3 维 1 广播到 H
.sum(dim=1)(T, H)在 $M$ 维上求和,合并所有 head
.to(torch.bfloat16)(T, H)降回 bf16

数学表达:

$$ \text{layer\_input}[t, j] = \sum_{k=0}^{M-1} p_k[t] \cdot R[t, k, j], \quad j \in [0, H) $$

用 pre-mix 权重 $p_k$ 对 $M$ 个 residual head 做加权求和,将维度从 (T, M, H) 坍缩为 (T, H)


返回值总结

输出形状含义
post_mix(..., M, 1)每个 head 的 post 缩放因子 $q_k$
comb_mix(..., M, M)Sinkhorn 归一化后的双随机组合矩阵 $A_{ij}$
layer_input(..., H)门控聚合后的残差输出,送入 sublayer

整体设计思路

为什么用共享投影矩阵?

$W_{fn}$ 一次性将 $MH$ 维的多头残差投影到 $2M + M^2$ 维空间,同时产出 pre、post、comb 三种 logits。这样设计参数高效:不需要为每种 logits 单独维护一套投影。

三类系数的分工

pre-mix  (p) : 决定每个 head 对最终 layer_input 的贡献权重(软选择)
post-mix (q) : 决定每个 head 对 sublayer 输出的缩放幅度
comb-mix (A) : 建立 head 间的两两组合关系(经 Sinkhorn 归一化为双随机矩阵)

RMS 归一化的作用

消除 $x$ 的幅度波动对投影结果的影响,确保 $W_{fn}$ 的输出稳定。

mhc_post_torch — 更新多头残差

mhc_pre_torch 是 block 入口,mhc_post_torch 则是 block 出口。它将 sublayer 的输出重新合并回多头残差。

# Step 1: 用 comb_mix 矩阵线形组合旧残差 head
mixed_residual = torch.einsum("...ij,...ih->...jh", comb_res_mix, residual)
# Step 2: 把 sublayer 输出 x 按 head 加权注入
post_term = post_layer_mix * x.unsqueeze(-2)
# Step 3: 相加得到新残差
return (mixed_residual + post_term).to(residual.dtype)

einsum("...ij,...ih->...jh", comb_res_mix, residual) 等价于:

$$ \text{mixed\_residual}[j, h] = \sum_i A[i, j] \cdot R[i, h] \quad \Leftrightarrow \quad \text{mixed\_residual} = A^\top R $$

输入输出形状:

参数形状含义
x(..., H)sublayer 输出(attention 或 FFN)
residual(..., M, H)进入 sublayer 之前的多头残差
post_layer_mix(..., M, 1)post-mix 权重
comb_res_mix(..., M, M)双随机组合矩阵
返回值(..., M, H)新的多头残差,流到下一个 block

einsum 直观理解

einsum("...ij,...ih->...jh") 可以拆解为三步:

  1. 转置comb_res_mix(i, j) 转置为 (j, i),让公共下标 i 对齐
  2. 矩阵乘comb_res_mix^T @ residual(M, M) @ (M, H) → (M, H)
  3. 输出映射:保留 j(来自第一个输入的第 2 维)和 h(来自第二个输入的第 2 维)

完整 mHC 流水线

        hc_pre                     hc_post
          ↓                          ↑
R(M,H) ──→ layer_input(H) → sublayer → x(H)
                └── pre_mix(M)──→ │
                                   ├── post_term = post_mix(M,1) * x
                                   │   (sublayer 输出按头加权注入)
R(M,H) ──────────────────────────→┼── mixed_residual = comb_mix^T @ R
                                   │   (双随机矩阵重排旧残差 head)
                                   ↓
                               新 R(M,H)

与传统 Transformer 残差的对比

传统残差流:

x₀ → sublayer → x₁ → sublayer → x₂ → ...
      ↕              ↕
   x₀ + out₀      x₁ + out₁

残差只是一条直连边。

mHC 残差流:

残差从单个向量扩展为 $M$ 个并行副本。每个 block 中:

  1. hc_pre:从 (M, H) 多头残差算出三种系数,加权求和为单向量送入 sublayer
  2. hc_post:用组合矩阵重排旧残差 head,将 sublayer 输出按头加权注入,得到新的多头残差

这等价于在隐层之上多了一个可学习的"头间路由/组合"维度,让信息可以有选择地流向不同 head,而不是简单地加总到同一向量中。

与其他 Op 的对比

特性MHCPreOpMHCPostOpHCHeadOp
输入residual(M,H) + fn + scale/base + eps/iterx(H) + residual(M,H) + post(M,1) + comb(M,M)hidden_states(M,H) + fn + scale/base
输出post(M,1), comb(M,M), layer_input(H)residual(M,H)hidden_states(H)
Sinkhorn✅ 生成❌ 不涉及❌ 不涉及
post_mix✅ 生成✅ 消费❌ 不涉及
comb_mix✅ 生成(Sinkhorn 归一化)✅ 消费(einsum 混合)❌ 不涉及

HCHeadOp 是模型最后一层使用的简化版:去掉 Sinkhorn 和 post_mix,只保留 RMS norm → linear → sigmoid → weighted sum,将多头残差坍缩为最终的输出向量。

总结

mhc_pre_torch 是 mHC 机制的核心算子之一,通过一个共享投影矩阵同时计算出三种混合系数,再经 Sinkhorn 迭代将组合矩阵转化为双随机矩阵。配合 mhc_post_torch,它实现了 Transformer 中多头残差的生成、门控、组合与更新,为模型提供了更灵活的信息路由能力。

从数学公式到 CUDA kernel 的实现映射清晰直接:矩阵乘对应 torch.matmul,逐元素运算对应 broadcast 操作,Sinkhorn 迭代对应循环中的行/列归一化。整个算子可以高效地利用 GPU 的 tensor core 和并行计算能力。