梯度提升(Gradient Boosting)算法
- 我的个人微信公众号:Microstrong
微信公众号ID: MicrostrongAI
微信公众号介绍: Microstrong(小强)同学主要研究机器学习、深度学习、计算机视觉、智能对话系统相关内容,分享在学习过程中的读书笔记!期待您的关注,欢迎一起学习交流进步! - 我的知乎主页: https://www.zhihu.com/people/MicrostrongAI/activities
- Github: https://github.com/Microstrong0305
- 个人博客: https://blog.****.net/program_developer
- 本文首发在我的微信公众号里,地址:https://mp.weixin.qq.com/s/Ods1PHhYyjkRA8bS16OfCg,如有公式和图片不清楚,可以在我的微信公众号里阅读。
1. 引言
提升树利用加法模型与前向分歩算法实现学习的优化过程。当损失函数是平方误差损失函数和指数损失函数时,每一步优化是很简单的。但对一般损失函数而言,往往每一步优化并不那么容易。针对这一问题,Freidman提出了梯度提升(gradient boosting)算法。Gradient Boosting是Boosting中的一大类算法,它的思想借鉴于梯度下降法,其基本原理是根据当前模型损失函数的负梯度信息来训练新加入的弱分类器,然后将训练好的弱分类器以累加的形式结合到现有模型中。采用决策树作为弱分类器的Gradient Boosting算法被称为GBDT,有时又被称为MART(Multiple Additive Regression Tree)。GBDT中使用的决策树通常为CART。
2. 梯度下降法
在机器学习任务中,需要最小化损失函数,其中 是要求解的模型参数。梯度下降法通常用来求解这种无约束最优化问题,它是一种迭代方法:选取初值 ,不断迭代,更新 的值,进行损失函数的极小化。这里我们还需要初始化算法终止距离 以及步长 。
使用梯度下降法求解的基本步骤为:
(1)确定当前位置的损失函数的梯度,对于 ,其梯度表达式如下:
(2)用步长 乘以损失函数的梯度,得到当前位置下降的距离,即:
(3)确定是否 梯度下降的距离小于 ,如果小于 则算法终止,当前的即为最终结果。否则进入步骤(4)。
(4)更新 ,其更新表达式如下。更新完毕后继续转入步骤(1)。
我们也可以用泰勒公式表示损失函数,用更数学的方式解释梯度下降法:
- 迭代公式:
- 将 在 处进行一阶泰勒展开:
- 要使得 ,可取: ,则:
这里多说一点,我为什么要用泰勒公式推导梯度下降法,是因为我们在面试中经常会被问到GBDT与XGBoost的区别和联系?其中一个重要的回答就是:GBDT在模型训练时只使用了代价函数的一阶导数信息,XGBoost对代价函数进行二阶泰勒展开,可以同时使用一阶和二阶导数。当然,GBDT和XGBoost还有许多其它的区别与联系,感兴趣的同学可以自己查阅一些相关的资料。
补充:泰勒公式知识点
3. 梯度提升算法
在梯度下降法中,我们可以看出,对于最终的最优解 ,是由初始值经过次迭代之后得到的,这里设 ,则为:
其中, 表示 在 处泰勒展开式的一阶导数。
在函数空间中,我们也可以借鉴梯度下降的思想,进行最优函数的搜索。对于模型的损失函数 ,为了能够求解出最优的函数 ,首先设置初始值为: 。
以函数 作为一个整体,与梯度下降法的更新过程一致,假设经过次迭代得到最优的函数为:
其中, 为:
可以看到,这里的梯度变量是一个函数,是在函数空间上求解,而我们以前梯度下降算法是在多维参数空间中的负梯度方向,变量是参数。为什么是多维参数,因为一个机器学习模型中可以存在多个参数。而这里的变量是函数,更新函数通过当前函数的负梯度方向来修正模型,使模型更优,最后累加的模型为近似最优函数。
总结:Gradient Boosting算法在每一轮迭代中,首先计算出当前模型在所有样本上的负梯度,然后以该值为目标训练一个新的弱分类器进行拟合并计算出该弱分类器的权重,最终实现对模型的更新。
4. 梯度提升原理推导
在梯度提升的 步中,假设已经有一些不完美的模型 (最初可以使用非常弱的模型,它只是预测输出训练集的平均值)。梯度提升算法不改变 ,而是通过增加估计器 构建新的模型 来提高整体模型的效果。那么问题来了,如何寻找 函数呢?梯度提升方法的解决办法是认为最好的 应该使得:
或者等价于:
因此,梯度提升算法将 与残差 拟合。与其他boosting算法的变体一样, 修正它的前身 。我们观察到残差 是损失函数 的负梯度方向,因此可以将其推广到其他不是平方误差(分类或是排序问题)的损失函数。也就是说,梯度提升算法是一种梯度下降算法,只需要更改损失函数和梯度就能将其推广。
当采用平方误差损失函数时, ,其损失函数变为:
其中, 是当前模型拟合数据的残差(residual)。在使用更一般的损失函数时,我们使用损失函数的负梯度在当前模型的值 作为提升树算法中残差的近似值,拟合一个梯度提升模型。当使用一般的损失函数时,为什么会出现上式的结果呢?下面我们就来详细阐述。
我们知道,对函数 在 处的泰勒展示式为:
因此,损失函数 处的泰勒展开式就是:
将 带入上式,可得:
因此,应该对应于平方误差损失函数中的 ,这也是我们为什么说对于平方损失函数拟合的是残差;对于一般损失函数,拟合的就是残差的近似值。
5. 对梯度提升算法的若干思考
(1)梯度提升与梯度下降的区别和联系是什么?
GBDT使用梯度提升算法作为训练方法,而在逻辑回归或者神经网络的训练过程中往往采用梯度下降作为训练方法,二者之间有什么联系和区别呢?
下表是梯度提升算法和梯度下降算法的对比情况。可以发现,两者都是在每一轮迭代中,利用损失函数相对于模型的负梯度方向的信息来对当前模型进行更新,只不过在梯度下降中,模型是以参数化形式表示,从而模型的更新等价于参数的更新。而在梯度提升中,模型并不需要进行参数化表示,而是直接定义在函数空间中,从而大大扩展了可以使用的模型种类。
(2)梯度提升和提升树算法的区别和联系?
提升树利用加法模型与前向分歩算法实现学习的优化过程。当损失函数是平方误差损失函数和指数损失函数时,每一步优化是很简单的。但对一般损失函数而言,往往每一步优化并不那么容易。针对这一问题,Freidman提出了梯度提升(gradient boosting)算法。这是利用损失函数的负梯度在当前模型的值 作为提升树算法中残差的近似值,拟合一个梯度提升模型。
(3)梯度提升和GBDT的区别和联系?
- 采用决策树作为弱分类器的Gradient Boosting算法被称为GBDT,有时又被称为MART(Multiple Additive Regression Tree)。GBDT中使用的决策树通常为CART。
- GBDT使用梯度提升(Gradient Boosting)作为训练方法。
(4)梯度提升算法包含哪些算法?
Gradient Boosting是Boosting中的一大类算法,其中包括:GBDT(Gradient Boosting Decision Tree)、XGBoost(eXtreme Gradient Boosting)、LightGBM (Light Gradient Boosting Machine)和CatBoost(Categorical Boosting)等。
(5)对于一般损失函数而言,为什么可以利用损失函数的负梯度在当前模型的值作为梯度提升算法中残差的近似值呢?
我们观察到在提升树算法中,残差 是损失函数 的负梯度方向,因此可以将其推广到其他不是平方误差(分类或是排序问题)的损失函数。也就是说,梯度提升算法是一种梯度下降算法,不同之处在于更改损失函数和求其负梯度就能将其推广。即,可以将结论推广为对于一般损失函数也可以利用损失函数的负梯度近似拟合残差。
6. 总结
我之前已经写过关于Regression Tree 回归树、深入理解提升树(Boosting Tree)算法的文章。回归树可以利用集成学习中的Boosting框架改良升级得到提升树,提升树再经过梯度提升算法改造就可以得到GBDT算法,GBDT再进一步可以升级为XGBoost、LightGBM或者CatBoost。在学习这些模型的时候,我们把它们前后连接起来,就能更加系统的理解这些模型的区别与联系。
7. Reference
【1】《统计学习方法》,李航著。
【2】《百面机器学习》,诸葛越主编、葫芦娃著。
【3】机器学习算法中 GBDT 和 XGBOOST 的区别有哪些? - wepon的回答 - 知乎 https://www.zhihu.com/question/41354392/answer/98658997
【4】GBDT算法原理与系统设计简介,地址:http://wepon.me/files/gbdt.pdf
【5】Boosting方法-从AdaBoost到LightGBM!,地址:https://mp.weixin.qq.com/s/iY_Sfyxe2VdW9FJCKNOBSw
【6】梯度提升(Gradient Boosting)算法,地址:http://www.360doc.com/content/19/0713/18/1353678_848501530.shtml
【7】梯度提升算法 介绍,地址:https://blog.****.net/wutao1530663/article/details/71235727
【8】梯度提升算法, 地址:https://blog.****.net/u013263891/article/details/97893803