BZOJ1036 树的统计

时间:2021-08-21 16:11:59

Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。 
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

Sample Output

4
1
2
2
10
6
5
6
5
16

正解:树链剖分+线段树

解题报告:

  维护树上一条路径上的结点权值最大值或和

  没什么好说的,链剖裸题。先树链剖分再根据访问次序建立线段树,用线段树动态维护。

  模板题练手。

 //It is made by jump~
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MAXN = ;
const int inf = (<<);
int n;
int total,ecnt;
int U,VV;
int a[MAXN];
int id[MAXN],pre[MAXN];
int top[MAXN],siz[MAXN],zhongerzi[MAXN],father[MAXN],deep[MAXN];
int next[MAXN*],to[MAXN*],first[MAXN];
char ch[]; struct node{
int l,r;
int _max;int _sum;
}jump[MAXN*]; void link(int x,int y){ next[++ecnt]=first[x]; first[x]=ecnt; to[ecnt]=y; } int getint()
{
int w=,q=;
char c=getchar();
while((c<'' || c>'') && c!='-') c=getchar();
if (c=='-') q=, c=getchar();
while (c>='' && c<='') w=w*+c-'', c=getchar();
return q ? -w : w;
} void build(int root,int l,int r){
jump[root].l=l;jump[root].r=r;
if(jump[root].l==jump[root].r) {
jump[root]._sum=jump[root]._max=a[ pre[l] ];
return ;
}
int lc=root*,rc=root*+;
int mid=l+(r-l)/;
build(lc,l,mid); build(rc,mid+,r);
jump[root]._sum=jump[lc]._sum+jump[rc]._sum;
jump[root]._max=max(jump[lc]._max,jump[rc]._max);
} void dfs1(int u,int fa){
siz[u]=;
for(int i=first[u];i;i=next[i]) {
int v=to[i];
if(v!=fa) {
father[v]=u;
deep[v]=deep[u]+;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[ zhongerzi[u] ]) zhongerzi[u]=v;
}
}
} void dfs2(int u,int fa){
id[u]=++total; pre[total]=u;
if(zhongerzi[u]) top[zhongerzi[u]]=top[u],dfs2(zhongerzi[u],u);
for(int i=first[u];i;i=next[i]) {
int v=to[i];
if(v==fa || v==zhongerzi[u]) continue;
top[v]=v;
dfs2(v,u);
}
} int query_sum(int root,int x,int y){
if(jump[root].l>=x && jump[root].r<=y) return jump[root]._sum;
int da=;
int mid=jump[root].l+(jump[root].r-jump[root].l)/;
int lc=root*,rc=root*+;
if(x<=mid) da+=query_sum(lc,x,y);
if(y>mid) da+=query_sum(rc,x,y);
return da;
} int query_max(int root,int x,int y){
if(jump[root].l>=x && jump[root].r<=y) return jump[root]._max;
int da=-inf;
int mid=jump[root].l+(jump[root].r-jump[root].l)/;
int lc=root*,rc=root*+;
if(x<=mid) da=max(da,query_max(lc,x,y));
if(y>mid) da=max(da,query_max(rc,x,y));
return da;
} int find_max(int x,int y){
int f1=top[x],f2=top[y];
int daan=-inf;
while(f1!=f2){
if(deep[f1]<deep[f2]) swap(f1,f2),swap(x,y);
daan=max(daan,query_max(,id[f1],id[x]));
x=father[f1];
f1=top[x];
}
if(deep[x]<deep[y]) swap(x,y);
daan=max(daan,query_max(,id[y],id[x]));
return daan;
} int find_sum(int x,int y){
int f1=top[x],f2=top[y];
int daan=;
while(f1!=f2){
if(deep[f1]<deep[f2]) swap(f1,f2),swap(x,y);
daan+=query_sum(,id[f1],id[x]);
x=father[f1]; f1=top[x];
}
if(deep[x]<deep[y]) swap(x,y);
daan+=query_sum(,id[y],id[x]);
return daan;
} void update(int root,int o,int add){
if(jump[root].l==jump[root].r){
jump[root]._sum+=add;
jump[root]._max+=add;return ;
}
int lc=root*,rc=root*+;
int mid=jump[root].l+(jump[root].r-jump[root].l)/;
if(o<=mid) update(lc,o,add); else update(rc,o,add);
jump[root]._sum=jump[lc]._sum+jump[rc]._sum;
jump[root]._max=max(jump[lc]._max,jump[rc]._max);
} int main()
{
n=getint();
int x,y;
for(int i=;i<n;i++){
x=getint();y=getint();
next[++ecnt]=first[x]; first[x]=ecnt; to[ecnt]=y;
next[++ecnt]=first[y]; first[y]=ecnt; to[ecnt]=x;
} deep[]=; dfs1(,);
top[]=; dfs2(,); for(int i=;i<=n;i++) a[i]=getint();
build(,,n);
int Q=getint(); for(int i=;i<=Q;i++){
scanf("%s",ch);
if(ch[]=='M'){
printf("%d\n",find_max(x,y));
}
else if(ch[]=='S'){
x=getint();y=getint();
printf("%d\n",find_sum(x,y));
}
else{
U=getint();VV=getint();
update(,id[U],VV-a[U]);a[U]=VV;
}
}
return ;
}