平衡树之splay总结

时间:2022-10-22 04:29:45

前置芝士:

平衡树:可以自平衡的二叉排序树,任然具有 左儿子<父亲<右儿子 的特点,且可保证不会退化成链,保证时间复杂度为(nlogn)

旋转:我的splay中只存在上旋(即将某个节点向上旋转),不区分左旋和右旋

前驱:比某个数小的最大数

后驱:比某个数大的最小数

平衡树的定义:

平衡树之splay总结平衡树之splay总结
ll root=0,decnt=0;//root表示splay的根节点 decnt代表新建节点编号 
ll ch[maxn][2],size[maxn],cnt[maxn],val[maxn],prt[maxn],rev[maxn];
//ch[v][0]表示v的左儿子 ch[v][1]表示v的右儿子 prt[v]表示v的父亲 
//val[v]表示v的值 size[v]表示以v为根节点的子树的节点总数 cnt[v]表示值为val[v]的点的个数
//rev[v]==1时代表要区间翻转 rev[v]==0时表示不需要区间翻转 
View Code

更新:

平衡树之splay总结平衡树之splay总结
void pushup(ll v){size[v]=size[ch[v][0]]+size[ch[v][1]]+cnt[v];}
View Code

然后是所有平衡树都会用到的旋转操作:

平衡树之splay总结平衡树之splay总结
void rotate(ll v){
    ll y=prt[v],z=prt[y],d=chk(v),k=ch[v][d^1];
    ch[y][d]=k;prt[k]=y;
    ch[z][chk(y)]=v;prt[v]=z;
    ch[v][d^1]=y;prt[y]=v;
    pushup(y),pushup(v);
}
View Code

接下来就是splay的核心操作------splay操作  本质就是把一个节点旋到某个节点的儿子处(默认为0的儿子,即旋到根节点):

平衡树之splay总结平衡树之splay总结
void splay(ll cur,ll v=0){
    while(prt[cur]!=v){
        ll pr=prt[cur];
        if(prt[pr]!=v){
            if(chk(cur)==chk(pr))rotate(pr);
            else rotate(cur);
        }
        rotate(cur);
    }
    if(!v)root=cur;
}
View Code

插入操作

平衡树之splay总结平衡树之splay总结
void insert(ll x){
    ll cur=root,p=0;
    while(cur&&x!=val[cur])p=cur,cur=ch[cur][x>val[cur]];
    if(cur)cnt[cur]++;
    else{
        cur=++decnt;
        if(p)ch[p][x>val[p]]=cur;
        ch[cur][0]=ch[cur][1]=0;
        val[cur]=x;prt[cur]=p;
        size[cur]=cnt[cur]=1;
    }
    splay(cur);
}
View Code

查找操作,即找到某个节点并把他旋转到根节点

平衡树之splay总结平衡树之splay总结
void find(ll x){
    ll cur=root;
    while(ch[cur][x>val[cur]]&&x!=val[cur])cur=ch[cur][x>val[cur]];
    splay(cur);
}
View Code

求某个树的排名,只需要把他旋转到根节点,排名就是他的左子树的节点数+1

平衡树之splay总结平衡树之splay总结
ll rank(ll x){
    find(x);return size[ch[root][0]];
    //本来应该是返回 size[ch[root][0]]+1但为了避免溢出,我先insert了inf和-inf,所以排名就应该-1 
}
View Code

求第k大(调用时应该是kth(k+1),原因同上)

平衡树之splay总结平衡树之splay总结
ll kth(ll k){
    ll cur=root;
    while(true){
        pushdown(cur);
        if(ch[cur][0]&&size[ch[cur][0]]>=k)cur=ch[cur][0];
        else if(ch[cur][1]&&size[ch[cur][0]]+cnt[cur]<k)k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
        else return cur;
    }
    return cur;
}
View Code

求前驱,把这个数旋到根,并在左子树中找最大值

