kd tree学习笔记 (最近邻域查询)

时间:2021-09-08 16:40:19

https://zhuanlan.zhihu.com/p/22557068

http://blog.csdn.net/zhjchengfeng5/article/details/7855241

KD树在算法竞赛中主要用来做各种各样的平面区域查询,包含则累加直接返回,相交则继续递归,相离的没有任何贡献也直接返回。可以处理圆,三角形,矩形等判断起来相对容易的平面区域内的符合加法性质的操作。

比如查询平面内欧几里得距离最近的点的距离。

kdtree其实有点像搜索,暴力+剪枝。

每次从根结点向下搜索,并进行剪枝操作,判断是否有必要继续搜索。

它是通过横一刀,竖一刀,横一刀再竖一刀将平面进行分割,建立二叉树。

建树的复杂度是O(nlogn), 每次用nth_element()在线性时间内取出中位数。 T(n) = 2T(n/2) + O(n) = O(nlogn)

查询复杂度呢? 据第二个链接的博客说最坏是O( sqrt(n) ) 的。并不会分析查询复杂度。

HDU2966 裸kdtree

题意:给平面图上N(1 ≤ N ≤100000)个点,对每个点,找到其他 欧几里德距离 离他最近的点,输出他们之间的距离。保证没有重点。

 #include <bits/stdc++.h>
#define ll long long
using namespace std;
#define N 200010
const ll inf = 1e18;
int n,i,id[N],root,cmp_d;
int x, y;
struct node{int d[],l,r,Max[],Min[],val,sum,f;}t[N];
bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];}
void umax(int&a,int b){if(a<b)a=b;}
void umin(int&a,int b){if(a>b)a=b;}
void up(int x){
if(t[x].l){
umax(t[x].Max[],t[t[x].l].Max[]);
umin(t[x].Min[],t[t[x].l].Min[]);
umax(t[x].Max[],t[t[x].l].Max[]);
umin(t[x].Min[],t[t[x].l].Min[]);
}
if(t[x].r){
umax(t[x].Max[],t[t[x].r].Max[]);
umin(t[x].Min[],t[t[x].r].Min[]);
umax(t[x].Max[],t[t[x].r].Max[]);
umin(t[x].Min[],t[t[x].r].Min[]);
}
}
int build(int l,int r,int D,int f){
int mid=(l+r)>>;
cmp_d=D,std::nth_element(t+l,t+mid,t+r+,cmp);
id[t[mid].f]=mid;
t[mid].f=f;
t[mid].Max[]=t[mid].Min[]=t[mid].d[];
t[mid].Max[]=t[mid].Min[]=t[mid].d[];
//t[mid].val=t[mid].sum=0;
if(l!=mid)t[mid].l=build(l,mid-,!D,mid);else t[mid].l=;
if(r!=mid)t[mid].r=build(mid+,r,!D,mid);else t[mid].r=;
return up(mid),mid;
} ll dis(ll x1, ll y1, ll x, ll y) {
ll xx = x1-x, yy = y1-y;
return xx*xx+yy*yy;
}
ll dis(int p, ll x, ll y){//估价函数, 以p为子树的最小距离
ll xx = , yy = ;
if(t[p].Max[] < x) xx = x-t[p].Max[];
if(t[p].Min[] > x) xx = t[p].Min[]-x;
if(t[p].Max[] < y) yy = y-t[p].Max[];
if(t[p].Min[] > y) yy = t[p].Min[]-y;
return xx*xx+yy*yy;
}
ll ans;
void query(int p){
ll dl = inf, dr = inf, d = dis(t[p].d[], t[p].d[], x, y);
if(d) ans = min(ans, d); if(t[p].l) dl = dis(t[p].l, x, y);
if(t[p].r) dr = dis(t[p].r, x, y);
if(dl < dr){
if(dl < ans) query(t[p].l);
if(dr < ans) query(t[p].r);
}
else {
if(dr < ans) query(t[p].r);
if(dl < ans) query(t[p].l);
}
} int main(){
int T; scanf("%d", &T);
while(T--){
scanf("%d", &n);
for(int i = ; i <= n; i++){
scanf("%d%d", &t[i].d[], &t[i].d[]);
t[i].f = i;
}
int rt = build(, n, , );
for(int i = ; i <= n; i++){
ans = inf;
x = t[ id[i] ].d[], y = t[ id[i] ].d[];
query(rt);
printf("%lld\n", ans);
}
}
return ;
}

