机器学习算法笔记之9:偏差与方差、学习曲线

1. 偏差与方差的理解

在训练机器学习模型时,使用不同的训练集很可能会得到不同的估计模型,估计模型随着训练集的改变而变化的程度就叫做方差variance。我们训练得到的估计模型与实际真实模型的偏差即为bias,估计与实际差距越大,bias就越高。为了得到较低的误差,需要尽可能地降低方差和偏差,然而这两者不能同时减小,在bias与variance之间存在一个权衡trade-off。

低偏差的模型可以很好的适应训练数据,改变训练集会得到特别不同的模型,及低Bias的方法能够捕捉到训练集中的大部分差异,改变数据集时估计模型会变化很多,意味着该模型是高方差的(high variance)。模型的 bias 越低,它适应数据的能力就越强,同时 variance 也越高。所以,bias 越低,variance 越高。

反过来也说得通:bias 越高,variance 越低。一个高 variance 的模型构建的简单模型通常是不能很好适应数据集的。当我们改变数据集的时候,从高 bias 的算法得到的模型 f^ 通常不会有很大不同。如果我们改变训练集的时候 f^ 不会改变太多,那么 variance 就比较低,这恰好证明了我们的观点:bias 越高,variance 越低。

在实际中,我们需要接受一个 trade-off。我们不可能同时得到低 bias 和低 variance,所以我们期望得到某种中间结果。

 

机器学习算法笔记之9:偏差与方差、学习曲线

2. 学习曲线

学习曲线会展示误差随着训练集的改变而如何发生变化。

机器学习算法笔记之9:偏差与方差、学习曲线

由上图左图可以得到,当训练集样本数增大到某个值以后时,验证集的误差保持大致不变,表明增加更多的训练数据点并不能带来更好的模型,与其增大训练集的规模,不如尝试构建更加复杂的模型算法。而右图则表明增加更多训练样本会降低模型误差,改善模型性能。

偏差Bias问题诊断,偏差问题的主要标志是高验证误差,如下图:

机器学习算法笔记之9:偏差与方差、学习曲线

高验证集误差表明是一个偏差问题,但并不能直接指明具体的偏差问题。与此同时,高训练集误差表明是高偏差问题(欠拟合),模型不能很好地拟合训练数据;而低训练集误差表明是低偏差问题,模型可以很好地拟合训练数据。

方差variance问题诊断:首先检查验证学习曲线和训练学习曲线之间的差距,然后检查训练误差(检查误差值随训练样本数的增加的变化)。

两曲线较小的差距代表较小的variance,差距越小则variance越小,反之亦然。高方差即variance较大说明出现了过拟合问题(过度拟合训练数据)。当过拟合的模型分别在训练集和验证集上测试时,训练误差较低而验证误差较高,且随着训练样本数的增加这种模式继续存在,训练集和验证集之间的差异程度决定了这两条曲线之间的距离。

训练误差和验证误差之间的关系,以及训练学习曲线和验证学习曲线之间的差距可以总结如下:gap = validation_error − training_error。两个误差之间的差距越大,曲线之间的距离越大,variance 越大。

机器学习算法笔记之9:偏差与方差、学习曲线

通常,以下两种修正方式在处理高 bias 和低 variance 的问题时会比较奏效:

  • 用更多的特征训练当前的学习算法,即通过增加模型的复杂度来降低 bias。
  • 减少对当前算法的正则化。正则化能够避免算法在训练数据上过拟合。如果我们减少了正则化,模型会更好地拟合训练数据,就会增加 variance,降低 bias。

理想化的学习曲线和不可约化的误差

理想化的学习曲线应该是两条学习曲线都收敛至误差为0的时候,而实际上这是不可能的,这是由于不可约误差(irreducible error)的存在。在实际中最好的模型会收敛于某个不可约误差而不是理想的误差值0。在更多数技术性写作中,贝叶斯误差通常指的是分类器的可能最佳错误得分。这个概念和不可约误差是类似的。

机器学习算法笔记之9:偏差与方差、学习曲线

3. sklearn 中的学习曲线应用

from sklearn.model_selection import learning_curve

使用 learning_curve() 函数生成需要的数据来绘制学习曲线。函数会返回一个包含三个元素的元组:训练集大小、训练集和验证集上的误差得分。在这个函数内部,我们使用了以下参数:

  •  estimator-代表我们估计实际模型时所用的学习算法;
  •  X-包含特征的数据;
  •  y-包含目标的数据;
  •  train_sizes—所用的特定的训练集大小;
  •  cv-确定交叉验证分割策略(我们马上会讨论这个内容);
  •  scoring-代表所用的误差指标;我们使用 nearest proxy 和负 MSE,我们随后必须颠倒一下符号。