bzoj1036: [ZJOI2008]树的统计Count

时间:2021-10-31 00:51:09

树链剖分



#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
int a[30010];
int read () {
	char c = getchar();
	int re = 0;
	int f = 1;
	while(c > '9' || c < '0') {
		if(c == '-')
			f = -1;
		c = getchar();
	}
	while(c <= '9' && c >= '0') {
		re = re * 10 + c - '0';
		c = getchar();
	}
	return re * f;
}
int n,u,v,q,lca;
char s[10];
int depth[30010],sz[30010],in[30010],son[30010],fa[30010],top[30010],ind;

struct LINE_TREE {
	int sum,MAX;
}tr[120010];

struct edge {
	int v,next;
}e[60010];
int cnt = -1,head[30010];
void adde (int u,int v) {
	e[++cnt].v = v;
	e[cnt].next = head[u];
	head[u] = cnt;
	e[++cnt].v = u;
	e[cnt].next = head[v];
	head[v] = cnt;
}

void dfs_1 (int U) {
	for(int i = head[U];i != -1;i = e[i].next) {
		int V = e[i].v;
		if(fa[U] == V)
			continue;
		depth[V] = depth[U] + 1;
		fa[V] = U;
		dfs_1(V);
		if(sz[son[U]] < sz[V])	
			son[U] = V;
		sz[U] += sz[V];
	}
	sz[U] += 1;
}

void dfs_2 (int U,int tp) {
	top[U] = tp;
	in[U] = ++ind;
	if(son[U])
		dfs_2(son[U],tp);
	for(int i = head[U];i != -1;i = e[i].next) {
		int V = e[i].v;
		if(fa[U] == V || V == son[U])
			continue;
		dfs_2(V,V);
	}
}

void push_up (int m) {
	if(tr[m << 1].MAX > tr[m << 1 | 1].MAX)
		tr[m].MAX = tr[m << 1].MAX;
	else tr[m].MAX = tr[m << 1 | 1].MAX;
	tr[m].sum = tr[m << 1].sum + tr[m << 1 | 1].sum;
}

void build (int l,int r,int m) {
	if(l == r) {
		tr[m].sum = tr[m].MAX = a[l];
		return;
	}
	int mid = (l + r) / 2;
	build(l,mid,m << 1);
	build(mid + 1,r,m << 1 | 1);
	push_up(m);
}

void change (int goal,int l,int r,int m) {
	if(l == r && l == goal) {
		tr[m].sum = tr[m].MAX = v;
		return;
	}
	int mid = (l + r) / 2;
	if(goal <= mid)
		change(goal,l,mid,m << 1);
	else change(goal,mid + 1,r,m << 1 | 1);
	push_up(m);
}

int Max (int L,int R,int l,int r,int m) {
	if(L == l && R == r)
		return tr[m].MAX;
	int mid = (l + r) / 2;
	if(L > mid)
		return Max(L,R,mid + 1,r,m << 1 | 1);
	else {
		if(R <= mid)
			return Max(L,R,l,mid,m << 1);
		else return max(Max(L,mid,l,mid,m << 1),Max(mid + 1,R,mid + 1,r,m << 1 | 1));
	}
}

int Sum (int L,int R,int l,int r,int m) {
	if(L == l && R == r)
		return tr[m].sum;
	int mid = (l + r) / 2;
	if(L > mid)
		return Sum(L,R,mid + 1,r,m << 1 | 1);
	else {
		if(R <= mid)
			return Sum(L,R,l,mid,m << 1);
		else return Sum(L,mid,l,mid,m << 1) + Sum(mid + 1,R,mid + 1,r,m << 1 | 1);
	}
}

int main () {
	freopen("FAQ.in","r",stdin);
	freopen("FAQ.in","w",stdout);
	memset(head,-1,sizeof head);
	n = read();
	for(int i = 1;i <= n - 1;++i) {
		u = read();
		v = read();
		adde(u,v);
	}
	depth[1] = 1;
	dfs_1(1);
	dfs_2(1,1);
	for(int i = 1;i <= n;i++)
		a[in[i]] = read();
	build(1,ind,1);
	q = read();
	while(q--) {
		scanf("%s",s);
		u = read();
		v = read();
		if(s[0] == 'C')
			change(in[u],1,ind,1);
		else {
			if(s[1] == 'M') {
				int re = -0x3f3f3f3f;
				
				int tu = top[u];
				int tv = top[v];
				while(tu != tv) {
					if(depth[tu] < depth[tv]) {
						swap(u,v);
						swap(tu,tv);
					}
					re = max(re,Max(in[top[u]],in[u],1,ind,1));
					u = fa[tu];
					tu = top[u];
				}
				if(depth[u] > depth[v])
					swap(u,v);
				re = max(re,Max(in[u],in[v],1,ind,1));
				printf("%d\n",re);
				
			}
			else {
				int re = 0;
				int tu = top[u];
				int tv = top[v];
				while(tu != tv) {
					if(depth[tu] < depth[tv]) {
						swap(u,v);
						swap(tu,tv);
					}
					re += Sum(in[top[u]],in[u],1,ind,1);
					u = fa[tu];
					tu = top[u];
				}
				if(depth[u] > depth[v])
					swap(u,v);
				re += Sum(in[u],in[v],1,ind,1);
				printf("%d\n",re);
			}
		}
	}
	return 0;
}