平衡树之splay总结平衡树之splay总结
ll pre(ll x){
    find(x);
    if(val[root]<x)return root;//特判一下,防止出现查找不存在的数的情况 
    ll cur=ch[root][0];
    while(ch[cur][1])cur=ch[cur][1];
    return cur;
}
View Code

求后驱,把这个数旋到根,并在右子树中找最小值

平衡树之splay总结平衡树之splay总结
ll succ(ll x){
    find(x);
    if(val[root]>x)return root;//特判一下,防止出现查找不存在的数的情况
    ll cur=ch[root][1];
    while(ch[cur][0])cur=ch[cur][0];
    return cur;
}
View Code

删除某个数,只需要把他的前驱旋转到根,把他的后驱旋转到根的左儿子,因为大于他的前驱,所以他在根的右子树,又因为他小于后驱且除他之外没有小于后驱而大于前驱的数,所以他的后驱的左子树只有他一个节点

平衡树之splay总结平衡树之splay总结
void remove(ll x){
    ll lst=pre(x),nxt=succ(x);
    splay(lst),splay(nxt,lst);
    ll del=ch[nxt][0];
    if(cnt[del]>1){
        cnt[del]--;splay(del);
    }else ch[nxt][0]=0;
}
View Code

区间翻转,打标记就好了

平衡树之splay总结平衡树之splay总结
void reverse(ll l,ll r){   //这个只在所有节点编号为1~n的时候能用
    ll x=kth(l),y=kth(r+2);
    splay(x),splay(y,x);
    rev[ch[y][0]]^=1;
}
View Code

输出序列,就中序遍历一遍就好了

平衡树之splay总结平衡树之splay总结
void print(ll v){
    if(!v)return;
    pushdown(v);
    print(ch[v][0]);
    if(val[v]!=inf&&val[v]!=inf)for(ll i=1;i<=cnt[v];i++)printf("%lld ",val[v]);
    print(ch[v][1]);
}
View Code

综上,splay的代码如下

