学习笔记 | treap | splay

时间:2021-11-28 05:41:25

前言

不会数据结构选手深深地感受到了来自treap的恶意QwQ
在听的时候感觉自己听得听懂的??大概只是听懂了它的意思
代码是怎么写都感觉写不好╮(╯﹏╰)╭ 菜啊

treap

一句很好的话总结treap: tree+heap
即 树+堆
维护一棵二叉查找树 同时对每个节点rand一个值使得满足堆的性质TAT
感性理解一下 这样是可以保持treap深度在\(log\)的级别内的 也就保证了它的复杂度

它的基本操作

ZOJ2112|Dynamic Rankings

因为满足二叉排序的性质

  • 插入
  • 删除
  • 找前驱和后继
  • 查找第k大
#include<bits/stdc++.h>
#define fr(i,x,y) for(int i=x;i<=y;++i)
#define rf(i,x,y) for(int i=x;i>=y;--i)
#define ls q[x].ch[0]
#define rs q[x].ch[1]
#define LL long long
using namespace std;
const int N=1e5+10;
const LL inf=3e9;
struct data{
    int ch[2],rd,sz,cnt;
    LL val;
}q[N];
int cnt=0;

void up(int x){
    q[x].sz=q[ls].sz+q[rs].sz+q[x].cnt;
}

void rotate(int &x,int d){
    int son=q[x].ch[d];
    q[x].ch[d]=q[son].ch[d^1];
    q[son].ch[d^1]=x;
   // printf("x%d son%d\n",x,son);
    up(x),up(x=son);
}

void insert(int &x,LL val){
    if(!x){
        x=++cnt;
        q[x].val=val,q[x].sz=q[x].cnt=1;
        q[x].rd=rand()%inf;
        return ;
    }
    q[x].sz++;
    if(q[x].val==val){ q[x].cnt++;return ; }
    int d=q[x].val<val;
    insert(q[x].ch[d],val);
    if(q[q[x].ch[d]].rd<q[x].rd) rotate(x,d);
}

void del(int &x,LL val){
    if(!x) return ;
    if(q[x].val==val){
        if(q[x].cnt>=2) q[x].cnt--,q[x].sz--;
        else {
            if(ls==0||rs==0) x=ls+rs;
            else {
                int d=q[ls].rd>q[rs].rd;
                rotate(x,d),del(x,val);
            }
        }
    }
    else q[x].sz--,del(q[x].ch[val>q[x].val],val);
}

LL qq(int x,LL val){
    if(!x) return -inf;
    if(q[x].val>=val) return qq(ls,val);
    return max(q[x].val,qq(rs,val));
}

LL hj(int x,LL val){
   // printf("%d %lld\n",x,val);
    if(!x) return inf;
    if(q[x].val<=val) return hj(rs,val);
    return min(hj(ls,val),q[x].val);
}

int rk(int x,LL val){
    if(!x) return 0;
    if(q[x].val==val) return q[ls].sz+1;
    if(q[x].val>val) return rk(ls,val);
    if(q[x].val<val) return q[ls].sz+q[x].cnt+rk(rs,val);
}

LL kth(int x,int k){
    int nm=q[ls].sz+q[x].cnt;
    if(k>q[ls].sz&&k<=nm) return q[x].val;
    if(nm>k) return kth(ls,k);
    else return kth(rs,k-nm);
}

int main(){
    int n;scanf("%d",&n);
    int rt=0;
    fr(i,1,n){
        int tp;LL z;
        scanf("%d%lld",&tp,&z);
        //printf("rt=%d\n",rt);
        if(tp==1) insert(rt,z);
        if(tp==2) del(rt,z);
        if(tp==3) printf("%d\n",rk(rt,z));
        if(tp==4) printf("%lld\n",kth(rt,z));
        if(tp==5) printf("%lld\n",qq(rt,z));
        if(tp==6) printf("%lld\n",hj(rt,z));
    }
    return 0;
}

Splay

每次操作 都把要操作的东西放到离根近一点的地方

