Bzoj 1036: [ZJOI2008]树的统计Count 树链剖分,LCT

时间:2022-02-09 10:31:51

1036: [ZJOI2008]树的统计Count

Time Limit: 10 Sec  Memory Limit: 162 MB
Submit: 11102  Solved: 4490
[Submit][Status][Discuss]

Description

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。 对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

Sample Output

4
1
2
2
10
6
5
6
5
16

HINT

Source

 题解:
树链剖分或LCT模版题。
我写了个树剖:
 #include<bits/stdc++.h>
using namespace std;
#define lson k*2,l,mid
#define rson k*2+1,mid+1,r
#define MAXN 30010
#define INF 1e9
struct node
{
int begin,end,next;
}edge[MAXN*];
struct NODE
{
int left,right,sum,mx,val;
}tree[*MAXN];
int cnt,Head[MAXN],val[MAXN],deep[MAXN],size[MAXN],P[MAXN][],belong[MAXN],pos[MAXN],SIZE,n;
bool vis[MAXN];
void addedge(int bb,int ee)
{
edge[++cnt].begin=bb;edge[cnt].end=ee;edge[cnt].next=Head[bb];Head[bb]=cnt;
}
void addedge1(int bb,int ee)
{
addedge(bb,ee);addedge(ee,bb);
}
int read()
{
int s=,fh=;char ch=getchar();
while(ch<''||ch>''){if(ch=='-')fh=-;ch=getchar();}
while(ch>=''&&ch<=''){s=s*+(ch-'');ch=getchar();}
return s*fh;
}
void dfs1(int u)
{
int i,v;
size[u]=;vis[u]=true;
for(i=Head[u];i!=-;i=edge[i].next)
{
v=edge[i].end;
if(vis[v]==false)
{
deep[v]=deep[u]+;
P[v][]=u;
dfs1(v);
size[u]+=size[v];
}
}
}
void Ycl()
{
int i,j;
for(j=;(<<j)<=n;j++)
{
for(i=;i<=n;i++)
{
if(P[i][j-]!=-)P[i][j]=P[P[i][j-]][j-];
}
}
}
void dfs2(int u,int chain)
{
int k=,i,v;
pos[u]=++SIZE;belong[u]=chain;
for(i=Head[u];i!=-;i=edge[i].next)
{
v=edge[i].end;
if(deep[v]>deep[u]&&size[v]>size[k])k=v;
}
if(k==)return;
dfs2(k,chain);
for(i=Head[u];i!=-;i=edge[i].next)
{
v=edge[i].end;
if(deep[v]>deep[u]&&v!=k)dfs2(v,v);
}
}
int LCA(int x,int y)
{
int i,j;
if(deep[x]<deep[y])swap(x,y);
for(i=;(<<i)<=deep[x];i++);i--;
for(j=i;j>=;j--)if(deep[x]-(<<j)>=deep[y])x=P[x][j];
if(x==y)return x;
for(j=i;j>=;j--)
{
if(P[x][j]!=-&&P[x][j]!=P[y][j])
{
x=P[x][j];
y=P[y][j];
}
}
return P[x][];
}
void Pushup(int k)
{
tree[k].sum=tree[k*].sum+tree[k*+].sum;
tree[k].mx=max(tree[k*].mx,tree[k*+].mx);
}
void Build(int k,int l,int r)
{
tree[k].left=l;tree[k].right=r;
if(l==r)return;
int mid=(l+r)/;
Build(lson);
Build(rson);
}
void Change(int k,int P,int V)
{
if(tree[k].left==tree[k].right)
{
tree[k].val=tree[k].sum=tree[k].mx=V;
return;
}
int mid=(tree[k].left+tree[k].right)/;
if(P<=mid)Change(k*,P,V);
else Change(k*+,P,V);
Pushup(k);
}
int Query_max(int k,int l,int r)
{
if(l<=tree[k].left&&tree[k].right<=r)return tree[k].mx;
int mid=(tree[k].left+tree[k].right)/;
if(r<=mid)return Query_max(k*,l,r);
else if(l>mid)return Query_max(k*+,l,r);
else return max(Query_max(k*,l,mid),Query_max(k*+,mid+,r));
}
int Query_sum(int k,int l,int r)
{
if(l<=tree[k].left&&tree[k].right<=r)return tree[k].sum;
int mid=(tree[k].left+tree[k].right)/;
if(r<=mid)return Query_sum(k*,l,r);
else if(l>mid)return Query_sum(k*+,l,r);
else return Query_sum(k*,l,mid)+Query_sum(k*+,mid+,r);
}
int solve_max(int x,int f)
{
int MAX=-INF;
while(belong[x]!=belong[f])
{
MAX=max(MAX,Query_max(,pos[belong[x]],pos[x]));
x=P[belong[x]][];
}
MAX=max(MAX,Query_max(,pos[f],pos[x]));
return MAX;
}
int solve_sum(int x,int f)
{
int SUM=;
while(belong[x]!=belong[f])
{
SUM+=Query_sum(,pos[belong[x]],pos[x]);
x=P[belong[x]][];
}
SUM+=Query_sum(,pos[f],pos[x]);
return SUM;
}
int main()
{
freopen("bzoj_1036.in","r",stdin);
freopen("bzoj_1036.out","w",stdout);
int m,i,bb,ee,u,v,t,lca;
char zs[];
n=read();
memset(Head,-,sizeof(Head));cnt=;
for(i=;i<n;i++)
{
bb=read();ee=read();addedge1(bb,ee);
}
memset(P,-,sizeof(P));SIZE=;
dfs1();Ycl();
dfs2(,);
for(i=;i<=n;i++)val[i]=read();
Build(,,n);
for(i=;i<=n;i++)Change(,pos[i],val[i]);
m=read();
for(i=;i<=m;i++)
{
scanf("\n%s",zs);
if(zs[]=='C'){u=read();t=read();val[u]=t;Change(,pos[u],val[u]);}
else
{
if(zs[]=='M')
{
u=read();v=read();
lca=LCA(u,v);
printf("%d\n",max(solve_max(u,lca),solve_max(v,lca)));
}
else {u=read();v=read();lca=LCA(u,v);printf("%d\n",solve_sum(u,lca)+solve_sum(v,lca)-val[lca]);}
}
}
fclose(stdin);
fclose(stdout);
return ;
}