题目大意:
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。有一些操作:1.把结点u的权值改为t;2.询问从点u到点v的路径上的节点的最大权值 3.询问从点u到点v的路径上的节点的权值和。
思路:
进行轻重树链剖分,再根据每个节点的dfs序建立线段树,维护其最大值以及和,询问时用树剖后的结果将重链作为区间一段一段求和。
代码:
1 #include<cstdio> 2 #include<cstring> 3 #include<iostream> 4 #define M 1000009 5 using namespace std; 6 7 int n,dfn,cnt,to[M],next[M],head[M],size[M],vis[M],deep[M],fa[M],top[M],w[M],mx[M],sum[M],id[M]; 8 9 void add(int x,int y) 10 { 11 to[++cnt]=y,next[cnt]=head[x],head[x]=cnt; 12 } 13 14 void dfs1(int x) 15 { 16 size[x]=vis[x]=1; 17 for (int i=head[x];i;i=next[i]) 18 if (!vis[to[i]]) 19 { 20 deep[to[i]]=deep[x]+1; 21 fa[to[i]]=x; 22 dfs1(to[i]); 23 size[x]+=size[to[i]]; 24 } 25 } 26 27 void dfs2(int x,int chain) 28 { 29 int k=0,i; 30 id[x]=++dfn; 31 top[x]=chain; 32 for (i=head[x];i;i=next[i]) 33 if (deep[to[i]]>deep[x] && size[to[i]]>size[k]) k=to[i]; 34 if (!k) return; 35 dfs2(k,chain); 36 for (i=head[x];i;i=next[i]) 37 if (deep[to[i]]>deep[x] && to[i]!=k) dfs2(to[i],to[i]); 38 } 39 40 int LCA(int x,int y) 41 { 42 for (;top[x]!=top[y];x=fa[top[x]]) 43 if (deep[top[x]]<deep[top[y]]) swap(x,y); 44 return deep[x]<deep[y]?x:y; 45 } 46 47 void change(int l,int r,int x,int y,int cur) 48 { 49 if (l==r) 50 { 51 mx[cur]=sum[cur]=y; 52 return; 53 } 54 int mid=l+r>>1; 55 if (x<=mid) change(l,mid,x,y,cur<<1); 56 else change(mid+1,r,x,y,cur<<1|1); 57 mx[cur]=max(mx[cur<<1],mx[cur<<1|1]); 58 sum[cur]=sum[cur<<1]+sum[cur<<1|1]; 59 } 60 61 int SUM(int L,int R,int l,int r,int cur) 62 { 63 if (l<=L && r>=R) return sum[cur]; 64 int mid=L+R>>1; 65 if (l>mid) return SUM(mid+1,R,l,r,cur<<1|1); 66 else if (r<=mid) return SUM(L,mid,l,r,cur<<1); 67 else return SUM(L,mid,l,r,cur<<1)+SUM(mid+1,R,mid+1,r,cur<<1|1); 68 } 69 70 int MAX(int L,int R,int l,int r,int cur) 71 { 72 if (l<=L && r>=R) return mx[cur]; 73 int mid=L+R>>1; 74 if (l>mid) return MAX(mid+1,R,l,r,cur<<1|1); 75 else if (r<=mid) return MAX(L,mid,l,r,cur<<1); 76 else return max(MAX(L,mid,l,mid,cur<<1),MAX(mid+1,R,mid+1,r,cur<<1|1)); 77 } 78 79 int Sum(int x,int y) 80 { 81 int ans=0; 82 for (;top[x]!=top[y];x=fa[top[x]]) 83 { 84 if (deep[top[x]]<deep[top[y]]) swap(x,y); 85 ans+=SUM(1,n,id[top[x]],id[x],1); 86 } 87 if (deep[x]>deep[y]) swap(x,y); 88 return ans+SUM(1,n,id[x],id[y],1); 89 } 90 91 int Max(int x,int y) 92 { 93 int ans=-999999999; 94 for (;top[x]!=top[y];x=fa[top[x]]) 95 { 96 if (deep[top[x]]<deep[top[y]]) swap(x,y); 97 ans=max(ans,MAX(1,n,id[top[x]],id[x],1)); 98 } 99 if (deep[x]>deep[y]) swap(x,y); 100 return max(ans,MAX(1,n,id[x],id[y],1)); 101 } 102 103 int main() 104 { 105 int m,i,x,y; 106 scanf("%d",&n); 107 for (i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x); 108 dfs1(1); 109 dfs2(1,1); 110 for (i=1;i<=n;i++) scanf("%d",&w[i]),change(1,n,id[i],w[i],1); 111 scanf("%d",&m); 112 for (i=1;i<=m;i++) 113 { 114 char ch[9]; 115 scanf("%s%d%d",ch,&x,&y); 116 if (ch[0]=='C') w[x]=y,change(1,n,id[x],y,1); 117 else 118 if (ch[1]=='S') printf("%d\n",Sum(x,y)); 119 else printf("%d\n",Max(x,y)); 120 } 121 return 0; 122 }