机器学习之最近邻算法
机器学习之最近邻算法
小白,最近在学习Python的机器学习课程,最近邻算法,对应中国的古话:近朱者赤近墨者黑
举个例子,你想知道某个人A的职业,你可以通过他身边朋友的职业来猜测,比如他有n个朋友,一个做设计的铁哥们B,一个做房产中介的高中同学C,还有两个玩的比较好的是做码农的D和E。我们来量化一下几人的关系AB<AC<AD,AE<A~others
k近邻算法中的k值的选择,如果我们选k=2,k=3或者k=4时,得出1:1:1,选取最近的B,那么我们猜测A的职业为设计师,当我们选择k=5时,1:1:2,我们猜测A的职业为码农
图片:
下面是看的课程的代码示例,Python版本为3.6版本
import numpy as np
import operator
# inX 要检测的数据
# dataSet 数据集
# labels 结果集
# k 选择距离最小的K个点
def classify0(inX,dataSet,labels,k):
# 计算矩阵的行数
dataSetSize = dataSet.shape[0]
# 第一个维度重复1次,第二个维度重复dataSetSize次
diffMat = np.tile(inX,(dataSetSize,1))-dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
# argsort()函数返回的是distances元素从小到大排列后相应元素的索引。如a=array([2,1,5,3]),a.argsort() 的结果为:[1,0,3,2]
classCount = {} # 分类标签字典 标签:标签出现次数
for i in range(k):
# 选出k个距离最近的数据
voteILabel = labels[sortedDistIndicies[i]]
# 字典的get(key,default)方法 返回字典中key对应的值,若key在字典中不存在,则返回default的值
classCount[voteILabel] = classCount.get(voteILabel,0)+1
# 排序 sort() 在本地排序,不返回副本;sorted() 返回副本,原始输入不变
# sorted(iterable, key=None, reverse=False)
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
return sortedClassCount[0][0]
# 用来从文件中加载数据
def file2matrix(filename):
# 打开文件并且获取数据
fr = open(filename)
# 读取文件所有内容,得到文件行数
numberOfLines = len(fr.readlines())
# 准备矩阵 numberOfLines行,3列
returnMat = np.zeros((numberOfLines,3))
# 准备结果标签
classLabelVector = []
# 转换成numpy的数组格式
returnMat = np.array(returnMat)
# 再一次打开文件
fr = open(filename)
# 索引行号
index = 0
# print(fr)
for line in fr.readlines():
# 截取所有回车字符
line = line.strip()
listFromLine = line.split('\t') # split()将一个字符串分裂成多个字符串组成的列表
returnMat[index,:] = listFromLine[0:3] # 将数据前三列提取出来
classLabelVector.append(int(listFromLine[-1])) # 将数据最后一列标签提取出来
index += 1
# 返回矩阵和标签
return returnMat,classLabelVector
# 归一化
def autoNorm(dataSet):
minVals = dataSet.min(0) # 参数0使得函数可以从列中选取最小值,而不是选取当前行的最小值
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = np.zeros(np.shape(dataSet))
m = dataSet.shape[0]
# for i in range(1,m):
# normDataSet[i,:] = (dataSet[i,:] - minVals)/ranges
normDataSet = dataSet - np.tile(minVals,(m,1))
normDataSet = normDataSet/np.tile(ranges,(m,1))
return normDataSet,ranges,minVals
# 构造dating的分类函数
def datingClassify():
hoRatio = 0.50 # 取50%的数据作为测试集 50%的数据作为训练集
datingDataMat,datingLabels = file2matrix(r'G:\work\python\datingTestSet2.txt')
normMat,ranges,minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio) # 测试集数据量
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print("this classifier came back with : %d, this real answer is : %d"%(classifierResult,datingLabels[i]))
if (classifierResult != datingLabels[i]):
errorCount += 1.0
print("this total error rate is : %f"%(errorCount/float(numTestVecs)))
print(errorCount)
补充知识点:
shape函数是numpy.core.fromnumeric中的函数,它的功能是查看矩阵或者数组的维数。
numpy.zeros(创建0矩阵) 返回给定形状和类型的新数组,用零填充。
numpy.zeros(shape, dtype=float, order=’C’)
shape:int或ints序列,新数组的形状,例如(2,3 )或2。
dtype:数据类型,可选数组的所需数据类型,例如numpy.int8。默认值为 numpy.float64。
order:{'C','F'},可选是否在存储器中以C或Fortran连续(按行或列方式)存储多维数据。
numpy.tile()是个什么函数呢,说白了,就是把数组沿各个方向复制
比如 a = np.array([0,1,2]),np.tile(a,(2,1))就是把a先沿x轴(就这样称呼吧)复制1倍,
即没有复制,仍然是 [0,1,2]。 再把结果沿y方向复制2倍,即最终得到
array([[0,1,2],
[0,1,2]])
sorted(iterable, key=None, reverse=False)
iteritems()返回的是一个迭代器
reverse是一个布尔值。如果设置为True,列表元素将被倒序排列,默认为False
key接受一个函数,这个函数只接受一个元素,默认为None
key,用来进行比较的元素,只有一个参数,参数取自可迭代对象,指定可迭代对象中的一个元素来进行排序。
itemgetter()用于获取对象指定维的数据,参数为序号
示例:sorted([36, 5, 12, 9, 21], reverse=True)就可以实现倒序
.readlines() 自动将文件内容分析成一个行的列表,该列表可以由 Python 的 for ... in ... 结构进行处理。另一方面,
.readline() 每次只读取一行,通常比 .readlines() 慢得多。仅当没有足够内存可以一次读取整个文件时,才应该使用 .readline()
strip() 方法用于移除字符串头尾指定的字符(默认为空格)或字符序列。
注意:该方法只能删除开头或是结尾的字符,不能删除中间部分的字符。
s.strip(rm) s为字符串,rm为要删除的字符序列。 删除s字符串中开头结尾处rm字符。rm为空时,默认删除空白符(‘\n’,‘\r’,‘\t’‘ ’)
如a=' 123',a='\t123',a='123\r\n' a.strip()的结果都为123 s.lstrip(rm) 删除s字符串中开头处rm字符 ;
s.rstrip(rm) 删除s字符串中结尾处rm字符
split()通过指定分隔符对字符串进行切片,如果参数 num 有指定值,则仅分隔 num+1 个子字符串
str.split(str="", num=string.count(str))
str -- 分隔符,默认为所有的空字符,包括空格、换行(\n)、制表符(\t)等。
num -- 分割次数。默认为 -1, 即分隔所有。
这是第一次用这个编辑器,还很不习惯,
上面的代码里面归一化中的:
normDataSet = dataSet - np.tile(minVals,(m,1))
normDataSet = normDataSet/np.tile(ranges,(m,1))
这两行代码有点头疼,求高手帮忙指点一下,我比较倾向于我注释掉的两行代码