左旋右旋什么的都跟treap一样

厉害的地方大概是区间操作
如果要操作区间\([L,R]\)
那么把\(L-1\)节点移到根 再将\(R+1\)节点移到\(L-1\)的儿子位置
那么\(R+1\)的左子树就是\([L,R]\)

既然splay可以支持区间操作 那么就是说 可以有一些线段树的操作了
维护最大值最小值 和 子序列的最大和等等

理解了splay的核心思想之后欢迎尝试
play毒瘤模拟题bzoj1500|NOI2005|维修数列

手残党被续了大半天 其实是几个很思博的错误TAT

#include<bits/stdc++.h>
#define fr(i,x,y) for(int i=x;i<=y;++i)
#define rf(i,x,y) for(int i=x;i>=y;--i)
#define LL long long

#define ls ch[x][0]
#define rs ch[x][1]

using namespace std;
const int N=5e5+10,M=4e6+10,inf=1e4;
int top=0,idx=0,cnt=0,root;
int s[M],ch[N][2],f[N],tag[N],add[N],fz[N],q[N];
int vmx[N],v[N],lmx[N],sum[N],sz[N],rmx[N],mx[N];//weihu

void up(int x){
  if(!x) return ;
  sum[x]=sum[ls]+sum[rs]+v[x];
  sz[x]=sz[ls]+sz[rs]+1;
  vmx[x]=max(v[x],max(vmx[ls],vmx[rs]));
  lmx[x]=max(lmx[ls],sum[ls]+v[x]+lmx[rs]);
  rmx[x]=max(rmx[rs],sum[rs]+v[x]+rmx[ls]);
  mx[x]=max(max(max(mx[ls],mx[rs]),max(lmx[x],rmx[x])),v[x]+rmx[ls]+lmx[rs]);
  //printf("xxx%d\n",mx[x]);
}

void rec(int x,int val){
  tag[x]=1,add[x]=val;
  v[x]=vmx[x]=val;
  sum[x]=val*sz[x];
  lmx[x]=rmx[x]=mx[x]=max(val*sz[x],0);
}

void rev(int x){
  fz[x]^=1;
  swap(ls,rs);
  swap(lmx[x],rmx[x]);
}

void down(int x){
  //printf("???x=%d\n",x);
  if(tag[x]){
    if(ls) rec(ls,add[x]);
    if(rs) rec(rs,add[x]);
    tag[x]=0;
  }
  if(fz[x]){
    if(ls) rev(ls);
    if(rs) rev(rs);
    fz[x]=0;
  }
}

void newnode(int &x,int val,int p){
  if(top) x=q[top--];
  else x=++idx;
  sz[x]=1;
  ls=rs=0;
  f[x]=p,v[x]=vmx[x]=sum[x]=val;
  lmx[x]=rmx[x]=mx[x]=max(sum[x],0);
  tag[x]=fz[x]=add[x]=0;
}

void build(int &x,int l,int r,int p){
  if(l>r) return ;
  int mid=(l+r)>>1;
  newnode(x,s[mid],p);
  build(ls,l,mid-1,x),build(rs,mid+1,r,x);
  up(x);
}

void init(){
  vmx[0]=-inf;
  newnode(ch[0][1],-inf,0);
  root=ch[0][1];
  newnode(ch[root][0],-inf,root);
  build(ch[ch[root][0]][1],1,s[0],ch[root][0]);
    up(ch[root][0]);up(root);
}

void print(int x){
  if(!x) return ;
  print(ls);
  printf("%d l=%d r=%d %d sz=%d f=%d\n",v[x],v[ls],v[rs],sum[x],sz[x],v[f[x]]);
  print(rs);
}

void rotate(int x){
  int y=f[x],z=f[y];
  down(x),down(y);
  int k=(ch[z][1]==y),d=(ch[y][1]==x);
  ch[z][k]=x,f[x]=z;
  ch[y][d]=ch[x][d^1],f[ch[x][d^1]]=y;
  ch[x][d^1]=y,f[y]=x;
  up(y),up(x);
}

