树上启发式合并 (dsu on tree)

时间:2022-08-06 13:05:32

这个故事告诉我们,在做一个辣鸡出题人的比赛之前,最好先看看他发明了什么新姿势= =居然直接出了道裸题

参考链接:

http://codeforces.com/blog/entry/44351(原文)

http://blog.csdn.net/QAQ__QAQ/article/details/53455462

这种技巧可以在O(nlogn)的时间内解决绝大多数的无修改子树询问问题。

例1 子树颜色统计

有一棵n个点的有根树,根为1,每个点有一个1~n的颜色,对于每一个点给了一个数k,要询问这个子树中颜色为k的点的个数。n<=500000。

这个例子当然过于trivial,dfs序完一棵主席树就能水过去,不过这不是重点...

我们的目标就是实现一个不那么暴力的东西,可以代替以下代码:

Edg int n,cc[SZ],col[SZ],ks[SZ],anss[SZ];
void edt(int x,int f,int v)
{
cc[col[x]]+=v;
for esb(x,e,b)
if(b!=f) edt(b,x,v);
}
void dfs(int x,int f=0)
{
edt(x,f,1);
anss[x]=cc[ks[x]];
edt(x,f,-1);
for esb(x,e,b)
if(b!=f) dfs(b,x);
}

输入大致如下:

int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",col+i);
for(int i=1;i<=n;i++) scanf("%d",ks+i);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
adde(a,b);
}
dfs(1);
for(int i=1;i<=n;i++) printf("%d\n",anss[i]);
}

以下是一个清真的nlogn做法:

Edg
int n,cc[SZ],col[SZ],ks[SZ],anss[SZ];
int sz[SZ],son[SZ];
void gs(int x,int f=0)
{
sz[x]=1;
for esb(x,e,b)
{
if(b==f) continue;
gs(b,x); sz[x]+=sz[b];
if(sz[b]>sz[son[x]]) son[x]=b;
}
}
int skip=0;
void edt(int x,int f,int v)
{
cc[col[x]]+=v;
for esb(x,e,b)
if(b!=f&&b!=skip) edt(b,x,v);
}
void dfs(int x,int f=0,bool kep=0)
{
for esb(x,e,b)
if(b!=f&&b!=son[x]) dfs(b,x);
if(son[x])
dfs(son[x],x,1), skip=son[x];
edt(x,f,1);
anss[x]=cc[ks[x]];
skip=0;
if(!kep) edt(x,f,-1);
}

(如果操作比较复杂的话建议看看下面的例三代码= =)

这为什么是nlogn的?因为一条重链会使所在子树大小翻一倍。

我们发现这个技巧十分好用,只要兹磁往一个集合里插入是O(1)的,删除元素到空为止是每个O(1)的(注意到这样写的话每次删除是一定会删到全空的),那么对于子树集合的询问就可以做到O(nlogn),不知道比辣鸡莫队高到哪里去了(莫队复杂度nsqrt(n),而且必须也要支持删除O(1))。

例2 Lomsat gelral(cf600E)

n个点的有根树,以1为根,每个点有一种颜色。我们称一种颜色占领了一个子树当且仅当没有其他颜色在这个子树中出现得比它多。求占领每个子树的所有颜色之和。

模板题啦。

#define SZ 666666
Edg
int n,cc[SZ],col[SZ],sz[SZ],son[SZ];
ll anss[SZ];
void gs(int x,int f=0)
{
sz[x]=1;
for esb(x,e,b)
{
if(b==f) continue;
gs(b,x); sz[x]+=sz[b];
if(sz[b]>sz[son[x]]) son[x]=b;
}
}
bool skip[SZ];
int cx=0; ll sum=0;
void edt(int x,int f,int k)
{
cc[col[x]]+=k;
if(k>0&&cc[col[x]]>=cx)
{
if(cc[col[x]]>cx)
sum=0, cx=cc[col[x]];
sum+=col[x];
}
for esb(x,e,b)
if(b!=f&&!skip[b]) edt(b,x,k);
}
void dfs(int x,int f=0,bool kep=0)
{
for esb(x,e,b)
if(b!=f&&b!=son[x]) dfs(b,x);
if(son[x])
dfs(son[x],x,1), skip[son[x]]=1;
edt(x,f,1);
anss[x]=sum;
if(son[x]) skip[son[x]]=0;
if(!kep)
edt(x,f,-1), cx=sum=0;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",col+i);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
adde(a,b);
}
gs(1); dfs(1);
for(int i=1;i<=n;i++) printf("%I64d ",anss[i]);
}

例3 Arpa's letter-marked tree and Mehrdad's Dokhtar-kosh paths(CF741D)

辣鸡出题人还我rating

我们称一个字符串为doubi string当且仅当重排它的字符可以组成一个回文串。

给出一个n个点的有根树,根为1,每条边上有一个字符(只有a~v,别问我为什么),求每个点的子树中所有简单路径可以组成的doubi string中的最长长度。

doubi string显然就是只有0/1个字符出现奇数次的字符串,如果只有a~v的话考虑把每个字符当做一个二进制位,把一个点i到根的路径异或值记为s[i],那么我们就是要对于每个x在子树中找到a和b,使得s[a]^s[b]为0或2的次幂,且dep[a]+dep[b]-dep[lca]*2最大。

那么问题来了,lca如果直接当做x算出来的答案是会变大的...看起来我们需要把这个东西扩展一下,让它只统计不同子树的。这个好办,对于每棵子树先统计再更新就行了。

模板大法好!

#define SZ 1234567
Edgc
int n,sz[SZ],son[SZ],fc[SZ],dep[SZ];
void gs(int x,int f=0)
{
sz[x]=1;
for esb(x,e,b)
{
if(b==f) continue;
fc[b]=fc[x]^vc[e];
dep[b]=dep[x]+1;
gs(b,x); sz[x]+=sz[b];
if(sz[b]>sz[son[x]]) son[x]=b;
}
}
const int S='v'-'a'+1,inf=1e9;
int md[5555555],cans,skip,cdep;
void clr(int x) {md[fc[x]]=-inf;}
void upd(int x)
{
cans=max(cans,md[fc[x]]+dep[x]-cdep*2);
for(int i=0;i<S;i++)
cans=max(cans,md[fc[x]^(1<<i)]+dep[x]-cdep*2);
}
void ins(int x)
{md[fc[x]]=max(md[fc[x]],dep[x]);}
template<void(*func)(int)>
void edt(int x,int f)
{
func(x);
for esb(x,e,b)
if(b!=f&&b!=skip) edt<func>(b,x);
}
int dp[SZ];
void dfs(int x,int f=0,bool kep=0)
{
for esb(x,e,b)
if(b!=f&&b!=son[x]) dfs(b,x);
if(son[x])
dfs(son[x],x,1), skip=son[x];
cdep=dep[x];
for esb(x,e,b) if(b!=f)
dp[x]=max(dp[x],dp[b]);
for esb(x,e,b) if(b!=f&&b!=son[x])
edt<upd>(b,x), edt<ins>(b,x);
upd(x); ins(x);
dp[x]=max(dp[x],cans);
skip=0;
if(!kep) edt<clr>(x,f), cans=-inf;
}
int main()
{
for(int i=0;i<(1<<S);i++)
md[i]=-inf;
scanf("%d",&n);
for(int i=2;i<=n;i++)
{
int x; char c[3];
scanf("%d%s",&x,c);
adde(i,x,1<<(c[0]-'a'));
}
gs(1); dfs(1);
for(int i=1;i<=n;i++)
printf("%d ",dp[i]);
}