先推荐一篇很好的博客:http://blog.csdn.net/jiangshibiao/article/details/34144829
以下是博客文章节选:
*************************************************************************************************************************************
【KD-TREE介绍】在SYC1999大神的“蛊惑”下,我开始接触这种算法。
首先,大概的概念可以去百度百科。具体实现,我是看RZZ的代码长大的。
我们可以想象在平面上有N个点。首先,按横坐标排序找到最中间的那个点。然后水平划一条线,把平面分成左右两个部分。再递归调用左右两块。注意,在第二次(偶数次)调用的时候,是找到纵坐标中最中间的点,并垂直画一条线。
这样效率看上去很好。维护的时候有点像线段树。每个点记录它的坐标、它辖管的区间4个方向的极值、它的左右(或上下)的两个点的标号。递归两个子树时,注意要up更新这个点辖管的范围。
- inline int cmp(arr a,arr b){return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1];}
- inline void up(int k,int s)
- {
- a[k].min[0]=min(a[k].min[0],a[s].min[0]);
- a[k].max[0]=max(a[k].max[0],a[s].max[0]);
- a[k].min[1]=min(a[k].min[1],a[s].min[1]);
- a[k].max[1]=max(a[k].max[1],a[s].max[1]);
- }
- int build(int l,int r,int dd)
- {
- D=dd;int mid=(l+r)>>1;
- nth_element(a+l+1,a+mid+1,a+r+1,cmp);
- a[mid].min[0]=a[mid].max[0]=a[mid].d[0];
- a[mid].min[1]=a[mid].max[1]=a[mid].d[1];
- if (l!=mid) a[mid].l=build(l,mid-1,dd^1);
- if (mid!=r) a[mid].r=build(mid+1,r,dd^1);
- if (a[mid].l) up(mid,a[mid].l);
- if (a[mid].r) up(mid,a[mid].r);
- return mid;
- }
介绍一下nth_element这个STL。头文件就是algorithm。它相当于快排的一部分,调用格式如上。意思是把第MID个数按cmp放在中间,把比mid“小”的数放在左边,否则放在右边。(注意:不保证左边和右边有序)
上述代码很好理解。
然后先在我要支持加入点,也是类似于线段树的思想:
- void insert(int k)
- {
- int p=root;D=0;
- while (orzSYC)
- {
- up(p,k);
- if (a[k].d[D]<=a[p].d[D]){if (!a[p].l) {a[p].l=k;return;} p=a[p].l;}
- else {if (!a[p].r) {a[p].r=k;return;} p=a[p].r;}
- D^=1;
- }
- }
为什么我忽然觉得是splay的insert操作?就是每次往某个点的左或右(或者上或下)过去。
比如我们要查询与(x,y)最近的点(曼哈顿距离)与其的距离。
- int getdis(int k)
- {
- int res=0;
- if (x<a[k].min[0]) res+=a[k].min[0]-x;
- if (x>a[k].max[0]) res+=x-a[k].max[0];
- if (y<a[k].min[1]) res+=a[k].min[1]-y;
- if (y>a[k].max[1]) res+=y-a[k].max[1];
- return res;
- }
- void ask(int k)
- {
- int d0=abs(a[k].d[0]-x)+abs(a[k].d[1]-y);
- if (d0<ans) ans=d0;
- int dl=(a[k].l)?getdis(a[k].l):INF;
- int dr=(a[k].r)?getdis(a[k].r):INF;
- if (dl<dr){if (dl<ans) ask(a[k].l);if (dr<ans) ask(a[k].r);}
- else {if (dr<ans) ask(a[k].r);if (dl<ans) ask(a[k].l);}
- }
getdis有点像Astar中的“估价函数”。计算(x,y)与当前点范围的差距有多少,然后按顺序遍历左二子和右儿子。这样,如果更新到最优值,就能及时退出。这种算法在随机数据上是lg的,但是在构造数据上约是sqrt的。
*************************************************************************************************************************************
然后我看了这篇博客http://blog.csdn.net/zhjchengfeng5/article/details/7855241
对kd树有个更好的理解。
然后,我写了kd-tree裸题hdu2966
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<vector> #include<queue> #include<stack> using namespace std; #define rep(i,a,n) for (int i=a;i<n;i++) #define per(i,a,n) for (int i=n-1;i>=a;i--) #define pb push_back #define fi first #define se second typedef vector<int> VI; typedef long long ll; typedef pair<int,int> PII; const ll inf=9e18; const ll mod=1000000007; const int maxn=1e5+100; struct arr { ll d[2]; int id; int l,r; }a[maxn*4]; int D; int idx[maxn]; bool cmp(arr a,arr b) { return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1]; } int build(int l,int r,int dd) { D=dd; int mid=(l+r)/2; nth_element(a+l,a+mid,a+r+1,cmp); idx[a[mid].id]=mid; // if (l!=mid) a[mid].l=build(l,mid-1,dd^1); else a[mid].l=0; if (mid!=r) a[mid].r=build(mid+1,r,dd^1); else a[mid].r=0; return mid; } ll x,y,ans; int root; void query(int k,int w) { ll d0=(a[k].d[0]-x)*(a[k].d[0]-x)+(a[k].d[1]-y)*(a[k].d[1]-y); if (d0&&d0<ans) ans=d0; if(a[k].l&&a[k].r) { bool f=!w? (x<=a[k].d[0]):(y<=a[k].d[1]); ll d=!w? (a[k].d[0]-x)*(a[k].d[0]-x):(a[k].d[1]-y)*(a[k].d[1]-y); query(f? a[k].l:a[k].r,w^1); if(d<ans) query(f? a[k].r:a[k].l,w^1); } else if(a[k].l) query(a[k].l,w^1); else if(a[k].r) query(a[k].r,w^1); } int main() { int cas; scanf("%d",&cas); while(cas--) { int n; scanf("%d",&n); rep(i,1,n+1) scanf("%lld%lld",&a[i].d[0],&a[i].d[1]),a[i].id=i; root=build(1,n,0); rep(i,1,n+1) { ans=inf; x=a[idx[i]].d[0],y=a[idx[i]].d[1]; query(root,0); printf("%lld\n",ans); } } return 0; }
然后写了bzoj2648,这个题直接按照hdu2966写会T,他的重要一个东西就是估价函数,先判断左右子树的估价函数值,若大于此时的ans,那么就不在这里面搜了,相当于搜索的剪枝。
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<vector> #include<queue> #include<stack> using namespace std; #define rep(i,a,n) for (int i=a;i<n;i++) #define per(i,a,n) for (int i=n-1;i>=a;i--) #define pb push_back #define fi first #define se second typedef vector<int> VI; typedef long long ll; typedef pair<int,int> PII; const int inf=0x3fffffff; const ll mod=1000000007; const int maxn=5e5+100; struct arr { int d[2],min[2],max[2]; int l,r; }a[maxn*4]; int D; inline int cmp(arr a,arr b){return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1];} inline void up(int k,int s) { a[k].min[0]=min(a[k].min[0],a[s].min[0]); a[k].max[0]=max(a[k].max[0],a[s].max[0]); a[k].min[1]=min(a[k].min[1],a[s].min[1]); a[k].max[1]=max(a[k].max[1],a[s].max[1]); } int build(int l,int r,int dd) { D=dd;int mid=(l+r)>>1; nth_element(a+l+1,a+mid+1,a+r+1,cmp); a[mid].min[0]=a[mid].max[0]=a[mid].d[0]; a[mid].min[1]=a[mid].max[1]=a[mid].d[1]; if (l!=mid) a[mid].l=build(l,mid-1,dd^1); if (mid!=r) a[mid].r=build(mid+1,r,dd^1); if (a[mid].l) up(mid,a[mid].l); if (a[mid].r) up(mid,a[mid].r); return mid; } int x,y,ans; int getdis(int k) //估价函数 { int res=0; if (x<a[k].min[0]) res+=a[k].min[0]-x; if (x>a[k].max[0]) res+=x-a[k].max[0]; if (y<a[k].min[1]) res+=a[k].min[1]-y; if (y>a[k].max[1]) res+=y-a[k].max[1]; return res; } int root; void ask(int k) { int d0=abs(a[k].d[0]-x)+abs(a[k].d[1]-y); if (d0<ans) ans=d0; int dl=(a[k].l)?getdis(a[k].l):inf; int dr=(a[k].r)?getdis(a[k].r):inf; if (dl<dr){if (dl<ans) ask(a[k].l);if (dr<ans) ask(a[k].r);} else {if (dr<ans) ask(a[k].r);if (dl<ans) ask(a[k].l);} } void query(int k) //调用这个函数会T { int d0=abs(a[k].d[0]-x)+abs(a[k].d[1]-y); if (d0<ans) ans=d0; if(a[k].l&&a[k].r) { //bool f=!w? x<=a[k].d[0]:y<=a[k].d[1]; //ll d=!w? abs(a[k].d[0]-x):abs(a[k].d[1]-y); ll d1=getdis(a[k].r),d2=getdis(a[k].l); if(d1<d2) { query(a[k].r); if(d2<ans) query(a[k].l); } else { query(a[k].l); if(d2<ans) query(a[k].r); } } else if(a[k].l) query(a[k].l); else if(a[k].r) query(a[k].r); } void insert(int k) //插入函数 { int p=root;D=0; while (1) { up(p,k); if (a[k].d[D]<=a[p].d[D]){if (!a[p].l) {a[p].l=k;return;} p=a[p].l;} else {if (!a[p].r) {a[p].r=k;return;} p=a[p].r;} D^=1; } } int main() { int n,m; scanf("%d%d",&n,&m); rep(i,1,n+1) scanf("%d%d",&a[i].d[0],&a[i].d[1]); root=build(1,n,0); int cnt=n; while(m--) { int op; scanf("%d%d%d",&op,&x,&y); if(op==1) { cnt++; a[cnt].max[0]=a[cnt].min[0]=a[cnt].d[0]=x; a[cnt].max[1]=a[cnt].min[1]=a[cnt].d[1]=y; insert(cnt); } else if(op==2) { ans=inf; query(root); printf("%d\n",ans); } } return 0; }
所以,hdu2966也可以弄一个类似的估价函数,不过交上去跑的好像还慢些==
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<vector> #include<queue> #include<stack> #include<cmath> using namespace std; #define rep(i,a,n) for (int i=a;i<n;i++) #define per(i,a,n) for (int i=n-1;i>=a;i--) #define pb push_back #define fi first #define se second typedef vector<int> VI; typedef long long ll; typedef pair<int,int> PII; const ll inf=9e18; const ll mod=1000000007; const int maxn=1e5+100; struct arr { ll d[2],min[2],max[2]; int l,r,id; }a[maxn*4]; int D; inline int cmp(arr a,arr b){return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1];} inline void up(int k,int s) { a[k].min[0]=min(a[k].min[0],a[s].min[0]); a[k].max[0]=max(a[k].max[0],a[s].max[0]); a[k].min[1]=min(a[k].min[1],a[s].min[1]); a[k].max[1]=max(a[k].max[1],a[s].max[1]); } int idx[maxn]; int build(int l,int r,int dd) { D=dd;int mid=(l+r)>>1; nth_element(a+l,a+mid,a+r+1,cmp); idx[a[mid].id]=mid; a[mid].min[0]=a[mid].max[0]=a[mid].d[0]; a[mid].min[1]=a[mid].max[1]=a[mid].d[1]; if (l!=mid) a[mid].l=build(l,mid-1,dd^1); else a[mid].l=0; if (mid!=r) a[mid].r=build(mid+1,r,dd^1); else a[mid].r=0; if (a[mid].l) up(mid,a[mid].l); if (a[mid].r) up(mid,a[mid].r); return mid; } ll x,y,ans; ll getdis(int k) //估价函数 { ll res=0; if (x<a[k].min[0]) res+=(a[k].min[0]-x)*(a[k].min[0]-x); if (x>a[k].max[0]) res+=(x-a[k].max[0])*(x-a[k].max[0]); if (y<a[k].min[1]) res+=(a[k].min[1]-y)*(a[k].min[1]-y); if (y>a[k].max[1]) res+=(y-a[k].max[1])*(y-a[k].max[1]); return res; } int root; void ask(int k) { ll d0=(a[k].d[0]-x)*(a[k].d[0]-x)+(a[k].d[1]-y)*(a[k].d[1]-y); if (d0&&d0<ans) ans=d0; ll dl=(a[k].l)?getdis(a[k].l):inf; ll dr=(a[k].r)?getdis(a[k].r):inf; if (dl<dr){if (dl<ans) ask(a[k].l);if (dr<ans) ask(a[k].r);} else {if (dr<ans) ask(a[k].r);if (dl<ans) ask(a[k].l);} } int main() { int cas; scanf("%d",&cas); while(cas--) { int n; scanf("%d",&n); rep(i,1,n+1) scanf("%lld%lld",&a[i].d[0],&a[i].d[1]),a[i].id=i; root=build(1,n,0); rep(i,1,n+1) { ans=inf; x=a[idx[i]].d[0],y=a[idx[i]].d[1]; ask(root); printf("%lld\n",ans); } } return 0; }