kd-tree学习&&hdu2966&&bzoj2648

时间:2022-12-17 17:02:59

先推荐一篇很好的博客:http://blog.csdn.net/jiangshibiao/article/details/34144829

以下是博客文章节选:


*************************************************************************************************************************************

【KD-TREE介绍】在SYC1999大神的“蛊惑”下,我开始接触这种算法

首先,大概的概念可以去百度百科。具体实现,我是看RZZ的代码长大的。

我们可以想象在平面上有N个点。首先,按横坐标排序找到最中间的那个点。然后水平划一条线,把平面分成左右两个部分。再递归调用左右两块。注意,在第二次(偶数次)调用的时候,是找到纵坐标中最中间的点,并垂直画一条线。

这样效率看上去很好。维护的时候有点像线段树。每个点记录它的坐标、它辖管的区间4个方向的极值、它的左右(或上下)的两个点的标号。递归两个子树时,注意要up更新这个点辖管的范围。

[cpp]  view plain  copy   kd-tree学习&&hdu2966&&bzoj2648 kd-tree学习&&hdu2966&&bzoj2648
  1. inline int cmp(arr a,arr b){return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1];}  
  2. inline void up(int k,int s)  
  3. {  
  4.   a[k].min[0]=min(a[k].min[0],a[s].min[0]);  
  5.   a[k].max[0]=max(a[k].max[0],a[s].max[0]);  
  6.   a[k].min[1]=min(a[k].min[1],a[s].min[1]);  
  7.   a[k].max[1]=max(a[k].max[1],a[s].max[1]);  
  8. }  
  9. int build(int l,int r,int dd)  
  10. {  
  11.   D=dd;int mid=(l+r)>>1;  
  12.   nth_element(a+l+1,a+mid+1,a+r+1,cmp);  
  13.   a[mid].min[0]=a[mid].max[0]=a[mid].d[0];  
  14.   a[mid].min[1]=a[mid].max[1]=a[mid].d[1];  
  15.   if (l!=mid) a[mid].l=build(l,mid-1,dd^1);  
  16.   if (mid!=r) a[mid].r=build(mid+1,r,dd^1);  
  17.   if (a[mid].l) up(mid,a[mid].l);  
  18.   if (a[mid].r) up(mid,a[mid].r);  
  19.   return mid;  
  20. }  

介绍一下nth_element这个STL。头文件就是algorithm。它相当于快排的一部分,调用格式如上。意思是把第MID个数按cmp放在中间,把比mid“小”的数放在左边,否则放在右边。(注意:不保证左边和右边有序)

上述代码很好理解。

然后先在我要支持加入点,也是类似于线段树的思想:

[cpp]  view plain  copy   kd-tree学习&&hdu2966&&bzoj2648 kd-tree学习&&hdu2966&&bzoj2648
  1. void insert(int k)  
  2. {  
  3.   int p=root;D=0;  
  4.   while (orzSYC)  
  5.   {  
  6.     up(p,k);  
  7.     if (a[k].d[D]<=a[p].d[D]){if (!a[p].l) {a[p].l=k;return;} p=a[p].l;}  
  8.     else {if (!a[p].r) {a[p].r=k;return;} p=a[p].r;}  
  9.     D^=1;  
  10.   }  
  11. }  

为什么我忽然觉得是splay的insert操作?就是每次往某个点的左或右(或者上或下)过去。

比如我们要查询与(x,y)最近的点(曼哈顿距离)与其的距离。

[cpp]  view plain  copy   kd-tree学习&&hdu2966&&bzoj2648 kd-tree学习&&hdu2966&&bzoj2648
  1. int getdis(int k)  
  2. {  
  3.   int res=0;  
  4.   if (x<a[k].min[0]) res+=a[k].min[0]-x;  
  5.   if (x>a[k].max[0]) res+=x-a[k].max[0];  
  6.   if (y<a[k].min[1]) res+=a[k].min[1]-y;  
  7.   if (y>a[k].max[1]) res+=y-a[k].max[1];  
  8.   return res;  
  9. }  
  10. void ask(int k)  
  11. {  
  12.   int d0=abs(a[k].d[0]-x)+abs(a[k].d[1]-y);  
  13.   if (d0<ans) ans=d0;  
  14.   int dl=(a[k].l)?getdis(a[k].l):INF;  
  15.   int dr=(a[k].r)?getdis(a[k].r):INF;  
  16.   if (dl<dr){if (dl<ans) ask(a[k].l);if (dr<ans) ask(a[k].r);}  
  17.   else {if (dr<ans) ask(a[k].r);if (dl<ans) ask(a[k].l);}  
  18. }  

getdis有点像Astar中的“估价函数”。计算(x,y)与当前点范围的差距有多少,然后按顺序遍历左二子和右儿子。这样,如果更新到最优值,就能及时退出。这种算法在随机数据上是lg的,但是在构造数据上约是sqrt的。

