BZOJ2626 JZPFAR及kd-tree入门

时间:2022-06-22 20:49:48

Problem

BZOJ权限题
洛谷

KD-tree简要介绍

有一篇讲得比较好的入门文章

kd-tree就是一个按照面的划分来进行二分节点的一种数据结构,它可以解决一些偏序问题和维护k维空间的点的相关问题。
它就是树的每一层都按照不同的维度把当前面二分,为了平衡,我们一般取中位数。取中位数的一个比较好的函数是nth_element(*l,*m,*r),它是一个algorithm自带的部分排序函数,可以把[l,r)区间内的第m小的值放在mid位,并保证左边的都比第m位小,右边的都比第m位大,它的期望时间复杂度是 O ( n )
其次呢,如果要带插入怎么办,由于同一层的划分依据一样,我们不便打乱它,我们不考虑旋转的数据结构,那么我们就可以用替罪羊来维护这个东西(代码下附)。

其他剪枝技巧
文章中提到我们可以优先选取的划分维度是方差比较大的维度,再如此不断选取更小一点的,轮流来。但是实现比较麻烦,我们一般不用。
更优的估价函数
Code

inline int balance(int x){return tr[x].sz*alpha>=tr[tr[x].ls].sz&&tr[x].sz*alpha>=tr[tr[x].rs].sz;}
void dfs(int k,int num)
{
    if(tr[k].ls) dfs(tr[k].ls,num);
    p[tr[tr[k].ls].sz+num+1]=tr[k].tp;rub[++top]=k;
    if(tr[k].rs) dfs(tr[k].rs,tr[tr[k].ls].sz+num+1);
}
void insert(int &k,point tmp,int wd)
{
    if(!k){k=newnode();tr[k].ls=tr[k].rs=0;tr[k].tp=tmp;pushup(k);return ;}
    if(tmp.x[wd]<=tr[k].tp.x[wd]) insert(tr[k].ls,tmp,wd^1);
    else insert(tr[k].rs,tmp,wd^1);
    pushup(k);
    if(!balance(k)) ptr=&k,ind=wd;//找深度最小需要重构的位置,为了方便,用了指针
}
void ins(point tmp)
{
    ptr=NULL;
    insert(rt,tmp,0);
    if(ptr!=NULL)
    {
        dfs(*ptr,0);
        *ptr=build(1,tr[*ptr].sz,ind);
    }
}

Solution

hdu4347和这道题也很类似。

我们用kd-tree来做,用一个堆来维护可能成为k远的几个点,不断更新答案。优先搜可能更远的子树,再根据估价函数判断需不需要搜另一个子树。注意重载小于号和估价函数不要写错了。注意当现在的节点少于k时仍然不能进入空子树,否则可能会影响答案。

Code

#include <algorithm>
#include <cstdio>
#include <cmath>
#define ok(d,x) (d==h.a[1].dis&&tr[x].mnid<h.a[1].id)
#define abs(x) ((x)>=0?(x):(-(x)))
#define sqr(x) ((ll)(x)*(x))
using namespace std;
typedef long long ll;
const int maxn=100010;
int n,m,mm,rt,cur,WD;
struct point{
    int x[2],id;ll dis;
    bool operator < (const point &t)const
    {
        if(dis==t.dis) return id<t.id;
        return dis>t.dis;
    }
    void getdis(const point &t)
    {
        ll res=0;
        for(int i=0;i<2;i++) res+=sqr(x[i]-t.x[i]);
        dis=res;
    }
}now,p[maxn];
struct data{int mn[2],mx[2],mnid,lc,rc;point tp;}tr[maxn];
struct Heap{
    int top;point a[30];
    void clear(){top=0;}
    void push(point t){a[++top]=t;push_heap(a+1,a+top+1);}
    void pop(){pop_heap(a+1,a+top+1);top--;}
}h;
template <typename Tp> inline void read(Tp &x)
{
    x=0;int f=0;char ch=getchar();
    while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
    if(ch=='-') f=1,ch=getchar();
    while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    if(f) x=-x;
}
template <typename Tp>inline void getmin(Tp &x,Tp y){if(y<x) x=y;}
template <typename Tp>inline void getmax(Tp &x,Tp y){if(y>x) x=y;}
inline int cmp(const point &a,const point &b){return a.x[WD]<b.x[WD];}
void pushup(int x)
{
    int l=tr[x].lc,r=tr[x].rc;
    tr[x].mnid=min(tr[l].mnid,tr[r].mnid);
    for(int i=0;i<2;i++)
    {
        tr[x].mn[i]=tr[x].mx[i]=tr[x].tp.x[i];
        if(l) getmin(tr[x].mn[i],tr[l].mn[i]),getmax(tr[x].mx[i],tr[l].mx[i]);
        if(r) getmin(tr[x].mn[i],tr[r].mn[i]),getmax(tr[x].mx[i],tr[r].mx[i]);
    }
}
int build(int l,int r,int wd)
{
    if(l>r) return 0;
    int mid=(l+r)>>1,x=++cur;
    WD=wd;nth_element(p+l,p+mid,p+r+1,cmp);tr[x].tp=p[mid];
    tr[x].lc=build(l,mid-1,wd^1);tr[x].rc=build(mid+1,r,wd^1);
    pushup(x);
    return x;
}
ll check(int x)
{
    ll res=0;
    for(int i=0;i<2;i++)
      res+=sqr(max(abs(tr[x].mx[i]-now.x[i]),abs(now.x[i]-tr[x].mn[i])));
    return res;
}
void query(int x)
{
    tr[x].tp.getdis(now);h.push(tr[x].tp);
    if(h.top>mm) h.pop();
    ll dl=-1,dr=-1;
    if(tr[x].lc) dl=check(tr[x].lc);
    if(tr[x].rc) dr=check(tr[x].rc);
    if(dl>dr)
    {
        if(tr[x].lc&&(ok(dl,tr[x].lc)||dl>h.a[1].dis||h.top<mm)) query(tr[x].lc);
        if(tr[x].rc&&(ok(dr,tr[x].lc)||dr>h.a[1].dis||h.top<mm)) query(tr[x].rc);
    }
    else
    {
        if(tr[x].rc&&(ok(dl,tr[x].rc)||dr>h.a[1].dis||h.top<mm)) query(tr[x].rc);
        if(tr[x].lc&&(ok(dr,tr[x].lc)||dl>h.a[1].dis||h.top<mm)) query(tr[x].lc);
    }
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("in.txt","r",stdin);
    #endif
    read(n);
    for(int i=1;i<=n;i++){read(p[i].x[0]);read(p[i].x[1]);p[i].id=i;}
    rt=build(1,n,0);read(m);
    for(int i=1;i<=m;i++)
    {
        read(now.x[0]);read(now.x[1]);read(mm);
        query(rt);printf("%d\n",h.a[1].id);
        h.clear();
    }
    return 0;
}