Diff_Transformer
Differential Transformer[1]
作者是来自MSRA和Tsinghua的Tianzhu Ye等人。论文引用[1]:Ye, Tianzhu et al. “Differential Transformer.” ArXiv abs/2410.05258 (2024): n. pag.
Time
- 2025.Apr
Key Words
- 一句话来说:用两个softmax attention functions之间的差,作为attention socres,来消除attention noise。
总结
- Transformer 倾向于将attnetion过多地分配给不相关的context,在这个工作中,作者介绍了Diff Transformer,放大了relevant context的attention,同时抵消了noise,特别地,differential attention机制通过计算两个独立的 softmax 注意力图之间的差值来得到注意力分数。subtraction 操作cancel 了noise,提升了sparse attention patterns的出现。实验结果表明:Diff Transformer在多个scaling up model size和training token的多种设置下,超过了Transformer。另外更有趣的是,它在实际应用中,提供了notable advantages,例如long-context modeling,key information retrieval和幻觉缓解,in-context learning,activation outliers的reduction。通过减少不相关context的distract, Diff Transformer在question answering和text summarization上缓解了幻觉。对于in-context learning,Diff Transformer不仅能增强精度,也对于order permutation更加robust,order permutation被认为是chronic robustness issue。结果表明Diff Transformer是一个高效和有前途的架构。
decoder-only 的Transformer 作为一个事实上的LLMs的标准。Transformer的heart是attention mechanism,采用Softmax function来对sequence中的多个tokens的importance进行加权。然而,最近的研究表明, LLMs面临从context中精确地检索key information的挑战。 作者将分配给context的不同部分的normalized attention scores可视化。这个task是检索嵌入在一堆文档中的一个answer。这个可视化揭示了Transformer倾向于将一小部分的attention scores分配给正确的answer,同时,不成比例地关注不相关的context。实验进一步证明:Transformers在处理此类能力方面存在困难。这个问题源自于不可忽视的分配给不相关context的attention scores,压过了correct answer。作者将这些无关的scores称之为attention noise。在这个文章中,作者提出了DIFF Transformer,是一个foundation 架构 for LLMs。提出了differential attention mechanism,用differential denoising来消除attention noise,特别地,将query和key vectors分为两个groups,计算两个separate softmax attention maps,讲个maps的subtracting的结果被视为attention scores。这个differential attention mechanism消除了attention noise,鼓励models聚焦于critical information。这个方法类似于noise-canceling headphones和differential amplifiers in electrical engineering,两个信号的difference消除了common-mode noise。作者提出了attention scores的normalized distribution,观察到,Diff Transformer将higher scores分配给correct answer,更多的lower scores分配给不相关的context。图展示了提出了方法实现了notable improvements。
作者做了大量的实验 on language modeling,将DIFF Transformer在parameter count、training tokens和context length上进行scale up。这个scaling curves表明:DIFF Transformer仅需要Transformer的65%的model size或者training tokens,实现了comparable的language modeling performance。另外,DIFF Transformer在多个下游任务上超过了Transformer,另外,DIFF Transformer有很多的intriguing的优点。
作者将decoder-only model作为一个example来描述架构,这个model堆叠了 \(L\) 个DIFF Transformer layers,给定一个input sequence,将input embeddings pack到 \(X^0 = [x_1, ..., x_N] \in mathcal{R}^{N \times d_{model}}\),这个输入进一步被contextualized,得到输出 \(X^L\),每个layer包含两个modules:一个differential attention followed by a feed-forward network module,相比于Transformer,主要的不同在于:传统的softmax attention用differential attention替换了,然而macro layout是一样的。作者也采用pre-RMSNorm和SwiGLU。
differential attention mechanism将query、key和value vectors 映射到outputs,用query、key vectors来计算attention scores,然年后计算value vectors的加权和。critical design是用一对softmax functions来cancel noise of attention scores,特别地,给定输入 \(X\),首先将它们project到query、key和value, \(Q_1, Q_2, K_1, K_2 \in mathcal{R}^{N \times d}, V \in mathcal{R}^{N \times 2d}\)。然后,differential attention operator \(DiffAttn(.)\) 计算输出如下:
\[ [Q_1; Q_2] = XW^Q, \quad [K_1; K_2] = XW^K, \quad V = XW^V \]
\[ \text{DiffAttn}(X) = (\text{softmax}(\frac{Q_1 K_1^T}{\sqrt{d}}) - \lambda \, \text{softmax}(\frac{Q_2 K_2^T}{\sqrt{d}})) V \]
\(\lambda\) 是一个learnable scalar,为了synchronize learning dynamics,将scalar \(\lambda\) 重新参数化为:
\[ \lambda = \exp(\lambda_{q_1} \cdot \lambda_{k_1}) - \exp(\lambda_{q_2} \cdot \lambda_{k_2}) + \lambda_{\text{init}} \]
\(\lambda_{q_1}, \lambda_{k_1}, \lambda_{q_2}, \lambda_{k_2}, \lambda_{\text{init}}\) 是learnable scalars,初始化 \(\lambda_{\text{init}} \in (0,1)\)
Differential attention 将两个softmax attentin functions之间difference来消除attention noise,这个idea类似于差分放大器,两个信号的差作为输入,因为可以将输入的common-mode noise消除。也有人证明了differential attentino使得attention matrices的spectral distribution更加平衡,有效地解决了rank collapse。另外,降噪耳机的设计也是基于类似的idea,作者直接用了FlashAttention,提高了model efficiency。
Headwise Normalization:用了GroupNorm来emphasize \(LN\)独立地用于每个head, differential attention倾向于有一个sparse pattern,heads之间的statistical information更加diverse,\(LN\) operator将每个head进行normalize before concatenation,来提高gradient statistics。
- Overall Architecture:整个的架构堆叠了 \(L\) layers,每个layer包含一个Multi-head differential attention module和一个feed-forward network module。
\[ Y^l = \text{MultiHead}(\text{LN}(X^l)) + X^l \]
\[ X^{l+1} = \text{SwiGLU}(\text{LN}(Y^l)) + Y^l \]