void splay(int x,int goal){
  down(x);
  while(f[x]!=goal){
    int y=f[x],z=f[y];
    if(z!=goal)
      (y==ch[z][0])^(x==ch[y][0]) ? rotate(x):rotate(y);
    rotate(x);
  }
  up(x);
  if(!goal) root=x;
}

int find(int x,int k){
  down(x);
  int gg=sz[ls]+1;
 // printf("%d %d %d %d %d\n",x,ls,rs,k,gg);
  if(gg>k) return find(ls,k);
  else if(gg<k) return find(rs,k-gg);
  else return x;
}

void del(int x){
  if(!x) return ;
  q[++top]=x;
  del(ls),del(rs);
}

int main(){
  int n,m;scanf("%d%d",&n,&m);
  fr(i,1,n) scanf("%d",&s[++s[0]]);
  init();
  string S;
  int cnt=n;
  fr(o,1,m){
    cin>>S;
    int pos,tot,c;
    if(S=="INSERT"){
      s[0]=0;
      scanf("%d%d",&pos,&tot);
      fr(i,1,tot) scanf("%d",&s[++s[0]]);
      int L=find(root,pos+1),R=find(root,pos+2);
      splay(L,0),splay(R,L);
      build(ch[R][0],1,s[0],R);
      up(R),up(L);
      cnt+=tot;
    }
    if(S=="DELETE"){
      scanf("%d%d",&pos,&tot);
      int L=find(root,pos),R=find(root,pos+tot+1);
      splay(L,0),splay(R,L);
      del(ch[R][0]);
      ch[R][0]=0;
      up(R),up(L);
      cnt-=tot;
    }
    if(S=="MAKE-SAME"){
      scanf("%d%d%d",&pos,&tot,&c);
      int L=find(root,pos),R=find(root,pos+tot+1);
      splay(L,0),splay(R,L);
      rec(ch[R][0],c);
      down(ch[R][0]);
      up(R),up(L);
    }
    if(S=="REVERSE"){
      scanf("%d%d",&pos,&tot);
      int L=find(root,pos),R=find(root,pos+tot+1);
      splay(L,0),splay(R,L);
      rev(ch[R][0]);
      down(ch[R][0]);
      up(R),up(L);
    }
    if(S=="GET-SUM"){
      scanf("%d%d",&pos,&tot);
      int L=find(root,pos),R=find(root,pos+tot+1);
      splay(L,0),splay(R,L);
      printf("%d\n",sum[ch[R][0]]);
    }
    if(S=="MAX-SUM"){
      int L=find(root,1),R=find(root,cnt+2);
      splay(L,0),splay(R,L);
      if(vmx[ch[R][0]]>=0) printf("%d\n",mx[ch[R][0]]);
      else printf("%d\n",vmx[ch[R][0]]);
    }
  }
  return 0;
}

hdu4441

#include<bits/stdc++.h>
#define fr(i,x,y) for(int i=x;i<=y;++i)
#define rf(i,x,y) for(int i=x;i>=y;--i)
#define LL long long

#define lson x<<1
#define rson x<<1|1

using namespace std;
const int N=2e5+10,M=N<<1;
int nm[N<<2];

string s;
int pos,n;

void change(int x,int l,int r,int L,int k){
  if(l==r) {
    nm[x]+=k;
    return ;
  }
  int mid=(l+r)>>1;
  if(L<=mid) change(lson,l,mid,L,k);
  else change(rson,mid+1,r,L,k);
  nm[x]=nm[lson]+nm[rson];
}

int Ask(int x,int l,int r){
  if(l==r) return l;
  int mid=(l+r)>>1;
  if(nm[lson]!=(mid-l+1)) return Ask(lson,l,mid);
  else return Ask(rson,mid+1,r);
}

int idx=0,root=0;
int ch[M][2],f[M],zf[M][2],sz[M],mk[N][2];
LL sum[M],v[M];

#define ls ch[x][0]
#define rs ch[x][1]

