Codeforces 600E Lomsat gelral [dsu on tree(树上启发式合并)]

时间:2021-12-23 11:05:39

题意

给出一棵树,1为根节点,每个点都有一个颜色。求每个点所在子树内所有出现次数最多的颜色的和。
n<=100000

分析

新学习了一种姿势叫dsu on tree,大概意思就是树上启发式合并吧。
dsu on tree大概是用来解决这样一类问题:需要多次查询某棵子树内的某个值(必须要离线)。像这题就是需要查询每棵子树出现颜色次数最多的颜色的和。
怎么做呢?
首先我们把这棵树的dfs序求出来,顺便轻重链剖分一下,用一个全局数组统计当前答案。
每次到达一个节点,先访问其所有轻节点,每次暴力把轻节点子树的贡献删掉。然后访问重儿子,然后暴力把所有轻儿子的子树和自己的贡献加进去,最后查询当前点的询问即可。
复杂度的话,因为每个点到根的路径上,重链的数量不超过logn条,所以每个点最多被暴力加入logn次,于是总的复杂度就是O(nlogn).

伪代码

find the BigChild of each vertex
dfs(u, fa, keep)
dfs(LightChild, u, 0)
dfs(BigChild, u, 1), big[BigChild] = 1
update(u, fa, 1) //calculate the contribution of u's LightChild's SubTree
update the ans of u
big[BigChild] = 0
if keep == 0
update(u, fa, -1) //remove the contributino of u's SubTree

update(u, fa, val)
calculate u'
s information
update(v : (u, v) and !big[v], u, val)

具体到这题的话,只要开两个数组sum[i]表示出现了i次的颜色的和,tot[i]表示颜色i出现的次数,然后就变成模板题了。

代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;

typedef long long LL;

const int N=100005;

int n,mx[N],mn[N],sz,dfn[N],tot[N],col[N],last[N],size[N],Mx,cnt;
struct edge{int to,next;}e[N*2];
LL sum[N],ans[N];

int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}

void addedge(int u,int v)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}

void dfs(int x,int fa)
{
mn[x]=++sz;dfn[sz]=x;size[x]=1;
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa) continue;
dfs(e[i].to,x);
size[x]+=size[e[i].to];
}
mx[x]=sz;
}

void add(int x)
{
for (int i=mn[x];i<=mx[x];i++)
{
int y=col[dfn[i]];
sum[tot[y]]-=y;
sum[++tot[y]]+=y;
Mx=max(Mx,tot[y]);
}
}

void del(int x)
{
for (int i=mn[x];i<=mx[x];i++)
{
int y=col[dfn[i]];
sum[tot[y]]-=y;
sum[--tot[y]]+=y;
if (!sum[Mx]) Mx--;
}
}

void solve(int x,int fa)
{
int k=0;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&size[e[i].to]>size[k]) k=e[i].to;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&e[i].to!=k) solve(e[i].to,x),del(e[i].to);
if (k) solve(k,x);
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa&&e[i].to!=k) add(e[i].to);
sum[tot[col[x]]]-=col[x];
sum[++tot[col[x]]]+=col[x];
Mx=max(Mx,tot[col[x]]);
ans[x]=sum[Mx];
}

int main()
{
n=read();
for (int i=1;i<=n;i++) col[i]=read(),sum[0]+=i;
for (int i=1;i<n;i++)
{
int x=read(),y=read();
addedge(x,y);
}
dfs(1,0);
solve(1,0);
for (int i=1;i<=n;i++) printf("%I64d ",ans[i]);
return 0;
}