HDU_5692_dfs序+线段树

时间:2023-02-02 21:15:47

http://acm.hdu.edu.cn/showproblem.php?pid=5692

 

这道题真的是看了题解还搞了一天,把每条路径后序遍历按1-n重新标号,储存每个点在哪些路径中出现过(l和r数组),然后转化成线段树来更新和取最大值。

注意,如果使用递归建线段树,数组要开4n才能保证不超。

刚开始更新的函数每一遍更新了所有子树,然后超时了,后来在tree中加了个add,保存子树需要增加的量,如果用到这个子树,在把这个量加起来,简直巧妙。

 

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define INF 0x3f3f3f3f
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;

vector<int> line[100005];
int n,m,cnt,l[100005],r[100005];
long long w[100005],init[100005];
struct segtree
{
    int left,right;
    long long maxx,add;
}tree[400005];

void dfs(int now,int pre,long long sum)
{
    sum += w[now];
    int flag = 1;
    l[now] = INF;
    for(int i = 0;i < line[now].size();i++)
    {
        int next = line[now][i];
        if(next == pre) continue;
        flag = 0;
        dfs(next,now,sum);
        l[now] = min(l[now],l[next]);
    }
    init[cnt] = sum;
    r[now] = cnt++;
    if(flag)    l[now] = r[now];
}

void build(int pos,int l,int r)
{
    tree[pos].left = l;
    tree[pos].right = r;
    tree[pos].add = 0;
    if(l == r)  tree[pos].maxx = init[l];
    else
    {
        int mid = (l+r)/2;
        build(pos*2,l,mid);
        build(pos*2+1,mid+1,r);
        tree[pos].maxx = max(tree[pos*2].maxx,tree[pos*2+1].maxx);
    }
}

void update(int pos,int l,int r,long long v)
{
    if(tree[pos].add != 0)
    {
        if(tree[pos].left != tree[pos].right)
        {
            tree[pos*2].maxx += tree[pos].add;
            tree[pos*2].add += tree[pos].add;
            tree[pos*2+1].maxx += tree[pos].add;
            tree[pos*2+1].add += tree[pos].add;
            tree[pos].add = 0;
        }
    }
    if(tree[pos].left == 
       l && r == tree[pos].right)
    {
        tree[pos].maxx += v;
        tree[pos].add += v;
        return;
    }
    int mid = (tree[pos].left+tree[pos].right)/2;
    if(r <= mid)    update(pos*2,l,r,v);
    else if(l > mid)    update(pos*2+1,l,r,v);
    else
    {
        update(pos*2,l,mid,v);
        update(pos*2+1,mid+1,r,v);
    }
    tree[pos].maxx = max(tree[pos*2].maxx,tree[pos*2+1].maxx);
}

long long getmax(int pos,int l,int r)
{
    if(tree[pos].add != 0)
    {
        if(tree[pos].left != tree[pos].right)
        {
            tree[pos*2].maxx += tree[pos].add;
            tree[pos*2].add += tree[pos].add;
            tree[pos*2+1].maxx += tree[pos].add;
            tree[pos*2+1].add += tree[pos].add;
            tree[pos].add = 0;
        }
    }
    if(tree[pos].left == l && r == tree[pos].right) return tree[pos].maxx;
    int mid = (tree[pos].left+tree[pos].right)/2;
    if(r <= mid)    return getmax(pos*2,l,r);
    if(l > mid)     return getmax(pos*2+1,l,r);
    return max(getmax(pos*2,l,mid),getmax(pos*2+1,mid+1,r));
}

int main()
{
    int T;
    scanf("%d",&T);
    for(int z = 1;z <= T;z++)
    {
        printf("Case #%d:\n",z);
        scanf("%d%d",&n,&m);
        for(int i = 0;i < n;i++)    line[i].clear();
        for(int i = 0;i < n-1;i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            line[x].push_back(y);
            line[y].push_back(x);
        }
        for(int i = 0;i < n;i++)    scanf("%lld",&w[i]);
        cnt = 1;
        dfs(0,0,0);
        build(1,1,n);
        while(m--)
        {
            int op;
            scanf("%d",&op);
            if(op == 0)
            {
                int x,y;
                scanf("%d%d",&x,&y);
                long long temp = y-w[x];
                w[x] = y;
                update(1,l[x],r[x],temp);
            }
            else
            {
                int x;
                scanf("%d",&x);
                printf("%lld\n",getmax(1,l[x],r[x]));
            }
        }
    }
    return 0;
}