K-近邻算法的思想如下:首先,计算新样本与训练样本之间的距离,找到距离最近的K 个邻居;然后,根据这些邻居所属的类别来判定新样本的类别,如果它们都属于同一个类别,那么新样本也属于这个类;否则,对每个后选类别进行评分,按照某种规则确定新样本的类别。(统计出现的频率)
该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分当K值较小时可能产生过拟合,因为训练误差很小,但是测试误差可能很大;相反,当K值较大时可能产生欠拟合。
算法伪代码
对未知类别属性的数据集中的每个点依次执行以下操作:
(1) 计算已知类别的数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3) 选取与当前点距离最小的K个点;
(4) 确定前K个点所在类别的出现频率;
(5) 返回前K个点出现频率最高的类别作为当前点的预测分类。
%
%手写数字识别系统的测试代码
%
function handWritingTest()
tic; %开始计时
K = 3; % 这里可以调整k值
trainLabels = [];
direct = mfilename('fullpath');%
traindirect = strrep(direct,'handWritingTest','trainingDigits'); %trainingDigits
%获得路径
traindirfile = dir(fullfile(traindirect,'*.txt'));%提取后缀名.txt
traindircell = struct2cell(traindirfile)';
trainfilenames = traindircell(:,1);
trainfileNums = length(trainfilenames);
trainMat = zeros(trainfileNums,1024);
for i = 1:trainfileNums
fileNameStr = trainfilenames(i);
str = deblank(fileNameStr);
s = regexp(str,'\.','split'); %
fileStr = s{1}(1);
classNumStr = regexp(fileStr,'\_','split');
trainLabels(i)=str2num(char(classNumStr{1}(1))); %得到类别 0 - 9
filePath = strcat(traindirect,'\',fileNameStr); %文件路径
trainMat(i,:) = img2vector(filePath);%处理文件 获得向量
end
%测试样本
direct = mfilename('fullpath');
testdirect = strrep(direct,'handWritingTest','testDigits');%testDigits
testdirfile = dir(fullfile(testdirect,'*.txt'));
testdircell = struct2cell(testdirfile)';
testfilenames = testdircell(:,1);
testfileNums = length(testfilenames);
errorcount = 0;
for j = 1:testfileNums
fileNameStr = testfilenames(j);
str = deblank(fileNameStr);
s = regexp(str,'\.','split');
fileStr = s{1}(1);
classNumStr = regexp(fileStr,'\_','split');
testLabel = str2num(char(classNumStr{1}(1))); %得到类别 0 - 9
filePath = strcat(testdirect,'\',fileNameStr);
testVector = img2vector(filePath);
classifyRet = classify(testVector,trainMat,trainLabels,K);
if(classifyRet ~= testLabel)
errorcount = errorcount + 1;
fprintf('test result: %d, real result: %d , here error!!! \n',classifyRet,testLabel);
else
fprintf('test result: %d, real result: %d \n',classifyRet,testLabel);
end
end
lastTime = num2str(toc);
fprintf('\n the sum numbers of errors : %d ',errorcount);
fprintf('\n the total error rate : %f ' ,(errorcount / testfileNums));
fprintf('\n total time : %f',lastTime);
end
%
%KNN算法 classify(test,dataSet,labels,k)
%四个参数:test用于分类的输入向量;输入的训练样本集为dataSet;
%标签向量为labels; k 表示用于选择最近邻居的数目;
%
function maxClass = classify(test,dataSet,labels,k)
[dataRow,dataCol] = size(dataSet);%dataRow:样本个数;dataCol:特征
%求距离 test 与样本数据之间的距离 这里为欧式距离
diffMat = dataSet;
for i = 1:dataRow
diffMat(i,:) = diffMat(i,:) - test;
end
sqdiffMat = diffMat.^2;
sqDistances = sum(sqdiffMat,2).^(0.5);
[p,q] = sort(sqDistances); %p代表要排序的数,q代表要排序的数原来对应的索引
%通过k 来求最邻居的前k 个数据,然后找的在这些数据中类别最多的
classCount=zeros(10,1);
class = [];
for j = 1:k
tempLabel = labels(q(j));
class(j) = tempLabel;%没用到
classCount(tempLabel+1) = classCount(tempLabel+1)+1;
end
[r,s] = max(classCount);
maxClass = s - 1; %返回 相似个数最多的 那个类
end
%
%将32*32的二进制图形矩阵转换为1*1024的向量
%
function retVector = img2vector(fileName)
fileName = char(fileName);
tempVector = [];
% 读文件
fileData = textread(fileName,'%s');
fileData = char(fileData);%读取文件,并将文件转换矩阵的格式
temp = fileData(:)';
for i = 1 : length(temp)
tempVector(i) = str2num(temp(i));
end
retVector = tempVector;
end