Differential Transformer V2
About this article
A Blog post by Microsoft on Hugging Face
Back to Articles Differential Transformer V2 Enterprise Article Published January 20, 2026 Upvote 41 +35 Li Dong unilm Follow microsoft Tianzhu Ye, Li Dong, Yutao Sun, Furu Wei Github Link Notion Link (for better readability) Code We compare DIFF V2 with DIFF V1 below: (For simplicity, we omit the batch dimension and assume that both the input and output of the following flash_attn_func are three-dimensional tensors (tokens, heads, head dimension). Heads belonging to the same GQA group are arranged contiguously in the output) Note DIFF V2 subtracts two heads that are in the same GQA group, which means they share the same key and value. This is crucial to performance. See design ablations section and Github code. def DiffAttnV1( layer_index, q1, q2, k1, k2, v, lam_q1, lam_k1, lam_q2, lam_k2, ): """ q1, q2: (N, h/2, d) k1, k2: (N, h_kv/2, d) v: (N, h_kv/2, 2d) lam_*: (d,) """ attn1 = flash_attn_func(q1, k1, v) attn2 = flash_attn_func(q2, k2, v) lam_init = 0.8 - 0.6 * \ exp(-0.3 * layer_index) lam1 = exp(sum(lam_q1 * lam_k1) lam2 = exp(sum(lam_q2 * lam_k2) lam = lam1 - lam2 + lam_init attn = attn1 - lam * attn2 attn = rmsnorm(attn) attn = attn * (1 - lam_init) return attn def DiffAttnV2( q, k, v, lam ): """ q: (N, 2h, d) k: (N, h_kv, d) v: (N, h_kv, d) lam: (N, h, 1) """ attn = flash_attn_func(q, k, v) attn1, attn2 = (attn[:, 0::2], attn[:, 1::2]) lam_val = sigmoid(lam) attn = attn1 - lam_val * attn2 return attn Full code at: unilm/Diff-Transformer/Diff-Transformer-V2 at mast...