域名 | 类型 | 描述 |
dom_elt | kd维的向量 | kd维空间中的一个样本点 |
split | 整数 | 分裂维的序号,也是垂直于分割超面的方向轴序号 |
left | kd-tree | 由位于该结点分割超面左子空间内所有数据点构成的kd-tree |
right | kd-tree | 由位于该结点分割超面右子空间内所有数据点构成的kd-tree |
先以一个简单直观的实例来介绍k-d树算法。假设有6个二维数据点{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)},数据点位于二维空间内(如图1中黑点所示)。k-d树算法就是要确定图1中这些分割空间的分割线(多维空间即为分割平面,一般为超平面)。下面就要通过一步步展示k-d树是如何确定这些分割线的。
由于此例简单,数据维度只有2维,所以可以简单地给x,y两个方向轴编号为0,1,也即split={0,1}。
(1)确定split域的首先该取的值。分别计算x,y方向上数据的方差得知x方向上的方差最大,所以split域值首先取0,也就是x轴方向;
(2)确定Node-data的域值。根据x轴方向的值2,5,9,4,8,7排序选出中值为7,所以Node-data = (7,2)。这样,该节点的分割超平面就是通过(7,2)并垂直于split = 0(x轴)的直线x = 7;
(3)确定左子空间和右子空间。分割超平面x = 7将整个空间分为两部分,如图2所示。x < = 7的部分为左子空间,包含3个节点{(2,3),(5,4),(4,7)};另一部分为右子空间,包含2个节点{(9,6),(8,1)}。
如算法所述,k-d树的构建是一个递归的过程。然后对左子空间和右子空间内的数据重复根节点的过程就可以得到下一级子节点(5,4)和(9,6)(也就是左右子空间的'根'节点),同时将空间和数据集进一步细分。如此反复直到空间中只包含一个数据点,如图1所示。最后生成的k-d树如图3所示。
算法:createKDTree 构建一棵k-d tree 输入:exm_set 样本集 输出 : Kd, 类型为kd-tree 1. 如果exm_set是空的,则返回空的kd-tree 2.调用分裂结点选择程序(输入是exm_set),返回两个值 dom_elt:= exm_set中的一个样本点 split := 分裂维的序号 3.exm_set_left = {exm∈exm_set – dom_elt && exm[split] <= dom_elt[split]} exm_set_right = {exm∈exm_set – dom_elt && exm[split] > dom_elt[split]} 4.left = createKDTree(exm_set_left) right = createKDTree(exm_set_right)
k-d tree最近邻搜索算法
算法:kdtreeFindNearest /* k-d tree的最近邻搜索 */ 输入:Kd /* k-d tree类型*/ target /* 待查询数据点 */ 输出 : nearest /* 最近邻数据结点 */ dist /* 最近邻和查询点的距离 */ 1. 如果Kd是空的,则设dist为无穷大返回 2. 向下搜索直到叶子结点 pSearch = &Kd while(pSearch != NULL) { pSearch加入到search_path中; if(target[pSearch->split] <= pSearch->dom_elt[pSearch->split]) /* 如果小于就进入左子树 */ { pSearch = pSearch->left; } else { pSearch = pSearch->right; } } 取出search_path最后一个赋给nearest dist = Distance(nearest, target); 3. 回溯搜索路径 while(search_path不为空) { 取出search_path最后一个结点赋给pBack if(pBack->left为空 && pBack->right为空) /* 如果pBack为叶子结点 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } } else { s = pBack->split; if( abs(pBack->dom_elt[s] - target[s]) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } if(target[s] <= pBack->dom_elt[s]) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */ pSearch = pBack->right; else pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */ if(pSearch != NULL) pSearch加入到search_path中 } } }
假设我们的k-d tree就是上面通过样本集{(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)}创建的。
以下是k-d树的c++代码实现,包括建树过程和搜索过程。算法main函数输入k-d树训练实例点,算法会完成建树操作,随后可以输入待查询的目标点,程序将会搜索K-d树找出与输入目标点最近邻的训练实例点。本程序只实现了1近邻搜索,如果要实现k近邻搜索,只需对程序稍作修改。比如可以对每个结点添加一个标记,如果已经输出该结点为最近邻结点,那么就继续查找次近邻的结点,直到输出k个结点后算法结束。
#include <iostream> #include <algorithm> #include <stack> #include <math.h> using namespace std; /*function of this program: build a 2d tree using the input training data the input is exm_set which contains a list of tuples (x,y) the output is a 2d tree pointer*/ struct data { double x = 0; double y = 0; }; struct Tnode { struct data dom_elt; int split; struct Tnode * left; struct Tnode * right; }; bool cmp1(data a, data b){ return a.x < b.x; } bool cmp2(data a, data b){ return a.y < b.y; } bool equal(data a, data b){ if (a.x == b.x && a.y == b.y) { return true; } else{ return false; } } void ChooseSplit(data exm_set[], int size, int &split, data &SplitChoice){ /*compute the variance on every dimension. Set split as the dismension that have the biggest variance. Then choose the instance which is the median on this split dimension.*/ /*compute variance on the x,y dimension. DX=EX^2-(EX)^2*/ double tmp1,tmp2; tmp1 = tmp2 = 0; for (int i = 0; i < size; ++i) { tmp1 += 1.0 / (double)size * exm_set[i].x * exm_set[i].x; tmp2 += 1.0 / (double)size * exm_set[i].x; } double v1 = tmp1 - tmp2 * tmp2; //compute variance on the x dimension tmp1 = tmp2 = 0; for (int i = 0; i < size; ++i) { tmp1 += 1.0 / (double)size * exm_set[i].y * exm_set[i].y; tmp2 += 1.0 / (double)size * exm_set[i].y; } double v2 = tmp1 - tmp2 * tmp2; //compute variance on the y dimension split = v1 > v2 ? 0:1; //set the split dimension if (split == 0) { sort(exm_set,exm_set + size, cmp1); } else{ sort(exm_set,exm_set + size, cmp2); } //set the split point value SplitChoice.x = exm_set[size / 2].x; SplitChoice.y = exm_set[size / 2].y; } Tnode* build_kdtree(data exm_set[], int size, Tnode* T){ //call function ChooseSplit to choose the split dimension and split point if (size == 0){ return NULL; } else{ int split; data dom_elt; ChooseSplit(exm_set, size, split, dom_elt); data exm_set_right [100]; data exm_set_left [100]; int sizeleft ,sizeright; sizeleft = sizeright = 0; if (split == 0) { for (int i = 0; i < size; ++i) { if (!equal(exm_set[i],dom_elt) && exm_set[i].x <= dom_elt.x) { exm_set_left[sizeleft].x = exm_set[i].x; exm_set_left[sizeleft].y = exm_set[i].y; sizeleft++; } else if (!equal(exm_set[i],dom_elt) && exm_set[i].x > dom_elt.x) { exm_set_right[sizeright].x = exm_set[i].x; exm_set_right[sizeright].y = exm_set[i].y; sizeright++; } } } else{ for (int i = 0; i < size; ++i) { if (!equal(exm_set[i],dom_elt) && exm_set[i].y <= dom_elt.y) { exm_set_left[sizeleft].x = exm_set[i].x; exm_set_left[sizeleft].y = exm_set[i].y; sizeleft++; } else if (!equal(exm_set[i],dom_elt) && exm_set[i].y > dom_elt.y) { exm_set_right[sizeright].x = exm_set[i].x; exm_set_right[sizeright].y = exm_set[i].y; sizeright++; } } } T = new Tnode; T->dom_elt.x = dom_elt.x; T->dom_elt.y = dom_elt.y; T->split = split; T->left = build_kdtree(exm_set_left, sizeleft, T->left); T->right = build_kdtree(exm_set_right, sizeright, T->right); return T; } } double Distance(data a, data b){ double tmp = (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y); return sqrt(tmp); } void searchNearest(Tnode * Kd, data target, data &nearestpoint, double & distance){ //1. 如果Kd是空的,则设dist为无穷大返回 //2. 向下搜索直到叶子结点 stack<Tnode*> search_path; Tnode* pSearch = Kd; data nearest; double dist; while(pSearch != NULL) { //pSearch加入到search_path中; search_path.push(pSearch); if (pSearch->split == 0) { if(target.x <= pSearch->dom_elt.x) /* 如果小于就进入左子树 */ { pSearch = pSearch->left; } else { pSearch = pSearch->right; } } else{ if(target.y <= pSearch->dom_elt.y) /* 如果小于就进入左子树 */ { pSearch = pSearch->left; } else { pSearch = pSearch->right; } } } //取出search_path最后一个赋给nearest nearest.x = search_path.top()->dom_elt.x; nearest.y = search_path.top()->dom_elt.y; search_path.pop(); dist = Distance(nearest, target); //3. 回溯搜索路径 Tnode* pBack; while(search_path.size() != 0) { //取出search_path最后一个结点赋给pBack pBack = search_path.top(); search_path.pop(); if(pBack->left == NULL && pBack->right == NULL) /* 如果pBack为叶子结点 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } } else { int s = pBack->split; if (s == 0) { if( fabs(pBack->dom_elt.x - target.x) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } if(target.x <= pBack->dom_elt.x) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */ pSearch = pBack->right; else pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */ if(pSearch != NULL) //pSearch加入到search_path中 search_path.push(pSearch); } } else { if( fabs(pBack->dom_elt.y - target.y) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } if(target.y <= pBack->dom_elt.y) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */ pSearch = pBack->right; else pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */ if(pSearch != NULL) // pSearch加入到search_path中 search_path.push(pSearch); } } } } nearestpoint.x = nearest.x; nearestpoint.y = nearest.y; distance = dist; } int main(){ data exm_set[100]; //assume the max training set size is 100 double x,y; int id = 0; cout<<"Please input the training data in the form x y. One instance per line. Enter -1 -1 to stop."<<endl; while (cin>>x>>y){ if (x == -1) { break; } else{ exm_set[id].x = x; exm_set[id].y = y; id++; } } struct Tnode * root = NULL; root = build_kdtree(exm_set, id, root); data nearestpoint; double distance; data target; cout <<"Enter search point"<<endl; while (cin>>target.x>>target.y) { searchNearest(root, target, nearestpoint, distance); cout<<"The nearest distance is "<<distance<<",and the nearest point is "<<nearestpoint.x<<","<<nearestpoint.y<<endl; cout <<"Enter search point"<<endl; } }