bzoj1036: [ZJOI2008]树的统计Count——树链剖分

时间:2023-01-06 09:52:38

bzoj1036: [ZJOI2008]树的统计Count:http://www.lydsy.com/JudgeOnline/problem.php?id=1036

1036: [ZJOI2008]树的统计Count

Time Limit: 10 Sec   Memory Limit: 162 M
[ Submit][ Status][ Discuss]

Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

Sample Output

4
1
2
2
10
6
5
6
5
16

思路:没做过树链剖分,,然后学习了一番。。参考

树链剖分就是把树拆成一系列链,然后用数据结构对链进行维护。

通常的剖分方法是轻重链剖分,所谓轻重链就是对于节点u的所有子结点v,size[v]最大的v与u的边是重边,其它边是轻边,其中size[v]是以v为根的子树的节点个数,全部由重边组成的路径是重路径,根据论文上的证明,任意一点到根的路径上存在不超过logn条轻边和logn条重路径。

这样我们考虑用数据结构来维护重路径上的查询,轻边直接查询。

通常用来维护的数据结构是线段树,splay较少见。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <map>
#include <cmath>

using namespace std;
const int MAXN = 30005;
const int inf = 0x3f3f3f3f;

struct Node {
int v;
int link;
} e[MAXN<<1];
int head[MAXN],valu[MAXN];
int si[MAXN];//size[v]是以v为根的子树的节点个数
int fa[MAXN];//fa[x]是x的父亲
int de[MAXN];//表示每个节点的深度
int pos[MAXN];//pos[x]表示x在线段树中的编号
int bl[MAXN];//bl[x]表示x所在这条重链的顶端节点(top)

int tol;
void add(int u,int v) {
e[++tol].v = v,e[tol].link = head[u],head[u] = tol;
e[++tol].v = u,e[tol].link = head[v],head[v] = tol;
}

//第一遍dfs求出树每个结点的深度deep[x],其为根的子树大小size[x],fa[x]代表x的父亲
void dfs1(int x) {
si[x] = 1;
for(int i = head[x]; i; i = e[i].link) {
if(e[i].v == fa[x])continue;
de[e[i].v] = de[x] + 1;
fa[e[i].v] = x;
dfs1(e[i].v);
si[x] += si[e[i].v];
}
}
int cnt;
//第二遍dfs: 根节点为起点,向下拓展构建重链,选择最大的一个子树的根继承当前重链
//其余节点,都以该节点为起点向下重新拉一条重链
//给每个结点分配一个位置编号,每条重链就相当于一段区间,用数据结构去维护。
//把所有的重链首尾相接,放到同一个数据结构上,然后维护这一个整体即可
void dfs2(int x,int y) {
int k = 0;
cnt++;
pos[x] = cnt;//分配x结点在线段树中的编号
bl[x] = y;//bl[x]表示x所在这条重链的顶端节点(top)
for(int i = head[x]; i; i = e[i].link)
if(de[e[i].v] > de[x] && si[e[i].v] > si[k])
k = e[i].v;
if(k == 0)return;
dfs2(k,y);
for(int i = head[x]; i; i = e[i].link)
if(de[e[i].v] > de[x] && k != e[i].v)
dfs2(e[i].v,e[i].v);
}
struct Tree {
int l,r,mx,sum;
} seg[MAXN<<2];

void build(int l,int r,int rt) {
seg[rt].l = l;
seg[rt].r = r;
// seg[rt].mx = seg[rt].sum = 0;
if(l == r)return;
int mid = l+r>>1;
build(l,mid,rt<<1);
build(mid+1,r,rt<<1|1);
}

void update(int x,int w,int rt) {
int l = seg[rt].l,r = seg[rt].r,mid = l+r>>1;
if(l == r) {
seg[rt].sum = seg[rt].mx = w;
return;
}
if(x <= mid)update(x,w,rt<<1);
else update(x,w,rt<<1|1);
seg[rt].mx = max(seg[rt<<1].mx,seg[rt<<1|1].mx);
seg[rt].sum = seg[rt<<1].sum + seg[rt<<1|1].sum;
}
int querysum(int L,int R,int rt) {
int l = seg[rt].l,r = seg[rt].r,mid = l+r>>1;
if(L <= l && r <= R)return seg[rt].sum;
int ans = 0;
if(L <= mid)ans += querysum(L,R,rt<<1);
if(R > mid)ans += querysum(L,R,rt<<1|1);
return ans;
}
int querymx(int L,int R,int rt) {
int l = seg[rt].l,r = seg[rt].r,mid = l+r>>1;
if(L <= l && r <= R)return seg[rt].mx;
if(R <= mid)return querymx(L,R,rt<<1);
else if(L > mid)return querymx(L,R,rt<<1|1);
else return max(querymx(L,mid,rt<<1),querymx(mid+1,R,rt<<1|1));
}
int solvesum(int L,int R) {
int sum = 0;
while(bl[L] != bl[R]) {
if(de[bl[L]] < de[bl[R]])swap(L,R);
sum += querysum(pos[bl[L]],pos[L],1);
L = fa[bl[L]];
}
if(pos[L] > pos[R]) swap(L,R);
sum += querysum(pos[L],pos[R],1);
return sum;
}
int solvemx(int L,int R) {
int mx = -inf;
while(bl[L] != bl[R]) {
if(de[bl[L]] < de[bl[R]])swap(L,R);
mx = max(mx,querymx(pos[bl[L]],pos[L],1));
L = fa[bl[L]];
}
if(pos[L] > pos[R])swap(L,R);
mx = max(mx,querymx(pos[L],pos[R],1));
return mx;
}
int main() {
// freopen("in.txt","r",stdin);
char ss[10];
int u,v,n,q;
scanf("%d",&n);
for(int i = 1; i < n; i++) {
scanf("%d%d",&u,&v);
add(u,v);
}
for(int i = 1; i <= n; i++)
scanf("%d",&valu[i]);
dfs1(1);
dfs2(1,1);
build(1,n,1);
for(int i = 1; i <= n; i++)
update(pos[i],valu[i],1);
scanf("%d",&q);
while(q--) {
scanf("%s%d%d",ss,&u,&v);
if(ss[0] == 'C') {
valu[u] = v;
update(pos[u],v,1);
} else if(ss[1] == 'M')printf("%d\n",solvemx(u,v));
else printf("%d\n",solvesum(u,v));
}
return 0;
}