【题目】BZOJ 1095
【题意】给定n个黑白点的树,初始全为黑点,Q次操作翻转一个点的颜色,或询问最远的两个黑点的距离,\(n \leq 10^5,Q \leq 5*10^5\)。
【算法】括号序列+线段树
【题解】参考:konjac
括号序列其实就是入栈出栈序,每个点在进入时加左括号和点编号,退出时加右括号。
这样做的好处:两个点间的括号数(除去匹配的括号)就是两点间路径的长度。
除去匹配的括号后,容易发现两个点间的括号时“)))((("的形式,右括号就是向上一条边,左括号就是向下一条边。
考虑两个区间的合并(只是区间不是线段树区间),记\(a_1\)表示左区间的右括号,\(b_1\)表示左区间的左括号,\(a_2,b_2\)表示右区间的。
\[a+b=a_1+b_2+|a_2-b_1|=max\{(a_1-b_1)+(a_2+b_2),(a_1+b_1)+(b_2-a_2)\}\]
\[a-b=(a_1-b_1)+(a_2-b_2)\]
\[b-a=(b_1-a_1)+(b_2-a_2)\]
那么这道题,用线段树维护括号序列,支持单点修改和区间查询,需要记录以下这些量:
- \(a\):右括号数
- \(b\):左括号数
- \(l_1\):左端点到某个黑点的b+a的最大值
- \(l_2\):左端点到某个黑点的b-a的最大值
- \(r_1\):右端点到某个黑点的a+b的最大值
- \(r_2\):右端点到某个黑点的a-b的最大值
- \(ans\):区间最远的两个黑点的距离
之所以维护这些量是因为:
\[ans=max\{L.ans,R.ans,L.r_1+R.l_2,L.r_2+R.l_1\}\]
答案要么来自左区间或右区间,要么跨越中点。考虑中点分割的左右两部分,如果\(b_1>a_2\)那么答案就是左区间右起的a+b和右区间左起b-a,否则答案是左区间右起的a-b和右区间左起的a+b。
其它的量根据上面的三条合并法则可以快速合并。
初始化:考虑实际意义,左右括号除了ab其它全部-inf,黑点除了ans其它全部0。
复杂度\(O(n \ \ log \ \ n)\)。
注意:左区间维护的是b-a,这样才是取最大值,如果维护a-b就是最小值了。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
bool isdigit(char c){return c>='0'&&c<='9';}
int read(){
int s=0,t=1;char c;
while(!isdigit(c=getchar()))if(c=='-')t=-1;
do{s=s*10+c-'0';}while(isdigit(c=getchar()));
return s*t;
}
const int maxn=300010,inf=0x3f3f3f3f;
int n,tot,ed,a[maxn],pos[maxn],first[maxn];
struct edge{int v,from;}e[maxn*2];
struct tree{int l,r,a,b,l1,l2,r1,r2,ans;}t[maxn*4];
void insert(int u,int v){ed++;e[ed].v=v;e[ed].from=first[u];first[u]=ed;}
void dfs(int x,int fa){
a[++tot]=-1;
a[pos[x]=++tot]=1;
for(int i=first[x];i;i=e[i].from)if(e[i].v!=fa){
dfs(e[i].v,x);
}
a[++tot]=-2;
}
int max(int a,int b){return a<b?b:a;}
void up(int k){
int l=k<<1,r=k<<1|1;
t[k].a=t[l].a+max(0,t[r].a-t[l].b);
t[k].b=t[r].b+max(0,t[l].b-t[r].a);
t[k].l1=max(t[l].l1,max(t[l].a-t[l].b+t[r].l1,t[l].a+t[l].b+t[r].l2));
t[k].l2=max(t[l].l2,t[l].b-t[l].a+t[r].l2);//
t[k].r1=max(t[r].r1,max(t[r].b-t[r].a+t[l].r1,t[r].a+t[r].b+t[l].r2));
t[k].r2=max(t[r].r2,t[r].a-t[r].b+t[l].r2);
t[k].ans=max(max(t[l].ans,t[r].ans),max(t[l].r1+t[r].l2,t[l].r2+t[r].l1));
}
void build(int k,int l,int r){
t[k].l=l;t[k].r=r;
if(l==r){
if(a[l]==-1)t[k]=(tree){l,r,0,1,-inf,-inf,-inf,-inf,-inf};
if(a[l]==-2)t[k]=(tree){l,r,1,0,-inf,-inf,-inf,-inf,-inf};//
if(a[l]==1)t[k]=(tree){l,r,0,0,0,0,0,0,-inf};//
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);build(k<<1|1,mid+1,r);
up(k);
}
void modify(int k,int x,int y){
if(t[k].l==t[k].r){
if(y==1)t[k]=(tree){t[k].l,t[k].r,0,0,0,0,0,0,0};
else t[k]=(tree){t[k].l,t[k].r,0,0,-inf,-inf,-inf,-inf,-inf};
return;
}
int mid=(t[k].l+t[k].r)>>1;
if(x<=mid)modify(k<<1,x,y);else modify(k<<1|1,x,y);
up(k);
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int u=read(),v=read();
insert(u,v);insert(v,u);
}
dfs(1,0);
build(1,1,tot);
int Q=read();
char s[10];
int num=n;//
while(Q--){
scanf("%s",s);
if(s[0]=='G'){
if(num==1)printf("0\n");else if(num==0)printf("-1\n");
else printf("%d\n",t[1].ans);
}else{
int x=read();
modify(1,pos[x],a[pos[x]]^=1);
if(a[pos[x]]==1)num++;else num--;
}
}
return 0;
}