LG3835 【模板】可持久化平衡树

时间:2022-12-17 10:13:00

题意

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作(对于各个以往的历史版本):

  1. 插入x数

  2. 删除x数(若有多个相同的数,因只删除一个,如果没有请忽略该操作)

  3. 查询x数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)

  4. 查询排名为x的数

  5. 求x的前驱(前驱定义为小于x,且最大的数,如不存在输出-2147483647)

  6. 求x的后继(后继定义为大于x,且最小的数,如不存在输出2147483647)

和原本平衡树不同的一点是,每一次的任何操作都是基于某一个历史版本,同时生成一个新的版本。(操作3, 4, 5, 6即保持原版本无变化)

每个版本的编号即为操作的序号(版本0即为初始状态,空树)

\(n \leq 5 \times 10^5\)

分析

可以发现非旋Treap的split和merge每次变动的都是一条链。

然后就对这条链可持久化一下就行了。

时空复杂度\(O(n \log n)\)

Hint

注意copy的部分仅限于递归处理的时候,now=0,x=0,y=0这些时候就不用可持久化了,不然会莫名其妙地错。

然后是空间问题,其实3、4、5、6操作不用可持久化,但是平衡树能A就行了,论效率平衡树肯定赶不上其他的做法。

代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<algorithm>
#define rg register
#define il inline
#define co const
template<class T>il T read()
{
    rg T data=0;
    rg int w=1;
    rg char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')
            w=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        data=data*10+ch-'0';
        ch=getchar();
    }
    return data*w;
}
template<class T>il T read(rg T&x)
{
    return x=read<T>();
}
using namespace std;

co int MAXN=5e5*50,MAXM=5e5+7; // edit 2

int root[MAXN],tot;
struct Treap
{
    int ch[MAXN][2],siz[MAXN];
    int val[MAXN],pri[MAXN];
    
    il int newnode(rg int v=0)
    {
        ++tot;
        ch[tot][0]=ch[tot][1]=0,siz[tot]=1;
        val[tot]=v,pri[tot]=rand()|rand()<<16;
        return tot;
    }
    
    il void pushup(rg int now)
    {
        siz[now]=siz[ch[now][0]]+1+siz[ch[now][1]];
    }
    
    il void copy(rg int x,rg int y)
    {
        ch[x][0]=ch[y][0],ch[x][1]=ch[y][1],siz[x]=siz[y];
        val[x]=val[y],pri[x]=pri[y];
    }
    
    il void split(rg int now,rg int v,rg int&x,rg int&y)
    {
        if(!now)
        {
            x=y=0;
            return;
        }
        if(val[now]<=v)
        {
            x=newnode();
            copy(x,now);
            split(ch[x][1],v,ch[x][1],y);
            pushup(x);
        }
        else
        {
            y=newnode();
            copy(y,now);
            split(ch[y][0],v,x,ch[y][0]);
            pushup(y);
        }
    }
    
    il int merge(rg int x,rg int y)
    {
        if(!x||!y) // edit 1
            return x+y;
        rg int now=newnode();
        if(pri[x]<pri[y])
        {
            copy(now,x);
            ch[now][1]=merge(ch[now][1],y);
            pushup(now);
        }
        else
        {
            copy(now,y);
            ch[now][0]=merge(x,ch[now][0]);
            pushup(now);
        }
        return now;
    }
    
    il void ins(rg int&now,rg int v)
    {
        rg int x,y;
        split(now,v,x,y);
        now=merge(x,merge(newnode(v),y));
    }
    
    il void del(rg int&now,rg int v)
    {
        rg int x,y,z;
        split(now,v-1,x,y);
        split(y,v,y,z);
        y=merge(ch[y][0],ch[y][1]);
        now=merge(x,merge(y,z));
    }
    
    il int rank(rg int&now,rg int v)
    {
        rg int x,y;
        split(now,v-1,x,y);
        rg int ans=siz[x]+1;
        now=merge(x,y);
        return ans;
    }
    
    il int kth(rg int now,rg int k)
    {
        if(!now)
            return 0;
        while(k)
        {
            if(siz[ch[now][0]]>=k)
                now=ch[now][0];
            else if(siz[ch[now][0]]+1==k)
                return now;
            else
            {
                k-=siz[ch[now][0]]+1;
                now=ch[now][1];
            }
        }
        return now;
    }
    
    il int pre(rg int&now,rg int v)
    {
        rg int x,y;
        split(now,v-1,x,y);
        rg int ans=kth(x,siz[x]);
        now=merge(x,y);
        return ans;
    }
    
    il int suc(rg int&now,rg int v)
    {
        rg int x,y;
        split(now,v,x,y);
        rg int ans=kth(y,1);
        now=merge(x,y);
        return ans;
    }
}T;

