1036: [ZJOI2008]树的统计Count
Time Limit: 10 Sec Memory Limit: 162 MBSubmit: 19800 Solved: 8055
[Submit][Status][Discuss]
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
1
2
2
10
6
5
6
5
16
分析:树链剖分模板题.比较容易写挂的两个地方:1.求答案的时候比较链顶的深度 2.对重儿子进行第二次dfs的时候链顶不变.
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn = 1000010,inf = 0x7fffffff; int sum[maxn << 2],maxx[maxn << 2],son[maxn],sizee[maxn],top[maxn],id[maxn],idx[maxn],head[maxn],to[maxn * 2],nextt[maxn * 2],tot = 1; int n,q,v[maxn],dep[maxn],cnt,f[maxn]; char s[10]; void add(int x,int y) { to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } void dfs(int u,int fa,int d) { dep[u] = d; f[u] = fa; sizee[u] = 1; for (int i = head[u];i;i = nextt[i]) { int v = to[i]; if (v == fa) continue; dfs(v,u,d + 1); sizee[u] += sizee[v]; if (sizee[v] >= sizee[son[u]]) son[u] = v; } } void dfs2(int u,int topp) { top[u] = topp; id[u] = ++cnt; idx[cnt] = u; if (son[u]) dfs2(son[u],topp); for(int i = head[u];i;i = nextt[i]) { int v = to[i]; if(v == f[u] || v == son[u]) continue; dfs2(v,v); } } void pushup(int o) { sum[o] = sum[o * 2] + sum[o * 2 + 1]; maxx[o] = max(maxx[o * 2],maxx[o * 2 + 1]); } void build(int o,int l,int r) { if (l == r) { sum[o] = maxx[o] = v[idx[l]]; return; } int mid = (l + r) >> 1; build(o * 2,l,mid); build(o * 2 + 1,mid + 1,r); pushup(o); } void update(int o,int l,int r,int cur,int v) { if (l == r) { sum[o] = maxx[o] = v; return; } int mid = (l + r) >> 1; if (cur <= mid) update(o * 2,l,mid,cur,v); if (cur > mid) update(o * 2 + 1,mid + 1,r,cur,v); pushup(o); } int query1(int o,int l,int r,int x,int y) { if (x <= l && r <= y) return sum[o]; int mid = (l + r) >> 1,res = 0; if (x <= mid) res += query1(o * 2,l,mid,x,y); if(y > mid) res += query1(o * 2 + 1,mid + 1,r,x,y); return res; } int query2(int o,int l,int r,int x,int y) { if (x <= l && r <= y) return maxx[o]; int mid = (l + r) >> 1,res = -inf; if (x <= mid) res = max(res,query2(o * 2,l,mid,x,y)); if (y > mid) res = max(res,query2(o * 2 +1,mid + 1,r,x,y)); return res; } int solve1(int x,int y) { int ans = 0; if (dep[x] < dep[y]) swap(x,y); while (top[x] != top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x,y); ans += query1(1,1,n,id[top[x]],id[x]); x = f[top[x]]; } if (dep[x] < dep[y]) swap(x,y); ans += query1(1,1,n,id[y],id[x]); return ans; } int solve2(int x,int y) { int ans = -inf; if (dep[x] < dep[y]) swap(x,y); //printf("%d %d\n",x,y); while (top[x] != top[y]) { if(dep[top[x]] < dep[top[y]]) swap(x,y); ans = max(ans,query2(1,1,n,id[top[x]],id[x])); x = f[top[x]]; } if (dep[x] < dep[y]) swap(x,y); ans = max(ans,query2(1,1,n,id[y],id[x])); return ans; } int main() { scanf("%d",&n); for(int i = 1; i < n; i++) { int u,v; scanf("%d%d",&u,&v); add(u,v); add(v,u); } for (int i = 1; i <= n; i++) scanf("%d",&v[i]); dfs(1,0,1); dfs2(1,1); build(1,1,n); scanf("%d",&q); for(int i = 1; i <= q; i++) { scanf("%s",s); int u,v; if (!strcmp(s,"CHANGE")) { scanf("%d%d",&u,&v); update(1,1,n,id[u],v); } else if (!strcmp(s,"QMAX")) { scanf("%d%d",&u,&v); printf("%d\n",solve2(u,v)); } else if(!strcmp(s,"QSUM")) { scanf("%d%d",&u,&v); printf("%d\n",solve1(u,v)); } } return 0; }