dsu on tree(CF600E Lomsat gelral)

时间:2022-10-13 22:09:55

题意

一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。

dsu on tree

用来解决子树问题
好像不能带修改??

暴力做这个题,就是每次扫一遍子树统计答案
时间\(O(n^2)\)

或者会高级的数据结构解决
空间,编程难度是个挑战

然而\(dsu \ on \ tree\)树上启发式合并则是一个好方法
它通过增加对重儿子子树信息的利用来提高效率

流程:

递归轻儿子
递归重儿子
统计答案
如果该点为它父亲的重儿子,保存信息
否则删除信息

复杂度分析:
每个点被扫到的次数只有它到根的路径上轻边的次数*\(2\)
也就是\(log\)
那么总复杂度为空间\(O(n)\),时间\(O(nlogn)\)

该题代码

# include <bits/stdc++.h>
# define IL inline
# define RG register
# define Fill(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long ll;

IL int Input(){
    RG int x = 0, z = 1; RG char c = getchar();
    for(; c < '0' || c > '9'; c = getchar()) z = c == '-' ? -1 : 1;
    for(; c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
    return x * z;
}

const int maxn(1e5 + 5);

int n, first[maxn], cnt, col[maxn], size[maxn], son[maxn], vis[maxn], num[maxn], mx;
ll sum[maxn], ans[maxn];

struct Edge{
    int to, next;
} edge[maxn << 1];

IL void Add(RG int u, RG int v){
    edge[cnt] = (Edge){v, first[u]}, first[u] = cnt++;
}

IL void Dfs(RG int u, RG int ff){
    size[u] = 1;
    for(RG int e = first[u]; e != -1; e = edge[e].next){
        RG int v = edge[e].to;
        if(v != ff){
            Dfs(v, u);
            size[u] += size[v];
            if(size[v] > size[son[u]]) son[u] = v;
        }
    }
}

IL void Update(RG int u, RG int ff, RG int val){
    sum[num[col[u]]] -= col[u];
    num[col[u]] += val;
    sum[num[col[u]]] += col[u];
    if(val > 0) mx = max(mx, num[col[u]]);
    else while(mx && !sum[mx]) --mx;
    for(RG int e = first[u]; e != -1; e = edge[e].next)
        if(edge[e].to != ff && !vis[edge[e].to]) Update(edge[e].to, u, val);
}

IL void Solve(RG int u, RG int ff, RG int op){
    size[u] = 1;
    for(RG int e = first[u]; e != -1; e = edge[e].next)
        if(edge[e].to != ff && edge[e].to != son[u]) Solve(edge[e].to, u, 0);
    if(son[u]) Solve(son[u], u, 1), vis[son[u]] = 1;
    Update(u, ff, 1), vis[son[u]] = 0;
    ans[u] = sum[mx];
    if(!op) Update(u, ff, -1);
}

int main(){
    n = Input();
    for(RG int i = 1; i <= n; ++i) col[i] = Input(), first[i] = -1;
    for(RG int i = 1; i < n; ++i){
        RG int u = Input(), v = Input();
        Add(u, v), Add(v, u);
    }
    Dfs(1, 0), Solve(1, 0, 1);
    for(RG int i = 1; i <= n; ++i) printf("%lld ", ans[i]);
    return 0;
}