机器学习实验(二):梯度下降和牛顿法解逻辑回归

一:梯度下降法

  1. 加载实验数据画出画出散点图,区分出能进入大学的学生和不能进入大学的学生。
  2. 回顾逻辑回归,其假设函数为

           hθx=gθTx=11+e-θTx=p(y=1|x;θ)机器学习实验(二):梯度下降和牛顿法解逻辑回归

在使用前先声明此函数,代码如下:

                g=inline(‘1.0./(1.0+exp(-z)))

  1. 当我们给定一个训练数据集时,x(i)i=1,2,…,m机器学习实验(二):梯度下降和牛顿法解逻辑回归 我们定义似然函数为:

                     Jθ=i=1m(hθ(x(i)))y(i)(1-hθ(x(i)))1-y(i)机器学习实验(二):梯度下降和牛顿法解逻辑回归

      

       为了简化计算,我们转变成下列的似然函数:

                Lθ=1mi=1my(i)loghθxi+(1-yi)log⁡(hθ(x(i)))机器学习实验(二):梯度下降和牛顿法解逻辑回归  

       并且L的梯度可以被定义为:

                       ∇θL=1mxi=1m(hθxi-y(i))x(i)机器学习实验(二):梯度下降和牛顿法解逻辑回归

  1. 每次更新θ:机器学习实验(二):梯度下降和牛顿法解逻辑回归

                            θ←θ-α∇θL机器学习实验(二):梯度下降和牛顿法解逻辑回归

      

       直到似然函数的差值变得很小,即:

                             L+θ-Lθ<ε机器学习实验(二):梯度下降和牛顿法解逻辑回归

 二:牛顿法

 (1)牛顿法求解逻辑回归:θ机器学习实验(二):梯度下降和牛顿法解逻辑回归 的更新变为

                              θ←θ-H-1θL机器学习实验(二):梯度下降和牛顿法解逻辑回归

      逻辑回归中,海塞矩阵的求法为:

                      H=1mi=1mhθ(xi)(1-hθ(xi))xiT(x(i))机器学习实验(二):梯度下降和牛顿法解逻辑回归

 (2)其余的与上面的方法一样。     

三。实验结果:

梯度下降:学习率为0.0024,成本函数的变化情况:

机器学习实验(二):梯度下降和牛顿法解逻辑回归

求出的决策边界:

机器学习实验(二):梯度下降和牛顿法解逻辑回归

牛顿法:成本函数的变化情况(跟之前相比加负号)。

机器学习实验(二):梯度下降和牛顿法解逻辑回归

牛顿法的划分结果:

机器学习实验(二):梯度下降和牛顿法解逻辑回归

四、相关代码:(MATLAB)

%梯度下降法:
clear,clc;
x=load('ex2x.dat');
y=load('ex2y.dat');
m=length(y);
one=ones(m,1);
x=[one,x];
%x(:,3)=(x(:,3)-mean(x(:,3)))./std(x(:,3));
%x(:,2)=(x(:,2)-mean(x(:,2)))./std(x(:,2));
pos = find (y == 1); neg = find (y == 0);
plot (x(pos , 2) , x(pos ,3) , '+' );
hold on ;
plot (x(neg , 2) , x(neg , 3) ,'ro ' );
g = inline ( '1.0 ./ (1.0 + exp(-z ))'); 
c0=0;c1=0;c2=0;
c=[c0,c1,c2];
step=0.0024 ; %ÉèÖò½³¤,³õʼֵ,µü´ú¼Ç¼Êý  
L_val_current=0;
L_val_previous=100;
  min=9999999999;
  j=1; temp=100;
 while abs(L_val_current-L_val_previous)>0.000001
    %for i=1:m
     % r1=r1+(g(c*x(i,:)'))-y(i,:);
      %r2=r2+(g(c*x(i,:)')-y(i,:))*x(i,2);
      %r3=r3+(g(c*x(i,:)')-y(i,:))*x(i,3);
    %end
    r1=(g(x*c')-y)'*x(:,1);
    r2=(g(x*c')-y)'*x(:,2);
    r3=(g(x*c')-y)'*x(:,3);
    c0=c0-step*r1/m;
    c1=c1-step*r2/m;
    c2=c2-step*r3/m;
    c=[c0,c1,c2];
    cc(j,:)=c;
    %for k=1:m
     %  L_val_current=L_val_current+(y(k,:)*log(g(x(k,:)*c'))+(1-y(k,:))*log(1-g(x(k,:)*c')))/m;
    %end
    L_val_current=(-(log(g(x*c')))'*y-(log(one-g(x*c')))'*(ones(m,1)-y))/m;
        L_val_previous=temp;
        f(j,:)=L_val_current;
        temp=L_val_current;
  %  j_values(j,:)=fx;
   % if(L_val_current<min)
    %    min=L_val_current;
     %   cmin=[c0,c1,c2];
      %  num=j;
  
    j=j+1; 
 end 
%a=[-2.5:0.1:2.5];
a=[10:1:70];
plot(a,-c1/c2*a-c0/c2,'r-')
xlabel('Exam 1 score');
ylabel('Exam 2 score');
legend('admitted','Not admitted','decision boundary');
figure ;
plot(1:j-1,f(:,1),'r-');


%牛顿法:
clear,clc;
x=load('ex2x.dat');
y=load('ex2y.dat');
m=length(y);
x=[ones(m,1),x];
%x(:,3)=(x(:,3)-mean(x(:,3)))./std(x(:,3));
%x(:,2)=(x(:,2)-mean(x(:,2)))./std(x(:,2));
pos = find (y == 1); neg = find (y == 0);
plot (x(pos , 2) , x(pos ,3) , '+' );
hold on ;
plot (x(neg , 2) , x(neg , 3) ,'ro ' );
 
g = inline ( '1.0 ./ (1.0 + exp(-z ))'); 
c=[0,0,0];
step=0.009; %ÉèÖò½³¤,³õʼֵ,µü´ú¼Ç¼Êý  
L_val_current=0;
L_val_previous=100;
  min=9999999999;
  j=1; temp=100;
 while abs(L_val_current-L_val_previous)>0.000001
     L_val_current=0;
    r1=0;r2=0;r3=0;
    H=zeros(3,3);
    for i=1:m
      r1=r1+((g(c*x(i,:)'))-y(i,:));
      r2=r2+((g(c*x(i,:)')-y(i,:))*x(i,2));
      r3=r3+((g(c*x(i,:)')-y(i,:))*x(i,3));
      H=H+(g(c*x(i,:)')*(1-g(c*x(i,:)'))*x(i,:)'*x(i,:))/m;
    end
    c=c-[r1,r2,r3]/H/m;
    cc(j,:)=c;
    for k=1:m
        L_val_current=L_val_current+(y(k,:)*log(g(x(k,:)*c'))+(1-y(k,:))*log(1-g(x(k,:)*c')))/m;
    end
        L_val_previous=temp;
        f(j,:)=L_val_current;
        temp=L_val_current;
        
  %  j_values(j,:)=fx;
    if(L_val_current<min)
        min=L_val_current;
        cmin=[c(1),c(2),c(3)];
        num=j;
    end
    j=j+1; 
 end 
a=[10:0.1:70];
plot(a,-c(:,2)/c(:,3)*a-c(:,1)/c(:,3),'r-')
xlabel('Exam 1 score');
ylabel('Exam 2 score');
legend('admitted','Not admitted','decision boundary');
figure ;
plot(1:j-1,f(:,1),'r-');