P3369 【模板】普通平衡树(Treap/SBT)

时间:2023-03-24 12:14:26

题目描述

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入x数

  2. 删除x数(若有多个相同的数,因只删除一个)

  3. 查询x数的排名(若有多个相同的数,因输出最小的排名)

  4. 查询排名为x的数

  5. 求x的前驱(前驱定义为小于x,且最大的数)

  6. 求x的后继(后继定义为大于x,且最小的数)

输入输出格式

输入格式:

第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)

输出格式:

对于操作3,4,5,6每行输出一个数,表示对应答案

输入输出样例

输入样例#1:
10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598
输出样例#1:
106465
84185
492737

说明

时空限制:1000ms,128M

1.n的数据范围:n<=100000

2.每个数的数据范围:[-1e7,1e7]

来源:Tyvj1728 原名:普通平衡树

在此鸣谢

1.03

treap是一棵修改了结点顺序的二叉查找树

通常树内的每个结点x都有一个关键字值key[x],另外,还要为结点分配一个随机数。

假设所有的优先级是不同的,所有的关键字也是不同的。treap的结点排列成让关键字遵循二叉查找树性质,并且优先级遵

循最小堆顺序性质:
1.如果v是u的左孩子,则key[v] < key[u].
2.如果v是u的右孩子,则key[v] > key[u].
3.如果v是u的孩子,则rand[v] > rand[u].
这两个性质的结合就是为什么这种树被称为“treap”的原因,因为它同时具有二叉查找树和heap的特征。

(1.18整编转载自hzwer)

#include<cstdio>
#include<cstdlib>
#include<ctime>
using namespace std;
const int N=1e6+;
struct tree{
int l,r;//左右儿子节点编号
int num;//当前节点的数字
int s;//以当前节点为根的子树的节点数
int sum;//当前节点的数字的数量
int rnd;//随机优先级
}tr[N];//下标为节点编号
int n,rt,cnt,t1,t2;
void updata(int &k){
int &l=tr[k].l,&r=tr[k].r;
tr[k].s=tr[l].s+tr[r].s+tr[k].sum;
}
void lturn(int &k){
int t=tr[k].r;tr[k].r=tr[t].l;tr[t].l=k;
tr[t].s=tr[k].s;updata(k);k=t;
}
void rturn(int &k){
int t=tr[k].l;tr[k].l=tr[t].r;tr[t].r=k;
tr[t].s=tr[k].s;updata(k);k=t;
}
void insert(int &k,int x){
if(!k){
k=++cnt;tr[k].num=x;tr[k].s=;tr[k].sum++;tr[k].rnd=rand();return ;
}
tr[k].s++;
int &l=tr[k].l,&r=tr[k].r;
if(x<tr[k].num){
insert(l,x);
if(tr[l].rnd<tr[k].rnd) rturn(k);
}
else if(x>tr[k].num){
insert(r,x);
if(tr[r].rnd<tr[k].rnd) lturn(k);
}
else{
tr[k].sum++;return ;
}
}
void del(int &k,int x){
if(!k) return ;
int &l=tr[k].l,&r=tr[k].r;
if(x==tr[k].num){
if(tr[k].sum>){
tr[k].sum--;tr[k].s--;return ;
}
if(l*r==) k=l+r;
else{
if(tr[l].rnd<tr[r].rnd) rturn(k);
else lturn(k);
del(k,x);
}
}
else{
tr[k].s--;
if(x>tr[k].num) del(r,x);
else del(l,x);
}
}
int find1(int &k,int x){
if(!k) return ;
int &l=tr[k].l,&r=tr[k].r;
if(tr[k].num==x) return tr[l].s+;
if(tr[k].num>x) return find1(l,x);
if(tr[k].num<x) return tr[l].s+tr[k].sum+find1(r,x);
}
int find2(int &k,int x){
if(!k) return ;
int &l=tr[k].l,&r=tr[k].r;
if(tr[l].s+<=x&&tr[l].s+tr[k].sum>=x) return tr[k].num;
if(tr[l].s>=x) return find2(l,x);
if(tr[l].s+tr[k].sum<x) return find2(r,x-tr[l].s-tr[k].sum);
}
void pred(int &k,int x){
if(!k) return ;
int &l=tr[k].l,&r=tr[k].r;
if(tr[k].num<x){
t1=tr[k].num;
pred(r,x);
}
else pred(l,x);
}
void succ(int &k,int x){
if(!k) return ;
int &l=tr[k].l,&r=tr[k].r;
if(tr[k].num>x){
t2=tr[k].num;
succ(l,x);
}
else succ(r,x);
}
int main(){
srand(time());
scanf("%d",&n);
for(int i=,opt,x;i<=n;i++){
scanf("%d%d",&opt,&x);t1=t2=;
switch(opt){
case :insert(rt,x);break;
case :del(rt,x);break;
case :printf("%d\n",find1(rt,x));break;
case :printf("%d\n",find2(rt,x));break;
case :pred(rt,x);printf("%d\n",t1);break;
case :succ(rt,x);printf("%d\n",t2);break;
}
}
return ;
}

1.10

返现GNU系统有个pb_ds库,里面有好多bbt,我用rb-tree过掉了。

不过他的bbt不支持重复元素的出现,难道还要hash一下?

那不就失去了他的优越性了?

山人自有妙计。

