【Splay】bzoj3224 Tyvj 1728 普通平衡树

时间:2022-06-23 15:47:08
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
#define maxn 1000000
#define INF 2147483647
int n,fa[maxn],val[maxn],c[maxn][2],root,tot,siz[maxn],cnt[maxn];
void Maintain(int x)
{
siz[x]=siz[c[x][0]]+siz[c[x][1]]+cnt[x];
}
void NewNode(int &x,int Fa,int key)
{
x=++tot;
fa[x]=Fa;
c[x][0]=c[x][1]=0;
val[x]=key;
siz[x]=cnt[x]=1;
}
void Rotate(int x,bool flag)
{
int y=fa[x];
c[y][!flag]=c[x][flag];
fa[c[x][flag]]=y;
if(fa[y]){
c[fa[y]][c[fa[y]][1]==y]=x;
}
fa[x]=fa[y];
c[x][flag]=y;
fa[y]=x;
Maintain(y);
Maintain(x);
}
void Splay(int x,int goal)
{
if(!x){
return;
}
int y;
while((y=fa[x])!=goal){
if(fa[y]==goal){
Rotate(x,c[y][0]==x);
}
else{
if((c[y][0]==x)==(c[fa[y]][0]==y)){
Rotate(y,c[fa[y]][0]==y);
}
else{
Rotate(x,c[y][0]==x);
y=fa[x];
}
Rotate(x,c[y][0]==x);
}
}
Maintain(x);
if(!goal){
root=x;
}
}
int Find(int key,int x=root)
{
while(c[x][val[x]<key]){
if(val[x]==key){
return x;
}
x=c[x][val[x]<key];
}
return x;
}
void Insert(int key)
{
if(!root){
NewNode(root,0,key);
return;
}
int x=Find(key);
if(val[x]==key){
++cnt[x];
Splay(x,0);
return;
}
NewNode(c[x][val[x]<key],x,key);
Splay(c[x][val[x]<key],0);
}
int Findmax(int x=root)
{
while(c[x][1]){
x=c[x][1];
}
return x;
}
int Findmin(int x=root)
{
while(c[x][0]){
x=c[x][0];
}
return x;
}
void Delete(int key)
{
int x=Find(key);
Splay(x,0);
if(val[x]==key){
--cnt[x];
if(!cnt[x]){
if(!c[x][0]&&!c[x][1]){
root=0;
}
else if(!c[x][0]||!c[x][1]){
fa[c[x][c[x][0]==0]]=0;
root=c[x][c[x][0]==0];
}
else{
int y=Findmax(c[x][0]);
Splay(y,root);
c[y][1]=c[root][1];
fa[c[root][1]]=y;
root=y;
fa[root]=0;
Maintain(root);
}
}
else{
Maintain(x);
}
}
}
int Rank(int key)
{
int x=Find(key);
Splay(x,0);
return siz[c[x][0]]+1;
}
int Kth(int K,int x=root)
{
while(1){
int Siz0=siz[c[x][0]];
if(K<=Siz0){
x=c[x][0];
}
else if(K<=Siz0+cnt[x]){
return val[x];
}
else{
K-=(Siz0+cnt[x]);
x=c[x][1];
}
}
}
int GetPre(int key)
{
int x=Find(key);
if(val[x]==key){
Splay(x,0);
return val[Findmax(c[x][0])];
}
while(val[x]>key&&fa[x]){
x=fa[x];
}
return val[x];
}
int GetNex(int key)
{
int x=Find(key);
if(val[x]==key){
Splay(x,0);
return val[Findmin(c[x][1])];
}
while(val[x]<key&&fa[x]){
x=fa[x];
}
return val[x];
}
int main(){
int op, x;
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d",&op,&x);
if(op==1){
Insert(x);
}
else if(op==2){
Delete(x);
}
else if(op==3){
printf("%d\n",Rank(x));
}
else if(op==4){
printf("%d\n",Kth(x));
}
else if(op==5){
printf("%d\n",GetPre(x));
}
else{
printf("%d\n",GetNex(x));
}
}
return 0;
}