平衡树之splay总结平衡树之splay总结
namespace splay{
    const ll inf=1ll<<30;
    const ll maxn=200010;
    ll root=0,decnt=0;//root表示splay的根节点 decnt代表新建节点编号 
    ll ch[maxn][2],size[maxn],cnt[maxn],val[maxn],prt[maxn],rev[maxn];
    //ch[v][0]表示v的左儿子 ch[v][1]表示v的右儿子 prt[v]表示v的父亲 
    //val[v]表示v的值 size[v]表示以v为根节点的子树的节点总数 cnt[v]表示值为val[v]的点的个数
    //rev[v]==1时代表要区间翻转 rev[v]==0时表示不需要区间翻转 
    ll chk(ll v){return ch[prt[v]][1]==v;}
    void swap(ll &a,ll &b){a^=b^=a^=b;}
    void pushup(ll v){size[v]=size[ch[v][0]]+size[ch[v][1]]+cnt[v];}
    void pushdown(ll v){
        if(rev[v]){
            swap(ch[v][0],ch[v][1]);
            rev[ch[v][0]]^=1,rev[ch[v][1]]^=1;
            rev[v]=0;
        }
    }
    void rotate(ll v){
        ll y=prt[v],z=prt[y],d=chk(v),k=ch[v][d^1];
        ch[y][d]=k;prt[k]=y;
        ch[z][chk(y)]=v;prt[v]=z;
        ch[v][d^1]=y;prt[y]=v;
        pushup(y),pushup(v);
    }
    void splay(ll cur,ll v=0){
        while(prt[cur]!=v){
            ll pr=prt[cur];
            if(prt[pr]!=v){
                if(chk(cur)==chk(pr))rotate(pr);
                else rotate(cur);
            }
            rotate(cur);
        }
        if(!v)root=cur;
    }
    void insert(ll x){
        ll cur=root,p=0;
        while(cur&&x!=val[cur])p=cur,cur=ch[cur][x>val[cur]];
        if(cur)cnt[cur]++;
        else{
            cur=++decnt;
            if(p)ch[p][x>val[p]]=cur;
            ch[cur][0]=ch[cur][1]=0;
            val[cur]=x;prt[cur]=p;
            size[cur]=cnt[cur]=1;
        }
        splay(cur);
    }
    void find(ll x){
        ll cur=root;
        while(ch[cur][x>val[cur]]&&x!=val[cur])cur=ch[cur][x>val[cur]];
        splay(cur);
    }
    ll rank(ll x){
        find(x);return size[ch[root][0]];
        //本来应该是返回 size[ch[root][0]]+1但为了避免溢出,我先insert了inf和-inf,所以排名就应该-1 
    }
    ll kth(ll k){
        ll cur=root;
        while(true){
            pushdown(cur);
            if(ch[cur][0]&&size[ch[cur][0]]>=k)cur=ch[cur][0];
            else if(ch[cur][1]&&size[ch[cur][0]]+cnt[cur]<k)k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
            else return cur;
        }
        return cur;
    }
    ll pre(ll x){
        find(x);
        if(val[root]<x)return root;//特判一下,防止出现查找不存在的数的情况 
        ll cur=ch[root][0];
        while(ch[cur][1])cur=ch[cur][1];
        return cur;
    }
    ll succ(ll x){
        find(x);
        if(val[root]>x)return root;//特判一下,防止出现查找不存在的数的情况
        ll cur=ch[root][1];
        while(ch[cur][0])cur=ch[cur][0];
        return cur;
    }
    void remove(ll x){
        ll lst=pre(x),nxt=succ(x);
        splay(lst),splay(nxt,lst);
        ll del=ch[nxt][0];
        if(cnt[del]>1){
            cnt[del]--;splay(del);
        }else ch[nxt][0]=0;
    }
    void reverse(ll l,ll r){
        ll x=kth(l),y=kth(r+2);
        splay(x),splay(y,x);
        rev[ch[y][0]]^=1;
    }
    void print(ll v){
        if(!v)return;
        pushdown(v);
        print(ch[v][0]);
        if(val[v]!=inf&&val[v]!=inf)for(ll i=1;i<=cnt[v];i++)printf("%lld ",val[v]);
        print(ch[v][1]);
    }
}
View Code

Luogu P3369 【模板】普通平衡树

