[BZOJ1036][ZJOI2008]树的统计Count(树链剖分)

时间:2021-01-14 09:54:53

树剖模板题了,

Code

 

#include <cstdio>
#include <algorithm>
#define MID int mid=(l+r)>>1,ls=id<<1,rs=id<<1|1
#define N 30010
using namespace std;

struct node{int sum,mx;node(){sum=0,mx=-1e9;}}T[N*4];
struct info{int to,nex;}e[N*2];
int n,A[N],tot,head[N],tag[N*4];
int dep[N],fa[N],sz[N],son[N];
int cnt,tp[N],tw[N],tid[N];

inline int read(){
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

inline void Link(int u,int v){
	e[++tot].nex=head[u];e[tot].to=v;head[u]=tot;
}

void update(int l,int r,int id,int p,int x){
	if(l==r){T[id].sum=T[id].mx=x;return;}
	MID;
	if(p<=mid) update(l,mid,ls,p,x);
	else update(mid+1,r,rs,p,x);
	T[id].sum=T[ls].sum+T[rs].sum;
	T[id].mx=max(T[ls].mx,T[rs].mx);
}

int querySum(int l,int r,int id,int L,int R){
	if(L<=l&&r<=R) return T[id].sum;
	MID;
	int res=0;
	if(L<=mid) res+=querySum(l,mid,ls,L,R);
	if(R>mid) res+=querySum(mid+1,r,rs,L,R);
	return res;
}

int queryMx(int l,int r,int id,int L,int R){
	if(L<=l&&r<=R) return T[id].mx;
	MID;
	int res=-1e9;
	if(L<=mid) res=max(res,queryMx(l,mid,ls,L,R));
	if(R>mid) res=max(res,queryMx(mid+1,r,rs,L,R));
	return res;
}

void dfs(int u,int pre){
	sz[u]=1;
	for(int i=head[u],mx=0;i;i=e[i].nex){
		int v=e[i].to;
		if(v==pre) continue;
		dep[v]=dep[u]+1;
		fa[v]=u;
		dfs(v,u);
		sz[u]+=sz[v];
		if(sz[v]>mx) son[u]=v,mx=sz[v];
	}
}

void dddfs(int u,int top){
	tp[u]=top;
	tid[u]=++cnt;
	tw[cnt]=A[u];
	if(!son[u]) return;
	
	dddfs(son[u],top);
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		if(v!=son[u]&&v!=fa[u]) dddfs(v,v);
	}
}

inline int qRangeSum(int u,int v){
	int res=0;
	while(tp[u]!=tp[v]){
		if(dep[tp[u]]<dep[tp[v]]) swap(u,v);
		res+=querySum(1,n,1,tid[tp[u]],tid[u]);
		u=fa[tp[u]];
	}
	if(dep[u]>dep[v]) swap(u,v);
	res+=querySum(1,n,1,tid[u],tid[v]);
	return res;
}

inline int qRangeMx(int u,int v){
	int res=-1e9;
	while(tp[u]!=tp[v]){
		if(dep[tp[u]]<dep[tp[v]]) swap(u,v);
		res=max(res,queryMx(1,n,1,tid[tp[u]],tid[u]));
		u=fa[tp[u]];
	}
	if(dep[u]>dep[v]) swap(u,v);
	res=max(res,queryMx(1,n,1,tid[u],tid[v]));
	return res;
}

void build(int l,int r,int id){
	if(l==r){
		T[id].sum=T[id].mx=tw[l];
		return;
	}
	MID;
	build(l,mid,ls);
	build(mid+1,r,rs);
	T[id].sum=T[ls].sum+T[rs].sum;
	T[id].mx=max(T[ls].mx,T[rs].mx);
}

inline void Init(){
	n=read();
	for(int i=1;i<n;++i){
		int u=read(),v=read();
		Link(u,v),Link(v,u);
	}
	for(int i=1;i<=n;A[i++]=read());
	dfs(1,0);
	dddfs(1,1);
	build(1,n,1);
}

inline void solve(){
	int m=read(),x,y;
	char s[10];
	while(m--){
		scanf("%s%d%d\n",s,&x,&y);
		if(s[1]=='H') update(1,n,1,tid[x],y);
		else if(s[1]=='M') printf("%d\n",qRangeMx(x,y));
		else printf("%d\n",qRangeSum(x,y));	
	}
}

int main(){Init();solve();}