bzoj1036: [ZJOI2008]树的统计Count
Time Limit: 10 Sec
Memory Limit: 162 MBDescription
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
题目地址: bzoj1036: [ZJOI2008]树的统计Count
题目大意:
题目已经很清楚了
题解:
树链剖分裸题
先把树轻重链剖分(两遍dfs)
然后线段树维护区间最大值和区间和
将剖出来的序列合并一下就好了
AC代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#define inf 0x7fffffff
#define N 30005
using namespace std;
int n,Q,cnt,sz;
int w[N],dep[N],size[N],head[N],fa[N];
int pos[N],top[N];
char ch[10];
struct edge{
int to,next;
}e[N+N];
struct seg{
int l,r,mx,sum;
}t[N<<2];
void add_edge(int u,int v){
e[++cnt]=(edge){v,head[u]};head[u]=cnt;
e[++cnt]=(edge){u,head[v]};head[v]=cnt;
}
void dfs1(int u){
size[u]=1;
for(int i=head[u];i;i=e[i].next){
if(e[i].to==fa[u])continue;
dep[e[i].to]=dep[u]+1;
fa[e[i].to]=u;
dfs1(e[i].to);
size[u]+=size[e[i].to];
}
}
void dfs2(int u,int chain){
int k=0;sz++;
pos[u]=sz;
top[u]=chain;
for(int i=head[u];i;i=e[i].next)
if(dep[e[i].to]>dep[u]&&size[e[i].to]>size[k])
k=e[i].to;
if(k==0)return;
dfs2(k,chain);
for(int i=head[u];i;i=e[i].next)
if(dep[e[i].to]>dep[u]&&k!=e[i].to)
dfs2(e[i].to,e[i].to);
}
void build(int l,int r,int id){
t[id].l=l;t[id].r=r;
if(l==r)return;
int mid=(l+r)>>1;
build(l,mid,id<<1);
build(mid+1,r,id<<1|1);
}
void change(int id,int k,int w){
int l=t[id].l,r=t[id].r,mid=(l+r)>>1;
if(l==r){
t[id].sum=t[id].mx=w;
return;
}
if(k<=mid)change(id<<1,k,w);
else change(id<<1|1,k,w);
t[id].sum=t[id<<1].sum+t[id<<1|1].sum;
t[id].mx=max(t[id<<1].mx,t[id<<1|1].mx);
}
int querysum(int id,int L,int R){
int l=t[id].l,r=t[id].r,mid=(l+r)>>1;
if(l==L&&R==r)return t[id].sum;
if(R<=mid)return querysum(id<<1,L,R);
else if(L>mid)return querysum(id<<1|1,L,R);
else return querysum(id<<1,L,mid)+querysum(id<<1|1,mid+1,R);
}
int querymx(int id,int L,int R){
int l=t[id].l,r=t[id].r,mid=(l+r)>>1;
if(l==L&&R==r)return t[id].mx;
if(R<=mid)return querymx(id<<1,L,R);
else if(L>mid)return querymx(id<<1|1,L,R);
else return max(querymx(id<<1,L,mid),querymx(id<<1|1,mid+1,R));
}
int solvesum(int a,int b){
int sum=0;
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]])swap(a,b);
sum+=querysum(1,pos[top[a]],pos[a]);
a=fa[top[a]];
}
if(pos[a]>pos[b])swap(a,b);
sum+=querysum(1,pos[a],pos[b]);
return sum;
}
int solvemx(int a,int b){
int mx=-inf;
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]])swap(a,b);
mx=max(mx,querymx(1,pos[top[a]],pos[a]));
a=fa[top[a]];
}
if(pos[a]>pos[b])swap(a,b);
mx=max(mx,querymx(1,pos[a],pos[b]));
return mx;
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add_edge(u,v);
}
for(int i=1;i<=n;i++)scanf("%d",&w[i]);
dfs1(1);
dfs2(1,1);
build(1,n,1);
for(int i=1;i<=n;i++)
change(1,pos[i],w[i]);
scanf("%d",&Q);
while(Q--){
int x,y;scanf("%s%d%d",ch+1,&x,&y);
if(ch[1]=='C'){
w[x]=y;
change(1,pos[x],y);
}else
if(ch[2]=='M')
printf("%d\n",solvemx(x,y));
else
printf("%d\n",solvesum(x,y));
}
return 0;
}