题目大意:给定一棵 N 个节点的有根树,1 号节点是树的根节点,每个节点有一个颜色。求对于每个节点来说,能够支配整棵子树的颜色之和是多少。支配的定义为对于以 i 为根的子树,该颜色出现的次数不小于任何其他颜色出现的次数。
题解:学会了 dsu on tree。
树上启发式合并算法是一种对暴力的优化算法。对于暴力算法来说,直接遍历每个节点,再遍历该节点对应的子树寻找答案,时间复杂度显然为 \(O(n^2)\)。考虑进行优化,显然对于父节点来说,遍历到的最后一棵子树的贡献可以不用消去,直接加在父节点对应的子树上即可。根据这条性质,每次都选择 i 的重儿子进行直接累加答案贡献,重儿子的定义和求法与树剖中重儿子相同。可以看出对于树上每个节点,在统计答案的时候仅仅遍历了子树内的所有轻边,即:重儿子对应的子树再合并时不需要再次遍历。从全局的角度来说,对于第 i 个节点为根的子树,子树在统计答案时被遍历的次数和该节点到根节点路径上的轻边个数成正比。一共只有 \(O(logn)\) 条轻边,因此,遍历 n 个节点的时间复杂度上界为 \(O(nlogn)\)。
代码如下
#include <bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define all(x) x.begin(),x.end()
#define cls(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
typedef pair<int,int> P;
const int dx[]={0,1,0,-1};
const int dy[]={1,0,-1,0};
const int mod=1e9+7;
const int inf=0x3f3f3f3f;
const int maxn=1e5+10;
const double eps=1e-6;
inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
inline ll sqr(ll x){return x*x;}
inline ll fpow(ll a,ll b,ll c){ll ret=1%c;for(;b;b>>=1,a=a*a%c)if(b&1)ret=ret*a%c;return ret;}
inline ll read(){
ll x=0,f=1;char ch;
do{ch=getchar();if(ch=='-')f=-1;}while(!isdigit(ch));
do{x=x*10+ch-'0';ch=getchar();}while(isdigit(ch));
return f*x;
}
/*------------------------------------------------------------*/
vector<int> G[maxn];
int size[maxn],son[maxn];
int n,cor[maxn],cnt[maxn];
bool skip[maxn];
ll mx,now,ans[maxn];
void getsize(int u,int fa){
size[u]=1;
for(auto v:G[u]){
if(v==fa)continue;
getsize(v,u);
if(size[v]>size[son[u]])son[u]=v;
size[u]+=size[v];
}
}
void add(int u,int fa,int val){
cnt[cor[u]]+=val;
if(val>0){
if(cnt[cor[u]]>mx)now=cor[u],mx=cnt[cor[u]];
else if(cnt[cor[u]]==mx)now+=cor[u];
}
for(auto v:G[u]){
if(v==fa||skip[v])continue;
add(v,u,val);
}
}
void dfs(int u,int fa,bool keep){
for(auto v:G[u]){
if(v==fa||v==son[u])continue;
dfs(v,u,0);
}
if(son[u])dfs(son[u],u,1),skip[son[u]]=1;
add(u,fa,1);
ans[u]=now;
if(son[u])skip[son[u]]=0;
if(!keep)add(u,fa,-1),mx=now=0;
}
void read_and_parse(){
n=read();
for(int i=1;i<=n;i++)cor[i]=read();
for(int i=1,x,y;i<n;i++){
x=read(),y=read();
G[x].pb(y),G[y].pb(x);
}
}
void solve(){
getsize(1,0);
dfs(1,0,1);
for(int i=1;i<=n;i++)printf("%lld%c",ans[i],i==n?'\n':' ');
}
int main(){
read_and_parse();
solve();
return 0;
}