感知机(下)

  1. 感知机对偶形式(Python3.6)
    未经允许,禁止转载
#from __future__ import division 把下个版本的特性加入 不加也行因为已经是最新版本
import random
import numpy as np
import matplotlib.pyplot as plt
#**函数
def sign(v):
    if v<0:
        return -1
    else:
        return 1##即v>=0
###训练过程
def train(train_num,train_datas,lr):###lr是学习效率
    #权值,偏置初始值
    w = 0.0
    b = 0.0
    datas_len = len(train_datas)###训练集长度
    alpha = [0 for i in range(datas_len)]###更新
    train_array = np.array(train_datas)##变为矩阵形式才能进行运算
    Gram = np.matmul(train_array[:,0:-1],train_array[:,0:-1].T)##Gram矩阵
    for idx in range(train_num):
        tmp = 0
        i = random.randint(0,datas_len-1)
        yi=train_array[i,-1]###输出
        for j in range(datas_len):##1~N
            tmp+=alpha[j]*train_array[j,-1]*Gram[i,j]
        tmp+=b###
        if(yi*tmp<=0):
            alpha[i]=alpha[i]+lr
            b=b+lr*yi
    for i in range(datas_len):
        w+=alpha[i]*train_array[i,0:-1]*train_array[i,-1]##更新权值
        print(w)
    return w,b,alpha,Gram##返回权值,偏置,学习效率,Gram矩阵
###可视化
def plot_points(train_datas,w,b):
    plt.figure()
    x1 = np.linspace(0,8,100)
    x2 = (-b-w[0]*x1)/(w[1])###超平面
    plt.plot(x1,x2,color='r',label='y1 data')
    datas_len = len(train_datas)
    for i in range(datas_len):
        if(train_datas[i][-1]==1):
            plt.scatter(train_datas[i][0],train_datas[i][1],s=50)##画出正样本
        else:
            plt.scatter(train_datas[i][0],train_datas[i][1],marker='x',s=50)##画出正样本
    plt.show()
if __name__=='__main__':
    train_data1 = [[1,3,1],[2,2,1],[3,8,1],[2,6,1]]###正训练集
    train_data2 = [[2,1,-1],[4,1,-1],[6,2,-1],[7,3,-1]]###负训练集
    train_datas = train_data1+train_data2 ###训练集
    w,b,alpha,Gram = train(train_num=500,train_datas=train_datas,lr=0.01)###最终权值,偏置
    plot_points(train_datas,w,b)
    ###下面表格是w的更新动态
[0.01 0.03]
[0.05 0.07]
[0.05 0.07]
[0.07 0.13]
[0.05 0.12]
[0.01 0.11]
[0.01 0.11]
[-0.06  0.08]

感知机(下)

if name == 'main’的意思是:当.py文件被直接运行时,if name == 'main’之下的代码块将被运行;当.py文件以模块形式被导入时,if name == 'main’之下的代码块不被运行。

w,b
(array([-0.06,  0.08]), 0.01)

原理部分可参考李航老师的《统计学习方法》P33,以及《矩阵论》的部分知识。
有啥问题可在博客下方留言或者联系邮箱[email protected]

  • 补充:Gram矩阵
    感知机(下)