KNN的简单实现

时间:2022-11-24 21:24:54

前言

KNN算法是一种分类算法,一般的,我们知道“物以类聚”,在一个已经分类好的样本中,每个样本点都有一个对应的类别,KNN算法所要做的,就是对于一个未知的样本点,给出它的类别。

算法

KNN算法的思想很简单:
1. 计算给出的样本点到各个训练点的欧式距离;
2. 按照升序排序,取前k个训练点的类别;
3. 取类别中出现频率最高的作为这个样本点的类别;

表1是已知类别的电影与未知电影G的距离(这里取欧氏距离),如果k=3的话,那么与G距离最近的三个电影为B、C、A。而这三个类别全是爱情片,所以可以判断G的类别为爱情片。

电影名称 打斗镜头 接吻镜头 电影类型
A 3 104 爱情片
B 2 100 爱情片
C 1 81 爱情片
D 101 10 动作片
E 99 5 动作片
F 98 2 动作片
G 18 90 未知

表1 电影分类数据

电影名称 与电影G的距离
A 20.5
B 18.7
C 19.2
D 115.3
E 117.4
F 118.9

表1-2. 已知电影与未知电影G的距离

实现

根据我们的算法,可以编制核心函数myKNN(test,data,label,k),使用matlab进行判别,代码很简单:

%KNN算法简单实现
function KNN
clc;clear;
train_data=load('movie_train.txt'); %读取训练数据
test_data=load('movie_test.txt');   %读取测试数据
train_label=train_data(:,3);        %提取训练数据类别
k=4;
m=size(test_data,1);                %测试数据的大小
data_size=size(train_data,1);       %训练数据的大小
%对于每一条数据使用KNN进行判别
result_m=[];                        %结果矩阵
for i=1:m
    result_m(i)=myKNN(test_data(i,:),train_data(:,1:2),train_label,k)
end

%%
%KNN算法核心
function result=myKNN(test,data,label,k)
    test=repmat(test,data_size,1);  %平铺测试数据,为计算距离做准备
    diff=test-data;                 %计算差值
    diff=sqrt(sum(diff.^2,2));      %计算欧式距离
    [re,loc]=sort(diff,'ascend');   %排序,loc为位置
    result=mode(label(loc(1:k)));        %根据loc位置找到对应的k个结果,取频率最高的类别
end
end
该算法的不足点在于样本不均衡的时候,k个值中大多数都是大样本的点,容易误判。