BZOJ2648

题意:给出n个点,接下来m个操作,每次插入一个点,或者询问离询问点的最近曼哈顿距离。

 #include <bits/stdc++.h>
#define ll long long
using namespace std;
#define N 1000010
const ll inf = 1e18;
int n,m,i,id[N],root,cmp_d,rt;
int x, y;
struct node{int d[],l,r,Max[],Min[],val,sum,f;}t[N];
bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];}
void umax(int&a,int b){if(a<b)a=b;}
void umin(int&a,int b){if(a>b)a=b;}
void up(int x){
if(t[x].l){
umax(t[x].Max[],t[t[x].l].Max[]);
umin(t[x].Min[],t[t[x].l].Min[]);
umax(t[x].Max[],t[t[x].l].Max[]);
umin(t[x].Min[],t[t[x].l].Min[]);
}
if(t[x].r){
umax(t[x].Max[],t[t[x].r].Max[]);
umin(t[x].Min[],t[t[x].r].Min[]);
umax(t[x].Max[],t[t[x].r].Max[]);
umin(t[x].Min[],t[t[x].r].Min[]);
}
}
int build(int l,int r,int D,int f){
int mid=(l+r)>>;
cmp_d=D,std::nth_element(t+l,t+mid,t+r+,cmp);
id[t[mid].f]=mid;
t[mid].f=f;
t[mid].Max[]=t[mid].Min[]=t[mid].d[];
t[mid].Max[]=t[mid].Min[]=t[mid].d[];
//t[mid].val=t[mid].sum=0;
if(l!=mid)t[mid].l=build(l,mid-,!D,mid);else t[mid].l=;
if(r!=mid)t[mid].r=build(mid+,r,!D,mid);else t[mid].r=;
return up(mid),mid;
} ll dis(ll x1, ll y1, ll x, ll y) {
return abs(x1-x)+abs(y1-y);
//ll xx = x1-x, yy = y1-y;
//return xx*xx+yy*yy;
}
ll dis(int p, ll x, ll y){//估价函数, 以p为子树的最小距离
ll xx = , yy = ;
if(t[p].Max[] < x) xx = x-t[p].Max[];
if(t[p].Min[] > x) xx = t[p].Min[]-x;
if(t[p].Max[] < y) yy = y-t[p].Max[];
if(t[p].Min[] > y) yy = t[p].Min[]-y;
return xx+yy;
//return xx*xx+yy*yy;
}
ll ans;
void ins(int now, int k, int x){
if(t[x].d[k] >= t[now].d[k]){
if(t[now].r) ins(t[now].r, !k, x);
else
t[now].r = x, t[x].f = now;
}
else {
if(t[now].l) ins(t[now].l, !k, x);
else t[now].l = x, t[x].f = now;
}
up(now);
}
void query(int p){
ll dl = inf, dr = inf, d = dis(t[p].d[], t[p].d[], x, y);
ans = min(ans, d); if(t[p].l) dl = dis(t[p].l, x, y);
if(t[p].r) dr = dis(t[p].r, x, y);
if(dl < dr){
if(dl < ans) query(t[p].l);
if(dr < ans) query(t[p].r);
}
else {
if(dr < ans) query(t[p].r);
if(dl < ans) query(t[p].l);
}
} int main(){
scanf("%d%d", &n, &m);
for(int i = ; i <= n; i++)
scanf("%d%d", &t[i].d[], &t[i].d[]);
rt = build(, n, , );
while(m--){
int op;
scanf("%d%d%d", &op, &x, &y);
if(op == ){
n++;
t[n].l = t[n].r = ;
t[n].Max[] = t[n].Min[] = t[n].d[] = x;
t[n].Max[] = t[n].Min[] = t[n].d[] = y;
ins(rt, , n);
}
else{
ans = inf;
query(rt);
printf("%lld\n", ans);
}
}
return ;
}

BZOJ3053

