BZOJ1208 [HNOI2004]宠物收养所 splay

时间:2025-03-08 16:07:08

原文链接http://www.cnblogs.com/zhouzhendong/p/8085803.html


题目传送门 - BZOJ1208


题意概括

  有两种数,依次加入。

  规则为下:

  如果当前剩余的为同种数(或者没有数字),那么直接加入该数。

  否则找到与剩余的数中与当前数差的绝对值最小的(如果有多个一样小的,选择原值最小的),然后ans+=abs(差),并把这两个数都弄没。


题解

  splay裸题。


代码

#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <cmath>
using namespace std;
const int N=80005,mod=1000000;
int fa[N],son[N][2],val[N],root,total=0;
void spt_clear(){
root=total=0;
memset(fa,0,sizeof fa);
memset(son,0,sizeof son);
}
int wson(int x){
return son[fa[x]][1]==x;
}
void rotate(int x){
if (fa[x]==0)
return;
int y=fa[x],z=fa[y],L=wson(x),R=L^1;
if (z)
son[z][wson(y)]=x;
fa[x]=z,fa[y]=x,fa[son[x][R]]=y;
son[y][L]=son[x][R],son[x][R]=y;
}
void splay(int x,int rt){
if (!x)
return;
if (!rt)
root=x;
for (int y=fa[x];fa[x];rotate(x),y=fa[x])
if (fa[y])
rotate(wson(x)==wson(y)?y:x);
}
int findpre(int v,int rt){
if (!rt)
return 0;
if (v==val[rt])
return rt;
if (v>val[rt]){
int x=findpre(v,son[rt][1]);
return x?x:rt;
}
return findpre(v,son[rt][0]);
}
int findpre(int v){
int res=findpre(v,root);
splay(res,0);
return res;
}
int findnxt(int v,int rt){
if (!rt)
return 0;
if (v==val[rt])
return rt;
if (v<val[rt]){
int x=findnxt(v,son[rt][0]);
return x?x:rt;
}
return findnxt(v,son[rt][1]);
}
int findnxt(int v){
int res=findnxt(v,root);
splay(res,0);
return res;
}
int find(int v,int rt){
if (!rt)
return 0;
if (v==val[rt])
return rt;
return find(v,son[rt][v>val[rt]]);
}
int findmax(int rt){
return son[rt][1]?findmax(son[rt][1]):rt;
}
void insert(int v,int &x,int pre){
if (x)
return insert(v,son[x][v>val[x]],x);
fa[x=++total]=pre,val[x]=v;
splay(x,0);
}
void erase(int v){
int x=find(v,root),rt;
splay(x,0);
if (!son[x][0]&&!son[x][1])
return spt_clear();
if (!son[x][0]||!son[x][1]){
int &s=son[x][(bool)son[x][1]];
fa[root=s]=0;
s=0;
return;
}
rt=findmax(son[x][0]);
son[x][0]=fa[son[x][0]]=0;
splay(rt,0);
fa[son[rt][1]=son[x][1]]=rt;
son[x][1]=0;
}
int n,op,v,ans=0,nowop;
int main(){
spt_clear();
scanf("%d",&n);
for (int i=1;i<=n;i++){
scanf("%d%d",&op,&v);
if (root==0){
nowop=op;
insert(v,root,0);
continue;
}
if (op==nowop)
insert(v,root,0);
else {
int pre=findpre(v),nxt=findnxt(v),cv;
if (!pre)
cv=val[nxt];
else if (!nxt)
cv=val[pre];
else
cv=abs(val[nxt]-v)<abs(v-val[pre])?val[nxt]:val[pre];
ans=(ans+abs(v-cv))%mod;
erase(cv);
}
}
printf("%d",ans);
return 0;
}