int main()
{
//  freopen(".in","r",stdin);
//  freopen(".out","w",stdout);
    srand(20030506);
    rg int n;
    read(n);
    for(rg int i=1;i<=n;++i)
    {
        rg int v,opt,x;
        read(v);read(opt);read(x);
//      cerr<<"v="<<v<<" opt="<<opt<<" x="<<x<<endl;
        root[i]=root[v];
        if(opt==1) // insert
        {
            T.ins(root[i],x);
        }
        else if(opt==2) // delete
        {
            T.del(root[i],x);
        }
        else if(opt==3) // rank
        {
            printf("%d\n",T.rank(root[i],x));
        }
        else if(opt==4) // kth
        {
            printf("%d\n",T.val[T.kth(root[i],x)]);
        }
        else if(opt==5) // pre
        {
            int ans=T.pre(root[i],x);
            if(ans==0)
                puts("-2147483647");
            else
                printf("%d\n",T.val[ans]);
        }
        else if(opt==6) // suf
        {
            int ans=T.suc(root[i],x);
            if(ans==0)
                puts("2147483647");
            else
                printf("%d\n",T.val[ans]);
        }
    }
    return 0;
}

再分析

然而对这题而言有更优的做法,主席树(可持久化权值线段树)。

将权值离散化,得到了这题较优的做法。

但是要离线,所以也是个问题。不离线的话空间会大一些,问题不大。

时间复杂度\(O(n \log n)\),常数小多了。

Hint

注意调用查询的时候,边界问题。

另外“若有多个相同的数,因只删除一个,如果没有请忽略该操作”。这个神坑点卡了我好久,非旋式Treap会自动忽略不存在的,所以就没管。

再代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<algorithm>
#include<cassert>
#define rg register
#define il inline
#define co const
template<class T>T read()
{
    T data=0;
    int w=1;
    char ch=getchar();
    while(!isdigit(ch))
    {
        if(ch=='-')
            w=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        data=data*10+ch-'0';
        ch=getchar();
    }
    return data*w;
}
template<class T>T read(T&x)
{
    return x=read<T>();
}
using namespace std;
typedef long long ll;

co int MAXN=5e5*20,MAXM=5e5+7;

int v[MAXN],opt[MAXN],x[MAXN];
vector<int>xlist;

int root[MAXM],tot;
struct SegTree
{
    int sumv[MAXN];
    int L[MAXN],R[MAXN];
    
    il void pushup(rg int now)
    {
        sumv[now]=sumv[L[now]]+sumv[R[now]];
//      assert(sumv[now]>=0);
    }
    
    il void copy(rg int x,rg int y)
    {
        sumv[x]=sumv[y];
        L[x]=L[y],R[x]=R[y];
    }
    
    il void modify(rg int&now,rg int l,rg int r,rg int p,rg int v)
    {
        ++tot;
        copy(tot,now);
        now=tot;
        if(l==r)
        {
            sumv[now]+=v;
            return;
        }
        rg int m=(l+r)>>1;
        if(p<=m)
            modify(L[now],l,m,p,v);
        else
            modify(R[now],m+1,r,p,v);
        pushup(now);
    }
    
    il int sum(rg int now,rg int l,rg int r,rg int ql,rg int qr)
    {
        if(ql<=l&&r<=qr)
            return sumv[now];
        rg int m=(l+r)>>1;
        if(qr<=m)
            return sum(L[now],l,m,ql,qr);
        if(ql>=m+1)
            return sum(R[now],m+1,r,ql,qr);
        return sum(L[now],l,m,ql,qr)+sum(R[now],m+1,r,ql,qr);
    }
    
    il int kth(rg int now,rg int l,rg int r,rg int k)
    {
        if(l==r)
            return l;
        rg int m=(l+r)>>1;
        if(sumv[L[now]]>=k)
            return kth(L[now],l,m,k);
        else
        {
            k-=sumv[L[now]];
            return kth(R[now],m+1,r,k);
        }
    }
    
    il int pre(rg int root,rg int p)
    {
        rg int num=p>1?sum(root,1,xlist.size(),1,p-1):0; // edit 1:notice p=1
        if(num==0) // do not exist
            return 0;
        else
            return kth(root,1,xlist.size(),num);
    }
    
    il int suc(rg int root,rg int p)
    {
        rg int num=p<xlist.size()?sum(root,1,xlist.size(),p+1,xlist.size()):0; // edit 2:notice p=xlist.size()
        if(num==0) // do not exist
            return xlist.size()+1;
        else
            return kth(root,1,xlist.size(),sumv[root]-num+1);
    }
}T;

