Preparing Lessons: Improve Knowledge Distillation with Better Supervision论文笔记
论文地址:http://arxiv.org/abs/1911.07471
github地址:无
本文提出两种监督方式来提高知识蒸馏的效果,旨在解决teacher网络出现错分类结果和模糊分类结果时的蒸馏问题,保证student网络始终学习到有效的知识。
Methods
Bad phenomenon 1:Genetic errors
含义:学生网络和教师网络都得到相同错误的预测结果,当教师网络预测错时,学生网络很难自身纠正这个错误的知识,此时产生genetic error。
Method 1:knowledge adjustment
作者提出对原本知识蒸馏的loss的修改,具体而言,明确定义学生网络与教师网络之间的度量方式为KL divergence,增加了对教师网络的logits的修正函数,舍弃与real label之间的cross entropy项。
修正函数修正错误的logits,对正确的logits不做修改。文中作者分别使用Lable Smooth Regularization (LSR) 与提出的Probability Shift (PS) 作为修正函数。LSR将预测错误的类标签加入其他类别标签的影响,起到软化标签的效果。PS将预测错误的类的置信度与真实对应的标签置信度交换,以保证最大置信度落在正确的标签上。
Bad phenomenon 2:Uncertainty of supervision
含义:教师网络的softened logits输出分布比较平坦时(即temperature scaling较大),其提供的监督信息可能会损失一些具有判别性的信息,这使得学生网络学到模糊的logits分布,从而产生错误的预测结果。
Method 2:Dynamic Temperature Distillation
作者提出动态调整temperature scaling的蒸馏方式DTD,即sample-wise的自适应设定参数。
其中是batch的大小,分别是基准和偏置项,是batch-wise normalization后的样本的权重,用于描述该样本的迷惑程度,即难分类程度。当难分类时,较大,此时变小,使得softened logits根据判别性,反之亦然。
对于的计算方法,作者提出两种方法:
- Focal Loss Style Weights (FLSW)
受focal loss对困难样本的惩罚权重调整的启发,作者提出FLSW方法。
其中分别为学生和教师的logits。为超参。- Confidence Weighted by Student Max (CWSM)
其中为学生的logits经归一化后的最大值,作者认为该值能反映学生对此样本的置信度。
最终,将两种方法相结合,可以同时解决两个问题。其整体loss function调整为:
Experiments
数据集:CIFAR-10, CIFAR-100, Tiny ImageNet
方法比较:标准蒸馏 (KD),注意力机制 (AT),神经元选择性迁移 (NST)