机器学习之KNN算法思想及其实现

时间:2023-03-09 00:32:34
机器学习之KNN算法思想及其实现

从一个例子来直观感受KNN思想

如下图 , 绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

                    机器学习之KNN算法思想及其实现

从这个例子中,我们再来看KNN思想:

, 计算已知类别数据集合中的点与当前点之间的距离(使用欧式距离公司: d =sqrt(pow(x-x1),)+pow(y-y1),)

, 按照距离递增次序排序(由近到远)

, 选取与当前点距离最小的的K个点(如上题中的 k=,k=)

, 确定前K个点所在类别的出现频率

, 将频率最高的那组,作为该点的预测分类

实现代码:

 package com.data.knn;

 /**
* *********************************************************
* <p/>
* Author: XiJun.Gong
* Date: 2016-09-06 12:02
* Version: default 1.0.0
* Class description:
* <p/>
* *********************************************************
*/
public class Point { private double x; //x坐标
private double y; //y坐标
private double dist; //距离另一个点的距离 private String label; //所属类别 public Point() {
this(0d, 0d, "");
} public Point(double x, double y, String label) {
this.x = x;
this.y = y;
this.label = label;
} /*计算两点之间的距离*/
public double distance(final Point a) {
return Math.sqrt((a.x - x) * (a.x - x) + (a.y - y) * (a.y - y));
} public double getX() {
return x;
} public void setX(double x) {
this.x = x;
} public double getY() {
return y;
} public void setY(double y) {
this.y = y;
} public String getLabel() {
return label;
} public void setLabel(String label) {
this.label = label;
} public double getDist() {
return dist;
} public void setDist(double dist) {
this.dist = dist;
}
}

KNN实现

 package com.data.knn;

 import com.google.common.base.Preconditions;
import com.google.common.collect.Maps; import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map; /**
* *********************************************************
* <p/>
* Author: XiJun.Gong
* Date: 2016-09-06 11:59
* Version: default 1.0.0
* Class description:
* <p/>
* *********************************************************
*/
public class knn { private List<Point> dataSet; //统计频率
private Point newPoint; //当前点 //进行KNN分类
public String classify(List<Point> dataSet, final Point newPoint, Integer K) { Preconditions.checkArgument(K < dataSet.size(), "K的值超过了dataSet的元素");
//求解每一个点到新的点的距离
for (Point point : dataSet) {
point.setDist(newPoint.distance(point));
}
//进行排序
Collections.sort(dataSet, new Comparator<Point>() {
@Override
public int compare(Point o1, Point o2) {
//return o1.distance(newPoint) < o2.distance(newPoint) ? 1 : -1;
return o1.getDist() < o2.getDist() ? 1 : -1;
}
}); //统计前K个标签的频率
Map<String, Integer> map = Maps.newHashMap();
Integer maxCnt = -9999; //最高频率
String label = ""; //最高频率标签
Integer currentCnt = 0; //当前标签的频率
Integer times = 0;
for (Point point : dataSet) {
currentCnt = 1;
if (map.containsKey(point.getLabel())) {
currentCnt += map.get(point);
}
if (maxCnt < currentCnt) {
maxCnt = currentCnt;
label = point.getLabel();
}
map.put(point.getLabel(), currentCnt);
times++;
if (times > K) break;
}
return label;
} }
 package com.data.knn;

 import com.google.common.collect.Lists;

 import java.util.List;

 /**
* *********************************************************
* <p/>
* Author: XiJun.Gong
* Date: 2016-09-06 14:45
* Version: default 1.0.0
* Class description:
* <p/>
* *********************************************************
*/
public class Main { public static void main(String args[]) {
List<Point> list = Lists.newArrayList();
list.add(new Point(1., 1.1, "A"));
list.add(new Point(1., 1., "A"));
list.add(new Point(0., 0., "B"));
list.add(new Point(0., 0.1, "B"));
Point point = new Point(0.5, 0.5, null);
KNN knn = new KNN();
System.out.println(knn.classify(list, point, 3));
}
}

结果:

A