【bzoj1036】[ZJOI2008]树的统计Count

时间:2022-03-08 03:50:40

树链剖分模板题,嗯看一个不知名的神犇里的博客里的一句话感觉说的很对,树链剖分就是将树hash到数组中然后用线段树或者平衡树来维护的数据结构,剖分出轻重链之后维护就好了,也可以结合欧拉序,DFS序来用。
一开始想找着zyf2000的写,发现她的代码又臭又长(不是故意D神犇的,希望神犇看不见),所以结合着hzwer的,自己改成了自己的码风写了一下.

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>

using namespace std;
const int N=30010,M=60010,inf=100000000;
int n,q,cnt,sz,te,qx;
int val[N],h[N],size[N],head[N],fa[N];
int num[N],tp[N],tree[N];
struct edge{
    int v,next;
}e[M];
struct seg{
    int l,r,sum,mx;
}tr[100010];
inline int F()
{
    register int aa,bb;register char ch;
    while(ch=getchar(),(ch<'0'||ch>'9')&&ch!='-');ch=='-'?aa=bb=0:(aa=ch-'0',bb=1);
    while(ch=getchar(),ch>='0'&&ch<='9')aa=(aa<<3)+(aa<<1)+ch-'0';return bb?aa:-aa;
}
void add(int u,int v)
{
    e[++te].v=v;
    e[te].next=head[u];
    head[u]=te;
}
void dfs1(int x)
{
    size[x]=1;
    for (int i=head[x];i;i=e[i].next)
    {
        int v=e[i].v;
        if (e[i].v==fa[x])continue;
        h[v]=h[x]+1;
        fa[v]=x;
        dfs1(v);
        size[x]+=size[v];
    }
}
void dfs2(int x,int chain)
{
    int k=0;
    num[x]=++sz;
    tp[x]=chain;
    for (int i=head[x];i;i=e[i].next)
    {
        int v=e[i].v;
        if (h[v]>h[x]&&size[v]>size[k])
        k=v;
    }
    if (k==0)return;
    dfs2(k,chain);
    for (int i=head[x];i;i=e[i].next)
    {
        int v=e[i].v;
        if (h[v]>h[x]&&k!=v)
        dfs2(v,v);
    }
}
void build(int k,int l,int r)//建树 
{
    tr[k].l=l;tr[k].r=r;
    if (l==r){tr[k].mx=val[tree[l]],tr[k].sum=val[tree[l]];return;}
    int mid=(l+r)>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx);
    tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
}
void change(int k,int x,int y)//线段树单点修改 
{
    int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1;
    if (l==r){tr[k].sum=tr[k].mx=y;return;}
    if (x<=mid)change(k<<1,x,y);
    else change(k<<1|1,x,y);
    tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
    tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx);
}
void querysum(int k,int x,int y)
{
    int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1;
    if (x<=l&&y>=r){qx+=tr[k].sum;return;}
    if (x<=mid)querysum(k<<1,x,y);
    if (y>mid)querysum(k<<1|1,x,y);
}
void querymx(int k,int x,int y)
{
    int l=tr[k].l,r=tr[k].r,mid=(l+r)>>1;
    if (x<=l&&y>=r){qx=max(tr[k].mx,qx);return;}
    if (x<=mid)querymx(k<<1,x,y);
    if (y>mid)querymx(k<<1|1,x,y);
}
void solvesum(int x,int y)
{
    qx=0;
    while (tp[x]!=tp[y])
    {
        if (h[tp[x]]<h[tp[y]])swap(x,y);
        querysum(1,num[tp[x]],num[x]);
        x=fa[tp[x]];
    }
    if (num[x]>num[y])swap(x,y);
    querysum(1,num[x],num[y]);
}
void solvemx(int x,int y)
{   
    qx=-inf;
    while (tp[x]!=tp[y])
    {
        if (h[tp[x]]<h[tp[y]])swap(x,y);
        querymx(1,num[tp[x]],num[x]);
        x=fa[tp[x]];
    }
    if (num[x]>num[y])swap(x,y);
    querymx(1,num[x],num[y]);
}
int main()
{
    memset(head,0,sizeof(head));
    n=F();size[0]=0;
    for (int i=1;i<n;++i)
    {
        int u,v;
        u=F(),v=F();
        add(u,v),add(v,u);
    }
    for (int i=1;i<=n;++i)val[i]=F();
    dfs1(1);
    dfs2(1,1);
    for (int i=1;i<=n;++i)tree[num[i]]=i;
    build(1,1,n);
    q=F();
    char ch[10];
    for (int i=1;i<=q;++i)
    {
        int x,y;scanf("%s%d%d",ch,&x,&y);
        if (ch[0]=='C'){val[x]=y;change(1,num[x],y);}
        else 
        {
            if(ch[1]=='M') solvemx(x,y);
            else  solvesum(x,y);
            printf("%d\n",qx);
        }
    }
}