[转载]无旋treap:从好奇到入门(例题:bzoj3224 普通平衡树)

时间:2024-08-23 10:06:32

转载自ZZH大佬,原文:http://www.cnblogs.com/LadyLex/p/7182491.html

今天我们来学习一种新的数据结构:无旋treap。它和splay一样支持区间操作,和treap一样简单易懂,同时还支持可持久化。

无旋treap的节点定义和treap一样,都要同时满足树性质和堆性质,我们还是用rand()来实现平衡

而无旋treap与treap不同的地方,也是其核心,就是它不旋转用两个新的核心函数:merge函数(合并两棵子树)和split函数(分裂出某棵树的前k个节点,并且作为一棵树返回)

首先看merge函数,它是一个递归实现的过程,先看代码:

 1 Treap *merge(Treap *a,Treap *b)
 2 {
 3     if(a==null)return b;
 4     if(b==null)return a;
 5     pushdown(a);pushdown(b);
 6     if(a->key < b->key)
 7         {a->ch[1]=merge(a->ch[1],b);a->update();return a;}
 8     else
 9         {b->ch[0]=merge(a,b->ch[0]);b->update();return b;}
10 }

对于两棵子树a和b,我们可以实现把b树合并到a树中

在合并时,我们首先看他们的根节点谁的键值比较小(我维护的是一个小根堆),并且建立对应的父子关系。

又由于平衡树的中序遍历不变,我们又要把b插在a后面,维持一个确定的中序遍历,

所以我们应该一直把a作为merge函数的前一个参数,b作为后一个参数,这个顺序不能换.

这一个确定的顺序的重要性尤其体现在后续的区间操作中。刚开始的时候可以当板子背下来,但随着打题肯定会逐渐理解。

接下来我们介绍split函数,这也是一个递归实现的过程,还是先看代码:

 1 typedef pair<Treap*,Treap*> D;
 2 D split(Treap *o,int k)
 3 {
 4     if(o==null) return D(null,null);
 5     D y;pushdown(o);
 6     if(o->ch[0]->size>=k)
 7         {y=split(o->ch[0],k);o->ch[0]=y.second;o->update();y.second=o;}
 8     else
 9         {y=split(o->ch[1],k-o->ch[0]->size-1);o->ch[1]=y.first;o->update();y.first=o;}
10     return y;
11 }

我们首先定义一个pair,这样做的好处是同时返回分裂出来的两棵树的根节点指针,我规定第一个是分离完成的树,第二个是剩下的原树。

然后考虑分离前k个的过程:如果o的左儿子有k个以上节点,我们显然应该去左儿子分离。

然后我们会得到分离完成的树和左儿子剩下的树,这时候把左儿子剩下的部分接回节点o,并把新的o作为分离o剩下的原树

如果左儿子节点个数不够,我们就去右儿子分离,过程是相似的,但略有不同,留给读者思考。

有了这两个函数,我们就可以用他们实现一些常用的操作了,比如:

insert=split+newnode+merge+merge

delete=split+split+merge(合并第一个split的first和第二个的second)

等等,其他操作也可以用类似的思路打出来。下面我们用一道例题实战一下。建议读者自己实现代码并充分思考后再核对标程。

3224: Tyvj 1728 普通平衡树

Time Limit: 10 Sec  Memory Limit: 128 MB

Description

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)

Input

第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)

Output

对于操作3,4,5,6每行输出一个数,表示对应答案

Sample Input

10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598

Sample Output

106465
84185
492737

HINT

1.n的数据范围:n<=100000
2.每个数的数据范围:[-2e9,2e9]
题解:
这道题本质上只比上面讲的基本操作多了两个函数:查询某个权值的排名和查询某个排名的权值
查询某个权值的排名很简单,在树中递归询问即可
而对于某个权值的排名,我们可以考虑split前k-1个节点,再对第一次split的second进行split,得到第k个节点,并且返回权值
前驱和后继只是上面这两个操作的简单变形,但稍微需要注意一下边界的处理。
代码见下:
 1 #include <cstdio>
 2 #include <algorithm>
 3 #include <cstring>
 4 #include <ctime>
 5 #include <cstdlib>
 6 using namespace std;
 7 const int maxn=100100,inf=0x7fffffff;
 8 struct Treap
 9 {
10     Treap* ch[2];
11     int key,val,size;
12     Treap(int v)
13         {size=1,val=v,key=rand();ch[0]=ch[1]=NULL;}
14     inline void tain()
15         {size=1+(ch[0]?ch[0]->size:0)+(ch[1]?ch[1]->size:0);}
16 }*root;
17 typedef pair<Treap*,Treap*> D;
18 inline int size(Treap *o){return o?o->size:0;}
19 Treap *Merge(Treap *a,Treap* b)
20 {
21     if(!a)return b;
22     if(!b)return a;
23     if(a->key < b->key)
24         {a->ch[1]=Merge(a->ch[1],b);a->tain();return a;}
25     else
26         {b->ch[0]=Merge(a,b->ch[0]);b->tain();return b;}
27 }
28 D Split(Treap *o,int k)
29 {
30     if(!o)return D(NULL,NULL);
31     D y;
32     if(size(o->ch[0])>=k)
33         {y=Split(o->ch[0],k);o->ch[0]=y.second;o->tain();y.second=o;}
34     else
35         {y=Split(o->ch[1],k-size(o->ch[0])-1);o->ch[1]=y.first;o->tain();y.first=o;}
36     return y;
37 }
38 int Getkth(Treap *o,int v)
39 {
40     if(o==NULL)return 0;
41     return(o->val>=v)?Getkth(o->ch[0],v):Getkth(o->ch[1],v)+size(o->ch[0])+1;
42 }
43 inline int Findkth(int k)
44 {
45     D x=Split(root,k-1);
46     D y=Split(x.second,1);
47     Treap *ans=y.first;
48     root=Merge(Merge(x.first,ans),y.second);
49     return ans!=NULL?ans->val:0;
50 }
51 inline void Insert(int v)
52 {
53     int k=Getkth(root,v);
54     D x=Split(root,k);
55     Treap *o=new Treap(v);
56     root=Merge(Merge(x.first,o),x.second);
57 }
58 void Delete(int v)
59 {
60     int k=Getkth(root,v);
61     D x=Split(root,k);
62     D y=Split(x.second,1);
63     root=Merge(x.first,y.second);
64 }
65 int main(){
66     int m,opt,x;scanf("%d",&m);
67     while(m--)
68     {
69         scanf("%d%d",&opt,&x);
70         switch(opt)
71         {
72             case 1:Insert(x);break;
73             case 2:Delete(x);break;
74             case 3:printf("%d\n",Getkth(root,x)+1);break;
75             case 4:printf("%d\n",Findkth(x));break;
76             case 5:printf("%d\n",Findkth(Getkth(root,x)));break;
77             case 6:printf("%d\n",Findkth(Getkth(root,x+1)+1));break;
78         }
79     }
80 }