题意:k维坐标系下的最近的m个点。直接对于每一个询问都在kdtree中询问m次最近点,每次找到一个最近点对需要把它记录下来,用堆维护即可。

 #include <bits/stdc++.h>
#define ll long long
#define mp make_pair using namespace std;
#define N 50010
const ll inf = 1e18;
int n,m,k,i,id[N],root,cmp_d,rt;
int x, y, num;
struct node{int d[],l,r,Max[],Min[],val,sum,f;}t[N];
bool cmp(const node&a,const node&b){return a.d[cmp_d]<b.d[cmp_d];}
void umax(int&a,int b){if(a<b)a=b;}
void umin(int&a,int b){if(a>b)a=b;}
void up(int x){
for(int i = ; i < k; i++){
if(t[x].l){
umax(t[x].Max[i],t[t[x].l].Max[i]);
umin(t[x].Min[i],t[t[x].l].Min[i]);
}
if(t[x].r){
umax(t[x].Max[i],t[t[x].r].Max[i]);
umin(t[x].Min[i],t[t[x].r].Min[i]);
}
}
}
int build(int l,int r,int D,int f){
int mid=(l+r)>>;
cmp_d=D,std::nth_element(t+l,t+mid,t+r+,cmp);
id[t[mid].f]=mid;
t[mid].f=f;
for(int i = ; i < k; i++)
t[mid].Max[i]=t[mid].Min[i]=t[mid].d[i];
//t[mid].Max[1]=t[mid].Min[1]=t[mid].d[1];
//t[mid].val=t[mid].sum=0;
if(l!=mid)t[mid].l=build(l,mid-,(D+)%k,mid);else t[mid].l=;
if(r!=mid)t[mid].r=build(mid+,r,(D+)%k,mid);else t[mid].r=;
return up(mid),mid;
}
int qx[];
ll dis(int p){//估价函数, 以p为子树的最小距离
ll ret = , ans = ;
for(int i = ; i < k; i++) {
ret = ;
if(t[p].Max[i] < qx[i]) ret = qx[i]-t[p].Max[i];
if(t[p].Min[i] > qx[i]) ret = t[p].Min[i]-qx[i];
ans += ret*ret;
}
return ans;
}
ll getdis(int p){
ll ans = ;
for(int i = ; i < k; i++)
ans += (qx[i]-t[p].d[i])*(qx[i]-t[p].d[i]);
return ans;
}
void ins(int now, int k, int x){
if(t[x].d[k] >= t[now].d[k]){
if(t[now].r) ins(t[now].r, !k, x);
else
t[now].r = x, t[x].f = now;
}
else {
if(t[now].l) ins(t[now].l, !k, x);
else t[now].l = x, t[x].f = now;
}
up(now);
}
ll ret;
multiset< pair<int, int> > ans;
void query(int p){
ll dl = inf, dr = inf, d = getdis(p);
ans.insert( mp((int)d, p) );
if(ans.size() > num){
multiset< pair<int, int> >::iterator it = ans.end();
it--;
ans.erase(it);
}
ret = (*ans.rbegin()).first;
if(t[p].l) dl = dis(t[p].l);
if(t[p].r) dr = dis(t[p].r);
if(dl < dr){
if(dl < ret||ans.size() < num) query(t[p].l);
if(dr < ret||ans.size() < num) query(t[p].r);
}
else {
if(dr < ret||ans.size() < num) query(t[p].r);
if(dl < ret||ans.size() < num) query(t[p].l);
}
} int main(){
while(~scanf("%d%d", &n, &k)){
for(int i = ; i <= n; i++){
for(int j = ; j < k; j++)
scanf("%d", &t[i].d[j]);
}
rt = build(, n, , );
scanf("%d", &m);
while(m--){
for(int i = ; i < k; i++)
scanf("%d", qx+i);
scanf("%d", &num);
ans.clear();
query(rt);
printf ("the closest %d points are:\n", num);
for(multiset< pair<int, int> >::iterator it = ans.begin(); it != ans.end(); it++){
int pos = (*it).second;
for(int i = ; i < k; i++)
printf("%d%c", t[pos].d[i], " \n"[i == k-]);
}
}
}
return ;
}