*************************************************************************************************************************************



然后我看了这篇博客http://blog.csdn.net/zhjchengfeng5/article/details/7855241 

对kd树有个更好的理解。



然后,我写了kd-tree裸题hdu2966

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define fi first
#define se second
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
const ll inf=9e18;
const ll mod=1000000007;
const int maxn=1e5+100;
struct arr
{
    ll d[2];
    int id;
    int l,r;
}a[maxn*4];
int D;
int idx[maxn];
bool cmp(arr a,arr b)
{
    return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1];
}

int build(int l,int r,int dd)
{
    D=dd;
    int mid=(l+r)/2;
    nth_element(a+l,a+mid,a+r+1,cmp);
    idx[a[mid].id]=mid;  //
    if (l!=mid) a[mid].l=build(l,mid-1,dd^1);
    else a[mid].l=0;
    if (mid!=r) a[mid].r=build(mid+1,r,dd^1);
    else a[mid].r=0;
    return mid;
}
ll x,y,ans;

int root;

void query(int k,int w)
{
    ll d0=(a[k].d[0]-x)*(a[k].d[0]-x)+(a[k].d[1]-y)*(a[k].d[1]-y);
    if (d0&&d0<ans) ans=d0;
    if(a[k].l&&a[k].r)
    {
        bool f=!w? (x<=a[k].d[0]):(y<=a[k].d[1]);
        ll d=!w? (a[k].d[0]-x)*(a[k].d[0]-x):(a[k].d[1]-y)*(a[k].d[1]-y);
        query(f? a[k].l:a[k].r,w^1);
        if(d<ans) query(f? a[k].r:a[k].l,w^1);
    }
    else if(a[k].l) query(a[k].l,w^1);
    else if(a[k].r) query(a[k].r,w^1);
}
int main()
{
    int cas;
    scanf("%d",&cas);
    while(cas--)
    {
        int n;
        scanf("%d",&n);
        rep(i,1,n+1)
        scanf("%lld%lld",&a[i].d[0],&a[i].d[1]),a[i].id=i;
        root=build(1,n,0);
        rep(i,1,n+1)
        {
            ans=inf;
            x=a[idx[i]].d[0],y=a[idx[i]].d[1];
            query(root,0);
            printf("%lld\n",ans);
        }
    }
    return 0;
}



然后写了bzoj2648,这个题直接按照hdu2966写会T,他的重要一个东西就是估价函数,先判断左右子树的估价函数值,若大于此时的ans,那么就不在这里面搜了,相当于搜索的剪枝。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define fi first
#define se second
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
const int inf=0x3fffffff;
const ll mod=1000000007;
const int maxn=5e5+100;
struct arr
{
    int d[2],min[2],max[2];
    int l,r;
}a[maxn*4];
int D;
inline int cmp(arr a,arr b){return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1];}
inline void up(int k,int s)
{
    a[k].min[0]=min(a[k].min[0],a[s].min[0]);
    a[k].max[0]=max(a[k].max[0],a[s].max[0]);
    a[k].min[1]=min(a[k].min[1],a[s].min[1]);
    a[k].max[1]=max(a[k].max[1],a[s].max[1]);
}
int build(int l,int r,int dd)
{
    D=dd;int mid=(l+r)>>1;
    nth_element(a+l+1,a+mid+1,a+r+1,cmp);
    a[mid].min[0]=a[mid].max[0]=a[mid].d[0];
    a[mid].min[1]=a[mid].max[1]=a[mid].d[1];
    if (l!=mid) a[mid].l=build(l,mid-1,dd^1);
    if (mid!=r) a[mid].r=build(mid+1,r,dd^1);
    if (a[mid].l) up(mid,a[mid].l);
    if (a[mid].r) up(mid,a[mid].r);
    return mid;
}
int x,y,ans;
int getdis(int k)                  //估价函数
{
    int res=0;
    if (x<a[k].min[0]) res+=a[k].min[0]-x;
    if (x>a[k].max[0]) res+=x-a[k].max[0];
    if (y<a[k].min[1]) res+=a[k].min[1]-y;
    if (y>a[k].max[1]) res+=y-a[k].max[1];
    return res;
}
int root;
void ask(int k)
{
    int d0=abs(a[k].d[0]-x)+abs(a[k].d[1]-y);
    if (d0<ans) ans=d0;
    int dl=(a[k].l)?getdis(a[k].l):inf;
    int dr=(a[k].r)?getdis(a[k].r):inf;
    if (dl<dr){if (dl<ans) ask(a[k].l);if (dr<ans) ask(a[k].r);}
    else {if (dr<ans) ask(a[k].r);if (dl<ans) ask(a[k].l);}
}

