蒸馏法训练网络

《Distilling the Knowledge in a Neural Network》

摘要

在ML领域中有一种最为简单的提升模型效果的方式,在同一训练集上训练多个不同的模型,在预测阶段采用综合均值作为预测值。但是,运用这样的组合模型需要太多的计算资源,特别是当单个模型都非常大的时候。已经有相关的研究表明,复杂模型或者组合模型的中“知识”通过合适的方式是可以迁移到一个相对简单模型之中,进而方便模型推广部署。

简介

在大规模的机器学习领域,如物体检测、语音识别等为了获得较好的performance常常会训练很复杂的模型,因为不需要考虑实时性、计算量等因素。但是,在部署阶段就需要考虑模型的大小、计算复杂度、速度等诸多因素,因此我们需要更小更精炼的模型用于部署。这种训练和部署阶段不同的模型形态,可以类比于自然界中很多昆虫有多种形态以适应不同阶段的需求。具体地,如蝴蝶在幼虫以蛹的形式存储能量和营养来更好的发育,但是到了后期就为了更好的繁殖和移动它就呈现了另外一种完全不一样的形态。

有一种直观的概念就是,越是复杂的网络具有越好的描述能力,可以用来解决更为复杂的问题。我们所说的模型学习得到“知识”就是模型参数,说到底我们想要学习的是一个输入向量到输出向量的映射,而不必太过于去关心中间映射过程。

模型蒸馏

模型蒸馏就是将训练好的复杂模型的推广能力“知识”迁移到一个结构更为简单的网络中。或者通过简单的网络去学习复杂模型中“知识”。其基本流程如下图:

基本可以分为两个阶段:

原始模型训练:

A1. 根据提出的目标问题,设计一个或多个复杂网络(N1,N2,…,Nt)。

A2. 使用足够的训练数据,按照常规CNN模型训练流程,并行训练多个复杂网络,得到(M1,M2,…,Mt)

精简模型训练:

B1.      根据(N1,N2,…,Nt)设计一个简单网络N0。

B2.      使用简单模型训练数据,此处的训练数据可以是训练原始网络的有标签数据,也可以是额外的无标签数据。

B3.      将A2中收集到的样本输入原始模型(M1,M2,…,Mt),修改原始模型softmax层中温度参数T为一个较大值,如T=20。每一个样本在每个原始模型可以得到其最终的分类概率向量,选取其中概率至最大即为该模型对于当前样本的判定结果。对于t个原始模型就可以得到t个概率向量。然后对t概率向量求取均值作为当前样本最后的概率输出向量,记为soft_target,保存。

B4.  标签融合B2中收集到的数据定义为hard_target,有标签数据的hard_target取值为其标签值1,无标签数据hard_taret取值为0。Target =a*hard_target + b*soft_target(a+b=1)。Target最终作为训练数据的标签去训练精简模型。参数a,b是用于控制标签融合权重的,推荐经验值为(a=0.1 b=0.9)

5. 设置精简模型softmax层温度参数与原始复杂模型产生Soft-target时所采用的温度,按照常规模型训练精简网络模型。

6. 部署时将精简模型中的softmax温度参数重置为1,即采用最原始的softmax

                        蒸馏法训练网络