bzoj3637(lct)

时间:2023-11-10 12:34:02

又一次把lct写炸了,硬着头皮终于改对了

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
const int maxn=;
struct node{
int fa,ls,rs,is_root;
}tr[maxn*];
int t,tot[maxn*],sum[maxn*],cnt,last[maxn*],pre[maxn*],to[maxn*];//tot表示这个点连出的虚边子树和;
int n,q,col[maxn],f[maxn];
void add(int x,int y){++t;pre[t]=last[x];last[x]=t;to[t]=y;}
void update(int x){
sum[x]=tot[x];
if(tr[x].ls!=&&tr[x].ls!=n+)sum[x]+=sum[tr[x].ls];
if(tr[x].rs!=&&tr[x].rs!=n+)sum[x]+=sum[tr[x].rs];
}
void rx(int x){
int y=tr[x].fa,z=tr[y].fa;
tr[y].ls=tr[x].rs;
if(tr[x].rs!=&&tr[x].rs!=n+)tr[tr[x].rs].fa=y;
tr[x].rs=y;tr[y].fa=x;
tr[x].fa=z;
if(z!=&&z!=n+&&!tr[y].is_root){
if(tr[z].ls==y)tr[z].ls=x;else tr[z].rs=x;
}
if(tr[y].is_root)tr[y].is_root=,tr[x].is_root=;
update(y);update(x);
}
void lx(int x){
int y=tr[x].fa,z=tr[y].fa;
tr[y].rs=tr[x].ls;
if(tr[x].ls!=&&tr[x].rs!=n+)tr[tr[x].ls].fa=y;
tr[x].ls=y;tr[y].fa=x;
tr[x].fa=z;
if(z&&z!=n+&&!tr[y].is_root){
if(tr[z].ls==y)tr[z].ls=x;else tr[z].rs=x;
}
if(tr[y].is_root)tr[y].is_root=,tr[x].is_root=;
update(y);update(x);
}
void splay(int x){
while(!tr[x].is_root){
int y=tr[x].fa,z=tr[y].fa;
if(tr[y].is_root){if(tr[y].ls==x)rx(x);else lx(x);}
else{
if(tr[z].ls==y&&tr[y].ls==x){rx(y);rx(x);}
else if(tr[z].ls==y&&tr[y].rs==x){lx(x);rx(x);}
else if(tr[z].rs==y&&tr[y].ls==x){rx(x);lx(x);}
else {lx(y);lx(x);}
}
}
}
void ace(int x){
for(int p=;x!=&&x!=n+;x=tr[x].fa){
splay(x);
if(tr[x].rs!=&&tr[x].rs!=n+){
tr[tr[x].rs].is_root=;
tot[x]+=sum[tr[x].rs];
}
if(p!=&&p!=n+){
tot[x]-=sum[p];
}
tr[tr[x].rs=p].is_root=;
update(p=x);
}
}
void link(int x,int y){//x是y的父亲
if(x==||x==n+)return;
ace(x);splay(x);splay(y);
tr[y].fa=x;tr[x].rs=y;tr[y].is_root=;//一开始最后这句丢了;
update(x);
}
void cut(int x,int y){//y是x的父亲
if(y==||y==n+)return;
ace(x);splay(x);tr[tr[x].ls].fa=;tr[tr[x].ls].is_root=;tr[x].ls=;update(x);
}
void dfs(int x,int fa){
for(int i=last[x];i;i=pre[i]){
int v=to[i];
if(v==fa)continue;
link(x,v);f[v]=x;
dfs(v,x);
}
}
int query(int x){
int tmp1=x,tmp2;
if(col[x])x+=n+;
ace(x);
splay(x);
while(tr[x].ls){
x=tr[x].ls;
}
splay(x);
if(col[tmp1])tmp2=x-n-;
else tmp2=x;
if(col[tmp2]!=col[tmp1])return sum[tr[x].rs];
else {return sum[x];}
}
int main(){
int x,y,op;
cin>>n;
for(int i=;i<n;++i){
scanf("%d %d",&x,&y);
add(x,y);add(y,x);
}
for(int i=;i<=n;++i){
sum[i]=tot[i]=;
tr[i].is_root=;
}
for(int i=n+;i<=*n+;++i)tr[i].is_root=;
dfs(,);
cin>>q;
for(int i=;i<=q;++i){
scanf("%d %d",&op,&x);
if(op){
if(col[x]){
cut(x+n+,f[x]+n+);
tot[x+n+]-=;sum[x+n+]-=;
tot[x]+=;sum[x]+=;
link(f[x],x);
}
else{
cut(x,f[x]);
tot[x]-=;sum[x]-=;
tot[x+n+]+=;sum[x+n+]+=;
link(f[x]+n+,x+n+);
}
col[x]^=;
}
else{
printf("%d\n",query(x));
}
}
return ;
}