《机器学习实战》学习笔记之k-近邻算法3

时间:2022-01-15 12:49:11

2.3 手写识别系统

从os模块中导入listdir函数,用来读取给定目录中的文件名

from os import listdir
关于zeros函数的使用,

《机器学习实战》学习笔记之k-近邻算法3

代码及注释

#image convert to vector
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr=fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect

def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('digits/trainingDigits')
m = len(trainingFileList)#训练样本的个数
trainingMat = zeros((m,1024))#创建训练矩阵,每行有1024个元素,表示一个训练样本
for i in range(m):
fileNameStr = trainingFileList[i]#第i个训练样本
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])#样本命名的第一个数字表示实际的分类
hwLabels.append(classNumStr)#得到训练集的所有分类
trainingMat[i,:] = img2vector('digits/trainingDigits/%s'%fileNameStr)#将所有样本转换成矩阵,得到训练样本集
testFileList = listdir('digits/testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr =fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])#同样方法得到一个测试样本的分类
vectorUnderTest = img2vector('digits/testDigits/%s'%fileNameStr)#将一个测试样本转成矩阵
classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)#执行分类
print "the classifier came back with:%d, the real answer is:%d"%(classifierResult,classNumStr)
if(classifierResult != classNumStr):errorCount+=1.0
print "\nthe total number of errors is:%d"%errorCount
print "\nthe total error rate is:%f"%(errorCount/float(mTest))

终端结果

《机器学习实战》学习笔记之k-近邻算法3

k-近邻算法总结:摘自《机器学习实战》

简单的说,该算法采用测量不同特征值之间的距离方法进行分类,缺陷:

1. 必须保存全部数据集,如果训练数据集很大,必须使用大量的存储空间

2.必须对每个数据计算距离值,耗时大

3.无法给出任何数据的基础结构信息,无法知晓平均实例样本和典型实例样本具有什么特征