int main()
{
//  freopen(".in","r",stdin);
//  freopen(".out","w",stdout);
    rg int n;
    read(n);
    for(rg int i=1;i<=n;++i)
    {
        read(v[i]);read(opt[i]);read(x[i]);
        if(opt[i]!=4) // unless kth
            xlist.push_back(x[i]);
    }
    sort(xlist.begin(),xlist.end());
    xlist.erase(unique(xlist.begin(),xlist.end()),xlist.end());
    for(rg int i=1;i<=n;++i)
    {
        if(opt[i]!=4)
            x[i]=lower_bound(xlist.begin(),xlist.end(),x[i])-xlist.begin()+1;
//      cerr<<"x "<<i<<" = "<<x[i]<<endl;
        root[i]=root[v[i]];
        if(opt[i]==1) // insert
        {
            T.modify(root[i],1,xlist.size(),x[i],1);
        }
        else if(opt[i]==2) // delete
        {
            // edit 3:if this val do not exist, you should ignore this operation
            if(T.sum(root[i],1,xlist.size(),x[i],x[i])==0)
                continue;
            T.modify(root[i],1,xlist.size(),x[i],-1);
        }
        else if(opt[i]==3) // rank
        {
            printf("%d\n",1+(x[i]>1?T.sum(root[i],1,xlist.size(),1,x[i]-1):0)); // edit 1: notice x[i]=1
        }
        else if(opt[i]==4) // kth
        {
//          assert(1<=x[i]&&x[i]<=T.sumv[root[i]]);
            printf("%d\n",xlist[T.kth(root[i],1,xlist.size(),x[i])-1]);
        }
        else if(opt[i]==5) // pre
        {
            int ans=T.pre(root[i],x[i]);
            if(ans!=0)
                printf("%d\n",xlist[ans-1]);
            else
                puts("-2147483647");
        }
        else if(opt[i]==6) // suc
        {
            int ans=T.suc(root[i],x[i]);
            if(ans!=xlist.size()+1)
                printf("%d\n",xlist[ans-1]);
            else
                puts("2147483647");
        }
    }
    return 0;
}

三分析

其实树状数组也可以做,并且常数更小。

但是空间就必须提前开出来,并且不离线不行了。

然后不用可持久化,可以搞一个dfs。给时间点连上边,dfs的时候就修改+撤销就行了。

第一次知道这么精妙的做法,那线段树、平衡树貌似都可以这么搞,并且空间复杂度大为减小。

分享一下洛谷全站最快代码,by Mr_Spade

三代码

#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#include<cstdio>
#include<algorithm>
#define getchar() in[fin++]
#define putchar(x) out[fout++]=(x)
using std::lower_bound;
using std::sort;
using std::unique;
int fin,fout;
char in[1<<24],out[1<<24];
inline int read()
{
    int res=0;
    bool f=0;
    char x;
    while((x=getchar())<'0'||x>'9')
        f|=x=='-';
    for(;x>='0'&&x<='9';x=getchar())
        res=res*10+x-'0';
    return f?-res:res;
}
inline void write(int x)
{
    if(x<0)
        putchar('-'),x=-x;
    if(x>=10)
        write(x/10);
    putchar(x%10+'0');
    return;
}
const int N=5e5+5;
int n,m,lgn;
int num[N],tot;
int bit[N];
inline void add(int x,int k)
{
    while(x<=n)
        bit[x]+=k,x+=x&-x;
    return;
}
inline int ask(int x)
{
    int res=0;
    while(x)
        res+=bit[x],x&=x-1;
    return res;
}
inline int find(int x)
{
    int res=0;
    for(int i=lgn;~i;i--)
        if((res|1<<i)<=n&&bit[res|1<<i]<x)
            x-=bit[res|=1<<i];
    return res+1;
}
struct oper
{
    int opt,x;
}o[N];
int first[N],next[N];
int ans[N];
void dfs(int now)
{
    int d=0;
    register int go;
    for(go=first[now];go;go=next[go])
    {
        switch(o[go].opt)
        {
            case 1:add(o[go].x,1);break;
            case 2:
                if(ask(o[go].x)^ask(o[go].x-1))
                    add(o[go].x,-1);
                else
                    d=1;
                break;
            case 3:ans[go]=ask(o[go].x-1)+1;break;
            case 4:ans[go]=num[find(o[go].x)];break;
            case 5:
                if(!(d=ask(o[go].x-1)))
                    ans[go]=-0x7fffffff;
                else
                    ans[go]=num[find(d)];
                break;
            case 6:
                if((d=ask(o[go].x))==ask(n))
                    ans[go]=0x7fffffff;
                else
                    ans[go]=num[find(d+1)];
                break;
        }
        dfs(go);
        switch(o[go].opt)
        {
            case 1:add(o[go].x,-1);break;
            case 2:
                if(!d)
                    add(o[go].x,1);
                break;
        }
    }
    return;
}
signed main()
{
    fseek(stdin,0l,2);
    int len=ftell(stdin);
    rewind(stdin);
    fread(in,1,len,stdin);
    int x;
    register int i;
    m=read();
    for(i=1;i<=m;i++)
    {
        next[i]=first[x=read()];first[x]=i;
        o[i].opt=read();o[i].x=read();
        if(o[i].opt^4)
            num[++tot]=o[i].x;
    }
    sort(num+1,num+tot+1);
    n=unique(num+1,num+tot+1)-num-1;
    for(lgn=1;1<<lgn<=n;lgn++);lgn--;
    for(i=1;i<=m;i++)
        if(o[i].opt^4)
            o[i].x=lower_bound(num+1,num+n+1,o[i].x)-num;
    dfs(0);
    for(i=1;i<=m;i++)
        if(o[i].opt!=1&&o[i].opt!=2)
            write(ans[i]),putchar('\n');
    fwrite(out,1,fout,stdout);
    return 0;
}