最近研究KNN,找到了一些优秀的源码,贴出来,做个笔记吧。
#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<time.h> typedef struct{//数据维度
double x;
double y;
}data_struct; typedef struct kd_node{
data_struct split_data;//数据结点
int split;//分裂维
struct kd_node *left;//由位于该结点分割超面左子空间内所有数据点构成的kd-tree
struct kd_node *right;//由位于该结点分割超面右子空间内所有数据点构成的kd-tree
}kd_struct; //用于排序
int cmp1( const void *a , const void *b )
{
return (*(data_struct *)a).x > (*(data_struct *)b).x ? :-;
}
//用于排序
int cmp2( const void *a , const void *b )
{
return (*(data_struct *)a).y > (*(data_struct *)b).y ? :-;
}
//计算分裂维和分裂结点
void choose_split(data_struct data_set[],int size,int dimension,int *split,data_struct *split_data)
{
int i;
data_struct *data_temp;
data_temp=(data_struct *)malloc(size*sizeof(data_struct));
for(i=;i<size;i++)
data_temp[i]=data_set[i];
static int count=;//设为静态
*split=(count++)%dimension;//分裂维
if((*split)==) qsort(data_temp,size,sizeof(data_temp[]),cmp1);
else qsort(data_temp,size,sizeof(data_temp[]),cmp2);
*split_data=data_temp[(size-)/];//分裂结点排在中位
}
//判断两个数据点是否相等
int equal(data_struct a,data_struct b){
if(a.x==b.x && a.y==b.y) return ;
else return ;
}
//建立KD树
kd_struct *build_kdtree(data_struct data_set[],int size,int dimension,kd_struct *T)
{
if(size==) return NULL;//递归出口
else{
int sizeleft=,sizeright=;
int i,split;
data_struct split_data;
choose_split(data_set,size,dimension,&split,&split_data);
data_struct data_right[size];
data_struct data_left[size]; if (split==){//x维
for(i=;i<size;++i){
if(!equal(data_set[i],split_data) && data_set[i].x <= split_data.x){//比分裂结点小
data_left[sizeleft].x=data_set[i].x;
data_left[sizeleft].y=data_set[i].y;
sizeleft++;//位于分裂结点的左子空间的结点数
}
else if(!equal(data_set[i],split_data) && data_set[i].x > split_data.x){//比分裂结点大
data_right[sizeright].x=data_set[i].x;
data_right[sizeright].y=data_set[i].y;
sizeright++;//位于分裂结点的右子空间的结点数
}
}
}
else{//y维
for(i=;i<size;++i){
if(!equal(data_set[i],split_data) && data_set[i].y <= split_data.y){
data_left[sizeleft].x=data_set[i].x;
data_left[sizeleft].y=data_set[i].y;
sizeleft++;
}
else if (!equal(data_set[i],split_data) && data_set[i].y > split_data.y){
data_right[sizeright].x = data_set[i].x;
data_right[sizeright].y = data_set[i].y;
sizeright++;
}
}
}
T=(kd_struct *)malloc(sizeof(kd_struct));
T->split_data.x=split_data.x;
T->split_data.y=split_data.y;
T->split=split;
T->left=build_kdtree(data_left,sizeleft,dimension,T->left);//左子空间
T->right=build_kdtree(data_right,sizeright,dimension,T->right);//右子空间
return T;//返回指针
}
}
//计算欧氏距离
double compute_distance(data_struct a,data_struct b){
double tmp=pow(a.x-b.x,2.0)+pow(a.y-b.y,2.0);
return sqrt(tmp);
}
//搜索1近邻
void search_nearest(kd_struct *T,int size,data_struct test,data_struct *nearest_point,double *distance)
{
int path_size;//搜索路径内的指针数目
kd_struct *search_path[size];//搜索路径保存各结点的指针
kd_struct* psearch=T;
data_struct nearest;//最近邻的结点
double dist;//查询结点与最近邻结点的距离
search_path[]=psearch;//初始化搜索路径
path_size=;
while(psearch->left!=NULL || psearch->right!=NULL){
if (psearch->split==){
if(test.x <= psearch->split_data.x)//如果小于就进入左子树
psearch=psearch->left;
else
psearch=psearch->right;
}
else{
if(test.y <= psearch->split_data.y)//如果小于就进入右子树
psearch=psearch->left;
else
psearch=psearch->right;
}
search_path[path_size++]=psearch;//将经过的分裂结点保存在搜索路径中
}
//取出search_path最后一个元素,即叶子结点赋给nearest
nearest.x=search_path[path_size-]->split_data.x;
nearest.y=search_path[path_size-]->split_data.y;
path_size--;//search_path的指针数减一
dist=compute_distance(nearest,test);//计算与该叶子结点的距离作为初始距离 //回溯搜索路径
kd_struct* pback;
while(path_size!=){
pback=search_path[path_size-];//取出search_path最后一个结点赋给pback
path_size--;//search_path的指针数减一 if(pback->left==NULL && pback->right==NULL){//如果pback为叶子结点
if(dist>compute_distance(pback->split_data,test)){
nearest=pback->split_data;
dist=compute_distance(pback->split_data,test);
}
}
else{//如果pback为分裂结点
int s=pback->split;
if(s==){//x维
if(fabs(pback->split_data.x-test.x)<dist){//若以查询点为中心的圆(球或超球),半径为dist的圆与分割超平面相交,那么就要跳到另一边的子空间去搜索
if(dist>compute_distance(pback->split_data,test)){
nearest=pback->split_data;
dist=compute_distance(pback->split_data, test);
}
if(test.x<=pback->split_data.x)//若查询点位于pback的左子空间,那么就要跳到右子空间去搜索
psearch=pback->right;
else
psearch=pback->left;//若以查询点位于pback的右子空间,那么就要跳到左子空间去搜索
if(psearch!=NULL)
search_path[path_size++]=psearch;//psearch加入到search_path中
}
}
else {//y维
if(fabs(pback->split_data.y-test.y)<dist){//若以查询点为中心的圆(球或超球),半径为dist的圆与分割超平面相交,那么就要跳到另一边的子空间去搜索
if(dist>compute_distance(pback->split_data,test)){
nearest=pback->split_data;
dist=compute_distance(pback->split_data,test);
}
if(test.y<=pback->split_data.y)//若查询点位于pback的左子空间,那么就要跳到右子空间去搜索
psearch=pback->right;
else
psearch=pback->left;//若查询点位于pback的的右子空间,那么就要跳到左子空间去搜索
if(psearch!=NULL)
search_path[path_size++]=psearch;//psearch加入到search_path中
}
}
}
} (*nearest_point).x=nearest.x;//最近邻
(*nearest_point).y=nearest.y;
*distance=dist;//距离
} int main()
{
int n=;//数据个数
data_struct nearest_point;
double distance;
kd_struct *root=NULL;
data_struct data_set[]={{,},{,},{,},{,},{,},{,}};//数据集
data_struct test={7.1,2.1};//查询点
root=build_kdtree(data_set,n,,root); search_nearest(root,n,test,&nearest_point,&distance);
printf("nearest neighbor:(%.2f,%.2f)\ndistance:%.2f \n",nearest_point.x,nearest_point.y,distance);
return ;
}
/* x 5,4
/ \
y 2,3 7.2
\ / \
x 4,7 8.1 9.6
*/
看了一些做这方面的文章,把写的不错的几个也收录了。
https://www.joinquant.com/post/2627?f=study&m=math