BZOJ1036(ZJOI2008)[树的统计Count]--树链剖分+线段树

时间:2022-04-01 09:52:52

【链接】
bzoj1036

【题目大意】
给你一个有n个节点的树,节点的编号为1~n,每个节点有一个权值w。我们对这棵树进行三个操作:1. CHANGE u t : 把结点u的权值改为t。 2. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 。3. QSUM u v: 询问从点u到点v的路径上的节点的权值和 Ps:从点u到点v的路径上的节点包括u和v本身

【解题报告】
此题其实就是树链剖分+线段树的模板题,详见树链剖分总结

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=30005,maxm=60005,maxv=120005,INF=((1<<30)-1)*2+1;
int n,Q,tot,num[maxn],lnk[maxn],son[maxm],nxt[maxm],fa[maxn],sonH[maxn],s[maxn],dep[maxn],top[maxn],id[maxn],who[maxn];
struct LT
{
int l[maxv],r[maxv],sum[maxv],MAX[maxv];
void Pushup(int d){sum[d]=sum[d<<1]+sum[(d<<1)+1]; MAX[d]=max(MAX[d<<1],MAX[(d<<1)+1]);}
void Build(int d,int L,int R)
{
l[d]=L; r[d]=R;
if (L==R) {sum[d]=MAX[d]=num[who[L]]; return;}
int mid=(R-L>>1)+L;
Build(d<<1,L,mid); Build((d<<1)+1,mid+1,R);
Pushup(d);
}
void Insert(int d,int where,int p)
{
if (l[d]==r[d]) {sum[d]=MAX[d]=p; return;}
int mid=(r[d]-l[d]>>1)+l[d];
if (where<=mid) Insert(d<<1,where,p);
else Insert((d<<1)+1,where,p);
Pushup(d);
}
int Query_sum(int d,int L,int R)
{
if (l[d]==L&&r[d]==R) return sum[d];
int mid=(r[d]-l[d]>>1)+l[d];
if (R<=mid) return Query_sum(d<<1,L,R);
else if (L>mid) return Query_sum((d<<1)+1,L,R);
else return Query_sum(d<<1,L,mid)+Query_sum((d<<1)+1,mid+1,R);
}
int Query_MAX(int d,int L,int R)
{
if (l[d]==L&&r[d]==R) return MAX[d];
int mid=(r[d]-l[d]>>1)+l[d];
if (R<=mid) return Query_MAX(d<<1,L,R);
else if (L>mid) return Query_MAX((d<<1)+1,L,R);
else return max(Query_MAX(d<<1,L,mid),Query_MAX((d<<1)+1,mid+1,R));
}
}tr;//线段树
inline int Read()
{
int res=0,f=1;
char ch=getchar(),cc=ch;
while (ch<'0'||ch>'9') cc=ch,ch=getchar();
if (cc=='-') f=-1;
while (ch>='0'&&ch<='9') res=res*10+ch-48,ch=getchar();
return res*f;
}
void Add(int x,int y)
{
son[++tot]=y; nxt[tot]=lnk[x]; lnk[x]=tot;
}
void Dfs(int x)
{
s[x]=1; dep[x]=dep[fa[x]]+1;
for (int j=lnk[x]; j; j=nxt[j])
if (son[j]!=fa[x])
{
fa[son[j]]=x;
Dfs(son[j]);
s[x]+=s[son[j]];
if (s[son[j]]>s[sonH[x]]) sonH[x]=son[j];
}
}
void HLD(int x,int lst)
{
top[x]=lst; who[++tot]=x; id[x]=tot;
if (sonH[x]) HLD(sonH[x],lst);
for (int j=lnk[x]; j; j=nxt[j])
if (son[j]!=fa[x]&&son[j]!=sonH[x]) HLD(son[j],son[j]);
}
int Ask_sum(int x,int y)//询问总和
{
int sum=0;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
sum+=tr.Query_sum(1,id[top[x]],id[x]);
x=fa[top[x]];
}
if (id[x]>id[y]) swap(x,y);
sum+=tr.Query_sum(1,id[x],id[y]);
return sum;
}
int Ask_MAX(int x,int y)//询问最大值
{
int MAX=-INF;
while(top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
MAX=max(MAX,tr.Query_MAX(1,id[top[x]],id[x]));
x=fa[top[x]];
}
if (id[x]>id[y]) swap(x,y);
MAX=max(MAX,tr.Query_MAX(1,id[x],id[y]));
return MAX;
}
int main()
{
freopen("1036.in","r",stdin);
freopen("1036.out","w",stdout);
n=Read(); tot=0;
memset(lnk,0,sizeof(lnk));
for (int i=1,x,y; i<n; i++) x=Read(),y=Read(),Add(x,y),Add(y,x);
for (int i=1; i<=n; i++) num[i]=Read();
memset(dep,0,sizeof(dep));
tot=0; Dfs(1); HLD(1,1); tr.Build(1,1,n);
Q=Read();
for (int i=1; i<=Q; i++)
{
char ch=getchar();
while (ch!='Q'&&ch!='C') ch=getchar();
char c2=getchar();
int x=Read(),y=Read();
if (ch=='C') tr.Insert(1,id[x],y);
else if (c2=='S') printf("%d\n",Ask_sum(x,y));
else printf("%d\n",Ask_MAX(x,y));
}
return 0;
}