void up(int x){
  if(!x) return ;
  sum[x]=sum[ls]+sum[rs]+v[x];
  sz[x]=sz[ls]+sz[rs]+1;
  zf[x][0]=zf[ls][0]+zf[rs][0]+(v[x]>0);
  zf[x][1]=zf[ls][1]+zf[rs][1]+(v[x]<0);
}

int newnode(int& x,LL val,int p){
  x=++idx;
  f[x]=p,ls=rs=0,sz[x]=1;
  sum[x]=v[x]=val;
  zf[x][0]=(val>0),zf[x][1]=(val<0);
  return idx;
}

void init(){
  fr(i,1,n*4) nm[i]=0;
  root=idx=0;
  newnode(root,0,0);
  newnode(ch[root][1],0,root);
  up(ch[root][1]),up(root);
}

int find(int x,int k){
  int gg=sz[ls]+1;
  if(gg>k) return find(ls,k);
  else if(gg==k) return x;
  else return find(rs,k-gg);
}

void rotate(int x){
  int y=f[x],z=f[y];
  int d=(ch[z][1]==y),k=(ch[y][1]==x);
  ch[z][d]=x,f[x]=z;
  ch[y][k]=ch[x][k^1],f[ch[x][k^1]]=y;
  ch[x][k^1]=y,f[y]=x;
  up(y),up(x);
}

void splay(int x,int goal){
  while(f[x]!=goal){
    int y=f[x],z=f[y];
    if(z!=goal)
      (ch[z][1]==y)^(ch[y][1]==x) ? rotate(x) : rotate(y);
    rotate(x);
  }
  if(!goal) root=x;
}

int Fu(int x,int k){
  int nm=zf[ls][1],d=(v[x]<0);
  if(nm>=k) return Fu(ls,k);
  else if(nm+d==k) return x;
  else return Fu(rs,k-d-nm);
}

void print(int x){
  if(!x) return ;
  print(ls);
  printf("fa=%lld ls=%lld rs=%lld v=%lld\n",v[f[x]],v[ls],v[rs],v[x]);
  print(rs);
}

void delet(int pos){
  int L=pos;
  splay(L,0);
  int gg=sz[ch[L][0]];
  int Z=find(root,gg),Y=find(root,gg+2);
  splay(Z,0),splay(Y,Z);
  ch[Y][0]=0;
  up(Y),up(Z);
}

int main(){
  int cse=0;
  while(scanf("%d",&n)==1){
    init();
    printf("Case #%d:\n",++cse);
    fr(o,1,n){
      cin>>s;scanf("%d",&pos);
      if(s=="insert"){
        int val=Ask(1,1,n);
        int L=find(root,pos+1),R=find(root,pos+2);
        splay(L,0),splay(R,L);
        mk[val][0]=newnode(ch[R][0],val,R);
        up(R),up(L);
        int gg=zf[ch[root][0]][0]+(v[root]>0);
        if(zf[root][1]<=gg){
         int zz=sz[root]-1;
         int l=find(root,zz),r=find(root,zz+1);
         splay(l,0),splay(r,l);
         mk[val][1]=newnode(ch[r][0],-val,r);
         up(r),up(l);
        } else {
          gg++;
          int l=Fu(root,gg);
          splay(l,0);
          int pos=sz[ch[l][0]];
          int Z=find(root,pos),Y=find(root,pos+1);
          splay(Z,0),splay(Y,Z);
          mk[val][1]=newnode(ch[Y][0],-val,Y);
          up(Y),up(Z);
        }
        change(1,1,n,val,1);
      }
      if(s=="remove")
        delet(mk[pos][0]),delet(mk[pos][1]),change(1,1,n,pos,-1);
      if(s=="query"){
        int L=mk[pos][0],R=mk[pos][1];
        splay(L,0),splay(R,L);
        printf("%lld\n",sum[ch[R][0]]);
      }
    }
  }
  return 0;
}

为什么一写splay就写错..
数据结构题好难调啊啊w(゚Д゚)w
终于调过啦!!真是开心~