K近邻算法思想非常简单,总结起来就是根据某种距离度量检测未知数据与已知数据的距离,统计其中距离最近的k个已知数据的类别,以多数投票的形式确定未知数据的类别。
一直想自己实现knn的java实现,但限于自己的编程水平,java刚刚入门,所以就广泛搜索网上以实现的java代码来研习。下面这个简单的knn算法的java实现是在这篇博客中找到的:http://blog.csdn.net/luowen3405/article/details/6278764
下面给出我对代码的注释,如果有错误请指正。
源程序共定义了三个class文件,分别是:public class KNNNode;public class KNN;public class TestKNN。
Description:
KNNNode: KNN结点类,用来存储最近邻的k个元组相关的信息
KNN: KNN算法主体类
TestKNN: KNN算法测试类
首先,按照程序执行顺序依次解释class的思想。
1、 TestKNN
Method: public void read()
读取文件中的数据,存储为数组的形式(以嵌套链表的形式实现)List<List<Double>> datas
程序主体执行:main
首先读入训练数据文件和测试数据文件的数据,然后输出测试数据的类别。此程序中K=3,根据对这个数据集的了解,k=3时效果是最好的。Knn算法k的确定一直是一个值得研究的problem。
2、 算法主体:KNN
此程序中比较一个难点是作者定义了一个大小为k优先级队列来存储k个最近邻节点。优先级队列初始默认是距离越远越优先,然后根据算法中的实现,将与测试集最近的k个节点保存下来。
3、 定义了一个数据节点数据结构:KNNNode
源码如下:
package KNN;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue; /**
* KNN算法主体类
* @author Rowen
* @qq 443773264
* @mail luowen3405@163.com
* @blog blog.csdn.net/luowen3405
* @data 2011.03.25
*/
public class KNN {
/**
* 设置优先级队列的比较函数,距离越大,优先级越高
*/
private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
public int compare(KNNNode o1, KNNNode o2) {
if (o1.getDistance() >= o2.getDistance()) {
return 1;
} else {
return 0;
}
}
};
/**
* 获取K个不同的随机数
* @param k 随机数的个数
* @param max 随机数最大的范围
* @return 生成的随机数数组
*/
public List<Integer> getRandKNum(int k, int max) {
List<Integer> rand = new ArrayList<Integer>(k);
for (int i = 0; i < k; i++) {
int temp = (int) (Math.random() * max);
if (!rand.contains(temp)) {
rand.add(temp);
} else {
i--;
}
}
return rand;
}
/**
* 计算测试元组与训练元组之前的距离
* @param d1 测试元组
* @param d2 训练元组
* @return 距离值
*/
public double calDistance(List<Double> d1, List<Double> d2) {
double distance = 0.00;
for (int i = 0; i < d1.size(); i++) {
distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
}
return distance;
}
/**
* 执行KNN算法,获取测试元组的类别
* @param datas 训练数据集
* @param testData 测试元组
* @param k 设定的K值
* @return 测试元组的类别
*/
public String knn(List<List<Double>> datas, List<Double> testData, int k) {
PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);//按照自然顺序存储容量为k的优先级队列
List<Integer> randNum = getRandKNum(k, datas.size()); // 建立一个列表,列表中保存的是训练数据集中实例的个数
//计算当前一个测试数据实例与训练数据集的距离,并按照距离来排序
for (int i = 0; i < k; i++) {
int index = randNum.get(i);
List<Double> currData = datas.get(index);
String c = currData.get(currData.size() - 1).toString();
KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
pq.add(node);
// System.out.println("距离"+node.getDistance()+"测试样例"+index+"k值"+k); }
//统计与测试实例距离最近的数据,然后将
for (int i = 0; i < datas.size(); i++) {
List<Double> t = datas.get(i);
double distance = calDistance(testData, t);
KNNNode top = pq.peek();
if (top.getDistance() > distance) {
pq.remove();
pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
}
} return getMostClass(pq);
}
/**
* 获取所得到的k个最近邻元组的多数类
* @param pq 存储k个最近近邻元组的优先级队列
* @return 多数类的名称
*/
private String getMostClass(PriorityQueue<KNNNode> pq) {
Map<String, Integer> classCount = new HashMap<String, Integer>();
for (int i = 0; i < pq.size(); i++) {
KNNNode node = pq.remove();
String c = node.getC();
if (classCount.containsKey(c)) {
classCount.put(c, classCount.get(c) + 1);
} else {
classCount.put(c, 1);
}
}
int maxIndex = -1;
int maxCount = 0;
Object[] classes = classCount.keySet().toArray();
for (int i = 0; i < classes.length; i++) {
if (classCount.get(classes[i]) > maxCount) {
maxIndex = i;
maxCount = classCount.get(classes[i]);
}
}
return classes[maxIndex].toString();
}
}
package KNN;
/**
* KNN结点类,用来存储最近邻的k个元组相关的信息
* @author Rowen
* @qq 443773264
* @mail luowen3405@163.com
* @blog blog.csdn.net/luowen3405
* @data 2011.03.25
*/
public class KNNNode {
private int index; // 元组标号
private double distance; // 与测试元组的距离
private String c; // 所属类别
public KNNNode(int index, double distance, String c) {
super();
this.index = index;
this.distance = distance;
this.c = c;
} public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
public String getC() {
return c;
}
public void setC(String c) {
this.c = c;
}
}
package KNN;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
/**
* KNN算法测试类
* @author Rowen
* @qq 443773264
* @mail luowen3405@163.com
* @blog blog.csdn.net/luowen3405
* @data 2011.03.25
*/
public class TestKNN { /**
* 从数据文件中读取数据
* @param datas 存储数据的集合对象
* @param path 数据文件的路径
*/
public void read(List<List<Double>> datas, String path){
try { BufferedReader br = new BufferedReader(new FileReader(new File(path)));
String data = br.readLine();
List<Double> l = null;
while (data != null) {
String t[] = data.split(" ");
l = new ArrayList<Double>(); for (int i = 0; i < t.length; i++) { l.add(Double.parseDouble(t[i]));
// System.out.println(l);
}
datas.add(l);
data = br.readLine(); }
br.close();
} catch (Exception e) {
e.printStackTrace();
}
} /**
* 程序执行入口
* @param args
*/
public static void main(String[] args) {
TestKNN t = new TestKNN();
String datafile = new File("").getAbsolutePath() + File.separator + "datafile";
String testfile = new File("").getAbsolutePath() + File.separator + "testfile";
// System.out.println(datafile);
try {
List<List<Double>> datas = new ArrayList<List<Double>>();
List<List<Double>> testDatas = new ArrayList<List<Double>>();
t.read(datas, datafile);
t.read(testDatas, testfile);
// System.out.println(datas);
KNN knn = new KNN();
for (int i = 0; i < testDatas.size(); i++) {
List<Double> test = testDatas.get(i);
System.out.print("测试元组: ");
for (int j = 0; j < test.size(); j++) {
System.out.print(test.get(j) + " ");
}
System.out.print("类别为: ");
System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
附上待分类数据:
文件名字:datafile
0.1887 0.3276 -1
0.8178 0.7703 1
0.6761 0.4849 -1
0.6022 0.6878 -1
0.1759 0.8217 -1
0.2607 0.3502 1
0.2875 0.6713 -1
0.916 0.7363 -1
0.1615 0.2564 1
0.2653 0.9452 1
0.0911 0.4386 -1
0.0012 0.3947 -1
0.4253 0.8419 1
0.0067 0.4424 -1
0.8244 0.2089 1
0.3868 0.3592 -1
0.9174 0.216 -1
0.6074 0.3968 -1
0.068 0.5201 -1
0.9686 0.9937 1
0.0908 0.3658 1
0.3411 0.7691 -1
0.4609 0.4423 -1
0.1078 0.4501 1
0.3445 0.0445 -1
0.9827 0.7093 1
0.2428 0.3774 -1
0.0358 0.1971 -1
0.82 0.721 1
0.6718 0.6714 -1
0.6753 0.2428 -1
0.7218 0.4299 -1
0.3127 0.8329 1
0.0225 0.4162 1
0.5313 0.2187 1
0.7847 0.4243 -1
0.2518 0.6476 1
0.4076 0.5439 1
0.9063 0.4587 1
0.4714 0.2703 -1
0.7702 0.0196 -1
0.2548 0.3477 -1
0.0942 0.5407 1
0.1917 0.8085 -1
0.6834 0.7689 -1
0.1056 0.1097 1
0.9577 0.5303 -1
0.9436 0.0938 -1
0.6959 0.3181 1
0.4235 0.4484 1
0.6171 0.6358 1
0.5309 0.5447 1
0.8444 0.2621 -1
0.5762 0.8335 -1
0.281 0.772 1
0.224 0.15 -1
0.4243 0.704 -1
0.7384 0.7551 -1
0.4401 0.9329 1
0.2665 0.7635 1
0.5944 0.662 1
0.3225 0.3309 -1
0.4709 0.2648 1
0.6444 0.9899 -1
0.5271 0.9727 1
0.7788 0.4046 1
0.7302 0.2362 1
0.5181 0.6963 -1
0.5841 0.6073 1
0.7184 0.5225 1
0.6999 0.1192 1
0.3439 0.1194 1
0.6951 0.7413 -1
0.611 0.0636 1
0.4229 0.5822 1
0.4735 0.8878 -1
0.2891 0.3935 -1
0.3196 0.6393 1
0.1527 0.3912 -1
0.6385 0.9398 1
0.2904 0.679 1
0.4574 0.192 1
0.3251 0.1058 1
0.6377 0.5254 -1
0.5985 0.8699 1
0.4257 0.862 -1
0.2691 0.7904 -1
0.8754 0.1389 1
0.0336 0.6456 1
0.6544 0.6473 1
文件名称:testfile
0.9516 0.0326
0.9203 0.5612
0.0527 0.8819
0.7379 0.6692
0.2691 0.1904
0.4228 0.3689
0.5479 0.4607
0.9427 0.9816
0.4177 0.1564
0.9831 0.8555
0.3015 0.6448
0.7011 0.3763
0.6663 0.1909
0.5391 0.4283
0.6981 0.4820
0.6665 0.1206
0.1781 0.5895
0.1280 0.2262
0.9991 0.3846
0.1711 0.5830
通过KNN算法对未知数据集分类,设置k=3,分类结果如下:
测试元组: 0.9516 0.0326 类别为: -1
测试元组: 0.9203 0.5612 类别为: -1
测试元组: 0.0527 0.8819 类别为: -1
测试元组: 0.7379 0.6692 类别为: -1
测试元组: 0.2691 0.1904 类别为: -1
测试元组: 0.4228 0.3689 类别为: -1
测试元组: 0.5479 0.4607 类别为: -1
测试元组: 0.9427 0.9816 类别为: 1
测试元组: 0.4177 0.1564 类别为: 1
测试元组: 0.9831 0.8555 类别为: -1
测试元组: 0.3015 0.6448 类别为: -1
测试元组: 0.7011 0.3763 类别为: -1
测试元组: 0.6663 0.1909 类别为: -1
测试元组: 0.5391 0.4283 类别为: -1
测试元组: 0.6981 0.482 类别为: -1
测试元组: 0.6665 0.1206 类别为: 1
测试元组: 0.1781 0.5895 类别为: 1
测试元组: 0.128 0.2262 类别为: 1
测试元组: 0.9991 0.3846 类别为: -1
测试元组: 0.1711 0.583 类别为: 1
KNN算法java实现代码注释的更多相关文章
-
★ java删除代码注释
package com.witwicky.util; import java.io.BufferedReader; import java.io.BufferedWriter; import java ...
-
负载均衡的几种算法Java实现代码
轮询 package class2.zookeeper.loadbalance; import java.util.ArrayList; import java.util.HashMap; impor ...
-
《机器学习实战》kNN算法及约会网站代码详解
使用kNN算法进行分类的原理是:从训练集中选出离待分类点最近的kkk个点,在这kkk个点中所占比重最大的分类即为该点所在的分类.通常kkk不超过202020 kNN算法步骤: 计算数据集中的点与待分类 ...
-
数据挖掘之KNN算法(C#实现)
在十大经典数据挖掘算法中,KNN算法算得上是最为简单的一种.该算法是一种惰性学习法(lazy learner),与决策树.朴素贝叶斯这些急切学习法(eager learner)有所区别.惰性学习法仅仅 ...
-
java代码注释规范
java代码注释规范 代码注释是架起程序设计者与程序阅读者之间的通信桥梁,最大限度的提高团队开发合作效率.也是程序代码可维护性的重要环节之一.所以我们不是为写注释而写注释.下面说一下我们在诉求网二 ...
-
[转]java代码注释规范
代码注释是架起程序设计者与程序阅读者之间的通信桥梁,最大限度的提高团队开发合作效率.也是程序代码可维护性的重要环节之一.所以我们不是为写注释而写注释.下面说一下我们在诉求网二期开发中使用的代码注释规范 ...
-
经典KMP算法C++与Java实现代码
前言: KMP算法是一种字符串匹配算法,由Knuth,Morris和Pratt同时发现(简称KMP算法).KMP算法的关键是利用匹配失败后的信息,尽量减少模式串与主串的匹配次数以达到快速匹配的目的.比 ...
-
Eclipse和MyEclipse 手动设置 Java代码 注释模板
一.目的 1. 为什么需要注释规范? 注释规范对于程序员而言尤为重要,有以下几个原因: 一个软件的生命周期中,80%的花费在于维护. 几乎没有任何一个软件,在其整个生命周期中,均由最初的开发人员来维 ...
-
算法代码[置顶] 机器学习实战之KNN算法详解
改章节笔者在深圳喝咖啡的时候突然想到的...之前就有想写几篇关于算法代码的文章,所以回家到以后就奋笔疾书的写出来发表了 前一段时间介绍了Kmeans聚类,而KNN这个算法刚好是聚类以后经常使用的匹配技 ...
随机推荐
-
根据浏览器显示界面大小变换,替换css文件方法
在1024屏幕下,选择适配1024屏幕的css文件, 在大于1024屏幕下,选择适配大屏幕的css文件. 在html中的head标签中引用css文件时,加入media属性. 例: <link r ...
-
python3验证码机器学习
python3验证码机器学习 文档结构为 -- iconset -- ... -- jpg -- captcha.gif -- py -- crack.py 需要的库 pip3 install pil ...
-
windows下安装PhpDocumentor(phpdoc)笔记
PhpDocumentor简介 PHPDocumentor是一个用PHP写的工具,对于有规范注释的php程序,它能够快速生成具有相互参照,索引等功能的API文档.老的版本是phpdoc,从1.3.0开 ...
-
nyoj 230/poj 2513 彩色棒 并查集+字典树+欧拉回路
题目链接:http://acm.nyist.net/JudgeOnline/problem.php?pid=230 题意:给你许许多多的木棍,没条木棍两端有两种颜色,问你在将木棍相连时,接触的端点颜色 ...
-
Tomcat以指定JDK运行
如果一台机器上有多个Tomcat,可能存在不同的Tomcat需要不同版本JDK才能运行的情况,这时候就需要指定JDK来同时运行多个Tomcat了. 在windows环境下以批处理文件方式启动tomca ...
-
RxJava开发精要8 – 与REST无缝结合-RxJava和Retrofit
原文出自<RxJava Essentials> 原文作者 : Ivan Morgillo 译文出自 : 开发技术前线 www.devtf.cn 转载声明: 本译文已授权开发者头条享有独家转 ...
-
POJ1416 Shredding Company(dfs)
题目链接. 分析: 这题从早上调到现在.也不算太麻烦,细节吧. 每个数字都只有两种状态,加入前一序列和不加入前一序列.DFS枚举. #include <iostream> #include ...
-
ASP.NET WEB API 如何使用基于Post的方式传递多个值(二)
前面我曾经写过一篇文章,是基于HttpContext的请求上下文中读取表单参数,其实还可以将其单独拆分出来. 基于Filter的方式 获取表单值:(核心代码) public void OnActi ...
-
c++野指针 之 实战篇
一:今天做poj上的3750那个题,用到了list的erase方法.提交之后总是报runtime error! 纠结了好长时间.曾有一度怀疑过vector的erase和list的erase处理方式不一 ...
-
XCode8中的sizeClass设置
xcode8出来很久了,xcode9都要出来了,项目中由于一直没遇到用到适配屏幕的情况,所以一直也就忽略了这个知识点.今天忽然想起来,就抱着试一试的态度打开了xcode,我去~就我现在了解而言,屏幕大 ...