【机器学习】求解逻辑回归参数(梯度上升算法和牛顿法)

回顾

这篇博客【链接】我们简单介绍了逻辑回归模型,留下了一个问题:怎么求解使J(θ)最大的θ值呢?

J(θ)=i=1m(y(i)loghθ(x(i))+(1y(i))log(1hθ(x(i))))

前面我们提到了用梯度上升法和牛顿法。那么什么是梯度上升法和牛顿法呢?

梯度上升算法

由于J(θ)过于复杂,我们从一个简单的函数求极大值说起。
一元二次函数

f(x)=x2+4x

图像如下:
【机器学习】求解逻辑回归参数(梯度上升算法和牛顿法)

根据高中所学知识:
1. 求极值,先求函数的导数

f(x)=2x+4

2. 令导数为0,可求出x=2即取得函数f(x)的极大值。极大值等于f(2)=4

但是真实环境中的函数不会像上面这么简单,就算求出了函数的导数,也很难精确计算出函数的极值。此时我们就可以用迭代的方法来做。就像爬坡一样,一点一点逼近极值。这种寻找最佳拟合参数的方法,就是最优化算法。爬坡这个动作用数学公式表达即为:

xi+1=xi+αf(xi)xi

其中,α为步长,也就是学习速率,控制更新的幅度。效果如下图:
【机器学习】求解逻辑回归参数(梯度上升算法和牛顿法)

比如从(0,0)开始,迭代路径就是1->2->3->4->…->n,直到求出的x为函数极大值的近似值,停止迭代。
这一过程,就是梯度上升算法。那么同理,J(θ)这个函数的极值,也可以这么求解。公式可以写为:

θj:=θj+αJ(θ)θj

那么,我们现在只要求出J(θ)的偏导,就可以利用梯度上升算法求解J(θ)的极大值了。

J(θ)=i=1m{y(i)loghθ(x(i))+(1y(i))log(1hθ(x(i)))}

hθ(x)=g(θTx)=11+eθTx

令:
g(z)=11+ez

求导:
g(z)=ez(1+ez)2=11+ezez1+ez=11+ez(111+ez)=g(z)(1g(z))

可得:
g(θTx)=g(θTx)(1g(θTx))

J(θ)

J(θ)θj=i=1m(y(i)hθ(x(i))1y(i)1hθ(x(i)))hθ(x(i))θj

=i=1m(y(i)g(θTx(i))1y(i)1g(θTx(i)))g(θTx(i))θj

=i=1m(y(i)g(θTx(i))1y(i)1g(θTx(i)))g(θTx(i))(1g(θTx(i)))θTx(i)θj

其中:
θTx(i)θj=(θ1x1(i)+θ2x2(i)+θ3x3(i)+...+θnxn(i))θj=xj(i)

=i=1m{y(i)(1g(θTx(i)))(1y(i))(g(θTx(i))}xj(i)=i=1m(y(i)g(θTx(i)))xj(i)

综上:

θj:=θj+αi=1m(y(i)hθ(x(i)))xj(i)

θj:=θj+α(y(i)hθ(x(i)))xj(i)

牛顿法

同样,我们先来看个简单的例子。求函数值为0时的x的值。
用牛顿法迭代公式:

xn+1=xnf(xn)f(xn)xn+2=xn+1f(xn+1)f(xn+1)

【机器学习】求解逻辑回归参数(梯度上升算法和牛顿法)

这个迭代 公式的意思就是:在x=x1时,求得(x1,f(x1))的切线与x轴的交点为x2,再求(x2,f(x2))的切线与x轴的交点x3,依次迭代,直到找到满足要求的点。

然而,对于J(θ)我们需要求得一阶导数为0的点,那么牛顿法迭代公式可以更新为:

xn+1=xnJ(xn)J(xn)xn+2=xn+1J(xn+1)J(xn+1)

拓展

在多元的情况下,J(xn)=H(θ^) 海塞矩阵

H(θ^)=[2Jθ1θ12Jθ1θ22Jθ2θ12Jθ2θ2]

三阶海塞矩阵形式为:

H(θ^)=[2Jθ1θ12Jθ1θ22Jθ1θ32Jθ2θ12Jθ2θ22Jθ2θ32Jθ3θ12Jθ3θ22Jθ3θ3]

H(θ^)=[i=1nhθ(xi)(1hθ(xi))xi,1xi,1, i=1nhθ(xi)(1hθ(xi))xi,1xi,2, i=1nhθ(xi)(1hθ(xi))xi,1i=1nhθ(xi)(1hθ(xi))xi,2xi,1, i=1nhθ(xi)(1hθ(xi))xi,2xi,2, i=1nhθ(xi)(1hθ(xi))xi,2,i=1nhθ(xi)(1hθ(xi))xi,1, i=1nhθ(xi)(1hθ(xi))xi,2, i=1nhθ(xi)(1hθ(xi))]hθ(xi)=11+ezz=θ1xi,1+θ2xi,2+θ3

一阶导数

J=i=1n(yihθ(xi))xi,1i=1n(yihθ(xi))xi,2i=1n(yihθ(xi))

注:
此外,还可以用sklearn自带函数求解逻辑回归参数
此三种方法的python3代码实现,点击这里,对比本文公式看