KNN算法——python实现(参考《机器学习实战》)
一、k-近邻算法概述 (KNN)
简单地说,k-近邻算法采用测量不同特征值之间的距离方法进行分类。
它的工作原理是:存在一个样本数 据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据 与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的 特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们 只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。 最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
二、k-近邻算法的一般流程
(1) 收集数据:可以使用任何方法。
(2) 准备数据:距离计算所需要的数值,最好是结构化的数据格式。
(3) 分析数据:可以使用任何方法。
(4) 训练算法:此步骤不适用于k-近邻算法。
(5) 测试算法:计算错误率。
(6) 使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输 入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。
三、从文本文件中解析数据
伪代码如下:
对未知类别属性的数据集中的每个点依次执行以下操作:
(1) 计算已知类别数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3) 选取与当前点距离最小的k个点;
(4) 确定前k个点所在类别的出现频率;
(5) 返回前k个点出现频率最高的类别作为当前点的预测分类。
四、距离计算
五、完整代码
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import operator
"""k近邻算法步骤:对未知类别属性的数据集的每个点依次执行以下操作
(1) 计算已知类别数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3)选取与当前点距离最小的k个点;
(4)确定前k个点所在类别的出现频率;
(5)返回前k个点所出现频率最高的类别作为当前点的预测分类
"""
'''
函数功能:创建数据集
Input: 无
Output: group:数据集
labels:类别标签
'''
def createDataSet():#创建数据集
group = np.array([[3,104],[2,100],[99,5],[98,2]])
labels = ['爱情片','爱情片','动作片','动作片']
return group, labels
"""
函数功能:KNN分类
输入数据参数:测试集:inX(1xN)
已知数据的特征:dataset(NxM)
已知数据的标签或者类别:labels(1xM vector)
k 近邻算法的 k,一般是小于20
输出:测试样本最可能属于哪类标签
"""
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0] # shape[0]返回dataSet的行数
#距离计算
"""tile(inX,(a,b))函数将inX重复a行,重复b列"""
"""tile:numpy中的函数。tile将原来的一个数组,扩充成了4个一样的数组。diffMat得到了目标与训练数值之间的差值。"""
diffMat = np.tile(inX, (dataSetSize,1)) - dataSet # tile(inX,(a,b))函数将inX重复a行,重复b列
sqDiffMat = diffMat**2 #作差后平方
sqDistances = sqDiffMat.sum(axis=1)#sum()求和函数,sum(0)每列所有元素相加,sum(1)每行所有元素相加
distances = sqDistances**0.5 #开平方,求欧式距离
sortedDistIndicies = distances.argsort() #argsort函数返回的是数组值从小到大的索引值
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]#取出前k个距离对应的标签
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
#计算每个类别的样本数。字典get()函数返回指定键的值,如果值不在字典中返回默认值0
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
#reverse降序排列字典
#python2版本中的iteritems()换成python3的items()
#key=operator.itemgetter(1)按照字典的值(value)进行排序
#key=operator.itemgetter(0)按照字典的键(key)进行排序
return sortedClassCount[0][0] #返回字典的第一条的key,也即是测试样本所属类别
'''
函数功能: 主函数
'''
if __name__ == '__main__':
group,labels = createDataSet()#创建数据集
print('group:\n',group)#打印数据集
print('labels:',labels)
zhfont = matplotlib.font_manager.FontProperties(fname=r'c:\windows\fonts\simsun.ttc')#设置中文字体路径
fig = plt.figure(figsize=(10,8))#可视化
ax = plt.subplot(111) #图片在第一行,第一列的第一个位置
ax.scatter(group[0:2,0],group[0:2,1],color='red',s=50)
ax.scatter(group[2:4,0],group[2:4,1],color='blue',s=50)
ax.scatter(18,90,color='orange',s=50)
plt.annotate('which class?', xy=(18, 90), xytext=(3, 2),arrowprops=dict(facecolor='black', shrink=0.05),)
plt.xlabel('打斗镜头',fontproperties=zhfont)
plt.ylabel('接吻镜头',fontproperties=zhfont)
plt.title('电影分类可视化',fontproperties=zhfont)
plt.show()
testclass = classify0([18,90], group, labels, 3)#用未知的样本来测试算法
print('测试结果:',testclass)#打印测试结果
六、结果显示
可视化结果:
参考:https://blog.****.net/u013829973/article/details/77942942感谢这位博主的博文!