上一篇较详细地介绍了k-d树算法。本文来讲解具体的实现代码。
首先是一些数据结构的定义。我们先来定义单个数据,代码如下:
//单个数据向量结构定义
struct _Examplar
{
public:
_Examplar():dom_dims(0){} //数据维度初始化为0
//带有完整的两个参数的constructor,这里const是为了保护原数据不被修改
_Examplar(const std::vector<double> elt, int dims)
{
if(dims > 0)
{
dom_elt = elt;
dom_dims = dims;
}
else
{
dom_dims = 0;
}
}
_Examplar(int dims) //只含有维度信息的constructor
{
if(dims > 0)
{
dom_elt.resize(dims);
dom_dims = dims;
}
else
{
dom_dims = 0;
}
}
_Examplar(const _Examplar& rhs) //copy-constructor
{
if(rhs.dom_dims > 0)
{
dom_elt = rhs.dom_elt;
dom_dims = rhs.dom_dims;
}
else
{
dom_dims = 0;
}
}
_Examplar& operator=(const _Examplar& rhs) //重载"="运算符
{
if(this == &rhs)
return *this;
releaseExamplarMem();
if(rhs.dom_dims > 0)
{
dom_elt = rhs.dom_elt;
dom_dims = rhs.dom_dims;
}
return *this;
}
~_Examplar()
{
}
double& dataAt(int dim) //定义访问控制函数
{
assert(dim < dom_dims);
return dom_elt[dim];
}
double& operator[](int dim) //重载"[]"运算符,实现下标访问
{
return dataAt(dim);
}
const double& dataAt(int dim) const //定义只读访问函数
{
assert(dim < dom_dims);
return dom_elt[dim];
}
const double& operator[](int dim) const //重载"[]"运算符,实现下标只读访问
{
return dataAt(dim);
}
void create(int dims) //创建数据向量
{
releaseExamplarMem();
if(dims > 0)
{
dom_elt.resize(dims); //控制数据向量维度
dom_dims = dims;
}
}
int getDomDims() const //获得数据向量维度信息
{
return dom_dims;
}
void setTo(double val) //数据向量初始化设置
{
if(dom_dims > 0)
{
for(int i=0;i<dom_dims;i++)
{
dom_elt[i] = val;
}
}
}
private:
void releaseExamplarMem() //清除现有数据向量
{
dom_elt.clear();
dom_dims = 0;
}
private:
std::vector<double> dom_elt; //每个数据定义为一个double类型的向量
int dom_dims; //数据向量的维度
};
结构_Examplar定义了单个数据节点的结构,主要包含的信息有:1.数据向量本身;2.数据向量的维度。接下来定义一整个数据集的结构,代码如下:
//数据集结构定义
class ExamplarSet : public TrainData //整个数据集类,由一个抽象类TrainData派生
{
private:
//_Examplar *_ex_set;
std::vector<_Examplar> _ex_set; //定义含有若干个_Examplar类数据向量的数据集
int _size; //数据集大小
int _dims; //数据集中每个数据向量的维度
public:
ExamplarSet():_size(0), _dims(0){}
ExamplarSet(std::vector<_Examplar> ex_set, int size, int dims);
ExamplarSet(int size, int dims);
ExamplarSet(const ExamplarSet& rhs);
ExamplarSet& operator=(const ExamplarSet& rhs);
~ExamplarSet(){}
_Examplar& examplarAt(int idx)
{
assert(idx < _size);
return _ex_set[idx];
}
_Examplar& operator[](int idx)
{
return examplarAt(idx);
}
const _Examplar& examplarAt(int idx) const
{
assert(idx < _size);
return _ex_set[idx];
}
void create(int size, int dims);
int getDims() const { return _dims;}
int getSize() const { return _size;}
_HyperRectangle calculateRange();
bool empty() const
{
return (_size == 0);
}
void sortByDim(int dim); //按某个方向维的排序函数
bool remove(int idx); //去除数据集中排序后指定位置的数据向量
void push_back(const _Examplar& ex) //添加某个数据向量至数据集末尾
{
_ex_set.push_back(ex);
_size++;
}
int readData(char *strFilePath); //从文件读取数据集
private:
void releaseExamplarSetMem() //清除现有数据集
{
_ex_set.clear();
_size = 0;
}
};
类ExamplarSet定义了整个数据集的结构,其包含的主要信息有:1.含有若干个_Examplar类数据向量的数据集;2.数据集的大小;3.每个数据向量的维度。以上两个结构是整个算法两个基本的数据结构,这里的代码只是展示其主要包含的结构信息,详细的定义及函数实现代码请参看附件。
接下来就要定义k-d tree的结构。同样采用上述由点定义到集定义的思路,我们先来定义k-d tree中一个节点结构,代码如下:
//k-d tree节点结构定义
class KDTreeNode
{
private:
int _split_dim; //该节点的最大区分度方向维
_Examplar _dom_elt; //该节点的数据向量
_HyperRectangle _range_hr; //表示数据范围的超矩形结构
public:
KDTreeNode *_left_child, *_right_child, *_parent; //该节点的左右子树和父节点
public:
KDTreeNode():_left_child(0), _right_child(0), _parent(0),
_split_dim(0){}
KDTreeNode(KDTreeNode *left_child, KDTreeNode *right_child,
KDTreeNode *parent, int split_dim, _Examplar dom_elt, _HyperRectangle range_hr):
_left_child(left_child), _right_child(right_child), _parent(parent),
_split_dim(split_dim), _dom_elt(dom_elt), _range_hr(range_hr){}
KDTreeNode(const KDTreeNode &rhs);
KDTreeNode& operator=(const KDTreeNode &rhs);
_Examplar& getDomElt() { return _dom_elt; }
_HyperRectangle& getHyperRectangle(){ return _range_hr; }
int& splitDim(){ return _split_dim; }
void create(KDTreeNode *left_child, KDTreeNode *right_child,
KDTreeNode *parent, int split_dim, _Examplar dom_elt, _HyperRectangle range_hr);
};
类KDTreeNode就是按照前一篇表1所述定义的。需要注意的是_HyperRectangle这一结构,它表示的就是这一节点所代表的空间范围Range,其定义如下:
struct _HyperRectangle //定义表示数据范围的超矩形结构
{
_Examplar min; //统计数据集中所有数据向量每个维度上最小值组成的一个数据向量
_Examplar max; //统计数据集中所有数据向量每个维度上最大值组成的一个数据向量
_HyperRectangle() {}
_HyperRectangle(_Examplar mx, _Examplar mn)
{
assert (mx.getDomDims() == mn.getDomDims());
min = mn;
max = mx;
}
_HyperRectangle(const _HyperRectangle& rhs)
{
min = rhs.min;
max = rhs.max;
}
_HyperRectangle& operator= (const _HyperRectangle& rhs)
{
if(this == &rhs)
return *this;
min = rhs.min;
max = rhs.max;
return *this;
}
void create(_Examplar mx, _Examplar mn)
{
assert (mx.getDomDims() == mn.getDomDims());
min = mn;
max = mx;
}
};
对于整个数据集来说_HyperRectangle表示的就是对全体的统计范围信息,对部分数据集来说其表示的就是对部分数据的统计范围信息。还是以上篇中实例中的数据{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}为例,_HyperRectangle表示的统计范围如图1所示:
图1 _HyperRectangle表示的统计范围
- 对于根节点(7,2),其所对应的空间范围是整个数据集,所以根节点(7,2)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的数据范围统计得min = {dom_elt = (2,1),dom_dims = 2},max = {dom_elt = (9,7),dom_dims = 2};
- 对于中间节点(5,4),其所对应的空间范围是根节点的左子树,所以节点(5,4)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的数据范围统计得min = {dom_elt = (2,3),dom_dims = 2},max = {dom_elt = (5,7),dom_dims = 2};
- 对于叶子节点(4,7),其所对应的空间范围是节点本身,所以节点(4,7)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的 数据范围统计得min = {dom_elt = (4,7),dom_dims = 2},max = {dom_elt = (4,7),dom_dims = 2};
最后再进行整个k-d tree结构的定义。代码如下:
class KDTree //k-d tree结构定义
{
public:
KDTreeNode *_root; //k-d tree的根节点
public:
KDTree():_root(NULL){}
void create(const ExamplarSet &exm_set); //创建k-d tree,实际上调用createKDTree
void destroy(); //销毁k-d tree,实际上调用destroyKDTree
~KDTree(){ destroyKDTree(_root); }
std::pair<_Examplar, double> findNearest(_Examplar target); //查找最近邻点函数,返回值是pair类型
//实际是调用findNearest_i
//查找距离在range范围内的近邻点,返回这样近邻点的个数,实际是调用findNearest_range
int findNearest(_Examplar target, double range, std::vector<std::pair<_Examplar, double>> &res_nearest);
private:
KDTreeNode* createKDTree(const ExamplarSet &exm_set);
void destroyKDTree(KDTreeNode *root);
std::pair<_Examplar, double> findNearest_i(KDTreeNode *root, _Examplar target);
int findNearest_range(KDTreeNode *root, _Examplar target, double range,
std::vector<std::pair<_Examplar, double>> &res_nearest);
可见,整个k-d tree结构是由一系列KDTreeNode类的节点构成。整个k-d树的构建算法和基于k-d树的最邻近查找算法主要就是由createKDTree,findNearest_i以及findNearest_range这三个函数完成。代码分别如下:
- createKDTree
//KDTree::是由于定义了KDTree的namespace
KDTree::KDTreeNode* KDTree::KDTree::createKDTree( const ExamplarSet &exm_set )
{
if(exm_set.empty())
return NULL;
ExamplarSet exm_set_copy(exm_set);
int dims = exm_set_copy.getDims();
int size = exm_set_copy.getSize();
//计算每个维的方差,选出方差值最大的维
double var_max = -0.1;
double avg, var;
int dim_max_var = -1;
for(int i=0;i<dims;i++)
{
avg = 0;
var = 0;
//求某一维的总和
for(int j=0;j<size;j++)
{
avg += exm_set_copy[j][i];
}
//求平均
avg /= size;
//求方差
for(int j=0;j<size;j++)
{
var += ( exm_set_copy[j][i] - avg ) *
( exm_set_copy[j][i] - avg );
}
var /= size;
if(var > var_max)
{
var_max = var;
dim_max_var = i;
}
}
//确定节点的数据矢量
_HyperRectangle hr = exm_set_copy.calculateRange(); //统计节点空间范围
exm_set_copy.sortByDim(dim_max_var); //将所有数据向量按最大区分度方向排序
int mid = size / 2;
_Examplar exm_split = exm_set_copy.examplarAt(mid); //取出排序结果的中间节点
exm_set_copy.remove(mid); //将中间节点作为父(根)节点,所有将其从数据集中去除
//确定左右节点
ExamplarSet exm_set_left(0, exm_set_copy.getDims());
ExamplarSet exm_set_right(0, exm_set_copy.getDims());
exm_set_right.remove(0);
int size_new = exm_set_copy.getSize(); //获得子数据空间大小
for(int i=0;i<size_new;i++) //生成左右子节点
{
_Examplar temp = exm_set_copy[i];
if( temp.dataAt(dim_max_var) <
exm_split.dataAt(dim_max_var) )
exm_set_left.push_back(temp);
else
exm_set_right.push_back(temp);
}
KDTreeNode *pNewNode = new KDTreeNode(0, 0, 0, dim_max_var, exm_split, hr);
pNewNode->_left_child = createKDTree(exm_set_left); //递归调用生成左子树
if(pNewNode->_left_child != NULL) //确认左子树父节点
pNewNode->_left_child->_parent = pNewNode;
pNewNode->_right_child = createKDTree(exm_set_right); //递归调用生成右子树
if(pNewNode->_right_child != NULL) //确认右子树父节点
pNewNode->_right_child->_parent = pNewNode;
return pNewNode; //最终返回k-d tree的根节点
}
整个createKDTree函数完全符合上篇中表2所述。注意其中统计节点空间范围calculateRange这一函数,其定义如下:
KDTree::_HyperRectangle KDTree::ExamplarSet::calculateRange()
{
assert(_size > 0);
assert(_dims > 0);
_Examplar mn(_dims);
_Examplar mx(_dims);
for(int j=0;j<_dims;j++)
{
mn.dataAt(j) = (*this)[0][j]; //初始化最小范围向量
mx.dataAt(j) = (*this)[0][j]; //初始化最大范围向量
}
for(int i=1;i<_size;i++) //统计数据集中每一个数据向量
{
for(int j=0;j<_dims;j++)
{
if( (*this)[i][j] < mn[j] ) //比较每一维,寻找最小值
mn[j] = (*this)[i][j];
if( (*this)[i][j] > mx[j] ) //比较每一维,寻找最大值
mx[j] = (*this)[i][j];
}
}
_HyperRectangle hr(mx, mn);
return hr; //返回一个_HyperRectangle结构
}
- findNearest_i
std::pair<KDTree::_Examplar, double> KDTree::KDTree::findNearest_i( KDTreeNode *root, _Examplar target )
{
KDTreeNode *pSearch = root;
//堆栈用于保存搜索路径
std::vector<KDTreeNode*> search_path;
_Examplar nearest;
double max_dist;
while(pSearch != NULL) //首先通过二叉查找得到搜索路径
{
search_path.push_back(pSearch);
int s = pSearch->splitDim();
if(target[s] <= pSearch->getDomElt()[s])
{
pSearch = pSearch->_left_child;
}
else
{
pSearch = pSearch->_right_child;
}
}
nearest = search_path.back()->getDomElt(); //取路径中最后的叶子节点为回溯前的最邻近点
max_dist = Distance_exm(nearest, target);
search_path.pop_back();
//回溯搜索路径
while(!search_path.empty())
{
KDTreeNode *pBack = search_path.back();
search_path.pop_back();
if( pBack->_left_child == NULL && pBack->_right_child == NULL) //如果是叶子节点,就直接比较距离的大小
{
if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) )
{
nearest = pBack->getDomElt();
max_dist = Distance_exm(pBack->getDomElt(), target);
}
}
else
{
int s = pBack->splitDim();
if( abs(pBack->getDomElt()[s] - target[s]) < max_dist) //以target为圆心,max_dist为半径的圆和分割面如果
{ //有交割,则需要进入另一边子空间搜索
if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) )
{
nearest = pBack->getDomElt();
max_dist = Distance_exm(pBack->getDomElt(), target);
}
if(target[s] <= pBack->getDomElt()[s]) //如果target位于左子空间,就应进入右子空间
pSearch = pBack->_right_child;
else
pSearch = pBack->_left_child; //如果target位于右子空间,就应进入左子空间
if(pSearch != NULL)
search_path.push_back(pSearch); //将新的节点加入search_path中
}
}
}
std::pair<_Examplar, double> res(nearest, max_dist);
return res; //返回包含最邻近点和最近距离的pair
}
- findNearest_range
int KDTree::KDTree::findNearest_range( KDTreeNode *root, _Examplar target, double range,
std::vector<std::pair<_Examplar, double>> &res_nearest )
{
if(root == NULL)
return 0;
double dist_sq, dx;
int ret, added_res = 0;
dist_sq = 0;
dist_sq = Distance_exm(root->getDomElt(), target); //计算搜索路径中每个节点和target的距离
if(dist_sq <= range) { //将范围内的近邻添加到结果向量res_nearest中
std::pair<_Examplar,double> temp(root->getDomElt(), dist_sq);
res_nearest.push_back(temp);
//结果个数+1
added_res = 1;
}
dx = target[root->splitDim()] - root->getDomElt()[root->splitDim()];
//左子树或右子树递归的查找
ret = findNearest_range(dx <= 0.0 ? root->_left_child : root->_right_child, target, range, res_nearest);
//当另外一边可能存在范围内的近邻
if(ret >= 0 && fabs(dx) < range) {
added_res += ret;
ret = findNearest_range(dx <= 0.0 ? root->_right_child : root->_left_child, target, range, res_nearest);
}
added_res += ret;
return added_res; //最终返回范围内的近邻个数
}
依然利用前述实例的数据来做测试,查找(2.1,3.1)和(2,4.5)两点的最近邻,并查找距离在4以内的所有近邻。程序运行结果如下:
图2 查找(2.1,3.1)的结果 图3 查找(2,4.5)的结果
附件:http://files.cnblogs.com/eyeszjwang/kdtree.rar
转载请注明:http://www.cnblogs.com/eyeszjwang/articles/2432465.html