Decouple KL

Decoupled Kullback-Leibler Divergence Loss[1]

作者是来自NTU、CUHK等机构的Jiequan Cui等人,论文引用[1]:Cui, Jiequan et al. “Decoupled Kullback-Leibler Divergence Loss.” ArXiv abs/2305.13948 (2023): n. pag.

Time

  • 2024.Oct

Key Words

  • breaking asymmetric optimization property
  • inserting class-wise global information to mitigate sample-wise bias
  • KL loss在反向传播上等价于DKL loss

总结

  1. 在本文中,作者深入研究了KL散度损失,数学上证明了它等价于Decoupled Kullback-Leibler(DKL) Divergence loss, DKL包含一个weighted Mean Square Error(wMSE) loss和一个引入了soft labels交叉熵损失。多亏了DKL loss的decomposed formulation,作者有两方面的改进,首先,通过打破它的非对称优化的特点,解决了KL/DKL在像知识蒸馏等场景的局限。这个修改确保了wMSE在训练的时候总是有效的,提供了额外的constructive cues。其次,作者引入了class-wise global information,来缓解来自个体岩本的bias。有了这两个提高,作者推出了Improved Kullback-Leibler Divergence Loss (IKL)。

  2. Loss函数是训练model的重要的组成部分,交叉熵损失在图像分类任务中特别重要,同时MSE loss在回归任务中很常见,对比损失是representation learning中流行的一个损失函数。选择合适的loss对模型的性能会产生重要的影响。因此,有效的loss函数是ML和CV领域中一个重要的研究课题。

  3. KL量化了一个概率分布和一个参考分布(reference distribution) 之间的不相似的程度,作为一个最常用的loss functions,在很多场景中都用应用,例如对抗训练,知识蒸馏,增强学习,还有分布之外数据的Robustness。尽管很多研究引入了KL 散度损失,它们可能没有完全研究这个loss函数背后的机制。为了弥补这个问题,本文旨在解释KL散度关于梯度优化的工作机制。作者的研究聚焦于从梯度优化的角度对KL 散度的分析。对于用softmax activation的模型,作者提供了理论证明:它等价于DKL 散度损失,其包含一个加权的MSE loss和一个带有soft labels的交叉熵损失。图表明了KL和DKL在梯度反向传播上,两者是等价的,有了这个decomposed的形式,更方便分析KL loss是如何在训练优化中工作的。作者用新的DKL loss发现了KL loss的问题,特别地,它的梯度优化是在输入方面具有不对称性。如图所示, \(o_m\)\(o_n\) 的梯度是非对称的,通过wMSE和交叉熵推导出来,这个优化非对称会导致wMSE在特定的场景中被忽略,例如知识蒸馏,\(o_m\) 是teacher model的logits,不参与梯度反向传播,通过打破非对称优化,用解耦的DKL loss的形式很容易解决这个问题。 另外,wMSE是受sample-wise predictions所指导,有不正确分数的难的样本会导致有挑战的优化,因此将class-wise 全局信息进行插入,来正则化训练过程。将DKL和这两点结合起来,作者推导出了Improved Kullback-Leibler(IKL) 散度损失。

    为了展示作者提出的IKL 损失的有效性,作者在对抗训练和知识蒸馏任务中进行评估。

  4. 作者的贡献如下:

    • 作者揭示了:KL loss在数学上等价于一个加权的MSE loss和一个用soft labels的交叉熵损失。
    • 基于作者的分析,作者提出了两个修改来提高:打破它的非对称优化和引入class-wise global information,推导出了Improved Kulllback-Leibler (IKL) loss。
  5. Adversarial Robustness: 自从Szegedy识别对抗samples以来,DNNs的安全性受到了很大的关注,确保DNNs的可靠性在ML中变成了一个prominent topic,对抗训练,是最有效的方法,由于其持续的高性能变得突出。对抗训练在训练过程中引入了对抗样本,有人提出了采用通用的first-order adversary。特别是对抗训练中的PGD攻击,Zhang等人通过KL loss平衡了精度和Robustness,Wu等人引入了对抗权重扰动,来显式地调节weight loss的flatness。Cui等人利用训练的model来正则化对抗训练中的decision boundary。另外,多种其它的techniques也被开发了,聚焦于优化或者训练方面的。另外,一些工作探索了数据增强,来提高对抗训练。作者探索了KL loss机制用于adversarial robustness,提出的IKL loss的有效性在有合成数据和没有合成数据的设定中都进行了测试。

  6. 知识蒸馏:知识蒸馏的概念首次是被Hinton提出,他设计从accurate teacher model中提取dark knowledge来指导student model的学习过程。这是通过利用KL loss来正则化student model输出概率来实现的,在给定相同的输入的时候,将它们和teacher models进行对齐。这个简单有效的方法显著地提高了更小的模型的泛化能力,在多个domains中有广泛的应用。自从KD的成功,一些先进的方法,包括logits-based和features-based方法都被引入了,文本将KL loss解耦成一个新的formulation, DKL,解决了KL loss在KD等场景中的局限。

  7. 其它KL 散度loss的应用: 在半监督学习中,KL loss是作为weakly 和strongly augmented images输出之间的consistency loss。在连续学习中,KL loss 通过encouraging pretrained和newly updated models输出之间的consistency,来保留之前的knowledge,另外,KL loss也用于提高分布外数据的robustness。

  8. KL 散度的定义: KL散度测量两个概率分布之间的differences,对于一个连续随机变量的 P和Q,它的定义如下:

    p和q表示 P和Q的概率密度 KL loss是深度学习中最广泛使用的loss,在涉及类别分布的多个contexts中都有应用。本文主要检验它在对抗训练和知识蒸馏任务中的role。

    在对抗训练中,KL loss通过将对抗样本的分布的输出概率和它们对应表的clean images来提高model robustness,这最小化了output changes,在知识蒸馏中,KL loss使得一个student model来模仿一个teacher model的behavior,加速了konwledge transfer,增强了student model的泛化能力。

  9. KL loss的应用:图像分类模型用softmax来预测概率,KL loss通常用于encourage \(s_m\)\(s_n\) 之间的similarity,目标函数如下: \[\mathcal{L}_{\text{KL}}(x_m, x_n) = \sum_{j=1}^{C} \mathbf{s}_m^j * \log \frac{\mathbf{s}_m^j}{\mathbf{s}_n^j}.\]

    例如,在对抗训练中,\(x_m\) 表示一个clean image,\(x_n\) 是对应的对抗样本,在知识蒸馏中,\(x_m\)\(x_n\) 是同一个image,它们分别输入给teacher和student model,在知识蒸馏过程中,\(s_m\) 是不参与梯度的反向传播的,teacher model是预训练的,在训练过程中是固定的。

  10. 之前的工作引入了KL loss,没有研究它的机制,本文旨在通过分析KL loss function,揭示梯度优化背后的driving force。利用不定积分的工具,作者引入了一个新的formulation,称之为Decoupled Kullback-Leibler 散度损失,DKL loss和KL loss是等价的,通过提供更容易分析的tractable alternative,用于之后的exploration和study。 理论1: 从梯度优化的角度看,当 \(\alpha\) =1, \(\beta\) = 1 KL 散度损失等价于下面的DKL 散度损失。 \[ \mathcal{L}_{\text{DKL}}(x_m, x_n) = \underbrace{ \frac{\alpha}{4} \left\| \sqrt{\mathcal{S}(\mathbf{w}_m)} (\Delta \mathbf{m} - \mathcal{S}(\Delta \mathbf{n})) \right\|^2 }_{\text{weighted MSE (wMSE)}} - \underbrace{ \beta \cdot \mathcal{S}(\mathbf{s}_m^\top) \cdot \log \mathbf{s}_n }_{\text{Cross-Entropy}} \] \(S(.)\) 代表stop gradients操作, \(s^T_m\)\(s_m\) 的转置。

    作者更新了这个公式,使得 \(\delta(n)\) 可以进行梯度更新,更新之后如下: \[ \mathcal{L}_{\text{DKL}}(x_m, x_n) = \underbrace{ \frac{\alpha}{4} \left\| \sqrt{\mathcal{S}(\mathbf{w}_m)} (\Delta \mathbf{m} - \Delta \mathbf{n}) \right\|^2 }_{\text{weighted MSE (wMSE)}} - \underbrace{ \beta \cdot \mathcal{S}(\mathbf{s}_m^\top) \cdot \log \mathbf{s}_n }_{\text{Cross-Entropy}} \] 插入class-wise Global information:

    它表明 \(w_m\) 依赖于 sample-wise prediction socres,然而,这个model当处理outliers或者hard examples的时候,不能输出正确的predictions。在这种情况下,wMSE将会给预测的class \(\hat{y} = argmax(o_m)\) most importance,而不是ground truth class,会误导优化导致训练不稳定。

    作者将class-wise global information引入到wMSE中,用 \(\tilde{w}_y\) 代替 \(w_m\)\[ \bar{\mathbf{y}}^{j,k} = \bar{\mathbf{s}}^j * \bar{\mathbf{s}}^k \]

    y 是 \(x_m\) 的ground-truth label,\(\bar{s}_y = \frac{1}{|X_y|} \sum_{x_i \in X_y} s_i\), \(\bar{w}_y\) 插入的class-wise global information 作为一个正则化,来增强intra-class一致性,缓解samples noises的biases,特别地,在训练的早期阶段,\(\bar{w}_y\) 能够提供正确的预测,有益于 \(\bar{w}MSE\) 的优化。通过引入这两个设计,IKL loss是这样的: \[ \mathcal{L}_{\text{IKL}}(x_m, x_n) = \underbrace{ \frac{\alpha}{4} \left\| \sqrt{\mathcal{S}(\bar{\mathbf{w}}_y)} (\Delta \mathbf{m} - \mathcal{S}(\Delta \mathbf{n})) \right\|^2 }_{\text{weighted MSE (wMSE)}} - \underbrace{ \beta \cdot \mathcal{S}(\mathbf{s}_m^\top) \cdot \log \mathbf{s}_n }_{\text{Cross-Entropy}} \] \(y\)\(x_m\) 的ground truth label, \(\bar{w}_y \in \mathcal{R}^{C \times C}\) 是 类别 \(y\) 的weights。

!(Comparisons)[https://figures.semanticscholar.org/2702d965eb53cba99f9afa36357efca2b87d7d9c/2-Figure1-1.png] \(Fig.1^{[1]}\) KL, DKL, IKL loss之间的对比,DKL loss在反向传播优化上等价于KL loss, \(M\)\(N\) 可以是相同的,或者两个separate models,这由具体的应用场景决定。类似地, \(x_m, x_n \in X\) 也可以是相同的,或者两个不同的images,\(o_m, o_n \in O\) 是logits output,概率是通过用softmax activation计算得到的,黑色箭头表示前向过程,彩色箭头表示对应loss functions的反向过程。