题目链接: http://acm.hdu.edu.cn/showproblem.php?pid=4897
题意:
给你一棵树,一开始每条边都是白色,有三种操作:
1.将 u - v路径上的边转换颜色
2.将 u - v路径上相邻的边转换颜色
3.求 u - v 路径上黑色边的数量
思路:
好变态的一道树链剖分啊。。。。写自闭了
首先树链剖分,用点表示边.然后就是非常绕的轻重链的维护了,我们用两棵线段树来进行维护,一颗维护重链上的,一棵维护轻链上的标记值。
第一种操作,重链的话直接线段树区间操作i,轻链的话由于第二个操作,它的颜色会被两个点影响(本身和父节点),两个端点颜色不相同时,它才会改变颜色,这里因为他是第一个操作我们直接对它的值进行修改,用另一棵线段树来维护第二个操作对点的标记,对于轻链异或两个端点的标记值再异或第一个操作得到的值就是需要的值。
第二种操作,我们要多建一颗线段树来维护哪点被标记了,需要考虑会有重链上的点与路径相交,上面说了我们重链上的点是用第一颗线段树直接求值的,这种情况下重链是会到影响的,所以我们需要在线段树上更新被影响的点,那么是哪些点会被影响,当轻链向上跳的时候时会有重链在它旁边,这些重边就是被影响的。
第三个操作:重链上的直接用第一颗线段树求,轻链上的点的值就一个一个根据第二棵线段树异或得到。
这道题代码量比较大,写起来要理清思路,要不越写越乱。。。
实现代码:
#include<bits/stdc++.h>
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mid int m = (l + r) >> 1
const int M = 5e5+; int sum[M],lazy[M],lazyl[M],head[M],son[M],siz[M];
int dep[M],fa[M],tid[M],top[M],col[M],Xor[M];
int cnt,cnt1,n;
struct node{
int to,next;
}e[M]; void add(int u,int v){
e[++cnt].to = v;e[cnt].next = head[u];head[u] = cnt;
} void dfs(int u,int faz,int deep){
dep[u] = deep;
siz[u] = ;
fa[u] = faz;
for(int i = head[u];i;i=e[i].next){
int v = e[i].to;
if(v == faz) continue;
dfs(v,u,deep+);
siz[u] += siz[v];
if(son[u] == -||siz[v] > siz[son[u]])
son[u] = v;
}
} void dfs1(int u,int t){
top[u] = t;
tid[u] = ++cnt1;
if(son[u] == -) return ;
dfs1(son[u],t);
for(int i = head[u];i;i=e[i].next){
int v = e[i].to;
if(v != son[u]&&v != fa[u])
dfs1(v,v);
}
} void init(){
cnt1 = ; cnt = ;
memset(son,-,sizeof(son));
memset(head,,sizeof(head));
memset(sum,,sizeof(sum));
memset(fa,,sizeof(fa));
memset(col,,sizeof(col));
memset(lazy,,sizeof(lazy));
memset(lazyl,,sizeof(lazyl));
memset(Xor,,sizeof(Xor));
memset(dep,,sizeof(dep));
memset(top,,sizeof(top));
memset(tid,,sizeof(tid));
} void pushup(int rt){
sum[rt] = sum[rt<<] + sum[rt<<|];
} void pushdown(int l,int r,int rt){
if(lazy[rt]){
mid;
if(l != ) sum[rt<<] = m-l+-sum[rt<<];
else sum[rt<<] = m-l-sum[rt<<];
sum[rt<<|] = r-m-sum[rt<<|];
lazy[rt<<] ^= ; lazy[rt<<|] ^= ;
lazy[rt] = ;
}
} void updatew(int L,int R,int l,int r,int rt){
if(L <= l&&R >= r){
lazy[rt] ^= ;
if(l != ) sum[rt] = r-l+-sum[rt];
else sum[rt] = r-l+sum[rt];
return ;
}
pushdown(l,r,rt);
mid;
if(L <= m) updatew(L,R,lson);
if(R > m) updatew(L,R,rson);
pushup(rt);
} int queryw(int L,int R,int l,int r,int rt){
if(L <= l&&R >= r){
return sum[rt];
}
pushdown(l,r,rt);
int ret = ;
mid;
if(L <= m) ret += queryw(L,R,lson);
if(R > m) ret += queryw(L,R,rson);
return ret;
} void pushdownl(int rt){
if(lazyl[rt]){
lazyl[rt<<] ^= ;
lazyl[rt<<|] ^= ;
lazyl[rt] = ;
}
}
void updatel(int L,int R,int l,int r,int rt){
if(L <= l&&R >= r){
lazyl[rt] ^= ;
return ;
}
mid;
if(L <= m) updatel(L,R,lson);
if(R > m) updatel(L,R,rson);
} int queryl(int p,int l,int r,int rt){
if(l == r) {
return Xor[rt]^lazyl[rt];
}
pushdownl(rt);
mid;
if(p <= m) return queryl(p,lson);
else return queryl(p,rson);
} int addl(int x,int y){
int fx = top[x],fy = top[y];
while(fx != fy){
if(dep[fx] < dep[fy]) swap(fx,fy),swap(x,y);
if(son[x]!=-) updatew(tid[son[x]],tid[son[x]],,n,);
updatel(tid[fx],tid[x],,n,);
x = fa[fx]; fx = top[x];
}
if(dep[x] > dep[y]) swap(x,y);
if(son[x] != -&&x == y) updatew(tid[son[x]],tid[son[x]],,n,);
if(x != y&&son[y] != -) updatew(tid[son[y]],tid[son[y]],,n,);
if(x != &&son[fa[x]] == x) updatew(tid[x],tid[x],,n,);
updatel(tid[x],tid[y],,n,);
} void addw(int x,int y){
int fx = top[x],fy = top[y];
while(fx != fy){
if(dep[fx] < dep[fy]) swap(fx,fy),swap(x,y);
if(fx != x) updatew(tid[son[fx]],tid[x],,n,);
col[fx] ^= ; x = fa[fx]; fx = top[x];
}
if(x != y){
if(dep[x] < dep[y]&&son[x] != -) updatew(tid[son[x]],tid[y],,n,);
else if(son[y] != -) updatew(tid[son[y]],tid[x],,n,);
}
} int solve(int x,int y){
int ret = ;
int fx = top[x],fy = top[y];
while(fx != fy){
if(dep[fx] < dep[fy]) swap(x,y),swap(fx,fy);
if(fx != x) ret += queryw(tid[son[fx]],tid[x],,n,);
int num = queryl(tid[fx],,n,)^queryl(tid[fa[fx]],,n,);
ret += num^col[fx];
x = fa[fx]; fx = top[x];
}
if(x != y){
if(dep[x] < dep[y]&&son[x] != -) ret += queryw(tid[son[x]],tid[y],,n,);
else if(son[y] != -) ret += queryw(tid[son[y]],tid[x],,n,);
}
return ret;
} int main()
{
int t,u,v,op,x,y,q;
scanf("%d",&t);
while(t--){
init();
scanf("%d",&n);
for(int i = ;i < n;i ++){
int x,y;
scanf("%d%d",&u,&v);
add(u,v); add(v,u);
}
dfs(,,);
dfs1(,);
scanf("%d",&q);
while(q--){
scanf("%d%d%d",&op,&x,&y);
if(op == ){
addw(x,y);
}
else if(op == ){
addl(x,y);
}
else{
if(x == y) printf("0\n");
else
printf("%d\n",solve(x,y));
}
}
}
return ;
}