平衡树之splay总结平衡树之splay总结
#include<cstdio>
#define ll long long
namespace splay{
    const ll inf=1ll<<30;
    const ll maxn=200010;
    ll root=0,decnt=0;//root表示splay的根节点 decnt代表新建节点编号 
    ll ch[maxn][2],size[maxn],cnt[maxn],val[maxn],prt[maxn],rev[maxn];
    //ch[v][0]表示v的左儿子 ch[v][1]表示v的右儿子 prt[v]表示v的父亲 
    //val[v]表示v的值 size[v]表示以v为根节点的子树的节点总数 cnt[v]表示值为val[v]的点的个数
    //rev[v]==1时代表要区间翻转 rev[v]==0时表示不需要区间翻转 
    ll chk(ll v){return ch[prt[v]][1]==v;}
    void swap(ll &a,ll &b){a^=b^=a^=b;}
    void pushup(ll v){size[v]=size[ch[v][0]]+size[ch[v][1]]+cnt[v];}
    void pushdown(ll v){
        if(rev[v]){
            swap(ch[v][0],ch[v][1]);
            rev[ch[v][0]]^=1,rev[ch[v][1]]^=1;
            rev[v]=0;
        }
    }
    void rotate(ll v){
        ll y=prt[v],z=prt[y],d=chk(v),k=ch[v][d^1];
        ch[y][d]=k;prt[k]=y;
        ch[z][chk(y)]=v;prt[v]=z;
        ch[v][d^1]=y;prt[y]=v;
        pushup(y),pushup(v);
    }
    void splay(ll cur,ll v=0){
        while(prt[cur]!=v){
            ll pr=prt[cur];
            if(prt[pr]!=v){
                if(chk(cur)==chk(pr))rotate(pr);
                else rotate(cur);
            }
            rotate(cur);
        }
        if(!v)root=cur;
    }
    void insert(ll x){
        ll cur=root,p=0;
        while(cur&&x!=val[cur])p=cur,cur=ch[cur][x>val[cur]];
        if(cur)cnt[cur]++;
        else{
            cur=++decnt;
            if(p)ch[p][x>val[p]]=cur;
            ch[cur][0]=ch[cur][1]=0;
            val[cur]=x;prt[cur]=p;
            size[cur]=cnt[cur]=1;
        }
        splay(cur);
    }
    void find(ll x){
        ll cur=root;
        while(ch[cur][x>val[cur]]&&x!=val[cur])cur=ch[cur][x>val[cur]];
        splay(cur);
    }
    ll rank(ll x){
        find(x);return size[ch[root][0]];
        //本来应该是返回 size[ch[root][0]]+1但为了避免溢出,我先insert了inf和-inf,所以排名就应该-1 
    }
    ll kth(ll k){
        ll cur=root;
        while(true){
            pushdown(cur);
            if(ch[cur][0]&&size[ch[cur][0]]>=k)cur=ch[cur][0];
            else if(ch[cur][1]&&size[ch[cur][0]]+cnt[cur]<k)k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
            else return cur;
        }
        return cur;
    }
    ll pre(ll x){
        find(x);
        if(val[root]<x)return root;//特判一下,防止出现查找不存在的数的情况 
        ll cur=ch[root][0];
        while(ch[cur][1])cur=ch[cur][1];
        return cur;
    }
    ll succ(ll x){
        find(x);
        if(val[root]>x)return root;//特判一下,防止出现查找不存在的数的情况
        ll cur=ch[root][1];
        while(ch[cur][0])cur=ch[cur][0];
        return cur;
    }
    void remove(ll x){
        ll lst=pre(x),nxt=succ(x);
        splay(lst),splay(nxt,lst);
        ll del=ch[nxt][0];
        if(cnt[del]>1){
            cnt[del]--;splay(del);
        }else ch[nxt][0]=0;
    }
    void reverse(ll l,ll r){
        ll x=kth(l),y=kth(r+2);
        splay(x),splay(y,x);
        rev[ch[y][0]]^=1;
    }
    void print(ll v){
        if(!v)return;
        pushdown(v);
        print(ch[v][0]);
        if(val[v]!=inf&&val[v]!=-inf)for(ll i=1;i<=cnt[v];i++)printf("%lld ",val[v]);
        print(ch[v][1]);
    }
}
using namespace splay;
ll n;
int main(){
    scanf("%lld",&n);
    insert(inf);
    insert(-inf);
    while(n--){
        ll opt,x;
        scanf("%lld%lld",&opt,&x);
        switch(opt){
            case 1:{insert(x);break;}
            case 2:{remove(x);break;}
            case 3:{printf("%lld\n",rank(x));break;}
            case 4:{printf("%lld\n",val[kth(x+1)]);break;}
            case 5:{printf("%lld\n",val[pre(x)]);break;}
            case 6:{printf("%lld\n",val[succ(x)]);break;}
        }
    }
    return 0;
}
View Code

Luogu P3391 【模板】文艺平衡树(Splay)