void query(int k)         //调用这个函数会T
{
    int d0=abs(a[k].d[0]-x)+abs(a[k].d[1]-y);
    if (d0<ans) ans=d0;
    if(a[k].l&&a[k].r)
    {
        //bool f=!w? x<=a[k].d[0]:y<=a[k].d[1];
        //ll d=!w? abs(a[k].d[0]-x):abs(a[k].d[1]-y);
        ll d1=getdis(a[k].r),d2=getdis(a[k].l);
        if(d1<d2)
        {
            query(a[k].r);
            if(d2<ans) query(a[k].l);
        }
        else
        {
            query(a[k].l);
            if(d2<ans) query(a[k].r);
        }
    }
    else if(a[k].l) query(a[k].l);
    else if(a[k].r) query(a[k].r);
}

void insert(int k)           //插入函数
{
    int p=root;D=0;
    while (1)
    {
        up(p,k);
        if (a[k].d[D]<=a[p].d[D]){if (!a[p].l) {a[p].l=k;return;} p=a[p].l;}
        else {if (!a[p].r) {a[p].r=k;return;} p=a[p].r;}
        D^=1;
    }
}
int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    rep(i,1,n+1)
    scanf("%d%d",&a[i].d[0],&a[i].d[1]);
    root=build(1,n,0);
    int cnt=n;
    while(m--)
    {
        int op;
        scanf("%d%d%d",&op,&x,&y);
        if(op==1)
        {
            cnt++;
            a[cnt].max[0]=a[cnt].min[0]=a[cnt].d[0]=x;
            a[cnt].max[1]=a[cnt].min[1]=a[cnt].d[1]=y;
            insert(cnt);
        }
        else if(op==2)
        {
            ans=inf;
            query(root);
            printf("%d\n",ans);
        }
    }
    return 0;
}


所以,hdu2966也可以弄一个类似的估价函数,不过交上去跑的好像还慢些==

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define fi first
#define se second
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
const ll inf=9e18;
const ll mod=1000000007;
const int maxn=1e5+100;
struct arr
{
    ll d[2],min[2],max[2];
    int l,r,id;
}a[maxn*4];
int D;
inline int cmp(arr a,arr b){return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1];}
inline void up(int k,int s)
{
    a[k].min[0]=min(a[k].min[0],a[s].min[0]);
    a[k].max[0]=max(a[k].max[0],a[s].max[0]);
    a[k].min[1]=min(a[k].min[1],a[s].min[1]);
    a[k].max[1]=max(a[k].max[1],a[s].max[1]);
}
int idx[maxn];
int build(int l,int r,int dd)
{
    D=dd;int mid=(l+r)>>1;
    nth_element(a+l,a+mid,a+r+1,cmp);
    idx[a[mid].id]=mid;
    a[mid].min[0]=a[mid].max[0]=a[mid].d[0];
    a[mid].min[1]=a[mid].max[1]=a[mid].d[1];
    if (l!=mid) a[mid].l=build(l,mid-1,dd^1);
    else a[mid].l=0;
    if (mid!=r) a[mid].r=build(mid+1,r,dd^1);
    else a[mid].r=0;
    if (a[mid].l) up(mid,a[mid].l);
    if (a[mid].r) up(mid,a[mid].r);
    return mid;
}
ll x,y,ans;
ll getdis(int k)                  //估价函数
{
    ll res=0;
    if (x<a[k].min[0]) res+=(a[k].min[0]-x)*(a[k].min[0]-x);
    if (x>a[k].max[0]) res+=(x-a[k].max[0])*(x-a[k].max[0]);
    if (y<a[k].min[1]) res+=(a[k].min[1]-y)*(a[k].min[1]-y);
    if (y>a[k].max[1]) res+=(y-a[k].max[1])*(y-a[k].max[1]);
    return res;
}
int root;
void ask(int k)
{
    ll d0=(a[k].d[0]-x)*(a[k].d[0]-x)+(a[k].d[1]-y)*(a[k].d[1]-y);
    if (d0&&d0<ans) ans=d0;
    ll dl=(a[k].l)?getdis(a[k].l):inf;
    ll dr=(a[k].r)?getdis(a[k].r):inf;
    if (dl<dr){if (dl<ans) ask(a[k].l);if (dr<ans) ask(a[k].r);}
    else {if (dr<ans) ask(a[k].r);if (dl<ans) ask(a[k].l);}
}

int main()
{
    int cas;
    scanf("%d",&cas);
    while(cas--)
    {
        int n;
        scanf("%d",&n);
        rep(i,1,n+1)
        scanf("%lld%lld",&a[i].d[0],&a[i].d[1]),a[i].id=i;
        root=build(1,n,0);
        rep(i,1,n+1)
        {
            ans=inf;
            x=a[idx[i]].d[0],y=a[idx[i]].d[1];
            ask(root);
            printf("%lld\n",ans);
        }
    }
    return 0;
}