1. 学习环境
windows10 、Anaconda(向初学者推荐这个工具) 中的IDE工具Spyder 、python 3.7。
2. K-近邻算法概述
2.1. K-近邻算法工作原理
在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训 练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类。简单来说,k-近邻算法采 用测量不同特征值之间的距离方法进行分类。
2.2 K-近邻算法优缺点
优点:精度高、对异常值不敏感、无数据输入假定。
缺点:计算复杂度高、空间复杂度高。
适用数据范围:数值型和标称型。
2.3.K-近邻算法的一般流程
(1) 收集数据:可以使用任何合适的方法。
(2) 准备数据:距离计算所需要的数值,最好是结构化的数据格式。
(3) 分析数据:可以使用任何可行的方法。
(4) 训练算法:此步骤不适用于K-近邻算法。
(5) 测试算法:计算错误率。
(6) 使用算法:首先需要输入样本数据和结构化的输出结果,然后运行K-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。
3. 实施K-近邻算法
3.1 k-近邻算法伪代码:
对未知类别属性的数据集中的每个点依次执行以下操作:
(1) 计算已知类别数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3) 选取与当前点距离最小的K个点;
(4) 确定前k个点所在类别的出现频率;
(5) 返回前k个点出现频率最高的类别作为当前点的预测分类;
3.2 程序清单:k-近邻算法
使用欧式距离公式计算两个向量点A和B之间的距离:

def classify0(inX,dataSet,labels,k):
dataSetSize=dataSet.shape[0]
diffMat=tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1)
distances=sqDistances**0.5
sortedDistIndicies=distances.argsort()
classCount={}
for i in range(k):
voteIlabel=labels[sortedDistIndicies[i]]
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
sortedClassCount=sorted(classCount.items(), key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
4. k-近邻算法示例1
示例1:使用K-近邻算法改进约会网站的配对效果
4.1 基本流程:
(1) 收集数据:提供文本文件。
(2) 准备数据:使用python解析文本文件。
(3) 分析数据:使用Matplotlib画二维扩散图。
(4) 训练算法:此步骤不适用于K-近邻算法。
(5) 测试算法:使用海伦提供的部分数据作为测试样本。
测试样本和非测试样本的区别在于:测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
(6) 使用算法:产生简单的命令行程序,然后海伦可以输入一些特征数据以判断对方是否为自己喜欢的类型。
4.2 具体实现
4.2.1 准备数据:从文本文件中解析数据
将待处理的数据改变为分类器可以接受的格式。该函数的输入为文件名字符串,输出为训练样本矩阵和标签向量
def file2matrix(filename):
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines)
returnMat = zeros((numberOfLines,3))
classLabelVector = []
index = 0
for line in arrayOLines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
测试解析函数文件file2matrix( )
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
print(" datingDataMat \n " , datingDataMat," \n")
print(" datingLabels \n" , datingLabels[0:20])
4.2.2 分析数据:使用Matplotlib创建散点图
在 .py文件开头导入包
import matplotlib
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
plt.show()
散点图使用datingDataMat矩阵的第二、三列数据,分别表示特征值“玩视频游戏所耗时间比”

4.2.3 准备数据:归一化数值
方程中数字差值最大的属性对计算结果的影响最大,但这三种特征是同等重要的,因此作为三个等权重的特征之一,飞行常客里程数不应该如此严重地影响到计算结果。处理这种不同取值范围的特征时,我们采用的方法是将数值归一化,如将取值范围处理为 0 到 1 或者 -1 到 1 之间。下面公式可以将任意取值范围的特征值转化为 0 到 1 区间内的值:
其中min 和 max 分别是数据集中的最小特征值和最大特征值。
归一化特征值函数代码:
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1))
return normDataSet, ranges, minVals
normMat, ranges, minVals = autoNorm(datingDataMat)
print("normMat: \n", normMat,"\n ")
print("ranges: \n", ranges," \n ")
print("minVals: \n", minVals)
测试结果:

4.2.4 测试算法:作为完整程序验证分类器
测试代码:
def datingClassTest():
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],\
datingLabels[numTestVecs:m], 20)
print("the classsifier came back with: %d, the real answer is: %d"\
%(classifierResult,datingLabels[i]))
if(classifierResult != datingLabels[i]): errorCount +=1.0
print("the total error rate is:%f" % (errorCount/float(numTestVecs)))
datingClassTest()
测试结果:

分类器处理约会数据集的错误率是 6%。
4.2.5 使用算法:构建完整可用系统
约会网站预测函数代码:
def classifyPerson():
resultList = ['not at all', 'in small doses', 'in large doses']
percentTats = float(input(\
"percentage of time spent playing video games?"))
ffMiles = float(input("frequent flier miles earned per year?"))
iceCream = float(input ("liters of ice creamm consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = array([ffMiles, percentTats, iceCream])
classifierResult = classify0((inArr-\
minVals)/ranges, normMat, datingLabels, 3)
print("You will probably like this person: ",\
resultList[classifierResult - 1])
classifyPerson()
测试结果:

5. k-近邻算法示例2
示例2:手写字识别系统
5.1 基本流程
(1) 收集数据:提供文本文件。
(2) 准备数据:编写函数img2vector(),将图像格式转化为分类器使用的向量格式。
(3) 分析数据:在python命令行中检查数据,确保它符合要求。
(4) 训练算法:此步骤不适用于K-近邻算法。
(5) 测试算法:编写函数使用提供的部分数据集作为测试样本。 测试样本和非测试样本的区别在于:测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
(6) 使用算法:本列没有此步骤,若你感兴趣你可以用此算法去完成 kaggle 上的 Digital Recognition(数字识别)题目。
5.2 具体实现
5.2.1 准备数据:将图像转化为测试向量
转化函数代码:
"""
手写数据集 准备数据:将图像转换为测试向量
"""
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32*i+j] = int(lineStr[j])
return returnVect
testVector = img2vector('testDigits/0_13.txt')
print(testVector[0,0:22])
测试结果:
5.2.2 测试算法:使用k-近邻算法识别手写数字
测试函数:
错误率为 1.2%。 - END -
关注微信公众号:迈微电子研发社,回复 “KNN” 获取本博客相关工程及数据文件[Github开源项目]。
△微信扫一扫关注「迈微电子研发社」公众号
知识星球:社群旨在分享AI算法岗的秋招/春招准备攻略(含刷题)、面经和内推机会、学习路线、知识题库等。
△扫码加入「迈微电子研发社」学习辅导群