POJ3237 树链剖分+线段树

时间:2021-10-20 21:26:12

http://blog.sina.com.cn/s/blog_6974c8b20100zc61.html 树链剖分的原理这里讲的很详细了


# include<cstdio>
# include<algorithm>
# include<iostream>
# include<string>
# include<cstring>
# include<vector>
# include<queue>
# include<stack>
# include<cctype>

# define inf 0x3f3f3f3f
# define N 100000
using namespace std;

struct edge
{
    int from,to,cost;
}e[10005];

struct node
{
    int to,cost;
};

vector<node> G[10005];

int n;
int counter;
int fa[10005],dep[10005],siz[10005],son[10005];
int w[10005],top[10005];
int maxv[40005];
int minv[40005];
int flag[40005];

void clean()
{
    for(int i=1;i<=n;i++)   G[i].clear();
    memset(dep,0,sizeof(dep));
    memset(son,0,sizeof(son));
    memset(siz,0,sizeof(siz));
    memset(maxv,0,sizeof(maxv));
    memset(minv,0,sizeof(minv));
    memset(flag,0,sizeof(flag));
    counter=0;
}

void dfs1(int u,int pa)
{
    int maxs=-1;
    dep[u]=dep[pa]+1;
    fa[u]=pa;
    for(int i=0;i<G[u].size();i++)
    {
        node p=G[u][i];
        if(p.to==pa)    continue;
        dfs1(p.to,u);
        siz[u]+=siz[p.to];
        if(siz[p.to]>maxs)
        {
            maxs=siz[p.to];
            son[u]=p.to;
        }
    }
}

void dfs2(int u,int pa,int ancester)
{
    if(!u)  return ;
    w[u]=counter++;
    top[u]=ancester;
    dfs2(son[u],u,ancester);
    for(int i=0;i<G[u].size();i++)
    {
        node p=G[u][i];
        if(p.to==pa||p.to==son[u])    continue;
        dfs2(p.to,u,p.to);
    }
}

void neg(int k)
{
    int t1=-minv[k],t2=-maxv[k];
    maxv[k]=t1,minv[k]=t2;
}

void change(int q,int k,int l,int r,int c)
{
    if(q<l||q>=r)   return ;
    else if(r-l==1)     maxv[k]=minv[k]=c;
    else
    {
        if(flag[k])
        {
            flag[2*k+1]^=1;flag[2*k+2]^=1;
            flag[k]=0;
            neg(2*k+1);
            neg(2*k+2);
        }
        change(q,2*k+1,l,(l+r)/2,c);
        change(q,2*k+2,(l+r)/2,r,c);
        maxv[k]=max(maxv[2*k+1],maxv[2*k+2]);
        minv[k]=min(minv[2*k+1],minv[2*k+2]);
    }
}

void setf(int ql,int qr,int k,int l,int r)
{
    if(ql>=r||qr<=l)    return ;
    else if(ql<=l&&qr>=r)
    {
        flag[k]^=1;
        neg(k);
    }
    else
    {
        if(flag[k])
        {
            flag[2*k+1]^=1;flag[2*k+2]^=1;
            flag[k]=0;
            neg(2*k+1);
            neg(2*k+2);
        }
        setf(ql,qr,2*k+1,l,(l+r)/2);
        setf(ql,qr,2*k+2,(l+r)/2,r);
        maxv[k]=max(maxv[2*k+1],maxv[2*k+2]);
        minv[k]=min(minv[2*k+1],minv[2*k+2]);
    }
}

int query(int ql,int qr,int k,int l,int r)
{
    if(ql>=r||qr<=l)    return -inf;
    else if(ql<=l&&qr>=r)   return maxv[k];
    else
    {
        if(flag[k])
        {
            flag[2*k+1]^=1;flag[2*k+2]^=1;
            neg(2*k+1);
            neg(2*k+2);
            flag[k]=0;
        }
        int res1=query(ql,qr,2*k+1,l,(l+r)/2);
        int res2=query(ql,qr,2*k+2,(l+r)/2,r);
        return max(res1,res2);
    }
}

void solve(int va,int vb,int op)
{
    if(op==1)
    {
        change(w[e[va].to],0,0,counter,vb);
        return ;
    }
    int f1=top[va],f2=top[vb],res=-inf;
    while(f1!=f2)
    {
        if(dep[f1]<dep[f2])     {swap(f1,f2);swap(va,vb);}
        if(op==2)   setf(w[f1],w[va]+1,0,0,counter);
        else   res=max(query(w[f1],w[va]+1,0,0,counter),res);
        va=fa[f1];f1=top[va];
    }
    if(va==vb)
    {
        if(op==3)   printf("%d\n",res);
        return ;
    }
    if(dep[va]>dep[vb])  swap(va,vb);
    if(op==2)   setf(w[son[va]],w[vb]+1,0,0,counter);
    else
    {
        res=max(res,query(w[son[va]],w[vb]+1,0,0,counter));
        printf("%d\n",res);
    }
}

int main()
{
    int test;
    scanf("%d",&test);
    while(test--)
    {
        scanf("%d",&n);
        clean();
        for(int i=1;i<n;i++)
        {
            int a,b,c;
            scanf("%d%d%d",&a,&b,&c);
            e[i]=(edge){a,b,c};
            G[a].push_back((node){b,c});
            G[b].push_back((node){a,c});
        }
        dfs1(1,1);
        dfs2(1,1,1);
        for(int i=1;i<n;i++)
        {
            if(dep[e[i].from]>dep[e[i].to])    swap(e[i].from,e[i].to);
            change(w[e[i].to],0,0,counter,e[i].cost);
        }
        char que[10];
        while(scanf("%s",que)&&que[0]!='D')
        {
            int u,v;
            scanf("%d%d",&u,&v);
            if(que[0]=='C')    solve(u,v,1);
            else if(que[0]=='N')    solve(u,v,2);
            else    solve(u,v,3);
        }
    }
    return 0;
}