//by shenben
#include<cstdio>
#include<iostream>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
tree<ll,null_mapped_type,less<ll>,rb_tree_tag,tree_order_statistics_node_update> bbt;
int n;ll k,ans;
inline int read(){
int x=,f=;char ch=getchar();
while(ch<''||ch>''){if(ch=='-')f=-;ch=getchar();}
while(ch>=''&&ch<=''){x=x*+ch-'';ch=getchar();}
return x*f;
}
int main(){
freopen("phs.in","r",stdin);
freopen("phs.out","w",stdout);
n=read();
for(int i=,opt;i<=n;i++){
opt=read();k=read();
if(opt==) bbt.insert((k<<)+i);
if(opt==) bbt.erase(bbt.lower_bound(k<<));
if(opt==) printf("%d\n",bbt.order_of_key(k<<)+);
if(opt==) ans=*bbt.find_by_order(k-),printf("%lld\n",ans>>);
if(opt==) ans=*--bbt.lower_bound(k<<),printf("%lld\n",ans>>);
// if(opt==6) ans=*bbt.lower_bound(k+1<<20),printf("%lld\n",ans>>20);
if(opt==) ans=*bbt.upper_bound((k<<)+n),printf("%lld\n",ans>>);
}
return ;
}

1.18

看到hzwer‘blog暴力vector,于是就写过了。

貌似插入是O(√n+1)的

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
int read(){
int x=,f=;char ch=getchar();
while(ch<''||ch>''){if(ch=='-')f=-;ch=getchar();}
while(ch>=''&&ch<=''){x=x*+ch-'';ch=getchar();}
return x*f;
}
int n;
vector<int>a;
void insert(int x){
a.insert(upper_bound(a.begin(),a.end(),x),x);
}
void del(int x){
a.erase(lower_bound(a.begin(),a.end(),x));
}
int find(int x){
return lower_bound(a.begin(),a.end(),x)-a.begin()+;
}
int main(){
a.reserve();
n=read();
for(int i=,opt,x;i<=n;i++){
opt=read();x=read();
switch(opt){
case :insert(x);break;
case :del(x);break;
case :printf("%d\n",find(x));break;
case :printf("%d\n",a[x-]);break;
case :printf("%d\n",*--lower_bound(a.begin(),a.end(),x));break;
case :printf("%d\n",*upper_bound(a.begin(),a.end(),x));break;
}
}
return ;
}

2.26 splay版

#include<cstdio>
using namespace std;
const int N=1e5+;
const int inf=2e9;
int n,c[N][],fa[N],val[N],cnt[N],siz[N],rt,sz;
inline int read(){
int x=,f=;char ch=getchar();
while(ch<''||ch>''){if(ch=='-')f=-;ch=getchar();}
while(ch>=''&&ch<=''){x=x*+ch-'';ch=getchar();}
return x*f;
}
void updata(int x){
siz[x]=siz[c[x][]]+siz[c[x][]]+cnt[x];
}
void rotate(int x,int &k){
int y=fa[x],z=fa[y],l,r;
l=(c[y][]==x);r=l^;
if(y==k) k=x;
else c[z][c[z][]==y]=x;
fa[x]=z;fa[y]=x;fa[c[x][r]]=y;
c[y][l]=c[x][r];c[x][r]=y;
updata(y);updata(x);
}
void splay(int x,int &k){
while(x!=k){
int y=fa[x],z=fa[y];
if(y!=k){
if((c[y][]==x)^(c[z][]==y)) rotate(x,k);
else rotate(y,k);
}
rotate(x,k);
}
}
#define l c[k][0]
#define r c[k][1]
void Rank(int v){
int k=rt;if(!rt) return ;
while(c[k][v>val[k]]&&val[k]!=v) k=c[k][v>val[k]];
splay(k,rt);
}
int kth(int rk){
rk++;int k=rt;
if(siz[k]<rk) return ;
for(;;){
if(siz[l]<rk&&siz[l]+cnt[k]>=rk) return val[k];
if(siz[l]>=rk) k=l;
else rk-=siz[l]+cnt[k],k=r;
}
}
void insert(int v){
int k=rt,y=;
while(k&&val[k]!=v) y=k,k=c[k][v>val[k]];
if(k) cnt[k]++;
else{
k=++sz;val[k]=v;siz[k]=cnt[k]=;fa[k]=y;
if(y) c[y][v>val[y]]=k;
}
splay(k,rt);
}
void erase(int v){
Rank(v);int k;
if(cnt[rt]>){cnt[rt]--;siz[rt]--;return ;}
if(!c[rt][]||!c[rt][]){
rt=c[rt][]+c[rt][];
fa[rt]=;
}
else{
k=c[rt][];
while(l) k=l;
siz[k]+=siz[c[rt][]];
fa[c[rt][]]=k;l=c[rt][];
rt=c[rt][];
fa[rt]=;
splay(k,rt);
}
}
int prev(int v){
Rank(v);
if(val[rt]<v) return val[rt];
int k=c[rt][];
while(r) k=r;
return val[k];
}
int succ(int v){
Rank(v);
if(val[rt]>v) return val[rt];
int k=c[rt][];
while(l) k=l;
return val[k];
}
#undef l
#undef r
int main(){
insert(-inf);insert(inf);
n=read();
for(int i=,opt,x;i<=n;i++){
opt=read();x=read();
if(opt==) insert(x);
if(opt==) erase(x);
if(opt==) Rank(x),printf("%d\n",siz[c[rt][]]);
if(opt==) printf("%d\n",kth(x));
if(opt==) printf("%d\n",prev(x));
if(opt==) printf("%d\n",succ(x));
}
return ;
}