平衡树之splay总结平衡树之splay总结
#include<cstdio>
#define ll long long
ll n,m;
namespace splay{
    const ll inf=1ll<<30;
    const ll maxn=200010;
    ll root=0,decnt=0;//root表示splay的根节点 decnt代表新建节点编号 
    ll ch[maxn][2],size[maxn],cnt[maxn],val[maxn],prt[maxn],rev[maxn];
    //ch[v][0]表示v的左儿子 ch[v][1]表示v的右儿子 prt[v]表示v的父亲 
    //val[v]表示v的值 size[v]表示以v为根节点的子树的节点总数 cnt[v]表示值为val[v]的点的个数
    //rev[v]==1时代表要区间翻转 rev[v]==0时表示不需要区间翻转 
    ll chk(ll v){return ch[prt[v]][1]==v;}
    void swap(ll &a,ll &b){a^=b^=a^=b;}
    void pushup(ll v){size[v]=size[ch[v][0]]+size[ch[v][1]]+cnt[v];}
    void pushdown(ll v){
        if(rev[v]){
            swap(ch[v][0],ch[v][1]);
            rev[ch[v][0]]^=1,rev[ch[v][1]]^=1;
            rev[v]=0;
        }
    }
    void rotate(ll v){
        ll y=prt[v],z=prt[y],d=chk(v),k=ch[v][d^1];
        ch[y][d]=k;prt[k]=y;
        ch[z][chk(y)]=v;prt[v]=z;
        ch[v][d^1]=y;prt[y]=v;
        pushup(y),pushup(v);
    }
    void splay(ll cur,ll v=0){
        while(prt[cur]!=v){
            ll pr=prt[cur];
            if(prt[pr]!=v){
                if(chk(cur)==chk(pr))rotate(pr);
                else rotate(cur);
            }
            rotate(cur);
        }
        if(!v)root=cur;
    }
    void insert(ll x){
        ll cur=root,p=0;
        while(cur&&x!=val[cur])p=cur,cur=ch[cur][x>val[cur]];
        if(cur)cnt[cur]++;
        else{
            cur=++decnt;
            if(p)ch[p][x>val[p]]=cur;
            ch[cur][0]=ch[cur][1]=0;
            val[cur]=x;prt[cur]=p;
            size[cur]=cnt[cur]=1;
        }
        splay(cur);
    }
    void find(ll x){
        ll cur=root;
        while(ch[cur][x>val[cur]]&&x!=val[cur])cur=ch[cur][x>val[cur]];
        splay(cur);
    }
    ll rank(ll x){
        find(x);return size[ch[root][0]];
        //本来应该是返回 size[ch[root][0]]+1但为了避免溢出,我先insert了inf和-inf,所以排名就应该-1 
    }
    ll kth(ll k){
        ll cur=root;
        while(true){
            pushdown(cur);
            if(ch[cur][0]&&size[ch[cur][0]]>=k)cur=ch[cur][0];
            else if(ch[cur][1]&&size[ch[cur][0]]+cnt[cur]<k)k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
            else return cur;
        }
        return cur;
    }
    ll pre(ll x){
        find(x);
        if(val[root]<x)return root;//特判一下,防止出现查找不存在的数的情况 
        ll cur=ch[root][0];
        while(ch[cur][1])cur=ch[cur][1];
        return cur;
    }
    ll succ(ll x){
        find(x);
        if(val[root]>x)return root;//特判一下,防止出现查找不存在的数的情况
        ll cur=ch[root][1];
        while(ch[cur][0])cur=ch[cur][0];
        return cur;
    }
    void remove(ll x){
        ll lst=pre(x),nxt=succ(x);
        splay(lst),splay(nxt,lst);
        ll del=ch[nxt][0];
        if(cnt[del]>1){
            cnt[del]--;splay(del);
        }else ch[nxt][0]=0;
    }
    void reverse(ll l,ll r){
        ll x=kth(l),y=kth(r+2);
        splay(x),splay(y,x);
        rev[ch[y][0]]^=1;
    }
    void print(ll v){
        if(!v)return;
        pushdown(v);
        print(ch[v][0]);
        if(val[v]!=inf&&val[v]!=-inf)for(ll i=1;i<=cnt[v];i++)printf("%lld ",val[v]);
        print(ch[v][1]);
    }
}
using namespace splay;
int main(){
    scanf("%lld%lld",&n,&m);
    insert(inf);
    insert(-inf);
    for(ll i=1;i<=n;i++)insert(i);
    while(m--){
        ll l,r;
        scanf("%lld%lld",&l,&r);
        reverse(l,r);
    }
    print(root);
    return 0;
}
View Code