hdu 5909 Tree Cutting——点分治(树形DP转为序列DP)

时间:2021-12-22 22:59:24

题目:http://acm.hdu.edu.cn/showproblem.php?pid=5909

点分治的话,每次要做一次树形DP;但时间应该是 siz*m2 的。可以用 FWT 变成 siz*mlogm ,但这里写的是把树变成序列来 DP 的方法,应该是 nlogn*m 的。

树上的一个点,如果选,就可以选它的孩子,所以它向它的第一个孩子连边;如果不选,就会跳到它的下一个兄弟或者是父亲的下一个兄弟之类的,向那边连一条边。

做出树的 dfs 序,把边都连在 dfs 序上;其实那个第一条边一定连向自己 dfs 序+1,即使自己没有孩子也是符合的,所以可以不用连了;第二条边可以通过传父亲的连边对象来解决。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=,M=,mod=1e9+;
int T,n,m,w[N],hd[N],xnt,to[N<<],nxt[N<<],siz[N],rt,mn;
int dfn[N],tot,sta[N],top,f[N][M],g[N],nt[N],ans[M]; bool vis[N];
int rdn()
{
int ret=;bool fx=;char ch=getchar();
while(ch>''||ch<''){if(ch=='-')fx=;ch=getchar();}
while(ch>=''&&ch<='')ret=ret*+ch-'',ch=getchar();
return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
void upd(int &x){x>=mod?x-=mod:;}
void init()
{
xnt=;memset(hd,,sizeof hd);
memset(ans,,sizeof ans); memset(vis,,sizeof vis);
}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void getrt(int cr,int fa,int s)
{
siz[cr]=; int mx=;
for(int i=hd[cr],v;i;i=nxt[i])
if(!vis[v=to[i]]&&v!=fa)
{
getrt(v,cr,s);siz[cr]+=siz[v];
mx=Mx(mx,siz[v]);
}
mx=Mx(mx,s-siz[cr]);if(mx<mn)mn=mx,rt=cr;
}
void dfs(int cr,int fa)
{
dfn[cr]=++tot;g[tot]=w[cr];
for(int i=hd[cr],v;i;i=nxt[i])
if(!vis[v=to[i]]&&v!=fa)dfs(v,cr);
}
void dfsx(int cr,int fa,int lst)
{
nt[dfn[cr]]=lst;
int l=top+;
for(int i=hd[cr],v;i;i=nxt[i])
if(!vis[v=to[i]]&&v!=fa)sta[++top]=v;
int r=top;
for(int i=hd[cr],v,p0=l;i;i=nxt[i])
if(!vis[v=to[i]]&&v!=fa)
{
dfsx(v,cr,p0==r?lst:dfn[sta[p0+]]);p0++;
}
}
void solve(int cr,int s)
{
vis[cr]=;
tot=;dfs(cr,);top=;dfsx(cr,,s+);
for(int i=;i<=s+;i++)memset(f[i],,sizeof f[i]);
f[][]=;
for(int i=;i<=s;i++)
for(int j=;j<m;j++)
{
if(!f[i][j])continue;
f[i+][j^g[i]]+=f[i][j];upd(f[i+][j^g[i]]);
f[nt[i]][j]+=f[i][j];upd(f[nt[i]][j]);
}
f[s+][]--;//dec the empty
for(int j=,k=s+;j<m;j++)ans[j]+=f[k][j],upd(ans[j]);
for(int i=hd[cr],v,ts;i;i=nxt[i])
if(!vis[v=to[i]])
{
ts=(siz[cr]>siz[v]?siz[v]:s-siz[cr]);
mn=N;getrt(v,cr,ts);solve(rt,ts);
}
}
int main()
{
T=rdn();
while(T--)
{
n=rdn();m=rdn();for(int i=;i<=n;i++)w[i]=rdn();
init();
for(int i=,u,v;i<n;i++)u=rdn(),v=rdn(),add(u,v),add(v,u);
mn=N;getrt(,,n);solve(rt,n);
for(int i=,j=m-;i<j;i++)printf("%d ",ans[i]);
printf("%d\n",ans[m-]);
}
return ;
}