CF414E,splay维护括号序列,码农题

时间:2021-05-07 00:10:45

传送门
刚看到这题一脸懵逼,尝试抄题解,死活看不懂。百度别人的中文题解,提到说因为子树移动,所以要用splay维护括号序列。
一棵有根树的括号序列定义我也找不到,脑补了一下,可能是对一个点x,进入时,序列尾部插入1,记下 lx 等于这个1的位置,离开时,序列尾部插入-1,记下 rx 等于这个1的位置,然后一个点到根的经过的点数就是前缀和。
强行yy一波后,发现两个点x,y如果不是祖先关系,那么括号序列里 lx ly 中深度最小的点一定是 lca(x,y) 的儿子
具体实现时,splay的每个节点要保存该节点是1还是-1,这个点(在原树中)的深度,以及(splay中)子树中深度的最值。
子树移动时打标记即可。

#include <cstdio>
#include <vector>
using std::vector;
struct pii{
    int first,second;
    bool operator<(const pii&x) const{
        return first==x.first?second<x.second:first<x.first;
    }
    bool operator>(const pii&x)const{
        return first==x.first?second>x.second:first>x.first;
    }
};
pii make_pair(int a,int b){
    return (pii){a,b};
}
const int N=250000;
int n,x,t,st[N],w,xb,y,dfn[N],id[N],b[N],m,j,dep[N],li[N],ri[N],k,o,q,i,xx,yy,zz;
inline void swap(int&a,int&b){
    t=a;
    a=b;
    b=t;
}
inline pii max(pii a,pii b){
    return a>b?a:b;
}
inline pii min(pii a,pii b){
    return a<b?a:b;
}
struct node{
       int l,r,fa,size,sum,add,v;
       pii ma,mi,x; 
       node(int _v=0):v(_v){size=1,l=r=fa=0;}
}a[N];
struct list{
    int rt;
    list(int _rt=0):rt(_rt){}
    inline void pushdown(int i){
        if(a[i].add){
            if(a[i].l)a[a[i].l].add+=a[i].add;
            if(a[i].r)a[a[i].r].add+=a[i].add;
            a[i].x.first+=a[i].add;
            a[i].ma.first+=a[i].add;
            a[i].mi.first+=a[i].add;
            a[i].add=0;
        }
    }
    inline void maintain(int i){
        pushdown(a[i].l);
        pushdown(a[i].r);
        a[i].size=a[a[i].l].size+a[a[i].r].size+1;
        a[i].sum=a[a[i].l].sum+a[a[i].r].sum+a[i].v;
        a[i].ma=max(max(a[a[i].l].ma,a[a[i].r].ma),a[i].x);
        a[i].mi=a[i].x;
        if(a[i].l)a[i].mi=min(a[a[i].l].mi,a[i].mi);
        if(a[i].r)a[i].mi=min(a[a[i].r].mi,a[i].mi);
    }
    void build(int fa,int&x,int l,int r){
        a[x=++xb].fa=fa;
        int m=(l+r)>>1;
        a[x].v=b[m];
        a[x].x=make_pair(dep[dfn[m]],dfn[m]);
        if(b[m]==1)li[dfn[m]]=xb;
            else ri[dfn[m]]=xb;
        if(m!=l)build(x,a[x].l,l,m-1);
        if(m!=r)build(x,a[x].r,m+1,r);
        maintain(x);
    }
    void rorate(int x){
        int y=a[x].fa,z=a[y].fa;        
        if(z){
            if(a[z].l==y)a[z].l=x;
                else a[z].r=x;
        }
        a[x].fa=z;
        a[y].fa=x;
        if(a[y].l==x){
            a[a[x].r].fa=y;
            a[y].l=a[x].r;
            a[x].r=y;
        }else{
            a[a[x].l].fa=y;
            a[y].r=a[x].l;
            a[x].l=y;
        }
        maintain(y);
        maintain(x);
    }
    void splay(int x){
        int y,z;
        for(w=0,y=x;y;y=a[y].fa)st[++w]=y;
        for(z=w;z;--z)pushdown(st[z]);
        while(a[x].fa){
            y=a[x].fa,z=a[y].fa;
            if(z){                
                if((a[y].l==x)^(a[z].l==y))rorate(x);
                    else rorate(y);
            }
            rorate(x);
        }
        rt=x;
    }
    inline int rank(int i){
        int x=i,ans=a[a[x].l].size+1;
        while(a[x].fa){
            if(a[a[x].fa].r==x)ans+=a[a[a[x].fa].l].size+1;
            x=a[x].fa;
        }
        return ans;
    }
    inline int kth(int i){//1<=i<=rt->size
        int x=rt;
        for(;;){
            pushdown(x);
            int y=a[a[x].l].size+1;
            if(y>i)x=a[x].l;
                else if(y<i){
                    i-=y;
                    x=a[x].r;
                }else return x;
        }
    }
    void operator+=(list b){
        if(!b.rt)return;
        if(!rt){
            rt=b.rt;
            return;
        }
        int x=rt;
        pushdown(x);
        while(a[x].r)x=a[x].r,pushdown(x);
        splay(x);
        a[rt].r=b.rt;
        a[b.rt].fa=rt;
        maintain(rt);
    }
    void midout(int i){
        if(i){
            pushdown(i);
            midout(a[i].l);
            printf("%d %d\n",a[i].x.second,a[i].v);
            midout(a[i].r);
        }
    }
    int split(int v){
        if(!v){
            int x=rt;
            rt=0;
            return x;
        }
        splay(kth(v));
        int x=a[rt].r;
        a[rt].r=0;
        a[x].fa=0;
        maintain(rt);
        return x;
    }
    inline int size(){
        return a[rt].size;
    }
    int sum(int i,int l,int r){
        int x=rank(l);
        list c,b;
        if(x>1){
            c=split(x-1);
            b=c.split(r-l+1);
            x=a[c.rt].sum;
            c+=b;
            *this+=c;
        }else{
            b=split(r-l+1);
            x=a[rt].sum;
            *this+=b;
        }
        return x;
    }
    int query(int i,int y){
        pushdown(i);
        pushdown(a[i].r);
        pushdown(a[i].l);
        if(a[a[i].r].ma.first>=y && a[a[i].r].mi.first<=y)return query(a[i].r,y);
        if(a[i].x.first==y)return a[i].x.second;
        return query(a[i].l,y);
    }
    int query(int i,int x,int y){
        splay(x);
        if(a[x].x.first==y)return a[x].x.second;
            else return query(a[x].l,y);
    }
    int psum(int i){
        splay(i);
        return a[i].x.first;
    }
}u,v,p,l;
struct tree{
    vector<int> g[N];
    inline void addedge(int x,int y){
        g[x].push_back(y);
    }
    void dfs(int x,int fa){
        dep[x]=dep[fa]+1;
        b[++xb]=1;
        dfn[xb]=x;
        for(unsigned int i=0;i<g[x].size();++i)dfs(g[x][i],x);
        b[++xb]=-1;
        dfn[xb]=x;
    }
}g;
int main(){
    a[0].size=0;
    scanf("%d%d",&n,&m);
    for(i=1;i<=n;++i){
        scanf("%d",&j);
        while(j--){
            scanf("%d",&x);
            g.addedge(i,x);
        }
    }
    g.dfs(1,0);
    xb=0;
    l.build(0,l.rt,1,n<<1);
    while(m--){
        scanf("%d",&j);
        if(j==1){
            scanf("%d%d",&x,&y);
            if(x==y){
                puts("0");
                continue;
            }
            o=l.psum(li[x]);
            q=l.psum(li[y]);
            xx=l.rank(li[x]);
            yy=l.rank(li[y]);
            if(xx>yy)swap(x,y),swap(o,q),swap(xx,yy);
            if(l.rank(ri[y])<=l.rank(ri[x])){
                printf("%d\n",q-o);
                continue;
            }
            u=l.split(xx-1);
            v=u.split(yy-xx+1);
            k=a[u.rt].mi.second;
            l+=u;
            l+=v;
            i=l.psum(li[k]);
            printf("%d\n",o+q-(i<<1)+2);
        }
        if(j==2){
            scanf("%d%d",&x,&y);
            i=l.query(l.rt,li[x],l.psum(li[x])-y);
            k=l.rank(ri[i]);
            u=l.split(k-1);
            o=l.rank(li[x]);
            v=l.split(o-1);
            p=v.split(v.rank(ri[x]));
            a[v.rt].add+=1-y;
            l+=p;
            l+=v;
            l+=u;
        }
        if(j==3){
            scanf("%d",&x);
            i=l.query(l.rt,ri[1],x+1);
            printf("%d\n",i);
        }
    }
    return 0;
}